diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 2375a391..b6f24777 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -14,31 +14,36 @@ parameters: - name: matrix type: object default: - # - job_name: 'UbuntuPython38' - # py: '3.8' - # img: 'ubuntu-latest' - # every_time: false - # publish_coverage: false - - job_name: 'UbuntuPython311' - py: '3.11' + - job_name: 'UbuntuPython39' + py: '3.9' + img: 'ubuntu-latest' + every_time: false + publish_coverage: false + - job_name: 'UbuntuPython310' + py: '3.10' img: 'ubuntu-latest' every_time: true publish_coverage: true - # - job_name: 'WindowsPython38' - # py: '3.8' - # img: 'windows-latest' - # every_time: false - # publish_coverage: false + - job_name: 'UbuntuPython311' + py: '3.11' + img: 'ubuntu-latest' + every_time: false + publish_coverage: false + - job_name: 'WindowsPython39' + py: '3.9' + img: 'windows-latest' + every_time: false + publish_coverage: false - job_name: 'WindowsPython311' py: '3.11' img: 'windows-latest' every_time: false publish_coverage: false - # - job_name: 'MacOSPython38' - # py: '3.8' - # img: 'macOS-latest' - # every_time: false - # publish_coverage: false + - job_name: 'MacOSPython39' + py: '3.9' + img: 'macOS-latest' + every_time: false + publish_coverage: false - job_name: 'MacOSPython311' py: '3.11' img: 'macOS-latest' @@ -63,7 +68,7 @@ stages: displayName: 'Use Python ${{ item.py }}' - script: | - python -m pip install pre-commit + python -m pip install pre-commit>=6.1 pre-commit install pre-commit run --all-files displayName: 'Run pre-commits' @@ -71,7 +76,6 @@ stages: - script: | python -m pip install --upgrade pip==23.0.1 pip install wheel==0.38.4 --upgrade - pip install setuptools==66 --upgrade pip install build==0.10.0 pip install pytest-azurepipelines displayName: 'Install build dependencies' @@ -109,10 +113,8 @@ stages: - script: | pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml notebooks_exit_code=$? - pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml - package_notebooks_exit_code=$? - # Fail step if either of these do not have exit code 0 - if [ $notebooks_exit_code -ne 0 ] || [ $package_notebooks_exit_code -ne 0 ]; then + # Fail step if exit code not equal to 0 + if [ $notebooks_exit_code -ne 0 ]; then exit 1 fi displayName: 'Run notebooks on Linux and macOS' @@ -122,11 +124,8 @@ stages: - script: | pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml set notebooks_exit_code=%ERRORLEVEL% - pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml - set package_notebooks_exit_code=%ERRORLEVEL% - rem Fail step if either of these do not have exit code 0 + rem Fail step if exit code not equal to 0 if %notebooks_exit_code% NEQ 0 exit /b 1 - if %package_notebooks_exit_code% NEQ 0 exit /b 1 displayName: 'Run notebooks on Windows' condition: eq(variables['Agent.OS'], 'Windows_NT') diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3088dc1d..d004dd6c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,10 @@ repos: - - repo: local - hooks: - - id: ensure-copyright-clause - name: ensure copyright clause - entry: python copyright_clause_pre_commit_hook.py - language: python + # - repo: local + # hooks: + # - id: ensure-copyright-clause + # name: ensure copyright clause + # entry: python copyright_clause_pre_commit_hook.py + # language: python - repo: http://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: @@ -31,7 +31,7 @@ repos: - id: isort args: [ "--profile", "black" ] - repo: http://github.com/PyCQA/flake8 - rev: 6.0.0 + rev: 6.1.0 hooks: - id: flake8 additional_dependencies: diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a1f7919..c7597bb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,42 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [4.0.0] = TBC + +### Added +- Log observation space data by episode and step. +- Added `show_history` method to Agents, allowing you to view actions taken by an agent per step. By default, `do-nothing` actions are omitted. +- New ``node-send-local-command`` action implemented which grants agents the ability to execute commands locally. (Previously limited to remote only) +- Added ability to set the observation threshold for NMNE, file access and application executions + +### Changed +- Agents now follow a common configuration format, simplifying the configuration of agents and their extensibilty. +- Actions within PrimAITE are now extensible, allowing for plugin support. +- Added a config schema to `ObservationManager`, `ActionManager`, and `RewardFunction`. +- Streamlined the way agents are created from config +- Agent config no longer requires a dummy action space if the action space is empty, the same applies for observation space and reward function +- Actions now support a config schema, to allow yaml data validation and default parameter values +- Action parameters are no longer defined through IDs, instead meaningful data is provided directly in the action map +- Test and example YAMLs have been updated to match the new agent and action schemas, such as: + - Removed empty action spaces, observation spaces, or reward spaces for agent which didn't use them + - Relabeled action parameters to match the new action config schemas, and updated the values to no longer rely on indices + - Removed action space options which were previously used for assigning meaning to action space IDs +- Updated tests that don't use YAMLs to still use the new action and agent schemas +- Nodes now use a config schema and are extensible, allowing for plugin support. +- Node tests have been updated to use the new node config schemas when not using YAML files. +- ACLs are no longer applied to layer-2 traffic. +- Random number seed values are recorded in simulation/seed.log if the seed is set in the config file + or `generate_seed_value` is set to `true`. +- ARP .show() method will now include the port number associated with each entry. +- Added `services_requires_scan` and `applications_requires_scan` to agent observation space config to allow the agents to be able to see actual health states of services and applications without requiring scans (Default `True`, set to `False` to allow agents to see actual health state without scanning). +- Updated the `Terminal` class to provide response information when sending remote command execution. + +### Fixed +- DNS client no longer fails to check its cache if a DNS server address is missing. +- DNS client now correctly inherits the node's DNS address configuration setting. + + +## [3.3.0] - 2024-09-04 ## [3.4.0] @@ -36,11 +72,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added reward calculation details to AgentHistoryItem. - Added a new Privilege-Escalation-and Data-Loss-Example.ipynb notebook with a realistic cyber scenario focusing on internal privilege escalation and data loss through the manipulation of SSH access and Access Control Lists (ACLs). +- Added a new extensible `NetworkNodeAdder` class for convenient addition of sets of nodes based on a simplified config. ### Changed - File and folder observations can now be configured to always show the true health status, or require scanning like before. - It's now possible to disable stickiness on reward components, meaning their value returns to 0 during timesteps where agent don't issue the corresponding action. Affects `GreenAdminDatabaseUnreachablePenalty`, `WebpageUnavailablePenalty`, `WebServer404Penalty` - Node observations can now be configured to show the number of active local and remote logins. +- Ports and IP Protocols no longer use enums. They are defined in dictionary lookups and are handled by custom validation to enable extensibility with plugins. +- Changed AirSpaceFrequency to a data transfer object with a registry to allow extensibility +- Changed the Office LAN creation convenience function to follow the new `NetworkNodeAdder` pattern. Office LANs can now also be defined in YAML config. ### Fixed - Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config) @@ -48,6 +88,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 and `uninstall` methods in the `Node` class. - Updated the `receive_payload_from_session_manager` method in `SoftwareManager` so that it now sends a copy of the payload to any software listening on the destination port of the `Frame`. +- Made the `show` method of `Network` show all node types, including ones registered at runtime ### Removed - Removed the `install` and `uninstall` methods in the `Node` class. diff --git a/MANIFEST.in b/MANIFEST.in index 2ac7b306..51ae4ddf 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,2 @@ include src/primaite/setup/_package_data/primaite_config.yaml include src/primaite/config/_package_data/*.yaml -include src/primaite/simulator/_package_data/*.ipynb diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 4ad398b9..ddedebb7 100644 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Any, Dict, Optional, Tuple from gymnasium.core import ObsType diff --git a/benchmark/primaite_benchmark.py b/benchmark/primaite_benchmark.py index 86ed22a9..70ea8900 100644 --- a/benchmark/primaite_benchmark.py +++ b/benchmark/primaite_benchmark.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import json import shutil from datetime import datetime diff --git a/benchmark/report.py b/benchmark/report.py index 4035ceca..c11528ab 100644 --- a/benchmark/report.py +++ b/benchmark/report.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import json import sys from datetime import datetime diff --git a/benchmark/utils.py b/benchmark/utils.py index 2e92d80d..f17c64b7 100644 --- a/benchmark/utils.py +++ b/benchmark/utils.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import platform from typing import Dict diff --git a/docs/_templates/custom-class-template.rst b/docs/_templates/custom-class-template.rst index 920158d5..71e992bc 100644 --- a/docs/_templates/custom-class-template.rst +++ b/docs/_templates/custom-class-template.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates. diff --git a/docs/_templates/custom-module-template.rst b/docs/_templates/custom-module-template.rst index 98627e43..3a2ced35 100644 --- a/docs/_templates/custom-module-template.rst +++ b/docs/_templates/custom-module-template.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates. diff --git a/docs/api.rst b/docs/api.rst index 977f9e87..eb7e4719 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -2,7 +2,7 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. DO NOT DELETE THIS FILE! It contains the all-important `.. autosummary::` directive with `:recursive:` option, without diff --git a/docs/conf.py b/docs/conf.py index 318829fd..60739499 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK # Configuration file for the Sphinx documentation builder. # # For the full list of built-in configuration values, see the documentation: diff --git a/docs/index.rst b/docs/index.rst index 118f7ebf..aa7d16e0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK Welcome to PrimAITE's documentation ==================================== @@ -30,6 +30,7 @@ What is PrimAITE? source/varying_config_files source/environment source/action_masking + source/node_sets .. toctree:: :caption: Notebooks: @@ -69,7 +70,7 @@ PrimAITE incorporates the following features: - Architected with a separate Simulation layer and Game layer. This separation of concerns defines a clear path towards transfer learning with environments of differing fidelity; - Ability to reconfigure an RL reward function based on (a) the ability to counter the modelled adversarial cyber-attack, and (b) the ability to ensure success for green agents; -- Access Control List (ACL) functions for network devices (routers and firewalls), following standard ACL rule format (e.g., DENY / ALLOW, source / destination IP addresses, protocol and port); +- Access Control List (ACL) functions for network devices (routers and firewalls), following standard ACL rule format (e.g., DENY / PERMIT, source / destination IP addresses, protocol and port); - Application of traffic to the links of the system laydown adheres to the ACL rulesets and routing tables contained within each network device; - Provides RL environments adherent to the Farama Foundation Gymnasium (Previously OpenAI Gym) API, allowing integration with any compliant RL Agent frameworks; - Provides RL environments adherent to Ray RLlib environment specifications for single-agent and multi-agent scenarios; diff --git a/docs/source/about.rst b/docs/source/about.rst index da87102a..839bbb0b 100644 --- a/docs/source/about.rst +++ b/docs/source/about.rst @@ -184,7 +184,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE! - 192.168.1.5 - ANY - ANY - All ACL rules are considered when applying an IER. Logic follows the order of rules, so a DENY or ALLOW for the same parameters will override an earlier entry. + All ACL rules are considered when applying an IER. Logic follows the order of rules, so a DENY or PERMIT for the same parameters will override an earlier entry. Observation Spaces ****************** The observation space provides the blue agent with information about the current status of nodes and links. @@ -331,7 +331,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE! * Dictionary item {... ,1: [x1, x2, x3, x4, x5, x6] ...} The placeholders inside the list under the key '1' mean the following: * [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule) - * [0, 1] - Permission (0 = DENY, 1 = ALLOW) + * [0, 1] - Permission (0 = DENY, 1 = PERMIT) * [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses) * [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses) * [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol) diff --git a/docs/source/action_masking.rst b/docs/source/action_masking.rst index 264ab254..bee4674b 100644 --- a/docs/source/action_masking.rst +++ b/docs/source/action_masking.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK Action Masking ************** @@ -23,123 +23,117 @@ The following logic is applied: +------------------------------------------+---------------------------------------------------------------------+ | Action | Action Mask Logic | +==========================================+=====================================================================+ -| **DONOTHING** | Always Possible. | +| **do-nothing** | Always Possible. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_SERVICE_SCAN** | Node is on. Service is running. | +| **node-service-scan** | Node is on. Service is running. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_SERVICE_STOP** | Node is on. Service is running. | +| **node-service-stop** | Node is on. Service is running. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_SERVICE_START** | Node is on. Service is stopped. | +| **node-service-start** | Node is on. Service is stopped. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_SERVICE_PAUSE** | Node is on. Service is running. | +| **node-service-pause** | Node is on. Service is running. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_SERVICE_RESUME** | Node is on. Service is paused. | +| **node-service-resume** | Node is on. Service is paused. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_SERVICE_RESTART** | Node is on. Service is running. | +| **node-service-restart** | Node is on. Service is running. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_SERVICE_DISABLE** | Node is on. | +| **node-service-disable** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_SERVICE_ENABLE** | Node is on. Service is disabled. | +| **node-service-enable** | Node is on. Service is disabled. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_SERVICE_FIX** | Node is on. Service is running. | +| **node-service-fix** | Node is on. Service is running. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_APPLICATION_EXECUTE** | Node is on. | +| **node-application-execute** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_APPLICATION_SCAN** | Node is on. Application is running. | +| **node-application-scan** | Node is on. Application is running. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_APPLICATION_CLOSE** | Node is on. Application is running. | +| **node-application-close** | Node is on. Application is running. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_APPLICATION_FIX** | Node is on. Application is running. | +| **node-application-fix** | Node is on. Application is running. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_APPLICATION_INSTALL** | Node is on. | +| **node-application-install** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_APPLICATION_REMOVE** | Node is on. | +| **node-application-remove** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_FILE_SCAN** | Node is on. File exists. File not deleted. | +| **node-file-scan** | Node is on. File exists. File not deleted. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_FILE_CREATE** | Node is on. | +| **node-file-create** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_FILE_CHECKHASH** | Node is on. File exists. File not deleted. | +| **node-file-checkhash** | Node is on. File exists. File not deleted. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_FILE_DELETE** | Node is on. File exists. | +| **node-file-delete** | Node is on. File exists. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_FILE_REPAIR** | Node is on. File exists. File not deleted. | +| **node-file-repair** | Node is on. File exists. File not deleted. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_FILE_RESTORE** | Node is on. File exists. File is deleted. | +| **node-file-restore** | Node is on. File exists. File is deleted. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_FILE_CORRUPT** | Node is on. File exists. File not deleted. | +| **node-file-corrupt** | Node is on. File exists. File not deleted. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_FILE_ACCESS** | Node is on. File exists. File not deleted. | +| **node-file-access** | Node is on. File exists. File not deleted. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_FOLDER_CREATE** | Node is on. | +| **node-folder-create** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_FOLDER_SCAN** | Node is on. Folder exists. Folder not deleted. | +| **node-folder-scan** | Node is on. Folder exists. Folder not deleted. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_FOLDER_CHECKHASH** | Node is on. Folder exists. Folder not deleted. | +| **node-folder-checkhash** | Node is on. Folder exists. Folder not deleted. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_FOLDER_REPAIR** | Node is on. Folder exists. Folder not deleted. | +| **node-folder-repair** | Node is on. Folder exists. Folder not deleted. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_FOLDER_RESTORE** | Node is on. Folder exists. Folder is deleted. | +| **node-folder-restore** | Node is on. Folder exists. Folder is deleted. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_OS_SCAN** | Node is on. | +| **node-os-scan** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **HOST_NIC_ENABLE** | NIC is disabled. Node is on. | +| **host-nic-enable** | NIC is disabled. Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **HOST_NIC_DISABLE** | NIC is enabled. Node is on. | +| **host-nic-disable** | NIC is enabled. Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_SHUTDOWN** | Node is on. | +| **node-shutdown** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_STARTUP** | Node is off. | +| **node-startup** | Node is off. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_RESET** | Node is on. | +| **node-reset** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_NMAP_PING_SCAN** | Node is on. | +| **node-nmap-ping-scan** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_NMAP_PORT_SCAN** | Node is on. | +| **node-nmap-port-scan** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_NMAP_NETWORK_SERVICE_RECON** | Node is on. | +| **node-network-service-recon** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **NETWORK_PORT_ENABLE** | Node is on. Router is on. | +| **network-port-enable** | Node is on. Router is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **NETWORK_PORT_DISABLE** | Router is on. | +| **network-port-disable** | Router is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **ROUTER_ACL_ADDRULE** | Router is on. | +| **router-acl-add-rule** | Router is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **ROUTER_ACL_REMOVERULE** | Router is on. | +| **router-acl-remove-rule** | Router is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **FIREWALL_ACL_ADDRULE** | Firewall is on. | +| **firewall-acl-add-rule** | Firewall is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **FIREWALL_ACL_REMOVERULE** | Firewall is on. | +| **firewall-acl-remove-rule** | Firewall is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_NMAP_PING_SCAN** | Node is on. | +| **configure-database-client** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_NMAP_PORT_SCAN** | Node is on. | +| **configure-ransomware-script** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_NMAP_NETWORK_SERVICE_RECON** | Node is on. | +| **c2-server-ransomware-configure** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **CONFIGURE_DATABASE_CLIENT** | Node is on. | +| **configure-dos-bot** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **CONFIGURE_RANSOMWARE_SCRIPT** | Node is on. | +| **configure-c2-beacon** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **CONFIGURE_DOSBOT** | Node is on. | +| **c2-server-ransomware-launch** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **CONFIGURE_C2_BEACON** | Node is on. | +| **c2-server-terminal-command** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **C2_SERVER_RANSOMWARE_LAUNCH** | Node is on. | +| **c2-server-data-exfiltrate** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **C2_SERVER_RANSOMWARE_CONFIGURE** | Node is on. | +| **node-account-change-password** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **C2_SERVER_TERMINAL_COMMAND** | Node is on. | +| **node-session-remote-login** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **C2_SERVER_DATA_EXFILTRATE** | Node is on. | +| **node-session-remote-logoff** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ -| **NODE_ACCOUNTS_CHANGE_PASSWORD** | Node is on. | -+------------------------------------------+---------------------------------------------------------------------+ -| **SSH_TO_REMOTE** | Node is on. | -+------------------------------------------+---------------------------------------------------------------------+ -| **SESSIONS_REMOTE_LOGOFF** | Node is on. | -+------------------------------------------+---------------------------------------------------------------------+ -| **NODE_SEND_REMOTE_COMMAND** | Node is on. | +| **node-send-remote-command** | Node is on. | +------------------------------------------+---------------------------------------------------------------------+ diff --git a/docs/source/config.rst b/docs/source/config.rst index eb0b9906..0fa4a4d5 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK PrimAITE |VERSION| Configuration ******************************** diff --git a/docs/source/configuration/agents.rst b/docs/source/configuration/agents.rst index 74571cf2..c2674e31 100644 --- a/docs/source/configuration/agents.rst +++ b/docs/source/configuration/agents.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK ``agents`` @@ -19,26 +19,7 @@ Agents can be scripted (deterministic and stochastic), or controlled by a reinfo ... - ref: green_agent_example team: GREEN - type: ProbabilisticAgent - observation_space: - type: UC2GreenObservation - action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_2 - applications: - - application_name: WebBrowser - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 1 - - reward_function: - reward_components: - - type: DUMMY + type: probabilistic-agent agent_settings: start_settings: @@ -57,13 +38,13 @@ Specifies if the agent is malicious (``RED``), benign (``GREEN``), or defensive ``type`` -------- -Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``ProbabilisticAgent`` generate their own behaviour. +Specifies which class should be used for the agent. ``proxy-agent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``red-database-corrupting-agent`` and ``probabilistic-agent`` generate their own behaviour. Available agent types: -- ``ProbabilisticAgent`` -- ``ProxyAgent`` -- ``RedDatabaseCorruptingAgent`` +- ``probabilistic-agent`` +- ``proxy-agent`` +- ``red-database-corrupting-agent`` ``observation_space`` --------------------- @@ -79,10 +60,10 @@ selects which python class from the :py:mod:`primaite.game.agent.observation` mo Allows configuration of the chosen observation type. These are optional. - * ``num_services_per_node``, ``num_folders_per_node``, ``num_files_per_folder``, ``num_nics_per_node`` all define the shape of the observation space. The size and shape of the obs space must remain constant, but the number of files, folders, ACL rules, and other components can change within an episode. Therefore padding is performed and these options set the size of the obs space. + * ``num_services_per_node``, ``num_folders_per_node``, ``num_files_per_folder``, ``num_nics_per_node`` all define the shape of the observation space. The size and shape of the obs space must remain constant, but the number of files, folders, acl rules, and other components can change within an episode. Therefore padding is performed and these options set the size of the obs space. * ``nodes``: list of nodes that will be present in this agent's observation space. The ``node_ref`` relates to the human-readable unique reference defined later in the ``simulation`` part of the config. Each node can also be configured with services, and files that should be monitored. * ``links``: list of links that will be present in this agent's observation space. The ``link_ref`` relates to the human-readable unique reference defined later in the ``simulation`` part of the config. - * ``acl``: configure how the agent reads the access control list on the router in the simulation. ``router_node_ref`` is for selecting which router's ACL table should be used. ``ip_list`` sets the encoding of ip addresses as integers within the observation space. + * ``acl``: configure how the agent reads the access control list on the router in the simulation. ``router_node_ref`` is for selecting which router's acl table should be used. ``ip_list`` sets the encoding of ip addresses as integers within the observation space. For more information see :py:mod:`primaite.game.agent.observations` @@ -91,10 +72,6 @@ For more information see :py:mod:`primaite.game.agent.observations` The action space is configured to be made up of individual action types. Once configured, the agent can select an action type and some optional action parameters at every step. For example: The ``NODE_SERVICE_SCAN`` action takes the parameters ``node_id`` and ``service_id``. -``action_list`` -^^^^^^^^^^^^^^^ - -A list of action modules. The options are listed in the :py:mod:`primaite.game.agent.actions.ActionManager.act_class_identifiers` module. ``action_map`` ^^^^^^^^^^^^^^ @@ -120,7 +97,7 @@ Similar to action space, this is defined as a list of components from the :py:mo ``reward_components`` ^^^^^^^^^^^^^^^^^^^^^ - +TODO: update description A list of reward types from :py:mod:`primaite.game.agent.rewards.RewardFunction.rew_class_identifiers` e.g. @@ -128,8 +105,8 @@ e.g. .. code-block:: yaml reward_components: - - type: DUMMY - - type: DATABASE_FILE_INTEGRITY + - type: dummy + - type: database-file-integrity ``agent_settings`` diff --git a/docs/source/configuration/game.rst b/docs/source/configuration/game.rst index 2048708c..b3c139b2 100644 --- a/docs/source/configuration/game.rst +++ b/docs/source/configuration/game.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK ``game`` diff --git a/docs/source/configuration/io_settings.rst b/docs/source/configuration/io_settings.rst index 1c9585c9..ab3a978e 100644 --- a/docs/source/configuration/io_settings.rst +++ b/docs/source/configuration/io_settings.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK ``io_settings`` diff --git a/docs/source/configuration/simulation.rst b/docs/source/configuration/simulation.rst index fa1d774a..47ff6832 100644 --- a/docs/source/configuration/simulation.rst +++ b/docs/source/configuration/simulation.rst @@ -1,12 +1,12 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK ``simulation`` ============== In this section the network layout is defined. This part of the config follows a hierarchical structure. Almost every component defines a ``ref`` field which acts as a human-readable unique identifier, used by other parts of the config, such as agents. - +# TODO: ref field is no longer real At the top level of the network are ``nodes``, ``links`` and ``airspace``. e.g. diff --git a/docs/source/configuration/simulation/nodes/common/common.rst b/docs/source/configuration/simulation/nodes/common/common.rst index a0f2eb13..c45eccf6 100644 --- a/docs/source/configuration/simulation/nodes/common/common.rst +++ b/docs/source/configuration/simulation/nodes/common/common.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _Node Attributes: diff --git a/docs/source/configuration/simulation/nodes/common/common_host_node_attributes.rst b/docs/source/configuration/simulation/nodes/common/common_host_node_attributes.rst index bb3b2a52..b717340e 100644 --- a/docs/source/configuration/simulation/nodes/common/common_host_node_attributes.rst +++ b/docs/source/configuration/simulation/nodes/common/common_host_node_attributes.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _common_host_node_attributes: diff --git a/docs/source/configuration/simulation/nodes/common/common_network_node_attributes.rst b/docs/source/configuration/simulation/nodes/common/common_network_node_attributes.rst index d556e2dc..035c7e55 100644 --- a/docs/source/configuration/simulation/nodes/common/common_network_node_attributes.rst +++ b/docs/source/configuration/simulation/nodes/common/common_network_node_attributes.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _common_network_node_attributes: diff --git a/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst b/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst index 5c055ecd..e6d5da67 100644 --- a/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst +++ b/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _common_node_attributes: diff --git a/docs/source/configuration/simulation/nodes/common/node_type_list.rst b/docs/source/configuration/simulation/nodes/common/node_type_list.rst index 1ec496d9..21181019 100644 --- a/docs/source/configuration/simulation/nodes/common/node_type_list.rst +++ b/docs/source/configuration/simulation/nodes/common/node_type_list.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK ``type`` -------- diff --git a/docs/source/configuration/simulation/nodes/computer.rst b/docs/source/configuration/simulation/nodes/computer.rst index 32e0b2b9..456d11a2 100644 --- a/docs/source/configuration/simulation/nodes/computer.rst +++ b/docs/source/configuration/simulation/nodes/computer.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _computer_configuration: diff --git a/docs/source/configuration/simulation/nodes/firewall.rst b/docs/source/configuration/simulation/nodes/firewall.rst index 775ffabd..84b5c99e 100644 --- a/docs/source/configuration/simulation/nodes/firewall.rst +++ b/docs/source/configuration/simulation/nodes/firewall.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _firewall_configuration: diff --git a/docs/source/configuration/simulation/nodes/network_examples.rst b/docs/source/configuration/simulation/nodes/network_examples.rst index 2a34a206..84ee4c60 100644 --- a/docs/source/configuration/simulation/nodes/network_examples.rst +++ b/docs/source/configuration/simulation/nodes/network_examples.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _network_examples: @@ -617,10 +617,10 @@ Each node is configured to ensure it meets the specific security and operational default_gateway: 192.168.1.1 dns_server: 8.8.8.2 applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 10.10.1.11 - - type: WebBrowser + - type: web-browser options: target_url: http://sometech.ai @@ -631,10 +631,10 @@ Each node is configured to ensure it meets the specific security and operational default_gateway: 192.168.1.1 dns_server: 8.8.8.2 applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 10.10.1.11 - - type: WebBrowser + - type: web-browser options: target_url: http://sometech.ai @@ -700,7 +700,7 @@ Each node is configured to ensure it meets the specific security and operational default_gateway: 8.8.8.1 services: - ref: dns_server - type: DNSServer + type: dns-server options: domain_mapping: sometech.ai: 94.10.180.6 @@ -794,9 +794,9 @@ Each node is configured to ensure it meets the specific security and operational dns_server: 8.8.8.2 services: - ref: web_server - type: WebServer + type: web-server applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 10.10.1.11 @@ -903,10 +903,10 @@ Each node is configured to ensure it meets the specific security and operational default_gateway: 10.10.1.1 dns_server: 8.8.8.2 services: - - type: DatabaseService + - type: database-service options: backup_server_ip: 10.10.1.12 # The some_tech_storage_srv server - - type: FTPClient + - type: ftp-client - hostname: some_tech_storage_srv type: server @@ -915,7 +915,7 @@ Each node is configured to ensure it meets the specific security and operational default_gateway: 10.10.1.1 dns_server: 8.8.8.2 services: - - type: FTPServer + - type: ftp-server - hostname: some_tech_hr_1 type: computer @@ -924,10 +924,10 @@ Each node is configured to ensure it meets the specific security and operational default_gateway: 10.10.3.1 dns_server: 8.8.8.2 applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 10.10.1.11 - - type: WebBrowser + - type: web-browser options: target_url: http://sometech.ai @@ -938,10 +938,10 @@ Each node is configured to ensure it meets the specific security and operational default_gateway: 10.10.2.1 dns_server: 8.8.8.2 applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 10.10.1.11 - - type: WebBrowser + - type: web-browser options: target_url: http://sometech.ai @@ -952,10 +952,10 @@ Each node is configured to ensure it meets the specific security and operational default_gateway: 10.10.2.1 dns_server: 8.8.8.2 applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 10.10.1.11 - - type: WebBrowser + - type: web-browser options: target_url: http://sometech.ai @@ -1177,8 +1177,8 @@ ACLs permitting or denying traffic as per our configured ACL rules. some_tech_storage_srv = network.get_node_by_hostname("some_tech_storage_srv") some_tech_storage_srv.file_system.create_file(file_name="test.png") - pc_1_ftp_client: FTPClient = network.get_node_by_hostname("pc_1").software_manager.software["FTPClient"] - pc_2_ftp_client: FTPClient = network.get_node_by_hostname("pc_2").software_manager.software["FTPClient"] + pc_1_ftp_client: FTPClient = network.get_node_by_hostname("pc_1").software_manager.software["ftp-client"] + pc_2_ftp_client: FTPClient = network.get_node_by_hostname("pc_2").software_manager.software["ftp-client"] assert not pc_1_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -1224,7 +1224,7 @@ ACLs permitting or denying traffic as per our configured ACL rules. web_server: Server = network.get_node_by_hostname("some_tech_web_srv") - web_ftp_client: FTPClient = web_server.software_manager.software["FTPClient"] + web_ftp_client: FTPClient = web_server.software_manager.software["ftp-client"] assert not web_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -1269,7 +1269,7 @@ ACLs permitting or denying traffic as per our configured ACL rules. some_tech_storage_srv.file_system.create_file(file_name="test.png") some_tech_snr_dev_pc: Computer = network.get_node_by_hostname("some_tech_snr_dev_pc") - snr_dev_ftp_client: FTPClient = some_tech_snr_dev_pc.software_manager.software["FTPClient"] + snr_dev_ftp_client: FTPClient = some_tech_snr_dev_pc.software_manager.software["ftp-client"] assert snr_dev_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -1294,7 +1294,7 @@ ACLs permitting or denying traffic as per our configured ACL rules. some_tech_storage_srv.file_system.create_file(file_name="test.png") some_tech_jnr_dev_pc: Computer = network.get_node_by_hostname("some_tech_jnr_dev_pc") - jnr_dev_ftp_client: FTPClient = some_tech_jnr_dev_pc.software_manager.software["FTPClient"] + jnr_dev_ftp_client: FTPClient = some_tech_jnr_dev_pc.software_manager.software["ftp-client"] assert not jnr_dev_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -1337,7 +1337,7 @@ ACLs permitting or denying traffic as per our configured ACL rules. some_tech_storage_srv.file_system.create_file(file_name="test.png") some_tech_hr_pc: Computer = network.get_node_by_hostname("some_tech_hr_1") - hr_ftp_client: FTPClient = some_tech_hr_pc.software_manager.software["FTPClient"] + hr_ftp_client: FTPClient = some_tech_hr_pc.software_manager.software["ftp-client"] assert not hr_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, diff --git a/docs/source/configuration/simulation/nodes/router.rst b/docs/source/configuration/simulation/nodes/router.rst index b8741521..ee278a98 100644 --- a/docs/source/configuration/simulation/nodes/router.rst +++ b/docs/source/configuration/simulation/nodes/router.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _router_configuration: diff --git a/docs/source/configuration/simulation/nodes/server.rst b/docs/source/configuration/simulation/nodes/server.rst index 92b33ca7..616efb38 100644 --- a/docs/source/configuration/simulation/nodes/server.rst +++ b/docs/source/configuration/simulation/nodes/server.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _server_configuration: diff --git a/docs/source/configuration/simulation/nodes/switch.rst b/docs/source/configuration/simulation/nodes/switch.rst index 17cf76f9..d09f5ba7 100644 --- a/docs/source/configuration/simulation/nodes/switch.rst +++ b/docs/source/configuration/simulation/nodes/switch.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _switch_configuration: diff --git a/docs/source/configuration/simulation/software/applications.rst b/docs/source/configuration/simulation/software/applications.rst index 8c590d53..9973a167 100644 --- a/docs/source/configuration/simulation/software/applications.rst +++ b/docs/source/configuration/simulation/software/applications.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK ``applications`` ---------------- diff --git a/docs/source/configuration/simulation/software/services.rst b/docs/source/configuration/simulation/software/services.rst index fafdf2e8..ec6bbba9 100644 --- a/docs/source/configuration/simulation/software/services.rst +++ b/docs/source/configuration/simulation/software/services.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK ``services`` ------------ diff --git a/docs/source/customising_scenarios.rst b/docs/source/customising_scenarios.rst index 092f306b..df7d4b1e 100644 --- a/docs/source/customising_scenarios.rst +++ b/docs/source/customising_scenarios.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK Customising Agents ****************** diff --git a/docs/source/dependencies.rst b/docs/source/dependencies.rst index 74f3cd14..e8be00d3 100644 --- a/docs/source/dependencies.rst +++ b/docs/source/dependencies.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. role:: raw-html(raw) :format: html diff --git a/docs/source/developer_tools.rst b/docs/source/developer_tools.rst index a66b7902..b3d81a27 100644 --- a/docs/source/developer_tools.rst +++ b/docs/source/developer_tools.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _Developer Tools: diff --git a/docs/source/environment.rst b/docs/source/environment.rst index a282c09e..251b1090 100644 --- a/docs/source/environment.rst +++ b/docs/source/environment.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK RL Environments *************** diff --git a/docs/source/example_notebooks.rst b/docs/source/example_notebooks.rst index 920175c9..6caeae3d 100644 --- a/docs/source/example_notebooks.rst +++ b/docs/source/example_notebooks.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _example jupyter notebooks: diff --git a/docs/source/game_layer.rst b/docs/source/game_layer.rst index 775c02b5..36ec016d 100644 --- a/docs/source/game_layer.rst +++ b/docs/source/game_layer.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK PrimAITE Game layer ******************* @@ -57,13 +57,13 @@ An agent's reward can be based on rewards of other agents. This is particularly reward_components: # When the webpage loads, the reward goes up by 0.25 when it fails to load, it goes down to -0.25 - - type: WEBPAGE_UNAVAILABLE_PENALTY + - type: webpage-unavailable-penalty weight: 0.25 options: node_hostname: client_2 # When the database is reachable, the reward goes up by 0.05, when it is unreachable it goes down to -0.05 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + - type: green-admin-database-unreachable-penalty weight: 0.05 options: node_hostname: client_2 @@ -74,7 +74,7 @@ An agent's reward can be based on rewards of other agents. This is particularly reward_components: # When the database file is in a good state, blue's reward is 0.4, when it's in a corrupted state the reward is -0.4 - - type: DATABASE_FILE_INTEGRITY + - type: database-file-integrity weight: 0.40 options: node_hostname: database_server @@ -82,7 +82,7 @@ An agent's reward can be based on rewards of other agents. This is particularly file_name: database.db # The green's reward is added onto the blue's reward. - - type: SHARED_REWARD + - type: shared-reward weight: 1.0 options: agent_name: client_2_green_user diff --git a/docs/source/getting_started.rst b/docs/source/getting_started.rst index ded92c60..427d1823 100644 --- a/docs/source/getting_started.rst +++ b/docs/source/getting_started.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _getting-started: diff --git a/docs/source/glossary.rst b/docs/source/glossary.rst index 8fff0ea3..02c578d1 100644 --- a/docs/source/glossary.rst +++ b/docs/source/glossary.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK Glossary ============= diff --git a/docs/source/how_to_guides/extensible_actions.rst b/docs/source/how_to_guides/extensible_actions.rst new file mode 100644 index 00000000..c2cc07bf --- /dev/null +++ b/docs/source/how_to_guides/extensible_actions.rst @@ -0,0 +1,67 @@ +.. only:: comment + + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + + +Extensible Actions +****************** + + +Changes to Actions class Structure. +=================================== + +Actions within PrimAITE have been updated to inherit from a base class, AbstractAction, standardising their format and allowing for easier creation of custom actions. Actions now use a ``ConfigSchema`` to define the possible configuration variables, and use pydantic to enforce correct parameters are passed through. + + +Developing Custom Actions. +========================== + +Custom actions within PrimAITE must be a sub-class of `AbstractAction`, and contain 3 key items: + +#. ConfigSchema class + +#. Unique discriminator + +#. `form_request` method. + + +ConfigSchema +############ + +The ConfigSchema sub-class of the action must contain all `configurable` variables within the action, that would be specified within the environments configuration YAML file. + + +Unique discriminator +################# + +When declaring a custom class, it must have a unique discriminator string, that allows PrimAITE to generate the correct action when needed. + +.. code:: Python + + class CreateDirectoryAction(AbstractAction, discriminator="node-folder-create") + + config: CreateDirectoryAction.ConfigSchema + + class ConfigSchema(AbstractAction.ConfigSchema): + + verb: ClassVar[str] = "create" + node_name: str + directory_name: str + + def form_request(cls, config: ConfigSchema) -> RequestFormat: + return ["network", + "node", + config.node_name, + "file_system", + config.verb, + "folder", + config.directory_name, + ] + +The above action would fail pydantic validation as the discriminator "node-folder-create" is already used by the `NodeFolderCreateAction`, and would create a duplicate listing within `AbstractAction._registry`. + + +form_request method +################### + +PrimAITE actions need to have a `form_request` method, which can be passed to the `RequestManager` for processing. This allows the custom action to be actioned within the simulation environment. diff --git a/docs/source/how_to_guides/extensible_agents.rst b/docs/source/how_to_guides/extensible_agents.rst new file mode 100644 index 00000000..3236c21a --- /dev/null +++ b/docs/source/how_to_guides/extensible_agents.rst @@ -0,0 +1,74 @@ +.. only:: comment + + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + +.. _about: + +Extensible Agents +***************** + +Agents defined within PrimAITE have been updated to allow for easier creation of new bespoke agents for use in custom environments. + + +Developing Agents for PrimAITE +============================== + +All agent types within PrimAITE must be subclassed from ``AbstractAgent`` in order to be used from configuration YAML files. This then allows you to implement any custom agent logic for the new agent in your training scenario. Examples of implementing custom agent logic can be seen in pre-existing agents, such as the ``DataManipulationBot`` and ``RandomAgent``. + +The core features that should be implemented in any new agent are detailed below: + +#. **ConfigSchema**: + + Configurable items within a new agent within PrimAITE should contain a ``ConfigSchema`` which holds all configurable variables of the agent. This should not include parameters related to its *state*, these would be listed seperately. + Agent generation will fail pydantic checks if incorrect or invalid parameters are passed to the ConfigSchema of the chosen Agent. + + + .. code-block:: python + + class ExampleAgent(AbstractAgent, discriminator = "ExampleAgent"): + """An example agent for demonstration purposes.""" + + config: "ExampleAgent.ConfigSchema" = Field(default_factory= lambda: ExampleAgent.ConfigSchema()) + """Agent configuration""" + num_executions: int = 0 + """Number of action executions by agent""" + + class ConfigSchema(AbstractAgent.ConfigSchema): + """ExampleAgent configuration schema""" + + type: str = "ExampleAgent + """Name of agent""" + starting_host: int + """Host node that this agent should start from in the given environment.""" + + + .. code-block:: yaml + + - ref: example_green_agent + team: GREEN + type: example-agent + + action_space: + action_map: + 0: + action: do-nothing + options: {} + agent_settings: + start_step: 25 + frequency: 20 + variance: 5 + starting_host: "Server_1" + + +#. **discriminators**: + + All agent classes should have an ``discriminator`` attribute, a unique kebab-case string, for when they are added to the base ``AbstractAgent`` registry. This is then specified in your configuration YAML, and used by PrimAITE to generate the correct Agent. + +Changes to YAML file +==================== + +PrimAITE v4.0.0 introduces some breaking changes to how environment configuration yaml files are created. YAML files created for Primaite versions 3.3.0 should be compatible through a translation function, though it is encouraged that these are updated to reflect the updated format of 4.0.0+. + +Agents now follow a more standardised settings definition, so should be more consistent across YAML files and the available agent types with PrimAITE. + +All configurable items for agents sit under the ``agent_settings`` heading within your YAML files. There is no need for the inclusion of a ``start_settings``. Please see the above YAML example for full changes to agents. diff --git a/docs/source/how_to_guides/extensible_nodes.rst b/docs/source/how_to_guides/extensible_nodes.rst new file mode 100644 index 00000000..18d64ca8 --- /dev/null +++ b/docs/source/how_to_guides/extensible_nodes.rst @@ -0,0 +1,55 @@ +.. only:: comment + + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + +.. _about: + + +Extensible Nodes +**************** + +Node classes within PrimAITE have been updated to allow for easier generation of custom nodes within simulations. + + +Changes to Node Class structure. +================================ + +Node classes all inherit from the base Node Class, though new classes should inherit from either HostNode or NetworkNode, subject to the intended application of the Node. + +The use of an `__init__` method is not necessary, as configurable variables for the class should be specified within the `config` of the class, and passed at run time via your YAML configuration using the `from_config` method. + +An example of how additional Node classes is below, taken from `router.py` within PrimAITE. + +.. code-block:: Python + +class Router(NetworkNode, identifier="router"): + """ Represents a network router within the simulation, managing routing and forwarding of IP packets across network interfaces.""" + + SYSTEM_SOFTWARE: ClassVar[Dict] = { + "user-session-manager": UserSessionManager, + "user-manager": UserManager, + "terminal": Terminal, + } + + network_interfaces: Dict[str, RouterInterface] = {} + "The Router Interfaces on the node." + network_interface: Dict[int, RouterInterface] = {} + "The Router Interfaces on the node by port id." + + sys_log: SysLog + + config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSchema()) + + class ConfigSchema(NetworkNode.ConfigSchema): + """Configuration Schema for Router Objects.""" + + num_ports: int = 5 + + hostname: str = "Router" + + + +Changes to YAML file. +===================== + +While effort has been made to ensure that nodes defined within configuration YAML files for use with PrimAITE 3.X remain compatible with PrimAITE v4+, it is encouraged to review for minor changes needed. diff --git a/docs/source/how_to_guides/extensible_rewards.rst b/docs/source/how_to_guides/extensible_rewards.rst new file mode 100644 index 00000000..d3053a49 --- /dev/null +++ b/docs/source/how_to_guides/extensible_rewards.rst @@ -0,0 +1,57 @@ +.. only:: comment + + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + +.. _about: + +Extensible Rewards +****************** +Extensible Rewards differ from the previous reward mechanism used in PrimAITE v3.x as new reward +types can be added without requiring a change to the RewardFunction class in rewards.py (PrimAITE +core repository). + +Changes to reward class structure. +================================== + +Reward classes are inherited from AbstractReward (a sub-class of Pydantic's BaseModel). +Within the reward class there is a ConfigSchema class responsible for ensuring the config file data +is in the correct format. This also means there is little (if no) requirement for and `__init__` +method. The `.from_config` method is no longer required as it's inherited from `AbstractReward`. +Each class requires an discriminator string which is used by the ConfigSchema class to verify that it +hasn't previously been added to the registry. + +Inheriting from `BaseModel` removes the need for an `__init__` method but means that object +attributes need to be passed by keyword. + +To add a new reward class follow the example below. Note that the type attribute in the +`ConfigSchema` class should match the type used in the config file to define the reward. + +.. code-block:: Python + +class DatabaseFileIntegrity(AbstractReward, discriminator="database-file-integrity"): + """Reward function component which rewards the agent for maintaining the integrity of a database file.""" + + config: "DatabaseFileIntegrity.ConfigSchema" + location_in_state: List[str] = [""] + reward: float = 0.0 + + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for DatabaseFileIntegrity.""" + + type: str = "database-file-integrity" + node_hostname: str + folder_name: str + file_name: str + + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: + """Calculate the reward for the current state. + pass + + + +Changes to YAML file. +===================== +.. code:: YAML + + There's no longer a need to provide a `dns_server` as an option in the simulation section + of the config file. diff --git a/docs/source/node_sets.rst b/docs/source/node_sets.rst new file mode 100644 index 00000000..75047d54 --- /dev/null +++ b/docs/source/node_sets.rst @@ -0,0 +1,115 @@ +.. only:: comment + + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + +.. _network_node_adder: + +Network Node Adder Module +######################### + +This module provides a framework for adding nodes to a network in a standardised way. It defines a base class ``NetworkNodeAdder``, which can be extended to create specific node adders, and utility functions to calculate network infrastructure requirements. + +The module allows you to use the pre-defined node adders, ``OfficeLANAdder``, or create custom ones by extending the base class. + +How It Works +============ + +The main class in the module is ``NetworkNodeAdder``, which defines the interface for adding nodes to a network. Child classes are expected to: + +1. Define a ``ConfigSchema`` nested class to define configuration options. +2. Implement the ``add_nodes_to_net(config, network)`` method, which adds the nodes to the network according to the configuration object. + +The ``NetworkNodeAdder`` base class handles node adders defined in the primAITE config YAML file as well. It does this by keeping a registry of node adder classes, and uses the ``type`` field of the config to select the appropriate class to which to pass the configuration. + +Example Usage +============= + +Via Python API +-------------- + +Adding nodes to a network can be done using the python API by constructing the relevant ``ConfigSchema`` object like this: + +.. code-block:: python + + net = Network() + + office_lan_config = OfficeLANAdder.ConfigSchema( + lan_name="CORP-LAN", + subnet_base=2, + pcs_ip_block_start=10, + num_pcs=8, + include_router=False, + bandwidth=150, + ) + OfficeLANAdder.add_nodes_to_net(config=office_lan_config, network=net) + +In this example, a network with 8 computers connected by a switch will be added to the network object. + + +Via YAML Config +--------------- + +.. code-block:: yaml + simulation: + network: + nodes: + # ... nodes go here + node_sets: + - type: office-lan + lan_name: CORP_LAN + subnet_base: 2 + pcs_ip_block_start: 10 + num_pcs: 8 + include_router: False + bandwidth: 150 + # ... additional node sets can be added below + +``NetworkNodeAdder`` reads the ``type`` property of the config, then constructs and passes the configuration to ``OfficeLANAdder.add_nodes_to_net()``. + +In this example, a network with 8 computers connected by a switch will be added to the network object. Equivalent to the above. + + +Creating Custom Node Adders +=========================== +To create a custom node adder, subclass NetworkNodeAdder and define: + +* A ConfigSchema class that defines the configuration schema for the node adder. +* The add_nodes_to_net method that implements how nodes should be added to the network. + +Example: DataCenterAdder +------------------------ +Here is an example of creating a custom node adder, DataCenterAdder: + +.. code-block:: python + + class DataCenterAdder(NetworkNodeAdder, discriminator="data-center"): + class ConfigSchema(NetworkNodeAdder.ConfigSchema): + type: Literal["data-center"] = "data-center" + num_servers: int + data_center_name: str + + @classmethod + def add_nodes_to_net(cls, config: ConfigSchema, network: Network) -> None: + for i in range(config.num_servers): + server = Computer( + hostname=f"server_{i}_{config.data_center_name}", + ip_address=f"192.168.100.{i + 8}", + subnet_mask="255.255.255.0", + default_gateway="192.168.100.1", + start_up_duration=0 + ) + server.power_on() + network.add_node(server) + +**Using the Custom Node Adder:** + +.. code-block:: python + + config = { + "type": "data-center", + "num_servers": 5, + "data_center_name": "dc1" + } + + network = Network() + DataCenterAdder.from_config(config, network) diff --git a/docs/source/notebooks/executed_notebooks.rst b/docs/source/notebooks/executed_notebooks.rst index 3431d344..f4acfad6 100644 --- a/docs/source/notebooks/executed_notebooks.rst +++ b/docs/source/notebooks/executed_notebooks.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _Executed Notebooks: diff --git a/docs/source/primaite-dependencies.rst b/docs/source/primaite-dependencies.rst index 8367ee61..ce2087ca 100644 --- a/docs/source/primaite-dependencies.rst +++ b/docs/source/primaite-dependencies.rst @@ -1,45 +1,45 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| Name | Version | License | Description | URL | -+===================+=========+====================================+=======================================================================================================+====================================================================+ -| gymnasium | 0.28.1 | MIT License | A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym). | https://farama.org | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| ipywidgets | 8.1.5 | BSD License | Jupyter interactive widgets | http://jupyter.org | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| jupyterlab | 3.6.1 | BSD License | JupyterLab computational environment | https://jupyter.org | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| kaleido | 0.2.1 | MIT | Static image export for web-based visualization libraries with zero dependencies | https://github.com/plotly/Kaleido | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| matplotlib | 3.7.1 | Python Software Foundation License | Python plotting package | https://matplotlib.org | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| networkx | 3.1 | BSD License | Python package for creating and manipulating graphs and networks | https://networkx.org/ | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| numpy | 1.23.5 | BSD License | NumPy is the fundamental package for array computing with Python. | https://www.numpy.org | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| platformdirs | 3.5.1 | MIT License | A small Python package for determining appropriate platform-specific dirs, e.g. a "user data dir". | https://github.com/platformdirs/platformdirs | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| plotly | 5.15.0 | MIT License | An open-source, interactive data visualization library for Python | https://plotly.com/python/ | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| polars | 0.20.30 | MIT License | Blazingly fast DataFrame library | https://www.pola.rs/ | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| prettytable | 3.8.0 | BSD License (BSD (3 clause)) | A simple Python library for easily displaying tabular data in a visually appealing ASCII table format | https://github.com/jazzband/prettytable | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| pydantic | 2.7.0 | MIT License | Data validation using Python type hints | https://github.com/pydantic/pydantic | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| PyYAML | 6.0 | MIT License | YAML parser and emitter for Python | https://pyyaml.org/ | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| ray | 2.32.0 | Apache 2.0 | Ray provides a simple, universal API for building distributed applications. | https://github.com/ray-project/ray | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| stable-baselines3 | 2.1.0 | MIT | Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms. | https://github.com/DLR-RM/stable-baselines3 | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| tensorflow | 2.12.0 | Apache Software License | TensorFlow is an open source machine learning framework for everyone. | https://www.tensorflow.org/ | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| typer | 0.9.0 | MIT License | Typer, build great CLIs. Easy to code. Based on Python type hints. | https://github.com/tiangolo/typer | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| Deepdiff | 8.0.1 | MIT License | Deep difference of dictionaries, iterables, strings, and any other object objects. | https://github.com/seperman/deepdiff | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ -| sb3_contrib | 2.1.0 | MIT License | Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code (Action Masking)| https://github.com/Stable-Baselines-Team/stable-baselines3-contrib | -+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+ ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| Name | Supported Version | Built Version | License | Description | URL | ++===================+=====================+===============+======================================+========================================================================================================+=====================================================================+ +| gymnasium | 0.28.1 | 0.28.1 | MIT License | A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym). | https://farama.org | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| ipywidgets | ~=8.0 | 8.1.5 | BSD License | Jupyter interactive widgets | http://jupyter.org | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| jupyterlab | 3.6.1 | 3.6.1 | BSD License | JupyterLab computational environment | https://jupyter.org | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| kaleido | ==0.2.1 | 0.2.1 | MIT | Static image export for web-based visualization libraries with zero dependencies | https://github.com/plotly/Kaleido | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| matplotlib | >=3.7.1 | 3.7.1 | Python Software Foundation License | Python plotting package | https://matplotlib.org | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| networkx | 3.1 | 3.1 | BSD License | Python package for creating and manipulating graphs and networks | https://networkx.org/ | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| numpy | ~1.23 | 1.23.5 | BSD License | NumPy is the fundamental package for array computing with Python. | https://www.numpy.org | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| platformdirs | 3.5.1 | 3.5.1 | MIT License | A small Python package for determining appropriate platform-specific dirs, e.g. a "user data dir". | https://github.com/platformdirs/platformdirs | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| plotly | 5.15 | 5.15.0 | MIT License | An open-source, interactive data visualization library for Python | https://plotly.com/python/ | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| polars | 0.20.30 | 0.20.30 | MIT License | Blazingly fast DataFrame library | https://www.pola.rs/ | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| prettytable | 3.8.0 | 3.8.0 | BSD License (BSD (3 clause)) | A simple Python library for easily displaying tabular data in a visually appealing ASCII table format | https://github.com/jazzband/prettytable | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| pydantic | 2.7.0 | 2.7.0 | MIT License | Data validation using Python type hints | https://github.com/pydantic/pydantic | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| PyYAML | >=6.0 | 6.0 | MIT License | YAML parser and emitter for Python | https://pyyaml.org/ | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| ray | >=2.20, <2.33 | 2.32.0 | Apache 2.0 | Ray provides a simple, universal API for building distributed applications. | https://github.com/ray-project/ray | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| stable-baselines3 | 2.1.0 | 2.1.0 | MIT | Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms. | https://github.com/DLR-RM/stable-baselines3 | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| tensorflow | ~=2.12 | 2.12.0 | Apache Software License | TensorFlow is an open source machine learning framework for everyone. | https://www.tensorflow.org/ | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| typer | >=0.9 | 0.9.0 | MIT License | Typer, build great CLIs. Easy to code. Based on Python type hints. | https://github.com/tiangolo/typer | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| Deepdiff | 8.0.1 | 8.0.1 | MIT License | Deep difference of dictionaries, iterables, strings, and any other object objects. | https://github.com/seperman/deepdiff | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ +| sb3_contrib | 2.1.0 | 2.1.0 | MIT License | Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code (Action Masking) | https://github.com/Stable-Baselines-Team/stable-baselines3-contrib | ++-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+ diff --git a/docs/source/request_system.rst b/docs/source/request_system.rst index 6b71bf25..93fc2a9f 100644 --- a/docs/source/request_system.rst +++ b/docs/source/request_system.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _request_system: @@ -53,10 +53,10 @@ Request responses When the simulator receives a request, it returns a response with a success status. The possible statuses are: * **success**: The request was received and successfully executed. - * For example, the agent tries to add an ACL rule and specifies correct parameters, and the ACL rule is added successfully. + * For example, the agent tries to add an acl rule and specifies correct parameters, and the acl rule is added successfully. * **failure**: The request was received, but it could not be executed, or it failed while executing. - * For example, the agent tries to execute the ``WebBrowser`` application, but the webpage wasn't retrieved because the DNS server is not setup on the node. + * For example, the agent tries to execute the ``web-browser`` application, but the webpage wasn't retrieved because the DNS server is not setup on the node. * **unreachable**: The request was sent to a simulation component that does not exist. * For example, the agent tries to scan a file that has not been created yet. diff --git a/docs/source/rewards.rst b/docs/source/rewards.rst index 0163284c..1f588f36 100644 --- a/docs/source/rewards.rst +++ b/docs/source/rewards.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK Rewards ####### @@ -23,7 +23,7 @@ The following API pages describe the use of each reward component and the possib # ... reward_function: reward_components: - - type: DUMMY + - type: dummy weight: 1.0 @@ -36,7 +36,7 @@ The following API pages describe the use of each reward component and the possib # ... reward_function: reward_components: - - type: DATABASE_FILE_INTEGRITY + - type: database-file-integrity weight: 1.0 options: node_hostname: server_1 @@ -53,7 +53,7 @@ The following API pages describe the use of each reward component and the possib # ... reward_function: reward_components: - - type: WEB_SERVER_404_PENALTY + - type: web-server-404-penalty node_hostname: web_server weight: 1.0 options: @@ -70,7 +70,7 @@ The following API pages describe the use of each reward component and the possib # ... reward_function: reward_components: - - type: WEBPAGE_UNAVAILABLE_PENALTY + - type: webpage-unavailable-penalty node_hostname: computer_1 weight: 1.0 options: @@ -86,7 +86,7 @@ The following API pages describe the use of each reward component and the possib # ... reward_function: reward_components: - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + - type: green-admin-database-unreachable-penalty weight: 1.0 options: node_hostname: admin_pc_1 @@ -104,7 +104,7 @@ The following API pages describe the use of each reward component and the possib # ... reward_function: reward_components: - - type: SHARED_REWARD + - type: shared-reward weight: 1.0 options: agent_name: scripted_agent @@ -119,7 +119,7 @@ The following API pages describe the use of each reward component and the possib # ... reward_function: reward_components: - - type: ACTION_PENALTY + - type: action-penalty weight: 1.0 options: action_penalty: -0.3 diff --git a/docs/source/simulation.rst b/docs/source/simulation.rst index cc723e40..95807703 100644 --- a/docs/source/simulation.rst +++ b/docs/source/simulation.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK Simulation diff --git a/docs/source/simulation_components/network/airspace.rst b/docs/source/simulation_components/network/airspace.rst index 06a884a7..a6967b91 100644 --- a/docs/source/simulation_components/network/airspace.rst +++ b/docs/source/simulation_components/network/airspace.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _airspace: diff --git a/docs/source/simulation_components/network/base_hardware.rst b/docs/source/simulation_components/network/base_hardware.rst index ce1e5c74..8b325ffc 100644 --- a/docs/source/simulation_components/network/base_hardware.rst +++ b/docs/source/simulation_components/network/base_hardware.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK ############# Base Hardware diff --git a/docs/source/simulation_components/network/network.rst b/docs/source/simulation_components/network/network.rst index b04d6ecf..a6fe4070 100644 --- a/docs/source/simulation_components/network/network.rst +++ b/docs/source/simulation_components/network/network.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _network: @@ -103,6 +103,13 @@ we'll use the following Network that has a client, server, two switches, and a r router_1.acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.ICMP, + src_port=PORT_LOOKUP["ARP"], + dst_port=PORT_LOOKUP["ARP"], + position=22 + ) + + router_1.acl.add_rule( + action=ACLAction.PERMIT, + protocol=PROTOCOL_LOOKUP["ICMP"], position=23 ) diff --git a/docs/source/simulation_components/network/network_interfaces.rst b/docs/source/simulation_components/network/network_interfaces.rst index c6b97a8e..663af7ba 100644 --- a/docs/source/simulation_components/network/network_interfaces.rst +++ b/docs/source/simulation_components/network/network_interfaces.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK ################################# Network Interface Hierarchy Model diff --git a/docs/source/simulation_components/network/nodes/firewall.rst b/docs/source/simulation_components/network/nodes/firewall.rst index 149d3e67..f2d7e61a 100644 --- a/docs/source/simulation_components/network/nodes/firewall.rst +++ b/docs/source/simulation_components/network/nodes/firewall.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK ######## Firewall @@ -156,8 +156,8 @@ To prevent all external traffic from accessing the internal network, with except # Exception rule to allow HTTP traffic from external to internal network firewall.internal_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTP, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTP"], dst_ip_address="192.168.1.0", dst_wildcard_mask="0.0.0.255", position=2 @@ -172,16 +172,16 @@ To enable external traffic to access specific services hosted within the DMZ: # Allow HTTP and HTTPS traffic to the DMZ firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTP, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTP"], dst_ip_address="172.16.0.0", dst_wildcard_mask="0.0.0.255", position=3 ) firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTPS, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.0", dst_wildcard_mask="0.0.0.255", position=4 @@ -196,9 +196,9 @@ To permit SSH access from a designated external IP to a specific server within t # Allow SSH from a specific external IP to an internal server firewall.internal_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="10.0.0.2", - dst_port=Port.SSH, + dst_port=Port["SSH"], dst_ip_address="192.168.1.10", position=5 ) @@ -212,9 +212,9 @@ To limit database server access to selected external IP addresses: # Allow PostgreSQL traffic from an authorized external IP to the internal DB server firewall.internal_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="10.0.0.3", - dst_port=Port.POSTGRES_SERVER, + dst_port=Port["POSTGRES_SERVER"], dst_ip_address="192.168.1.20", position=6 ) @@ -222,8 +222,8 @@ To limit database server access to selected external IP addresses: # Deny all other PostgreSQL traffic from external sources firewall.internal_inbound_acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, - dst_port=Port.POSTGRES_SERVER, + protocol=IPProtocol["TCP"], + dst_port=Port["POSTGRES_SERVER"], dst_ip_address="192.168.1.0", dst_wildcard_mask="0.0.0.255", position=7 @@ -247,15 +247,15 @@ To authorize HTTP/HTTPS access to a DMZ-hosted web server, excluding known malic # Allow HTTP/HTTPS traffic to the DMZ web server firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTP, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTP"], dst_ip_address="172.16.0.2", position=9 ) firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTPS, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.2", position=10 ) @@ -269,9 +269,9 @@ To facilitate restricted access from the internal network to DMZ-hosted services # Permit specific internal application server HTTPS access to a DMZ-hosted API firewall.internal_outbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="192.168.1.30", # Internal application server IP - dst_port=Port.HTTPS, + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.3", # DMZ API server IP position=11 ) @@ -289,9 +289,9 @@ To facilitate restricted access from the internal network to DMZ-hosted services # Corresponding rule in DMZ inbound ACL to allow the traffic from the specific internal server firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="192.168.1.30", # Ensuring this specific source is allowed - dst_port=Port.HTTPS, + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.3", # DMZ API server IP position=13 ) @@ -301,7 +301,7 @@ To facilitate restricted access from the internal network to DMZ-hosted services action=ACLAction.DENY, src_ip_address="192.168.1.0", src_wildcard_mask="0.0.0.255", - dst_port=Port.HTTPS, + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.3", # DMZ API server IP position=14 ) @@ -315,8 +315,8 @@ To block all SSH access attempts from the external network: # Deny all SSH traffic from any external source firewall.external_inbound_acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, - dst_port=Port.SSH, + protocol=IPProtocol["TCP"], + dst_port=Port["SSH"], position=1 ) @@ -329,8 +329,8 @@ To allow the internal network to initiate HTTP connections to the external netwo # Permit outgoing HTTP traffic from the internal network to any external destination firewall.external_outbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTP, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTP"], position=2 ) diff --git a/docs/source/simulation_components/network/nodes/host_node.rst b/docs/source/simulation_components/network/nodes/host_node.rst index b8aae098..2c1e75d0 100644 --- a/docs/source/simulation_components/network/nodes/host_node.rst +++ b/docs/source/simulation_components/network/nodes/host_node.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK ######### diff --git a/docs/source/simulation_components/network/nodes/network_node.rst b/docs/source/simulation_components/network/nodes/network_node.rst index e1fa976c..4aebe09f 100644 --- a/docs/source/simulation_components/network/nodes/network_node.rst +++ b/docs/source/simulation_components/network/nodes/network_node.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK ############ Network Node diff --git a/docs/source/simulation_components/network/nodes/router.rst b/docs/source/simulation_components/network/nodes/router.rst index 5d3de60f..fb582b23 100644 --- a/docs/source/simulation_components/network/nodes/router.rst +++ b/docs/source/simulation_components/network/nodes/router.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK ###### Router diff --git a/docs/source/simulation_components/network/nodes/switch.rst b/docs/source/simulation_components/network/nodes/switch.rst index 0ecbcbf3..e7143f0c 100644 --- a/docs/source/simulation_components/network/nodes/switch.rst +++ b/docs/source/simulation_components/network/nodes/switch.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK ###### Switch diff --git a/docs/source/simulation_components/network/nodes/wireless_router.rst b/docs/source/simulation_components/network/nodes/wireless_router.rst index 0c875801..4078ffda 100644 --- a/docs/source/simulation_components/network/nodes/wireless_router.rst +++ b/docs/source/simulation_components/network/nodes/wireless_router.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK ###### Wireless Router @@ -49,7 +49,7 @@ additional steps to configure wireless settings: wireless_router.configure_wireless_access_point( port=1, ip_address="192.168.2.1", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency.WIFI_2_4, + frequency="WIFI_2_4", ) @@ -102,7 +102,8 @@ ICMP traffic, ensuring basic network connectivity and ping functionality. network.connect(pc_a.network_interface[1], router_1.router_interface) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Configure PC B pc_b = Computer( @@ -129,13 +130,13 @@ ICMP traffic, ensuring basic network connectivity and ping functionality. port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency.WIFI_2_4, + frequency="WIFI_2_4", ) router_2.configure_wireless_access_point( port=1, ip_address="192.168.1.2", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency.WIFI_2_4, + frequency="WIFI_2_4", ) # Configure routes for inter-router communication diff --git a/docs/source/simulation_components/network/transport_to_data_link_layer.rst b/docs/source/simulation_components/network/transport_to_data_link_layer.rst index cc546021..54118c90 100644 --- a/docs/source/simulation_components/network/transport_to_data_link_layer.rst +++ b/docs/source/simulation_components/network/transport_to_data_link_layer.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK Transport Layer to Data Link Layer ================================== @@ -104,7 +104,7 @@ address of 'aa:bb:cc:dd:ee:ff' to port 8080 on the host 10.0.0.10 which has a NI ip_packet = IPPacket( src_ip_address="192.168.0.100", dst_ip_address="10.0.0.10", - protocol=IPProtocol.TCP + protocol=IPProtocol["TCP"] ) # Data Link Layer ethernet_header = EthernetHeader( diff --git a/docs/source/simulation_components/system/applications/c2_suite.rst b/docs/source/simulation_components/system/applications/c2_suite.rst index d045949a..c780485a 100644 --- a/docs/source/simulation_components/system/applications/c2_suite.rst +++ b/docs/source/simulation_components/system/applications/c2_suite.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _C2_Suite: @@ -183,7 +183,7 @@ Python # Example command: Installing and configuring Ransomware: ransomware_installation_command = { "commands": [ - ["software_manager","application","install","RansomwareScript"], + ["software_manager","application","install","ransomware-script"], ], "username": "admin", "password": "admin", @@ -229,7 +229,7 @@ Via Configuration type: computer ... applications: - type: C2Server + type: c2-server ... hostname: computer_b type: computer @@ -238,7 +238,7 @@ Via Configuration # Either an agent must use application_execute. # Or a if using the simulation layer - .establish(). applications: - type: C2Beacon + type: c2-beacon options: c2_server_ip_address: ... keep_alive_frequency: 5 diff --git a/docs/source/simulation_components/system/applications/data_manipulation_bot.rst b/docs/source/simulation_components/system/applications/data_manipulation_bot.rst index 1a387514..04c581bd 100644 --- a/docs/source/simulation_components/system/applications/data_manipulation_bot.rst +++ b/docs/source/simulation_components/system/applications/data_manipulation_bot.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _DataManipulationBot: @@ -77,7 +77,7 @@ Python network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) client_1.software_manager.install(DatabaseClient) client_1.software_manager.install(DataManipulationBot) - data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") + data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("data-manipulation-bot") data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE") data_manipulation_bot.run() @@ -95,39 +95,7 @@ If not using the data manipulation bot manually, it needs to be used with a data agents: - ref: data_manipulation_red_bot team: RED - type: RedDatabaseCorruptingAgent - - observation_space: - type: UC2RedObservation - options: - nodes: - - node_name: client_1 - observations: - - logon_status - - operating_status - applications: - - application_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} - - action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_1 - applications: - - application_ref: data_manipulation_bot - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - - reward_function: - reward_components: - - type: DUMMY + type: red-database-corrupting-agent agent_settings: start_settings: @@ -144,14 +112,14 @@ If not using the data manipulation bot manually, it needs to be used with a data # ... additional configuration here applications: - ref: data_manipulation_bot - type: DataManipulationBot + type: data-manipulation-bot options: port_scan_p_of_success: 0.1 data_manipulation_p_of_success: 0.1 payload: "DELETE" server_ip: 192.168.1.14 - ref: web_server_database_client - type: DatabaseClient + type: database-client options: db_server_ip: 192.168.1.14 diff --git a/docs/source/simulation_components/system/applications/database_client.rst b/docs/source/simulation_components/system/applications/database_client.rst index 1fea78ab..465827d9 100644 --- a/docs/source/simulation_components/system/applications/database_client.rst +++ b/docs/source/simulation_components/system/applications/database_client.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _DatabaseClient: @@ -59,7 +59,7 @@ Python # install DatabaseClient client.software_manager.install(DatabaseClient) - database_client: DatabaseClient = client.software_manager.software.get("DatabaseClient") + database_client: DatabaseClient = client.software_manager.software.get("database-client") # Configure the DatabaseClient database_client.configure(server_ip_address=IPv4Address("192.168.0.1")) # address of the DatabaseService @@ -83,7 +83,7 @@ Via Configuration ... applications: - ref: database_client - type: DatabaseClient + type: database-client options: db_server_ip: 192.168.0.1 diff --git a/docs/source/simulation_components/system/applications/dos_bot.rst b/docs/source/simulation_components/system/applications/dos_bot.rst index 6ad45424..47b72be7 100644 --- a/docs/source/simulation_components/system/applications/dos_bot.rst +++ b/docs/source/simulation_components/system/applications/dos_bot.rst @@ -1,13 +1,13 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _DoSBot: -DoSBot +dos-bot ###### -The ``DoSBot`` is an implementation of a Denial of Service attack within the PrimAITE simulation. +The ``dos-bot`` is an implementation of a Denial of Service attack within the PrimAITE simulation. This specifically simulates a `Slow Loris attack`_. .. _Slow Loris Attack: https://en.wikipedia.org/wiki/Slowloris_(computer_security) @@ -15,20 +15,20 @@ This specifically simulates a `Slow Loris attack`_. Key features ============ -- Connects to the :ref:`DatabaseService` via the ``SoftwareManager``. -- Makes many connections to the :ref:`DatabaseService` which ends up using up the available connections. +- Connects to the :ref:`database-service` via the ``SoftwareManager``. +- Makes many connections to the :ref:`database-service` which ends up using up the available connections. Usage ===== - Configure with target IP address and optional password. -- use ``run`` to run the application_loop of DoSBot to begin attacks -- DoSBot runs through different actions at each timestep +- use ``run`` to run the application_loop of dos-bot to begin attacks +- dos-bot runs through different actions at each timestep Implementation ============== -- Leverages :ref:`DatabaseClient` to create connections with :ref`DatabaseServer`. +- Leverages :ref:`database-client` to create connections with :ref`DatabaseServer`. - Extends base Application class. Examples @@ -42,7 +42,7 @@ Python from ipaddress import IPv4Address from primaite.simulator.network.hardware.nodes.host.computer import Computer - from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot + from primaite.simulator.system.applications.red_applications.dos_bot import dos-bot # Create Computer computer = Computer( @@ -54,11 +54,11 @@ Python ) computer.power_on() - # Install DoSBot on computer - computer.software_manager.install(DoSBot) - dos_bot: DoSBot = computer.software_manager.software.get("DoSBot") + # Install dos-bot on computer + computer.software_manager.install(dos-bot) + dos_bot: dos-bot = computer.software_manager.software.get("dos-bot") - # Configure the DoSBot + # Configure the dos-bot dos_bot.configure( target_ip_address=IPv4Address("192.168.0.10"), payload="SPOOF DATA", @@ -68,7 +68,7 @@ Python max_sessions=1000 ) - # run DoSBot + # run dos-bot dos_bot.run() @@ -86,7 +86,7 @@ Via Configuration ... applications: - ref: dos_bot - type: DoSBot + type: dos-bot options: target_ip_address: 192.168.0.10 payload: SPOOF DATA @@ -101,7 +101,7 @@ Configuration ``target_ip_address`` """"""""""""""""""""" -IP address of the :ref:`DatabaseService` which the ``DataManipulationBot`` will try to attack. +IP address of the :ref:`database-service` which the ``data-manipulation-bot`` will try to attack. This must be a valid octet i.e. in the range of ``0.0.0.0`` and ``255.255.255.255``. @@ -119,7 +119,7 @@ See :ref:`List of IPProtocols ` for a list of protocols. Optional. Default value is ``None``. -The payload that the ``DoSBot`` sends as part of its attack. +The payload that the ``dos-bot`` sends as part of its attack. .. include:: ../common/db_payload_list.rst @@ -128,14 +128,14 @@ The payload that the ``DoSBot`` sends as part of its attack. Optional. Default value is ``False``. -If ``True`` the ``DoSBot`` will maintain its attack. +If ``True`` the ``dos-bot`` will maintain its attack. ``port_scan_p_of_success`` """""""""""""""""""""""""" Optional. Default value is ``0.1``. -The chance of the ``DoSBot`` to succeed with a port scan (and therefore continue the attack). +The chance of the ``dos-bot`` to succeed with a port scan (and therefore continue the attack). This must be a float value between ``0`` and ``1``. @@ -153,7 +153,7 @@ This must be a float value between ``0`` and ``1``. Optional. Default value is ``1000``. -The maximum number of sessions the ``DoSBot`` is able to make. +The maximum number of sessions the ``dos-bot`` is able to make. This must be an integer value equal to or greater than ``0``. diff --git a/docs/source/simulation_components/system/applications/nmap.rst b/docs/source/simulation_components/system/applications/nmap.rst index dbb8a022..a82735c8 100644 --- a/docs/source/simulation_components/system/applications/nmap.rst +++ b/docs/source/simulation_components/system/applications/nmap.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _NMAP: @@ -165,8 +165,8 @@ Perform a horizontal port scan on port 5432 across multiple IP addresses: { IPv4Address('192.168.1.12'): { - : [ - + : [ + ] } } @@ -192,7 +192,7 @@ Perform a vertical port scan on multiple ports on a single IP address: vertical_scan_results = pc_1_nmap.port_scan( target_ip_address=[IPv4Address("192.168.1.12")], - target_port=[Port(21), Port(22), Port(80), Port(443)] + target_port=[21, 22, 80, 443] ) .. code-block:: python @@ -200,9 +200,9 @@ Perform a vertical port scan on multiple ports on a single IP address: { IPv4Address('192.168.1.12'): { - : [ - , - + : [ + , + ] } } @@ -233,7 +233,7 @@ Perform a box scan on multiple ports across multiple IP addresses: box_scan_results = pc_1_nmap.port_scan( target_ip_address=[IPv4Address("192.168.1.12"), IPv4Address("192.168.1.13")], - target_port=[Port(21), Port(22), Port(80), Port(443)] + target_port=[21, 22, 80, 443] ) .. code-block:: python @@ -241,15 +241,15 @@ Perform a box scan on multiple ports across multiple IP addresses: { IPv4Address('192.168.1.13'): { - : [ - , - + : [ + , + ] }, IPv4Address('192.168.1.12'): { - : [ - , - + : [ + , + ] } } @@ -289,36 +289,36 @@ Perform a full box scan on all ports, over both TCP and UDP, on a whole subnet: { IPv4Address('192.168.1.11'): { - : [ - + : [ + ] }, IPv4Address('192.168.1.1'): { - : [ - + : [ + ] }, IPv4Address('192.168.1.12'): { - : [ - , - , - , - + : [ + , + , + , + ], - : [ - , - + : [ + , + ] }, IPv4Address('192.168.1.13'): { - : [ - , - , - + : [ + , + , + ], - : [ - , - + : [ + , + ] } } diff --git a/docs/source/simulation_components/system/applications/ransomware_script.rst b/docs/source/simulation_components/system/applications/ransomware_script.rst index 5bff6991..a8975f32 100644 --- a/docs/source/simulation_components/system/applications/ransomware_script.rst +++ b/docs/source/simulation_components/system/applications/ransomware_script.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _RansomwareScript: @@ -62,7 +62,7 @@ Python network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) client_1.software_manager.install(DatabaseClient) client_1.software_manager.install(RansomwareScript) - RansomwareScript: RansomwareScript = client_1.software_manager.software.get("RansomwareScript") + RansomwareScript: RansomwareScript = client_1.software_manager.software.get("ransomware-script") RansomwareScript.configure(server_ip_address=IPv4Address("192.168.1.14")) RansomwareScript.execute() @@ -70,7 +70,7 @@ Python Configuration ============= -The RansomwareScript inherits configuration options such as ``fix_duration`` from its parent class. However, for the ``RansomwareScript`` the most relevant option is ``server_ip``. +The RansomwareScript inherits configuration options such as ``fixing_duration`` from its parent class. However, for the ``RansomwareScript`` the most relevant option is ``server_ip``. ``server_ip`` diff --git a/docs/source/simulation_components/system/applications/web_browser.rst b/docs/source/simulation_components/system/applications/web_browser.rst index c56c450d..659caa09 100644 --- a/docs/source/simulation_components/system/applications/web_browser.rst +++ b/docs/source/simulation_components/system/applications/web_browser.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _WebBrowser: @@ -61,7 +61,7 @@ The :ref:`DNSClient` must be configured to use the :ref:`DNSServer`. The :ref:`D # Install WebBrowser on computer computer.software_manager.install(WebBrowser) - web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") + web_browser: WebBrowser = computer.software_manager.software.get("web-browser") web_browser.run() # configure the WebBrowser @@ -85,7 +85,7 @@ Via Configuration ... applications: - ref: web_browser - type: WebBrowser + type: web-browser options: target_url: http://arcd.com/ diff --git a/docs/source/simulation_components/system/common/common_configuration.rst b/docs/source/simulation_components/system/common/common_configuration.rst index c53ac8b8..c1bbd4b2 100644 --- a/docs/source/simulation_components/system/common/common_configuration.rst +++ b/docs/source/simulation_components/system/common/common_configuration.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _Common Configuration: @@ -22,8 +22,8 @@ options The configuration options are the attributes that fall under the options for an application or service. -fix_duration -"""""""""""" +fixing_duration +""""""""""""""" Optional. Default value is ``2``. diff --git a/docs/source/simulation_components/system/common/db_payload_list.rst b/docs/source/simulation_components/system/common/db_payload_list.rst index 0930f09d..89668665 100644 --- a/docs/source/simulation_components/system/common/db_payload_list.rst +++ b/docs/source/simulation_components/system/common/db_payload_list.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _Database Payload List: diff --git a/docs/source/simulation_components/system/internal_frame_processing.rst b/docs/source/simulation_components/system/internal_frame_processing.rst index 65336f9b..f82dec13 100644 --- a/docs/source/simulation_components/system/internal_frame_processing.rst +++ b/docs/source/simulation_components/system/internal_frame_processing.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _internal_frame_processing: diff --git a/docs/source/simulation_components/system/list_of_applications.rst b/docs/source/simulation_components/system/list_of_applications.rst index 94090d93..a7e05ea6 100644 --- a/docs/source/simulation_components/system/list_of_applications.rst +++ b/docs/source/simulation_components/system/list_of_applications.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. toctree:: :maxdepth: 1 diff --git a/docs/source/simulation_components/system/list_of_services.rst b/docs/source/simulation_components/system/list_of_services.rst index b6995647..2082ac6f 100644 --- a/docs/source/simulation_components/system/list_of_services.rst +++ b/docs/source/simulation_components/system/list_of_services.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. toctree:: :maxdepth: 1 diff --git a/docs/source/simulation_components/system/list_of_system_applications.rst b/docs/source/simulation_components/system/list_of_system_applications.rst index c8807ef0..0c66662f 100644 --- a/docs/source/simulation_components/system/list_of_system_applications.rst +++ b/docs/source/simulation_components/system/list_of_system_applications.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK ``system applications`` """"""""""""""""""""""" diff --git a/docs/source/simulation_components/system/list_of_system_services.rst b/docs/source/simulation_components/system/list_of_system_services.rst index 9b5c3265..01df4dc8 100644 --- a/docs/source/simulation_components/system/list_of_system_services.rst +++ b/docs/source/simulation_components/system/list_of_system_services.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK ``system services`` """"""""""""""""""" diff --git a/docs/source/simulation_components/system/pcap.rst b/docs/source/simulation_components/system/pcap.rst index 830c28bd..0da28a39 100644 --- a/docs/source/simulation_components/system/pcap.rst +++ b/docs/source/simulation_components/system/pcap.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK PCAP ==== diff --git a/docs/source/simulation_components/system/services/database_service.rst b/docs/source/simulation_components/system/services/database_service.rst index f3e800cd..c819a0f7 100644 --- a/docs/source/simulation_components/system/services/database_service.rst +++ b/docs/source/simulation_components/system/services/database_service.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _DatabaseService: @@ -66,7 +66,7 @@ Python # Install DatabaseService on server server.software_manager.install(DatabaseService) - db_service: DatabaseService = server.software_manager.software.get("DatabaseService") + db_service: DatabaseService = server.software_manager.software.get("database-service") db_service.start() # configure DatabaseService @@ -87,7 +87,7 @@ Via Configuration ... services: - ref: database_service - type: DatabaseService + type: database-service options: backup_server_ip: 192.168.0.10 diff --git a/docs/source/simulation_components/system/services/dns_client.rst b/docs/source/simulation_components/system/services/dns_client.rst index eca152f0..40762bfc 100644 --- a/docs/source/simulation_components/system/services/dns_client.rst +++ b/docs/source/simulation_components/system/services/dns_client.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _DNSClient: @@ -56,7 +56,7 @@ Python # Install DNSClient on server server.software_manager.install(DNSClient) - dns_client: DNSClient = server.software_manager.software.get("DNSClient") + dns_client: DNSClient = server.software_manager.software.get("dns-client") dns_client.start() # configure DatabaseService @@ -77,7 +77,7 @@ Via Configuration ... services: - ref: dns_client - type: DNSClient + type: dns-client options: dns_server: 192.168.0.10 diff --git a/docs/source/simulation_components/system/services/dns_server.rst b/docs/source/simulation_components/system/services/dns_server.rst index 1e30b9bd..ca0e3691 100644 --- a/docs/source/simulation_components/system/services/dns_server.rst +++ b/docs/source/simulation_components/system/services/dns_server.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _DNSServer: @@ -53,7 +53,7 @@ Python # Install DNSServer on server server.software_manager.install(DNSServer) - dns_server: DNSServer = server.software_manager.software.get("DNSServer") + dns_server: DNSServer = server.software_manager.software.get("dns-server") dns_server.start() # configure DatabaseService @@ -74,7 +74,7 @@ Via Configuration ... services: - ref: dns_server - type: DNSServer + type: dns-server options: domain_mapping: arcd.com: 192.168.0.10 diff --git a/docs/source/simulation_components/system/services/ftp_client.rst b/docs/source/simulation_components/system/services/ftp_client.rst index c8a21743..530b5aff 100644 --- a/docs/source/simulation_components/system/services/ftp_client.rst +++ b/docs/source/simulation_components/system/services/ftp_client.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _FTPClient: @@ -15,7 +15,7 @@ Key features - Connects to the :ref:`FTPServer` via the ``SoftwareManager``. - Simulates FTP requests and FTPPacket transfer across a network - Allows the emulation of FTP commands between an FTP client and server: - - PORT: specifies the port that server should connect to on the client (currently only uses ``Port.FTP``) + - PORT: specifies the port that server should connect to on the client (currently only uses ``Port["FTP"]``) - STOR: stores a file from client to server - RETR: retrieves a file from the FTP server - QUIT: disconnect from server @@ -60,7 +60,7 @@ Python # Install FTPClient on server server.software_manager.install(FTPClient) - ftp_client: FTPClient = server.software_manager.software.get("FTPClient") + ftp_client: FTPClient = server.software_manager.software.get("ftp-client") ftp_client.start() @@ -78,7 +78,7 @@ Via Configuration ... services: - ref: ftp_client - type: FTPClient + type: ftp-client Configuration ============= diff --git a/docs/source/simulation_components/system/services/ftp_server.rst b/docs/source/simulation_components/system/services/ftp_server.rst index f52fa043..20dd6707 100644 --- a/docs/source/simulation_components/system/services/ftp_server.rst +++ b/docs/source/simulation_components/system/services/ftp_server.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _FTPServer: @@ -55,7 +55,7 @@ Python # Install FTPServer on server server.software_manager.install(FTPServer) - ftp_server: FTPServer = server.software_manager.software.get("FTPServer") + ftp_server: FTPServer = server.software_manager.software.get("ftp-server") ftp_server.start() ftp_server.server_password = "test" @@ -74,7 +74,7 @@ Via Configuration ... services: - ref: ftp_server - type: FTPServer + type: ftp-server options: server_password: test diff --git a/docs/source/simulation_components/system/services/ntp_client.rst b/docs/source/simulation_components/system/services/ntp_client.rst index 7af831bf..5406d9fc 100644 --- a/docs/source/simulation_components/system/services/ntp_client.rst +++ b/docs/source/simulation_components/system/services/ntp_client.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _NTPClient: @@ -53,7 +53,7 @@ Python # Install NTPClient on server server.software_manager.install(NTPClient) - ntp_client: NTPClient = server.software_manager.software.get("NTPClient") + ntp_client: NTPClient = server.software_manager.software.get("ntp-client") ntp_client.start() ntp_client.configure(ntp_server_ip_address=IPv4Address("192.168.0.10")) @@ -73,7 +73,7 @@ Via Configuration ... services: - ref: ntp_client - type: NTPClient + type: ntp-client options: ntp_server_ip: 192.168.0.10 diff --git a/docs/source/simulation_components/system/services/ntp_server.rst b/docs/source/simulation_components/system/services/ntp_server.rst index a09c8bdd..2c01dcaf 100644 --- a/docs/source/simulation_components/system/services/ntp_server.rst +++ b/docs/source/simulation_components/system/services/ntp_server.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _NTPServer: @@ -55,7 +55,7 @@ Python # Install NTPServer on server server.software_manager.install(NTPServer) - ntp_server: NTPServer = server.software_manager.software.get("NTPServer") + ntp_server: NTPServer = server.software_manager.software.get("ntp-server") ntp_server.start() @@ -73,7 +73,7 @@ Via Configuration ... services: - ref: ntp_server - type: NTPServer + type: ntp-server ``Common Attributes`` diff --git a/docs/source/simulation_components/system/services/terminal.rst b/docs/source/simulation_components/system/services/terminal.rst index b11d74bb..5c9bad79 100644 --- a/docs/source/simulation_components/system/services/terminal.rst +++ b/docs/source/simulation_components/system/services/terminal.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _Terminal: @@ -78,16 +78,16 @@ The below code examples demonstrate how to use terminal related actions in yaml yaml """" -``NODE_SEND_LOCAL_COMMAND`` +``node-send-local-command`` """"""""""""""""""""""""""" -Agents can execute local commands without needing to perform a separate remote login action (``SSH_TO_REMOTE``). +Agents can execute local commands without needing to perform a separate remote login action (``node-session-remote-login``). .. code-block:: yaml ... ... - action: NODE_SEND_LOCAL_COMMAND + action: node-send-local-command options: node_id: 0 username: admin @@ -101,7 +101,7 @@ Agents can execute local commands without needing to perform a separate remote l - "False" -``SSH_TO_REMOTE`` +``node-session-remote-login`` """"""""""""""""" Agents are able to use the terminal to login into remote nodes via ``SSH`` which allows for agents to execute commands on remote hosts. @@ -110,7 +110,7 @@ Agents are able to use the terminal to login into remote nodes via ``SSH`` which ... ... - action: SSH_TO_REMOTE + action: node-session-remote-login options: node_id: 0 username: admin @@ -118,16 +118,16 @@ Agents are able to use the terminal to login into remote nodes via ``SSH`` which remote_ip: 192.168.0.10 # Example Ip Address. (The remote host's IP that will be used by ssh) -``NODE_SEND_REMOTE_COMMAND`` +``node-send-remote-command`` """""""""""""""""""""""""""" -After remotely logging into another host, an agent can use the ``NODE_SEND_REMOTE_COMMAND`` to execute commands across the network remotely. +After remotely logging into another host, an agent can use the ``node-send-remote-command`` to execute commands across the network remotely. .. code-block:: yaml ... ... - action: NODE_SEND_REMOTE_COMMAND + action: node-send-remote-command options: node_id: 0 remote_ip: 192.168.0.10 @@ -166,7 +166,7 @@ Python operating_state=NodeOperatingState.ON, ) - terminal: Terminal = client.software_manager.software.get("Terminal") + terminal: Terminal = client.software_manager.software.get("terminal") Creating Remote Terminal Connection """"""""""""""""""""""""""""""""""" @@ -187,7 +187,7 @@ Creating Remote Terminal Connection node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) - terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + terminal_a: Terminal = node_a.software_manager.software.get("terminal") term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") @@ -213,12 +213,12 @@ Executing a basic application install command node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) - terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + terminal_a: Terminal = node_a.software_manager.software.get("terminal") term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") - term_a_term_b_remote_connection.execute(["software_manager", "application", "install", "RansomwareScript"]) + term_a_term_b_remote_connection.execute(["software_manager", "application", "install", "ransomware-script"]) @@ -241,7 +241,7 @@ Creating a folder on a remote node node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) - terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + terminal_a: Terminal = node_a.software_manager.software.get("terminal") term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") @@ -268,7 +268,7 @@ Disconnect from Remote Node node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) - terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + terminal_a: Terminal = node_a.software_manager.software.get("terminal") term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") diff --git a/docs/source/simulation_components/system/services/web_server.rst b/docs/source/simulation_components/system/services/web_server.rst index cec20a60..9d7f4d2f 100644 --- a/docs/source/simulation_components/system/services/web_server.rst +++ b/docs/source/simulation_components/system/services/web_server.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _WebServer: @@ -56,7 +56,7 @@ Python # Install WebServer on server server.software_manager.install(WebServer) - web_server: WebServer = server.software_manager.software.get("WebServer") + web_server: WebServer = server.software_manager.software.get("web-server") web_server.start() Via Configuration @@ -73,7 +73,7 @@ Via Configuration ... services: - ref: web_server - type: WebServer + type: web-server ``Common Attributes`` diff --git a/docs/source/simulation_components/system/session_and_software_manager.rst b/docs/source/simulation_components/system/session_and_software_manager.rst index 230f6687..f20af556 100644 --- a/docs/source/simulation_components/system/session_and_software_manager.rst +++ b/docs/source/simulation_components/system/session_and_software_manager.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK Session and Software Manager ============================ diff --git a/docs/source/simulation_components/system/software.rst b/docs/source/simulation_components/system/software.rst index c8f0e2d3..c2f3066b 100644 --- a/docs/source/simulation_components/system/software.rst +++ b/docs/source/simulation_components/system/software.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _software: @@ -30,7 +30,7 @@ See :ref:`Node Start up and Shut down` node.software_manager.install(WebServer) - web_server: WebServer = node.software_manager.software.get("WebServer") + web_server: WebServer = node.software_manager.software.get("web-server") assert web_server.operating_state is ServiceOperatingState.RUNNING # service is immediately ran after install node.power_off() diff --git a/docs/source/simulation_components/system/sys_log.rst b/docs/source/simulation_components/system/sys_log.rst index cdf19faa..05629993 100644 --- a/docs/source/simulation_components/system/sys_log.rst +++ b/docs/source/simulation_components/system/sys_log.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK SysLog ====== diff --git a/docs/source/simulation_structure.rst b/docs/source/simulation_structure.rst index cd9ac409..7debe112 100644 --- a/docs/source/simulation_structure.rst +++ b/docs/source/simulation_structure.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK Simulation Structure diff --git a/docs/source/state_system.rst b/docs/source/state_system.rst index e31474ea..a5fd1df1 100644 --- a/docs/source/state_system.rst +++ b/docs/source/state_system.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK Simulation State ================ diff --git a/docs/source/varying_config_files.rst b/docs/source/varying_config_files.rst index fa66f0d9..942e522b 100644 --- a/docs/source/varying_config_files.rst +++ b/docs/source/varying_config_files.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK Defining variations in the config files ======================================= diff --git a/pyproject.toml b/pyproject.toml index 354df8b2..e840797c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "primaite" description = "PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme." authors = [{name="Defence Science and Technology Laboratory UK", email="oss@dstl.gov.uk"}] license = {file = "LICENSE"} -requires-python = ">=3.9, <3.12" +requires-python = ">=3.9, <3.13" dynamic = ["version", "readme"] classifiers = [ "Development Status :: 5 - Production/Stable", @@ -26,15 +26,15 @@ dependencies = [ "gymnasium==0.28.1", "jupyterlab==3.6.1", "kaleido==0.2.1", - "matplotlib==3.7.1", + "matplotlib>=3.7.1", "networkx==3.1", - "numpy==1.23.5", + "numpy~=1.23", "platformdirs==3.5.1", "plotly==5.15.0", "polars==0.20.30", "prettytable==3.8.0", - "PyYAML==6.0", - "typer[all]==0.9.0", + "PyYAML>=6.0", + "typer[all]>=0.9", "pydantic==2.7.0", "ipywidgets", "deepdiff" @@ -53,8 +53,8 @@ license-files = ["LICENSE"] [project.optional-dependencies] rl = [ "ray[rllib] >= 2.20.0, <2.33", - "tensorflow==2.12.0", - "stable-baselines3[extra]==2.1.0", + "tensorflow~=2.12", + "stable-baselines3==2.1.0", "sb3-contrib==2.1.0", ] dev = [ @@ -69,7 +69,7 @@ dev = [ "pytest-xdist==3.3.1", "pytest-cov==4.0.0", "pytest-flake8==1.1.1", - "setuptools==66", + "setuptools==75.6.0", "Sphinx==7.1.2", "sphinx-copybutton==0.5.2", "wheel==0.38.4", diff --git a/run_test_and_coverage.py b/run_test_and_coverage.py index 3bd9072d..dfa71f74 100644 --- a/run_test_and_coverage.py +++ b/run_test_and_coverage.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import subprocess import sys from typing import Any diff --git a/src/primaite/VERSION b/src/primaite/VERSION index 688932aa..d9b058f1 100644 --- a/src/primaite/VERSION +++ b/src/primaite/VERSION @@ -1 +1 @@ -3.4.0-dev +4.0.0-dev diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 8dd84428..54eac69d 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import datetime as datetime import logging import logging.config diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 4fbbdec9..2bd18baf 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """Provides a CLI using Typer as an entry point.""" import logging import os diff --git a/src/primaite/config/__init__.py b/src/primaite/config/__init__.py index c2ae1b5b..7b5e2889 100644 --- a/src/primaite/config/__init__.py +++ b/src/primaite/config/__init__.py @@ -1,2 +1,2 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """Configuration parameters for running experiments.""" diff --git a/src/primaite/config/_package_data/basic_lan_network_example.yaml b/src/primaite/config/_package_data/basic_lan_network_example.yaml index 9490ff00..9996be84 100644 --- a/src/primaite/config/_package_data/basic_lan_network_example.yaml +++ b/src/primaite/config/_package_data/basic_lan_network_example.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + game: ports: - ARP diff --git a/src/primaite/config/_package_data/client_server_p2p_network_example.yaml b/src/primaite/config/_package_data/client_server_p2p_network_example.yaml index 798dd318..1a9fca98 100644 --- a/src/primaite/config/_package_data/client_server_p2p_network_example.yaml +++ b/src/primaite/config/_package_data/client_server_p2p_network_example.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + game: ports: - ARP diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index 2a069971..7d9fdf36 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + io_settings: save_agent_actions: true save_step_metadata: false @@ -24,98 +27,72 @@ game: agents: - ref: client_2_green_user team: GREEN - type: ProbabilisticAgent + type: probabilistic-agent agent_settings: action_probabilities: 0: 0.3 1: 0.6 2: 0.1 - observation_space: null + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_2 - applications: - - application_name: WebBrowser - - application_name: DatabaseClient - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 2 action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 + node_name: client_2 + application_name: web-browser 2: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 1 + node_name: client_2 + application_name: database-client reward_function: reward_components: - - type: WEBPAGE_UNAVAILABLE_PENALTY + - type: webpage-unavailable-penalty weight: 0.25 options: node_hostname: client_2 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + - type: green-admin-database-unreachable-penalty weight: 0.05 options: node_hostname: client_2 - ref: client_1_green_user team: GREEN - type: ProbabilisticAgent + type: probabilistic-agent agent_settings: action_probabilities: 0: 0.3 1: 0.6 2: 0.1 - observation_space: null + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_1 - applications: - - application_name: WebBrowser - - application_name: DatabaseClient - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 2 action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 + node_name: client_1 + application_name: web-browser 2: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 1 + node_name: client_1 + application_name: database-client reward_function: reward_components: - - type: WEBPAGE_UNAVAILABLE_PENALTY + - type: webpage-unavailable-penalty weight: 0.25 options: node_hostname: client_1 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + - type: green-admin-database-unreachable-penalty weight: 0.05 options: node_hostname: client_1 @@ -126,52 +103,31 @@ agents: - ref: data_manipulation_attacker team: RED - type: RedDatabaseCorruptingAgent + type: red-database-corrupting-agent - observation_space: null - - action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_1 - applications: - - application_name: DataManipulationBot - - node_name: client_2 - applications: - - application_name: DataManipulationBot - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - - reward_function: - reward_components: - - type: DUMMY - - agent_settings: # options specific to this particular agent type, basically args of __init__(self) - start_settings: - start_step: 25 - frequency: 20 - variance: 5 + agent_settings: + possible_start_nodes: [client_1, client_2] + target_application: data-manipulation-bot + start_step: 25 + frequency: 20 + variance: 5 - ref: defender team: BLUE - type: ProxyAgent + type: proxy-agent observation_space: - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: - hostname: domain_controller - hostname: web_server services: - - service_name: WebServer + - service_name: web-server - hostname: database_server folders: - folder_name: database @@ -208,15 +164,15 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: @@ -230,511 +186,447 @@ agents: - switch_2:eth-1<->client_1:eth-1 - switch_2:eth-2<->client_2:eth-1 - switch_2:eth-7<->security_suite:eth-2 - - type: "NONE" + - type: "none" label: ICS options: {} action_space: - action_list: - - type: DONOTHING - - type: NODE_SERVICE_SCAN - - type: NODE_SERVICE_STOP - - type: NODE_SERVICE_START - - type: NODE_SERVICE_PAUSE - - type: NODE_SERVICE_RESUME - - type: NODE_SERVICE_RESTART - - type: NODE_SERVICE_DISABLE - - type: NODE_SERVICE_ENABLE - - type: NODE_SERVICE_FIX - - type: NODE_FILE_SCAN - - type: NODE_FILE_CHECKHASH - - type: NODE_FILE_DELETE - - type: NODE_FILE_REPAIR - - type: NODE_FILE_RESTORE - - type: NODE_FOLDER_SCAN - - type: NODE_FOLDER_CHECKHASH - - type: NODE_FOLDER_REPAIR - - type: NODE_FOLDER_RESTORE - - type: NODE_OS_SCAN - - type: NODE_SHUTDOWN - - type: NODE_STARTUP - - type: NODE_RESET - - type: ROUTER_ACL_ADDRULE - - type: ROUTER_ACL_REMOVERULE - - type: HOST_NIC_ENABLE - - type: HOST_NIC_DISABLE - action_map: 0: - action: DONOTHING + action: do-nothing options: {} # scan webapp service 1: - action: NODE_SERVICE_SCAN + action: node-service-scan options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # stop webapp service 2: - action: NODE_SERVICE_STOP + action: node-service-stop options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # start webapp service 3: - action: "NODE_SERVICE_START" + action: "node-service-start" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 4: - action: "NODE_SERVICE_PAUSE" + action: "node-service-pause" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 5: - action: "NODE_SERVICE_RESUME" + action: "node-service-resume" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 6: - action: "NODE_SERVICE_RESTART" + action: "node-service-restart" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 7: - action: "NODE_SERVICE_DISABLE" + action: "node-service-disable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 8: - action: "NODE_SERVICE_ENABLE" + action: "node-service-enable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 9: # check database.db file - action: "NODE_FILE_SCAN" + action: "node-file-scan" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 10: - action: "NODE_FILE_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + action: "node-file-scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 11: - action: "NODE_FILE_DELETE" + action: "node-file-delete" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 12: - action: "NODE_FILE_REPAIR" + action: "node-file-repair" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 13: - action: "NODE_SERVICE_FIX" + action: "node-service-fix" options: - node_id: 2 - service_id: 0 + node_name: database_server + service_name: database-service 14: - action: "NODE_FOLDER_SCAN" + action: "node-folder-scan" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 15: - action: "NODE_FOLDER_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + action: "node-folder-scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 16: - action: "NODE_FOLDER_REPAIR" + action: "node-folder-repair" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 17: - action: "NODE_FOLDER_RESTORE" + action: "node-folder-restore" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 18: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 0 + node_name: domain_controller 19: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 0 + node_name: domain_controller 20: - action: NODE_STARTUP + action: node-startup options: - node_id: 0 + node_name: domain_controller 21: - action: NODE_RESET + action: node-reset options: - node_id: 0 + node_name: domain_controller 22: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 1 + node_name: web_server 23: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 1 + node_name: web_server 24: - action: NODE_STARTUP + action: node-startup options: - node_id: 1 + node_name: web_server 25: - action: NODE_RESET + action: node-reset options: - node_id: 1 + node_name: web_server 26: # old action num: 18 - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 2 + node_name: database_server 27: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 2 + node_name: database_server 28: - action: NODE_STARTUP + action: node-startup options: - node_id: 2 + node_name: database_server 29: - action: NODE_RESET + action: node-reset options: - node_id: 2 + node_name: database_server 30: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 3 + node_name: backup_server 31: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 3 + node_name: backup_server 32: - action: NODE_STARTUP + action: node-startup options: - node_id: 3 + node_name: backup_server 33: - action: NODE_RESET + action: node-reset options: - node_id: 3 + node_name: backup_server 34: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 4 + node_name: security_suite 35: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 4 + node_name: security_suite 36: - action: NODE_STARTUP + action: node-startup options: - node_id: 4 + node_name: security_suite 37: - action: NODE_RESET + action: node-reset options: - node_id: 4 + node_name: security_suite 38: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 5 + node_name: client_1 39: # old action num: 19 # shutdown client 1 - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 5 + node_name: client_1 40: # old action num: 20 - action: NODE_STARTUP + action: node-startup options: - node_id: 5 + node_name: client_1 41: # old action num: 21 - action: NODE_RESET + action: node-reset options: - node_id: 5 + node_name: client_1 42: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 6 + node_name: client_2 43: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 6 + node_name: client_2 44: - action: NODE_STARTUP + action: node-startup options: - node_id: 6 + node_name: client_2 45: - action: NODE_RESET + action: node-reset options: - node_id: 6 + node_name: client_2 - 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" - action: "ROUTER_ACL_ADDRULE" + 46: # old action num: 22 # "acl: ADDRULE - Block outgoing traffic from client 1" + action: "router-acl-add-rule" options: target_router: router_1 position: 1 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" - action: "ROUTER_ACL_ADDRULE" + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE + 47: # old action num: 23 # "acl: ADDRULE - Block outgoing traffic from client 2" + action: "router-acl-add-rule" options: target_router: router_1 position: 2 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE 48: # old action num: 24 # block tcp traffic from client 1 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 3 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 49: # old action num: 25 # block tcp traffic from client 2 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 4 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 50: # old action num: 26 - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 5 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 51: # old action num: 27 - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 6 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 52: # old action num: 28 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 0 53: # old action num: 29 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 1 54: # old action num: 30 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 2 55: # old action num: 31 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 3 56: # old action num: 32 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 4 57: # old action num: 33 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 5 58: # old action num: 34 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 6 59: # old action num: 35 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 7 60: # old action num: 36 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 8 61: # old action num: 37 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 9 62: # old action num: 38 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 63: # old action num: 39 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 64: # old action num: 40 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 65: # old action num: 41 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 66: # old action num: 42 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 67: # old action num: 43 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 68: # old action num: 44 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 69: # old action num: 45 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 70: # old action num: 46 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 71: # old action num: 47 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 72: # old action num: 48 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 73: # old action num: 49 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 74: # old action num: 50 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 75: # old action num: 51 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 76: # old action num: 52 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 6 - nic_id: 0 + node_name: client_2 + nic_num: 1 77: # old action num: 53 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 6 - nic_id: 0 + node_name: client_2 + nic_num: 1 - options: - nodes: - - node_name: domain_controller - - node_name: web_server - applications: - - application_name: DatabaseClient - services: - - service_name: WebServer - - node_name: database_server - folders: - - folder_name: database - files: - - file_name: database.db - services: - - service_name: DatabaseService - - node_name: backup_server - - node_name: security_suite - - node_name: client_1 - - node_name: client_2 - - max_folders_per_node: 2 - max_files_per_folder: 2 - max_services_per_node: 2 - max_nics_per_node: 8 - max_acl_rules: 10 - ip_list: - - 192.168.1.10 - - 192.168.1.12 - - 192.168.1.14 - - 192.168.1.16 - - 192.168.1.110 - - 192.168.10.21 - - 192.168.10.22 - - 192.168.10.110 - reward_function: reward_components: - - type: DATABASE_FILE_INTEGRITY + - type: database-file-integrity weight: 0.40 options: node_hostname: database_server folder_name: database file_name: database.db - - type: SHARED_REWARD + - type: shared-reward weight: 1.0 options: agent_name: client_1_green_user - - type: SHARED_REWARD + - type: shared-reward weight: 1.0 options: agent_name: client_2_green_user @@ -804,7 +696,7 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 services: - - type: DNSServer + - type: dns-server options: domain_mapping: arcd.com: 192.168.1.12 # web server @@ -816,9 +708,9 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: WebServer + - type: web-server applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 @@ -830,10 +722,10 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: DatabaseService + - type: database-service options: backup_server_ip: 192.168.1.16 - - type: FTPClient + - type: ftp-client - hostname: backup_server type: server @@ -842,7 +734,8 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: FTPServer + - type: ftp-server + - hostname: security_suite type: server ip_address: 192.168.1.110 @@ -861,20 +754,20 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: DataManipulationBot + - type: data-manipulation-bot options: port_scan_p_of_success: 0.8 data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 - - type: WebBrowser + - type: web-browser options: target_url: http://arcd.com/users/ - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 services: - - type: DNSClient + - type: dns-client - hostname: client_2 type: computer @@ -883,20 +776,20 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: WebBrowser + - type: web-browser options: target_url: http://arcd.com/users/ - - type: DataManipulationBot + - type: data-manipulation-bot options: port_scan_p_of_success: 0.8 data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 services: - - type: DNSClient + - type: dns-client links: - endpoint_a_hostname: router_1 diff --git a/src/primaite/config/_package_data/data_manipulation_marl.yaml b/src/primaite/config/_package_data/data_manipulation_marl.yaml index ba666781..71507acb 100644 --- a/src/primaite/config/_package_data/data_manipulation_marl.yaml +++ b/src/primaite/config/_package_data/data_manipulation_marl.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + io_settings: save_agent_actions: true save_step_metadata: false @@ -20,98 +23,72 @@ game: agents: - ref: client_2_green_user team: GREEN - type: ProbabilisticAgent + type: probabilistic-agent agent_settings: action_probabilities: 0: 0.3 1: 0.6 2: 0.1 - observation_space: null + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_2 - applications: - - application_name: WebBrowser - - application_name: DatabaseClient - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 2 action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 + node_name: client_2 + application_name: web-browser 2: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 1 + node_name: client_2 + application_name: database-client reward_function: reward_components: - - type: WEBPAGE_UNAVAILABLE_PENALTY + - type: webpage-unavailable-penalty weight: 0.25 options: node_hostname: client_2 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + - type: green-admin-database-unreachable-penalty weight: 0.05 options: node_hostname: client_2 - ref: client_1_green_user team: GREEN - type: ProbabilisticAgent + type: probabilistic-agent agent_settings: action_probabilities: 0: 0.3 1: 0.6 2: 0.1 - observation_space: null + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_1 - applications: - - application_name: WebBrowser - - application_name: DatabaseClient - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 2 action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 + node_name: client_1 + application_name: web-browser 2: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 1 + node_name: client_1 + application_name: web-browser reward_function: reward_components: - - type: WEBPAGE_UNAVAILABLE_PENALTY + - type: webpage-unavailable-penalty weight: 0.25 options: node_hostname: client_1 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + - type: green-admin-database-unreachable-penalty weight: 0.05 options: node_hostname: client_1 @@ -122,55 +99,32 @@ agents: - ref: data_manipulation_attacker team: RED - type: RedDatabaseCorruptingAgent + type: red-database-corrupting-agent - observation_space: null - action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - - type: NODE_FILE_DELETE - - type: NODE_FILE_CORRUPT - - type: NODE_OS_SCAN - options: - nodes: - - node_name: client_1 - applications: - - application_name: DataManipulationBot - - node_name: client_2 - applications: - - application_name: DataManipulationBot - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - - reward_function: - reward_components: - - type: DUMMY - - agent_settings: # options specific to this particular agent type, basically args of __init__(self) - start_settings: - start_step: 25 - frequency: 20 - variance: 5 + agent_settings: + possible_start_nodes: [client_1, client_2] + target_application: data-manipulation-bot + start_step: 25 + frequency: 20 + variance: 5 - ref: defender_1 team: BLUE - type: ProxyAgent + type: proxy-agent observation_space: - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: - hostname: domain_controller - hostname: web_server services: - - service_name: WebServer + - service_name: web-server - hostname: database_server folders: - folder_name: database @@ -202,15 +156,15 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: @@ -224,508 +178,443 @@ agents: - switch_2:eth-1<->client_1:eth-1 - switch_2:eth-2<->client_2:eth-1 - switch_2:eth-7<->security_suite:eth-2 - - type: "NONE" + - type: "none" label: ICS options: {} action_space: - action_list: - - type: DONOTHING - - type: NODE_SERVICE_SCAN - - type: NODE_SERVICE_STOP - - type: NODE_SERVICE_START - - type: NODE_SERVICE_PAUSE - - type: NODE_SERVICE_RESUME - - type: NODE_SERVICE_RESTART - - type: NODE_SERVICE_DISABLE - - type: NODE_SERVICE_ENABLE - - type: NODE_SERVICE_FIX - - type: NODE_FILE_SCAN - - type: NODE_FILE_CHECKHASH - - type: NODE_FILE_DELETE - - type: NODE_FILE_REPAIR - - type: NODE_FILE_RESTORE - - type: NODE_FOLDER_SCAN - - type: NODE_FOLDER_CHECKHASH - - type: NODE_FOLDER_REPAIR - - type: NODE_FOLDER_RESTORE - - type: NODE_OS_SCAN - - type: NODE_SHUTDOWN - - type: NODE_STARTUP - - type: NODE_RESET - - type: ROUTER_ACL_ADDRULE - - type: ROUTER_ACL_REMOVERULE - - type: HOST_NIC_ENABLE - - type: HOST_NIC_DISABLE - action_map: 0: - action: DONOTHING + action: do-nothing options: {} # scan webapp service 1: - action: NODE_SERVICE_SCAN + action: node-service-scan options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # stop webapp service 2: - action: NODE_SERVICE_STOP + action: node-service-stop options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # start webapp service 3: - action: "NODE_SERVICE_START" + action: "node-service-start" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 4: - action: "NODE_SERVICE_PAUSE" + action: "node-service-pause" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 5: - action: "NODE_SERVICE_RESUME" + action: "node-service-resume" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 6: - action: "NODE_SERVICE_RESTART" + action: "node-service-restart" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 7: - action: "NODE_SERVICE_DISABLE" + action: "node-service-disable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 8: - action: "NODE_SERVICE_ENABLE" + action: "node-service-enable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 9: # check database.db file - action: "NODE_FILE_SCAN" + action: "node-file-scan" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 10: - action: "NODE_FILE_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + action: "node-file-scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 11: - action: "NODE_FILE_DELETE" + action: "node-file-delete" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 12: - action: "NODE_FILE_REPAIR" + action: "node-file-repair" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 13: - action: "NODE_SERVICE_FIX" + action: "node-service-fix" options: - node_id: 2 - service_id: 0 + node_name: database_server + service_name: database-service 14: - action: "NODE_FOLDER_SCAN" + action: "node-folder-scan" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 15: - action: "NODE_FOLDER_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + action: "node-folder-scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 16: - action: "NODE_FOLDER_REPAIR" + action: "node-folder-repair" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 17: - action: "NODE_FOLDER_RESTORE" + action: "node-folder-restore" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 18: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 0 + node_name: domain_controller 19: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 0 + node_name: domain_controller 20: - action: NODE_STARTUP + action: node-startup options: - node_id: 0 + node_name: domain_controller 21: - action: NODE_RESET + action: node-reset options: - node_id: 0 + node_name: domain_controller 22: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 1 + node_name: web_server 23: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 1 + node_name: web_server 24: - action: NODE_STARTUP + action: node-startup options: - node_id: 1 + node_name: web_server 25: - action: NODE_RESET + action: node-reset options: - node_id: 1 + node_name: web_server 26: # old action num: 18 - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 2 + node_name: database_server 27: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 2 + node_name: database_server 28: - action: NODE_STARTUP + action: node-startup options: - node_id: 2 + node_name: database_server 29: - action: NODE_RESET + action: node-reset options: - node_id: 2 + node_name: database_server 30: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 3 + node_name: backup_server 31: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 3 + node_name: backup_server 32: - action: NODE_STARTUP + action: node-startup options: - node_id: 3 + node_name: backup_server 33: - action: NODE_RESET + action: node-reset options: - node_id: 3 + node_name: backup_server 34: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 4 + node_name: security_suite 35: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 4 + node_name: security_suite 36: - action: NODE_STARTUP + action: node-startup options: - node_id: 4 + node_name: security_suite 37: - action: NODE_RESET + action: node-reset options: - node_id: 4 + node_name: security_suite 38: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 5 + node_name: client_1 39: # old action num: 19 # shutdown client 1 - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 5 + node_name: client_1 40: # old action num: 20 - action: NODE_STARTUP + action: node-startup options: - node_id: 5 + node_name: client_1 41: # old action num: 21 - action: NODE_RESET + action: node-reset options: - node_id: 5 + node_name: client_1 42: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 6 + node_name: client_2 43: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 6 + node_name: client_2 44: - action: NODE_STARTUP + action: node-startup options: - node_id: 6 + node_name: client_2 45: - action: NODE_RESET + action: node-reset options: - node_id: 6 + node_name: client_2 - 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" - action: "ROUTER_ACL_ADDRULE" + 46: # old action num: 22 # "acl: ADDRULE - Block outgoing traffic from client 1" + action: "router-acl-add-rule" options: target_router: router_1 position: 1 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" - action: "ROUTER_ACL_ADDRULE" + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE + 47: # old action num: 23 # "acl: ADDRULE - Block outgoing traffic from client 2" + action: "router-acl-add-rule" options: target_router: router_1 position: 2 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE 48: # old action num: 24 # block tcp traffic from client 1 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 3 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 49: # old action num: 25 # block tcp traffic from client 2 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 4 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 50: # old action num: 26 - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 5 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 51: # old action num: 27 - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 6 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 52: # old action num: 28 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 0 53: # old action num: 29 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 1 54: # old action num: 30 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 2 55: # old action num: 31 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 3 56: # old action num: 32 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 4 57: # old action num: 33 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 5 58: # old action num: 34 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 6 59: # old action num: 35 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 7 60: # old action num: 36 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 8 61: # old action num: 37 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 9 62: # old action num: 38 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 63: # old action num: 39 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 64: # old action num: 40 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 65: # old action num: 41 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 66: # old action num: 42 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 67: # old action num: 43 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 68: # old action num: 44 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 69: # old action num: 45 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 70: # old action num: 46 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 71: # old action num: 47 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 72: # old action num: 48 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 73: # old action num: 49 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 74: # old action num: 50 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 75: # old action num: 51 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 76: # old action num: 52 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 6 - nic_id: 0 + node_name: client_2 + nic_num: 1 77: # old action num: 53 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 6 - nic_id: 0 - - - options: - nodes: - - node_name: domain_controller - - node_name: web_server - applications: - - application_name: DatabaseClient - services: - - service_name: WebServer - - node_name: database_server - folders: - - folder_name: database - files: - - file_name: database.db - services: - - service_name: DatabaseService - - node_name: backup_server - - node_name: security_suite - - node_name: client_1 - - node_name: client_2 - - max_folders_per_node: 2 - max_files_per_folder: 2 - max_services_per_node: 2 - max_nics_per_node: 8 - max_acl_rules: 10 - ip_list: - - 192.168.1.10 - - 192.168.1.12 - - 192.168.1.14 - - 192.168.1.16 - - 192.168.1.110 - - 192.168.10.21 - - 192.168.10.22 - - 192.168.10.110 + node_name: client_2 + nic_num: 1 reward_function: reward_components: - - type: DATABASE_FILE_INTEGRITY + - type: database-file-integrity weight: 0.40 options: node_hostname: database_server folder_name: database file_name: database.db - - type: SHARED_REWARD + - type: shared-reward weight: 1.0 options: agent_name: client_1_green_user - - type: SHARED_REWARD + - type: shared-reward weight: 1.0 options: agent_name: client_2_green_user @@ -737,20 +626,20 @@ agents: - ref: defender_2 team: BLUE - type: ProxyAgent + type: proxy-agent observation_space: - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: - hostname: domain_controller - hostname: web_server services: - - service_name: WebServer + - service_name: web-server - hostname: database_server folders: - folder_name: database @@ -782,15 +671,15 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: @@ -804,512 +693,444 @@ agents: - switch_2:eth-1<->client_1:eth-1 - switch_2:eth-2<->client_2:eth-1 - switch_2:eth-7<->security_suite:eth-2 - - type: "NONE" + - type: "none" label: ICS options: {} action_space: - action_list: - - type: DONOTHING - - type: NODE_SERVICE_SCAN - - type: NODE_SERVICE_STOP - - type: NODE_SERVICE_START - - type: NODE_SERVICE_PAUSE - - type: NODE_SERVICE_RESUME - - type: NODE_SERVICE_RESTART - - type: NODE_SERVICE_DISABLE - - type: NODE_SERVICE_ENABLE - - type: NODE_SERVICE_FIX - - type: NODE_FILE_SCAN - - type: NODE_FILE_CHECKHASH - - type: NODE_FILE_DELETE - - type: NODE_FILE_REPAIR - - type: NODE_FILE_RESTORE - - type: NODE_FOLDER_SCAN - - type: NODE_FOLDER_CHECKHASH - - type: NODE_FOLDER_REPAIR - - type: NODE_FOLDER_RESTORE - - type: NODE_OS_SCAN - - type: NODE_SHUTDOWN - - type: NODE_STARTUP - - type: NODE_RESET - - type: ROUTER_ACL_ADDRULE - options: - target_router: router_1 - - type: ROUTER_ACL_REMOVERULE - options: - target_router: router_1 - - type: HOST_NIC_ENABLE - - type: HOST_NIC_DISABLE - action_map: 0: - action: DONOTHING + action: do-nothing options: {} # scan webapp service 1: - action: NODE_SERVICE_SCAN + action: node-service-scan options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # stop webapp service 2: - action: NODE_SERVICE_STOP + action: node-service-stop options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # start webapp service 3: - action: "NODE_SERVICE_START" + action: "node-service-start" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 4: - action: "NODE_SERVICE_PAUSE" + action: "node-service-pause" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 5: - action: "NODE_SERVICE_RESUME" + action: "node-service-resume" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 6: - action: "NODE_SERVICE_RESTART" + action: "node-service-restart" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 7: - action: "NODE_SERVICE_DISABLE" + action: "node-service-disable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 8: - action: "NODE_SERVICE_ENABLE" + action: "node-service-enable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 9: # check database.db file - action: "NODE_FILE_SCAN" + action: "node-file-scan" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 10: - action: "NODE_FILE_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + action: "node-file-scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 11: - action: "NODE_FILE_DELETE" + action: "node-file-delete" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 12: - action: "NODE_FILE_REPAIR" + action: "node-file-repair" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 13: - action: "NODE_SERVICE_FIX" + action: "node-service-fix" options: - node_id: 2 - service_id: 0 + node_name: database_server + service_name: database-service 14: - action: "NODE_FOLDER_SCAN" + action: "node-folder-scan" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 15: - action: "NODE_FOLDER_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + action: "node-folder-scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 16: - action: "NODE_FOLDER_REPAIR" + action: "node-folder-repair" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 17: - action: "NODE_FOLDER_RESTORE" + action: "node-folder-restore" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 18: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 0 + node_name: domain_controller 19: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 0 + node_name: domain_controller 20: - action: NODE_STARTUP + action: node-startup options: - node_id: 0 + node_name: domain_controller 21: - action: NODE_RESET + action: node-reset options: - node_id: 0 + node_name: domain_controller 22: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 1 + node_name: web_server 23: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 1 + node_name: web_server 24: - action: NODE_STARTUP + action: node-startup options: - node_id: 1 + node_name: web_server 25: - action: NODE_RESET + action: node-reset options: - node_id: 1 + node_name: web_server 26: # old action num: 18 - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 2 + node_name: database_server 27: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 2 + node_name: database_server 28: - action: NODE_STARTUP + action: node-startup options: - node_id: 2 + node_name: database_server 29: - action: NODE_RESET + action: node-reset options: - node_id: 2 + node_name: database_server 30: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 3 + node_name: backup_server 31: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 3 + node_name: backup_server 32: - action: NODE_STARTUP + action: node-startup options: - node_id: 3 + node_name: backup_server 33: - action: NODE_RESET + action: node-reset options: - node_id: 3 + node_name: backup_server 34: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 4 + node_name: security_suite 35: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 4 + node_name: security_suite 36: - action: NODE_STARTUP + action: node-startup options: - node_id: 4 + node_name: security_suite 37: - action: NODE_RESET + action: node-reset options: - node_id: 4 + node_name: security_suite 38: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 5 + node_name: client_1 39: # old action num: 19 # shutdown client 1 - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 5 + node_name: client_1 40: # old action num: 20 - action: NODE_STARTUP + action: node-startup options: - node_id: 5 + node_name: client_1 41: # old action num: 21 - action: NODE_RESET + action: node-reset options: - node_id: 5 + node_name: client_1 42: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 6 + node_name: client_2 43: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 6 + node_name: client_2 44: - action: NODE_STARTUP + action: node-startup options: - node_id: 6 + node_name: client_2 45: - action: NODE_RESET + action: node-reset options: - node_id: 6 + node_name: client_2 - 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" - action: "ROUTER_ACL_ADDRULE" + 46: # old action num: 22 # "acl: ADDRULE - Block outgoing traffic from client 1" + action: "router-acl-add-rule" options: target_router: router_1 position: 1 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" - action: "ROUTER_ACL_ADDRULE" + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE + 47: # old action num: 23 # "acl: ADDRULE - Block outgoing traffic from client 2" + action: "router-acl-add-rule" options: target_router: router_1 position: 2 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE 48: # old action num: 24 # block tcp traffic from client 1 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 3 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 49: # old action num: 25 # block tcp traffic from client 2 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 4 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 50: # old action num: 26 - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 5 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 51: # old action num: 27 - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 6 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 52: # old action num: 28 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 0 53: # old action num: 29 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 1 54: # old action num: 30 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 2 55: # old action num: 31 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 3 56: # old action num: 32 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 4 57: # old action num: 33 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 5 58: # old action num: 34 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 6 59: # old action num: 35 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 7 60: # old action num: 36 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 8 61: # old action num: 37 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 9 62: # old action num: 38 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 63: # old action num: 39 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 64: # old action num: 40 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 65: # old action num: 41 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 66: # old action num: 42 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 67: # old action num: 43 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 68: # old action num: 44 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 69: # old action num: 45 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 70: # old action num: 46 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 71: # old action num: 47 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 72: # old action num: 48 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 73: # old action num: 49 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 74: # old action num: 50 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 75: # old action num: 51 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 76: # old action num: 52 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 6 - nic_id: 0 + node_name: client_2 + nic_num: 1 77: # old action num: 53 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 6 - nic_id: 0 + node_name: client_2 + nic_num: 1 - options: - nodes: - - node_name: domain_controller - - node_name: web_server - applications: - - application_name: DatabaseClient - services: - - service_name: WebServer - - node_name: database_server - folders: - - folder_name: database - files: - - file_name: database.db - services: - - service_name: DatabaseService - - node_name: backup_server - - node_name: security_suite - - node_name: client_1 - - node_name: client_2 - - max_folders_per_node: 2 - max_files_per_folder: 2 - max_services_per_node: 2 - max_nics_per_node: 8 - max_acl_rules: 10 - ip_list: - - 192.168.1.10 - - 192.168.1.12 - - 192.168.1.14 - - 192.168.1.16 - - 192.168.1.110 - - 192.168.10.21 - - 192.168.10.22 - - 192.168.10.110 - reward_function: reward_components: - - type: DATABASE_FILE_INTEGRITY + - type: database-file-integrity weight: 0.40 options: node_hostname: database_server folder_name: database file_name: database.db - - type: SHARED_REWARD + - type: shared-reward weight: 1.0 options: agent_name: client_1_green_user - - type: SHARED_REWARD + - type: shared-reward weight: 1.0 options: agent_name: client_2_green_user @@ -1379,7 +1200,7 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 services: - - type: DNSServer + - type: dns-server options: domain_mapping: arcd.com: 192.168.1.12 # web server @@ -1391,9 +1212,9 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: WebServer + - type: web-server applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 @@ -1405,10 +1226,10 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: DatabaseService + - type: database-service options: backup_server_ip: 192.168.1.16 - - type: FTPClient + - type: ftp-client - hostname: backup_server type: server @@ -1417,7 +1238,7 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: FTPServer + - type: ftp-server - hostname: security_suite type: server @@ -1437,20 +1258,20 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: DataManipulationBot + - type: data-manipulation-bot options: port_scan_p_of_success: 0.8 data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 - - type: WebBrowser + - type: web-browser options: target_url: http://arcd.com/users/ - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 services: - - type: DNSClient + - type: dns-client - hostname: client_2 type: computer @@ -1459,20 +1280,20 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: WebBrowser + - type: web-browser options: target_url: http://arcd.com/users/ - - type: DataManipulationBot + - type: data-manipulation-bot options: port_scan_p_of_success: 0.8 data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 services: - - type: DNSClient + - type: dns-client diff --git a/src/primaite/config/_package_data/mini_scenario_with_simulation_variation/base_scenario.yaml b/src/primaite/config/_package_data/mini_scenario_with_simulation_variation/base_scenario.yaml index b4457a28..2ea18867 100644 --- a/src/primaite/config/_package_data/mini_scenario_with_simulation_variation/base_scenario.yaml +++ b/src/primaite/config/_package_data/mini_scenario_with_simulation_variation/base_scenario.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + game: max_episode_length: 128 ports: [] @@ -5,69 +8,49 @@ game: agents: - ref: RL_Agent - type: ProxyAgent - observation_space: null + type: proxy-agent + action_space: - action_list: - - type: DONOTHING - - type: NODE_SHUTDOWN - - type: NODE_STARTUP - - type: HOST_NIC_ENABLE - - type: HOST_NIC_DISABLE action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_SHUTDOWN + action: node-shutdown options: - node_id: 0 + node_name: client_1 2: - action: NODE_SHUTDOWN + action: node-shutdown options: - node_id: 1 + node_name: server 3: - action: NODE_STARTUP + action: node-startup options: - node_id: 0 + node_name: client_1 4: - action: NODE_STARTUP + action: node-startup options: - node_id: 1 + node_name: server 5: - action: HOST_NIC_DISABLE + action: host-nic-disable options: - node_id: 0 - nic_id: 0 + node_name: client_1 + nic_num: 1 6: - action: HOST_NIC_DISABLE + action: host-nic-disable options: - node_id: 1 - nic_id: 0 + node_name: server + nic_num: 1 7: - action: HOST_NIC_ENABLE + action: host-nic-enable options: - node_id: 0 - nic_id: 0 + node_name: client_1 + nic_num: 1 8: - action: HOST_NIC_ENABLE + action: host-nic-enable options: - node_id: 1 - nic_id: 0 - options: - nodes: - - node_name: client_1 - - node_name: server - max_folders_per_node: 0 - max_files_per_folder: 0 - max_services_per_node: 0 - max_nics_per_node: 1 - max_acl_rules: 0 - ip_list: - - 192.168.1.2 - - 192.168.1.3 - reward_function: - reward_components: [] + node_name: server + nic_num: 1 simulation: network: diff --git a/src/primaite/config/_package_data/mini_scenario_with_simulation_variation/simulation_variant_1.yaml b/src/primaite/config/_package_data/mini_scenario_with_simulation_variation/simulation_variant_1.yaml index 3e27cc27..5a976294 100644 --- a/src/primaite/config/_package_data/mini_scenario_with_simulation_variation/simulation_variant_1.yaml +++ b/src/primaite/config/_package_data/mini_scenario_with_simulation_variation/simulation_variant_1.yaml @@ -1,5 +1,5 @@ server_services: &server_services - - type: DatabaseService + - type: database-service client_applications: &client_applications - - type: DatabaseClient + - type: database-client diff --git a/src/primaite/config/_package_data/mini_scenario_with_simulation_variation/simulation_variant_2.yaml b/src/primaite/config/_package_data/mini_scenario_with_simulation_variation/simulation_variant_2.yaml index 207e0c73..8b89e9f6 100644 --- a/src/primaite/config/_package_data/mini_scenario_with_simulation_variation/simulation_variant_2.yaml +++ b/src/primaite/config/_package_data/mini_scenario_with_simulation_variation/simulation_variant_2.yaml @@ -1,5 +1,5 @@ server_services: &server_services - - type: FTPServer + - type: ftp-server client_applications: &client_applications - - type: RansomwareScript + - type: ransomware-script diff --git a/src/primaite/config/_package_data/multi_lan_internet_network_example.yaml b/src/primaite/config/_package_data/multi_lan_internet_network_example.yaml index 61562418..deaef3bd 100644 --- a/src/primaite/config/_package_data/multi_lan_internet_network_example.yaml +++ b/src/primaite/config/_package_data/multi_lan_internet_network_example.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + game: ports: - ARP @@ -20,10 +23,10 @@ simulation: default_gateway: 192.168.1.1 dns_server: 8.8.8.2 applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 10.10.1.11 - - type: WebBrowser + - type: web-browser options: target_url: http://sometech.ai/users/ @@ -34,10 +37,10 @@ simulation: default_gateway: 192.168.1.1 dns_server: 8.8.8.2 applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 10.10.1.11 - - type: WebBrowser + - type: web-browser options: target_url: http://sometech.ai/users/ @@ -102,8 +105,7 @@ simulation: subnet_mask: 255.255.255.252 default_gateway: 8.8.8.1 services: - - ref: dns_server - type: DNSServer + - type: dns-server options: domain_mapping: sometech.ai: 94.10.180.6 @@ -150,7 +152,7 @@ simulation: dst_ip: 94.10.180.6 dst_port: POSTGRES_SERVER dst_wildcard_mask: 0.0.0.0 - 8: # Permit SomeTech DMZ to use ARP + 8: # Permit SomeTech DMZ to use arp action: PERMIT src_port: ARP dst_port: ARP @@ -170,7 +172,7 @@ simulation: dst_ip: 10.10.1.11 dst_port: POSTGRES_SERVER dst_wildcard_mask: 0.0.0.0 - 8: # Permit SomeTech DMZ to use ARP + 8: # Permit SomeTech DMZ to use arp action: PERMIT src_port: ARP dst_port: ARP @@ -196,10 +198,9 @@ simulation: default_gateway: 94.10.180.5 dns_server: 8.8.8.2 services: - - ref: web_server - type: WebServer + - type: web-server applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 10.10.1.11 @@ -269,12 +270,12 @@ simulation: action: PERMIT src_port: HTTP dst_port: HTTP - 18: # Allow the SomeTech internal network to use ARP + 18: # Allow the SomeTech internal network to use arp action: PERMIT src_ip: 10.10.0.0 src_wildcard_mask: 0.0.255.255 src_port: ARP - 19: # Allow the SomeTech internal network to use ICMP + 19: # Allow the SomeTech internal network to use icmp action: PERMIT src_ip: 10.10.0.0 src_wildcard_mask: 0.0.255.255 @@ -318,10 +319,10 @@ simulation: default_gateway: 10.10.1.1 dns_server: 8.8.8.2 services: - - type: DatabaseService + - type: database-service options: backup_server_ip: 10.10.1.12 # The some_tech_storage_srv server - - type: FTPClient + - type: ftp-client - hostname: some_tech_storage_srv type: server @@ -330,7 +331,7 @@ simulation: default_gateway: 10.10.1.1 dns_server: 8.8.8.2 services: - - type: FTPServer + - type: ftp-server - hostname: some_tech_hr_1 type: computer @@ -339,10 +340,10 @@ simulation: default_gateway: 10.10.3.1 dns_server: 8.8.8.2 applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 10.10.1.11 - - type: WebBrowser + - type: web-browser options: target_url: http://sometech.ai/users/ @@ -353,10 +354,10 @@ simulation: default_gateway: 10.10.2.1 dns_server: 8.8.8.2 applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 10.10.1.11 - - type: WebBrowser + - type: web-browser options: target_url: http://sometech.ai/users/ @@ -367,10 +368,10 @@ simulation: default_gateway: 10.10.2.1 dns_server: 8.8.8.2 applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 10.10.1.11 - - type: WebBrowser + - type: web-browser options: target_url: http://sometech.ai/users/ diff --git a/src/primaite/config/_package_data/scenario_with_placeholders/greens_1.yaml b/src/primaite/config/_package_data/scenario_with_placeholders/greens_1.yaml index 98d2392a..3f9b65f4 100644 --- a/src/primaite/config/_package_data/scenario_with_placeholders/greens_1.yaml +++ b/src/primaite/config/_package_data/scenario_with_placeholders/greens_1.yaml @@ -1,34 +1,26 @@ agents: &greens - ref: green_A team: GREEN - type: ProbabilisticAgent + type: probabilistic-agent agent_settings: action_probabilities: 0: 0.2 1: 0.8 - observation_space: null + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client - applications: - - application_name: DatabaseClient action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 + node_name: client + application_name: database-client reward_function: reward_components: - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + - type: green-admin-database-unreachable-penalty weight: 1.0 options: node_hostname: client diff --git a/src/primaite/config/_package_data/scenario_with_placeholders/greens_2.yaml b/src/primaite/config/_package_data/scenario_with_placeholders/greens_2.yaml index 17a5977b..77a689e7 100644 --- a/src/primaite/config/_package_data/scenario_with_placeholders/greens_2.yaml +++ b/src/primaite/config/_package_data/scenario_with_placeholders/greens_2.yaml @@ -1,34 +1,26 @@ agents: &greens - ref: green_B team: GREEN - type: ProbabilisticAgent + type: probabilistic-agent agent_settings: action_probabilities: 0: 0.95 1: 0.05 - observation_space: null + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client - applications: - - application_name: DatabaseClient action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 + node_name: client + application_name: database-client reward_function: reward_components: - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + - type: green-admin-database-unreachable-penalty weight: 1.0 options: node_hostname: client diff --git a/src/primaite/config/_package_data/scenario_with_placeholders/reds_1.yaml b/src/primaite/config/_package_data/scenario_with_placeholders/reds_1.yaml index 31675a0b..b95955b4 100644 --- a/src/primaite/config/_package_data/scenario_with_placeholders/reds_1.yaml +++ b/src/primaite/config/_package_data/scenario_with_placeholders/reds_1.yaml @@ -1,26 +1,11 @@ reds: &reds - ref: red_A team: RED - type: RedDatabaseCorruptingAgent - - observation_space: null - - action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client - applications: - - application_name: DataManipulationBot - - reward_function: - reward_components: - - type: DUMMY + type: red-database-corrupting-agent agent_settings: - start_settings: - start_step: 10 - frequency: 10 - variance: 0 + possible_start_nodes: [client,] + target_application: data-manipulation-bot + start_step: 10 + frequency: 10 + variance: 0 diff --git a/src/primaite/config/_package_data/scenario_with_placeholders/reds_2.yaml b/src/primaite/config/_package_data/scenario_with_placeholders/reds_2.yaml index c5572b89..653051c6 100644 --- a/src/primaite/config/_package_data/scenario_with_placeholders/reds_2.yaml +++ b/src/primaite/config/_package_data/scenario_with_placeholders/reds_2.yaml @@ -1,26 +1,11 @@ reds: &reds - ref: red_B team: RED - type: RedDatabaseCorruptingAgent - - observation_space: null - - action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client - applications: - - application_name: DataManipulationBot - - reward_function: - reward_components: - - type: DUMMY + type: red-database-corrupting-agent agent_settings: - start_settings: - start_step: 3 - frequency: 2 - variance: 1 + possible_start_nodes: [client_1] + target_application: data-manipulation-bot + start_step: 3 + frequency: 2 + variance: 1 diff --git a/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml b/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml index dfd200f3..4ec3d257 100644 --- a/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml +++ b/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + io_settings: save_agent_actions: true save_step_metadata: false @@ -26,12 +29,12 @@ agents: - ref: defender team: BLUE - type: ProxyAgent + type: proxy-agent observation_space: - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: routers: [] @@ -46,7 +49,7 @@ agents: include_num_access: false include_nmne: true - - type: LINKS + - type: links label: LINKS options: link_references: @@ -54,69 +57,50 @@ agents: - server:eth-1<->switch_1:eth-2 action_space: - action_list: - - type: DONOTHING - - type: NODE_SHUTDOWN - - type: NODE_STARTUP - - type: HOST_NIC_ENABLE - - type: HOST_NIC_DISABLE action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_SHUTDOWN + action: node-shutdown options: - node_id: 0 + node_name: client_1 2: - action: NODE_SHUTDOWN + action: node-shutdown options: - node_id: 1 + node_name: server 3: - action: NODE_STARTUP + action: node-startup options: - node_id: 0 + node_name: client_1 4: - action: NODE_STARTUP + action: node-startup options: - node_id: 1 + node_name: server 5: - action: HOST_NIC_DISABLE + action: host-nic-disable options: - node_id: 0 - nic_id: 0 + node_name: client_1 + nic_num: 1 6: - action: HOST_NIC_DISABLE + action: host-nic-disable options: - node_id: 1 - nic_id: 0 + node_name: server + nic_num: 1 7: - action: HOST_NIC_ENABLE + action: host-nic-enable options: - node_id: 0 - nic_id: 0 + node_name: client_1 + nic_num: 1 8: - action: HOST_NIC_ENABLE + action: host-nic-enable options: - node_id: 1 - nic_id: 0 - options: - nodes: - - node_name: client - - node_name: server - - max_folders_per_node: 0 - max_files_per_folder: 0 - max_services_per_node: 0 - max_nics_per_node: 1 - max_acl_rules: 0 - ip_list: - - 192.168.1.2 - - 192.168.1.3 + node_name: server + nic_num: 1 reward_function: reward_components: - - type: DATABASE_FILE_INTEGRITY + - type: database-file-integrity weight: 0.40 options: node_hostname: database_server @@ -140,10 +124,10 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.3 - - type: DataManipulationBot + - type: data-manipulation-bot options: server_ip: 192.168.1.3 payload: "DELETE" @@ -158,7 +142,7 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 services: - - type: DatabaseService + - type: database-service links: - endpoint_a_hostname: client diff --git a/src/primaite/config/load.py b/src/primaite/config/load.py index 144e0733..3553f527 100644 --- a/src/primaite/config/load.py +++ b/src/primaite/config/load.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from pathlib import Path from typing import Dict, Final, Union @@ -59,3 +59,18 @@ def data_manipulation_marl_config_path() -> Path: _LOGGER.error(msg) raise FileNotFoundError(msg) return path + + +def get_extended_config_path() -> Path: + """ + Get the path to an 'extended' example config that contains nodes using the extension framework. + + :return: Path to the extended example config + :rtype: Path + """ + path = _EXAMPLE_CFG / "extended_config.yaml" + if not path.exists(): + msg = f"Example config does not exist: {path}. Have you run `primaite setup`?" + _LOGGER.error(msg) + raise FileNotFoundError(msg) + return path diff --git a/src/primaite/exceptions.py b/src/primaite/exceptions.py index afc55271..4487111d 100644 --- a/src/primaite/exceptions.py +++ b/src/primaite/exceptions.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK class PrimaiteError(Exception): """The root PrimAITE Error.""" diff --git a/src/primaite/game/__init__.py b/src/primaite/game/__init__.py index 39034e92..57f96a56 100644 --- a/src/primaite/game/__init__.py +++ b/src/primaite/game/__init__.py @@ -1,2 +1,2 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """PrimAITE Game Layer.""" diff --git a/src/primaite/game/agent/__init__.py b/src/primaite/game/agent/__init__.py index be6c00e7..c005c173 100644 --- a/src/primaite/game/agent/__init__.py +++ b/src/primaite/game/agent/__init__.py @@ -1 +1,7 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from primaite.game.agent.interface import ProxyAgent +from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent +from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent +from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent, RandomAgent + +__all__ = ("ProbabilisticAgent", "ProxyAgent", "RandomAgent", "PeriodicAgent", "DataManipulationAgent") diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py deleted file mode 100644 index 2439ccc4..00000000 --- a/src/primaite/game/agent/actions.py +++ /dev/null @@ -1,1861 +0,0 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -""" -This module contains the ActionManager class which belongs to the Agent class. - -An agent's action space is made up of a collection of actions. Each action is an instance of a subclass of -AbstractAction. The ActionManager is responsible for: - 1. Creating the action space from a list of action types. - 2. Converting an integer action choice into a specific action and parameter choice. - 3. Converting an action and parameter choice into a request which can be ingested by the PrimAITE simulation. This - ensures that requests conform to the simulator's request format. -""" -import itertools -from abc import ABC, abstractmethod -from typing import Dict, List, Literal, Optional, Tuple, TYPE_CHECKING, Union - -from gymnasium import spaces -from pydantic import BaseModel, ConfigDict, Field, field_validator, ValidationInfo - -from primaite import getLogger -from primaite.interface.request import RequestFormat - -_LOGGER = getLogger(__name__) - -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame - - -class AbstractAction(ABC): - """Base class for actions.""" - - @abstractmethod - def __init__(self, manager: "ActionManager", **kwargs) -> None: - """ - Init method for action. - - All action init functions should accept **kwargs as a way of ignoring extra arguments. - - Since many parameters are defined for the action space as a whole (such as max files per folder, max services - per node), we need to pass those options to every action that gets created. To prevent verbosity, these - parameters are just broadcasted to all actions and the actions can pay attention to the ones that apply. - """ - self.name: str = "" - """Human-readable action identifier used for printing, logging, and reporting.""" - self.shape: Dict[str, int] = {} - """Dictionary describing the number of options for each parameter of this action. The keys of this dict must - align with the keyword args of the form_request method.""" - self.manager: ActionManager = manager - """Reference to the ActionManager which created this action. This is used to access the game and simulation - objects.""" - - @abstractmethod - def form_request(self) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return [] - - -class DoNothingAction(AbstractAction): - """Action which does nothing. This is here to allow agents to be idle if they choose to.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - self.name = "DONOTHING" - self.shape: Dict[str, int] = { - "dummy": 1, - } - # This action does not accept any parameters, therefore it technically has a gymnasium shape of Discrete(1), - # i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter - # with one option. This just aids the Action Manager to enumerate all possibilities. - - def form_request(self, **kwargs) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return ["do_nothing"] - - -class NodeServiceAbstractAction(AbstractAction): - """ - Base class for service actions. - - Any action which applies to a service and uses node_id and service_id as its only two parameters can inherit from - this base class. - """ - - @abstractmethod - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services} - self.verb: str # define but don't initialise: defends against children classes not defining this - - def form_request(self, node_id: int, service_id: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - service_name = self.manager.get_service_name_by_idx(node_id, service_id) - if node_name is None or service_name is None: - return ["do_nothing"] - return ["network", "node", node_name, "service", service_name, self.verb] - - -class NodeServiceScanAction(NodeServiceAbstractAction): - """Action which scans a service.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb: str = "scan" - - -class NodeServiceStopAction(NodeServiceAbstractAction): - """Action which stops a service.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb: str = "stop" - - -class NodeServiceStartAction(NodeServiceAbstractAction): - """Action which starts a service.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb: str = "start" - - -class NodeServicePauseAction(NodeServiceAbstractAction): - """Action which pauses a service.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb: str = "pause" - - -class NodeServiceResumeAction(NodeServiceAbstractAction): - """Action which resumes a service.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb: str = "resume" - - -class NodeServiceRestartAction(NodeServiceAbstractAction): - """Action which restarts a service.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb: str = "restart" - - -class NodeServiceDisableAction(NodeServiceAbstractAction): - """Action which disables a service.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb: str = "disable" - - -class NodeServiceEnableAction(NodeServiceAbstractAction): - """Action which enables a service.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb: str = "enable" - - -class NodeServiceFixAction(NodeServiceAbstractAction): - """Action which fixes a service.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) - self.verb: str = "fix" - - -class NodeApplicationAbstractAction(AbstractAction): - """ - Base class for application actions. - - Any action which applies to an application and uses node_id and application_id as its only two parameters can - inherit from this base class. - """ - - @abstractmethod - def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None: - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"node_id": num_nodes, "application_id": num_applications} - self.verb: str # define but don't initialise: defends against children classes not defining this - - def form_request(self, node_id: int, application_id: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - application_name = self.manager.get_application_name_by_idx(node_id, application_id) - if node_name is None or application_name is None: - return ["do_nothing"] - return ["network", "node", node_name, "application", application_name, self.verb] - - -class NodeApplicationExecuteAction(NodeApplicationAbstractAction): - """Action which executes an application.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications) - self.verb: str = "execute" - - -class NodeApplicationScanAction(NodeApplicationAbstractAction): - """Action which scans an application.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications) - self.verb: str = "scan" - - -class NodeApplicationCloseAction(NodeApplicationAbstractAction): - """Action which closes an application.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications) - self.verb: str = "close" - - -class NodeApplicationFixAction(NodeApplicationAbstractAction): - """Action which fixes an application.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications) - self.verb: str = "fix" - - -class NodeApplicationInstallAction(AbstractAction): - """Action which installs an application.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"node_id": num_nodes} - - def form_request(self, node_id: int, application_name: str) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - return [ - "network", - "node", - node_name, - "software_manager", - "application", - "install", - application_name, - ] - - -class ConfigureDatabaseClientAction(AbstractAction): - """Action which sets config parameters for a database client on a node.""" - - class _Opts(BaseModel): - """Schema for options that can be passed to this action.""" - - model_config = ConfigDict(extra="forbid") - server_ip_address: Optional[str] = None - server_password: Optional[str] = None - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int, config: Dict) -> RequestFormat: - """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - ConfigureDatabaseClientAction._Opts.model_validate(config) # check that options adhere to schema - return ["network", "node", node_name, "application", "DatabaseClient", "configure", config] - - -class ConfigureRansomwareScriptAction(AbstractAction): - """Action which sets config parameters for a ransomware script on a node.""" - - class _Opts(BaseModel): - """Schema for options that can be passed to this option.""" - - model_config = ConfigDict(extra="forbid") - server_ip_address: Optional[str] = None - server_password: Optional[str] = None - payload: Optional[str] = None - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int, config: Dict) -> RequestFormat: - """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - ConfigureRansomwareScriptAction._Opts.model_validate(config) # check that options adhere to schema - return ["network", "node", node_name, "application", "RansomwareScript", "configure", config] - - -class ConfigureDoSBotAction(AbstractAction): - """Action which sets config parameters for a DoS bot on a node.""" - - class _Opts(BaseModel): - """Schema for options that can be passed to this action.""" - - model_config = ConfigDict(extra="forbid") - target_ip_address: Optional[str] = None - target_port: Optional[str] = None - payload: Optional[str] = None - repeat: Optional[bool] = None - port_scan_p_of_success: Optional[float] = None - dos_intensity: Optional[float] = None - max_sessions: Optional[int] = None - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int, config: Dict) -> RequestFormat: - """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - self._Opts.model_validate(config) # check that options adhere to schema - return ["network", "node", node_name, "application", "DoSBot", "configure", config] - - -class NodeApplicationRemoveAction(AbstractAction): - """Action which removes/uninstalls an application.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"node_id": num_nodes} - - def form_request(self, node_id: int, application_name: str) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - return ["network", "node", node_name, "software_manager", "application", "uninstall", application_name] - - -class NodeFolderAbstractAction(AbstractAction): - """ - Base class for folder actions. - - Any action which applies to a folder and uses node_id and folder_id as its only two parameters can inherit from - this base class. - """ - - @abstractmethod - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders} - self.verb: str # define but don't initialise: defends against children classes not defining this - - def form_request(self, node_id: int, folder_id: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id) - if node_name is None or folder_name is None: - return ["do_nothing"] - return ["network", "node", node_name, "file_system", "folder", folder_name, self.verb] - - -class NodeFolderScanAction(NodeFolderAbstractAction): - """Action which scans a folder.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) - self.verb: str = "scan" - - -class NodeFolderCheckhashAction(NodeFolderAbstractAction): - """Action which checks the hash of a folder.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) - self.verb: str = "checkhash" - - -class NodeFolderRepairAction(NodeFolderAbstractAction): - """Action which repairs a folder.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) - self.verb: str = "repair" - - -class NodeFolderRestoreAction(NodeFolderAbstractAction): - """Action which restores a folder.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) - self.verb: str = "restore" - - -class NodeFileCreateAction(AbstractAction): - """Action which creates a new file in a given folder.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) - self.verb: str = "create" - - def form_request( - self, node_id: int, folder_name: str, file_name: str, force: Optional[bool] = False - ) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None or folder_name is None or file_name is None: - return ["do_nothing"] - return ["network", "node", node_name, "file_system", "create", "file", folder_name, file_name, force] - - -class NodeFolderCreateAction(AbstractAction): - """Action which creates a new folder.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) - self.verb: str = "create" - - def form_request(self, node_id: int, folder_name: str) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None or folder_name is None: - return ["do_nothing"] - return ["network", "node", node_name, "file_system", "create", "folder", folder_name] - - -class NodeFileAbstractAction(AbstractAction): - """Abstract base class for file actions. - - Any action which applies to a file and uses node_id, folder_id, and file_id as its only three parameters can inherit - from this base class. - """ - - @abstractmethod - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files} - self.verb: str # define but don't initialise: defends against children classes not defining this - - def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id) - file_name = self.manager.get_file_name_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id) - if node_name is None or folder_name is None or file_name is None: - return ["do_nothing"] - return ["network", "node", node_name, "file_system", "folder", folder_name, "file", file_name, self.verb] - - -class NodeFileScanAction(NodeFileAbstractAction): - """Action which scans a file.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb: str = "scan" - - -class NodeFileCheckhashAction(NodeFileAbstractAction): - """Action which checks the hash of a file.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb: str = "checkhash" - - -class NodeFileDeleteAction(NodeFileAbstractAction): - """Action which deletes a file.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb: str = "delete" - - def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id) - file_name = self.manager.get_file_name_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id) - if node_name is None or folder_name is None or file_name is None: - return ["do_nothing"] - return ["network", "node", node_name, "file_system", "delete", "file", folder_name, file_name] - - -class NodeFileRepairAction(NodeFileAbstractAction): - """Action which repairs a file.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb: str = "repair" - - -class NodeFileRestoreAction(NodeFileAbstractAction): - """Action which restores a file.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb: str = "restore" - - -class NodeFileCorruptAction(NodeFileAbstractAction): - """Action which corrupts a file.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) - self.verb: str = "corrupt" - - -class NodeFileAccessAction(AbstractAction): - """Action which increases a file's access count.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None: - super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) - self.verb: str = "access" - - def form_request(self, node_id: int, folder_name: str, file_name: str) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None or folder_name is None or file_name is None: - return ["do_nothing"] - return ["network", "node", node_name, "file_system", "access", folder_name, file_name] - - -class NodeAbstractAction(AbstractAction): - """ - Abstract base class for node actions. - - Any action which applies to a node and uses node_id as its only parameter can inherit from this base class. - """ - - @abstractmethod - def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"node_id": num_nodes} - self.verb: str # define but don't initialise: defends against children classes not defining this - - def form_request(self, node_id: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - return ["network", "node", node_name, self.verb] - - -class NodeOSScanAction(NodeAbstractAction): - """Action which scans a node's OS.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes) - self.verb: str = "scan" - - -class NodeShutdownAction(NodeAbstractAction): - """Action which shuts down a node.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes) - self.verb: str = "shutdown" - - -class NodeStartupAction(NodeAbstractAction): - """Action which starts up a node.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes) - self.verb: str = "startup" - - -class NodeResetAction(NodeAbstractAction): - """Action which resets a node.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes) - self.verb: str = "reset" - - -class RouterACLAddRuleAction(AbstractAction): - """Action which adds a rule to a router's ACL.""" - - class ACLRuleOptions(BaseModel): - """Validator for ACL_ADD_RULE options.""" - - target_router: str - """On which router to add the rule, must be specified.""" - position: int - """At what position to add the rule, must be specified.""" - permission: Literal[1, 2] - """Whether to allow or deny traffic, must be specified. 1 = PERMIT, 2 = DENY.""" - source_ip_id: int = Field(default=1, ge=1) - """Rule source IP address. By default, all ip addresses.""" - source_wildcard_id: int = Field(default=0, ge=0) - """Rule source IP wildcard. By default, use the wildcard at index 0 from action manager.""" - source_port_id: int = Field(default=1, ge=1) - """Rule source port. By default, all source ports.""" - dest_ip_id: int = Field(default=1, ge=1) - """Rule destination IP address. By default, all ip addresses.""" - dest_wildcard_id: int = Field(default=0, ge=0) - """Rule destination IP wildcard. By default, use the wildcard at index 0 from action manager.""" - dest_port_id: int = Field(default=1, ge=1) - """Rule destination port. By default, all destination ports.""" - protocol_id: int = Field(default=1, ge=1) - """Rule protocol. By default, all protocols.""" - - @field_validator( - "source_ip_id", - "source_port_id", - "source_wildcard_id", - "dest_ip_id", - "dest_port_id", - "dest_wildcard_id", - "protocol_id", - mode="before", - ) - @classmethod - def not_none(cls, v: str, info: ValidationInfo) -> int: - """If None is passed, use the default value instead.""" - if v is None: - return cls.model_fields[info.field_name].default - return v - - def __init__( - self, - manager: "ActionManager", - max_acl_rules: int, - num_ips: int, - num_ports: int, - num_protocols: int, - **kwargs, - ) -> None: - """Init method for RouterACLAddRuleAction. - - :param manager: Reference to the ActionManager which created this action. - :type manager: ActionManager - :param max_acl_rules: Maximum number of ACL rules that can be added to the router. - :type max_acl_rules: int - :param num_ips: Number of IP addresses in the simulation. - :type num_ips: int - :param num_ports: Number of ports in the simulation. - :type num_ports: int - :param num_protocols: Number of protocols in the simulation. - :type num_protocols: int - """ - super().__init__(manager=manager) - num_permissions = 3 - self.shape: Dict[str, int] = { - "position": max_acl_rules, - "permission": num_permissions, - "source_ip_id": num_ips, - "dest_ip_id": num_ips, - "source_port_id": num_ports, - "dest_port_id": num_ports, - "protocol_id": num_protocols, - } - - def form_request( - self, - target_router: str, - position: int, - permission: int, - source_ip_id: int, - source_wildcard_id: int, - dest_ip_id: int, - dest_wildcard_id: int, - source_port_id: int, - dest_port_id: int, - protocol_id: int, - ) -> List[str]: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - # Validate incoming data. - parsed_options = RouterACLAddRuleAction.ACLRuleOptions( - target_router=target_router, - position=position, - permission=permission, - source_ip_id=source_ip_id, - source_wildcard_id=source_wildcard_id, - dest_ip_id=dest_ip_id, - dest_wildcard_id=dest_wildcard_id, - source_port_id=source_port_id, - dest_port_id=dest_port_id, - protocol_id=protocol_id, - ) - if parsed_options.permission == 1: - permission_str = "PERMIT" - elif parsed_options.permission == 2: - permission_str = "DENY" - else: - _LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.") - - if parsed_options.protocol_id == 1: - protocol = "ALL" - else: - protocol = self.manager.get_internet_protocol_by_idx(parsed_options.protocol_id - 2) - # subtract 2 to account for UNUSED=0 and ALL=1. - - if parsed_options.source_ip_id == 1: - src_ip = "ALL" - else: - src_ip = self.manager.get_ip_address_by_idx(parsed_options.source_ip_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - - src_wildcard = self.manager.get_wildcard_by_idx(parsed_options.source_wildcard_id) - - if parsed_options.source_port_id == 1: - src_port = "ALL" - else: - src_port = self.manager.get_port_by_idx(parsed_options.source_port_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - - if parsed_options.dest_ip_id == 1: - dst_ip = "ALL" - else: - dst_ip = self.manager.get_ip_address_by_idx(parsed_options.dest_ip_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - dst_wildcard = self.manager.get_wildcard_by_idx(parsed_options.dest_wildcard_id) - - if parsed_options.dest_port_id == 1: - dst_port = "ALL" - else: - dst_port = self.manager.get_port_by_idx(parsed_options.dest_port_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - - return [ - "network", - "node", - target_router, - "acl", - "add_rule", - permission_str, - protocol, - str(src_ip), - src_wildcard, - src_port, - str(dst_ip), - dst_wildcard, - dst_port, - position, - ] - - -class RouterACLRemoveRuleAction(AbstractAction): - """Action which removes a rule from a router's ACL.""" - - def __init__(self, manager: "ActionManager", max_acl_rules: int, **kwargs) -> None: - """Init method for RouterACLRemoveRuleAction. - - :param manager: Reference to the ActionManager which created this action. - :type manager: ActionManager - :param max_acl_rules: Maximum number of ACL rules that can be added to the router. - :type max_acl_rules: int - """ - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"position": max_acl_rules} - - def form_request(self, target_router: str, position: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return ["network", "node", target_router, "acl", "remove_rule", position] - - -class FirewallACLAddRuleAction(AbstractAction): - """Action which adds a rule to a firewall port's ACL.""" - - def __init__( - self, - manager: "ActionManager", - max_acl_rules: int, - num_ips: int, - num_ports: int, - num_protocols: int, - **kwargs, - ) -> None: - """Init method for FirewallACLAddRuleAction. - - :param manager: Reference to the ActionManager which created this action. - :type manager: ActionManager - :param max_acl_rules: Maximum number of ACL rules that can be added to the router. - :type max_acl_rules: int - :param num_ips: Number of IP addresses in the simulation. - :type num_ips: int - :param num_ports: Number of ports in the simulation. - :type num_ports: int - :param num_protocols: Number of protocols in the simulation. - :type num_protocols: int - """ - super().__init__(manager=manager) - num_permissions = 3 - self.shape: Dict[str, int] = { - "position": max_acl_rules, - "permission": num_permissions, - "source_ip_id": num_ips, - "dest_ip_id": num_ips, - "source_port_id": num_ports, - "dest_port_id": num_ports, - "protocol_id": num_protocols, - } - - def form_request( - self, - target_firewall_nodename: str, - firewall_port_name: str, - firewall_port_direction: str, - position: int, - permission: int, - source_ip_id: int, - source_wildcard_id: int, - dest_ip_id: int, - dest_wildcard_id: int, - source_port_id: int, - dest_port_id: int, - protocol_id: int, - ) -> List[str]: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - if permission == 0: - permission_str = "UNUSED" - return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS - elif permission == 1: - permission_str = "PERMIT" - elif permission == 2: - permission_str = "DENY" - else: - _LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.") - - if protocol_id == 0: - return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS - - if protocol_id == 1: - protocol = "ALL" - else: - protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2) - # subtract 2 to account for UNUSED=0 and ALL=1. - - if source_ip_id == 0: - return ["do_nothing"] # invalid formulation - elif source_ip_id == 1: - src_ip = "ALL" - else: - src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - - if source_port_id == 0: - return ["do_nothing"] # invalid formulation - elif source_port_id == 1: - src_port = "ALL" - else: - src_port = self.manager.get_port_by_idx(source_port_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - - if dest_ip_id == 0: - return ["do_nothing"] # invalid formulation - elif dest_ip_id == 1: - dst_ip = "ALL" - else: - dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - - if dest_port_id == 0: - return ["do_nothing"] # invalid formulation - elif dest_port_id == 1: - dst_port = "ALL" - else: - dst_port = self.manager.get_port_by_idx(dest_port_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - src_wildcard = self.manager.get_wildcard_by_idx(source_wildcard_id) - dst_wildcard = self.manager.get_wildcard_by_idx(dest_wildcard_id) - - return [ - "network", - "node", - target_firewall_nodename, - firewall_port_name, - firewall_port_direction, - "acl", - "add_rule", - permission_str, - protocol, - str(src_ip), - src_wildcard, - src_port, - str(dst_ip), - dst_wildcard, - dst_port, - position, - ] - - -class FirewallACLRemoveRuleAction(AbstractAction): - """Action which removes a rule from a firewall port's ACL.""" - - def __init__(self, manager: "ActionManager", max_acl_rules: int, **kwargs) -> None: - """Init method for FirewallACLRemoveRuleAction. - - :param manager: Reference to the ActionManager which created this action. - :type manager: ActionManager - :param max_acl_rules: Maximum number of ACL rules that can be added to the router. - :type max_acl_rules: int - """ - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"position": max_acl_rules} - - def form_request( - self, target_firewall_nodename: str, firewall_port_name: str, firewall_port_direction: str, position: int - ) -> List[str]: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return [ - "network", - "node", - target_firewall_nodename, - firewall_port_name, - firewall_port_direction, - "acl", - "remove_rule", - position, - ] - - -class HostNICAbstractAction(AbstractAction): - """ - Abstract base class for NIC actions. - - Any action which applies to a NIC and uses node_id and nic_id as its only two parameters can inherit from this base - class. - """ - - def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: - """Init method for HostNICAbstractAction. - - :param manager: Reference to the ActionManager which created this action. - :type manager: ActionManager - :param num_nodes: Number of nodes in the simulation. - :type num_nodes: int - :param max_nics_per_node: Maximum number of NICs per node. - :type max_nics_per_node: int - """ - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node} - self.verb: str # define but don't initialise: defends against children classes not defining this - - def form_request(self, node_id: int, nic_id: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_idx=node_id) - nic_num = self.manager.get_nic_num_by_idx(node_idx=node_id, nic_idx=nic_id) - if node_name is None or nic_num is None: - return ["do_nothing"] - return ["network", "node", node_name, "network_interface", nic_num, self.verb] - - -class HostNICEnableAction(HostNICAbstractAction): - """Action which enables a NIC.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs) - self.verb: str = "enable" - - -class HostNICDisableAction(HostNICAbstractAction): - """Action which disables a NIC.""" - - def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: - super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs) - self.verb: str = "disable" - - -class NetworkPortEnableAction(AbstractAction): - """Action which enables are port on a router or a firewall.""" - - def __init__(self, manager: "ActionManager", max_nics_per_node: int, **kwargs) -> None: - """Init method for NetworkPortEnableAction. - - :param max_nics_per_node: Maximum number of NICs per node. - :type max_nics_per_node: int - """ - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"port_id": max_nics_per_node} - - def form_request(self, target_nodename: str, port_id: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - if target_nodename is None or port_id is None: - return ["do_nothing"] - return ["network", "node", target_nodename, "network_interface", port_id, "enable"] - - -class NetworkPortDisableAction(AbstractAction): - """Action which disables are port on a router or a firewall.""" - - def __init__(self, manager: "ActionManager", max_nics_per_node: int, **kwargs) -> None: - """Init method for NetworkPortDisableAction. - - :param max_nics_per_node: Maximum number of NICs per node. - :type max_nics_per_node: int - """ - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"port_id": max_nics_per_node} - - def form_request(self, target_nodename: str, port_id: int) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - if target_nodename is None or port_id is None: - return ["do_nothing"] - return ["network", "node", target_nodename, "network_interface", port_id, "disable"] - - -class NodeNMAPPingScanAction(AbstractAction): - """Action which performs an NMAP ping scan.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request( - self, source_node: str, target_ip_address: Union[str, List[str]], show: Optional[bool] = False - ) -> List[str]: # noqa - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return [ - "network", - "node", - source_node, - "application", - "NMAP", - "ping_scan", - {"target_ip_address": target_ip_address, "show": show}, - ] - - -class NodeNMAPPortScanAction(AbstractAction): - """Action which performs an NMAP port scan.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request( - self, - source_node: str, - target_ip_address: Union[str, List[str]], - target_protocol: Optional[Union[str, List[str]]] = None, - target_port: Optional[Union[str, List[str]]] = None, - show: Optional[bool] = False, - ) -> List[str]: # noqa - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return [ - "network", - "node", - source_node, - "application", - "NMAP", - "port_scan", - { - "target_ip_address": target_ip_address, - "target_port": target_port, - "target_protocol": target_protocol, - "show": show, - }, - ] - - -class NodeNetworkServiceReconAction(AbstractAction): - """Action which performs an NMAP network service recon (ping scan followed by port scan).""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request( - self, - source_node: str, - target_ip_address: Union[str, List[str]], - target_protocol: Optional[Union[str, List[str]]] = None, - target_port: Optional[Union[str, List[str]]] = None, - show: Optional[bool] = False, - ) -> List[str]: # noqa - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return [ - "network", - "node", - source_node, - "application", - "NMAP", - "network_service_recon", - { - "target_ip_address": target_ip_address, - "target_port": target_port, - "target_protocol": target_protocol, - "show": show, - }, - ] - - -class ConfigureC2BeaconAction(AbstractAction): - """Action which configures a C2 Beacon based on the parameters given.""" - - class _Opts(BaseModel): - """Schema for options that can be passed to this action.""" - - c2_server_ip_address: str - keep_alive_frequency: int = Field(default=5, ge=1) - masquerade_protocol: str = Field(default="TCP") - masquerade_port: str = Field(default="HTTP") - - @field_validator( - "c2_server_ip_address", - "keep_alive_frequency", - "masquerade_protocol", - "masquerade_port", - mode="before", - ) - @classmethod - def not_none(cls, v: str, info: ValidationInfo) -> int: - """If None is passed, use the default value instead.""" - if v is None: - return cls.model_fields[info.field_name].default - return v - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int, config: Dict) -> RequestFormat: - """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - config = ConfigureC2BeaconAction._Opts( - c2_server_ip_address=config["c2_server_ip_address"], - keep_alive_frequency=config["keep_alive_frequency"], - masquerade_port=config["masquerade_port"], - masquerade_protocol=config["masquerade_protocol"], - ) - - ConfigureC2BeaconAction._Opts.model_validate(config) # check that options adhere to schema - - return ["network", "node", node_name, "application", "C2Beacon", "configure", config.__dict__] - - -class NodeAccountsAddUserAction(AbstractAction): - """Action which changes adds a User.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: str, username: str, password: str, is_admin: bool) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - return ["network", "node", node_name, "service", "UserManager", "add_user", username, password, is_admin] - - -class NodeAccountsDisableUserAction(AbstractAction): - """Action which disables a user.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: str, username: str) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - return [ - "network", - "node", - node_name, - "service", - "UserManager", - "disable_user", - username, - ] - - -class NodeAccountsChangePasswordAction(AbstractAction): - """Action which changes the password for a user.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: str, username: str, current_password: str, new_password: str) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - return [ - "network", - "node", - node_name, - "service", - "UserManager", - "change_password", - username, - current_password, - new_password, - ] - - -class NodeSessionsRemoteLoginAction(AbstractAction): - """Action which performs a remote session login.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: str, username: str, password: str, remote_ip: str) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - return [ - "network", - "node", - node_name, - "service", - "Terminal", - "ssh_to_remote", - username, - password, - remote_ip, - ] - - -class NodeSessionsRemoteLogoutAction(AbstractAction): - """Action which performs a remote session logout.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: str, remote_ip: str) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - return ["network", "node", node_name, "service", "Terminal", "remote_logoff", remote_ip] - - -class RansomwareConfigureC2ServerAction(AbstractAction): - """Action which sends a command from the C2 Server to the C2 Beacon which configures a local RansomwareScript.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int, config: Dict) -> RequestFormat: - """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - # Using the ransomware scripts model to validate. - ConfigureRansomwareScriptAction._Opts.model_validate(config) # check that options adhere to schema - return ["network", "node", node_name, "application", "C2Server", "ransomware_configure", config] - - -class RansomwareLaunchC2ServerAction(AbstractAction): - """Action which causes the C2 Server to send a command to the C2 Beacon to launch the RansomwareScript.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int) -> RequestFormat: - """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - # This action currently doesn't require any further configuration options. - return ["network", "node", node_name, "application", "C2Server", "ransomware_launch"] - - -class ExfiltrationC2ServerAction(AbstractAction): - """Action which exfiltrates a target file from a certain node onto the C2 beacon and then the C2 Server.""" - - class _Opts(BaseModel): - """Schema for options that can be passed to this action.""" - - username: Optional[str] - password: Optional[str] - target_ip_address: str - target_file_name: str - target_folder_name: str - exfiltration_folder_name: Optional[str] - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request( - self, - node_id: int, - account: dict, - target_ip_address: str, - target_file_name: str, - target_folder_name: str, - exfiltration_folder_name: Optional[str], - ) -> RequestFormat: - """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - - command_model = { - "target_file_name": target_file_name, - "target_folder_name": target_folder_name, - "exfiltration_folder_name": exfiltration_folder_name, - "target_ip_address": target_ip_address, - "username": account["username"], - "password": account["password"], - } - ExfiltrationC2ServerAction._Opts.model_validate(command_model) - return ["network", "node", node_name, "application", "C2Server", "exfiltrate", command_model] - - -class NodeSendRemoteCommandAction(AbstractAction): - """Action which sends a terminal command to a remote node via SSH.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int, remote_ip: str, command: RequestFormat) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - return [ - "network", - "node", - node_name, - "service", - "Terminal", - "send_remote_command", - remote_ip, - {"command": command}, - ] - - -class NodeSendLocalCommandAction(AbstractAction): - """Action which sends a terminal command using a local terminal session.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int, username: str, password: str, command: RequestFormat) -> RequestFormat: - """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - return [ - "network", - "node", - node_name, - "service", - "Terminal", - "send_local_command", - username, - password, - {"command": command}, - ] - - -class TerminalC2ServerAction(AbstractAction): - """Action which causes the C2 Server to send a command to the C2 Beacon to execute the terminal command passed.""" - - class _Opts(BaseModel): - """Schema for options that can be passed to this action.""" - - commands: Union[List[RequestFormat], RequestFormat] - ip_address: Optional[str] - username: Optional[str] - password: Optional[str] - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int, commands: List, ip_address: Optional[str], account: dict) -> RequestFormat: - """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - - command_model = { - "commands": commands, - "ip_address": ip_address, - "username": account["username"], - "password": account["password"], - } - - TerminalC2ServerAction._Opts.model_validate(command_model) - return ["network", "node", node_name, "application", "C2Server", "terminal_command", command_model] - - -class RansomwareLaunchC2ServerAction(AbstractAction): - """Action which causes the C2 Server to send a command to the C2 Beacon to launch the RansomwareScript.""" - - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int) -> RequestFormat: - """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: - return ["do_nothing"] - # This action currently doesn't require any further configuration options. - return ["network", "node", node_name, "application", "C2Server", "ransomware_launch"] - - -class ActionManager: - """Class which manages the action space for an agent.""" - - act_class_identifiers: Dict[str, type] = { - "DONOTHING": DoNothingAction, - "NODE_SERVICE_SCAN": NodeServiceScanAction, - "NODE_SERVICE_STOP": NodeServiceStopAction, - "NODE_SERVICE_START": NodeServiceStartAction, - "NODE_SERVICE_PAUSE": NodeServicePauseAction, - "NODE_SERVICE_RESUME": NodeServiceResumeAction, - "NODE_SERVICE_RESTART": NodeServiceRestartAction, - "NODE_SERVICE_DISABLE": NodeServiceDisableAction, - "NODE_SERVICE_ENABLE": NodeServiceEnableAction, - "NODE_SERVICE_FIX": NodeServiceFixAction, - "NODE_APPLICATION_EXECUTE": NodeApplicationExecuteAction, - "NODE_APPLICATION_SCAN": NodeApplicationScanAction, - "NODE_APPLICATION_CLOSE": NodeApplicationCloseAction, - "NODE_APPLICATION_FIX": NodeApplicationFixAction, - "NODE_APPLICATION_INSTALL": NodeApplicationInstallAction, - "NODE_APPLICATION_REMOVE": NodeApplicationRemoveAction, - "NODE_FILE_SCAN": NodeFileScanAction, - "NODE_FILE_CREATE": NodeFileCreateAction, - "NODE_FILE_CHECKHASH": NodeFileCheckhashAction, - "NODE_FILE_DELETE": NodeFileDeleteAction, - "NODE_FILE_REPAIR": NodeFileRepairAction, - "NODE_FILE_RESTORE": NodeFileRestoreAction, - "NODE_FILE_CORRUPT": NodeFileCorruptAction, - "NODE_FILE_ACCESS": NodeFileAccessAction, - "NODE_FOLDER_CREATE": NodeFolderCreateAction, - "NODE_FOLDER_SCAN": NodeFolderScanAction, - "NODE_FOLDER_CHECKHASH": NodeFolderCheckhashAction, - "NODE_FOLDER_REPAIR": NodeFolderRepairAction, - "NODE_FOLDER_RESTORE": NodeFolderRestoreAction, - "NODE_OS_SCAN": NodeOSScanAction, - "NODE_SHUTDOWN": NodeShutdownAction, - "NODE_STARTUP": NodeStartupAction, - "NODE_RESET": NodeResetAction, - "ROUTER_ACL_ADDRULE": RouterACLAddRuleAction, - "ROUTER_ACL_REMOVERULE": RouterACLRemoveRuleAction, - "FIREWALL_ACL_ADDRULE": FirewallACLAddRuleAction, - "FIREWALL_ACL_REMOVERULE": FirewallACLRemoveRuleAction, - "HOST_NIC_ENABLE": HostNICEnableAction, - "HOST_NIC_DISABLE": HostNICDisableAction, - "NETWORK_PORT_ENABLE": NetworkPortEnableAction, - "NETWORK_PORT_DISABLE": NetworkPortDisableAction, - "NODE_NMAP_PING_SCAN": NodeNMAPPingScanAction, - "NODE_NMAP_PORT_SCAN": NodeNMAPPortScanAction, - "NODE_NMAP_NETWORK_SERVICE_RECON": NodeNetworkServiceReconAction, - "CONFIGURE_DATABASE_CLIENT": ConfigureDatabaseClientAction, - "CONFIGURE_RANSOMWARE_SCRIPT": ConfigureRansomwareScriptAction, - "CONFIGURE_DOSBOT": ConfigureDoSBotAction, - "CONFIGURE_C2_BEACON": ConfigureC2BeaconAction, - "C2_SERVER_RANSOMWARE_LAUNCH": RansomwareLaunchC2ServerAction, - "C2_SERVER_RANSOMWARE_CONFIGURE": RansomwareConfigureC2ServerAction, - "C2_SERVER_TERMINAL_COMMAND": TerminalC2ServerAction, - "C2_SERVER_DATA_EXFILTRATE": ExfiltrationC2ServerAction, - "NODE_ACCOUNTS_ADD_USER": NodeAccountsAddUserAction, - "NODE_ACCOUNTS_DISABLE_USER": NodeAccountsDisableUserAction, - "NODE_ACCOUNTS_CHANGE_PASSWORD": NodeAccountsChangePasswordAction, - "SSH_TO_REMOTE": NodeSessionsRemoteLoginAction, - "SESSIONS_REMOTE_LOGOFF": NodeSessionsRemoteLogoutAction, - "NODE_SEND_REMOTE_COMMAND": NodeSendRemoteCommandAction, - "NODE_SEND_LOCAL_COMMAND": NodeSendLocalCommandAction, - } - """Dictionary which maps action type strings to the corresponding action class.""" - - def __init__( - self, - actions: List[Dict], # stores list of actions available to agent - nodes: List[Dict], # extra configuration for each node - max_folders_per_node: int = 2, # allows calculating shape - max_files_per_folder: int = 2, # allows calculating shape - max_services_per_node: int = 2, # allows calculating shape - max_applications_per_node: int = 2, # allows calculating shape - max_nics_per_node: int = 8, # allows calculating shape - max_acl_rules: int = 10, # allows calculating shape - protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol - ports: List[str] = ["HTTP", "DNS", "ARP", "FTP", "NTP"], # allow mapping index to port - ip_list: List[str] = [], # to allow us to map an index to an ip address. - wildcard_list: List[str] = [], # to allow mapping from wildcard index to - act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions - ) -> None: - """Init method for ActionManager. - - :param game: Reference to the game to which the agent belongs. - :type game: PrimaiteGame - :param actions: List of action specs which should be made available to the agent. The keys of each spec are: - 'type' and 'options' for passing any options to the action class's init method - :type actions: List[dict] - :param nodes: Extra configuration for each node. - :type nodes: List[Dict] - :param max_folders_per_node: Maximum number of folders per node. Used for calculating action shape. - :type max_folders_per_node: int - :param max_files_per_folder: Maximum number of files per folder. Used for calculating action shape. - :type max_files_per_folder: int - :param max_services_per_node: Maximum number of services per node. Used for calculating action shape. - :type max_services_per_node: int - :param max_nics_per_node: Maximum number of NICs per node. Used for calculating action shape. - :type max_nics_per_node: int - :param max_acl_rules: Maximum number of ACL rules per router. Used for calculating action shape. - :type max_acl_rules: int - :param protocols: List of protocols that are available in the simulation. Used for calculating action shape. - :type protocols: List[str] - :param ports: List of ports that are available in the simulation. Used for calculating action shape. - :type ports: List[str] - :param ip_list: List of IP addresses that known to this agent. Used for calculating action shape. - :type ip_list: Optional[List[str]] - :param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions. - :type act_map: Optional[Dict[int, Dict]] - """ - self.node_names: List[str] = [n["node_name"] for n in nodes] - """List of node names in this action space. The list order is the mapping between node index and node name.""" - self.application_names: List[List[str]] = [] - """ - List of applications per node. The list order gives the two-index mapping between (node_id, app_id) to app name. - The first index corresponds to node id, the second index is the app id on that particular node. - For instance, self.application_names[0][2] is the name of the third application on the first node. - """ - self.service_names: List[List[str]] = [] - """ - List of services per node. The list order gives the two-index mapping between (node_id, svc_id) to svc name. - The first index corresponds to node id, the second index is the service id on that particular node. - For instance, self.service_names[0][2] is the name of the third service on the first node. - """ - self.folder_names: List[List[str]] = [] - """ - List of folders per node. The list order gives the two-index mapping between (node_id, folder_id) to folder - name. The first index corresponds to node id, the second index is the folder id on that particular node. - For instance, self.folder_names[0][2] is the name of the third folder on the first node. - """ - self.file_names: List[List[List[str]]] = [] - """ - List of files per folder per node. The list order gives the three-index mapping between - (node_id, folder_id, file_id) to file name. The first index corresponds to node id, the second index is the - folder id on that particular node, and the third index is the file id in that particular folder. - For instance, self.file_names[0][2][1] is the name of the second file in the third folder on the first node. - """ - - # Populate lists of apps, services, files, folders, etc on nodes. - for node in nodes: - app_list = [a["application_name"] for a in node.get("applications", [])] - while len(app_list) < max_applications_per_node: - app_list.append(None) - self.application_names.append(app_list) - - svc_list = [s["service_name"] for s in node.get("services", [])] - while len(svc_list) < max_services_per_node: - svc_list.append(None) - self.service_names.append(svc_list) - - folder_list = [f["folder_name"] for f in node.get("folders", [])] - while len(folder_list) < max_folders_per_node: - folder_list.append(None) - self.folder_names.append(folder_list) - - file_sublist = [] - for folder in node.get("folders", [{"files": []}]): - file_list = [f["file_name"] for f in folder.get("files", [])] - while len(file_list) < max_files_per_folder: - file_list.append(None) - file_sublist.append(file_list) - while len(file_sublist) < max_folders_per_node: - file_sublist.append([None] * max_files_per_folder) - self.file_names.append(file_sublist) - self.protocols: List[str] = protocols - self.ports: List[str] = ports - - self.ip_address_list: List[str] = ip_list - self.wildcard_list: List[str] = wildcard_list - if self.wildcard_list == []: - self.wildcard_list = ["NONE"] - # action_args are settings which are applied to the action space as a whole. - global_action_args = { - "num_nodes": len(self.node_names), - "num_folders": max_folders_per_node, - "num_files": max_files_per_folder, - "num_services": max_services_per_node, - "num_applications": max_applications_per_node, - "num_nics": max_nics_per_node, - "num_acl_rules": max_acl_rules, - "num_protocols": len(self.protocols), - "num_ports": len(self.ports), - "num_ips": len(self.ip_address_list), - "max_acl_rules": max_acl_rules, - "max_nics_per_node": max_nics_per_node, - } - self.actions: Dict[str, AbstractAction] = {} - for act_spec in actions: - # each action is provided into the action space config like this: - # - type: ACTION_TYPE - # options: - # option_1: value1 - # option_2: value2 - # where `type` decides which AbstractAction subclass should be used - # and `options` is an optional dict of options to pass to the init method of the action class - act_type = act_spec.get("type") - act_options = act_spec.get("options", {}) - self.actions[act_type] = self.act_class_identifiers[act_type](self, **global_action_args, **act_options) - - self.action_map: Dict[int, Tuple[str, Dict]] = {} - """ - Action mapping that converts an integer to a specific action and parameter choice. - - For example : - {0: ("NODE_SERVICE_SCAN", {node_id:0, service_id:2})} - """ - if act_map is None: - # raise RuntimeError("Action map must be specified in the config file.") - pass - else: - self.action_map = {i: (a["action"], a["options"]) for i, a in act_map.items()} - # make sure all numbers between 0 and N are represented as dict keys in action map - assert all([i in self.action_map.keys() for i in range(len(self.action_map))]) - - def _enumerate_actions( - self, - ) -> Dict[int, Tuple[str, Dict]]: - """Generate a list of all the possible actions that could be taken. - - This enumerates all actions all combinations of parameters you could choose for those actions. The output - of this function is intended to populate the self.action_map parameter in the situation where the user provides - a list of action types, but doesn't specify any subset of actions that should be made available to the agent. - - The enumeration relies on the Actions' `shape` attribute. - - :return: An action map maps consecutive integers to a combination of Action type and parameter choices. - An example output could be: - {0: ("DONOTHING", {'dummy': 0}), - 1: ("NODE_OS_SCAN", {'node_id': 0}), - 2: ("NODE_OS_SCAN", {'node_id': 1}), - 3: ("NODE_FOLDER_SCAN", {'node_id:0, folder_id:0}), - ... #etc... - } - :rtype: Dict[int, Tuple[AbstractAction, Dict]] - """ - all_action_possibilities = [] - for act_name, action in self.actions.items(): - param_names = list(action.shape.keys()) - num_possibilities = list(action.shape.values()) - possibilities = [range(n) for n in num_possibilities] - - param_combinations = list(itertools.product(*possibilities)) - all_action_possibilities.extend( - [ - (act_name, {param_names[i]: param_combinations[j][i] for i in range(len(param_names))}) - for j in range(len(param_combinations)) - ] - ) - - return {i: p for i, p in enumerate(all_action_possibilities)} - - def get_action(self, action: int) -> Tuple[str, Dict]: - """Produce action in CAOS format.""" - """the agent chooses an action (as an integer), this is converted into an action in CAOS format""" - """The CAOS format is basically a action identifier, followed by parameters stored in a dictionary""" - act_identifier, act_options = self.action_map[action] - return act_identifier, act_options - - def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat: - """Take action in CAOS format and use the execution definition to change it into PrimAITE request format.""" - act_obj = self.actions[action_identifier] - return act_obj.form_request(**action_options) - - @property - def space(self) -> spaces.Space: - """Return the gymnasium action space for this agent.""" - return spaces.Discrete(len(self.action_map)) - - def get_node_name_by_idx(self, node_idx: int) -> str: - """ - Get the node name corresponding to the given index. - - :param node_idx: The index of the node to retrieve. - :type node_idx: int - :return: The node hostname. - :rtype: str - """ - if not node_idx < len(self.node_names): - msg = ( - f"Error: agent attempted to perform an action on node {node_idx}, but its action space only" - f"has {len(self.node_names)} nodes." - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.node_names[node_idx] - - def get_folder_name_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]: - """ - Get the folder name corresponding to the given node and folder indices. - - :param node_idx: The index of the node. - :type node_idx: int - :param folder_idx: The index of the folder on the node. - :type folder_idx: int - :return: The name of the folder. Or None if the node has fewer folders than the given index. - :rtype: Optional[str] - """ - if node_idx >= len(self.folder_names) or folder_idx >= len(self.folder_names[node_idx]): - msg = ( - f"Error: agent attempted to perform an action on node {node_idx} and folder {folder_idx}, but this" - f" is out of range for its action space. Folder on each node: {self.folder_names}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.folder_names[node_idx][folder_idx] - - def get_file_name_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]: - """Get the file name corresponding to the given node, folder, and file indices. - - :param node_idx: The index of the node. - :type node_idx: int - :param folder_idx: The index of the folder on the node. - :type folder_idx: int - :param file_idx: The index of the file in the folder. - :type file_idx: int - :return: The name of the file. Or None if the node has fewer folders than the given index, or the folder has - fewer files than the given index. - :rtype: Optional[str] - """ - if ( - node_idx >= len(self.file_names) - or folder_idx >= len(self.file_names[node_idx]) - or file_idx >= len(self.file_names[node_idx][folder_idx]) - ): - msg = ( - f"Error: agent attempted to perform an action on node {node_idx} folder {folder_idx} file {file_idx}" - f" but this is out of range for its action space. Files on each node: {self.file_names}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.file_names[node_idx][folder_idx][file_idx] - - def get_service_name_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]: - """Get the service name corresponding to the given node and service indices. - - :param node_idx: The index of the node. - :type node_idx: int - :param service_idx: The index of the service on the node. - :type service_idx: int - :return: The name of the service. Or None if the node has fewer services than the given index. - :rtype: Optional[str] - """ - if node_idx >= len(self.service_names) or service_idx >= len(self.service_names[node_idx]): - msg = ( - f"Error: agent attempted to perform an action on node {node_idx} and service {service_idx}, but this" - f" is out of range for its action space. Services on each node: {self.service_names}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.service_names[node_idx][service_idx] - - def get_application_name_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]: - """Get the application name corresponding to the given node and service indices. - - :param node_idx: The index of the node. - :type node_idx: int - :param application_idx: The index of the service on the node. - :type application_idx: int - :return: The name of the service. Or None if the node has fewer services than the given index. - :rtype: Optional[str] - """ - if node_idx >= len(self.application_names) or application_idx >= len(self.application_names[node_idx]): - msg = ( - f"Error: agent attempted to perform an action on node {node_idx} and app {application_idx}, but " - f"this is out of range for its action space. Applications on each node: {self.application_names}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.application_names[node_idx][application_idx] - - def get_internet_protocol_by_idx(self, protocol_idx: int) -> str: - """Get the internet protocol corresponding to the given index. - - :param protocol_idx: The index of the protocol to retrieve. - :type protocol_idx: int - :return: The protocol. - :rtype: str - """ - if protocol_idx >= len(self.protocols): - msg = ( - f"Error: agent attempted to perform an action on protocol {protocol_idx} but this" - f" is out of range for its action space. Protocols: {self.protocols}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.protocols[protocol_idx] - - def get_ip_address_by_idx(self, ip_idx: int) -> str: - """ - Get the IP address corresponding to the given index. - - :param ip_idx: The index of the IP address to retrieve. - :type ip_idx: int - :return: The IP address. - :rtype: str - """ - if ip_idx >= len(self.ip_address_list): - msg = ( - f"Error: agent attempted to perform an action on ip address {ip_idx} but this" - f" is out of range for its action space. IP address list: {self.ip_address_list}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.ip_address_list[ip_idx] - - def get_wildcard_by_idx(self, wildcard_idx: int) -> str: - """ - Get the IP wildcard corresponding to the given index. - - :param ip_idx: The index of the IP wildcard to retrieve. - :type ip_idx: int - :return: The wildcard address. - :rtype: str - """ - if wildcard_idx >= len(self.wildcard_list): - msg = ( - f"Error: agent attempted to perform an action on ip wildcard {wildcard_idx} but this" - f" is out of range for its action space. Wildcard list: {self.wildcard_list}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.wildcard_list[wildcard_idx] - - def get_port_by_idx(self, port_idx: int) -> str: - """ - Get the port corresponding to the given index. - - :param port_idx: The index of the port to retrieve. - :type port_idx: int - :return: The port. - :rtype: str - """ - if port_idx >= len(self.ports): - msg = ( - f"Error: agent attempted to perform an action on port {port_idx} but this" - f" is out of range for its action space. Port list: {self.ip_address_list}" - ) - _LOGGER.error(msg) - raise RuntimeError(msg) - return self.ports[port_idx] - - def get_nic_num_by_idx(self, node_idx: int, nic_idx: int) -> int: - """ - Get the NIC number corresponding to the given node and NIC indices. - - :param node_idx: The index of the node. - :type node_idx: int - :param nic_idx: The index of the NIC on the node. - :type nic_idx: int - :return: The NIC number. - :rtype: int - """ - return nic_idx + 1 - - @classmethod - def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager": - """ - Construct an ActionManager from a config definition. - - The action space config supports the following three sections: - 1. ``action_list`` - ``action_list`` contains a list action components which need to be included in the action space. - Each action component has a ``type`` which maps to a subclass of AbstractAction, and additional options - which will be passed to the action class's __init__ method during initialisation. - 2. ``action_map`` - Since the agent uses a discrete action space which acts as a flattened version of the component-based - action space, action_map provides a mapping between an integer (chosen by the agent) and a meaningful - action and values of parameters. For example action 0 can correspond to do nothing, action 1 can - correspond to "NODE_SERVICE_SCAN" with ``node_id=1`` and ``service_id=1``, action 2 can be " - 3. ``options`` - ``options`` contains a dictionary of options which are passed to the ActionManager's __init__ method. - These options are used to calculate the shape of the action space, and to provide additional information - to the ActionManager which is required to convert the agent's action choice into a CAOS request. - - :param game: The Primaite Game to which the agent belongs. - :type game: PrimaiteGame - :param cfg: The action space config. - :type cfg: Dict - :return: The constructed ActionManager. - :rtype: ActionManager - """ - if "ip_list" not in cfg["options"]: - cfg["options"]["ip_list"] = [] - - obj = cls( - actions=cfg["action_list"], - **cfg["options"], - protocols=game.options.protocols, - ports=game.options.ports, - act_map=cfg.get("action_map"), - ) - - return obj diff --git a/src/primaite/game/agent/actions/__init__.py b/src/primaite/game/agent/actions/__init__.py new file mode 100644 index 00000000..8517ded8 --- /dev/null +++ b/src/primaite/game/agent/actions/__init__.py @@ -0,0 +1,33 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + +from primaite.game.agent.actions import ( + abstract, + acl, + application, + file, + folder, + host_nic, + manager, + network, + node, + service, + session, + software, +) +from primaite.game.agent.actions.manager import ActionManager + +__all__ = ( + "abstract", + "acl", + "application", + "software", + "file", + "folder", + "host_nic", + "manager", + "network", + "node", + "service", + "session", + "ActionManager", +) diff --git a/src/primaite/game/agent/actions/abstract.py b/src/primaite/game/agent/actions/abstract.py new file mode 100644 index 00000000..1c039ed3 --- /dev/null +++ b/src/primaite/game/agent/actions/abstract.py @@ -0,0 +1,41 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from __future__ import annotations + +from abc import ABC +from typing import Any, ClassVar, Dict, Optional, Type + +from pydantic import BaseModel, ConfigDict + +from primaite.interface.request import RequestFormat + + +class AbstractAction(BaseModel, ABC): + """Base class for actions.""" + + class ConfigSchema(BaseModel, ABC): + """Base configuration schema for Actions.""" + + model_config = ConfigDict(extra="forbid") + type: str = "" + + _registry: ClassVar[Dict[str, Type[AbstractAction]]] = {} + + def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None: + """ + Register an action type. + + :param discriminator: discriminator used to uniquely specify action types. + :type discriminator: str + :raises ValueError: When attempting to create an action with a name that is already in use. + """ + super().__init_subclass__(**kwargs) + if discriminator is None: + return + if discriminator in cls._registry: + raise ValueError(f"Cannot create new action under reserved name {discriminator}") + cls._registry[discriminator] = cls + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + pass diff --git a/src/primaite/game/agent/actions/acl.py b/src/primaite/game/agent/actions/acl.py new file mode 100644 index 00000000..fb59574d --- /dev/null +++ b/src/primaite/game/agent/actions/acl.py @@ -0,0 +1,153 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from __future__ import annotations + +from abc import ABC +from typing import Literal, Union + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import Port + +__all__ = ( + "RouterACLAddRuleAction", + "RouterACLRemoveRuleAction", + "FirewallACLAddRuleAction", + "FirewallACLRemoveRuleAction", +) + + +class ACLAddRuleAbstractAction(AbstractAction, ABC): + """Base abstract class for ACL add rule actions.""" + + class ConfigSchema(AbstractAction.ConfigSchema, ABC): + """Configuration Schema base for ACL add rule abstract actions.""" + + src_ip: Union[IPV4Address, Literal["ALL"]] + protocol_name: Union[IPProtocol, Literal["ALL"]] + permission: Literal["PERMIT", "DENY"] + position: int + dst_ip: Union[IPV4Address, Literal["ALL"]] + src_port: Union[Port, Literal["ALL"]] + dst_port: Union[Port, Literal["ALL"]] + src_wildcard: Union[IPV4Address, Literal["NONE"]] + dst_wildcard: Union[IPV4Address, Literal["NONE"]] + + +class ACLRemoveRuleAbstractAction(AbstractAction, ABC): + """Base abstract class for acl remove rule actions.""" + + class ConfigSchema(AbstractAction.ConfigSchema, ABC): + """Configuration Schema base for ACL remove rule abstract actions.""" + + position: int + + +class RouterACLAddRuleAction(ACLAddRuleAbstractAction, discriminator="router-acl-add-rule"): + """Action which adds a rule to a router's acl.""" + + config: "RouterACLAddRuleAction.ConfigSchema" + + class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema): + """Configuration Schema for RouterACLAddRuleAction.""" + + target_router: str + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.target_router, + "acl", + "add_rule", + config.permission, + config.protocol_name, + str(config.src_ip), + str(config.src_wildcard), + config.src_port, + str(config.dst_ip), + str(config.dst_wildcard), + config.dst_port, + config.position, + ] + + +class RouterACLRemoveRuleAction(ACLRemoveRuleAbstractAction, discriminator="router-acl-remove-rule"): + """Action which removes a rule from a router's acl.""" + + config: "RouterACLRemoveRuleAction.ConfigSchema" + + class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema): + """Configuration schema for RouterACLRemoveRuleAction.""" + + target_router: str + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return ["network", "node", config.target_router, "acl", "remove_rule", config.position] + + +class FirewallACLAddRuleAction(ACLAddRuleAbstractAction, discriminator="firewall-acl-add-rule"): + """Action which adds a rule to a firewall port's acl.""" + + config: "FirewallACLAddRuleAction.ConfigSchema" + + class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema): + """Configuration schema for FirewallACLAddRuleAction.""" + + target_firewall_nodename: str + firewall_port_name: str + firewall_port_direction: str + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.target_firewall_nodename, + config.firewall_port_name, + config.firewall_port_direction, + "acl", + "add_rule", + config.permission, + config.protocol_name, + str(config.src_ip), + str(config.src_wildcard), + config.src_port, + str(config.dst_ip), + str(config.dst_wildcard), + config.dst_port, + config.position, + ] + + +class FirewallACLRemoveRuleAction(ACLRemoveRuleAbstractAction, discriminator="firewall-acl-remove-rule"): + """Action which removes a rule from a firewall port's acl.""" + + config: "FirewallACLRemoveRuleAction.ConfigSchema" + + class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema): + """Configuration schema for FirewallACLRemoveRuleAction.""" + + target_firewall_nodename: str + firewall_port_name: str + firewall_port_direction: str + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.target_firewall_nodename, + config.firewall_port_name, + config.firewall_port_direction, + "acl", + "remove_rule", + config.position, + ] diff --git a/src/primaite/game/agent/actions/application.py b/src/primaite/game/agent/actions/application.py new file mode 100644 index 00000000..9651b600 --- /dev/null +++ b/src/primaite/game/agent/actions/application.py @@ -0,0 +1,135 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from abc import ABC +from typing import ClassVar + +from primaite.game.agent.actions.abstract import AbstractAction +from primaite.interface.request import RequestFormat + +__all__ = ( + "NodeApplicationExecuteAction", + "NodeApplicationScanAction", + "NodeApplicationCloseAction", + "NodeApplicationFixAction", + "NodeApplicationInstallAction", + "NodeApplicationRemoveAction", +) + + +class NodeApplicationAbstractAction(AbstractAction, ABC): + """ + Base class for application actions. + + Any action which applies to an application and uses node_name and application_name as its only two parameters can + inherit from this base class. + """ + + class ConfigSchema(AbstractAction.ConfigSchema, ABC): + """Base Configuration schema for Node Application actions.""" + + node_name: str + application_name: str + verb: ClassVar[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.node_name, + "application", + config.application_name, + config.verb, + ] + + +class NodeApplicationExecuteAction(NodeApplicationAbstractAction, discriminator="node-application-execute"): + """Action which executes an application.""" + + config: "NodeApplicationExecuteAction.ConfigSchema" + + class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): + """Configuration schema for NodeApplicationExecuteAction.""" + + verb: str = "execute" + + +class NodeApplicationScanAction(NodeApplicationAbstractAction, discriminator="node-application-scan"): + """Action which scans an application.""" + + config: "NodeApplicationScanAction.ConfigSchema" + + class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): + """Configuration schema for NodeApplicationScanAction.""" + + verb: str = "scan" + + +class NodeApplicationCloseAction(NodeApplicationAbstractAction, discriminator="node-application-close"): + """Action which closes an application.""" + + config: "NodeApplicationCloseAction.ConfigSchema" + + class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): + """Configuration schema for NodeApplicationCloseAction.""" + + verb: str = "close" + + +class NodeApplicationFixAction(NodeApplicationAbstractAction, discriminator="node-application-fix"): + """Action which fixes an application.""" + + config: "NodeApplicationFixAction.ConfigSchema" + + class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): + """Configuration schema for NodeApplicationFixAction.""" + + verb: str = "fix" + + +class NodeApplicationInstallAction(NodeApplicationAbstractAction, discriminator="node-application-install"): + """Action which installs an application.""" + + config: "NodeApplicationInstallAction.ConfigSchema" + + class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): + """Configuration schema for NodeApplicationInstallAction.""" + + verb: str = "install" + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.node_name, + "software_manager", + "application", + config.verb, + config.application_name, + ] + + +class NodeApplicationRemoveAction(NodeApplicationAbstractAction, discriminator="node-application-remove"): + """Action which removes/uninstalls an application.""" + + config: "NodeApplicationRemoveAction.ConfigSchema" + + class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): + """Configuration schema for NodeApplicationRemoveAction.""" + + verb: str = "uninstall" + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.node_name, + "software_manager", + "application", + config.verb, + config.application_name, + ] diff --git a/src/primaite/game/agent/actions/file.py b/src/primaite/game/agent/actions/file.py new file mode 100644 index 00000000..a2fcd3e2 --- /dev/null +++ b/src/primaite/game/agent/actions/file.py @@ -0,0 +1,187 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from abc import ABC +from typing import ClassVar + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat + +__all__ = ( + "NodeFileCreateAction", + "NodeFileScanAction", + "NodeFileDeleteAction", + "NodeFileRestoreAction", + "NodeFileCorruptAction", + "NodeFileAccessAction", + "NodeFileCheckhashAction", + "NodeFileRepairAction", +) + + +class NodeFileAbstractAction(AbstractAction, ABC): + """Abstract base class for file actions. + + Any action which applies to a file and uses node_name, folder_name, and file_name as its + only three parameters can inherit from this base class. + """ + + class ConfigSchema(AbstractAction.ConfigSchema, ABC): + """Configuration Schema for NodeFileAbstractAction.""" + + node_name: str + folder_name: str + file_name: str + verb: ClassVar[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.folder_name is None or config.file_name is None: + return ["do-nothing"] + return [ + "network", + "node", + config.node_name, + "file_system", + "folder", + config.folder_name, + "file", + config.file_name, + config.verb, + ] + + +class NodeFileCreateAction(NodeFileAbstractAction, discriminator="node-file-create"): + """Action which creates a new file in a given folder.""" + + config: "NodeFileCreateAction.ConfigSchema" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + """Configuration schema for NodeFileCreateAction.""" + + verb: ClassVar[str] = "create" + force: bool = False + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.folder_name is None or config.file_name is None: + return ["do-nothing"] + return [ + "network", + "node", + config.node_name, + "file_system", + config.verb, + "file", + config.folder_name, + config.file_name, + config.verb, + ] + + +class NodeFileScanAction(NodeFileAbstractAction, discriminator="node-file-scan"): + """Action which scans a file.""" + + config: "NodeFileScanAction.ConfigSchema" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + """Configuration schema for NodeFileScanAction.""" + + verb: ClassVar[str] = "scan" + + +class NodeFileDeleteAction(NodeFileAbstractAction, discriminator="node-file-delete"): + """Action which deletes a file.""" + + config: "NodeFileDeleteAction.ConfigSchema" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + """Configuration schema for NodeFileDeleteAction.""" + + verb: ClassVar[str] = "delete" + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.folder_name is None or config.file_name is None: + return ["do-nothing"] + return [ + "network", + "node", + config.node_name, + "file_system", + config.verb, + "file", + config.folder_name, + config.file_name, + ] + + +class NodeFileRestoreAction(NodeFileAbstractAction, discriminator="node-file-restore"): + """Action which restores a file.""" + + config: "NodeFileRestoreAction.ConfigSchema" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + """Configuration schema for NodeFileRestoreAction.""" + + verb: ClassVar[str] = "restore" + + +class NodeFileCorruptAction(NodeFileAbstractAction, discriminator="node-file-corrupt"): + """Action which corrupts a file.""" + + config: "NodeFileCorruptAction.ConfigSchema" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + """Configuration schema for NodeFileCorruptAction.""" + + verb: ClassVar[str] = "corrupt" + + +class NodeFileAccessAction(NodeFileAbstractAction, discriminator="node-file-access"): + """Action which increases a file's access count.""" + + config: "NodeFileAccessAction.ConfigSchema" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + """Configuration schema for NodeFileAccessAction.""" + + verb: ClassVar[str] = "access" + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.folder_name is None or config.file_name is None: + return ["do-nothing"] + return [ + "network", + "node", + config.node_name, + "file_system", + config.verb, + config.folder_name, + config.file_name, + ] + + +class NodeFileCheckhashAction(NodeFileAbstractAction, discriminator="node-file-checkhash"): + """Action which checks the hash of a file.""" + + config: "NodeFileCheckhashAction.ConfigSchema" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + """Configuration schema for NodeFileCheckhashAction.""" + + verb: ClassVar[str] = "checkhash" + + +class NodeFileRepairAction(NodeFileAbstractAction, discriminator="node-file-repair"): + """Action which repairs a file.""" + + config: "NodeFileRepairAction.ConfigSchema" + + class ConfigSchema(NodeFileAbstractAction.ConfigSchema): + """Configuration Schema for NodeFileRepairAction.""" + + verb: ClassVar[str] = "repair" diff --git a/src/primaite/game/agent/actions/folder.py b/src/primaite/game/agent/actions/folder.py new file mode 100644 index 00000000..80be0cd5 --- /dev/null +++ b/src/primaite/game/agent/actions/folder.py @@ -0,0 +1,115 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from abc import ABC +from typing import ClassVar + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat + +__all__ = ( + "NodeFolderScanAction", + "NodeFolderCheckhashAction", + "NodeFolderRepairAction", + "NodeFolderRestoreAction", + "NodeFolderCreateAction", +) + + +class NodeFolderAbstractAction(AbstractAction, ABC): + """ + Base class for folder actions. + + Any action which applies to a folder and uses node_name and folder_name as its only two parameters can inherit from + this base class. + """ + + class ConfigSchema(AbstractAction.ConfigSchema, ABC): + """Base configuration schema for NodeFolder actions.""" + + node_name: str + folder_name: str + verb: ClassVar[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.folder_name is None: + return ["do-nothing"] + return [ + "network", + "node", + config.node_name, + "file_system", + "folder", + config.folder_name, + config.verb, + ] + + +class NodeFolderScanAction(NodeFolderAbstractAction, discriminator="node-folder-scan"): + """Action which scans a folder.""" + + config: "NodeFolderScanAction.ConfigSchema" + + class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): + """Configuration schema for NodeFolderScanAction.""" + + verb: ClassVar[str] = "scan" + + +class NodeFolderCheckhashAction(NodeFolderAbstractAction, discriminator="node-folder-checkhash"): + """Action which checks the hash of a folder.""" + + config: "NodeFolderCheckhashAction.ConfigSchema" + + class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): + """Configuration schema for NodeFolderCheckhashAction.""" + + verb: ClassVar[str] = "checkhash" + + +class NodeFolderRepairAction(NodeFolderAbstractAction, discriminator="node-folder-repair"): + """Action which repairs a folder.""" + + config: "NodeFolderRepairAction.ConfigSchema" + + class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): + """Configuration schema for NodeFolderRepairAction.""" + + verb: ClassVar[str] = "repair" + + +class NodeFolderRestoreAction(NodeFolderAbstractAction, discriminator="node-folder-restore"): + """Action which restores a folder.""" + + config: "NodeFolderRestoreAction.ConfigSchema" + + class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): + """Configuration schema for NodeFolderRestoreAction.""" + + verb: ClassVar[str] = "restore" + + +class NodeFolderCreateAction(NodeFolderAbstractAction, discriminator="node-folder-create"): + """Action which creates a new folder.""" + + config: "NodeFolderCreateAction.ConfigSchema" + + class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): + """Configuration schema for NodeFolderCreateAction.""" + + verb: ClassVar[str] = "create" + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.folder_name is None: + return ["do-nothing"] + return [ + "network", + "node", + config.node_name, + "file_system", + config.verb, + "folder", + config.folder_name, + ] diff --git a/src/primaite/game/agent/actions/host_nic.py b/src/primaite/game/agent/actions/host_nic.py new file mode 100644 index 00000000..d192a757 --- /dev/null +++ b/src/primaite/game/agent/actions/host_nic.py @@ -0,0 +1,60 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from abc import ABC +from typing import ClassVar + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat + +__all__ = ("HostNICEnableAction", "HostNICDisableAction") + + +class HostNICAbstractAction(AbstractAction, ABC): + """ + Abstract base class for NIC actions. + + Any action which applies to a NIC and uses node_name and nic_num as its only two parameters can inherit from this + base class. + """ + + class ConfigSchema(AbstractAction.ConfigSchema, ABC): + """Base Configuration schema for HostNIC actions.""" + + node_name: str + nic_num: int + verb: ClassVar[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.nic_num is None: + return ["do-nothing"] + return [ + "network", + "node", + config.node_name, + "network_interface", + config.nic_num, + config.verb, + ] + + +class HostNICEnableAction(HostNICAbstractAction, discriminator="host-nic-enable"): + """Action which enables a NIC.""" + + config: "HostNICEnableAction.ConfigSchema" + + class ConfigSchema(HostNICAbstractAction.ConfigSchema): + """Configuration schema for HostNICEnableAction.""" + + verb: ClassVar[str] = "enable" + + +class HostNICDisableAction(HostNICAbstractAction, discriminator="host-nic-disable"): + """Action which disables a NIC.""" + + config: "HostNICDisableAction.ConfigSchema" + + class ConfigSchema(HostNICAbstractAction.ConfigSchema): + """Configuration schema for HostNICDisableAction.""" + + verb: ClassVar[str] = "disable" diff --git a/src/primaite/game/agent/actions/manager.py b/src/primaite/game/agent/actions/manager.py new file mode 100644 index 00000000..0a9d3ffd --- /dev/null +++ b/src/primaite/game/agent/actions/manager.py @@ -0,0 +1,108 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +"""yaml example. + +agents: + - name: agent_1 + action_space: + actions: + - do-nothing + - node-service-start + - node-service-stop + action_map: +""" + +from __future__ import annotations + +from typing import Dict, Tuple + +from gymnasium import spaces +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from primaite.game.agent.actions.abstract import AbstractAction +from primaite.interface.request import RequestFormat + +__all__ = ("DoNothingAction", "ActionManager") + + +class DoNothingAction(AbstractAction, discriminator="do-nothing"): + """Do Nothing Action.""" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Configuration Schema for do_nothingAction.""" + + type: str = "do-nothing" + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return ["do-nothing"] + + +class _ActionMapItem(BaseModel): + model_config = ConfigDict(extra="forbid") + + action: str + options: Dict + + +class ActionManager(BaseModel): + """Class which manages the action space for an agent.""" + + class ConfigSchema(BaseModel): + """Config Schema for ActionManager.""" + + model_config = ConfigDict(extra="forbid") + action_map: Dict[int, _ActionMapItem] = {} + """Mapping between integer action choices and CAOS actions.""" + + @field_validator("action_map", mode="after") + def consecutive_action_nums(cls, v: Dict) -> Dict: + """Make sure all numbers between 0 and N are represented as dict keys in action map.""" + assert all([i in v.keys() for i in range(len(v))]) + return v + + config: ActionManager.ConfigSchema = Field(default_factory=lambda: ActionManager.ConfigSchema()) + + action_map: Dict[int, Tuple[str, Dict]] = {} + """Init as empty, populate after model validation.""" + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.action_map = {n: (v.action, v.options) for n, v in self.config.action_map.items()} + + def get_action(self, action: int) -> Tuple[str, Dict]: + """ + Produce action in CAOS format. + + The agent chooses an action (as an integer), this is converted into an action in CAOS format + The CAOS format is basically an action identifier, followed by parameters stored in a dictionary. + """ + act_identifier, act_options = self.action_map[action] + return act_identifier, act_options + + def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat: + """Take action in CAOS format and use the execution definition to change it into PrimAITE request format.""" + act_class = AbstractAction._registry[action_identifier] + config = act_class.ConfigSchema(**action_options) + return act_class.form_request(config=config) + + @property + def space(self) -> spaces.Space: + """Return the gymnasium action space for this agent.""" + return spaces.Discrete(len(self.action_map)) + + @classmethod + def from_config(cls, cfg: Dict) -> "ActionManager": + """ + Construct an ActionManager from a config dictionary. + + The action space config supports must contain the following key: + ``action_map`` - List of actions available to the agent, formatted as a dictionary where the key is the + action number between 0 - N, and the value is the CAOS-formatted action. + + :param cfg: The action space config. + :type cfg: Dict + :return: The constructed ActionManager. + :rtype: ActionManager + """ + return cls(**cfg.get("options", {}), act_map=cfg.get("action_map")) diff --git a/src/primaite/game/agent/actions/network.py b/src/primaite/game/agent/actions/network.py new file mode 100644 index 00000000..22fc2c2d --- /dev/null +++ b/src/primaite/game/agent/actions/network.py @@ -0,0 +1,56 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + +from abc import ABC +from typing import ClassVar + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat + +__all__ = ("NetworkPortEnableAction", "NetworkPortDisableAction") + + +class NetworkPortAbstractAction(AbstractAction, ABC): + """Base class for Network port actions.""" + + class ConfigSchema(AbstractAction.ConfigSchema, ABC): + """Base configuration schema for NetworkPort actions.""" + + target_nodename: str + port_num: int + verb: ClassVar[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.target_nodename is None or config.port_num is None: + return ["do-nothing"] + return [ + "network", + "node", + config.target_nodename, + "network_interface", + config.port_num, + config.verb, + ] + + +class NetworkPortEnableAction(NetworkPortAbstractAction, discriminator="network-port-enable"): + """Action which enables are port on a router or a firewall.""" + + config: "NetworkPortEnableAction.ConfigSchema" + + class ConfigSchema(NetworkPortAbstractAction.ConfigSchema): + """Configuration schema for NetworkPortEnableAction.""" + + verb: ClassVar[str] = "enable" + + +class NetworkPortDisableAction(NetworkPortAbstractAction, discriminator="network-port-disable"): + """Action which disables are port on a router or a firewall.""" + + config: "NetworkPortDisableAction.ConfigSchema" + + class ConfigSchema(NetworkPortAbstractAction.ConfigSchema): + """Configuration schema for NetworkPortDisableAction.""" + + verb: ClassVar[str] = "disable" diff --git a/src/primaite/game/agent/actions/node.py b/src/primaite/game/agent/actions/node.py new file mode 100644 index 00000000..b1b6ec12 --- /dev/null +++ b/src/primaite/game/agent/actions/node.py @@ -0,0 +1,246 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from abc import ABC, abstractmethod +from typing import ClassVar, List, Literal, Optional, Union + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.port import Port + +__all__ = ( + "NodeOSScanAction", + "NodeShutdownAction", + "NodeStartupAction", + "NodeResetAction", + "NodeNMAPPingScanAction", + "NodeNMAPPortScanAction", + "NodeNetworkServiceReconAction", +) + + +class NodeAbstractAction(AbstractAction, ABC): + """ + Abstract base class for node actions. + + Any action which applies to a node and uses node_name as its only parameter can inherit from this base class. + """ + + class ConfigSchema(AbstractAction.ConfigSchema, ABC): + """Base Configuration schema for Node actions.""" + + node_name: str + verb: ClassVar[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return ["network", "node", config.node_name, config.verb] + + +class NodeOSScanAction(NodeAbstractAction, discriminator="node-os-scan"): + """Action which scans a node's OS.""" + + config: "NodeOSScanAction.ConfigSchema" + + class ConfigSchema(NodeAbstractAction.ConfigSchema): + """Configuration schema for NodeOSScanAction.""" + + verb: ClassVar[str] = "scan" + + +class NodeShutdownAction(NodeAbstractAction, discriminator="node-shutdown"): + """Action which shuts down a node.""" + + config: "NodeShutdownAction.ConfigSchema" + + class ConfigSchema(NodeAbstractAction.ConfigSchema): + """Configuration schema for NodeShutdownAction.""" + + verb: ClassVar[str] = "shutdown" + + +class NodeStartupAction(NodeAbstractAction, discriminator="node-startup"): + """Action which starts up a node.""" + + config: "NodeStartupAction.ConfigSchema" + + class ConfigSchema(NodeAbstractAction.ConfigSchema): + """Configuration schema for NodeStartupAction.""" + + verb: ClassVar[str] = "startup" + + +class NodeResetAction(NodeAbstractAction, discriminator="node-reset"): + """Action which resets a node.""" + + config: "NodeResetAction.ConfigSchema" + + class ConfigSchema(NodeAbstractAction.ConfigSchema): + """Configuration schema for NodeResetAction.""" + + verb: ClassVar[str] = "reset" + + +class NodeNMAPAbstractAction(AbstractAction, ABC): + """Base class for NodeNMAP actions.""" + + class ConfigSchema(AbstractAction.ConfigSchema, ABC): + """Base Configuration Schema for NodeNMAP actions.""" + + target_ip_address: Union[str, List[str]] + show: bool = False + source_node: str + + @classmethod + @abstractmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + # NMAP action requests don't share a common format for their requests + # This is just a placeholder to ensure the method is defined. + pass + + +class NodeNMAPPingScanAction(NodeNMAPAbstractAction, discriminator="node-nmap-ping-scan"): + """Action which performs an nmap ping scan.""" + + config: "NodeNMAPPingScanAction.ConfigSchema" + + @classmethod + def form_request(cls, config: "NodeNMAPPingScanAction.ConfigSchema") -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.source_node, + "application", + "nmap", + "ping_scan", + {"target_ip_address": config.target_ip_address, "show": config.show}, + ] + + +class NodeNMAPPortScanAction(NodeNMAPAbstractAction, discriminator="node-nmap-port-scan"): + """Action which performs an nmap port scan.""" + + config: "NodeNMAPPortScanAction.ConfigSchema" + + class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema): + """Configuration Schema for NodeNMAPPortScanAction.""" + + source_node: str + target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None + target_port: Optional[Union[Port, List[Port]]] = None + show: Optional[bool] = (False,) + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.source_node, + "application", + "nmap", + "port_scan", + { + "target_ip_address": config.target_ip_address, + "target_port": config.target_port, + "target_protocol": config.target_protocol, + "show": config.show, + }, + ] + + +class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, discriminator="node-network-service-recon"): + """Action which performs an nmap network service recon (ping scan followed by port scan).""" + + class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema): + """Configuration schema for NodeNetworkServiceReconAction.""" + + target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None + target_port: Optional[Union[Port, List[Port]]] = None + show: Optional[bool] = (False,) + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.source_node, + "application", + "nmap", + "network_service_recon", + { + "target_ip_address": config.target_ip_address, + "target_port": config.target_port, + "target_protocol": config.target_protocol, + "show": config.show, + }, + ] + + +class NodeAccountsAddUserAction(AbstractAction, discriminator="node-account-add-user"): + class ConfigSchema(AbstractAction.ConfigSchema): + type: Literal["node-account-add-user"] = "node-account-add-user" + node_name: str + username: str + password: str + is_admin: bool + + @classmethod + @staticmethod + def form_request(config: ConfigSchema) -> RequestFormat: + return [ + "network", + "node", + config.node_name, + "service", + "user-manager", + "add_user", + config.username, + config.password, + config.is_admin, + ] + + +class NodeAccountsDisableUserAction(AbstractAction, discriminator="node-account-disable-user"): + class ConfigSchema(AbstractAction.ConfigSchema): + type: Literal["node-account-disable-user"] = "node-account-disable-user" + node_name: str + username: str + + @classmethod + @staticmethod + def form_request(config: ConfigSchema) -> RequestFormat: + return [ + "network", + "node", + config.node_name, + "service", + "user-manager", + "disable_user", + config.username, + ] + + +class NodeSendLocalCommandAction(AbstractAction, discriminator="node-send-local-command"): + class ConfigSchema(AbstractAction.ConfigSchema): + type: Literal["node-send-local-command"] = "node-send-local-command" + node_name: str + username: str + password: str + command: RequestFormat + + @staticmethod + def form_request(config: ConfigSchema) -> RequestFormat: + return [ + "network", + "node", + config.node_name, + "service", + "terminal", + "send_local_command", + config.username, + config.password, + {"command": config.command}, + ] diff --git a/src/primaite/game/agent/actions/service.py b/src/primaite/game/agent/actions/service.py new file mode 100644 index 00000000..dfe1f91d --- /dev/null +++ b/src/primaite/game/agent/actions/service.py @@ -0,0 +1,134 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from abc import ABC +from typing import ClassVar + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat + +__all__ = ( + "NodeServiceScanAction", + "NodeServiceStopAction", + "NodeServiceStartAction", + "NodeServicePauseAction", + "NodeServiceResumeAction", + "NodeServiceRestartAction", + "NodeServiceDisableAction", + "NodeServiceEnableAction", + "NodeServiceFixAction", +) + + +class NodeServiceAbstractAction(AbstractAction, ABC): + """Abstract Action for Node Service related actions. + + Any actions which use node_name and service_name can inherit from this class. + """ + + class ConfigSchema(AbstractAction.ConfigSchema, ABC): + node_name: str + service_name: str + verb: ClassVar[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return ["network", "node", config.node_name, "service", config.service_name, config.verb] + + +class NodeServiceScanAction(NodeServiceAbstractAction, discriminator="node-service-scan"): + """Action which scans a service.""" + + config: "NodeServiceScanAction.ConfigSchema" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + """Configuration Schema for NodeServiceScanAction.""" + + verb: ClassVar[str] = "scan" + + +class NodeServiceStopAction(NodeServiceAbstractAction, discriminator="node-service-stop"): + """Action which stops a service.""" + + config: "NodeServiceStopAction.ConfigSchema" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + """Configuration Schema for NodeServiceStopAction.""" + + verb: ClassVar[str] = "stop" + + +class NodeServiceStartAction(NodeServiceAbstractAction, discriminator="node-service-start"): + """Action which starts a service.""" + + config: "NodeServiceStartAction.ConfigSchema" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + """Configuration Schema for NodeServiceStartAction.""" + + verb: ClassVar[str] = "start" + + +class NodeServicePauseAction(NodeServiceAbstractAction, discriminator="node-service-pause"): + """Action which pauses a service.""" + + config: "NodeServicePauseAction.ConfigSchema" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + """Configuration Schema for NodeServicePauseAction.""" + + verb: ClassVar[str] = "pause" + + +class NodeServiceResumeAction(NodeServiceAbstractAction, discriminator="node-service-resume"): + """Action which resumes a service.""" + + config: "NodeServiceResumeAction.ConfigSchema" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + """Configuration Schema for NodeServiceResumeAction.""" + + verb: ClassVar[str] = "resume" + + +class NodeServiceRestartAction(NodeServiceAbstractAction, discriminator="node-service-restart"): + """Action which restarts a service.""" + + config: "NodeServiceRestartAction.ConfigSchema" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + """Configuration Schema for NodeServiceRestartAction.""" + + verb: ClassVar[str] = "restart" + + +class NodeServiceDisableAction(NodeServiceAbstractAction, discriminator="node-service-disable"): + """Action which disables a service.""" + + config: "NodeServiceDisableAction.ConfigSchema" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + """Configuration Schema for NodeServiceDisableAction.""" + + verb: ClassVar[str] = "disable" + + +class NodeServiceEnableAction(NodeServiceAbstractAction, discriminator="node-service-enable"): + """Action which enables a service.""" + + config: "NodeServiceEnableAction.ConfigSchema" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + """Configuration Schema for NodeServiceEnableAction.""" + + verb: ClassVar[str] = "enable" + + +class NodeServiceFixAction(NodeServiceAbstractAction, discriminator="node-service-fix"): + """Action which fixes a service.""" + + config: "NodeServiceFixAction.ConfigSchema" + + class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): + """Configuration Schema for NodeServiceFixAction.""" + + verb: ClassVar[str] = "fix" diff --git a/src/primaite/game/agent/actions/session.py b/src/primaite/game/agent/actions/session.py new file mode 100644 index 00000000..63a45c5e --- /dev/null +++ b/src/primaite/game/agent/actions/session.py @@ -0,0 +1,101 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from abc import ABC, abstractmethod + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat + +__all__ = ( + "NodeSessionsRemoteLoginAction", + "NodeSessionsRemoteLogoutAction", + "NodeAccountChangePasswordAction", +) + + +class NodeSessionAbstractAction(AbstractAction, ABC): + """Base class for NodeSession actions.""" + + class ConfigSchema(AbstractAction.ConfigSchema, ABC): + """Base configuration schema for NodeSessionAbstractActions.""" + + node_name: str + remote_ip: str + + @classmethod + @abstractmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """ + Abstract method for request forming. + + Should return the action formatted as a request which can be ingested by the PrimAITE simulation. + """ + pass + + +class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="node-session-remote-login"): + """Action which performs a remote session login.""" + + class ConfigSchema(NodeSessionAbstractAction.ConfigSchema): + """Configuration schema for NodeSessionsRemoteLoginAction.""" + + username: str + password: str + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.remote_ip is None: + return ["do-nothing"] + return [ + "network", + "node", + config.node_name, + "service", + "terminal", + "node_session_remote_login", + config.username, + config.password, + config.remote_ip, + ] + + +class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, discriminator="node-session-remote-logoff"): + """Action which performs a remote session logout.""" + + class ConfigSchema(NodeSessionAbstractAction.ConfigSchema): + """Configuration schema for NodeSessionsRemoteLogoutAction.""" + + verb: str = "remote_logoff" + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + if config.node_name is None or config.remote_ip is None: + return ["do-nothing"] + return ["network", "node", config.node_name, "service", "terminal", config.verb, config.remote_ip] + + +class NodeAccountChangePasswordAction(AbstractAction, discriminator="node-account-change-password"): + """Action which changes the password for a user.""" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Configuration schema for NodeAccountsChangePasswordAction.""" + + node_name: str + username: str + current_password: str + new_password: str + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.node_name, + "service", + "user-manager", + "change_password", + config.username, + config.current_password, + config.new_password, + ] diff --git a/src/primaite/game/agent/actions/software.py b/src/primaite/game/agent/actions/software.py new file mode 100644 index 00000000..f170146b --- /dev/null +++ b/src/primaite/game/agent/actions/software.py @@ -0,0 +1,246 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + +from typing import List, Optional, Union + +from pydantic import ConfigDict, Field + +from primaite.game.agent.actions.manager import AbstractAction +from primaite.interface.request import RequestFormat +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.ipv4_address import StrIP +from primaite.utils.validation.port import Port + +__all__ = ( + "ConfigureRansomwareScriptAction", + "ConfigureDoSBotAction", + "ConfigureC2BeaconAction", + "NodeSendRemoteCommandAction", + "TerminalC2ServerAction", + "RansomwareLaunchC2ServerAction", + "ExfiltrationC2ServerAction", + "ConfigureDatabaseClientAction", +) + + +class ConfigureRansomwareScriptAction(AbstractAction, discriminator="configure-ransomware-script"): + """Action which sets config parameters for a ransomware script on a node.""" + + config: "ConfigureRansomwareScriptAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Configuration schema for ConfigureRansomwareScriptAction.""" + + node_name: str + server_ip_address: Optional[str] = None + server_password: Optional[str] = None + payload: Optional[str] = None + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request that can be ingested by the simulation.""" + if config.node_name is None: + return ["do-nothing"] + data = dict( + server_ip_address=config.server_ip_address, + server_password=config.server_password, + payload=config.payload, + ) + return ["network", "node", config.node_name, "application", "ransomware-script", "configure", data] + + +class RansomwareConfigureC2ServerAction( + ConfigureRansomwareScriptAction, discriminator="c2-server-ransomware-configure" +): + """Action which causes a C2 server to send a command to set options on a ransomware script remotely.""" + + @classmethod + def form_request(cls, config: ConfigureRansomwareScriptAction.ConfigSchema) -> RequestFormat: + data = dict( + server_ip_address=config.server_ip_address, server_password=config.server_password, payload=config.payload + ) + return ["network", "node", config.node_name, "application", "c2-server", "ransomware_configure", data] + + +class ConfigureDoSBotAction(AbstractAction, discriminator="configure-dos-bot"): + """Action which sets config parameters for a DoS bot on a node.""" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Schema for options that can be passed to this action.""" + + model_config = ConfigDict(extra="forbid") + node_name: str + target_ip_address: Optional[StrIP] = None + target_port: Optional[Port] = None + payload: Optional[str] = None + repeat: Optional[bool] = None + port_scan_p_of_success: Optional[float] = None + dos_intensity: Optional[float] = None + max_sessions: Optional[int] = None + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request that can be ingested by the simulation.""" + data = dict( + target_ip_address=config.target_ip_address, + target_port=config.target_port, + payload=config.payload, + repeat=config.repeat, + port_scan_p_of_success=config.port_scan_p_of_success, + dos_intensity=config.dos_intensity, + max_sessions=config.max_sessions, + ) + data = {k: v for k, v in data.items() if v is not None} + return ["network", "node", config.node_name, "application", "dos-bot", "configure", data] + + +class ConfigureC2BeaconAction(AbstractAction, discriminator="configure-c2-beacon"): + """Action which configures a C2 Beacon based on the parameters given.""" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Configuration schema for ConfigureC2BeaconAction.""" + + node_name: str + c2_server_ip_address: StrIP + keep_alive_frequency: int = Field(default=5, ge=1) + masquerade_protocol: IPProtocol = Field(default="tcp") + masquerade_port: Port = Field(default=80) + + @classmethod + def form_request(self, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request that can be ingested by the simulation.""" + data = dict( + c2_server_ip_address=config.c2_server_ip_address, + keep_alive_frequency=config.keep_alive_frequency, + masquerade_protocol=config.masquerade_protocol, + masquerade_port=config.masquerade_port, + ) + return ["network", "node", config.node_name, "application", "c2-beacon", "configure", data] + + +class NodeSendRemoteCommandAction(AbstractAction, discriminator="node-send-remote-command"): + """Action which sends a terminal command to a remote node via SSH.""" + + config: "NodeSendRemoteCommandAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Configuration schema for NodeSendRemoteCommandAction.""" + + node_name: str + remote_ip: StrIP + command: RequestFormat + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.node_name, + "service", + "terminal", + "send_remote_command", + config.remote_ip, + {"command": config.command}, + ] + + +class TerminalC2ServerAction(AbstractAction, discriminator="c2-server-terminal-command"): + """Action which causes the C2 Server to send a command to the C2 Beacon to execute the terminal command passed.""" + + config: "TerminalC2ServerAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Schema for options that can be passed to this action.""" + + node_name: str + commands: Union[List[RequestFormat], RequestFormat] + ip_address: Optional[StrIP] + username: Optional[str] + password: Optional[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request that can be ingested by the simulation.""" + if config.node_name is None: + return ["do-nothing"] + + command_model = { + "commands": config.commands, + "ip_address": config.ip_address, + "username": config.username, + "password": config.password, + } + return ["network", "node", config.node_name, "application", "c2-server", "terminal_command", command_model] + + +class RansomwareLaunchC2ServerAction(AbstractAction, discriminator="c2-server-ransomware-launch"): + """Action which causes the C2 Server to send a command to the C2 Beacon to launch the RansomwareScript.""" + + config: "RansomwareLaunchC2ServerAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Configuration schema for RansomwareLaunchC2ServerAction.""" + + node_name: str + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request that can be ingested by the simulation.""" + if config.node_name is None: + return ["do-nothing"] + # This action currently doesn't require any further configuration options. + return ["network", "node", config.node_name, "application", "c2-server", "ransomware_launch"] + + +class ExfiltrationC2ServerAction(AbstractAction, discriminator="c2-server-data-exfiltrate"): + """Action which exfiltrates a target file from a certain node onto the C2 beacon and then the C2 Server.""" + + config: "ExfiltrationC2ServerAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Schema for options that can be passed to this action.""" + + node_name: str + username: Optional[str] + password: Optional[str] + target_ip_address: StrIP + target_file_name: str + target_folder_name: str + exfiltration_folder_name: Optional[str] + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request that can be ingested by the simulation.""" + if config.node_name is None: + return ["do-nothing"] + + command_model = { + "target_file_name": config.target_file_name, + "target_folder_name": config.target_folder_name, + "exfiltration_folder_name": config.exfiltration_folder_name, + "target_ip_address": config.target_ip_address, + "username": config.username, + "password": config.password, + } + return ["network", "node", config.node_name, "application", "c2-server", "exfiltrate", command_model] + + +class ConfigureDatabaseClientAction(AbstractAction, discriminator="configure-database-client"): + """Action which sets config parameters for a database client on a node.""" + + config: "ConfigureDatabaseClientAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Schema for options that can be passed to this action.""" + + node_name: str + server_ip_address: Optional[StrIP] = None + server_password: Optional[str] = None + + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + """Return the action formatted as a request that can be ingested by the simulation.""" + if config.node_name is None: + return ["do-nothing"] + data = {"server_ip_address": config.server_ip_address, "server_password": config.server_password} + return ["network", "node", config.node_name, "application", "database-client", "configure", data] diff --git a/src/primaite/game/agent/agent_log.py b/src/primaite/game/agent/agent_log.py index 62ef4884..ddf14489 100644 --- a/src/primaite/game/agent/agent_log.py +++ b/src/primaite/game/agent/agent_log.py @@ -1,6 +1,7 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import logging from pathlib import Path +from typing import Optional from prettytable import MARKDOWN, PrettyTable @@ -20,20 +21,22 @@ class _NotJSONFilter(logging.Filter): class AgentLog: """ - A Agent Log class is a simple logger dedicated to managing and writing logging updates and information for an agent. + An Agent Log class is a simple logger dedicated to managing and writing updates and information for an agent. - Each log message is written to a file located at: /agent_name/agent_name.log + Each log message is written to a file located at: + /agent_name/agent_name.log """ - def __init__(self, agent_name: str): + def __init__(self, agent_name: Optional[str]): """ Constructs a Agent Log instance for a given hostname. - :param hostname: The hostname associated with the system logs being recorded. + :param agent_name: The agent_name associated with the system logs being recorded. """ - self.agent_name = agent_name - self.current_episode: int = 1 + super().__init__() + self.agent_name = agent_name if agent_name else "unnamed_agent" self.current_timestep: int = 0 + self.current_episode: int = 1 self.setup_logger() @property @@ -62,8 +65,9 @@ class AgentLog: The logger is set to the DEBUG level, and is equipped with a handler that writes to a file and filters out JSON-like messages. """ - if not SIM_OUTPUT.save_agent_logs: - return + # TODO: uncomment this once we figure out why it's broken + # if not SIM_OUTPUT.save_agent_logs: + # return log_path = self._get_log_path() file_handler = logging.FileHandler(filename=log_path) @@ -90,7 +94,7 @@ class AgentLog: def _write_to_terminal(self, msg: str, level: str, to_terminal: bool = False): if to_terminal or SIM_OUTPUT.write_agent_log_to_terminal: - print(f"{self.agent_name}: ({ self.timestep}) ({level}) {msg}") + print(f"{self.agent_name}: ({self.timestep}) ({level}) {msg}") def debug(self, msg: str, to_terminal: bool = False): """ diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 6609dd03..a6e9739f 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -1,11 +1,13 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """Interface for agents.""" +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Type, TYPE_CHECKING from gymnasium.core import ActType, ObsType from prettytable import PrettyTable -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, ConfigDict, Field from primaite.game.agent.actions import ActionManager from primaite.game.agent.agent_log import AgentLog @@ -16,6 +18,8 @@ from primaite.interface.request import RequestFormat, RequestResponse if TYPE_CHECKING: pass +__all__ = ("AgentHistoryItem", "AbstractAgent", "AbstractScriptedAgent", "ProxyAgent") + class AgentHistoryItem(BaseModel): """One entry of an agent's action log - what the agent did and how the simulator responded in 1 step.""" @@ -43,119 +47,87 @@ class AgentHistoryItem(BaseModel): """The observation space data for this step.""" -class AgentStartSettings(BaseModel): - """Configuration values for when an agent starts performing actions.""" - - start_step: int = 5 - "The timestep at which an agent begins performing it's actions" - frequency: int = 5 - "The number of timesteps to wait between performing actions" - variance: int = 0 - "The amount the frequency can randomly change to" - - @model_validator(mode="after") - def check_variance_lt_frequency(self) -> "AgentStartSettings": - """ - Make sure variance is equal to or lower than frequency. - - This is because the calculation for the next execution time is now + (frequency +- variance). If variance were - greater than frequency, sometimes the bracketed term would be negative and the attack would never happen again. - """ - if self.variance > self.frequency: - raise ValueError( - f"Agent start settings error: variance must be lower than frequency " - f"{self.variance=}, {self.frequency=}" - ) - return self - - -class AgentSettings(BaseModel): - """Settings for configuring the operation of an agent.""" - - start_settings: Optional[AgentStartSettings] = None - "Configuration for when an agent begins performing it's actions" - flatten_obs: bool = True - "Whether to flatten the observation space before passing it to the agent. True by default." - action_masking: bool = False - "Whether to return action masks at each step." - - @classmethod - def from_config(cls, config: Optional[Dict]) -> "AgentSettings": - """Construct agent settings from a config dictionary. - - :param config: A dict of options for the agent settings. - :type config: Dict - :return: The agent settings. - :rtype: AgentSettings - """ - if config is None: - return cls() - - return cls(**config) - - -class AbstractAgent(ABC): +class AbstractAgent(BaseModel, ABC): """Base class for scripted and RL agents.""" - def __init__( - self, - agent_name: Optional[str], - action_space: Optional[ActionManager], - observation_space: Optional[ObservationManager], - reward_function: Optional[RewardFunction], - agent_settings: Optional[AgentSettings] = None, - ) -> None: - """ - Initialize an agent. + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) - :param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes. - :type agent_name: Optional[str] - :param action_space: Action space for the agent. - :type action_space: Optional[ActionManager] - :param observation_space: Observation space for the agent. - :type observation_space: Optional[ObservationSpace] - :param reward_function: Reward function for the agent. - :type reward_function: Optional[RewardFunction] - :param agent_settings: Configurable Options for Abstracted Agents - :type agent_settings: Optional[AgentSettings] - """ - self.agent_name: str = agent_name or "unnamed_agent" - self.action_manager: Optional[ActionManager] = action_space - self.observation_manager: Optional[ObservationManager] = observation_space - self.reward_function: Optional[RewardFunction] = reward_function - self.agent_settings = agent_settings or AgentSettings() - self.history: List[AgentHistoryItem] = [] - self.logger = AgentLog(agent_name) + class AgentSettingsSchema(BaseModel, ABC): + """Schema for the 'agent_settings' key.""" - def add_agent_action(self, item: AgentHistoryItem, table: PrettyTable) -> PrettyTable: - """Update the given table with information from given AgentHistoryItem.""" - node, application = "unknown", "unknown" - if (node_id := item.parameters.get("node_id")) is not None: - node = self.action_manager.node_names[node_id] - if (application_id := item.parameters.get("application_id")) is not None: - application = self.action_manager.application_names[node_id][application_id] - if (application_name := item.parameters.get("application_name")) is not None: - application = application_name - table.add_row([item.timestep, item.action, node, application, item.response.status]) - return table + model_config = ConfigDict(extra="forbid") + + class ConfigSchema(BaseModel, ABC): + """Configuration Schema for AbstractAgents.""" + + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + type: str + ref: str = "" + """name of the agent.""" + team: Optional[Literal["BLUE", "GREEN", "RED"]] = None + agent_settings: AbstractAgent.AgentSettingsSchema = Field(default=lambda: AbstractAgent.AgentSettingsSchema()) + action_space: ActionManager.ConfigSchema = Field(default_factory=lambda: ActionManager.ConfigSchema()) + observation_space: ObservationManager.ConfigSchema = Field( + default_factory=lambda: ObservationManager.ConfigSchema() + ) + reward_function: RewardFunction.ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema()) + thresholds: Optional[Dict] = {} + # TODO: this is only relevant to some observations, need to refactor the way thresholds are dealt with (#3085) + """A dict containing the observation thresholds.""" + + config: ConfigSchema = Field(default_factory=lambda: AbstractAgent.ConfigSchema()) + + logger: AgentLog = None + history: List[AgentHistoryItem] = [] + + action_manager: ActionManager = Field(default_factory=lambda: ActionManager()) + observation_manager: ObservationManager = Field(default_factory=lambda: ObservationManager()) + reward_function: RewardFunction = Field(default_factory=lambda: RewardFunction()) + + _registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {} + + def __init__(self, **kwargs): + """Initialise and setup agent logger.""" + super().__init__(**kwargs) + self.logger: AgentLog = AgentLog(agent_name=kwargs["config"]["ref"]) + + def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + if discriminator is None: + return + if discriminator in cls._registry: + raise ValueError(f"Cannot create a new agent under reserved name {discriminator}") + cls._registry[discriminator] = cls + + def model_post_init(self, __context: Any) -> None: + """Overwrite the default empty action, observation, and rewards with ones defined through the config.""" + self.action_manager = ActionManager(config=self.config.action_space) + self.config.observation_space.options.thresholds = self.config.thresholds + self.observation_manager = ObservationManager(config=self.config.observation_space) + self.reward_function = RewardFunction(config=self.config.reward_function) + return super().model_post_init(__context) def show_history(self, ignored_actions: Optional[list] = None): """ - Print an agent action provided it's not the DONOTHING action. + Print an agent action provided it's not the do-nothing action. :param ignored_actions: OPTIONAL: List of actions to be ignored when displaying the history. - If not provided, defaults to ignore DONOTHING actions. + If not provided, defaults to ignore do-nothing actions. """ if not ignored_actions: - ignored_actions = ["DONOTHING"] + ignored_actions = ["do-nothing"] table = PrettyTable() - table.field_names = ["Step", "Action", "Node", "Application", "Response"] - print(f"Actions for '{self.agent_name}':") + table.field_names = ["Step", "Action", "Params", "Response", "Response Data"] + print(f"Actions for '{self.config.ref}':") for item in self.history: if item.action in ignored_actions: pass else: - table = self.add_agent_action(item=item, table=table) + # format dict by putting each key-value entry on a separate line and putting a blank line on the end. + param_string = "\n".join([*[f"{k}: {v:.30}" for k, v in item.parameters.items()], ""]) + data_string = "\n".join([*[f"{k}: {v:.30}" for k, v in item.response.data], ""]) + + table.add_row([item.timestep, item.action, param_string, item.response.status, data_string]) print(table) def update_observation(self, state: Dict) -> ObsType: @@ -194,9 +166,9 @@ class AbstractAgent(ABC): """ # in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39, # then use a bespoke conversion to take 1-40 int back into CAOS action - return ("DO_NOTHING", {}) + return ("do-nothing", {}) - def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> List[str]: + def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> RequestFormat: # this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator. # therefore the execution definition needs to be a mapping from CAOS into SIMULATOR """Format action into format expected by the simulator, and apply execution definition if applicable.""" @@ -228,36 +200,47 @@ class AbstractAgent(ABC): """Update the most recent history item with the reward value.""" self.history[-1].reward = self.reward_function.current_reward + @classmethod + def from_config(cls, config: Dict) -> AbstractAgent: + """Grab the relevant agent class and construct an instance from a config dict.""" + agent_type = config["type"] + agent_class = cls._registry[agent_type] + return agent_class(config=config) -class AbstractScriptedAgent(AbstractAgent): + +class AbstractScriptedAgent(AbstractAgent, ABC): """Base class for actors which generate their own behaviour.""" + class ConfigSchema(AbstractAgent.ConfigSchema, ABC): + """Configuration Schema for AbstractScriptedAgents.""" + + type: str = "AbstractScriptedAgent" + + config: ConfigSchema = Field(default_factory=lambda: AbstractScriptedAgent.ConfigSchema()) + @abstractmethod def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """Return an action to be taken in the environment.""" return super().get_action(obs=obs, timestep=timestep) -class ProxyAgent(AbstractAgent): +class ProxyAgent(AbstractAgent, discriminator="proxy-agent"): """Agent that sends observations to an RL model and receives actions from that model.""" - def __init__( - self, - agent_name: Optional[str], - action_space: Optional[ActionManager], - observation_space: Optional[ObservationManager], - reward_function: Optional[RewardFunction], - agent_settings: Optional[AgentSettings] = None, - ) -> None: - super().__init__( - agent_name=agent_name, - action_space=action_space, - observation_space=observation_space, - reward_function=reward_function, - ) - self.most_recent_action: ActType - self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False - self.action_masking: bool = agent_settings.action_masking if agent_settings else False + config: "ProxyAgent.ConfigSchema" = Field(default_factory=lambda: ProxyAgent.ConfigSchema()) + most_recent_action: ActType = None + + class AgentSettingsSchema(AbstractAgent.AgentSettingsSchema): + """Schema for the `agent_settings` part of the agent config.""" + + flatten_obs: bool = False + action_masking: bool = False + + class ConfigSchema(AbstractAgent.ConfigSchema): + """Configuration Schema for Proxy Agent.""" + + type: str = "Proxy_Agent" + agent_settings: ProxyAgent.AgentSettingsSchema = Field(default_factory=lambda: ProxyAgent.AgentSettingsSchema()) def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ @@ -279,3 +262,8 @@ class ProxyAgent(AbstractAgent): The environment is responsible for calling this method when it receives an action from the agent policy. """ self.most_recent_action = action + + @property + def flatten_obs(self) -> bool: + """Return agent flatten_obs param.""" + return self.config.agent_settings.flatten_obs diff --git a/src/primaite/game/agent/observations/__init__.py b/src/primaite/game/agent/observations/__init__.py index 6c88f844..a38095b3 100644 --- a/src/primaite/game/agent/observations/__init__.py +++ b/src/primaite/game/agent/observations/__init__.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK # flake8: noqa # Pre-import all the observations when we load up the observations module so that they can be resolved by the parser. from primaite.game.agent.observations.acl_observation import ACLObservation @@ -17,5 +17,5 @@ from primaite.game.agent.observations.software_observation import ApplicationObs __all__ = [ "ACLObservation", "FileObservation", "FolderObservation", "FirewallObservation", "HostObservation", "LinksObservation", "NICObservation", "PortObservation", "NodesObservation", "NestedObservation", - "ObservationManager", "ApplicationObservation", "ServiceObservation",] + "ObservationManager", "ApplicationObservation", "ServiceObservation", "RouterObservation", "LinkObservation",] # fmt: on diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py index 41af5a8f..b2f5e786 100644 --- a/src/primaite/game/agent/observations/acl_observation.py +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -1,7 +1,6 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations -from ipaddress import IPv4Address from typing import Dict, List, Optional from gymnasium import spaces @@ -10,23 +9,26 @@ from gymnasium.core import ObsType from primaite import getLogger from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.ipv4_address import StrIP +from primaite.utils.validation.port import Port _LOGGER = getLogger(__name__) -class ACLObservation(AbstractObservation, identifier="ACL"): +class ACLObservation(AbstractObservation, discriminator="acl"): """ACL observation, provides information about access control lists within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): """Configuration schema for ACLObservation.""" - ip_list: Optional[List[IPv4Address]] = None + ip_list: Optional[List[StrIP]] = None """List of IP addresses.""" wildcard_list: Optional[List[str]] = None """List of wildcard strings.""" - port_list: Optional[List[int]] = None - """List of port numbers.""" - protocol_list: Optional[List[str]] = None + port_list: Optional[List[Port]] = None + """List of port names.""" + protocol_list: Optional[List[IPProtocol]] = None """List of protocol names.""" num_rules: Optional[int] = None """Number of ACL rules.""" @@ -35,10 +37,10 @@ class ACLObservation(AbstractObservation, identifier="ACL"): self, where: WhereType, num_rules: int, - ip_list: List[IPv4Address], + ip_list: List[StrIP], wildcard_list: List[str], - port_list: List[int], - protocol_list: List[str], + port_list: List[Port], + protocol_list: List[IPProtocol], ) -> None: """ Initialise an ACL observation instance. @@ -48,19 +50,19 @@ class ACLObservation(AbstractObservation, identifier="ACL"): :param num_rules: Number of ACL rules. :type num_rules: int :param ip_list: List of IP addresses. - :type ip_list: List[IPv4Address] + :type ip_list: List[StrIP] :param wildcard_list: List of wildcard strings. :type wildcard_list: List[str] - :param port_list: List of port numbers. - :type port_list: List[int] + :param port_list: List of port names. + :type port_list: List[Port] :param protocol_list: List of protocol names. - :type protocol_list: List[str] + :type protocol_list: List[IPProtocol] """ self.where = where self.num_rules: int = num_rules self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)} self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)} - self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)} + self.port_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(port_list)} self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)} self.default_observation: Dict = { i diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index b24b26a6..a9e3a9aa 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations from typing import Dict, Iterable, List, Optional @@ -13,7 +13,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST _LOGGER = getLogger(__name__) -class FileObservation(AbstractObservation, identifier="FILE"): +class FileObservation(AbstractObservation, discriminator="file"): """File observation, provides status information about a file within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): @@ -158,7 +158,7 @@ class FileObservation(AbstractObservation, identifier="FILE"): ) -class FolderObservation(AbstractObservation, identifier="FOLDER"): +class FolderObservation(AbstractObservation, discriminator="folder"): """Folder observation, provides status information about a folder within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): @@ -225,6 +225,8 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): if self.files: self.default_observation["FILES"] = {i + 1: f.default_observation for i, f in enumerate(self.files)} + self.cached_obs: Optional[ObsType] = self.default_observation + def observe(self, state: Dict) -> ObsType: """ Generate observation based on the current state of the simulation. @@ -239,7 +241,10 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): return self.default_observation if self.file_system_requires_scan: - health_status = folder_state["visible_status"] + if not folder_state["scanned_this_step"]: + health_status = self.cached_obs["health_status"] + else: + health_status = folder_state["visible_status"] else: health_status = folder_state["health_status"] diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py index 42ceaff0..ac3b30d8 100644 --- a/src/primaite/game/agent/observations/firewall_observation.py +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations from typing import Dict, List, Optional @@ -11,11 +11,14 @@ from primaite.game.agent.observations.acl_observation import ACLObservation from primaite.game.agent.observations.nic_observations import PortObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.ipv4_address import StrIP +from primaite.utils.validation.port import Port _LOGGER = getLogger(__name__) -class FirewallObservation(AbstractObservation, identifier="FIREWALL"): +class FirewallObservation(AbstractObservation, discriminator="firewall"): """Firewall observation, provides status information about a firewall within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): @@ -23,26 +26,26 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): hostname: str """Hostname of the firewall node, used for querying simulation state dictionary.""" - ip_list: Optional[List[str]] = None + ip_list: Optional[List[StrIP]] = None """List of IP addresses for encoding ACLs.""" wildcard_list: Optional[List[str]] = None """List of IP wildcards for encoding ACLs.""" - port_list: Optional[List[int]] = None + port_list: Optional[List[Port]] = None """List of ports for encoding ACLs.""" - protocol_list: Optional[List[str]] = None + protocol_list: Optional[List[IPProtocol]] = None """List of protocols for encoding ACLs.""" num_rules: Optional[int] = None """Number of rules ACL rules to show.""" - include_users: Optional[bool] = True + include_users: Optional[bool] = None """If True, report user session information.""" def __init__( self, where: WhereType, - ip_list: List[str], + ip_list: List[StrIP], wildcard_list: List[str], - port_list: List[int], - protocol_list: List[str], + port_list: List[Port], + protocol_list: List[IPProtocol], num_rules: int, include_users: bool, ) -> None: @@ -53,13 +56,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): A typical location for a firewall might be ['network', 'nodes', ]. :type where: WhereType :param ip_list: List of IP addresses. - :type ip_list: List[str] + :type ip_list: List[StrIP] :param wildcard_list: List of wildcard rules. :type wildcard_list: List[str] - :param port_list: List of port numbers. - :type port_list: List[int] + :param port_list: List of port names. + :type port_list: List[Port] :param protocol_list: List of protocol types. - :type protocol_list: List[str] + :type protocol_list: List[IPProtocol] :param num_rules: Number of rules configured in the firewall. :type num_rules: int :param include_users: If True, report user session information. @@ -72,7 +75,6 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): self.ports: List[PortObservation] = [ PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3) ] - # TODO: check what the port nums are for firewall. self.internal_inbound_acl = ACLObservation( where=self.where + ["internal_inbound_acl", "acl"], @@ -140,6 +142,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): }, }, } + if self.include_users: + self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0} def observe(self, state: Dict) -> ObsType: """ @@ -153,29 +157,35 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): firewall_state = access_from_nested_dict(state, self.where) if firewall_state is NOT_PRESENT_IN_STATE: return self.default_observation - obs = { - "PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)}, - "ACL": { - "INTERNAL": { - "INBOUND": self.internal_inbound_acl.observe(state), - "OUTBOUND": self.internal_outbound_acl.observe(state), + + is_on = firewall_state["operating_state"] == 1 + if not is_on: + obs = {**self.default_observation} + + else: + obs = { + "PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)}, + "ACL": { + "INTERNAL": { + "INBOUND": self.internal_inbound_acl.observe(state), + "OUTBOUND": self.internal_outbound_acl.observe(state), + }, + "DMZ": { + "INBOUND": self.dmz_inbound_acl.observe(state), + "OUTBOUND": self.dmz_outbound_acl.observe(state), + }, + "EXTERNAL": { + "INBOUND": self.external_inbound_acl.observe(state), + "OUTBOUND": self.external_outbound_acl.observe(state), + }, }, - "DMZ": { - "INBOUND": self.dmz_inbound_acl.observe(state), - "OUTBOUND": self.dmz_outbound_acl.observe(state), - }, - "EXTERNAL": { - "INBOUND": self.external_inbound_acl.observe(state), - "OUTBOUND": self.external_outbound_acl.observe(state), - }, - }, - } - if self.include_users: - sess = firewall_state["services"]["UserSessionManager"] - obs["users"] = { - "local_login": 1 if sess["current_local_user"] else 0, - "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), } + if self.include_users: + sess = firewall_state["services"]["user-session-manager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } return obs @property @@ -186,34 +196,36 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): :return: Gymnasium space representing the observation space for firewall status. :rtype: spaces.Space """ - space = spaces.Dict( - { - "PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), - "ACL": spaces.Dict( - { - "INTERNAL": spaces.Dict( - { - "INBOUND": self.internal_inbound_acl.space, - "OUTBOUND": self.internal_outbound_acl.space, - } - ), - "DMZ": spaces.Dict( - { - "INBOUND": self.dmz_inbound_acl.space, - "OUTBOUND": self.dmz_outbound_acl.space, - } - ), - "EXTERNAL": spaces.Dict( - { - "INBOUND": self.external_inbound_acl.space, - "OUTBOUND": self.external_outbound_acl.space, - } - ), - } - ), - } - ) - return space + shape = { + "PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), + "ACL": spaces.Dict( + { + "INTERNAL": spaces.Dict( + { + "INBOUND": self.internal_inbound_acl.space, + "OUTBOUND": self.internal_outbound_acl.space, + } + ), + "DMZ": spaces.Dict( + { + "INBOUND": self.dmz_inbound_acl.space, + "OUTBOUND": self.dmz_outbound_acl.space, + } + ), + "EXTERNAL": spaces.Dict( + { + "INBOUND": self.external_inbound_acl.space, + "OUTBOUND": self.external_outbound_acl.space, + } + ), + } + ), + } + if self.include_users: + shape["users"] = spaces.Dict( + {"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)} + ) + return spaces.Dict(shape) @classmethod def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation: diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 2e7c381b..9b979063 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations from typing import Dict, List, Optional @@ -12,11 +12,13 @@ from primaite.game.agent.observations.nic_observations import NICObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.port import Port _LOGGER = getLogger(__name__) -class HostObservation(AbstractObservation, identifier="HOST"): +class HostObservation(AbstractObservation, discriminator="host"): """Host observation, provides status information about a host within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): @@ -44,7 +46,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): """Number of spaces for network interface observations on this host.""" include_nmne: Optional[bool] = None """Whether network interface observations should include number of malicious network events.""" - monitored_traffic: Optional[Dict] = None + monitored_traffic: Optional[Dict[IPProtocol, List[Port]]] = None """A dict containing which traffic types are to be included in the observation.""" include_num_access: Optional[bool] = None """Whether to include the number of accesses to files observations on this host.""" @@ -213,25 +215,31 @@ class HostObservation(AbstractObservation, identifier="HOST"): if node_state is NOT_PRESENT_IN_STATE: return self.default_observation - obs = {} + is_on = node_state["operating_state"] == 1 + if not is_on: + obs = {**self.default_observation} + + else: + obs = {} + if self.services: + obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} + if self.applications: + obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)} + if self.folders: + obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} + if self.nics: + obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)} + if self.include_num_access: + obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] + obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] + if self.include_users: + sess = node_state["services"]["user-session-manager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } + obs["operating_status"] = node_state["operating_state"] - if self.services: - obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} - if self.applications: - obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)} - if self.folders: - obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} - if self.nics: - obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)} - if self.include_num_access: - obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] - obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] - if self.include_users: - sess = node_state["services"]["UserSessionManager"] - obs["users"] = { - "local_login": 1 if sess["current_local_user"] else 0, - "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), - } return obs @property diff --git a/src/primaite/game/agent/observations/link_observation.py b/src/primaite/game/agent/observations/link_observation.py index 9af39a22..014a96c2 100644 --- a/src/primaite/game/agent/observations/link_observation.py +++ b/src/primaite/game/agent/observations/link_observation.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations from typing import Any, Dict, List @@ -13,7 +13,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST _LOGGER = getLogger(__name__) -class LinkObservation(AbstractObservation, identifier="LINK"): +class LinkObservation(AbstractObservation, discriminator="link"): """Link observation, providing information about a specific link within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): @@ -90,7 +90,7 @@ class LinkObservation(AbstractObservation, identifier="LINK"): return cls(where=where) -class LinksObservation(AbstractObservation, identifier="LINKS"): +class LinksObservation(AbstractObservation, discriminator="links"): """Collection of link observations representing multiple links within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 0dabd9f4..8faeb906 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations from typing import ClassVar, Dict, List, Optional @@ -9,10 +9,11 @@ from gymnasium.core import ObsType from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE from primaite.simulator.network.nmne import NMNEConfig -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.port import Port -class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): +class NICObservation(AbstractObservation, discriminator="network-interface"): """Status information about a network interface within the simulation environment.""" capture_nmne: ClassVar[bool] = NMNEConfig().capture_nmne @@ -25,7 +26,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): """Number of the network interface.""" include_nmne: Optional[bool] = None """Whether to include number of malicious network events (NMNE) in the observation.""" - monitored_traffic: Optional[Dict] = None + monitored_traffic: Optional[Dict[IPProtocol, List[Port]]] = None """A dict containing which traffic types are to be included in the observation.""" def __init__( @@ -33,7 +34,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None, - thresholds: Optional[Dict] = {}, + thresholds: Dict = {}, ) -> None: """ Initialise a network interface observation instance. @@ -76,7 +77,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): def _default_monitored_traffic_observation(self, monitored_traffic_config: Dict) -> Dict: default_traffic_obs = {"TRAFFIC": {}} - for protocol in monitored_traffic_config: + for protocol in self.monitored_traffic: protocol = str(protocol).lower() default_traffic_obs["TRAFFIC"][protocol] = {} @@ -84,8 +85,8 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): default_traffic_obs["TRAFFIC"]["icmp"] = {"inbound": 0, "outbound": 0} else: default_traffic_obs["TRAFFIC"][protocol] = {} - for port in monitored_traffic_config[protocol]: - default_traffic_obs["TRAFFIC"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0} + for port in self.monitored_traffic[protocol]: + default_traffic_obs["TRAFFIC"][protocol][port] = {"inbound": 0, "outbound": 0} return default_traffic_obs @@ -147,7 +148,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): """ nic_state = access_from_nested_dict(state, self.where) - if nic_state is NOT_PRESENT_IN_STATE: + if nic_state is NOT_PRESENT_IN_STATE or self.where is None: return self.default_observation obs = {"nic_status": 1 if nic_state["enabled"] else 2} @@ -174,17 +175,16 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): } else: for port in self.monitored_traffic[protocol]: - port_enum = Port[port] - obs["TRAFFIC"][protocol][port_enum.value] = {} + obs["TRAFFIC"][protocol][port] = {} traffic = {"inbound": 0, "outbound": 0} - if nic_state["traffic"][protocol].get(port_enum.value) is not None: - traffic = nic_state["traffic"][protocol][port_enum.value] + if nic_state["traffic"][protocol].get(port) is not None: + traffic = nic_state["traffic"][protocol][port] - obs["TRAFFIC"][protocol][port_enum.value]["inbound"] = self._categorise_traffic( + obs["TRAFFIC"][protocol][port]["inbound"] = self._categorise_traffic( traffic_value=traffic["inbound"], nic_state=nic_state ) - obs["TRAFFIC"][protocol][port_enum.value]["outbound"] = self._categorise_traffic( + obs["TRAFFIC"][protocol][port]["outbound"] = self._categorise_traffic( traffic_value=traffic["outbound"], nic_state=nic_state ) @@ -194,7 +194,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): obs["TRAFFIC"]["icmp"] = {"inbound": 0, "outbound": 0} else: for port in self.monitored_traffic[protocol]: - obs["TRAFFIC"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0} + obs["TRAFFIC"][protocol][port] = {"inbound": 0, "outbound": 0} if self.capture_nmne and self.include_nmne: obs.update({"NMNE": {}}) @@ -233,7 +233,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): else: space["TRAFFIC"][protocol] = spaces.Dict({}) for port in self.monitored_traffic[protocol]: - space["TRAFFIC"][protocol][Port[port].value] = spaces.Dict( + space["TRAFFIC"][protocol][port] = spaces.Dict( {"inbound": spaces.Discrete(11), "outbound": spaces.Discrete(11)} ) @@ -260,7 +260,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): ) -class PortObservation(AbstractObservation, identifier="PORT"): +class PortObservation(AbstractObservation, discriminator="port"): """Port observation, provides status information about a network port within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 26861028..260fac68 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations from typing import Dict, List, Optional @@ -12,11 +12,14 @@ from primaite.game.agent.observations.firewall_observation import FirewallObserv from primaite.game.agent.observations.host_observations import HostObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.observations.router_observation import RouterObservation +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.ipv4_address import StrIP +from primaite.utils.validation.port import Port _LOGGER = getLogger(__name__) -class NodesObservation(AbstractObservation, identifier="NODES"): +class NodesObservation(AbstractObservation, discriminator="nodes"): """Nodes observation, provides status information about nodes within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): @@ -40,7 +43,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"): """Number of network interface cards (NICs).""" include_nmne: Optional[bool] = None """Flag to include nmne.""" - monitored_traffic: Optional[Dict] = None + monitored_traffic: Optional[Dict[IPProtocol, List[Port]]] = None """A dict containing which traffic types are to be included in the observation.""" include_num_access: Optional[bool] = None """Flag to include the number of accesses.""" @@ -56,13 +59,13 @@ class NodesObservation(AbstractObservation, identifier="NODES"): """If True, report user session information.""" num_ports: Optional[int] = None """Number of ports.""" - ip_list: Optional[List[str]] = None + ip_list: Optional[List[StrIP]] = None """List of IP addresses for encoding ACLs.""" wildcard_list: Optional[List[str]] = None """List of IP wildcards for encoding ACLs.""" - port_list: Optional[List[int]] = None + port_list: Optional[List[Port]] = None """List of ports for encoding ACLs.""" - protocol_list: Optional[List[str]] = None + protocol_list: Optional[List[IPProtocol]] = None """List of protocols for encoding ACLs.""" num_rules: Optional[int] = None """Number of rules ACL rules to show.""" @@ -205,7 +208,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"): host_config.applications_requires_scan = config.applications_requires_scan if host_config.include_users is None: host_config.include_users = config.include_users - if host_config.thresholds is None: + if not host_config.thresholds: host_config.thresholds = config.thresholds for router_config in config.routers: @@ -223,7 +226,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"): router_config.num_rules = config.num_rules if router_config.include_users is None: router_config.include_users = config.include_users - if router_config.thresholds is None: + if not router_config.thresholds: router_config.thresholds = config.thresholds for firewall_config in config.firewalls: @@ -239,7 +242,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"): firewall_config.num_rules = config.num_rules if firewall_config.include_users is None: firewall_config.include_users = config.include_users - if firewall_config.thresholds is None: + if not firewall_config.thresholds: firewall_config.thresholds = config.thresholds hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts] diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py index cc32918c..c979c132 100644 --- a/src/primaite/game/agent/observations/observation_manager.py +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -1,16 +1,17 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations +from functools import cached_property from typing import Any, Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType -from pydantic import BaseModel, ConfigDict, model_validator, ValidationError +from pydantic import BaseModel, computed_field, ConfigDict, Field, model_validator, ValidationError from primaite.game.agent.observations.observations import AbstractObservation, WhereType -class NestedObservation(AbstractObservation, identifier="CUSTOM"): +class NestedObservation(AbstractObservation, discriminator="custom"): """Observation type that allows combining other observations into a gymnasium.spaces.Dict space.""" class NestedObservationItem(BaseModel): @@ -18,7 +19,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"): model_config = ConfigDict(extra="forbid") type: str - """Select observation class. It maps to the identifier of the obs class by checking the registry.""" + """Select observation class. It maps to the discriminator of the obs class by checking the registry.""" label: str """Dict key in the final observation space.""" options: Dict @@ -47,7 +48,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"): def __init__(self, components: Dict[str, AbstractObservation]) -> None: """Initialise nested observation.""" self.components: Dict[str, AbstractObservation] = components - """Maps label: observation object""" + """Maps label observation object""" self.default_observation = {label: obs.default_observation for label, obs in self.components.items()} """Default observation is just the default observations of constituents.""" @@ -83,7 +84,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"): ```yaml observation_space: - - type: CUSTOM + - type: custom options: components: @@ -120,7 +121,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"): return cls(components=instances) -class NullObservation(AbstractObservation, identifier="NONE"): +class NullObservation(AbstractObservation, discriminator="none"): """Empty observation that acts as a placeholder.""" def __init__(self) -> None: @@ -142,7 +143,7 @@ class NullObservation(AbstractObservation, identifier="NONE"): return cls() -class ObservationManager: +class ObservationManager(BaseModel): """ Manage the observations of an Agent. @@ -152,15 +153,66 @@ class ObservationManager: 3. Formatting this information so an agent can use it to make decisions. """ - def __init__(self, obs: AbstractObservation) -> None: - """Initialise observation space. + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) - :param observation: Observation object - :type observation: AbstractObservation - """ - self.obs: AbstractObservation = obs - self.current_observation: ObsType - """Cached copy of the observation at the time it was most recently calculated.""" + class ConfigSchema(BaseModel): + """Config Schema for Observation Manager.""" + + model_config = ConfigDict(extra="forbid") + type: str = "none" + """discriminator name for the top-level observation.""" + options: AbstractObservation.ConfigSchema = Field( + default_factory=lambda: NullObservation.ConfigSchema(), validate_default=True + ) + """Options to pass into the top-level observation during creation.""" + + @model_validator(mode="before") + @classmethod + def resolve_obs_options_type(cls, data: Any) -> Any: + """ + When constructing the model from a dict, resolve the correct observation class based on `type` field. + + Workaround: The `options` field is statically typed as AbstractObservation. Therefore, it falls over when + passing in data that adheres to a subclass schema rather than the plain AbstractObservation schema. There is + a way to do this properly using discriminated union, but most advice on the internet assumes that the full + list of types between which to discriminate is known ahead-of-time. That is not the case for us, because of + our plugin architecture. + + We may be able to revisit and implement a better solution when needed using the following resources as + research starting points: + https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions + https://github.com/pydantic/pydantic/issues/7366 + https://github.com/pydantic/pydantic/issues/7462 + https://github.com/pydantic/pydantic/pull/7983 + """ + if not isinstance(data, dict): + return data + + # (TODO: duplicate default definition between here and the actual model) + obs_type = data["type"] if "type" in data else "none" + obs_class = AbstractObservation._registry[obs_type] + + # if no options are passed in, try to create a default schema. Only works if there are no mandatory fields + if "options" not in data: + data["options"] = obs_class.ConfigSchema() + + # if options passed as a dict, validate against schema + elif isinstance(data["options"], dict): + data["options"] = obs_class.ConfigSchema(**data["options"]) + + return data + + config: ConfigSchema = Field(default_factory=lambda: ObservationManager.ConfigSchema()) + + current_observation: ObsType = 0 + + @computed_field + @cached_property + def obs(self) -> AbstractObservation: + """Create the main observation component for the observation manager from the config.""" + obs_class = AbstractObservation._registry[self.config.type] + obs_instance = obs_class.from_config(config=self.config.options) + return obs_instance def update(self, state: Dict) -> Dict: """ @@ -185,7 +237,7 @@ class ObservationManager: :param config: Dictionary containing the configuration for this observation space. If None, a blank observation space is created. Otherwise, this must be a Dict with a type field and options field. - type: string that corresponds to one of the observation identifiers that are provided when subclassing + type: string that corresponds to one of the observation discriminators that are provided when subclassing AbstractObservation options: this must adhere to the chosen observation type's ConfigSchema nested class. :type config: Dict @@ -194,10 +246,5 @@ class ObservationManager: """ if config is None: return cls(NullObservation()) - obs_type = config["type"] - obs_class = AbstractObservation._registry[obs_type] - observation = obs_class.from_config( - config=obs_class.ConfigSchema(**config["options"], thresholds=thresholds), - ) - obs_manager = cls(observation) + obs_manager = cls(config=config) return obs_manager diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index 7a31a26b..8558b75c 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """Manages the observation space for the agent.""" from abc import ABC, abstractmethod from typing import Any, Dict, Iterable, List, Optional, Type, Union @@ -19,7 +19,7 @@ class AbstractObservation(ABC): class ConfigSchema(ABC, BaseModel): """Config schema for observations.""" - thresholds: Optional[Dict] = None + thresholds: Optional[Dict] = {} """A dict containing the observation thresholds.""" model_config = ConfigDict(extra="forbid") @@ -34,18 +34,20 @@ class AbstractObservation(ABC): """Initialise an observation. This method must be overwritten.""" self.default_observation: ObsType - def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: + def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None: """ Register an observation type. - :param identifier: Identifier used to uniquely specify observation component types. - :type identifier: str + :param discriminator: discriminator used to uniquely specify observation component types. + :type discriminator: str :raises ValueError: When attempting to create a component with a name that is already in use. """ super().__init_subclass__(**kwargs) - if identifier in cls._registry: - raise ValueError(f"Duplicate observation component type {identifier}") - cls._registry[identifier] = cls + if discriminator is None: + return + if discriminator in cls._registry: + raise ValueError(f"Duplicate observation component type {discriminator}") + cls._registry[discriminator] = cls @abstractmethod def observe(self, state: Dict) -> Any: diff --git a/src/primaite/game/agent/observations/router_observation.py b/src/primaite/game/agent/observations/router_observation.py index d064936a..9a7f51cd 100644 --- a/src/primaite/game/agent/observations/router_observation.py +++ b/src/primaite/game/agent/observations/router_observation.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations from typing import Dict, List, Optional @@ -11,11 +11,14 @@ from primaite.game.agent.observations.acl_observation import ACLObservation from primaite.game.agent.observations.nic_observations import PortObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.ipv4_address import StrIP +from primaite.utils.validation.port import Port _LOGGER = getLogger(__name__) -class RouterObservation(AbstractObservation, identifier="ROUTER"): +class RouterObservation(AbstractObservation, discriminator="router"): """Router observation, provides status information about a router within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): @@ -29,17 +32,17 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): """Number of port observations configured for this router.""" acl: Optional[ACLObservation.ConfigSchema] = None """Configuration of ACL observation on this router.""" - ip_list: Optional[List[str]] = None + ip_list: Optional[List[StrIP]] = None """List of IP addresses for encoding ACLs.""" wildcard_list: Optional[List[str]] = None """List of IP wildcards for encoding ACLs.""" - port_list: Optional[List[int]] = None + port_list: Optional[List[Port]] = None """List of ports for encoding ACLs.""" - protocol_list: Optional[List[str]] = None + protocol_list: Optional[List[IPProtocol]] = None """List of protocols for encoding ACLs.""" num_rules: Optional[int] = None """Number of rules ACL rules to show.""" - include_users: Optional[bool] = True + include_users: Optional[bool] = None """If True, report user session information.""" def __init__( @@ -84,6 +87,8 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): } if self.ports: self.default_observation["PORTS"] = {i + 1: p.default_observation for i, p in enumerate(self.ports)} + if self.include_users: + self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0} def observe(self, state: Dict) -> ObsType: """ @@ -98,16 +103,21 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): if router_state is NOT_PRESENT_IN_STATE: return self.default_observation - obs = {} - obs["ACL"] = self.acl.observe(state) - if self.ports: - obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)} - if self.include_users: - sess = router_state["services"]["UserSessionManager"] - obs["users"] = { - "local_login": 1 if sess["current_local_user"] else 0, - "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), - } + is_on = router_state["operating_state"] == 1 + if not is_on: + obs = {**self.default_observation} + + else: + obs = {} + obs["ACL"] = self.acl.observe(state) + if self.ports: + obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)} + if self.include_users: + sess = router_state["services"]["user-session-manager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } return obs @property @@ -121,6 +131,10 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): shape = {"ACL": self.acl.space} if self.ports: shape["PORTS"] = spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}) + if self.include_users: + shape["users"] = spaces.Dict( + {"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)} + ) return spaces.Dict(shape) @classmethod diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index 0318c864..dac6b362 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations from typing import Dict, List, Optional @@ -10,7 +10,7 @@ from primaite.game.agent.observations.observations import AbstractObservation, W from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -class ServiceObservation(AbstractObservation, identifier="SERVICE"): +class ServiceObservation(AbstractObservation, discriminator="service"): """Service observation, shows status of a service in the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): @@ -81,7 +81,7 @@ class ServiceObservation(AbstractObservation, identifier="SERVICE"): ) -class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): +class ApplicationObservation(AbstractObservation, discriminator="application"): """Application observation, shows the status of an application within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 1de34b40..39f6fcfb 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """ Manages the reward function for the agent. @@ -12,7 +12,7 @@ the structure: ```yaml reward_function: reward_components: - - type: DATABASE_FILE_INTEGRITY + - type: database-file-integrity weight: 0.5 options: node_name: database_server @@ -20,16 +20,17 @@ the structure: file_name: database.db - - type: WEB_SERVER_404_PENALTY + - type: web-server-404-penalty weight: 0.5 options: node_name: web_server service_ref: web_server_database_client ``` """ -from abc import abstractmethod -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union +from abc import ABC, abstractmethod +from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union +from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Never from primaite import getLogger @@ -42,25 +43,28 @@ _LOGGER = getLogger(__name__) WhereType = Optional[Iterable[Union[str, int]]] -class AbstractReward: +class AbstractReward(BaseModel, ABC): """Base class for reward function components.""" - @abstractmethod - def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Calculate the reward for the current state. + class ConfigSchema(BaseModel, ABC): + """Config schema for AbstractReward.""" - :param state: Current simulation state - :type state: Dict - :param last_action_response: Current agent history state - :type last_action_response: AgentHistoryItem state - :return: Reward value - :rtype: float - """ - return 0.0 + type: str = "" + + config: ConfigSchema + + _registry: ClassVar[Dict[str, Type["AbstractReward"]]] = {} + + def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + if discriminator is None: + return + if discriminator in cls._registry: + raise ValueError(f"Duplicate reward {discriminator}") + cls._registry[discriminator] = cls @classmethod - @abstractmethod - def from_config(cls, config: dict) -> "AbstractReward": + def from_config(cls, config: Dict) -> "AbstractReward": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor @@ -68,11 +72,28 @@ class AbstractReward: :return: The reward component. :rtype: AbstractReward """ - return cls() + if config["type"] not in cls._registry: + raise ValueError(f"Invalid reward type {config['type']}") + reward_class = cls._registry[config["type"]] + reward_obj = reward_class(config=reward_class.ConfigSchema(**config)) + return reward_obj + + @abstractmethod + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: + """Calculate the reward for the current state. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ + return 0.0 -class DummyReward(AbstractReward): - """Dummy reward function component which always returns 0.""" +class DummyReward(AbstractReward, discriminator="dummy"): + """Dummy reward function component which always returns 0.0.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. @@ -86,41 +107,21 @@ class DummyReward(AbstractReward): """ return 0.0 - @classmethod - def from_config(cls, config: dict) -> "DummyReward": - """Create a reward function component from a config dictionary. - :param config: dict of options for the reward component's constructor. Should be empty. - :type config: dict - :return: The reward component. - :rtype: DummyReward - """ - return cls() - - -class DatabaseFileIntegrity(AbstractReward): +class DatabaseFileIntegrity(AbstractReward, discriminator="database-file-integrity"): """Reward function component which rewards the agent for maintaining the integrity of a database file.""" - def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None: - """Initialise the reward component. + config: "DatabaseFileIntegrity.ConfigSchema" + location_in_state: List[str] = [""] + reward: float = 0.0 - :param node_hostname: Hostname of the node which contains the database file. - :type node_hostname: str - :param folder_name: folder which contains the database file. - :type folder_name: str - :param file_name: name of the database file. - :type file_name: str - """ - self.location_in_state = [ - "network", - "nodes", - node_hostname, - "file_system", - "folders", - folder_name, - "files", - file_name, - ] + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for DatabaseFileIntegrity.""" + + type: str = "database-file-integrity" + node_hostname: str + folder_name: str + file_name: str def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. @@ -132,6 +133,17 @@ class DatabaseFileIntegrity(AbstractReward): :return: Reward value :rtype: float """ + self.location_in_state = [ + "network", + "nodes", + self.config.node_hostname, + "file_system", + "folders", + self.config.folder_name, + "files", + self.config.file_name, + ] + database_file_state = access_from_nested_dict(state, self.location_in_state) if database_file_state is NOT_PRESENT_IN_STATE: _LOGGER.debug( @@ -148,44 +160,21 @@ class DatabaseFileIntegrity(AbstractReward): else: return 0 - @classmethod - def from_config(cls, config: Dict) -> "DatabaseFileIntegrity": - """Create a reward function component from a config dictionary. - :param config: dict of options for the reward component's constructor - :type config: Dict - :return: The reward component. - :rtype: DatabaseFileIntegrity - """ - node_hostname = config.get("node_hostname") - folder_name = config.get("folder_name") - file_name = config.get("file_name") - if not (node_hostname and folder_name and file_name): - msg = f"{cls.__name__} could not be initialised with parameters {config}" - _LOGGER.error(msg) - raise ValueError(msg) - - return cls(node_hostname=node_hostname, folder_name=folder_name, file_name=file_name) - - -class WebServer404Penalty(AbstractReward): +class WebServer404Penalty(AbstractReward, discriminator="web-server-404-penalty"): """Reward function component which penalises the agent when the web server returns a 404 error.""" - def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None: - """Initialise the reward component. + config: "WebServer404Penalty.ConfigSchema" + location_in_state: List[str] = [""] + reward: float = 0.0 - :param node_hostname: Hostname of the node which contains the web server service. - :type node_hostname: str - :param service_name: Name of the web server service. - :type service_name: str - :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate - the reward if there were any responses this timestep. - :type sticky: bool - """ - self.sticky: bool = sticky - self.reward: float = 0.0 - """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" - self.location_in_state = ["network", "nodes", node_hostname, "services", service_name] + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for WebServer404Penalty.""" + + type: str = "web-server-404-penalty" + node_hostname: str + service_name: str + sticky: bool = True def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. @@ -197,6 +186,13 @@ class WebServer404Penalty(AbstractReward): :return: Reward value :rtype: float """ + self.location_in_state = [ + "network", + "nodes", + self.config.node_hostname, + "services", + self.config.service_name, + ] web_service_state = access_from_nested_dict(state, self.location_in_state) # if webserver is no longer installed on the node, return 0 @@ -211,54 +207,27 @@ class WebServer404Penalty(AbstractReward): return 1.0 if status == 200 else -1.0 if status == 404 else 0.0 self.reward = sum(map(status2rew, codes)) / len(codes) # convert form HTTP codes to rewards and average - elif not self.sticky: # there are no codes, but reward is not sticky, set reward to 0 + elif not self.config.sticky: # there are no codes, but reward is not sticky, set reward to 0 self.reward = 0.0 else: # skip calculating if sticky and no new codes. instead, reuse last step's value pass return self.reward - @classmethod - def from_config(cls, config: Dict) -> "WebServer404Penalty": - """Create a reward function component from a config dictionary. - :param config: dict of options for the reward component's constructor - :type config: Dict - :return: The reward component. - :rtype: WebServer404Penalty - """ - node_hostname = config.get("node_hostname") - service_name = config.get("service_name") - if not (node_hostname and service_name): - msg = ( - f"{cls.__name__} could not be initialised from config because node_name and service_ref were not " - "found in reward config." - ) - _LOGGER.warning(msg) - raise ValueError(msg) - sticky = config.get("sticky", True) - - return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky) - - -class WebpageUnavailablePenalty(AbstractReward): +class WebpageUnavailablePenalty(AbstractReward, discriminator="webpage-unavailable-penalty"): """Penalises the agent when the web browser fails to fetch a webpage.""" - def __init__(self, node_hostname: str, sticky: bool = True) -> None: - """ - Initialise the reward component. + config: "WebpageUnavailablePenalty.ConfigSchema" + reward: float = 0.0 + location_in_state: List[str] = [""] # Calculate in __init__()? - :param node_hostname: Hostname of the node which has the web browser. - :type node_hostname: str - :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate - the reward if there were any responses this timestep. - :type sticky: bool - """ - self._node: str = node_hostname - self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"] - self.sticky: bool = sticky - self.reward: float = 0.0 - """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for WebpageUnavailablePenalty.""" + + type: str = "webpage-unavailable-penalty" + node_hostname: str = "" + sticky: bool = True def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ @@ -274,6 +243,13 @@ class WebpageUnavailablePenalty(AbstractReward): :return: Reward value :rtype: float """ + self.location_in_state = [ + "network", + "nodes", + self.config.node_hostname, + "applications", + "web-browser", + ] web_browser_state = access_from_nested_dict(state, self.location_in_state) if web_browser_state is NOT_PRESENT_IN_STATE: @@ -283,14 +259,14 @@ class WebpageUnavailablePenalty(AbstractReward): request_attempted = last_action_response.request == [ "network", "node", - self._node, + self.config.node_hostname, "application", - "WebBrowser", + "web-browser", "execute", ] # skip calculating if sticky and no new codes, reusing last step value - if not request_attempted and self.sticky: + if not request_attempted and self.config.sticky: return self.reward if last_action_response.response.status != "success": @@ -298,7 +274,7 @@ class WebpageUnavailablePenalty(AbstractReward): elif web_browser_state is NOT_PRESENT_IN_STATE or not web_browser_state["history"]: _LOGGER.debug( "Web browser reward could not be calculated because the web browser history on node", - f"{self._node} was not reported in the simulation state. Returning 0.0", + f"{self.config.node_hostname} was not reported in the simulation state. Returning 0.0", ) self.reward = 0.0 else: @@ -312,37 +288,19 @@ class WebpageUnavailablePenalty(AbstractReward): return self.reward - @classmethod - def from_config(cls, config: dict) -> AbstractReward: - """ - Build the reward component object from config. - :param config: Configuration dictionary. - :type config: Dict - """ - node_hostname = config.get("node_hostname") - sticky = config.get("sticky", True) - return cls(node_hostname=node_hostname, sticky=sticky) - - -class GreenAdminDatabaseUnreachablePenalty(AbstractReward): +class GreenAdminDatabaseUnreachablePenalty(AbstractReward, discriminator="green-admin-database-unreachable-penalty"): """Penalises the agent when the green db clients fail to connect to the database.""" - def __init__(self, node_hostname: str, sticky: bool = True) -> None: - """ - Initialise the reward component. + config: "GreenAdminDatabaseUnreachablePenalty.ConfigSchema" + reward: float = 0.0 - :param node_hostname: Hostname of the node where the database client sits. - :type node_hostname: str - :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate - the reward if there were any responses this timestep. - :type sticky: bool - """ - self._node: str = node_hostname - self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] - self.sticky: bool = sticky - self.reward: float = 0.0 - """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for GreenAdminDatabaseUnreachablePenalty.""" + + type: str = "green-admin-database-unreachable-penalty" + node_hostname: str + sticky: bool = True def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ @@ -362,16 +320,16 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): request_attempted = last_action_response.request == [ "network", "node", - self._node, + self.config.node_hostname, "application", - "DatabaseClient", + "database-client", "execute", ] if request_attempted: # if agent makes request, always recalculate fresh value last_action_response.reward_info = {"connection_attempt_status": last_action_response.response.status} self.reward = 1.0 if last_action_response.response.status == "success" else -1.0 - elif not self.sticky: # if no new request and not sticky, set reward to 0 + elif not self.config.sticky: # if no new request and not sticky, set reward to 0 last_action_response.reward_info = {"connection_attempt_status": "n/a"} self.reward = 0.0 else: # if no new request and sticky, reuse reward value from last step @@ -380,47 +338,30 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): return self.reward - @classmethod - def from_config(cls, config: Dict) -> AbstractReward: - """ - Build the reward component object from config. - :param config: Configuration dictionary. - :type config: Dict - """ - node_hostname = config.get("node_hostname") - sticky = config.get("sticky", True) - return cls(node_hostname=node_hostname, sticky=sticky) - - -class SharedReward(AbstractReward): +class SharedReward(AbstractReward, discriminator="shared-reward"): """Adds another agent's reward to the overall reward.""" - def __init__(self, agent_name: Optional[str] = None) -> None: + config: "SharedReward.ConfigSchema" + + class ConfigSchema(AbstractReward.ConfigSchema): + """Config schema for SharedReward.""" + + type: str = "shared-reward" + agent_name: str + + def default_callback(agent_name: str) -> Never: """ - Initialise the shared reward. + Default callback to prevent calling this reward until it's properly initialised. - The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work - correctly. - - :param agent_name: The name whose reward is an input - :type agent_name: Optional[str] + SharedReward should not be used until the game layer replaces self.callback with a reference to the + function that retrieves the desired agent's reward. Therefore, we define this default callback that raises + an error. """ - self.agent_name = agent_name - """Agent whose reward to track.""" + raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.") - def default_callback(agent_name: str) -> Never: - """ - Default callback to prevent calling this reward until it's properly initialised. - - SharedReward should not be used until the game layer replaces self.callback with a reference to the - function that retrieves the desired agent's reward. Therefore, we define this default callback that raises - an error. - """ - raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.") - - self.callback: Callable[[str], float] = default_callback - """Method that retrieves an agent's current reward given the agent's name.""" + callback: Callable[[str], float] = default_callback + """Method that retrieves an agent's current reward given the agent's name.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Simply access the other agent's reward and return it. @@ -432,36 +373,25 @@ class SharedReward(AbstractReward): :return: Reward value :rtype: float """ - return self.callback(self.agent_name) - - @classmethod - def from_config(cls, config: Dict) -> "SharedReward": - """ - Build the SharedReward object from config. - - :param config: Configuration dictionary - :type config: Dict - """ - agent_name = config.get("agent_name") - return cls(agent_name=agent_name) + return self.callback(self.config.agent_name) -class ActionPenalty(AbstractReward): - """Apply a negative reward when taking any action except DONOTHING.""" +class ActionPenalty(AbstractReward, discriminator="action-penalty"): + """Apply a negative reward when taking any action except do-nothing.""" - def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None: - """ - Initialise the reward. + config: "ActionPenalty.ConfigSchema" - Reward or penalise agents for doing nothing or taking actions. + class ConfigSchema(AbstractReward.ConfigSchema): + """Config schema for ActionPenalty. - :param action_penalty: Reward to give agents for taking any action except DONOTHING + :param action_penalty: Reward to give agents for taking any action except do-nothing :type action_penalty: float - :param do_nothing_penalty: Reward to give agent for taking the DONOTHING action + :param do_nothing_penalty: Reward to give agent for taking the do-nothing action :type do_nothing_penalty: float """ - self.action_penalty = action_penalty - self.do_nothing_penalty = do_nothing_penalty + + action_penalty: float = -1.0 + do_nothing_penalty: float = 0.0 def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the penalty to be applied. @@ -473,39 +403,81 @@ class ActionPenalty(AbstractReward): :return: Reward value :rtype: float """ - if last_action_response.action == "DONOTHING": - return self.do_nothing_penalty + if last_action_response.action == "do-nothing": + return self.config.do_nothing_penalty + else: - return self.action_penalty + return self.config.action_penalty + +class _SingleComponentConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + type: str + options: AbstractReward.ConfigSchema + weight: float = 1.0 + + @model_validator(mode="before") @classmethod - def from_config(cls, config: Dict) -> "ActionPenalty": - """Build the ActionPenalty object from config.""" - action_penalty = config.get("action_penalty", -1.0) - do_nothing_penalty = config.get("do_nothing_penalty", 0.0) - return cls(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty) + def resolve_obs_options_type(cls, data: Any) -> Any: + """ + When constructing the model from a dict, resolve the correct reward class based on `type` field. + + Workaround: The `options` field is statically typed as AbstractReward. Therefore, it falls over when + passing in data that adheres to a subclass schema rather than the plain AbstractReward schema. There is + a way to do this properly using discriminated union, but most advice on the internet assumes that the full + list of types between which to discriminate is known ahead-of-time. That is not the case for us, because of + our plugin architecture. + + We may be able to revisit and implement a better solution when needed using the following resources as + research starting points: + https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions + https://github.com/pydantic/pydantic/issues/7366 + https://github.com/pydantic/pydantic/issues/7462 + https://github.com/pydantic/pydantic/pull/7983 + """ + if not isinstance(data, dict): + return data + + assert "type" in data, ValueError('Reward component definition is missing the "type" key.') + rew_type = data["type"] + rew_class = AbstractReward._registry[rew_type] + + # if no options are passed in, try to create a default schema. Only works if there are no mandatory fields. + if "options" not in data: + data["options"] = rew_class.ConfigSchema() + + # if options are passed as a dict, validate against schema + elif isinstance(data["options"], dict): + data["options"] = rew_class.ConfigSchema(**data["options"]) + + return data -class RewardFunction: +class RewardFunction(BaseModel): """Manages the reward function for the agent.""" - rew_class_identifiers: Dict[str, Type[AbstractReward]] = { - "DUMMY": DummyReward, - "DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity, - "WEB_SERVER_404_PENALTY": WebServer404Penalty, - "WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty, - "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty, - "SHARED_REWARD": SharedReward, - "ACTION_PENALTY": ActionPenalty, - } - """List of reward class identifiers.""" + model_config = ConfigDict(extra="forbid") - def __init__(self): - """Initialise the reward function object.""" - self.reward_components: List[Tuple[AbstractReward, float]] = [] - "attribute reward_components keeps track of reward components and the weights assigned to each." - self.current_reward: float = 0.0 - self.total_reward: float = 0.0 + class ConfigSchema(BaseModel): + """Config Schema for RewardFunction.""" + + model_config = ConfigDict(extra="forbid") + + reward_components: Iterable[_SingleComponentConfig] = [] + + config: ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema()) + + reward_components: List[Tuple[AbstractReward, float]] = [] + + current_reward: float = 0.0 + total_reward: float = 0.0 + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + for rew_config in self.config.reward_components: + rew_class = AbstractReward._registry[rew_config.type] + rew_instance = rew_class(config=rew_config.options) + self.register_component(component=rew_instance, weight=rew_config.weight) def register_component(self, component: AbstractReward, weight: float = 1.0) -> None: """Add a reward component to the reward function. @@ -534,7 +506,7 @@ class RewardFunction: @classmethod def from_config(cls, config: Dict) -> "RewardFunction": - """Create a reward function from a config dictionary. + """Create a reward function from a config dictionary and its related reward class. :param config: dict of options for the reward manager's constructor :type config: Dict @@ -545,8 +517,11 @@ class RewardFunction: for rew_component_cfg in config["reward_components"]: rew_type = rew_component_cfg["type"] + # XXX: If options key is missing add key then add type key. + if "options" not in rew_component_cfg: + rew_component_cfg["options"] = {} + rew_component_cfg["options"]["type"] = rew_type weight = rew_component_cfg.get("weight", 1.0) - rew_class = cls.rew_class_identifiers[rew_type] - rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {})) + rew_instance = AbstractReward.from_config(rew_component_cfg["options"]) new.register_component(component=rew_instance, weight=weight) return new diff --git a/src/primaite/game/agent/scripted_agents/__init__.py b/src/primaite/game/agent/scripted_agents/__init__.py index be6c00e7..5a97d15b 100644 --- a/src/primaite/game/agent/scripted_agents/__init__.py +++ b/src/primaite/game/agent/scripted_agents/__init__.py @@ -1 +1,6 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + +from primaite.game.agent import interface +from primaite.game.agent.scripted_agents import abstract_tap, data_manipulation_bot, probabilistic_agent, random_agent + +__all__ = ("abstract_tap", "data_manipulation_bot", "interface", "probabilistic_agent", "random_agent") diff --git a/src/primaite/game/agent/scripted_agents/abstract_tap.py b/src/primaite/game/agent/scripted_agents/abstract_tap.py new file mode 100644 index 00000000..679f69fa --- /dev/null +++ b/src/primaite/game/agent/scripted_agents/abstract_tap.py @@ -0,0 +1,61 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from __future__ import annotations + +import random +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Tuple + +from gymnasium.core import ObsType +from pydantic import Field + +from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent + +__all__ = "AbstractTAPAgent" + + +class AbstractTAPAgent(PeriodicAgent, ABC): + """Base class for TAP agents to inherit from.""" + + config: "AbstractTAPAgent.ConfigSchema" = Field(default_factory=lambda: AbstractTAPAgent.ConfigSchema()) + next_execution_timestep: int = 0 + + class AgentSettingsSchema(PeriodicAgent.AgentSettingsSchema, ABC): + """Schema for the `agent_settings` part of the agent config.""" + + possible_starting_nodes: List[str] = Field(default_factory=list) + + class ConfigSchema(PeriodicAgent.ConfigSchema, ABC): + """Configuration schema for Abstract TAP agents.""" + + type: str = "abstract-tap" + agent_settings: AbstractTAPAgent.AgentSettingsSchema = Field( + default_factory=lambda: AbstractTAPAgent.AgentSettingsSchema() + ) + + starting_node: Optional[str] = None + + @abstractmethod + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: + """Return an action to be taken in the environment.""" + return super().get_action(obs=obs, timestep=timestep) + + @abstractmethod + def setup_agent(self) -> None: + """Set up agent.""" + pass + + def _set_next_execution_timestep(self, timestep: int) -> None: + """Set the next execution timestep with a configured random variance. + + :param timestep: The timestep to add variance to. + """ + random_timestep_increment = random.randint( + -self.config.agent_settings.variance, self.config.agent_settings.variance + ) + self.next_execution_timestep = timestep + random_timestep_increment + + def _select_start_node(self) -> None: + """Set the starting starting node of the agent to be a random node from this agent's action manager.""" + # we are assuming that every node in the node manager has a data manipulation application at idx 0 + self.starting_node = random.choice(self.config.agent_settings.possible_starting_nodes) + self.logger.debug(f"Selected starting node: {self.starting_node}") diff --git a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py index 129fac1a..7a88bd12 100644 --- a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py +++ b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py @@ -1,31 +1,33 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -import random +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Dict, Tuple from gymnasium.core import ObsType +from pydantic import Field -from primaite.game.agent.interface import AbstractScriptedAgent +from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent -class DataManipulationAgent(AbstractScriptedAgent): +class DataManipulationAgent(PeriodicAgent, discriminator="red-database-corrupting-agent"): """Agent that uses a DataManipulationBot to perform an SQL injection attack.""" - next_execution_timestep: int = 0 - starting_node_idx: int = 0 + class AgentSettingsSchema(PeriodicAgent.AgentSettingsSchema): + """Schema for the `agent_settings` part of the agent config.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setup_agent() + target_application: str = "data-manipulation-bot" - def _set_next_execution_timestep(self, timestep: int) -> None: - """Set the next execution timestep with a configured random variance. + class ConfigSchema(PeriodicAgent.ConfigSchema): + """Configuration Schema for DataManipulationAgent.""" - :param timestep: The timestep to add variance to. - """ - random_timestep_increment = random.randint( - -self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance + type: str = "red-database-corrupting-agent" + agent_settings: "DataManipulationAgent.AgentSettingsSchema" = Field( + default_factory=lambda: DataManipulationAgent.AgentSettingsSchema() ) - self.next_execution_timestep = timestep + random_timestep_increment + + config: "DataManipulationAgent.ConfigSchema" = Field(default_factory=lambda: DataManipulationAgent.ConfigSchema()) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._set_next_execution_timestep(timestep=self.config.agent_settings.start_step, variance=0) def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]: """Waits until a specific timestep, then attempts to execute its data manipulation application. @@ -38,21 +40,14 @@ class DataManipulationAgent(AbstractScriptedAgent): :rtype: Tuple[str, Dict] """ if timestep < self.next_execution_timestep: - self.logger.debug(msg="Performing do NOTHING") - return "DONOTHING", {} + self.logger.debug(msg="Performing do nothing action") + return "do-nothing", {} - self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency) + self._set_next_execution_timestep( + timestep=timestep + self.config.agent_settings.frequency, variance=self.config.agent_settings.variance + ) self.logger.info(msg="Performing a data manipulation attack!") - return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0} - - def setup_agent(self) -> None: - """Set the next execution timestep when the episode resets.""" - self._select_start_node() - self._set_next_execution_timestep(self.agent_settings.start_settings.start_step) - - def _select_start_node(self) -> None: - """Set the starting starting node of the agent to be a random node from this agent's action manager.""" - # we are assuming that every node in the node manager has a data manipulation application at idx 0 - num_nodes = len(self.action_manager.node_names) - self.starting_node_idx = random.randint(0, num_nodes - 1) - self.logger.debug(msg=f"Select Start Node ID: {self.starting_node_idx}") + return "node-application-execute", { + "node_name": self.start_node, + "application_name": self.config.agent_settings.target_application, + } diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py index cd44644f..babf9179 100644 --- a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -1,29 +1,28 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """Agents with predefined behaviours.""" -from typing import Dict, Optional, Tuple +from typing import Dict, Tuple import numpy as np import pydantic from gymnasium.core import ObsType +from numpy.random import Generator +from pydantic import Field -from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractScriptedAgent -from primaite.game.agent.observations.observation_manager import ObservationManager -from primaite.game.agent.rewards import RewardFunction + +__all__ = "ProbabilisticAgent" -class ProbabilisticAgent(AbstractScriptedAgent): +class ProbabilisticAgent(AbstractScriptedAgent, discriminator="probabilistic-agent"): """Scripted agent which randomly samples its action space with prescribed probabilities for each action.""" - class Settings(pydantic.BaseModel): - """Config schema for Probabilistic agent settings.""" + rng: Generator = Field(default_factory=lambda: np.random.default_rng(np.random.randint(0, 65535))) - model_config = pydantic.ConfigDict(extra="forbid") - """Strict validation.""" - action_probabilities: Dict[int, float] + class AgentSettingsSchema(AbstractScriptedAgent.AgentSettingsSchema): + """Schema for the `agent_settings` part of the agent config.""" + + action_probabilities: Dict[int, float] = None """Probability to perform each action in the action map. The sum of probabilities should sum to 1.""" - # TODO: give the option to still set a random seed, but have it vary each episode in a predictable way - # for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed. @pydantic.field_validator("action_probabilities", mode="after") @classmethod @@ -44,31 +43,20 @@ class ProbabilisticAgent(AbstractScriptedAgent): ) return v - def __init__( - self, - agent_name: str, - action_space: Optional[ActionManager], - observation_space: Optional[ObservationManager], - reward_function: Optional[RewardFunction], - settings: Dict = {}, - ) -> None: - # If the action probabilities are not specified, create equal probabilities for all actions - if "action_probabilities" not in settings: - num_actions = len(action_space.action_map) - settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}} + class ConfigSchema(AbstractScriptedAgent.ConfigSchema): + """Configuration schema for Probabilistic Agent.""" - # The random number seed for np.random is dependent on whether a random number seed is set - # in the config file. If there is one it is processed by set_random_seed() in environment.py - # and as a consequence the the sequence of rng_seed's used here will be repeatable. - self.settings = ProbabilisticAgent.Settings(**settings) - rng_seed = np.random.randint(0, 65535) - self.rng = np.random.default_rng(rng_seed) + type: str = "probabilistic-agent" + agent_settings: "ProbabilisticAgent.AgentSettingsSchema" = Field( + default_factory=lambda: ProbabilisticAgent.AgentSettingsSchema() + ) - # convert probabilities from - self.probabilities = np.asarray(list(self.settings.action_probabilities.values())) + config: "ProbabilisticAgent.ConfigSchema" = Field(default_factory=lambda: ProbabilisticAgent.ConfigSchema()) - super().__init__(agent_name, action_space, observation_space, reward_function) - self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}") + @property + def probabilities(self) -> Dict[str, int]: + """Convenience method to view the probabilities of the Agent.""" + return np.asarray(list(self.config.agent_settings.action_probabilities.values())) def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ diff --git a/src/primaite/game/agent/scripted_agents/random_agent.py b/src/primaite/game/agent/scripted_agents/random_agent.py index df9273f7..eebf2c93 100644 --- a/src/primaite/game/agent/scripted_agents/random_agent.py +++ b/src/primaite/game/agent/scripted_agents/random_agent.py @@ -1,20 +1,27 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import random -from typing import Dict, Optional, Tuple +from functools import cached_property +from typing import Dict, List, Tuple from gymnasium.core import ObsType -from pydantic import BaseModel +from pydantic import computed_field, Field, model_validator -from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractScriptedAgent -from primaite.game.agent.observations.observation_manager import ObservationManager -from primaite.game.agent.rewards import RewardFunction + +__all__ = ("RandomAgent", "PeriodicAgent") -class RandomAgent(AbstractScriptedAgent): +class RandomAgent(AbstractScriptedAgent, discriminator="random-agent"): """Agent that ignores its observation and acts completely at random.""" - def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: + config: "RandomAgent.ConfigSchema" = Field(default_factory=lambda: RandomAgent.ConfigSchema()) + + class ConfigSchema(AbstractScriptedAgent.ConfigSchema): + """Configuration Schema for Random Agents.""" + + type: str = "random-agent" + + def get_action(self) -> Tuple[str, Dict]: """Sample the action space randomly. :param obs: Current observation for this agent, not used in RandomAgent @@ -27,41 +34,66 @@ class RandomAgent(AbstractScriptedAgent): return self.action_manager.get_action(self.action_manager.space.sample()) -class PeriodicAgent(AbstractScriptedAgent): +class PeriodicAgent(AbstractScriptedAgent, discriminator="periodic-agent"): """Agent that does nothing most of the time, but executes application at regular intervals (with variance).""" - class Settings(BaseModel): - """Configuration values for when an agent starts performing actions.""" + config: "PeriodicAgent.ConfigSchema" = Field(default_factory=lambda: PeriodicAgent.ConfigSchema()) - start_step: int = 20 - "The timestep at which an agent begins performing it's actions." - start_variance: int = 5 - "Deviation around the start step." + class AgentSettingsSchema(AbstractScriptedAgent.AgentSettingsSchema): + """Schema for the `agent_settings` part of the agent config.""" + + start_step: int = 5 + "The timestep at which an agent begins performing it's actions" + start_variance: int = 0 frequency: int = 5 - "The number of timesteps to wait between performing actions." + "The number of timesteps to wait between performing actions" variance: int = 0 - "The amount the frequency can randomly change to." + "The amount the frequency can randomly change to" max_executions: int = 999999 - "Maximum number of times the agent can execute its action." + possible_start_nodes: List[str] + target_application: str - def __init__( - self, - agent_name: str, - action_space: ActionManager, - observation_space: ObservationManager, - reward_function: RewardFunction, - settings: Optional[Settings] = None, - ) -> None: - """Initialise PeriodicAgent.""" - super().__init__( - agent_name=agent_name, - action_space=action_space, - observation_space=observation_space, - reward_function=reward_function, + @model_validator(mode="after") + def check_variance_lt_frequency(self) -> "PeriodicAgent.ConfigSchema": + """ + Make sure variance is equal to or lower than frequency. + + This is because the calculation for the next execution time is now + (frequency +- variance). + If variance were greater than frequency, sometimes the bracketed term would be negative + and the attack would never happen again. + """ + if self.variance >= self.frequency: + raise ValueError( + f"Agent start settings error: variance must be lower than frequency " + f"{self.variance=}, {self.frequency=}" + ) + return self + + class ConfigSchema(AbstractScriptedAgent.ConfigSchema): + """Configuration Schema for Periodic Agent.""" + + type: str = "periodic-agent" + """Name of the agent.""" + agent_settings: "PeriodicAgent.AgentSettingsSchema" = Field( + default_factory=lambda: PeriodicAgent.AgentSettingsSchema() ) - self.settings = settings or PeriodicAgent.Settings() - self._set_next_execution_timestep(timestep=self.settings.start_step, variance=self.settings.start_variance) - self.num_executions = 0 + + num_executions: int = 0 + """Number of times the agent has executed an action.""" + next_execution_timestep: int = 0 + """Timestep of the next action execution by the agent.""" + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self._set_next_execution_timestep( + timestep=self.config.agent_settings.start_step, variance=self.config.agent_settings.start_variance + ) + + @computed_field + @cached_property + def start_node(self) -> str: + """On instantiation, randomly select a start node.""" + return random.choice(self.config.agent_settings.possible_start_nodes) def _set_next_execution_timestep(self, timestep: int, variance: int) -> None: """Set the next execution timestep with a configured random variance. @@ -76,9 +108,14 @@ class PeriodicAgent(AbstractScriptedAgent): def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]: """Do nothing, unless the current timestep is the next execution timestep, in which case do the action.""" - if timestep == self.next_execution_timestep and self.num_executions < self.settings.max_executions: + if timestep == self.next_execution_timestep and self.num_executions < self.config.agent_settings.max_executions: self.num_executions += 1 - self._set_next_execution_timestep(timestep + self.settings.frequency, self.settings.variance) - return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0} + self._set_next_execution_timestep( + timestep + self.config.agent_settings.frequency, self.config.agent_settings.variance + ) + return "node-application-execute", { + "node_name": self.start_node, + "application_name": self.config.agent_settings.target_application, + } - return "DONOTHING", {} + return "do-nothing", {} diff --git a/src/primaite/game/agent/scripted_agents/tap001.py b/src/primaite/game/agent/scripted_agents/tap001.py deleted file mode 100644 index c4f6062a..00000000 --- a/src/primaite/game/agent/scripted_agents/tap001.py +++ /dev/null @@ -1,78 +0,0 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -import random -from typing import Dict, Tuple - -from gymnasium.core import ObsType - -from primaite.game.agent.interface import AbstractScriptedAgent - - -class TAP001(AbstractScriptedAgent): - """ - TAP001 | Mobile Malware -- Ransomware Variant. - - Scripted Red Agent. Capable of one action; launching the kill-chain (Ransomware Application) - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setup_agent() - - next_execution_timestep: int = 0 - starting_node_idx: int = 0 - installed: bool = False - - def _set_next_execution_timestep(self, timestep: int) -> None: - """Set the next execution timestep with a configured random variance. - - :param timestep: The timestep to add variance to. - """ - random_timestep_increment = random.randint( - -self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance - ) - self.next_execution_timestep = timestep + random_timestep_increment - - def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]: - """Waits until a specific timestep, then attempts to execute the ransomware application. - - This application acts a wrapper around the kill-chain, similar to green-analyst and - the previous UC2 data manipulation bot. - - :param obs: Current observation for this agent. - :type obs: ObsType - :param timestep: The current simulation timestep, used for scheduling actions - :type timestep: int - :return: Action formatted in CAOS format - :rtype: Tuple[str, Dict] - """ - if timestep < self.next_execution_timestep: - return "DONOTHING", {} - - self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency) - - if not self.installed: - self.installed = True - return "NODE_APPLICATION_INSTALL", { - "node_id": self.starting_node_idx, - "application_name": "RansomwareScript", - } - - return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0} - - def setup_agent(self) -> None: - """Set the next execution timestep when the episode resets.""" - self._select_start_node() - self._set_next_execution_timestep(self.agent_settings.start_settings.start_step) - for n, act in self.action_manager.action_map.items(): - if not act[0] == "NODE_APPLICATION_INSTALL": - continue - if act[1]["node_id"] == self.starting_node_idx: - self.ip_address = act[1]["ip_address"] - return - raise RuntimeError("TAP001 agent could not find database server ip address in action map") - - def _select_start_node(self) -> None: - """Set the starting starting node of the agent to be a random node from this agent's action manager.""" - # we are assuming that every node in the node manager has a data manipulation application at idx 0 - num_nodes = len(self.action_manager.node_names) - self.starting_node_idx = random.randint(0, num_nodes - 1) diff --git a/src/primaite/game/agent/utils.py b/src/primaite/game/agent/utils.py index 15efd0b6..87b02858 100644 --- a/src/primaite/game/agent/utils.py +++ b/src/primaite/game/agent/utils.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Any, Dict, Hashable, Optional, Sequence NOT_PRESENT_IN_STATE = object() diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 8d25af07..a3b77ec3 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -1,36 +1,22 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """PrimAITE game - Encapsulates the simulation and agents.""" -from ipaddress import IPv4Address from typing import Dict, List, Optional, Union import numpy as np from pydantic import BaseModel, ConfigDict from primaite import DEFAULT_BANDWIDTH, getLogger -from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent +from primaite.game.agent.interface import AbstractAgent, ProxyAgent from primaite.game.agent.observations import NICObservation -from primaite.game.agent.observations.observation_manager import ObservationManager -from primaite.game.agent.rewards import RewardFunction, SharedReward -from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent -from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent -from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent -from primaite.game.agent.scripted_agents.tap001 import TAP001 +from primaite.game.agent.rewards import SharedReward from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT -from primaite.simulator.file_system.file_type import FileType -from primaite.simulator.network.airspace import AirSpaceFrequency -from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState, UserManager -from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.creation import NetworkNodeAdder +from primaite.simulator.network.hardware.base import NetworkInterface, Node, NodeOperatingState, UserManager from primaite.simulator.network.hardware.nodes.host.host_node import NIC -from primaite.simulator.network.hardware.nodes.host.server import Printer, Server -from primaite.simulator.network.hardware.nodes.network.firewall import Firewall -from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from primaite.simulator.network.nmne import NMNEConfig -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient # noqa: F401 @@ -53,19 +39,21 @@ from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.terminal.terminal import Terminal from primaite.simulator.system.services.web_server.web_server import WebServer from primaite.simulator.system.software import Software +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) SERVICE_TYPES_MAPPING = { - "DNSClient": DNSClient, - "DNSServer": DNSServer, - "DatabaseService": DatabaseService, - "WebServer": WebServer, - "FTPClient": FTPClient, - "FTPServer": FTPServer, - "NTPClient": NTPClient, - "NTPServer": NTPServer, - "Terminal": Terminal, + "dns-client": DNSClient, + "dns-server": DNSServer, + "database-service": DatabaseService, + "web-server": WebServer, + "ftp-client": FTPClient, + "ftp-server": FTPServer, + "ntp-client": NTPClient, + "ntp-server": NTPServer, + "terminal": Terminal, } """List of available services that can be installed on nodes in the PrimAITE Simulation.""" @@ -85,9 +73,9 @@ class PrimaiteGameOptions(BaseModel): """Internally generated seed value.""" max_episode_length: int = 256 """Maximum number of episodes for the PrimAITE game.""" - ports: List[str] + ports: List[Port] """A whitelist of available ports in the simulation.""" - protocols: List[str] + protocols: List[IPProtocol] """A whitelist of available protocols in the simulation.""" thresholds: Optional[Dict] = {} """A dict containing the thresholds used for determining what is acceptable during observations.""" @@ -268,95 +256,49 @@ class PrimaiteGame: net = sim.network simulation_config = cfg.get("simulation", {}) + defaults_config = cfg.get("defaults", {}) network_config = simulation_config.get("network", {}) airspace_cfg = network_config.get("airspace", {}) frequency_max_capacity_mbps_cfg = airspace_cfg.get("frequency_max_capacity_mbps", {}) - - frequency_max_capacity_mbps_cfg = {AirSpaceFrequency[k]: v for k, v in frequency_max_capacity_mbps_cfg.items()} - - net.airspace.frequency_max_capacity_mbps_ = frequency_max_capacity_mbps_cfg + net.airspace.set_frequency_max_capacity_mbps(frequency_max_capacity_mbps_cfg) nodes_cfg = network_config.get("nodes", []) links_cfg = network_config.get("links", []) + node_sets_cfg = network_config.get("node_sets", []) # Set the NMNE capture config NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {})) NICObservation.capture_nmne = NMNEConfig(**network_config.get("nmne_config", {})).capture_nmne for node_cfg in nodes_cfg: n_type = node_cfg["type"] + new_node = None - if n_type == "computer": - new_node = Computer( - hostname=node_cfg["hostname"], - ip_address=node_cfg["ip_address"], - subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), - default_gateway=node_cfg.get("default_gateway"), - dns_server=node_cfg.get("dns_server", None), - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - elif n_type == "server": - new_node = Server( - hostname=node_cfg["hostname"], - ip_address=node_cfg["ip_address"], - subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), - default_gateway=node_cfg.get("default_gateway"), - dns_server=node_cfg.get("dns_server", None), - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - elif n_type == "switch": - new_node = Switch( - hostname=node_cfg["hostname"], - num_ports=int(node_cfg.get("num_ports", "8")), - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - elif n_type == "router": - new_node = Router.from_config(node_cfg) - elif n_type == "firewall": - new_node = Firewall.from_config(node_cfg) - elif n_type == "wireless_router": - new_node = WirelessRouter.from_config(node_cfg, airspace=net.airspace) - elif n_type == "printer": - new_node = Printer( - hostname=node_cfg["hostname"], - ip_address=node_cfg["ip_address"], - subnet_mask=node_cfg["subnet_mask"], - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) + if n_type in Node._registry: + n_class = Node._registry[n_type] + if issubclass(n_class, WirelessRouter): + new_node = n_class.from_config(config=node_cfg, airspace=net.airspace) + else: + new_node = Node._registry[n_type].from_config(config=node_cfg) else: msg = f"invalid node type {n_type} in config" _LOGGER.error(msg) raise ValueError(msg) - # handle node file system - if node_cfg.get("file_system"): - for folder_idx, folder_obj in enumerate(node_cfg.get("file_system")): - # if the folder is not a Dict, create an empty folder - if not isinstance(folder_obj, Dict): - new_node.file_system.create_folder(folder_name=folder_obj) - else: - folder_name = next(iter(folder_obj)) - for file_idx, file_obj in enumerate(node_cfg["file_system"][folder_idx][folder_name]): - if not isinstance(file_obj, Dict): - new_node.file_system.create_file(folder_name=folder_name, file_name=file_obj) - else: - file_name = next(iter(file_obj)) - new_node.file_system.create_file( - folder_name=folder_name, - file_name=file_name, - size=file_obj[file_name].get("size", 0), - file_type=FileType[file_obj[file_name].get("type", "UNKNOWN").upper()], - ) + # TODO: handle simulation defaults more cleanly + if "node_start_up_duration" in defaults_config: + new_node.config.start_up_duration = defaults_config["node_startup_duration"] + if "node_shut_down_duration" in defaults_config: + new_node.config.shut_down_duration = defaults_config["node_shut_down_duration"] + if "node_scan_duration" in defaults_config: + new_node.config.node_scan_duration = defaults_config["node_scan_duration"] + if "folder_scan_duration" in defaults_config: + new_node.file_system._default_folder_scan_duration = defaults_config["folder_scan_duration"] + if "folder_restore_duration" in defaults_config: + new_node.file_system._default_folder_restore_duration = defaults_config["folder_restore_duration"] + + if "users" in node_cfg and new_node.software_manager.software.get("user-manager"): + user_manager: UserManager = new_node.software_manager.software["user-manager"] # noqa - if "users" in node_cfg and new_node.software_manager.software.get("UserManager"): - user_manager: UserManager = new_node.software_manager.software["UserManager"] # noqa for user_cfg in node_cfg["users"]: user_manager.add_user(**user_cfg, bypass_can_perform_action=True) @@ -366,9 +308,9 @@ class PrimaiteGame: for port_id in set(software_cfg.get("options", {}).get("listen_on_ports", [])): port = None if isinstance(port_id, int): - port = Port(port_id) + port = port_id elif isinstance(port_id, str): - port = Port[port_id] + port = PORT_LOOKUP[port_id] if port: listen_on_ports.append(port) software.listen_on_ports = set(listen_on_ports) @@ -377,14 +319,22 @@ class PrimaiteGame: for service_cfg in node_cfg["services"]: new_service = None service_type = service_cfg["type"] - if service_type in SERVICE_TYPES_MAPPING: - _LOGGER.debug(f"installing {service_type} on node {new_node.hostname}") - new_node.software_manager.install(SERVICE_TYPES_MAPPING[service_type]) + + service_class = None + # Handle extended services + if service_type.lower() in Service._registry: + service_class = Service._registry[service_type.lower()] + elif service_type in SERVICE_TYPES_MAPPING: + service_class = SERVICE_TYPES_MAPPING[service_type] + + if service_class is not None: + _LOGGER.debug(f"installing {service_type} on node {new_node.config.hostname}") + new_node.software_manager.install(service_class, software_config=service_cfg.get("options", {})) new_service = new_node.software_manager.software[service_type] # fixing duration for the service - if "fix_duration" in service_cfg.get("options", {}): - new_service.fixing_duration = service_cfg["options"]["fix_duration"] + if "fixing_duration" in service_cfg.get("options", {}): + new_service.config.fixing_duration = service_cfg["options"]["fixing_duration"] _set_software_listen_on_ports(new_service, service_cfg) # start the service @@ -393,111 +343,42 @@ class PrimaiteGame: msg = f"Configuration contains an invalid service type: {service_type}" _LOGGER.error(msg) raise ValueError(msg) - # service-dependent options - if service_type == "DNSClient": - if "options" in service_cfg: - opt = service_cfg["options"] - if "dns_server" in opt: - new_service.dns_server = IPv4Address(opt["dns_server"]) - if service_type == "DNSServer": - if "options" in service_cfg: - opt = service_cfg["options"] - if "domain_mapping" in opt: - for domain, ip in opt["domain_mapping"].items(): - new_service.dns_register(domain, IPv4Address(ip)) - if service_type == "DatabaseService": - if "options" in service_cfg: - opt = service_cfg["options"] - new_service.password = opt.get("db_password", None) - if "backup_server_ip" in opt: - new_service.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip"))) - if service_type == "FTPServer": - if "options" in service_cfg: - opt = service_cfg["options"] - new_service.server_password = opt.get("server_password") - if service_type == "NTPClient": - if "options" in service_cfg: - opt = service_cfg["options"] - new_service.ntp_server = IPv4Address(opt.get("ntp_server_ip")) + + # TODO: handle simulation defaults more cleanly + if "service_fix_duration" in defaults_config: + new_service.config.fixing_duration = defaults_config["service_fix_duration"] + if "service_restart_duration" in defaults_config: + new_service.restart_duration = defaults_config["service_restart_duration"] + if "service_install_duration" in defaults_config: + new_service.install_duration = defaults_config["service_install_duration"] + if "applications" in node_cfg: for application_cfg in node_cfg["applications"]: new_application = None application_type = application_cfg["type"] - if application_type in Application._application_registry: - new_node.software_manager.install(Application._application_registry[application_type]) + if application_type in Application._registry: + application_class = Application._registry[application_type] + application_options = application_cfg.get("options", {}) + application_options["type"] = application_type + new_node.software_manager.install(application_class, software_config=application_options) new_application = new_node.software_manager.software[application_type] # grab the instance - # fixing duration for the application - if "fix_duration" in application_cfg.get("options", {}): - new_application.fixing_duration = application_cfg["options"]["fix_duration"] else: msg = f"Configuration contains an invalid application type: {application_type}" _LOGGER.error(msg) raise ValueError(msg) - _set_software_listen_on_ports(new_application, application_cfg) - # run the application new_application.run() - if application_type == "DataManipulationBot": - if "options" in application_cfg: - opt = application_cfg["options"] - new_application.configure( - server_ip_address=IPv4Address(opt.get("server_ip")), - server_password=opt.get("server_password"), - payload=opt.get("payload", "DELETE"), - port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), - data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")), - ) - elif application_type == "RansomwareScript": - if "options" in application_cfg: - opt = application_cfg["options"] - new_application.configure( - server_ip_address=IPv4Address(opt.get("server_ip")) if opt.get("server_ip") else None, - server_password=opt.get("server_password"), - payload=opt.get("payload", "ENCRYPT"), - ) - elif application_type == "DatabaseClient": - if "options" in application_cfg: - opt = application_cfg["options"] - new_application.configure( - server_ip_address=IPv4Address(opt.get("db_server_ip")), - server_password=opt.get("server_password"), - ) - elif application_type == "WebBrowser": - if "options" in application_cfg: - opt = application_cfg["options"] - new_application.target_url = opt.get("target_url") - elif application_type == "DoSBot": - if "options" in application_cfg: - opt = application_cfg["options"] - new_application.configure( - target_ip_address=IPv4Address(opt.get("target_ip_address")), - target_port=Port(opt.get("target_port", Port.POSTGRES_SERVER.value)), - payload=opt.get("payload"), - repeat=bool(opt.get("repeat")), - port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), - dos_intensity=float(opt.get("dos_intensity", "1.0")), - max_sessions=int(opt.get("max_sessions", "1000")), - ) - elif application_type == "C2Beacon": - if "options" in application_cfg: - opt = application_cfg["options"] - new_application.configure( - c2_server_ip_address=IPv4Address(opt.get("c2_server_ip_address")), - keep_alive_frequency=(opt.get("keep_alive_frequency", 5)), - masquerade_protocol=IPProtocol[(opt.get("masquerade_protocol", IPProtocol.TCP))], - masquerade_port=Port[(opt.get("masquerade_port", Port.HTTP))], - ) if "network_interfaces" in node_cfg: for nic_num, nic_cfg in node_cfg["network_interfaces"].items(): new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"])) # temporarily set to 0 so all nodes are initially on - new_node.start_up_duration = 0 - new_node.shut_down_duration = 0 + new_node.config.start_up_duration = 0 + new_node.config.shut_down_duration = 0 net.add_node(new_node) # run through the power on step if the node is to be turned on at the start @@ -505,8 +386,12 @@ class PrimaiteGame: new_node.power_on() # set start up and shut down duration - new_node.start_up_duration = int(node_cfg.get("start_up_duration", 3)) - new_node.shut_down_duration = int(node_cfg.get("shut_down_duration", 3)) + new_node.config.start_up_duration = int(node_cfg.get("start_up_duration", 3)) + new_node.config.shut_down_duration = int(node_cfg.get("shut_down_duration", 3)) + + # 1.1 Create Node Sets + for node_set_cfg in node_sets_cfg: + NetworkNodeAdder.from_config(node_set_cfg, network=net) # 2. create links between nodes for link_cfg in links_cfg: @@ -528,76 +413,11 @@ class PrimaiteGame: agents_cfg = cfg.get("agents", []) for agent_cfg in agents_cfg: - agent_ref = agent_cfg["ref"] # noqa: F841 - agent_type = agent_cfg["type"] - action_space_cfg = agent_cfg["action_space"] - observation_space_cfg = agent_cfg["observation_space"] - reward_function_cfg = agent_cfg["reward_function"] - - # CREATE OBSERVATION SPACE - obs_space = ObservationManager.from_config(config=observation_space_cfg, thresholds=game.options.thresholds) - - # CREATE ACTION SPACE - action_space = ActionManager.from_config(game, action_space_cfg) - - # CREATE REWARD FUNCTION - reward_function = RewardFunction.from_config(reward_function_cfg) - - # CREATE AGENT - if agent_type == "ProbabilisticAgent": - # TODO: implement non-random agents and fix this parsing - settings = agent_cfg.get("agent_settings", {}) - new_agent = ProbabilisticAgent( - agent_name=agent_cfg["ref"], - action_space=action_space, - observation_space=obs_space, - reward_function=reward_function, - settings=settings, - ) - elif agent_type == "PeriodicAgent": - settings = PeriodicAgent.Settings(**agent_cfg.get("settings", {})) - new_agent = PeriodicAgent( - agent_name=agent_cfg["ref"], - action_space=action_space, - observation_space=obs_space, - reward_function=reward_function, - settings=settings, - ) - - elif agent_type == "ProxyAgent": - agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) - new_agent = ProxyAgent( - agent_name=agent_cfg["ref"], - action_space=action_space, - observation_space=obs_space, - reward_function=reward_function, - agent_settings=agent_settings, - ) - game.rl_agents[agent_cfg["ref"]] = new_agent - elif agent_type == "RedDatabaseCorruptingAgent": - agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) - - new_agent = DataManipulationAgent( - agent_name=agent_cfg["ref"], - action_space=action_space, - observation_space=obs_space, - reward_function=reward_function, - agent_settings=agent_settings, - ) - elif agent_type == "TAP001": - agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) - new_agent = TAP001( - agent_name=agent_cfg["ref"], - action_space=action_space, - observation_space=obs_space, - reward_function=reward_function, - agent_settings=agent_settings, - ) - else: - msg = f"Configuration error: {agent_type} is not a valid agent type." - _LOGGER.error(msg) - raise ValueError(msg) + agent_cfg = {**agent_cfg, "thresholds": game.options.thresholds} + new_agent = AbstractAgent.from_config(agent_cfg) game.agents[agent_cfg["ref"]] = new_agent + if isinstance(new_agent, ProxyAgent): + game.rl_agents[agent_cfg["ref"]] = new_agent # Validate that if any agents are sharing rewards, they aren't forming an infinite loop. game.setup_reward_sharing() @@ -626,7 +446,7 @@ class PrimaiteGame: for comp, weight in agent.reward_function.reward_components: if isinstance(comp, SharedReward): comp: SharedReward - graph[name].add(comp.agent_name) + graph[name].add(comp.config.agent_name) # while constructing the graph, we might as well set up the reward sharing itself. comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward diff --git a/src/primaite/game/science.py b/src/primaite/game/science.py index 8d8949df..2cb5de7d 100644 --- a/src/primaite/game/science.py +++ b/src/primaite/game/science.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from random import random from typing import Any, Iterable, Mapping diff --git a/src/primaite/interface/__init__.py b/src/primaite/interface/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/interface/__init__.py +++ b/src/primaite/interface/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/interface/request.py b/src/primaite/interface/request.py index 1a9f0e5f..03d6491e 100644 --- a/src/primaite/interface/request.py +++ b/src/primaite/interface/request.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Dict, ForwardRef, List, Literal, Union from pydantic import BaseModel, ConfigDict, StrictBool # , validate_call diff --git a/src/primaite/notebooks/Action-masking.ipynb b/src/primaite/notebooks/Action-masking.ipynb index ba70f2b4..74504878 100644 --- a/src/primaite/notebooks/Action-masking.ipynb +++ b/src/primaite/notebooks/Action-masking.ipynb @@ -6,11 +6,20 @@ "source": [ "# Action Masking\n", "\n", - "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n", + "© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK\n", "\n", "PrimAITE environments support action masking. The action mask shows which of the agent's actions are applicable with the current environment state. For example, a node can only be turned on if it is currently turned off." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, @@ -19,7 +28,7 @@ "source": [ "from primaite.session.environment import PrimaiteGymEnv\n", "from primaite.config.load import data_manipulation_config_path\n", - "from prettytable import PrettyTable\n" + "from prettytable import PrettyTable" ] }, { @@ -195,7 +204,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, diff --git a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb b/src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb similarity index 81% rename from src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb rename to src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb index 2469835b..f187c8d5 100644 --- a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb @@ -6,11 +6,20 @@ "source": [ "# Command and Control Application Suite E2E Demonstration\n", "\n", - "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n", + "© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK\n", "\n", "This notebook demonstrates the current implementation of the command and control (C2) server and beacon applications in primAITE." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, @@ -50,118 +59,79 @@ "custom_c2_agent = \"\"\"\n", " - ref: CustomC2Agent\n", " team: RED\n", - " type: ProxyAgent\n", - " observation_space: null\n", + " type: proxy-agent\n", + "\n", " action_space:\n", - " action_list:\n", - " - type: DONOTHING\n", - " - type: NODE_APPLICATION_INSTALL\n", - " - type: NODE_APPLICATION_EXECUTE\n", - " - type: CONFIGURE_C2_BEACON\n", - " - type: C2_SERVER_RANSOMWARE_LAUNCH\n", - " - type: C2_SERVER_RANSOMWARE_CONFIGURE\n", - " - type: C2_SERVER_TERMINAL_COMMAND\n", - " - type: C2_SERVER_DATA_EXFILTRATE\n", - " options:\n", - " nodes:\n", - " - node_name: web_server\n", - " applications: \n", - " - application_name: C2Beacon\n", - " - node_name: client_1\n", - " applications: \n", - " - application_name: C2Server\n", - " max_folders_per_node: 1\n", - " max_files_per_folder: 1\n", - " max_services_per_node: 2\n", - " max_nics_per_node: 8\n", - " max_acl_rules: 10\n", - " ip_list:\n", - " - 192.168.1.21\n", - " - 192.168.1.14\n", - " wildcard_list:\n", - " - 0.0.0.1\n", " action_map:\n", " 0:\n", - " action: DONOTHING\n", + " action: do-nothing\n", " options: {}\n", " 1:\n", - " action: NODE_APPLICATION_INSTALL\n", + " action: node-application-install\n", " options:\n", - " node_id: 0\n", - " application_name: C2Beacon\n", + " node_name: web_server\n", + " application_name: c2-beacon\n", " 2:\n", - " action: CONFIGURE_C2_BEACON\n", + " action: configure-c2-beacon\n", " options:\n", - " node_id: 0\n", - " config:\n", - " c2_server_ip_address: 192.168.10.21\n", - " keep_alive_frequency:\n", - " masquerade_protocol:\n", - " masquerade_port:\n", + " node_name: web_server\n", + " c2_server_ip_address: 192.168.10.21\n", " 3:\n", - " action: NODE_APPLICATION_EXECUTE\n", + " action: node-application-execute\n", " options:\n", - " node_id: 0\n", - " application_id: 0 \n", + " node_name: web_server\n", + " application_name: c2-beacon\n", " 4:\n", - " action: C2_SERVER_TERMINAL_COMMAND\n", + " action: c2-server-terminal-command\n", " options:\n", - " node_id: 1\n", + " node_name: client_1\n", " ip_address:\n", - " account:\n", - " username: admin\n", - " password: admin\n", + " username: admin\n", + " password: admin\n", " commands:\n", - " - \n", + " -\n", " - software_manager\n", " - application\n", " - install\n", - " - RansomwareScript\n", + " - ransomware-script\n", " 5:\n", - " action: C2_SERVER_RANSOMWARE_CONFIGURE\n", + " action: c2-server-ransomware-configure\n", " options:\n", - " node_id: 1\n", - " config:\n", - " server_ip_address: 192.168.1.14\n", - " payload: ENCRYPT\n", + " node_name: client_1\n", + " server_ip_address: 192.168.1.14\n", + " payload: ENCRYPT\n", " 6:\n", - " action: C2_SERVER_DATA_EXFILTRATE\n", + " action: c2-server-data-exfiltrate\n", " options:\n", - " node_id: 1\n", + " node_name: client_1\n", " target_file_name: \"database.db\"\n", " target_folder_name: \"database\"\n", " exfiltration_folder_name: \"spoils\"\n", " target_ip_address: 192.168.1.14\n", - " account:\n", - " username: admin\n", - " password: admin \n", + " username: admin\n", + " password: admin\n", "\n", " 7:\n", - " action: C2_SERVER_RANSOMWARE_LAUNCH\n", + " action: c2-server-ransomware-launch\n", " options:\n", - " node_id: 1\n", + " node_name: client_1\n", " 8:\n", - " action: CONFIGURE_C2_BEACON\n", + " action: configure-c2-beacon\n", " options:\n", - " node_id: 0\n", - " config:\n", - " c2_server_ip_address: 192.168.10.21\n", - " keep_alive_frequency: 10\n", - " masquerade_protocol: TCP\n", - " masquerade_port: DNS\n", + " node_name: web_server\n", + " c2_server_ip_address: 192.168.10.21\n", + " keep_alive_frequency: 10\n", + " masquerade_protocol: tcp\n", + " masquerade_port: dns\n", " 9:\n", - " action: CONFIGURE_C2_BEACON\n", + " action: configure-c2-beacon\n", " options:\n", - " node_id: 0\n", - " config:\n", - " c2_server_ip_address: 192.168.10.22\n", - " keep_alive_frequency:\n", - " masquerade_protocol:\n", - " masquerade_port:\n", + " node_name: web_server\n", + " c2_server_ip_address: 192.168.10.22\n", "\n", " reward_function:\n", " reward_components:\n", - " - type: DUMMY\n", + " - type: dummy\n", "\"\"\"\n", "c2_agent_yaml = yaml.safe_load(custom_c2_agent)" ] @@ -177,7 +147,7 @@ " # removing all agents & adding the custom agent.\n", " cfg['agents'] = {}\n", " cfg['agents'] = c2_agent_yaml\n", - " \n", + "\n", "\n", "env = PrimaiteGymEnv(env_config=cfg)" ] @@ -202,7 +172,7 @@ "source": [ "client_1: Computer = env.game.simulation.network.get_node_by_hostname(\"client_1\")\n", "client_1.software_manager.install(C2Server)\n", - "c2_server: C2Server = client_1.software_manager.software[\"C2Server\"]\n", + "c2_server: C2Server = client_1.software_manager.software[\"c2-server\"]\n", "c2_server.run()\n", "client_1.software_manager.show()" ] @@ -222,7 +192,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### **Command and Control** | C2 Beacon Actions | NODE_APPLICATION_INSTALL\n", + "### **Command and Control** | C2 Beacon Actions | node_application_install\n", "\n", "The custom proxy red agent defined at the start of this notebook has been configured to install the C2 Beacon as action ``1`` in it's action map. \n", "\n", @@ -230,23 +200,19 @@ "\n", "```yaml\n", " action_space:\n", - " action_list:\n", - " ...\n", - " - type: NODE_APPLICATION_INSTALL\n", - " ...\n", " options:\n", " nodes: # Node List\n", " - node_name: web_server\n", " applications: \n", - " - application_name: C2Beacon\n", + " - application_name: c2-beacon\n", " ...\n", " ...\n", " action_map:\n", " 1:\n", - " action: NODE_APPLICATION_INSTALL \n", + " action: node-application-install \n", " options:\n", " node_id: 0 # Index 0 at the node list.\n", - " application_name: C2Beacon\n", + " application_name: c2-beacon\n", "```" ] }, @@ -265,34 +231,25 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### **Command and Control** | C2 Beacon Actions | CONFIGURE_C2_BEACON \n", + "### **Command and Control** | C2 Beacon Actions | configure_c2_beacon \n", "\n", "The custom proxy red agent defined at the start of this notebook can configure the C2 Beacon via action ``2`` in it's action map. \n", "\n", "The yaml snippet below shows all the relevant agent options for this action:\n", "\n", "```yaml\n", + "\n", " action_space:\n", - " action_list:\n", - " ...\n", - " - type: CONFIGURE_C2_BEACON\n", - " ...\n", - " options:\n", - " nodes: # Node List\n", - " - node_name: web_server\n", - " ...\n", - " ...\n", " action_map:\n", - " ...\n", " 2:\n", - " action: CONFIGURE_C2_BEACON\n", + " action: configure-c2-beacon\n", " options:\n", - " node_id: 0 # Node Index\n", - " config: # Further information about these config options can be found at the bottom of this notebook.\n", - " c2_server_ip_address: 192.168.10.21\n", - " keep_alive_frequency:\n", - " masquerade_protocol:\n", - " masquerade_port:\n", + " node_name: web_server\n", + " c2_server_ip_address: 192.168.10.21 # Further information about these config options can be found at the bottom of this notebook.\n", + " keep_alive_frequency:\n", + " masquerade_protocol:\n", + " masquerade_port:\n", + "\n", "```" ] }, @@ -303,7 +260,7 @@ "outputs": [], "source": [ "env.step(2)\n", - "c2_beacon: C2Beacon = web_server.software_manager.software[\"C2Beacon\"]\n", + "c2_beacon: C2Beacon = web_server.software_manager.software[\"c2-beacon\"]\n", "web_server.software_manager.show()\n", "c2_beacon.show()" ] @@ -312,32 +269,20 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### **Command and Control** | C2 Beacon Actions | NODE_APPLICATION_EXECUTE\n", + "### **Command and Control** | C2 Beacon Actions | node_application_execute\n", "\n", - "The final action is ``NODE_APPLICATION_EXECUTE`` which is used to establish a connection for the C2 application. This action can be called by the Red Agent via action ``3`` in it's action map. \n", + "The final action is ``node-application-execute`` which is used to establish a connection for the C2 application. This action can be called by the Red Agent via action ``3`` in it's action map. \n", "\n", "The yaml snippet below shows all the relevant agent options for this action:\n", "\n", "```yaml\n", " action_space:\n", - " action_list:\n", - " ...\n", - " - type: NODE_APPLICATION_EXECUTE\n", - " ...\n", - " options:\n", - " nodes: # Node List\n", - " - node_name: web_server\n", - " applications: \n", - " - application_name: C2Beacon\n", - " ...\n", - " ...\n", " action_map:\n", - " ...\n", " 3:\n", - " action: NODE_APPLICATION_EXECUTE\n", + " action: node-application-execute\n", " options:\n", - " node_id: 0\n", - " application_id: 0\n", + " node_name: web_server\n", + " application_name: c2-beacon\n", "```" ] }, @@ -347,7 +292,7 @@ "metadata": {}, "outputs": [], "source": [ - "env.step(3) " + "env.step(3)" ] }, { @@ -376,38 +321,26 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### **Command and Control** | C2 Server Actions | C2_SERVER_TERMINAL_COMMAND\n", + "### **Command and Control** | C2 Server Actions | c2-server-terminal-command\n", "\n", - "The C2 Server's terminal action: ``C2_SERVER_TERMINAL_COMMAND`` is indexed at ``4`` in it's action map. \n", + "The C2 Server's terminal action: ``c2-server-terminal-command`` is indexed at ``4`` in it's action map. \n", "\n", "This action leverages the terminal service that is installed by default on all nodes to grant red agents a lot more configurability. If you're unfamiliar with terminals then it's recommended that you refer to the ``Terminal Processing`` notebook.\n", "\n", "It's worth noting that an additional benefit a red agent has when using the terminal service via the C2 Server is that you can execute multiple commands in one action. \n", "\n", - "In this notebook, the ``C2_SERVER_TERMINAL_COMMAND`` is used to install a RansomwareScript application on the ``web_server`` node.\n", + "In this notebook, the ``c2-server-terminal-command`` is used to install a RansomwareScript application on the ``web_server`` node.\n", "\n", "The yaml snippet below shows all the relevant agent options for this action:\n", "\n", "``` yaml\n", " action_space:\n", - " action_list:\n", - " ...\n", - " - type: C2_SERVER_TERMINAL_COMMAND\n", - " ...\n", - " options:\n", - " nodes: # Node List\n", - " ...\n", - " - node_name: client_1\n", - " applications: \n", - " - application_name: C2Server\n", - " ...\n", " action_map:\n", " 4:\n", - " action: C2_SERVER_TERMINAL_COMMAND\n", + " action: c2-server-terminal-command\n", " options:\n", - " node_id: 1\n", + " node_name: client_1\n", " ip_address:\n", - " account:\n", " username: admin\n", " password: admin\n", " commands:\n", @@ -415,7 +348,7 @@ " - software_manager\n", " - application\n", " - install\n", - " - RansomwareScript\n", + " - ransomware-script\n", "```" ] }, @@ -441,7 +374,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### **Command and Control** | C2 Server Actions | C2_SERVER_RANSOMWARE_CONFIGURE\n", + "### **Command and Control** | C2 Server Actions | c2-server-ransomware-configure\n", "\n", "Another action the C2 Server grants is the ability for a Red Agent to configure the RansomwareScript via the C2 Server rather than the note directly.\n", "\n", @@ -451,25 +384,13 @@ "\n", "``` yaml\n", " action_space:\n", - " action_list:\n", - " ...\n", - " - type: C2_SERVER_RANSOMWARE_CONFIGURE\n", - " ...\n", - " options:\n", - " nodes: # Node List\n", - " ...\n", - " - node_name: client_1\n", - " applications: \n", - " - application_name: C2Server\n", - " ...\n", " action_map:\n", " 5:\n", - " action: C2_SERVER_RANSOMWARE_CONFIG\n", + " action: c2-server-ransomware-configure\n", " options:\n", - " node_id: 1\n", - " config:\n", - " server_ip_address: 192.168.1.14\n", - " payload: ENCRYPT\n", + " node_name: client_1\n", + " server_ip_address: 192.168.1.14\n", + " payload: ENCRYPT\n", "```\n" ] }, @@ -488,7 +409,7 @@ "metadata": {}, "outputs": [], "source": [ - "ransomware_script: RansomwareScript = web_server.software_manager.software[\"RansomwareScript\"]\n", + "ransomware_script: RansomwareScript = web_server.software_manager.software[\"ransomware-script\"]\n", "web_server.software_manager.show()\n", "ransomware_script.show()" ] @@ -497,9 +418,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### **Command and Control** | C2 Server Actions | C2_SERVER_DATA_EXFILTRATE\n", + "### **Command and Control** | C2 Server Actions | c2-server-data-exfiltrate\n", "\n", - "The second to last action available is the ``C2_SERVER_DATA_EXFILTRATE`` which is indexed as action ``6`` in the action map.\n", + "The second to last action available is the ``c2-server-data-exfiltrate`` which is indexed as action ``6`` in the action map.\n", "\n", "This action can be used to exfiltrate a target file on a remote node to the C2 Beacon and the C2 Server's host file system via the ``FTP`` services.\n", "\n", @@ -507,29 +428,17 @@ "\n", "``` yaml\n", " action_space:\n", - " action_list:\n", - " ...\n", - " - type: C2_SERVER_DATA_EXFILTRATE\n", - " ...\n", - " options:\n", - " nodes: # Node List\n", - " ...\n", - " - node_name: client_1\n", - " applications: \n", - " - application_name: C2Server\n", - " ...\n", " action_map:\n", " 6:\n", - " action: C2_SERVER_DATA_EXFILTRATE\n", + " action: c2-server-data-exfiltrate\n", " options:\n", " node_id: 1\n", " target_file_name: \"database.db\"\n", " target_folder_name: \"database\"\n", " exfiltration_folder_name: \"spoils\"\n", " target_ip_address: \"192.168.1.14\"\n", - " account:\n", - " username: \"admin\",\n", - " password: \"admin\"\n", + " username: \"admin\"\n", + " password: \"admin\"\n", "\n", "```" ] @@ -567,9 +476,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### **Command and Control** | C2 Server Actions | C2_SERVER_RANSOMWARE_LAUNCH\n", + "### **Command and Control** | C2 Server Actions | c2-server-ransomware-launch\n", "\n", - "Finally, the last available action is for the C2_SERVER_RANSOMWARE_LAUNCH to start the ransomware script installed on the same node as the C2 beacon.\n", + "Finally, the last available action is for the c2-server-ransomware-launch to start the ransomware script installed on the same node as the C2 beacon.\n", "\n", "This action is indexed as action ``7``.\n", "\n", @@ -577,20 +486,9 @@ "\n", "``` yaml\n", " action_space:\n", - " action_list:\n", - " ...\n", - " - type: C2_SERVER_RANSOMWARE_LAUNCH\n", - " ...\n", - " options:\n", - " nodes: # Node List\n", - " ...\n", - " - node_name: client_1\n", - " applications: \n", - " - application_name: C2Server\n", - " ...\n", " action_map:\n", " 7:\n", - " action: C2_SERVER_RANSOMWARE_LAUNCH\n", + " action: c2-server-ransomware-launch\n", " options:\n", " node_id: 1\n", "```\n" @@ -632,23 +530,23 @@ "metadata": {}, "outputs": [], "source": [ - "custom_blue_agent_yaml = \"\"\" \n", + "custom_blue_agent_yaml = \"\"\"\n", " - ref: defender\n", " team: BLUE\n", - " type: ProxyAgent\n", + " type: proxy-agent\n", "\n", " observation_space:\n", - " type: CUSTOM\n", + " type: custom\n", " options:\n", " components:\n", - " - type: NODES\n", - " label: NODES\n", + " - type: nodes\n", + " label: nodes\n", " options:\n", " hosts:\n", " - hostname: web_server\n", " applications:\n", - " - application_name: C2Beacon\n", - " - application_name: RansomwareScript\n", + " - application_name: c2-beacon\n", + " - application_name: ransomware-script\n", " folders:\n", " - folder_name: exfiltration_folder\n", " files:\n", @@ -698,7 +596,7 @@ " - UDP\n", " num_rules: 10\n", "\n", - " - type: LINKS\n", + " - type: links\n", " label: LINKS\n", " options:\n", " link_references:\n", @@ -712,72 +610,41 @@ " - switch_2:eth-1<->client_1:eth-1\n", " - switch_2:eth-2<->client_2:eth-1\n", " - switch_2:eth-7<->security_suite:eth-2\n", - " - type: \"NONE\"\n", + " - type: \"none\"\n", " label: ICS\n", " options: {}\n", - " \n", + "\n", " action_space:\n", - " action_list:\n", - " - type: NODE_APPLICATION_REMOVE\n", - " - type: NODE_SHUTDOWN\n", - " - type: ROUTER_ACL_ADDRULE\n", - " - type: DONOTHING\n", " action_map:\n", " 0:\n", - " action: DONOTHING\n", + " action: do-nothing\n", " options: {}\n", " 1:\n", - " action: NODE_APPLICATION_REMOVE\n", + " action: node-application-remove\n", " options:\n", - " node_id: 0\n", - " application_name: C2Beacon\n", + " node_name: web_server\n", + " application_name: c2-beacon\n", " 2:\n", - " action: NODE_SHUTDOWN\n", + " action: node-shutdown\n", " options:\n", - " node_id: 0\n", + " node_name: web_server\n", " 3:\n", - " action: ROUTER_ACL_ADDRULE\n", + " action: router-acl-add-rule\n", " options:\n", " target_router: router_1\n", " position: 1\n", - " permission: 2\n", - " source_ip_id: 2\n", - " dest_ip_id: 3\n", - " source_port_id: 2\n", - " dest_port_id: 2\n", - " protocol_id: 1\n", - " source_wildcard_id: 0\n", - " dest_wildcard_id: 0 \n", + " permission: DENY\n", + " src_ip: 192.168.10.21\n", + " dst_ip: 192.168.1.12\n", + " src_port: HTTP\n", + " dst_port: HTTP\n", + " protocol_name: ALL\n", + " src_wildcard: 0.0.0.1\n", + " dst_wildcard: 0.0.0.1\n", "\n", - "\n", - " options:\n", - " nodes:\n", - " - node_name: web_server\n", - " applications:\n", - " - application_name: C2Beacon\n", - "\n", - " - node_name: database_server\n", - " folders:\n", - " - folder_name: database\n", - " files:\n", - " - file_name: database.db\n", - " services:\n", - " - service_name: DatabaseService\n", - " - node_name: router_1\n", - "\n", - " max_folders_per_node: 2\n", - " max_files_per_folder: 2\n", - " max_services_per_node: 2\n", - " max_nics_per_node: 8\n", - " max_acl_rules: 10\n", - " ip_list:\n", - " - 192.168.10.21\n", - " - 192.168.1.12\n", - " wildcard_list:\n", - " - 0.0.0.1\n", " reward_function:\n", " reward_components:\n", - " - type: DUMMY\n", + " - type: dummy\n", "\n", " agent_settings:\n", " flatten_obs: False\n", @@ -796,7 +663,7 @@ " # removing all agents & adding the custom agent.\n", " cfg['agents'] = {}\n", " cfg['agents'] = custom_blue\n", - " \n", + "\n", "\n", "blue_env = PrimaiteGymEnv(env_config=cfg)" ] @@ -868,12 +735,12 @@ "\n", "# Installing the C2 Server.\n", "client_1.software_manager.install(C2Server)\n", - "c2_server: C2Server = client_1.software_manager.software[\"C2Server\"]\n", + "c2_server: C2Server = client_1.software_manager.software[\"c2-server\"]\n", "c2_server.run()\n", "\n", "# Installing the C2 Beacon.\n", "web_server.software_manager.install(C2Beacon)\n", - "c2_beacon: C2Beacon = web_server.software_manager.software[\"C2Beacon\"]\n", + "c2_beacon: C2Beacon = web_server.software_manager.software[\"c2-beacon\"]\n", "c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\")\n", "c2_beacon.establish()" ] @@ -917,7 +784,7 @@ "outputs": [], "source": [ "# Installing RansomwareScript via C2 Terminal Commands\n", - "ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"RansomwareScript\"]],\n", + "ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"ransomware-script\"]],\n", " \"username\": \"admin\",\n", " \"password\": \"admin\"}\n", "c2_server.send_command(C2Command.TERMINAL, command_options=ransomware_install_command)\n" @@ -1076,11 +943,11 @@ " web_server: Server = given_env.game.simulation.network.get_node_by_hostname(\"web_server\")\n", "\n", " client_1.software_manager.install(C2Server)\n", - " c2_server: C2Server = client_1.software_manager.software[\"C2Server\"]\n", + " c2_server: C2Server = client_1.software_manager.software[\"c2-server\"]\n", " c2_server.run()\n", "\n", " web_server.software_manager.install(C2Beacon)\n", - " c2_beacon: C2Beacon = web_server.software_manager.software[\"C2Beacon\"]\n", + " c2_beacon: C2Beacon = web_server.software_manager.software[\"c2-beacon\"]\n", " c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\")\n", " c2_beacon.establish()\n", "\n", @@ -1121,7 +988,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The code cell below uses the custom blue agent defined at the start of this section perform a NODE_APPLICATION_REMOVE on the C2 beacon:" + "The code cell below uses the custom blue agent defined at the start of this section perform a node_application_remove on the C2 beacon:" ] }, { @@ -1130,7 +997,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Using CAOS ACTION: NODE_APPLICATION_REMOVE & capturing the OBS\n", + "# Using CAOS ACTION: node_application_remove & capturing the OBS\n", "post_blue_action_obs, _, _, _, _ = blue_env.step(1)" ] }, @@ -1178,7 +1045,7 @@ " \"username\": \"admin\",\n", " \"password\": \"admin\"}\n", "\n", - "c2_server: C2Server = client_1.software_manager.software[\"C2Server\"]\n", + "c2_server: C2Server = client_1.software_manager.software[\"c2-server\"]\n", "c2_server.send_command(C2Command.TERMINAL, command_options=ransomware_install_command)" ] }, @@ -1216,7 +1083,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The code cell below uses the custom blue agent defined at the start of this section to perform a ``NODE_SHUT_DOWN`` action on the web server." + "The code cell below uses the custom blue agent defined at the start of this section to perform a ``node_shut_down`` action on the web server." ] }, { @@ -1225,7 +1092,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Using CAOS ACTION: NODE_SHUT_DOWN & capturing the OBS\n", + "# Using CAOS ACTION: node_shut_down & capturing the OBS\n", "post_blue_action_obs, _, _, _, _ = blue_env.step(2)" ] }, @@ -1266,7 +1133,7 @@ " \"username\": \"admin\",\n", " \"password\": \"admin\"}\n", "\n", - "c2_server: C2Server = client_1.software_manager.software[\"C2Server\"]\n", + "c2_server: C2Server = client_1.software_manager.software[\"c2-server\"]\n", "c2_server.send_command(C2Command.TERMINAL, command_options=ransomware_install_command)" ] }, @@ -1306,7 +1173,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The code cell below uses the custom blue agent defined at the start of this section to perform a ROUTER_ACL_ADDRULE on router 1." + "The code cell below uses the custom blue agent defined at the start of this section to perform a router_acl_add_rule on router 1." ] }, { @@ -1315,7 +1182,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Using CAOS ACTION: ROUTER_ACL_ADDRULE & capturing the OBS\n", + "# Using CAOS ACTION: router_acl_add_rule & capturing the OBS\n", "post_blue_action_obs, _, _, _, _ = blue_env.step(3)" ] }, @@ -1429,18 +1296,20 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "As demonstrated earlier, red agents can use the ``CONFIGURE_C2_BEACON`` action to configure these settings mid episode through the configuration options:\n", + "As demonstrated earlier, red agents can use the ``configure-c2-beacon`` action to configure these settings mid episode through the configuration options:\n", "\n", - "``` YAML\n", - "...\n", - " action: CONFIGURE_C2_BEACON\n", - " options:\n", - " node_id: 0\n", - " config:\n", + "```YAML\n", + "\n", + " action_space:\n", + " action_map:\n", + " 8:\n", + " action: configure-c2-beacon\n", + " options:\n", + " node_name: web_server\n", " c2_server_ip_address: 192.168.10.21\n", " keep_alive_frequency: 10\n", - " masquerade_protocol: TCP\n", - " masquerade_port: DNS\n", + " masquerade_protocol: tcp\n", + " masquerade_port: dns\n", "```" ] }, @@ -1468,7 +1337,7 @@ " # removing all agents & adding the custom agent.\n", " cfg['agents'] = {}\n", " cfg['agents'] = c2_agent_yaml\n", - " \n", + "\n", "\n", "c2_config_env = PrimaiteGymEnv(env_config=cfg)" ] @@ -1488,16 +1357,16 @@ "source": [ "web_server: Server = c2_config_env.game.simulation.network.get_node_by_hostname(\"web_server\")\n", "web_server.software_manager.install(C2Beacon)\n", - "c2_beacon: C2Beacon = web_server.software_manager.software[\"C2Beacon\"]\n", + "c2_beacon: C2Beacon = web_server.software_manager.software[\"c2-beacon\"]\n", "\n", "client_1: Computer = c2_config_env.game.simulation.network.get_node_by_hostname(\"client_1\")\n", "client_1.software_manager.install(C2Server)\n", - "c2_server_1: C2Server = client_1.software_manager.software[\"C2Server\"]\n", + "c2_server_1: C2Server = client_1.software_manager.software[\"c2-server\"]\n", "c2_server_1.run()\n", "\n", "client_2: Computer = c2_config_env.game.simulation.network.get_node_by_hostname(\"client_2\")\n", "client_2.software_manager.install(C2Server)\n", - "c2_server_2: C2Server = client_2.software_manager.software[\"C2Server\"]\n", + "c2_server_2: C2Server = client_2.software_manager.software[\"c2-server\"]\n", "c2_server_2.run()" ] }, @@ -1555,7 +1424,7 @@ "source": [ "for i in range(6):\n", " env.step(0)\n", - " \n", + "\n", "c2_server_1.show()" ] }, @@ -1676,7 +1545,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Comparing the OBS of the default frequency to a timestep frequency of 1 \n", + "# Comparing the OBS of the default frequency to a timestep frequency of 1\n", "for i in range(2):\n", " keep_alive_obs, _, _, _, _ = blue_config_env.step(0)\n", " display_obs_diffs(default_obs, keep_alive_obs, blue_config_env.game.step_counter)" @@ -1760,7 +1629,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Capturing default C2 Traffic \n", + "# Capturing default C2 Traffic\n", "for i in range(3):\n", " tcp_c2_obs, _, _, _, _ = blue_config_env.step(0)\n", "\n", @@ -1780,10 +1649,11 @@ "metadata": {}, "outputs": [], "source": [ - "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", - "from primaite.simulator.network.transmission.transport_layer import Port\n", + "from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n", + "from primaite.utils.validation.port import PORT_LOOKUP\n", + "\n", "# As we're configuring via the PrimAITE API we need to pass the actual IPProtocol/Port (Agents leverage the simulation via the game layer and thus can pass strings).\n", - "c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=IPProtocol.UDP, masquerade_port=Port.DNS)\n", + "c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=PROTOCOL_LOOKUP[\"UDP\"], masquerade_port=PORT_LOOKUP[\"DNS\"])\n", "c2_beacon.establish()\n", "c2_beacon.show()" ] @@ -1800,21 +1670,11 @@ "\n", "display_obs_diffs(tcp_c2_obs, udp_c2_obs, blue_config_env.game.step_counter)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "env.game.agents[\"CustomC2Agent\"].show_history()" - ] } ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -1828,7 +1688,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb b/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb index dd5def9e..9ac1da9b 100644 --- a/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb @@ -6,7 +6,7 @@ "source": [ "# Customising UC2 Red Agents\n", "\n", - "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n", + "© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK\n", "\n", "This notebook will go over some examples of how red agent behaviour can be varied by changing its configuration parameters.\n", "\n", @@ -15,6 +15,15 @@ "*(For a full explanation of the Data Manipulation scenario, check out the data manipulation scenario notebook)*" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, @@ -38,7 +47,7 @@ "source": [ "def make_cfg_have_flat_obs(cfg):\n", " for agent in cfg['agents']:\n", - " if agent['type'] == \"ProxyAgent\":\n", + " if agent['type'] == \"proxy-agent\":\n", " agent['agent_settings']['flatten_obs'] = False" ] }, @@ -67,11 +76,10 @@ " # parse the info dict form step output and write out what the red agent is doing\n", " red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", " red_action = red_info.action\n", - " if red_action == 'DONOTHING':\n", + " if red_action == 'do-nothing':\n", " red_str = 'DO NOTHING'\n", - " elif red_action == 'NODE_APPLICATION_EXECUTE':\n", - " client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n", - " red_str = f\"ATTACK from {client}\"\n", + " elif red_action == 'node-application-execute':\n", + " red_str = f\"ATTACK from {red_info.parameters['node_name']}\"\n", " return red_str" ] }, @@ -138,47 +146,15 @@ "```yaml\n", " - ref: data_manipulation_attacker # name of agent\n", " team: RED # not used, just for human reference\n", - " type: RedDatabaseCorruptingAgent # type of agent - this lets primaite know which agent class to use\n", - "\n", - " # Since the agent does not need to react to what is happening in the environment, the observation space is empty.\n", - " observation_space:\n", - " type: UC2RedObservation\n", - " options:\n", - " nodes: {}\n", - "\n", - " action_space:\n", - "\n", - " # The agent has two action choices, either do nothing, or execute a pre-scripted attack by using \n", - " action_list:\n", - " - type: DONOTHING\n", - " - type: NODE_APPLICATION_EXECUTE\n", - "\n", - " # The agent has access to the DataManipulationBoth on clients 1 and 2.\n", - " options:\n", - " nodes:\n", - " - node_name: client_1 # The network should have a node called client_1\n", - " applications:\n", - " - application_name: DataManipulationBot # The node client_1 should have DataManipulationBot configured on it\n", - " - node_name: client_2 # The network should have a node called client_2\n", - " applications:\n", - " - application_name: DataManipulationBot # The node client_2 should have DataManipulationBot configured on it\n", - "\n", - " # not important\n", - " max_folders_per_node: 1\n", - " max_files_per_folder: 1\n", - " max_services_per_node: 1\n", - "\n", - " # red agent does not need a reward function\n", - " reward_function:\n", - " reward_components:\n", - " - type: DUMMY\n", + " type: red-database-corrupting-agent # type of agent - this lets primaite know which agent class to use\n", "\n", " # These actions are passed to the RedDatabaseCorruptingAgent init method, they dictate the schedule of attacks\n", " agent_settings:\n", - " start_settings:\n", - " start_step: 25 # first attack at step 25\n", - " frequency: 20 # attacks will happen every 20 steps (on average)\n", - " variance: 5 # the timing of attacks will vary by up to 5 steps earlier or later\n", + " possible_start_nodes: [client_1, client_2] # List of clients the attack can start from\n", + " target_application: data-manipulation-bot\n", + " start_step: 25 # first attack at step 25\n", + " frequency: 20 # attacks will happen every 20 steps (on average)\n", + " variance: 5 # the timing of attacks will vary by up to 5 steps earlier or later\n", "```" ] }, @@ -198,8 +174,7 @@ "simulation:\n", " network:\n", " nodes:\n", - " - ref: client_1\n", - " hostname: client_1\n", + " - hostname: client_1\n", " type: computer\n", " ip_address: 192.168.10.21\n", " subnet_mask: 255.255.255.0\n", @@ -207,15 +182,13 @@ " \n", " # \n", " applications:\n", - " - ref: data_manipulation_bot\n", - " type: DataManipulationBot\n", + " - type: data-manipulation-bot\n", " options:\n", " port_scan_p_of_success: 0.8 # Probability that port scan is successful\n", " data_manipulation_p_of_success: 0.8 # Probability that SQL attack is successful\n", " payload: \"DELETE\" # The SQL query which causes the attack (this has to be DELETE)\n", " server_ip: 192.168.1.14 # IP address of server hosting the database\n", - " - ref: client_1_database_client\n", - " type: DatabaseClient # Database client must be installed in order for DataManipulationBot to function\n", + " - type: database-client # Database client must be installed in order for DataManipulationBot to function\n", " options:\n", " db_server_ip: 192.168.1.14 # IP address of server hosting the database\n", "```" @@ -239,7 +212,8 @@ "outputs": [], "source": [ "change = yaml.safe_load(\"\"\"\n", - "start_settings:\n", + " possible_start_nodes: [client_1]\n", + " target_application: DataManipulationBot\n", " start_step: 25\n", " frequency: 20\n", " variance: 0\n", @@ -249,7 +223,9 @@ " cfg = yaml.safe_load(f)\n", " for agent in cfg['agents']:\n", " if agent['ref'] == \"data_manipulation_attacker\":\n", + " print(f\"{agent['agent_settings']=}\")\n", " agent['agent_settings'] = change\n", + " print(f\"{agent['agent_settings']=}\")\n", "\n", "env = PrimaiteGymEnv(env_config = cfg)\n", "env.reset()\n", @@ -306,18 +282,20 @@ "outputs": [], "source": [ "change = yaml.safe_load(\"\"\"\n", - "action_space:\n", - " action_list:\n", - " - type: DONOTHING\n", - " - type: NODE_APPLICATION_EXECUTE\n", - " options:\n", - " nodes:\n", - " - node_name: client_1\n", - " applications:\n", - " - application_name: DataManipulationBot\n", - " max_folders_per_node: 1\n", - " max_files_per_folder: 1\n", - " max_services_per_node: 1\n", + " agent_settings:\n", + " possible_start_nodes: [client_1]\n", + " target_application: DataManipulationBot\n", + "\n", + " action_space:\n", + " action_map:\n", + " 0:\n", + " action: do-nothing\n", + " options: {}\n", + " 1:\n", + " action: node-application-execute\n", + " options:\n", + " node_name: client_1\n", + " application_name: DataManipulationBot\n", "\"\"\")\n", "\n", "with open(data_manipulation_config_path(), 'r') as f:\n", @@ -361,18 +339,18 @@ "change = yaml.safe_load(\"\"\"\n", " applications:\n", " - ref: data_manipulation_bot\n", - " type: DataManipulationBot\n", + " type: data-manipulation-bot\n", " options:\n", " port_scan_p_of_success: 1.0\n", " data_manipulation_p_of_success: 1.0\n", " payload: \"DELETE\"\n", " server_ip: 192.168.1.14\n", " - ref: client_1_web_browser\n", - " type: WebBrowser\n", + " type: web-browser\n", " options:\n", " target_url: http://arcd.com/users/\n", " - ref: client_1_database_client\n", - " type: DatabaseClient\n", + " type: database-client\n", " options:\n", " db_server_ip: 192.168.1.14\n", "\"\"\")\n", @@ -406,18 +384,18 @@ "change = yaml.safe_load(\"\"\"\n", " applications:\n", " - ref: data_manipulation_bot\n", - " type: DataManipulationBot\n", + " type: data-manipulation-bot\n", " options:\n", " port_scan_p_of_success: 0.0\n", " data_manipulation_p_of_success: 0.0\n", " payload: \"DELETE\"\n", " server_ip: 192.168.1.14\n", " - ref: client_1_web_browser\n", - " type: WebBrowser\n", + " type: web-browser\n", " options:\n", " target_url: http://arcd.com/users/\n", " - ref: client_1_database_client\n", - " type: DatabaseClient\n", + " type: database-client\n", " options:\n", " db_server_ip: 192.168.1.14\n", "\"\"\")\n", @@ -444,7 +422,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -458,7 +436,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb index 13533097..562f0f91 100644 --- a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb @@ -6,7 +6,7 @@ "source": [ "# Data Manipulation Scenario\n", "\n", - "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK" + "© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK" ] }, { @@ -165,13 +165,13 @@ "\n", "| node_id | node name |\n", "|---------|------------------|\n", - "| 1 | domain_controller|\n", - "| 2 | web_server |\n", - "| 3 | database_server |\n", - "| 4 | backup_server |\n", - "| 5 | security_suite |\n", - "| 6 | client_1 |\n", - "| 7 | client_2 |\n", + "| 0 | domain_controller|\n", + "| 1 | web_server |\n", + "| 2 | database_server |\n", + "| 3 | backup_server |\n", + "| 4 | security_suite |\n", + "| 5 | client_1 |\n", + "| 6 | client_2 |\n", "\n", "Service 1 on node 2 (web_server) corresponds to the Web Server service. Other services are only there for padding to ensure that each node's observation space has the same shape. They are filled with zeroes.\n", "\n", @@ -371,6 +371,15 @@ "First, load the required modules" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, @@ -420,9 +429,9 @@ " cfg = yaml.safe_load(f)\n", " # set success probability to 1.0 to avoid rerunning cells.\n", " cfg['simulation']['network']['nodes'][8]['applications'][0]['options']['data_manipulation_p_of_success'] = 1.0\n", - " cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['data_manipulation_p_of_success'] = 1.0\n", + " cfg['simulation']['network']['nodes'][9]['applications'][1]['options']['data_manipulation_p_of_success'] = 1.0\n", " cfg['simulation']['network']['nodes'][8]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n", - " cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n", + " cfg['simulation']['network']['nodes'][9]['applications'][1]['options']['port_scan_p_of_success'] = 1.0\n", " # don't flatten observations so that we can see what is going on\n", " cfg['agents'][3]['agent_settings']['flatten_obs'] = False\n", "\n", @@ -449,10 +458,10 @@ " # parse the info dict form step output and write out what the red agent is doing\n", " red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", " red_action = red_info.action\n", - " if red_action == 'DONOTHING':\n", + " if red_action == 'do-nothing':\n", " red_str = 'DO NOTHING'\n", - " elif red_action == 'NODE_APPLICATION_EXECUTE':\n", - " client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n", + " elif red_action == 'node-application-execute':\n", + " client = \"client 1\" if red_info.parameters['node_name'] == 0 else \"client 2\"\n", " red_str = f\"ATTACK from {client}\"\n", " return red_str" ] @@ -547,7 +556,7 @@ "\n", "The reward will increase slightly as soon as the file finishes restoring. Then, the reward will increase to 0.9 when both green agents make successful requests.\n", "\n", - "Run the following cell until the green action is `NODE_APPLICATION_EXECUTE` for application 0, then the reward should increase. If you run it enough times, another red attack will happen and the reward will drop again." + "Run the following cell until the green action is `node_application_execute` for application 0, then the reward should increase. If you run it enough times, another red attack will happen and the reward will drop again." ] }, { @@ -591,7 +600,7 @@ "while abs(reward - 0.8) > 1e-5:\n", " obs, reward, terminated, truncated, info = env.step(0) # do nothing\n", " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n", - " if env.game.step_counter > 10000:\n", + " if env.game.step_counter > 2000:\n", " break # make sure there's no infinite loop if something went wrong" ] }, diff --git a/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb b/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb index 6a60c1bc..40250d0c 100644 --- a/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb +++ b/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb @@ -6,7 +6,16 @@ "source": [ "# Getting information out of PrimAITE\n", "\n", - "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n" + "© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" ] }, { @@ -191,7 +200,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -205,7 +214,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Privilege-Escalation-and Data-Loss-Example.ipynb b/src/primaite/notebooks/Privilege-Escalation-and-Data-Loss-Example.ipynb similarity index 93% rename from src/primaite/notebooks/Privilege-Escalation-and Data-Loss-Example.ipynb rename to src/primaite/notebooks/Privilege-Escalation-and-Data-Loss-Example.ipynb index c751edfd..deb38eea 100644 --- a/src/primaite/notebooks/Privilege-Escalation-and Data-Loss-Example.ipynb +++ b/src/primaite/notebooks/Privilege-Escalation-and-Data-Loss-Example.ipynb @@ -6,7 +6,7 @@ "source": [ "# Simulating Privilege Escalation and Data Loss Using SSH and ACLs Manipulation\n", "\n", - "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n", + "© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK\n", "\n", "## Overview\n", "\n", @@ -51,6 +51,15 @@ "## The Scenario" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, @@ -68,7 +77,8 @@ "from primaite.simulator.network.hardware.nodes.host.server import Server\n", "from primaite.simulator.system.applications.database_client import DatabaseClient\n", "from primaite.simulator.system.applications.web_browser import WebBrowser\n", - "from primaite.simulator.system.services.database.database_service import DatabaseService" + "from primaite.simulator.system.services.database.database_service import DatabaseService\n", + "from primaite.simulator.network.hardware.nodes.network import firewall\n" ] }, { @@ -110,11 +120,11 @@ "outputs": [], "source": [ "some_tech_jnr_dev_pc: Computer = game.simulation.network.get_node_by_hostname(\"some_tech_jnr_dev_pc\")\n", - "some_tech_jnr_dev_db_client: DatabaseClient = some_tech_jnr_dev_pc.software_manager.software[\"DatabaseClient\"]\n", - "some_tech_jnr_dev_web_browser: WebBrowser = some_tech_jnr_dev_pc.software_manager.software[\"WebBrowser\"]\n", + "some_tech_jnr_dev_db_client: DatabaseClient = some_tech_jnr_dev_pc.software_manager.software[\"database-client\"]\n", + "some_tech_jnr_dev_web_browser: WebBrowser = some_tech_jnr_dev_pc.software_manager.software[\"web-browser\"]\n", "some_tech_rt: Router = game.simulation.network.get_node_by_hostname(\"some_tech_rt\")\n", "some_tech_db_srv: Server = game.simulation.network.get_node_by_hostname(\"some_tech_db_srv\")\n", - "some_tech_db_service: DatabaseService = some_tech_db_srv.software_manager.software[\"DatabaseService\"]\n", + "some_tech_db_service: DatabaseService = some_tech_db_srv.software_manager.software[\"database-service\"]\n", "some_tech_storage_srv: Server = game.simulation.network.get_node_by_hostname(\"some_tech_storage_srv\")\n", "some_tech_web_srv: Server = game.simulation.network.get_node_by_hostname(\"some_tech_web_srv\")" ] @@ -201,7 +211,7 @@ "source": [ "caos_action = [\n", " \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n", - " \"service\", \"Terminal\", \"ssh_to_remote\", \"admin\", \"admin\", str(some_tech_storage_srv.network_interface[1].ip_address)\n", + " \"service\", \"terminal\", \"node-session-remote-login\", \"admin\", \"admin\", str(some_tech_storage_srv.network_interface[1].ip_address)\n", "]\n", "game.simulation.apply_request(caos_action)" ] @@ -223,7 +233,7 @@ }, "outputs": [], "source": [ - "caos_action = [\"network\", \"node\", \"some_tech_jnr_dev_pc\", \"application\", \"WebBrowser\", \"execute\"]\n", + "caos_action = [\"network\", \"node\", \"some_tech_jnr_dev_pc\", \"application\", \"web-browser\", \"execute\"]\n", "game.simulation.apply_request(caos_action)" ] }, @@ -246,7 +256,7 @@ "metadata": {}, "outputs": [], "source": [ - "game.get_sim_state()[\"network\"][\"nodes\"][\"some_tech_rt\"][\"services\"][\"UserSessionManager\"][\"active_remote_sessions\"]" + "game.get_sim_state()[\"network\"][\"nodes\"][\"some_tech_rt\"][\"services\"][\"user-session-manager\"][\"active_remote_sessions\"]" ] }, { @@ -259,7 +269,7 @@ "source": [ "caos_action = [\n", " \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n", - " \"service\", \"Terminal\", \"ssh_to_remote\", \"admin\", \"admin\", str(some_tech_rt.network_interface[4].ip_address)\n", + " \"service\", \"terminal\", \"node-session-remote-login\", \"admin\", \"admin\", str(some_tech_rt.network_interface[4].ip_address)\n", "]\n", "game.simulation.apply_request(caos_action)" ] @@ -272,7 +282,7 @@ }, "outputs": [], "source": [ - "game.get_sim_state()[\"network\"][\"nodes\"][\"some_tech_rt\"][\"services\"][\"UserSessionManager\"][\"active_remote_sessions\"]" + "game.get_sim_state()[\"network\"][\"nodes\"][\"some_tech_rt\"][\"services\"][\"user-session-manager\"][\"active_remote_sessions\"]" ] }, { @@ -305,7 +315,7 @@ "source": [ "caos_action = [\n", " \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n", - " \"service\", \"Terminal\", \"send_remote_command\", str(some_tech_rt.network_interface[4].ip_address),\n", + " \"service\", \"terminal\", \"send_remote_command\", str(some_tech_rt.network_interface[4].ip_address),\n", " {\n", " \"command\": [\n", " \"acl\", \"add_rule\", \"PERMIT\", \"TCP\",\n", @@ -358,7 +368,7 @@ "source": [ "caos_action = [\n", " \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n", - " \"service\", \"Terminal\", \"remote_logoff\", str(some_tech_rt.network_interface[4].ip_address)\n", + " \"service\", \"terminal\", \"remote_logoff\", str(some_tech_rt.network_interface[4].ip_address)\n", "]\n", "game.simulation.apply_request(caos_action)" ] @@ -376,7 +386,7 @@ "metadata": {}, "outputs": [], "source": [ - "game.get_sim_state()[\"network\"][\"nodes\"][\"some_tech_rt\"][\"services\"][\"UserSessionManager\"][\"active_remote_sessions\"]" + "game.get_sim_state()[\"network\"][\"nodes\"][\"some_tech_rt\"][\"services\"][\"user-session-manager\"][\"active_remote_sessions\"]" ] }, { @@ -396,7 +406,7 @@ "source": [ "caos_action = [\n", " \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n", - " \"service\", \"Terminal\", \"ssh_to_remote\", \"admin\", \"admin\", str(some_tech_storage_srv.network_interface[1].ip_address)\n", + " \"service\", \"terminal\", \"node-session-remote-login\", \"admin\", \"admin\", str(some_tech_storage_srv.network_interface[1].ip_address)\n", "]\n", "game.simulation.apply_request(caos_action)" ] @@ -411,7 +421,7 @@ "source": [ "caos_action = [\n", " \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n", - " \"service\", \"Terminal\", \"send_remote_command\", str(some_tech_storage_srv.network_interface[1].ip_address),\n", + " \"service\", \"terminal\", \"send_remote_command\", str(some_tech_storage_srv.network_interface[1].ip_address),\n", " {\n", " \"command\": [\n", " \"file_system\", \"delete\", \"file\", db_backup_folder, \"database.db\"\n", @@ -466,7 +476,7 @@ }, "outputs": [], "source": [ - "caos_action = [\"network\", \"node\", \"some_tech_jnr_dev_pc\", \"application\", \"WebBrowser\", \"execute\"]\n", + "caos_action = [\"network\", \"node\", \"some_tech_jnr_dev_pc\", \"application\", \"web-browser\", \"execute\"]\n", "game.simulation.apply_request(caos_action)" ] }, @@ -525,7 +535,7 @@ }, "outputs": [], "source": [ - "caos_action = [\"network\", \"node\", \"some_tech_jnr_dev_pc\", \"application\", \"WebBrowser\", \"execute\"]\n", + "caos_action = [\"network\", \"node\", \"some_tech_jnr_dev_pc\", \"application\", \"web-browser\", \"execute\"]\n", "game.simulation.apply_request(caos_action)" ] }, diff --git a/src/primaite/notebooks/Requests-and-Responses.ipynb b/src/primaite/notebooks/Requests-and-Responses.ipynb index da614c93..c29d41dc 100644 --- a/src/primaite/notebooks/Requests-and-Responses.ipynb +++ b/src/primaite/notebooks/Requests-and-Responses.ipynb @@ -6,7 +6,7 @@ "source": [ "# Requests and Responses\n", "\n", - "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n", + "© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK\n", "\n", "Agents interact with the PrimAITE simulation via the Request system.\n" ] @@ -25,6 +25,15 @@ "Let's set up a minimal network simulation and send some requests to see how it works." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, @@ -43,12 +52,17 @@ "outputs": [], "source": [ "sim = Simulation()\n", + "\n", "sim.network.add_node(\n", - " HostNode(\n", - " hostname=\"client\",\n", - " ip_address='10.0.0.1',\n", - " subnet_mask='255.255.255.0',\n", - " operating_state=NodeOperatingState.ON)\n", + " HostNode.from_config(\n", + " config = {\n", + " 'type': \"host-node\",\n", + " 'hostname': \"client\",\n", + " 'ip_address': '10.0.0.1',\n", + " 'subnet_mask': '255.255.255.0',\n", + " 'operating_state': \"ON\",\n", + " }\n", + " )\n", ")\n", "client = sim.network.get_node_by_hostname('client')\n" ] @@ -85,7 +99,7 @@ "outputs": [], "source": [ "response = sim.apply_request(\n", - " request=[\"network\", \"node\", \"client\", \"service\", \"DNSClient\", \"stop\"],\n", + " request=[\"network\", \"node\", \"client\", \"service\", \"dns-client\", \"stop\"],\n", " context={}\n", " )\n", "print(response)" @@ -105,7 +119,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(f\"DNS Client state: {client.software_manager.software.get('DNSClient').operating_state.name}\")" + "print(f\"DNS Client state: {client.software_manager.software.get('dns-client').operating_state.name}\")" ] }, { @@ -129,7 +143,7 @@ "outputs": [], "source": [ "response = sim.apply_request(\n", - " request=[\"network\", \"node\", \"client\", \"service\", \"NonExistentApplication\", \"stop\"],\n", + " request=[\"network\", \"node\", \"client\", \"service\", \"non-existent-application\", \"stop\"],\n", " context={}\n", " )\n", "print(response)" @@ -192,7 +206,7 @@ "outputs": [], "source": [ "response = sim.apply_request(\n", - " request=[\"network\", \"node\", \"client\", \"service\", \"DNSClient\", \"start\"],\n", + " request=[\"network\", \"node\", \"client\", \"service\", \"dns-client\", \"start\"],\n", " context={}\n", " )\n", "print(response)" @@ -201,7 +215,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, diff --git a/src/primaite/notebooks/Terminal-Processing.ipynb b/src/primaite/notebooks/Terminal-Processing.ipynb index 2ab06a5c..755b0184 100644 --- a/src/primaite/notebooks/Terminal-Processing.ipynb +++ b/src/primaite/notebooks/Terminal-Processing.ipynb @@ -6,7 +6,7 @@ "source": [ "# Terminal Processing\n", "\n", - "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK" + "© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK" ] }, { @@ -25,6 +25,15 @@ "The Terminal service comes pre-installed on most Nodes (The exception being Switches, as these are currently dumb). " ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, @@ -40,9 +49,26 @@ "def basic_network() -> Network:\n", " \"\"\"Utility function for creating a default network to demonstrate Terminal functionality\"\"\"\n", " network = Network()\n", - " node_a = Computer(hostname=\"node_a\", ip_address=\"192.168.0.10\", subnet_mask=\"255.255.255.0\", start_up_duration=0)\n", + " node_a = Computer.from_config(\n", + " config = {\n", + " \"type\": \"computer\",\n", + " \"hostname\": \"node_a\",\n", + " \"ip_address\": \"192.168.0.10\",\n", + " \"subnet_mask\": \"255.255.255.0\",\n", + " # \"startup_duration\": 0,\n", + " }\n", + " )\n", + " print(f\"{node_a=}\")\n", " node_a.power_on()\n", - " node_b = Computer(hostname=\"node_b\", ip_address=\"192.168.0.11\", subnet_mask=\"255.255.255.0\", start_up_duration=0)\n", + " node_b = Computer.from_config(\n", + " config = {\n", + " \"type\": \"computer\",\n", + " \"hostname\": \"node_b\",\n", + " \"ip_address\": \"192.168.0.11\",\n", + " \"subnet_mask\": \"255.255.255.0\",\n", + " # \"startup_duration\": 0,\n", + " }\n", + " )\n", " node_b.power_on()\n", " network.connect(node_a.network_interface[1], node_b.network_interface[1])\n", " return network" @@ -65,9 +91,9 @@ "source": [ "network: Network = basic_network()\n", "computer_a: Computer = network.get_node_by_hostname(\"node_a\")\n", - "terminal_a: Terminal = computer_a.software_manager.software.get(\"Terminal\")\n", + "terminal_a: Terminal = computer_a.software_manager.software.get(\"terminal\")\n", "computer_b: Computer = network.get_node_by_hostname(\"node_b\")\n", - "terminal_b: Terminal = computer_b.software_manager.software.get(\"Terminal\")" + "terminal_b: Terminal = computer_b.software_manager.software.get(\"terminal\")" ] }, { @@ -119,7 +145,7 @@ "metadata": {}, "outputs": [], "source": [ - "term_a_term_b_remote_connection.execute([\"software_manager\", \"application\", \"install\", \"RansomwareScript\"])" + "term_a_term_b_remote_connection.execute([\"software_manager\", \"application\", \"install\", \"ransomware-script\"])" ] }, { @@ -235,9 +261,9 @@ "\n", "| Game Layer Action | Simulation Layer |\n", "|-----------------------------------|--------------------------|\n", - "| ``NODE_SEND_LOCAL_COMMAND`` | Uses the given user credentials, creates a ``LocalTerminalSession`` and executes the given command and returns the ``RequestResponse``.\n", - "| ``SSH_TO_REMOTE`` | Uses the given user credentials and remote IP to create a ``RemoteTerminalSession``.\n", - "| ``NODE_SEND_REMOTE_COMMAND`` | Uses the given remote IP to locate the correct ``RemoteTerminalSession``, executes the given command and returns the ``RequestsResponse``." + "| ``node-send-local-command`` | Uses the given user credentials, creates a ``LocalTerminalSession`` and executes the given command and returns the ``RequestResponse``.\n", + "| ``node-session-remote-login`` | Uses the given user credentials and remote IP to create a ``RemoteTerminalSession``.\n", + "| ``node-send-remote-command`` | Uses the given remote IP to locate the correct ``RemoteTerminalSession``, executes the given command and returns the ``RequestsResponse``." ] }, { @@ -271,35 +297,16 @@ "custom_terminal_agent = \"\"\"\n", " - ref: CustomC2Agent\n", " team: RED\n", - " type: ProxyAgent\n", - " observation_space: null\n", + " type: proxy-agent\n", " action_space:\n", - " action_list:\n", - " - type: DONOTHING\n", - " - type: NODE_SEND_LOCAL_COMMAND\n", - " - type: SSH_TO_REMOTE\n", - " - type: NODE_SEND_REMOTE_COMMAND\n", - " options:\n", - " nodes:\n", - " - node_name: client_1\n", - " max_folders_per_node: 1\n", - " max_files_per_folder: 1\n", - " max_services_per_node: 2\n", - " max_nics_per_node: 8\n", - " max_acl_rules: 10\n", - " ip_list:\n", - " - 192.168.1.21\n", - " - 192.168.1.14\n", - " wildcard_list:\n", - " - 0.0.0.1\n", " action_map:\n", " 0:\n", - " action: DONOTHING\n", + " action: do-nothing\n", " options: {}\n", " 1:\n", - " action: NODE_SEND_LOCAL_COMMAND\n", + " action: node-send-local-command\n", " options:\n", - " node_id: 0\n", + " node_name: client_1\n", " username: admin\n", " password: admin\n", " command:\n", @@ -310,16 +317,16 @@ " - dog.png\n", " - False\n", " 2:\n", - " action: SSH_TO_REMOTE\n", + " action: node-session-remote-login\n", " options:\n", - " node_id: 0\n", + " node_name: client_1\n", " username: admin\n", " password: admin\n", " remote_ip: 192.168.10.22\n", " 3:\n", - " action: NODE_SEND_REMOTE_COMMAND\n", + " action: node-send-remote-command\n", " options:\n", - " node_id: 0\n", + " node_name: client_1\n", " remote_ip: 192.168.10.22\n", " command:\n", " - file_system\n", @@ -328,9 +335,6 @@ " - downloads\n", " - cat.png\n", " - False\n", - " reward_function:\n", - " reward_components:\n", - " - type: DUMMY\n", "\"\"\"\n", "custom_terminal_agent_yaml = yaml.safe_load(custom_terminal_agent)" ] @@ -346,7 +350,7 @@ " # removing all agents & adding the custom agent.\n", " cfg['agents'] = {}\n", " cfg['agents'] = custom_terminal_agent_yaml\n", - " \n", + "\n", "env = PrimaiteGymEnv(env_config=cfg)\n", "\n", "client_1: Computer = env.game.simulation.network.get_node_by_hostname(\"client_1\")\n", @@ -357,7 +361,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Terminal Action | ``NODE_SEND_LOCAL_COMMAND`` \n", + "### Terminal Action | ``node-send-local-command`` \n", "\n", "The yaml snippet below shows all the relevant agent options for this action:\n", "\n", @@ -366,7 +370,7 @@ " action_space:\n", " action_list:\n", " ...\n", - " - type: NODE_SEND_LOCAL_COMMAND\n", + " - type: node-send-local-command\n", " ...\n", " options:\n", " nodes: # Node List\n", @@ -375,7 +379,7 @@ " ...\n", " action_map:\n", " 1:\n", - " action: NODE_SEND_LOCAL_COMMAND\n", + " action: node-send-local-command\n", " options:\n", " node_id: 0 # Index 0 at the node list.\n", " username: admin\n", @@ -404,7 +408,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Terminal Action | ``SSH_TO_REMOTE`` \n", + "### Terminal Action | ``node-session-remote-login`` \n", "\n", "The yaml snippet below shows all the relevant agent options for this action:\n", "\n", @@ -413,7 +417,7 @@ " action_space:\n", " action_list:\n", " ...\n", - " - type: SSH_TO_REMOTE\n", + " - type: node-session-remote-login\n", " ...\n", " options:\n", " nodes: # Node List\n", @@ -422,7 +426,7 @@ " ...\n", " action_map:\n", " 2:\n", - " action: SSH_TO_REMOTE\n", + " action: node-session-remote-login\n", " options:\n", " node_id: 0 # Index 0 at the node list.\n", " username: admin\n", @@ -445,7 +449,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Terminal Action | ``NODE_SEND_REMOTE_COMMAND``\n", + "### Terminal Action | ``node-send-remote-command``\n", "\n", "The yaml snippet below shows all the relevant agent options for this action:\n", "\n", @@ -454,7 +458,7 @@ " action_space:\n", " action_list:\n", " ...\n", - " - type: NODE_SEND_REMOTE_COMMAND\n", + " - type: node-send-remote-command\n", " ...\n", " options:\n", " nodes: # Node List\n", @@ -463,7 +467,7 @@ " ...\n", " action_map:\n", " 1:\n", - " action: NODE_SEND_REMOTE_COMMAND\n", + " action: node-send-remote-command\n", " options:\n", " node_id: 0 # Index 0 at the node list.\n", " remote_ip: 192.168.10.22\n", @@ -490,7 +494,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "venv", "language": "python", "name": "python3" }, @@ -504,7 +508,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb b/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb index 19e95a95..87d9c377 100644 --- a/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb +++ b/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb @@ -6,7 +6,7 @@ "source": [ "# Train a Multi agent system using RLLIB\n", "\n", - "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n", + "© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK\n", "\n", "This notebook will demonstrate how to use the `PrimaiteRayMARLEnv` to train a very basic system with two PPO agents." ] @@ -18,6 +18,15 @@ "#### First, Import packages and read our config file." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, @@ -31,9 +40,8 @@ "import ray\n", "from ray.rllib.algorithms.ppo import PPOConfig\n", "from primaite.session.ray_envs import PrimaiteRayMARLEnv\n", + "from primaite.game.agent.scripted_agents import probabilistic_agent\n", "\n", - "# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n", - "# to copy the files to your user data path.\n", "with open(PRIMAITE_PATHS.user_config_path / 'example_config/data_manipulation_marl.yaml', 'r') as f:\n", " cfg = yaml.safe_load(f)\n", "\n", @@ -103,21 +111,9 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb index dbe8871c..79740bca 100644 --- a/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb @@ -6,11 +6,20 @@ "source": [ "# Train a Single agent system using RLLib\n", "\n", - "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n", + "© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK\n", "\n", "This notebook will demonstrate how to use PrimaiteRayEnv to train a basic PPO agent." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, @@ -23,6 +32,8 @@ "from primaite.session.ray_envs import PrimaiteRayEnv\n", "import ray\n", "from ray.rllib.algorithms.ppo import PPOConfig\n", + "from primaite.game.agent.scripted_agents import probabilistic_agent\n", + "\n", "\n", "# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n", "# to copy the files to your user data path.\n", @@ -95,21 +106,9 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb index 892736fe..1bec0183 100644 --- a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb @@ -6,7 +6,7 @@ "source": [ "# Training an SB3 Agent\n", "\n", - "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n", + "© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK\n", "\n", "This notebook will demonstrate how to use primaite to create and train a PPO agent, using a pre-defined configuration file." ] @@ -18,6 +18,15 @@ "#### First, we import the inital packages and read in our configuration file." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, @@ -26,6 +35,7 @@ "source": [ "from primaite.game.game import PrimaiteGame\n", "from primaite.session.environment import PrimaiteGymEnv\n", + "from primaite.game.agent.scripted_agents import probabilistic_agent\n", "import yaml" ] }, @@ -182,7 +192,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Using-Episode-Schedules.ipynb b/src/primaite/notebooks/Using-Episode-Schedules.ipynb index cb06e0f9..4bc49a70 100644 --- a/src/primaite/notebooks/Using-Episode-Schedules.ipynb +++ b/src/primaite/notebooks/Using-Episode-Schedules.ipynb @@ -6,7 +6,7 @@ "source": [ "# Using Episode Schedules\n", "\n", - "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n", + "© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK\n", "\n", "PrimAITE supports the ability to use different variations on a scenario at different episodes. This can be used to increase \n", "domain randomisation to prevent overfitting, or to set up curriculum learning to train agents to perform more complicated tasks.\n", @@ -48,6 +48,7 @@ "from primaite.session.environment import PrimaiteGymEnv\n", "from primaite import PRIMAITE_PATHS\n", "from prettytable import PrettyTable\n", + "from primaite.game.agent.scripted_agents import probabilistic_agent, data_manipulation_bot\n", "scenario_path = PRIMAITE_PATHS.user_config_path / \"example_config/scenario_with_placeholders\"" ] }, @@ -238,7 +239,7 @@ "### Episode 2\n", "When we reset the environment again, it moves onto episode 2, where it will bring in greens_1 and reds_1 for green and red agent definitions. Let's verify the agent names and that they take actions at the defined frequency.\n", "\n", - "Most green actions will be `NODE_APPLICATION_EXECUTE` while red will `DONOTHING` except at steps 10 and 20." + "Most green actions will be `node_application_execute` while red will `DONOTHING` except at steps 10 and 20." ] }, { @@ -269,7 +270,7 @@ "### Episode 3\n", "When we reset the environment again, it moves onto episode 3, where it will bring in greens_2 and reds_2 for green and red agent definitions. Let's verify the agent names and that they take actions at the defined frequency.\n", "\n", - "Now, green will perform `NODE_APPLICATION_EXECUTE` only 5% of the time, while red will perform `NODE_APPLICATION_EXECUTE` more frequently than before." + "Now, green will perform `node_application_execute` only 5% of the time, while red will perform `node_application_execute` more frequently than before." ] }, { @@ -409,7 +410,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -423,7 +424,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/_package_data/uc2_attack.png b/src/primaite/notebooks/_package_data/uc2_attack.png index 8b8df5ce..03797d00 100644 Binary files a/src/primaite/notebooks/_package_data/uc2_attack.png and b/src/primaite/notebooks/_package_data/uc2_attack.png differ diff --git a/src/primaite/notebooks/_package_data/uc2_network.png b/src/primaite/notebooks/_package_data/uc2_network.png index 20fa43c9..10989201 100644 Binary files a/src/primaite/notebooks/_package_data/uc2_network.png and b/src/primaite/notebooks/_package_data/uc2_network.png differ diff --git a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb b/src/primaite/notebooks/create-simulation_demo.ipynb similarity index 77% rename from src/primaite/simulator/_package_data/create-simulation_demo.ipynb rename to src/primaite/notebooks/create-simulation_demo.ipynb index 77ac4842..46f03be6 100644 --- a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb +++ b/src/primaite/notebooks/create-simulation_demo.ipynb @@ -6,7 +6,7 @@ "source": [ "# Build a simulation using the Python API\n", "\n", - "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n", + "© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK\n", "\n", "Currently, this notebook manipulates the simulation by directly placing objects inside of the attributes of the network and domain. It should be refactored when proper methods exist for adding these objects." ] @@ -70,9 +70,23 @@ "metadata": {}, "outputs": [], "source": [ - "my_pc = Computer(hostname=\"Computer\", ip_address=\"192.168.1.10\", subnet_mask=\"255.255.255.0\")\n", + "my_pc = Computer.from_config(\n", + " config={\n", + " \"type\": \"computer\",\n", + " \"hostname\":\"pc_1\",\n", + " \"ip_address\":\"192.168.1.10\",\n", + " \"subnet_mask\":\"255.255.255.0\",\n", + " }\n", + " )\n", "net.add_node(my_pc)\n", - "my_server = Server(hostname=\"Server\", ip_address=\"192.168.1.11\", subnet_mask=\"255.255.255.0\")\n", + "my_server = Server.from_config(\n", + " config={\n", + " \"type\": \"server\",\n", + " \"hostname\":\"Server\",\n", + " \"ip_address\":\"192.168.1.11\",\n", + " \"subnet_mask\":\"255.255.255.0\"\n", + " }\n", + ")\n", "net.add_node(my_server)\n" ] }, @@ -99,7 +113,13 @@ "metadata": {}, "outputs": [], "source": [ - "my_switch = Switch(hostname=\"switch1\", num_ports=12)\n", + "my_switch = Switch.from_config(\n", + " config = {\n", + " \"type\":\"switch\",\n", + " \"hostname\":\"switch1\",\n", + " \"num_ports\":12\n", + " }\n", + ")\n", "net.add_node(my_switch)\n", "\n", "pc_nic = NIC(ip_address=\"130.1.1.1\", gateway=\"130.1.1.255\", subnet_mask=\"255.255.255.0\")\n", @@ -163,15 +183,30 @@ "metadata": {}, "outputs": [], "source": [ + "from pydantic import Field\n", + "\n", "from pathlib import Path\n", "from primaite.simulator.system.applications.application import Application, ApplicationOperatingState\n", "from primaite.simulator.system.software import SoftwareHealthState, SoftwareCriticality\n", - "from primaite.simulator.network.transmission.transport_layer import Port\n", - "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", "from primaite.simulator.file_system.file_system import FileSystem\n", + "from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n", + "from primaite.utils.validation.port import PORT_LOOKUP\n", + "from primaite.simulator.system.core.sys_log import SysLog\n", + "\n", "\n", "# no applications exist yet so we will create our own.\n", - "class MSPaint(Application, identifier=\"MSPaint\"):\n", + "class MSPaint(Application, discriminator=\"MSPaint\"):\n", + " class ConfigSchema(Application.ConfigSchema):\n", + " type: str = \"MSPaint\"\n", + "\n", + " config: ConfigSchema = Field(default_factory=lambda: MSPaint.ConfigSchema())\n", + "\n", + " def __init__(self, **kwargs):\n", + " kwargs[\"name\"] = \"MSPaint\"\n", + " kwargs[\"port\"] = PORT_LOOKUP[\"HTTP\"]\n", + " kwargs[\"protocol\"] = PROTOCOL_LOOKUP[\"NONE\"]\n", + " super().__init__(**kwargs)\n", + "\n", " def describe_state(self):\n", " return super().describe_state()" ] @@ -182,7 +217,8 @@ "metadata": {}, "outputs": [], "source": [ - "mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port.HTTP, protocol = IPProtocol.NONE,operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)" + "my_pc.software_manager.install(MSPaint)\n", + "mspaint = my_pc.software_manager.software.get(\"MSPaint\")" ] }, { @@ -249,7 +285,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, diff --git a/src/primaite/notebooks/multi-processing.ipynb b/src/primaite/notebooks/multi-processing.ipynb index 305cfd70..e56bf362 100644 --- a/src/primaite/notebooks/multi-processing.ipynb +++ b/src/primaite/notebooks/multi-processing.ipynb @@ -6,7 +6,7 @@ "source": [ "# Simple multi-processing demonstration\n", "\n", - "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n", + "© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK\n", "\n", "This notebook uses SubprocVecEnv from SB3." ] @@ -18,6 +18,15 @@ "Import packages and read config file." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite setup" + ] + }, { "cell_type": "code", "execution_count": null, @@ -28,7 +37,6 @@ "from stable_baselines3 import PPO\n", "from stable_baselines3.common.utils import set_random_seed\n", "from stable_baselines3.common.vec_env import SubprocVecEnv\n", - "\n", "from primaite.session.environment import PrimaiteGymEnv\n" ] }, @@ -129,7 +137,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -143,7 +151,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb b/src/primaite/notebooks/network_simulator_demo.ipynb similarity index 96% rename from src/primaite/simulator/_package_data/network_simulator_demo.ipynb rename to src/primaite/notebooks/network_simulator_demo.ipynb index 17a0f796..be930ac0 100644 --- a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb +++ b/src/primaite/notebooks/network_simulator_demo.ipynb @@ -7,7 +7,7 @@ "source": [ "# PrimAITE Router Simulation Demo\n", "\n", - "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n", + "© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK\n", "\n", "This demo uses a modified version of the ARCD Use Case 2 Network (seen below) to demonstrate the capabilities of the Network simulator in PrimAITE." ] @@ -532,12 +532,12 @@ }, "outputs": [], "source": [ - "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", - "from primaite.simulator.network.transmission.transport_layer import Port\n", "from primaite.simulator.network.hardware.nodes.network.router import ACLAction\n", + "from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n", + "\n", "network.get_node_by_hostname(\"router_1\").acl.add_rule(\n", " action=ACLAction.DENY,\n", - " protocol=IPProtocol.ICMP,\n", + " protocol=PROTOCOL_LOOKUP[\"ICMP\"],\n", " src_ip_address=\"192.168.10.22\",\n", " position=1\n", ")" @@ -653,18 +653,6 @@ "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" } }, "nbformat": 4, diff --git a/src/primaite/session/__init__.py b/src/primaite/session/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/session/__init__.py +++ b/src/primaite/session/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index db5425e3..fa545dbc 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import json import random import sys @@ -111,7 +111,7 @@ class PrimaiteGymEnv(gymnasium.Env): :return: Action mask :rtype: List[bool] """ - if not self.agent.action_masking: + if not self.agent.config.agent_settings.action_masking: return np.asarray([True] * len(self.agent.action_manager.action_map)) else: return self.game.action_mask(self._agent_name) diff --git a/src/primaite/session/episode_schedule.py b/src/primaite/session/episode_schedule.py index ad4d38e9..126dcf9f 100644 --- a/src/primaite/session/episode_schedule.py +++ b/src/primaite/session/episode_schedule.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import copy from abc import ABC, abstractmethod from itertools import chain diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index 78d7cb3c..6c2f4f29 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import json from datetime import datetime from pathlib import Path diff --git a/src/primaite/session/ray_envs.py b/src/primaite/session/ray_envs.py index 33c74b0e..16c85cb3 100644 --- a/src/primaite/session/ray_envs.py +++ b/src/primaite/session/ray_envs.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import json from typing import Dict, SupportsFloat, Tuple @@ -44,7 +44,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): ) for agent_name in self._agent_ids: agent = self.game.rl_agents[agent_name] - if agent.action_masking: + if agent.config.agent_settings.action_masking: self.observation_space[agent_name] = spaces.Dict( { "action_mask": spaces.MultiBinary(agent.action_manager.space.n), @@ -143,7 +143,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): unflat_space = agent.observation_manager.space unflat_obs = agent.observation_manager.current_observation obs = gymnasium.spaces.flatten(unflat_space, unflat_obs) - if agent.action_masking: + if agent.config.agent_settings.action_masking: all_obs[agent_name] = {"action_mask": self.game.action_mask(agent_name), "observations": obs} else: all_obs[agent_name] = obs @@ -168,7 +168,7 @@ class PrimaiteRayEnv(gymnasium.Env): self.env = PrimaiteGymEnv(env_config=env_config) # self.env.episode_counter -= 1 self.action_space = self.env.action_space - if self.env.agent.action_masking: + if self.env.agent.config.agent_settings.action_masking: self.observation_space = spaces.Dict( {"action_mask": spaces.MultiBinary(self.env.action_space.n), "observations": self.env.observation_space} ) @@ -178,7 +178,7 @@ class PrimaiteRayEnv(gymnasium.Env): def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" super().reset() # Ensure PRNG seed is set everywhere - if self.env.agent.action_masking: + if self.env.agent.config.agent_settings.action_masking: obs, *_ = self.env.reset(seed=seed) new_obs = {"action_mask": self.env.action_masks(), "observations": obs} return new_obs, *_ @@ -187,7 +187,7 @@ class PrimaiteRayEnv(gymnasium.Env): def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]: """Perform a step in the environment.""" # if action masking is enabled, intercept the step method and add action mask to observation - if self.env.agent.action_masking: + if self.env.agent.config.agent_settings.action_masking: obs, *_ = self.env.step(action) new_obs = {"action_mask": self.game.action_mask(self.env._agent_name), "observations": obs} return new_obs, *_ diff --git a/src/primaite/setup/__init__.py b/src/primaite/setup/__init__.py index 12e7c4e7..1447a47b 100644 --- a/src/primaite/setup/__init__.py +++ b/src/primaite/setup/__init__.py @@ -1,2 +1,2 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """Utilities to prepare the user's data folders.""" diff --git a/src/primaite/setup/reset_demo_notebooks.py b/src/primaite/setup/reset_demo_notebooks.py index f17fb211..ad4091e3 100644 --- a/src/primaite/setup/reset_demo_notebooks.py +++ b/src/primaite/setup/reset_demo_notebooks.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import filecmp import shutil from logging import Logger diff --git a/src/primaite/setup/reset_example_configs.py b/src/primaite/setup/reset_example_configs.py index c7eeecd5..a94d6d4a 100644 --- a/src/primaite/setup/reset_example_configs.py +++ b/src/primaite/setup/reset_example_configs.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import filecmp import os import shutil diff --git a/src/primaite/simulator/__init__.py b/src/primaite/simulator/__init__.py index ade1a73b..e85a2d1e 100644 --- a/src/primaite/simulator/__init__.py +++ b/src/primaite/simulator/__init__.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """Warning: SIM_OUTPUT is a mutable global variable for the simulation output directory.""" from datetime import datetime from enum import IntEnum diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 848570fe..750372b3 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK # flake8: noqa """Core of the PrimAITE Simulator.""" import warnings @@ -244,7 +244,7 @@ class SimComponent(BaseModel): ..code::python - class WebBrowser(Application, identifier="WebBrowser"): + class WebBrowser(Application, discriminator="web-browser"): def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() # all requests generic to any Application get initialised rm.add_request(...) # initialise any requests specific to the web browser diff --git a/src/primaite/simulator/domain/__init__.py b/src/primaite/simulator/domain/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/domain/__init__.py +++ b/src/primaite/simulator/domain/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/domain/account.py b/src/primaite/simulator/domain/account.py index d955cf55..85ec6d46 100644 --- a/src/primaite/simulator/domain/account.py +++ b/src/primaite/simulator/domain/account.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """User account simulation.""" from enum import Enum from typing import Dict diff --git a/src/primaite/simulator/domain/controller.py b/src/primaite/simulator/domain/controller.py index a264ba24..d8b7782c 100644 --- a/src/primaite/simulator/domain/controller.py +++ b/src/primaite/simulator/domain/controller.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from enum import Enum from typing import Dict, Final, List, Literal, Tuple diff --git a/src/primaite/simulator/file_system/__init__.py b/src/primaite/simulator/file_system/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/file_system/__init__.py +++ b/src/primaite/simulator/file_system/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/file_system/file.py b/src/primaite/simulator/file_system/file.py index ba39c791..58607bf6 100644 --- a/src/primaite/simulator/file_system/file.py +++ b/src/primaite/simulator/file_system/file.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations import hashlib @@ -130,8 +130,8 @@ class File(FileSystemItemABC): Return False if corruption is detected, otherwise True """ - warnings.warn("NODE_FILE_CHECKHASH is currently not implemented.") - self.sys_log.warning("NODE_FILE_CHECKHASH is currently not implemented.") + warnings.warn("node-file-checkhash is currently not implemented.") + self.sys_log.warning("node-file-checkhash is currently not implemented.") return False if self.deleted: diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index 2162915f..54e649f2 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations from pathlib import Path @@ -30,6 +30,11 @@ class FileSystem(SimComponent): num_file_deletions: int = 0 "Number of file deletions in the current step." + _default_folder_scan_duration: Optional[int] = None + "Override default scan duration for folders" + _default_folder_restore_duration: Optional[int] = None + "Override default restore duration for folders" + def __init__(self, **kwargs): super().__init__(**kwargs) # Ensure a default root folder @@ -258,6 +263,11 @@ class FileSystem(SimComponent): name=folder.name, request_type=RequestType(func=folder._request_manager) ) self.folders[folder.uuid] = folder + # set the folder scan and restore durations. + if self._default_folder_scan_duration is not None: + folder.scan_duration = self._default_folder_scan_duration + if self._default_folder_restore_duration is not None: + folder.restore_duration = self._default_folder_restore_duration return folder def delete_folder(self, folder_name: str) -> bool: diff --git a/src/primaite/simulator/file_system/file_system_item_abc.py b/src/primaite/simulator/file_system/file_system_item_abc.py index a9db8825..db51924c 100644 --- a/src/primaite/simulator/file_system/file_system_item_abc.py +++ b/src/primaite/simulator/file_system/file_system_item_abc.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations import math @@ -43,6 +43,9 @@ def convert_size(size_bytes: int) -> str: class FileSystemItemHealthStatus(Enum): """Status of the FileSystemItem.""" + NONE = 0 + """File system item health status is not known.""" + GOOD = 1 """File/Folder is OK.""" @@ -72,7 +75,7 @@ class FileSystemItemABC(SimComponent): health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.GOOD "Actual status of the current FileSystemItem" - visible_health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.GOOD + visible_health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.NONE "Visible status of the current FileSystemItem" previous_hash: Optional[str] = None diff --git a/src/primaite/simulator/file_system/file_type.py b/src/primaite/simulator/file_system/file_type.py index e6e81070..343d3565 100644 --- a/src/primaite/simulator/file_system/file_type.py +++ b/src/primaite/simulator/file_system/file_type.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations from enum import Enum diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index c98e4492..5b9a6931 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations import warnings @@ -46,7 +46,7 @@ class Folder(FileSystemItemABC): :param sys_log: The SysLog instance to us to create system logs. """ super().__init__(**kwargs) - + self._scanned_this_step: bool = False self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})") def _init_request_manager(self) -> RequestManager: @@ -83,6 +83,7 @@ class Folder(FileSystemItemABC): state = super().describe_state() state["files"] = {file.name: file.describe_state() for uuid, file in self.files.items()} state["deleted_files"] = {file.name: file.describe_state() for uuid, file in self.deleted_files.items()} + state["scanned_this_step"] = self._scanned_this_step return state def show(self, markdown: bool = False): @@ -135,7 +136,7 @@ class Folder(FileSystemItemABC): def pre_timestep(self, timestep: int) -> None: """Apply pre-timestep logic.""" super().pre_timestep(timestep) - + self._scanned_this_step = False for file in self.files.values(): file.pre_timestep(timestep) @@ -148,9 +149,17 @@ class Folder(FileSystemItemABC): for file_id in self.files: file = self.get_file_by_id(file_uuid=file_id) file.scan() - if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT: - self.health_status = FileSystemItemHealthStatus.CORRUPT + # set folder health to worst file's health by generating a list of file healths. If no files, use 0 + self.health_status = FileSystemItemHealthStatus( + max( + [f.health_status.value for f in self.files.values()] + or [ + 0, + ] + ) + ) self.visible_health_status = self.health_status + self._scanned_this_step = True def _reveal_to_red_timestep(self) -> None: """Apply reveal to red timestep.""" @@ -387,8 +396,8 @@ class Folder(FileSystemItemABC): Return False if corruption is detected, otherwise True """ - warnings.warn("NODE_FOLDER_CHECKHASH is currently not implemented.") - self.sys_log.error("NODE_FOLDER_CHECKHASH is currently not implemented.") + warnings.warn("node-folder-checkhash is currently not implemented.") + self.sys_log.error("node-folder-checkhash is currently not implemented.") return False if self.deleted: diff --git a/src/primaite/simulator/network/__init__.py b/src/primaite/simulator/network/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/network/__init__.py +++ b/src/primaite/simulator/network/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index cdb01514..434940f8 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -1,12 +1,11 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations from abc import ABC, abstractmethod -from enum import Enum -from typing import Any, Dict, List +from typing import Any, ClassVar, Dict, List from prettytable import MARKDOWN, PrettyTable -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field, validate_call from primaite import getLogger from primaite.simulator.network.hardware.base import Layer3Interface, NetworkInterface, WiredNetworkInterface @@ -41,50 +40,28 @@ def format_hertz(hertz: float, format_terahertz: bool = False, decimals: int = 3 return format_str.format(hertz) + " Hz" -class AirSpaceFrequency(Enum): - """Enumeration representing the operating frequencies for wireless communications.""" +class AirSpaceFrequency(BaseModel): + """Data transfer object for defining properties of an airspace frequency.""" - WIFI_2_4 = 2.4e9 - """WiFi 2.4 GHz. Known for its extensive range and ability to penetrate solid objects effectively.""" - WIFI_5 = 5e9 - """WiFi 5 GHz. Known for its higher data transmission speeds and reduced interference from other devices.""" + model_config = ConfigDict(extra="forbid") + name: str + """Alias for frequency.""" + frequency_hz: int + """This acts as the primary key. If two names are mapped to the same frequency, they will share a bandwidth.""" + data_rate_bps: float + """How much data can be transmitted on this frequency per second.""" - def __str__(self) -> str: - hertz_str = format_hertz(hertz=self.value) - if self == AirSpaceFrequency.WIFI_2_4: - return f"WiFi {hertz_str}" - if self == AirSpaceFrequency.WIFI_5: - return f"WiFi {hertz_str}" - return "Unknown Frequency" + _registry: ClassVar[Dict[str, AirSpaceFrequency]] = {} - @property - def maximum_data_rate_bps(self) -> float: - """ - Retrieves the maximum data transmission rate in bits per second (bps). + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.name in self._registry: + raise RuntimeError(f"Frequency {self.name} is already registered. Cannot register it again.") + self._registry[self.name] = self - The maximum rates are predefined for frequencies.: - - WIFI 2.4 supports 100,000,000 bps - - WIFI 5 supports 500,000,000 bps - :return: The maximum data rate in bits per second. - """ - if self == AirSpaceFrequency.WIFI_2_4: - return 100_000_000.0 # 100 Megabits per second - if self == AirSpaceFrequency.WIFI_5: - return 500_000_000.0 # 500 Megabits per second - return 0.0 - - @property - def maximum_data_rate_mbps(self) -> float: - """ - Retrieves the maximum data transmission rate in megabits per second (Mbps). - - This is derived by converting the maximum data rate from bits per second, as defined - in `maximum_data_rate_bps`, to megabits per second. - - :return: The maximum data rate in megabits per second. - """ - return self.maximum_data_rate_bps / 1_000_000.0 +FREQ_WIFI_2_4 = AirSpaceFrequency(name="WIFI_2_4", frequency_hz=2.4e9, data_rate_bps=100_000_000.0) +FREQ_WIFI_5 = AirSpaceFrequency(name="WIFI_5", frequency_hz=5e9, data_rate_bps=500_000_000.0) class AirSpace(BaseModel): @@ -97,38 +74,53 @@ class AirSpace(BaseModel): """ wireless_interfaces: Dict[str, WirelessNetworkInterface] = Field(default_factory=lambda: {}) - wireless_interfaces_by_frequency: Dict[AirSpaceFrequency, List[WirelessNetworkInterface]] = Field( - default_factory=lambda: {} - ) - bandwidth_load: Dict[AirSpaceFrequency, float] = Field(default_factory=lambda: {}) - frequency_max_capacity_mbps_: Dict[AirSpaceFrequency, float] = Field(default_factory=lambda: {}) + wireless_interfaces_by_frequency: Dict[int, List[WirelessNetworkInterface]] = Field(default_factory=lambda: {}) + bandwidth_load: Dict[int, float] = Field(default_factory=lambda: {}) + frequencies: Dict[str, AirSpaceFrequency] = AirSpaceFrequency._registry - def get_frequency_max_capacity_mbps(self, frequency: AirSpaceFrequency) -> float: + @validate_call + def get_frequency_max_capacity_mbps(self, freq_name: str) -> float: """ Retrieves the maximum data transmission capacity for a specified frequency. - This method checks a dictionary holding custom maximum capacities. If the frequency is found, it returns the - custom set maximum capacity. If the frequency is not found in the dictionary, it defaults to the standard - maximum data rate associated with that frequency. - - :param frequency: The frequency for which the maximum capacity is queried. - + :param freq_name: The frequency for which the maximum capacity is queried. :return: The maximum capacity in Mbps for the specified frequency. """ - if frequency in self.frequency_max_capacity_mbps_: - return self.frequency_max_capacity_mbps_[frequency] - return frequency.maximum_data_rate_mbps + if freq_name in self.frequencies: + return self.frequencies[freq_name].data_rate_bps / (1024.0 * 1024.0) + return 0.0 - def set_frequency_max_capacity_mbps(self, cfg: Dict[AirSpaceFrequency, float]): + def set_frequency_max_capacity_mbps(self, cfg: Dict[int, float]) -> None: """ Sets custom maximum data transmission capacities for multiple frequencies. :param cfg: A dictionary mapping frequencies to their new maximum capacities in Mbps. """ - self.frequency_max_capacity_mbps_ = cfg for freq, mbps in cfg.items(): + self.frequencies[freq].data_rate_bps = mbps * 1024 * 1024 print(f"Overriding {freq} max capacity as {mbps:.3f} mbps") + def register_frequency(self, freq_name: str, freq_hz: float, data_rate_bps: float) -> None: + """ + Define a new frequency for this airspace. + + :param freq_name: The frequency name. If this clashes with an existing frequency name, it will be overwritten. + :type freq_name: str + :param freq_hz: The frequency itself, measured in Hertz. + :type freq_hz: float + :param data_rate_bps: The transmission capacity over this frequency, in bits per second. + :type data_rate_bps: float + """ + if freq_name in self.frequencies: + _LOGGER.info( + f"Overwriting Air space frequency {freq_name}. " + f"Previous data rate: {self.frequencies[freq_name].data_rate_bps}. " + f"Current data rate: {data_rate_bps}." + ) + self.frequencies.update( + {freq_name: AirSpaceFrequency(name=freq_name, frequency_hz=freq_hz, data_rate_bps=data_rate_bps)} + ) + def show_bandwidth_load(self, markdown: bool = False): """ Prints a table of the current bandwidth load for each frequency on the airspace. @@ -145,12 +137,20 @@ class AirSpace(BaseModel): table.set_style(MARKDOWN) table.align = "l" table.title = "Airspace Frequency Channel Loads" - for frequency, load in self.bandwidth_load.items(): - maximum_capacity = self.get_frequency_max_capacity_mbps(frequency) - load_percent = load / maximum_capacity if maximum_capacity > 0 else 0.0 + for freq_name, freq_obj in self.frequencies.items(): + maximum_capacity = self.get_frequency_max_capacity_mbps(freq_name) + load_percent = ( + self.bandwidth_load.get(freq_obj.frequency_hz, 0.0) / maximum_capacity if maximum_capacity > 0 else 0.0 + ) if load_percent > 1.0: load_percent = 1.0 - table.add_row([format_hertz(frequency.value), f"{load_percent:.0%}", f"{maximum_capacity:.3f}"]) + table.add_row( + [ + format_hertz(self.frequencies[freq_name].frequency_hz), + f"{load_percent:.0%}", + f"{maximum_capacity:.3f}", + ] + ) print(table) def show_wireless_interfaces(self, markdown: bool = False): @@ -178,11 +178,11 @@ class AirSpace(BaseModel): status = "Enabled" if interface.enabled else "Disabled" table.add_row( [ - interface._connected_node.hostname, # noqa + interface._connected_node.config.hostname, # noqa interface.mac_address, interface.ip_address if hasattr(interface, "ip_address") else None, interface.subnet_mask if hasattr(interface, "subnet_mask") else None, - format_hertz(interface.frequency.value), + format_hertz(self.frequencies[interface.frequency.name].frequency_hz), f"{interface.speed:.3f}", status, ] @@ -210,9 +210,9 @@ class AirSpace(BaseModel): """ if wireless_interface.mac_address not in self.wireless_interfaces: self.wireless_interfaces[wireless_interface.mac_address] = wireless_interface - if wireless_interface.frequency not in self.wireless_interfaces_by_frequency: - self.wireless_interfaces_by_frequency[wireless_interface.frequency] = [] - self.wireless_interfaces_by_frequency[wireless_interface.frequency].append(wireless_interface) + if wireless_interface.frequency.frequency_hz not in self.wireless_interfaces_by_frequency: + self.wireless_interfaces_by_frequency[wireless_interface.frequency.frequency_hz] = [] + self.wireless_interfaces_by_frequency[wireless_interface.frequency.frequency_hz].append(wireless_interface) def remove_wireless_interface(self, wireless_interface: WirelessNetworkInterface): """ @@ -222,7 +222,7 @@ class AirSpace(BaseModel): """ if wireless_interface.mac_address in self.wireless_interfaces: self.wireless_interfaces.pop(wireless_interface.mac_address) - self.wireless_interfaces_by_frequency[wireless_interface.frequency].remove(wireless_interface) + self.wireless_interfaces_by_frequency[wireless_interface.frequency.frequency_hz].remove(wireless_interface) def clear(self): """ @@ -255,11 +255,11 @@ class AirSpace(BaseModel): relevant frequency and its current bandwidth load. :return: True if the frame can be transmitted within the bandwidth limit, False if it would exceed the limit. """ - if sender_network_interface.frequency not in self.bandwidth_load: - self.bandwidth_load[sender_network_interface.frequency] = 0.0 + if sender_network_interface.frequency.frequency_hz not in self.bandwidth_load: + self.bandwidth_load[sender_network_interface.frequency.frequency_hz] = 0.0 return self.bandwidth_load[ - sender_network_interface.frequency - ] + frame.size_Mbits <= self.get_frequency_max_capacity_mbps(sender_network_interface.frequency) + sender_network_interface.frequency.frequency_hz + ] + frame.size_Mbits <= self.get_frequency_max_capacity_mbps(sender_network_interface.frequency.name) def transmit(self, frame: Frame, sender_network_interface: WirelessNetworkInterface): """ @@ -271,8 +271,10 @@ class AirSpace(BaseModel): :param sender_network_interface: The wireless network interface sending the frame. This interface will be excluded from the list of receivers to prevent it from receiving its own transmission. """ - self.bandwidth_load[sender_network_interface.frequency] += frame.size_Mbits - for wireless_interface in self.wireless_interfaces_by_frequency.get(sender_network_interface.frequency, []): + self.bandwidth_load[sender_network_interface.frequency.frequency_hz] += frame.size_Mbits + for wireless_interface in self.wireless_interfaces_by_frequency.get( + sender_network_interface.frequency.frequency_hz, [] + ): if wireless_interface != sender_network_interface and wireless_interface.enabled: wireless_interface.receive_frame(frame) @@ -298,7 +300,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC): """ airspace: AirSpace - frequency: AirSpaceFrequency = AirSpaceFrequency.WIFI_2_4 + frequency: AirSpaceFrequency = FREQ_WIFI_2_4 def enable(self): """Attempt to enable the network interface.""" @@ -318,7 +320,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC): self.enabled = True self._connected_node.sys_log.info(f"Network Interface {self} enabled") self.pcap = PacketCapture( - hostname=self._connected_node.hostname, port_num=self.port_num, port_name=self.port_name + hostname=self._connected_node.config.hostname, port_num=self.port_num, port_name=self.port_name ) self.airspace.add_wireless_interface(self) @@ -430,7 +432,7 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC) # Update the state with information from Layer3Interface state.update(Layer3Interface.describe_state(self)) - state["frequency"] = self.frequency.value + state["frequency"] = self.frequency.name return state @@ -447,10 +449,8 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC) `default_gateway_hello` method is not defined, ignoring such errors to proceed without interruption. """ super().enable() - try: + if hasattr(self._connected_node, "default_gateway_hello"): self._connected_node.default_gateway_hello() - except AttributeError: - pass @abstractmethod def receive_frame(self, frame: Frame) -> bool: diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index 0408acde..b0426537 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address from typing import Any, Dict, List, Optional @@ -12,7 +12,9 @@ from primaite import getLogger from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.network.airspace import AirSpace from primaite.simulator.network.hardware.base import Link, Node, WiredNetworkInterface +from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.network.hardware.nodes.host.server import Printer +from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode from primaite.simulator.system.applications.application import Application from primaite.simulator.system.services.service import Service @@ -129,6 +131,16 @@ class Network(SimComponent): """The Firewalls in the Network.""" return [node for node in self.nodes.values() if node.__class__.__name__ == "Firewall"] + @property + def extended_hostnodes(self) -> List[Node]: + """Extended nodes that inherited HostNode in the network.""" + return [node for node in self.nodes.values() if node.__class__.__name__.lower() in HostNode._registry] + + @property + def extended_networknodes(self) -> List[Node]: + """Extended nodes that inherited NetworkNode in the network.""" + return [node for node in self.nodes.values() if node.__class__.__name__.lower() in NetworkNode._registry] + @property def printer_nodes(self) -> List[Node]: """The printers on the network.""" @@ -151,24 +163,14 @@ class Network(SimComponent): :param links: Include link details in the output. Defaults to True. :param markdown: Use Markdown style in table output. Defaults to False. """ - nodes_type_map = { - "Router": self.router_nodes, - "Firewall": self.firewall_nodes, - "Switch": self.switch_nodes, - "Server": self.server_nodes, - "Computer": self.computer_nodes, - "Printer": self.printer_nodes, - "Wireless Router": self.wireless_router_nodes, - } if nodes: table = PrettyTable(["Node", "Type", "Operating State"]) if markdown: table.set_style(MARKDOWN) table.align = "l" table.title = "Nodes" - for node_type, nodes in nodes_type_map.items(): - for node in nodes: - table.add_row([node.hostname, node_type, node.operating_state.name]) + for node in self.nodes.values(): + table.add_row((node.config.hostname, type(node)._discriminator, node.operating_state.name)) print(table) if ip_addresses: @@ -177,15 +179,20 @@ class Network(SimComponent): table.set_style(MARKDOWN) table.align = "l" table.title = "IP Addresses" - for nodes in nodes_type_map.values(): - for node in nodes: - for i, port in node.network_interface.items(): - if hasattr(port, "ip_address"): - if port.ip_address != IPv4Address("127.0.0.1"): - port_str = port.port_name if port.port_name else port.port_num - table.add_row( - [node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway] - ) + for node in self.nodes.values(): + for i, port in node.network_interface.items(): + if hasattr(port, "ip_address"): + if port.ip_address != IPv4Address("127.0.0.1"): + port_str = port.port_name if port.port_name else port.port_num + table.add_row( + [ + node.config.hostname, + port_str, + port.ip_address, + port.subnet_mask, + node.config.default_gateway, + ] + ) print(table) if links: @@ -197,22 +204,21 @@ class Network(SimComponent): table.align = "l" table.title = "Links" links = list(self.links.values()) - for nodes in nodes_type_map.values(): - for node in nodes: - for link in links[::-1]: - if node in [link.endpoint_a.parent, link.endpoint_b.parent]: - table.add_row( - [ - link.endpoint_a.parent.hostname, - str(link.endpoint_a), - link.endpoint_b.parent.hostname, - str(link.endpoint_b), - link.is_up, - link.bandwidth, - link.current_load_percent, - ] - ) - links.remove(link) + for node in self.nodes.values(): + for link in links[::-1]: + if node in [link.endpoint_a.parent, link.endpoint_b.parent]: + table.add_row( + [ + link.endpoint_a.parent.config.hostname, + str(link.endpoint_a), + link.endpoint_b.parent.config.hostname, + str(link.endpoint_b), + link.is_up, + link.bandwidth, + link.current_load_percent, + ] + ) + links.remove(link) print(table) def clear_links(self): @@ -239,7 +245,7 @@ class Network(SimComponent): state = super().describe_state() state.update( { - "nodes": {node.hostname: node.describe_state() for node in self.nodes.values()}, + "nodes": {node.config.hostname: node.describe_state() for node in self.nodes.values()}, "links": {}, } ) @@ -247,8 +253,8 @@ class Network(SimComponent): for _, link in self.links.items(): node_a = link.endpoint_a._connected_node node_b = link.endpoint_b._connected_node - hostname_a = node_a.hostname if node_a else None - hostname_b = node_b.hostname if node_b else None + hostname_a = node_a.config.hostname if node_a else None + hostname_b = node_b.config.hostname if node_b else None port_a = link.endpoint_a.port_num port_b = link.endpoint_b.port_num link_key = f"{hostname_a}:eth-{port_a}<->{hostname_b}:eth-{port_b}" @@ -274,9 +280,11 @@ class Network(SimComponent): self.nodes[node.uuid] = node self._node_id_map[len(self.nodes)] = node node.parent = self - self._nx_graph.add_node(node.hostname) + self._nx_graph.add_node(node.config.hostname) _LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}") - self._node_request_manager.add_request(name=node.hostname, request_type=RequestType(func=node._request_manager)) + self._node_request_manager.add_request( + name=node.config.hostname, request_type=RequestType(func=node._request_manager) + ) def get_node_by_hostname(self, hostname: str) -> Optional[Node]: """ @@ -288,7 +296,7 @@ class Network(SimComponent): :return: The Node if it exists in the network. """ for node in self.nodes.values(): - if node.hostname == hostname: + if node.config.hostname == hostname: return node def remove_node(self, node: Node) -> None: @@ -301,7 +309,7 @@ class Network(SimComponent): :type node: Node """ if node not in self: - _LOGGER.warning(f"Can't remove node {node.hostname}. It's not in the network.") + _LOGGER.warning(f"Can't remove node {node.config.hostname}. It's not in the network.") return self.nodes.pop(node.uuid) for i, _node in self._node_id_map.items(): @@ -309,8 +317,8 @@ class Network(SimComponent): self._node_id_map.pop(i) break node.parent = None - self._node_request_manager.remove_request(name=node.hostname) - _LOGGER.info(f"Removed node {node.hostname} from network {self.uuid}") + self._node_request_manager.remove_request(name=node.config.hostname) + _LOGGER.info(f"Removed node {node.config.hostname} from network {self.uuid}") def connect( self, endpoint_a: WiredNetworkInterface, endpoint_b: WiredNetworkInterface, bandwidth: int = 100, **kwargs @@ -340,7 +348,7 @@ class Network(SimComponent): link = Link(endpoint_a=endpoint_a, endpoint_b=endpoint_b, bandwidth=bandwidth, **kwargs) self.links[link.uuid] = link self._link_id_map[len(self.links)] = link - self._nx_graph.add_edge(endpoint_a.parent.hostname, endpoint_b.parent.hostname) + self._nx_graph.add_edge(endpoint_a.parent.config.hostname, endpoint_b.parent.config.hostname) link.parent = self _LOGGER.debug(f"Added link {link.uuid} to connect {endpoint_a} and {endpoint_b}") return link diff --git a/src/primaite/simulator/network/creation.py b/src/primaite/simulator/network/creation.py index b801a38e..089ed00d 100644 --- a/src/primaite/simulator/network/creation.py +++ b/src/primaite/simulator/network/creation.py @@ -1,12 +1,256 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from abc import ABC, abstractmethod from ipaddress import IPv4Address -from typing import Optional +from typing import Any, ClassVar, Dict, Literal, Optional, Type + +from pydantic import BaseModel, ConfigDict, model_validator from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP + + +class NetworkNodeAdder(BaseModel): + """ + Base class for adding a set of related nodes to a network in a standardised way. + + Child classes should define a ConfigSchema nested class that subclasses NetworkNodeAdder.ConfigSchema and a __call__ + method which performs the node addition to the network. + + Here is a template that users can use to define custom node adders: + ``` + class YourNodeAdder(NetworkNodeAdder, discriminator="your-name"): + class ConfigSchema(NetworkNodeAdder.ConfigSchema): + property_1 : str + property_2 : int + + @classmethod + def add_nodes_to_net(cls, config: ConfigSchema, network: Network) -> None: + node_1 = Node(property_1, ...) + node_2 = Node(...) + network.connect(node_1.network_interface[1], node_2.network_interface[1]) + ... + ``` + """ + + class ConfigSchema(BaseModel, ABC): + """ + Base schema for node adders. + + Child classes of NetworkNodeAdder must define a schema which inherits from this schema. The discriminator is + used by the from_config method to select the correct node adder at runtime. + """ + + model_config = ConfigDict(extra="forbid") + type: str + """Uniquely identifies the node adder class to use for adding nodes to network.""" + + _registry: ClassVar[Dict[str, Type["NetworkNodeAdder"]]] = {} + + def __init_subclass__(cls, discriminator: Optional[str], **kwargs: Any) -> None: + """ + Register a network node adder class. + + :param discriminator: Unique name for the node adder to use for matching against primaite config entries. + :type discriminator: str + :raises ValueError: When attempting to register a name that is already reserved. + """ + super().__init_subclass__(**kwargs) + if discriminator is None: + return + if discriminator in cls._registry: + raise ValueError(f"Duplicate node adder {discriminator}") + cls._registry[discriminator] = cls + + @classmethod + @abstractmethod + def add_nodes_to_net(cls, config: ConfigSchema, network: Network) -> None: + """ + Add nodes to the network. + + Abstract method that must be overwritten by child classes. Use the config definition to create nodes and add + them to the network that is passed in. + + :param config: Config object that defines how to create and add nodes to the network + :type config: ConfigSchema + :param network: PrimAITE network object to which to add nodes. + :type network: Network + """ + pass + + @classmethod + def from_config(cls, config: Dict, network: Network) -> None: + """ + Accept a config, find the relevant node adder class, and call it to add nodes to the network. + + Child classes do not need to define this method. + + :param config: Configuration object for the child adder class + :type config: Dict + :param network: The Network object to which to add nodes + :type network: Network + """ + if config["type"] not in cls._registry: + raise ValueError(f"Invalid node adder type {config['type']}") + adder_class = cls._registry[config["type"]] + adder_class.add_nodes_to_net(config=adder_class.ConfigSchema(**config), network=network) + + +class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"): + """Creates an office LAN.""" + + class ConfigSchema(NetworkNodeAdder.ConfigSchema): + """Configuration schema for OfficeLANAdder.""" + + type: Literal["office-lan"] = "office-lan" + lan_name: str + """Name of lan used for generating hostnames for new nodes.""" + subnet_base: int + """Used as the third octet of IP addresses for nodes in the network.""" + pcs_ip_block_start: int + """Starting point for the fourth octet of IP addresses of nodes in the network.""" + num_pcs: int + """The number of hosts to generate.""" + include_router: bool = True + """Whether to include a router in the new office LAN.""" + bandwidth: int = 100 + """Data bandwidth to the LAN measured in Mbps.""" + + @model_validator(mode="after") + def check_ip_range(self) -> "OfficeLANAdder.ConfigSchema": + """Make sure the ip addresses of hosts don't exceed the maximum possible ip address.""" + if self.pcs_ip_block_start + self.num_pcs >= 254: + raise ValueError( + f"Cannot create {self.num_pcs} pcs starting at ip block {self.pcs_ip_block_start} " + f"because ip address octets cannot exceed 254." + ) + return self + + @classmethod + def add_nodes_to_net(cls, config: ConfigSchema, network: Network) -> None: + """ + Add an office lan to the network according to the config definition. + + This method creates a number of hosts and enough switches such that all hosts can be connected to a switch. + Optionally, a router is added to connect the switches together. All the nodes and networking devices are added + to the provided network. + + :param config: Configuration object specifying office LAN parameters + :type config: OfficeLANAdder.ConfigSchema + :param network: The PrimAITE network to which to add the office LAN. + :type network: Network + :raises ValueError: upon invalid configuration + """ + # Calculate the required number of switches + num_of_switches = num_of_switches_required(num_nodes=config.num_pcs) + effective_network_interface = 23 # One port less for router connection + if config.pcs_ip_block_start <= num_of_switches: + raise ValueError( + f"pcs_ip_block_start must be greater than the number of required switches {num_of_switches}" + ) + + # Create a core switch if more than one edge switch is needed + if num_of_switches > 1: + core_switch = Switch.from_config( + config={ + "type": "switch", + "hostname": f"switch_core_{config.lan_name}", + "start_up_duration": 0, + "num_ports": 24, + } + ) + core_switch.power_on() + network.add_node(core_switch) + core_switch_port = 1 + + # Initialise the default gateway to None + default_gateway = None + + # Optionally include a router in the LAN + if config.include_router: + default_gateway = IPv4Address(f"192.168.{config.subnet_base}.1") + router = Router.from_config( + config={"hostname": f"router_{config.lan_name}", "type": "router", "start_up_duration": 0} + ) + router.power_on() + router.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + router.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + network.add_node(router) + router.configure_port(port=1, ip_address=default_gateway, subnet_mask="255.255.255.0") + router.enable_port(1) + + # Initialise the first edge switch and connect to the router or core switch + switch_port = 0 + switch_n = 1 + switch = Switch.from_config( + config={ + "type": "switch", + "hostname": f"switch_edge_{switch_n}_{config.lan_name}", + "start_up_duration": 0, + "num_ports": 24, + } + ) + switch.power_on() + network.add_node(switch) + if num_of_switches > 1: + network.connect( + core_switch.network_interface[core_switch_port], + switch.network_interface[24], + bandwidth=config.bandwidth, + ) + else: + network.connect(router.network_interface[1], switch.network_interface[24], bandwidth=config.bandwidth) + + # Add PCs to the LAN and connect them to switches + for i in range(1, config.num_pcs + 1): + # Add a new edge switch if the current one is full + if switch_port == effective_network_interface: + switch_n += 1 + switch_port = 0 + switch = Switch.from_config( + config={ + "type": "switch", + "hostname": f"switch_edge_{switch_n}_{config.lan_name}", + "start_up_duration": 0, + "num_ports": 24, + } + ) + switch.power_on() + network.add_node(switch) + # Connect the new switch to the router or core switch + if num_of_switches > 1: + core_switch_port += 1 + network.connect( + core_switch.network_interface[core_switch_port], + switch.network_interface[24], + bandwidth=config.bandwidth, + ) + else: + network.connect( + router.network_interface[1], switch.network_interface[24], bandwidth=config.bandwidth + ) + + # Create and add a PC to the network + pc_cfg = { + "type": "computer", + "hostname": f"pc_{i}_{config.lan_name}", + "ip_address": f"192.168.{config.subnet_base}.{i+config.pcs_ip_block_start-1}", + "default_gateway": default_gateway, + "start_up_duration": 0, + } + pc = Computer.from_config(config=pc_cfg) + pc.power_on() + network.add_node(pc) + + # Connect the PC to the switch + switch_port += 1 + network.connect(switch.network_interface[switch_port], pc.network_interface[1], bandwidth=config.bandwidth) + switch.network_interface[switch_port].enable() def num_of_switches_required(num_nodes: int, max_network_interface: int = 24) -> int: @@ -41,112 +285,3 @@ def num_of_switches_required(num_nodes: int, max_network_interface: int = 24) -> # Return the total number of switches required return full_switches + (1 if extra_pcs > 0 else 0) - - -def create_office_lan( - lan_name: str, - subnet_base: int, - pcs_ip_block_start: int, - num_pcs: int, - network: Optional[Network] = None, - include_router: bool = True, - bandwidth: int = 100, -) -> Network: - """ - Creates a 2-Tier or 3-Tier office local area network (LAN). - - The LAN is configured with a specified number of personal computers (PCs), optionally including a router, - and multiple edge switches to connect them. A core switch is added only if more than one edge switch is required. - The network topology involves edge switches connected either directly to the router in a 2-Tier setup or - to a core switch in a 3-Tier setup. If a router is included, it is connected to the core switch (if present) - and configured with basic access control list (ACL) rules. PCs are distributed across the edge switches. - - - :param str lan_name: The name to be assigned to the LAN. - :param int subnet_base: The subnet base number to be used in the IP addresses. - :param int pcs_ip_block_start: The starting block for assigning IP addresses to PCs. - :param int num_pcs: The number of PCs to be added to the LAN. - :param Optional[Network] network: The network to which the LAN components will be added. If None, a new network is - created. - :param bool include_router: Flag to determine if a router should be included in the LAN. Defaults to True. - :return: The network object with the LAN components added. - :raises ValueError: If pcs_ip_block_start is less than or equal to the number of required switches. - """ - # Initialise the network if not provided - if not network: - network = Network() - - # Calculate the required number of switches - num_of_switches = num_of_switches_required(num_nodes=num_pcs) - effective_network_interface = 23 # One port less for router connection - if pcs_ip_block_start <= num_of_switches: - raise ValueError(f"pcs_ip_block_start must be greater than the number of required switches {num_of_switches}") - - # Create a core switch if more than one edge switch is needed - if num_of_switches > 1: - core_switch = Switch(hostname=f"switch_core_{lan_name}", start_up_duration=0) - core_switch.power_on() - network.add_node(core_switch) - core_switch_port = 1 - - # Initialise the default gateway to None - default_gateway = None - - # Optionally include a router in the LAN - if include_router: - default_gateway = IPv4Address(f"192.168.{subnet_base}.1") - router = Router(hostname=f"router_{lan_name}", start_up_duration=0) - router.power_on() - router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - network.add_node(router) - router.configure_port(port=1, ip_address=default_gateway, subnet_mask="255.255.255.0") - router.enable_port(1) - - # Initialise the first edge switch and connect to the router or core switch - switch_port = 0 - switch_n = 1 - switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0) - switch.power_on() - network.add_node(switch) - if num_of_switches > 1: - network.connect( - core_switch.network_interface[core_switch_port], switch.network_interface[24], bandwidth=bandwidth - ) - else: - network.connect(router.network_interface[1], switch.network_interface[24], bandwidth=bandwidth) - - # Add PCs to the LAN and connect them to switches - for i in range(1, num_pcs + 1): - # Add a new edge switch if the current one is full - if switch_port == effective_network_interface: - switch_n += 1 - switch_port = 0 - switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0) - switch.power_on() - network.add_node(switch) - # Connect the new switch to the router or core switch - if num_of_switches > 1: - core_switch_port += 1 - network.connect( - core_switch.network_interface[core_switch_port], switch.network_interface[24], bandwidth=bandwidth - ) - else: - network.connect(router.network_interface[1], switch.network_interface[24], bandwidth=bandwidth) - - # Create and add a PC to the network - pc = Computer( - hostname=f"pc_{i}_{lan_name}", - ip_address=f"192.168.{subnet_base}.{i+pcs_ip_block_start-1}", - subnet_mask="255.255.255.0", - default_gateway=default_gateway, - start_up_duration=0, - ) - pc.power_on() - network.add_node(pc) - - # Connect the PC to the switch - switch_port += 1 - network.connect(switch.network_interface[switch_port], pc.network_interface[1], bandwidth=bandwidth) - switch.network_interface[switch_port].enable() - - return network diff --git a/src/primaite/simulator/network/hardware/__init__.py b/src/primaite/simulator/network/hardware/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/network/hardware/__init__.py +++ b/src/primaite/simulator/network/hardware/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 570a69b3..9cc39848 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations import re @@ -9,7 +9,7 @@ from pathlib import Path from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union from prettytable import MARKDOWN, PrettyTable -from pydantic import BaseModel, Field, validate_call +from pydantic import BaseModel, ConfigDict, Field, validate_call from primaite import getLogger from primaite.exceptions import NetworkError @@ -21,8 +21,6 @@ from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.packet_capture import PacketCapture from primaite.simulator.system.core.session_manager import SessionManager @@ -33,7 +31,9 @@ from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.terminal.terminal import Terminal from primaite.simulator.system.software import IOSoftware, Software from primaite.utils.converters import convert_dict_enum_keys_to_enum_values -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import PORT_LOOKUP IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware) @@ -203,16 +203,16 @@ class NetworkInterface(SimComponent, ABC): # Initialise basic frame data variables direction = "inbound" if inbound else "outbound" # Direction of the traffic ip_address = str(frame.ip.src_ip_address if inbound else frame.ip.dst_ip_address) # Source or destination IP - protocol = frame.ip.protocol.name # Network protocol used in the frame + protocol = frame.ip.protocol # Network protocol used in the frame # Initialise port variable; will be determined based on protocol type port = None # Determine the source or destination port based on the protocol (TCP/UDP) if frame.tcp: - port = frame.tcp.src_port.value if inbound else frame.tcp.dst_port.value + port = frame.tcp.src_port if inbound else frame.tcp.dst_port elif frame.udp: - port = frame.udp.src_port.value if inbound else frame.udp.dst_port.value + port = frame.udp.src_port if inbound else frame.udp.dst_port # Convert frame payload to string for keyword checking frame_str = str(frame.payload) @@ -274,20 +274,20 @@ class NetworkInterface(SimComponent, ABC): # Identify the protocol and port from the frame if frame.tcp: - protocol = IPProtocol.TCP + protocol = PROTOCOL_LOOKUP["TCP"] port = frame.tcp.dst_port elif frame.udp: - protocol = IPProtocol.UDP + protocol = PROTOCOL_LOOKUP["UDP"] port = frame.udp.dst_port elif frame.icmp: - protocol = IPProtocol.ICMP + protocol = PROTOCOL_LOOKUP["ICMP"] # Ensure the protocol is in the capture dict if protocol not in self.traffic: self.traffic[protocol] = {} # Handle non-ICMP protocols that use ports - if protocol != IPProtocol.ICMP: + if protocol != PROTOCOL_LOOKUP["ICMP"]: if port not in self.traffic[protocol]: self.traffic[protocol][port] = {"inbound": 0, "outbound": 0} self.traffic[protocol][port][direction] += frame.size_Mbits @@ -431,7 +431,7 @@ class WiredNetworkInterface(NetworkInterface, ABC): self.enabled = True self._connected_node.sys_log.info(f"Network Interface {self} enabled") self.pcap = PacketCapture( - hostname=self._connected_node.hostname, port_num=self.port_num, port_name=self.port_name + hostname=self._connected_node.config.hostname, port_num=self.port_num, port_name=self.port_name ) if self._connected_link: self._connected_link.endpoint_up() @@ -639,10 +639,8 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC): `default_gateway_hello` method is not defined, ignoring such errors to proceed without interruption. """ super().enable() - try: + if hasattr(self._connected_node, "default_gateway_hello"): self._connected_node.default_gateway_hello() - except AttributeError: - pass return True @abstractmethod @@ -824,7 +822,7 @@ class User(SimComponent): return self.model_dump() -class UserManager(Service): +class UserManager(Service, discriminator="user-manager"): """ Manages users within the PrimAITE system, handling creation, authentication, and administration. @@ -833,18 +831,25 @@ class UserManager(Service): :param disabled_admins: A dictionary of currently disabled admin users by their usernames """ + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for UserManager.""" + + type: str = "user-manager" + + config: "UserManager.ConfigSchema" = Field(default_factory=lambda: UserManager.ConfigSchema()) + users: Dict[str, User] = {} def __init__(self, **kwargs): """ - Initializes a UserManager instanc. + Initializes a UserManager instance. :param username: The username for the default admin user :param password: The password for the default admin user """ - kwargs["name"] = "UserManager" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["name"] = "user-manager" + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) self.start() @@ -1044,7 +1049,7 @@ class UserManager(Service): @property def _user_session_manager(self) -> "UserSessionManager": - return self.software_manager.software["UserSessionManager"] # noqa + return self.software_manager.software["user-session-manager"] # noqa class UserSession(SimComponent): @@ -1144,13 +1149,20 @@ class RemoteUserSession(UserSession): return state -class UserSessionManager(Service): +class UserSessionManager(Service, discriminator="user-session-manager"): """ Manages user sessions on a Node, including local and remote sessions. This class handles authentication, session management, and session timeouts for users interacting with the Node. """ + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for UserSessionManager.""" + + type: str = "user-session-manager" + + config: "UserSessionManager.ConfigSchema" = Field(default_factory=lambda: UserSessionManager.ConfigSchema()) + local_session: Optional[UserSession] = None """The current local user session, if any.""" @@ -1179,9 +1191,9 @@ class UserSessionManager(Service): :param username: The username for the default admin user :param password: The password for the default admin user """ - kwargs["name"] = "UserSessionManager" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["name"] = "user-session-manager" + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) self.start() @@ -1197,7 +1209,7 @@ class UserSessionManager(Service): """Request should take the form [username, password, remote_ip_address].""" username, password, remote_ip_address = request response = RequestResponse.from_bool(self.remote_login(username, password, remote_ip_address)) - response.data = {"remote_hostname": self.parent.hostname, "username": username} + response.data = {"remote_hostname": self.parent.config.hostname, "username": username} return response rm.add_request("remote_login", RequestType(func=_remote_login)) @@ -1230,7 +1242,7 @@ class UserSessionManager(Service): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.parent.hostname} User Sessions" + table.title = f"{self.parent.config.hostname} User Sessions" def _add_session_to_table(user_session: UserSession): """ @@ -1289,7 +1301,7 @@ class UserSessionManager(Service): :return: The UserManager instance. """ - return self.software_manager.software["UserManager"] # noqa + return self.software_manager.software["user-manager"] # noqa def pre_timestep(self, timestep: int) -> None: """Apply any pre-timestep logic that helps make sure we have the correct observations.""" @@ -1326,7 +1338,7 @@ class UserSessionManager(Service): software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( payload={"type": "user_timeout", "connection_id": session.uuid}, - dest_port=Port.SSH, + dest_port=PORT_LOOKUP["SSH"], dest_ip_address=session.remote_ip_address, ) @@ -1483,7 +1495,7 @@ class UserSessionManager(Service): return self.local_session is not None -class Node(SimComponent): +class Node(SimComponent, ABC): """ A basic Node class that represents a node on the network. @@ -1494,19 +1506,12 @@ class Node(SimComponent): :param operating_state: The node operating state, either ON or OFF. """ - hostname: str - "The node hostname on the network." - default_gateway: Optional[IPV4Address] = None - "The default gateway IP address for forwarding network traffic to other networks." operating_state: NodeOperatingState = NodeOperatingState.OFF "The hardware state of the node." network_interfaces: Dict[str, NetworkInterface] = {} "The Network Interfaces on the node." network_interface: Dict[int, NetworkInterface] = {} "The Network Interfaces on the node by port id." - dns_server: Optional[IPv4Address] = None - "List of IP addresses of DNS servers used for name resolution." - accounts: Dict[str, Account] = {} "All accounts on the node." applications: Dict[str, Application] = {} @@ -1523,36 +1528,98 @@ class Node(SimComponent): session_manager: SessionManager software_manager: SoftwareManager - revealed_to_red: bool = False - "Informs whether the node has been revealed to a red agent." - - start_up_duration: int = 3 - "Time steps needed for the node to start up." - - start_up_countdown: int = 0 - "Time steps needed until node is booted up." - - shut_down_duration: int = 3 - "Time steps needed for the node to shut down." - - shut_down_countdown: int = 0 - "Time steps needed until node is shut down." - - is_resetting: bool = False - "If true, the node will try turning itself off then back on again." - - node_scan_duration: int = 10 - "How many timesteps until the whole node is scanned. Default 10 time steps." - - node_scan_countdown: int = 0 - "Time steps until scan is complete" - - red_scan_countdown: int = 0 - "Time steps until reveal to red scan is complete." - SYSTEM_SOFTWARE: ClassVar[Dict[str, Type[Software]]] = {} "Base system software that must be preinstalled." + _registry: ClassVar[Dict[str, Type["Node"]]] = {} + """Registry of application types. Automatically populated when subclasses are defined.""" + + # TODO: this should not be set for abstract classes. + _discriminator: ClassVar[str] + """discriminator for this particular class, used for printing and logging. Each subclass redefines this.""" + + class ConfigSchema(BaseModel, ABC): + """Configuration Schema for Node based classes.""" + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + """Configure pydantic to allow arbitrary types, let the instance have attributes not present in the model.""" + + type: str + + hostname: str + "The node hostname on the network." + + revealed_to_red: bool = False + "Informs whether the node has been revealed to a red agent." + + start_up_duration: int = 3 + "Time steps needed for the node to start up." + + start_up_countdown: int = 0 + "Time steps needed until node is booted up." + + shut_down_duration: int = 3 + "Time steps needed for the node to shut down." + + shut_down_countdown: int = 0 + "Time steps needed until node is shut down." + + is_resetting: bool = False + "If true, the node will try turning itself off then back on again." + + node_scan_duration: int = 10 + "How many timesteps until the whole node is scanned. Default 10 time steps." + + node_scan_countdown: int = 0 + "Time steps until scan is complete" + + red_scan_countdown: int = 0 + "Time steps until reveal to red scan is complete." + + dns_server: Optional[IPv4Address] = None + "List of IP addresses of DNS servers used for name resolution." + + default_gateway: Optional[IPV4Address] = None + "The default gateway IP address for forwarding network traffic to other networks." + + operating_state: Any = None + + users: List[Dict] = [] # Temporary to appease "extra=forbid" + + config: ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema()) + """Configuration items within Node""" + + @property + def dns_server(self) -> Optional[IPv4Address]: + """Convenience method to access the dns_server IP.""" + return self.config.dns_server + + @classmethod + def from_config(cls, config: Dict) -> "Node": + """Create Node object from a given configuration dictionary.""" + if config["type"] not in cls._registry: + msg = f"Configuration contains an invalid Node type: {config['type']}" + return ValueError(msg) + obj = cls(config=cls.ConfigSchema(**config)) + return obj + + def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None: + """ + Register a node type. + + :param discriminator: Uniquely specifies an node class by name. Used for finding items by config. + :type discriminator: str + :raises ValueError: When attempting to register an node with a name that is already allocated. + """ + super().__init_subclass__(**kwargs) + if discriminator is None: + return + discriminator = discriminator.lower() + if discriminator in cls._registry: + raise ValueError(f"Tried to define new node {discriminator}, but this name is already reserved.") + cls._registry[discriminator] = cls + cls._discriminator = discriminator + def __init__(self, **kwargs): """ Initialize the Node with various components and managers. @@ -1561,11 +1628,11 @@ class Node(SimComponent): provided. """ if not kwargs.get("sys_log"): - kwargs["sys_log"] = SysLog(kwargs["hostname"]) + kwargs["sys_log"] = SysLog(kwargs["config"].hostname) if not kwargs.get("session_manager"): kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log")) if not kwargs.get("root"): - kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"] + kwargs["root"] = SIM_OUTPUT.path / kwargs["config"].hostname if not kwargs.get("file_system"): kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs") if not kwargs.get("software_manager"): @@ -1574,27 +1641,32 @@ class Node(SimComponent): sys_log=kwargs.get("sys_log"), session_manager=kwargs.get("session_manager"), file_system=kwargs.get("file_system"), - dns_server=kwargs.get("dns_server"), + dns_server=kwargs["config"].dns_server, ) super().__init__(**kwargs) + self.operating_state = ( + NodeOperatingState.ON if not (p := kwargs["config"].operating_state) else NodeOperatingState[p.upper()] + ) self._install_system_software() self.session_manager.node = self self.session_manager.software_manager = self.software_manager + for user in self.config.users: + self.user_manager.add_user(**user, bypass_can_perform_action=True) @property def user_manager(self) -> Optional[UserManager]: """The Nodes User Manager.""" - return self.software_manager.software.get("UserManager") # noqa + return self.software_manager.software.get("user-manager") # noqa @property def user_session_manager(self) -> Optional[UserSessionManager]: """The Nodes User Session Manager.""" - return self.software_manager.software.get("UserSessionManager") # noqa + return self.software_manager.software.get("user-session-manager") # noqa @property def terminal(self) -> Optional[Terminal]: - """The Nodes Terminal.""" - return self.software_manager.software.get("Terminal") + """The Node's Terminal.""" + return self.software_manager.software.get("terminal") def local_login(self, username: str, password: str) -> Optional[str]: """ @@ -1670,7 +1742,7 @@ class Node(SimComponent): @property def fail_message(self) -> str: """Message that is reported when a request is rejected by this validator.""" - return f"Cannot perform request on node '{self.node.hostname}' because it is not powered on." + return f"Cannot perform request on node '{self.node.config.hostname}' because it is not powered on." class _NodeIsOffValidator(RequestPermissionValidator): """ @@ -1689,7 +1761,7 @@ class Node(SimComponent): @property def fail_message(self) -> str: """Message that is reported when a request is rejected by this validator.""" - return f"Cannot perform request on node '{self.node.hostname}' because it is not turned off." + return f"Cannot perform request on node '{self.node.config.hostname}' because it is not turned off." def _init_request_manager(self) -> RequestManager: """ @@ -1713,11 +1785,11 @@ class Node(SimComponent): if self.software_manager.software.get(application_name): self.sys_log.info(f"Can't install {application_name}. It's already installed.") return RequestResponse(status="success", data={"reason": "already installed"}) - application_class = Application._application_registry[application_name] + application_class = Application._registry[application_name] self.software_manager.install(application_class) application_instance = self.software_manager.software.get(application_name) self.applications[application_instance.uuid] = application_instance - _LOGGER.debug(f"Added application {application_instance.name} to node {self.hostname}") + _LOGGER.debug(f"Added application {application_instance.name} to node {self.config.hostname}") self._application_request_manager.add_request( application_name, RequestType(func=application_instance._request_manager) ) @@ -1831,7 +1903,7 @@ class Node(SimComponent): state = super().describe_state() state.update( { - "hostname": self.hostname, + "hostname": self.config.hostname, "operating_state": self.operating_state.value, "NICs": { eth_num: network_interface.describe_state() @@ -1841,7 +1913,7 @@ class Node(SimComponent): "applications": {app.name: app.describe_state() for app in self.applications.values()}, "services": {svc.name: svc.describe_state() for svc in self.services.values()}, "process": {proc.name: proc.describe_state() for proc in self.processes.values()}, - "revealed_to_red": self.revealed_to_red, + "revealed_to_red": self.config.revealed_to_red, } ) return state @@ -1853,14 +1925,14 @@ class Node(SimComponent): def show_open_ports(self, markdown: bool = False): """Prints a table of the open ports on the Node.""" - table = PrettyTable(["Port", "Name"]) + table = PrettyTable(["Port"]) if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.hostname} Open Ports" + table.title = f"{self.config.hostname} Open Ports" for port in self.software_manager.get_open_ports(): - if port.value > 0: - table.add_row([port.value, port.name]) + if port > 0: + table.add_row([port]) print(table.get_string(sortby="Port")) @property @@ -1884,7 +1956,7 @@ class Node(SimComponent): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.hostname} Network Interface Cards" + table.title = f"{self.config.hostname} Network Interface Cards" for port, network_interface in self.network_interface.items(): ip_address = "" if hasattr(network_interface, "ip_address"): @@ -1919,38 +1991,38 @@ class Node(SimComponent): network_interface.apply_timestep(timestep=timestep) # count down to boot up - if self.start_up_countdown > 0: - self.start_up_countdown -= 1 + if self.config.start_up_countdown > 0: + self.config.start_up_countdown -= 1 else: if self.operating_state == NodeOperatingState.BOOTING: self.operating_state = NodeOperatingState.ON - self.sys_log.info(f"{self.hostname}: Turned on") + self.sys_log.info(f"{self.config.hostname}: Turned on") for network_interface in self.network_interfaces.values(): network_interface.enable() self._start_up_actions() # count down to shut down - if self.shut_down_countdown > 0: - self.shut_down_countdown -= 1 + if self.config.shut_down_countdown > 0: + self.config.shut_down_countdown -= 1 else: if self.operating_state == NodeOperatingState.SHUTTING_DOWN: self.operating_state = NodeOperatingState.OFF - self.sys_log.info(f"{self.hostname}: Turned off") + self.sys_log.info(f"{self.config.hostname}: Turned off") self._shut_down_actions() # if resetting turn back on - if self.is_resetting: - self.is_resetting = False + if self.config.is_resetting: + self.config.is_resetting = False self.power_on() # time steps which require the node to be on if self.operating_state == NodeOperatingState.ON: # node scanning - if self.node_scan_countdown > 0: - self.node_scan_countdown -= 1 + if self.config.node_scan_countdown > 0: + self.config.node_scan_countdown -= 1 - if self.node_scan_countdown == 0: + if self.config.node_scan_countdown == 0: # scan everything! for process_id in self.processes: self.processes[process_id].scan() @@ -1966,10 +2038,10 @@ class Node(SimComponent): # scan file system self.file_system.scan(instant_scan=True) - if self.red_scan_countdown > 0: - self.red_scan_countdown -= 1 + if self.config.red_scan_countdown > 0: + self.config.red_scan_countdown -= 1 - if self.red_scan_countdown == 0: + if self.config.red_scan_countdown == 0: # scan processes for process_id in self.processes: self.processes[process_id].reveal_to_red() @@ -2026,7 +2098,7 @@ class Node(SimComponent): to the red agent. """ - self.node_scan_countdown = self.node_scan_duration + self.config.node_scan_countdown = self.config.node_scan_duration return True def reveal_to_red(self) -> bool: @@ -2042,12 +2114,12 @@ class Node(SimComponent): `revealed_to_red` to `True`. """ - self.red_scan_countdown = self.node_scan_duration + self.config.red_scan_countdown = self.config.node_scan_duration return True def power_on(self) -> bool: """Power on the Node, enabling its NICs if it is in the OFF state.""" - if self.start_up_duration <= 0: + if self.config.start_up_duration <= 0: self.operating_state = NodeOperatingState.ON self._start_up_actions() self.sys_log.info("Power on") @@ -2056,14 +2128,14 @@ class Node(SimComponent): return True if self.operating_state == NodeOperatingState.OFF: self.operating_state = NodeOperatingState.BOOTING - self.start_up_countdown = self.start_up_duration + self.config.start_up_countdown = self.config.start_up_duration return True return False def power_off(self) -> bool: """Power off the Node, disabling its NICs if it is in the ON state.""" - if self.shut_down_duration <= 0: + if self.config.shut_down_duration <= 0: self._shut_down_actions() self.operating_state = NodeOperatingState.OFF self.sys_log.info("Power off") @@ -2072,7 +2144,7 @@ class Node(SimComponent): for network_interface in self.network_interfaces.values(): network_interface.disable() self.operating_state = NodeOperatingState.SHUTTING_DOWN - self.shut_down_countdown = self.shut_down_duration + self.config.shut_down_countdown = self.config.shut_down_duration return True return False @@ -2084,7 +2156,7 @@ class Node(SimComponent): Applying more timesteps will eventually turn the node back on. """ if self.operating_state.ON: - self.is_resetting = True + self.config.is_resetting = True self.sys_log.info("Resetting") self.power_off() return True @@ -2189,10 +2261,6 @@ class Node(SimComponent): for app_id in self.applications: self.applications[app_id].close() - # Turn off all processes in the node - # for process_id in self.processes: - # self.processes[process_id] - def _start_up_actions(self): """Actions to perform when the node is starting up.""" # Turn on all the services in the node @@ -2203,10 +2271,6 @@ class Node(SimComponent): for app_id in self.applications: self.applications[app_id].run() - # Turn off all processes in the node - # for process_id in self.processes: - # self.processes[process_id] - def _install_system_software(self) -> None: """Preinstall required software.""" for _, software_class in self.SYSTEM_SOFTWARE.items(): diff --git a/src/primaite/simulator/network/hardware/network_interface/__init__.py b/src/primaite/simulator/network/hardware/network_interface/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/network/hardware/network_interface/__init__.py +++ b/src/primaite/simulator/network/hardware/network_interface/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/network/hardware/network_interface/wireless/__init__.py b/src/primaite/simulator/network/hardware/network_interface/wireless/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/network/hardware/network_interface/wireless/__init__.py +++ b/src/primaite/simulator/network/hardware/network_interface/wireless/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_access_point.py b/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_access_point.py index a9a31768..3997872c 100644 --- a/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_access_point.py +++ b/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_access_point.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Dict from primaite.simulator.network.hardware.base import ( diff --git a/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_nic.py b/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_nic.py index eebaedc5..9bc4cd6f 100644 --- a/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_nic.py +++ b/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_nic.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Dict from primaite.simulator.network.hardware.base import ( diff --git a/src/primaite/simulator/network/hardware/node_operating_state.py b/src/primaite/simulator/network/hardware/node_operating_state.py index e64ef08b..8771cb84 100644 --- a/src/primaite/simulator/network/hardware/node_operating_state.py +++ b/src/primaite/simulator/network/hardware/node_operating_state.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from enum import Enum diff --git a/src/primaite/simulator/network/hardware/nodes/__init__.py b/src/primaite/simulator/network/hardware/nodes/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/network/hardware/nodes/__init__.py +++ b/src/primaite/simulator/network/hardware/nodes/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/network/hardware/nodes/host/__init__.py b/src/primaite/simulator/network/hardware/nodes/host/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/__init__.py +++ b/src/primaite/simulator/network/hardware/nodes/host/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/network/hardware/nodes/host/computer.py b/src/primaite/simulator/network/hardware/nodes/host/computer.py index 68c72554..bee172d9 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/computer.py +++ b/src/primaite/simulator/network/hardware/nodes/host/computer.py @@ -1,11 +1,13 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from typing import ClassVar, Dict +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from typing import ClassVar, Dict, Literal + +from pydantic import Field from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.system.services.ftp.ftp_client import FTPClient -class Computer(HostNode): +class Computer(HostNode, discriminator="computer"): """ A basic Computer class. @@ -33,6 +35,14 @@ class Computer(HostNode): * Web Browser """ - SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "FTPClient": FTPClient} + SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "ftp-client": FTPClient} + + class ConfigSchema(HostNode.ConfigSchema): + """Configuration Schema for Computer class.""" + + type: Literal["computer"] = "computer" + hostname: str = "Computer" + + config: ConfigSchema = Field(default_factory=lambda: Computer.ConfigSchema()) pass diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index c197d30b..fcd0eb80 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -1,10 +1,13 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations from ipaddress import IPv4Address -from typing import Any, ClassVar, Dict, Optional +from typing import Any, ClassVar, Dict, List, Literal, Optional + +from pydantic import Field from primaite import getLogger +from primaite.simulator.file_system.file_type import FileType from primaite.simulator.network.hardware.base import ( IPWiredNetworkInterface, Link, @@ -22,7 +25,7 @@ from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.icmp.icmp import ICMP from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.terminal.terminal import Terminal -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ipv4_address import IPV4Address _LOGGER = getLogger(__name__) @@ -44,8 +47,8 @@ class HostARP(ARP): :return: The MAC address of the default gateway if present in the ARP cache; otherwise, None. """ - if self.software_manager.node.default_gateway: - return self.get_arp_cache_mac_address(self.software_manager.node.default_gateway) + if self.software_manager.node.config.default_gateway: + return self.get_arp_cache_mac_address(self.software_manager.node.config.default_gateway) def get_default_gateway_network_interface(self) -> Optional[NIC]: """ @@ -53,8 +56,11 @@ class HostARP(ARP): :return: The NIC associated with the default gateway if it exists in the ARP cache; otherwise, None. """ - if self.software_manager.node.default_gateway and self.software_manager.node.has_enabled_network_interface: - return self.get_arp_cache_network_interface(self.software_manager.node.default_gateway) + if ( + self.software_manager.node.config.default_gateway + and self.software_manager.node.has_enabled_network_interface + ): + return self.get_arp_cache_network_interface(self.software_manager.node.config.default_gateway) def _get_arp_cache_mac_address( self, ip_address: IPV4Address, is_reattempt: bool = False, is_default_gateway_attempt: bool = False @@ -73,7 +79,7 @@ class HostARP(ARP): if arp_entry: return arp_entry.mac_address - if ip_address == self.software_manager.node.default_gateway: + if ip_address == self.software_manager.node.config.default_gateway: is_reattempt = True if not is_reattempt: self.send_arp_request(ip_address) @@ -81,11 +87,11 @@ class HostARP(ARP): ip_address=ip_address, is_reattempt=True, is_default_gateway_attempt=is_default_gateway_attempt ) else: - if self.software_manager.node.default_gateway: + if self.software_manager.node.config.default_gateway: if not is_default_gateway_attempt: - self.send_arp_request(self.software_manager.node.default_gateway) + self.send_arp_request(self.software_manager.node.config.default_gateway) return self._get_arp_cache_mac_address( - ip_address=self.software_manager.node.default_gateway, + ip_address=self.software_manager.node.config.default_gateway, is_reattempt=True, is_default_gateway_attempt=True, ) @@ -116,7 +122,7 @@ class HostARP(ARP): if arp_entry: return self.software_manager.node.network_interfaces[arp_entry.network_interface_uuid] else: - if ip_address == self.software_manager.node.default_gateway: + if ip_address == self.software_manager.node.config.default_gateway: is_reattempt = True if not is_reattempt: self.send_arp_request(ip_address) @@ -124,11 +130,11 @@ class HostARP(ARP): ip_address=ip_address, is_reattempt=True, is_default_gateway_attempt=is_default_gateway_attempt ) else: - if self.software_manager.node.default_gateway: + if self.software_manager.node.config.default_gateway: if not is_default_gateway_attempt: - self.send_arp_request(self.software_manager.node.default_gateway) + self.send_arp_request(self.software_manager.node.config.default_gateway) return self._get_arp_cache_network_interface( - ip_address=self.software_manager.node.default_gateway, + ip_address=self.software_manager.node.config.default_gateway, is_reattempt=True, is_default_gateway_attempt=True, ) @@ -262,7 +268,7 @@ class NIC(IPWiredNetworkInterface): return f"Port {self.port_name if self.port_name else self.port_num}: {self.mac_address}/{self.ip_address}" -class HostNode(Node): +class HostNode(Node, discriminator="host-node"): """ Represents a host node in the network. @@ -308,15 +314,15 @@ class HostNode(Node): """ SYSTEM_SOFTWARE: ClassVar[Dict] = { - "HostARP": HostARP, - "ICMP": ICMP, - "DNSClient": DNSClient, - "NTPClient": NTPClient, - "WebBrowser": WebBrowser, - "NMAP": NMAP, - "UserSessionManager": UserSessionManager, - "UserManager": UserManager, - "Terminal": Terminal, + "host-arp": HostARP, + "icmp": ICMP, + "dns-client": DNSClient, + "ntp-client": NTPClient, + "web-browser": WebBrowser, + "nmap": NMAP, + "user-session-manager": UserSessionManager, + "user-manager": UserManager, + "terminal": Terminal, } """List of system software that is automatically installed on nodes.""" @@ -325,9 +331,35 @@ class HostNode(Node): network_interface: Dict[int, NIC] = {} "The NICs on the node by port id." - def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs): + class ConfigSchema(Node.ConfigSchema): + """Configuration Schema for HostNode class.""" + + type: Literal["host-node"] = "host-node" + hostname: str = "HostNode" + subnet_mask: IPV4Address = "255.255.255.0" + ip_address: IPV4Address + services: Any = None # temporarily unset to appease extra="forbid" + applications: Any = None # temporarily unset to appease extra="forbid" + folders: List[Dict] = {} # temporarily unset to appease extra="forbid" + network_interfaces: Any = None # temporarily unset to appease extra="forbid" + + config: ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema()) + + def __init__(self, **kwargs): super().__init__(**kwargs) - self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask)) + self.connect_nic(NIC(ip_address=kwargs["config"].ip_address, subnet_mask=kwargs["config"].subnet_mask)) + + for folder in self.config.folders: + # handle empty foler defined by just a string + self.file_system.create_folder(folder["folder_name"]) + + for file in folder.get("files", []): + self.file_system.create_file( + folder_name=folder["folder_name"], + file_name=file["file_name"], + size=file.get("size", 0), + file_type=FileType[file.get("type", "UNKNOWN").upper()], + ) @property def nmap(self) -> Optional[NMAP]: @@ -337,7 +369,7 @@ class HostNode(Node): :return: NMAP application installed on the Node. :rtype: Optional[NMAP] """ - return self.software_manager.software.get("NMAP") + return self.software_manager.software.get("nmap") @property def arp(self) -> Optional[ARP]: @@ -347,7 +379,7 @@ class HostNode(Node): :return: ARP Cache for given HostNode :rtype: Optional[ARP] """ - return self.software_manager.software.get("ARP") + return self.software_manager.software.get("arp") def default_gateway_hello(self): """ @@ -356,7 +388,7 @@ class HostNode(Node): This method is invoked to ensure the host node can communicate with its default gateway, primarily to confirm network connectivity and populate the ARP cache with the gateway's MAC address. """ - if self.operating_state == NodeOperatingState.ON and self.default_gateway: + if self.operating_state == NodeOperatingState.ON and self.config.default_gateway: self.software_manager.arp.get_default_gateway_mac_address() def receive_frame(self, frame: Frame, from_network_interface: NIC): @@ -379,8 +411,8 @@ class HostNode(Node): dst_port = frame.udp.dst_port can_accept_nmap = False - if self.software_manager.software.get("NMAP"): - if self.software_manager.software["NMAP"].operating_state == ApplicationOperatingState.RUNNING: + if self.software_manager.software.get("nmap"): + if self.software_manager.software["nmap"].operating_state == ApplicationOperatingState.RUNNING: can_accept_nmap = True accept_nmap = can_accept_nmap and frame.payload.__class__.__name__ == "PortScanPayload" diff --git a/src/primaite/simulator/network/hardware/nodes/host/server.py b/src/primaite/simulator/network/hardware/nodes/host/server.py index 379c9927..09c1708b 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/server.py +++ b/src/primaite/simulator/network/hardware/nodes/host/server.py @@ -1,8 +1,13 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + +from typing import Literal + +from pydantic import Field + from primaite.simulator.network.hardware.nodes.host.host_node import HostNode -class Server(HostNode): +class Server(HostNode, discriminator="server"): """ A basic Server class. @@ -30,8 +35,24 @@ class Server(HostNode): * Web Browser """ + class ConfigSchema(HostNode.ConfigSchema): + """Configuration Schema for Server class.""" -class Printer(HostNode): + type: Literal["server"] = "server" + hostname: str = "server" + + config: ConfigSchema = Field(default_factory=lambda: Server.ConfigSchema()) + + +class Printer(HostNode, discriminator="printer"): """Printer? I don't even know her!.""" # TODO: Implement printer-specific behaviour + + class ConfigSchema(HostNode.ConfigSchema): + """Configuration Schema for Printer class.""" + + type: Literal["printer"] = "printer" + hostname: str = "printer" + + config: ConfigSchema = Field(default_factory=lambda: Printer.ConfigSchema()) diff --git a/src/primaite/simulator/network/hardware/nodes/network/__init__.py b/src/primaite/simulator/network/hardware/nodes/network/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/__init__.py +++ b/src/primaite/simulator/network/hardware/nodes/network/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 4510eac0..c872b8b3 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -1,12 +1,11 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address -from typing import Dict, Final, Union +from typing import Dict, Final, Literal, Union from prettytable import MARKDOWN, PrettyTable from pydantic import Field, validate_call from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.router import ( AccessControlList, ACLAction, @@ -14,10 +13,10 @@ from primaite.simulator.network.hardware.nodes.network.router import ( RouterInterface, ) from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.sys_log import SysLog -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import PORT_LOOKUP EXTERNAL_PORT_ID: Final[int] = 1 """The Firewall port ID of the external port.""" @@ -27,7 +26,7 @@ DMZ_PORT_ID: Final[int] = 3 """The Firewall port ID of the DMZ port.""" -class Firewall(Router): +class Firewall(Router, discriminator="firewall"): """ A Firewall class that extends the functionality of a Router. @@ -50,7 +49,7 @@ class Firewall(Router): Example: >>> from primaite.simulator.network.transmission.network_layer import IPProtocol - >>> from primaite.simulator.network.transmission.transport_layer import Port + >>> from primaite.utils.validation.port import Port >>> firewall = Firewall(hostname="Firewall1") >>> firewall.configure_internal_port(ip_address="192.168.1.1", subnet_mask="255.255.255.0") >>> firewall.configure_external_port(ip_address="10.0.0.1", subnet_mask="255.255.255.0") @@ -58,8 +57,8 @@ class Firewall(Router): >>> # Permit HTTP traffic to the DMZ >>> firewall.dmz_inbound_acl.add_rule( ... action=ACLAction.PERMIT, - ... protocol=IPProtocol.TCP, - ... dst_port=Port.HTTP, + ... protocol=IPProtocol["TCP"], + ... dst_port=Port["HTTP"], ... src_ip_address="0.0.0.0", ... src_wildcard_mask="0.0.0.0", ... dst_ip_address="172.16.0.0", @@ -99,11 +98,22 @@ class Firewall(Router): ) """Access Control List for managing traffic leaving towards an external network.""" - def __init__(self, hostname: str, **kwargs): - if not kwargs.get("sys_log"): - kwargs["sys_log"] = SysLog(hostname) + _identifier: str = "firewall" - super().__init__(hostname=hostname, num_ports=0, **kwargs) + class ConfigSchema(Router.ConfigSchema): + """Configuration Schema for Firewall 'Nodes' within PrimAITE.""" + + type: Literal["firewall"] = "firewall" + hostname: str = "firewall" + num_ports: int = 0 + + config: ConfigSchema = Field(default_factory=lambda: Firewall.ConfigSchema()) + + def __init__(self, **kwargs): + if not kwargs.get("sys_log"): + kwargs["sys_log"] = SysLog(kwargs["config"].hostname) + + super().__init__(**kwargs) self.connect_nic( RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", port_name="external") @@ -116,22 +126,23 @@ class Firewall(Router): ) # Update ACL objects with firewall's hostname and syslog to allow accurate logging self.internal_inbound_acl.sys_log = kwargs["sys_log"] - self.internal_inbound_acl.name = f"{hostname} - Internal Inbound" + self.internal_inbound_acl.name = f"{kwargs['config'].hostname} - Internal Inbound" self.internal_outbound_acl.sys_log = kwargs["sys_log"] - self.internal_outbound_acl.name = f"{hostname} - Internal Outbound" + self.internal_outbound_acl.name = f"{kwargs['config'].hostname} - Internal Outbound" self.dmz_inbound_acl.sys_log = kwargs["sys_log"] - self.dmz_inbound_acl.name = f"{hostname} - DMZ Inbound" + self.dmz_inbound_acl.name = f"{kwargs['config'].hostname} - DMZ Inbound" self.dmz_outbound_acl.sys_log = kwargs["sys_log"] - self.dmz_outbound_acl.name = f"{hostname} - DMZ Outbound" + self.dmz_outbound_acl.name = f"{kwargs['config'].hostname} - DMZ Outbound" self.external_inbound_acl.sys_log = kwargs["sys_log"] - self.external_inbound_acl.name = f"{hostname} - External Inbound" + self.external_inbound_acl.name = f"{kwargs['config'].hostname} - External Inbound" self.external_outbound_acl.sys_log = kwargs["sys_log"] - self.external_outbound_acl.name = f"{hostname} - External Outbound" + self.external_outbound_acl.name = f"{kwargs['config'].hostname} - External Outbound" + self.power_on() def _init_request_manager(self) -> RequestManager: """ @@ -231,7 +242,7 @@ class Firewall(Router): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.hostname} Network Interfaces" + table.title = f"{self.config.hostname} Network Interfaces" ports = {"External": self.external_port, "Internal": self.internal_port, "DMZ": self.dmz_port} for port, network_interface in ports.items(): table.add_row( @@ -258,23 +269,15 @@ class Firewall(Router): :param dmz: If True, shows ACL rules for DMZ interfaces. :param markdown: If True, formats the output in Markdown, enhancing readability in Markdown-compatible viewers. """ - print(f"{self.hostname} Firewall Rules") - print() if external: self.external_inbound_acl.show(markdown) - print() self.external_outbound_acl.show(markdown) - print() if internal: self.internal_inbound_acl.show(markdown) - print() self.internal_outbound_acl.show(markdown) - print() if dmz: self.dmz_inbound_acl.show(markdown) - print() self.dmz_outbound_acl.show(markdown) - print() def receive_frame(self, frame: Frame, from_network_interface: RouterInterface): """ @@ -559,18 +562,14 @@ class Firewall(Router): self.dmz_port.enable() @classmethod - def from_config(cls, cfg: dict) -> "Firewall": + def from_config(cls, config: dict) -> "Firewall": """Create a firewall based on a config dict.""" - firewall = Firewall( - hostname=cfg["hostname"], - operating_state=NodeOperatingState.ON - if not (p := cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - if "ports" in cfg: - internal_port = cfg["ports"]["internal_port"] - external_port = cfg["ports"]["external_port"] - dmz_port = cfg["ports"].get("dmz_port") + firewall = Firewall(config=cls.ConfigSchema(**config)) + + if "ports" in config: + internal_port = config["ports"]["internal_port"] + external_port = config["ports"]["external_port"] + dmz_port = config["ports"].get("dmz_port") # configure internal port firewall.configure_internal_port( @@ -590,15 +589,15 @@ class Firewall(Router): ip_address=IPV4Address(dmz_port.get("ip_address")), subnet_mask=IPV4Address(dmz_port.get("subnet_mask", "255.255.255.0")), ) - if "acl" in cfg: + if "acl" in config: # acl rules for internal_inbound_acl - if cfg["acl"]["internal_inbound_acl"]: - for r_num, r_cfg in cfg["acl"]["internal_inbound_acl"].items(): + if config["acl"]["internal_inbound_acl"]: + for r_num, r_cfg in config["acl"]["internal_inbound_acl"].items(): firewall.internal_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -607,13 +606,13 @@ class Firewall(Router): ) # acl rules for internal_outbound_acl - if cfg["acl"]["internal_outbound_acl"]: - for r_num, r_cfg in cfg["acl"]["internal_outbound_acl"].items(): + if config["acl"]["internal_outbound_acl"]: + for r_num, r_cfg in config["acl"]["internal_outbound_acl"].items(): firewall.internal_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -622,13 +621,13 @@ class Firewall(Router): ) # acl rules for dmz_inbound_acl - if cfg["acl"]["dmz_inbound_acl"]: - for r_num, r_cfg in cfg["acl"]["dmz_inbound_acl"].items(): + if config["acl"]["dmz_inbound_acl"]: + for r_num, r_cfg in config["acl"]["dmz_inbound_acl"].items(): firewall.dmz_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -637,13 +636,13 @@ class Firewall(Router): ) # acl rules for dmz_outbound_acl - if cfg["acl"]["dmz_outbound_acl"]: - for r_num, r_cfg in cfg["acl"]["dmz_outbound_acl"].items(): + if config["acl"]["dmz_outbound_acl"]: + for r_num, r_cfg in config["acl"]["dmz_outbound_acl"].items(): firewall.dmz_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -652,13 +651,13 @@ class Firewall(Router): ) # acl rules for external_inbound_acl - if cfg["acl"].get("external_inbound_acl"): - for r_num, r_cfg in cfg["acl"]["external_inbound_acl"].items(): + if config["acl"].get("external_inbound_acl"): + for r_num, r_cfg in config["acl"]["external_inbound_acl"].items(): firewall.external_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -667,13 +666,13 @@ class Firewall(Router): ) # acl rules for external_outbound_acl - if cfg["acl"].get("external_outbound_acl"): - for r_num, r_cfg in cfg["acl"]["external_outbound_acl"].items(): + if config["acl"].get("external_outbound_acl"): + for r_num, r_cfg in config["acl"]["external_outbound_acl"].items(): firewall.external_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -681,16 +680,16 @@ class Firewall(Router): position=r_num, ) - if "routes" in cfg: - for route in cfg.get("routes"): + if "routes" in config: + for route in config.get("routes"): firewall.route_table.add_route( address=IPv4Address(route.get("address")), subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")), next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")), metric=float(route.get("metric", 0)), ) - if "default_route" in cfg: - next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None) + if "default_route" in config: + next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None) if next_hop_ip_address: firewall.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) diff --git a/src/primaite/simulator/network/hardware/nodes/network/network_node.py b/src/primaite/simulator/network/hardware/nodes/network/network_node.py index 5ff791cc..e5e77402 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/network_node.py +++ b/src/primaite/simulator/network/hardware/nodes/network/network_node.py @@ -1,13 +1,13 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from abc import abstractmethod -from typing import Optional +from typing import Any, Optional from primaite.simulator.network.hardware.base import NetworkInterface, Node from primaite.simulator.network.transmission.data_link_layer import Frame from primaite.simulator.system.services.arp.arp import ARP -class NetworkNode(Node): +class NetworkNode(Node, discriminator="network-node"): """ Represents an abstract base class for a network node that can receive and process network frames. @@ -16,6 +16,11 @@ class NetworkNode(Node): provide functionality for receiving and processing frames received on their network interfaces. """ + class ConfigSchema(Node.ConfigSchema): + """Config schema for Node baseclass.""" + + num_ports: Any = None # temporarily unset to appease extra="forbid" + @abstractmethod def receive_frame(self, frame: Frame, from_network_interface: NetworkInterface): """ @@ -40,4 +45,4 @@ class NetworkNode(Node): :return: ARP Cache for given NetworkNode :rtype: Optional[ARP] """ - return self.software_manager.software.get("ARP") + return self.software_manager.software.get("arp") diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 8cdf3f86..33b55b9d 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1,13 +1,13 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations import secrets from enum import Enum from ipaddress import IPv4Address, IPv4Network -from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union +from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Union from prettytable import MARKDOWN, PrettyTable -from pydantic import validate_call +from pydantic import Field, validate_call from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent @@ -17,15 +17,15 @@ from primaite.simulator.network.hardware.nodes.network.network_node import Netwo from primaite.simulator.network.protocols.arp import ARPPacket from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.nmap import NMAP from primaite.simulator.system.core.session_manager import SessionManager from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.arp.arp import ARP from primaite.simulator.system.services.icmp.icmp import ICMP from primaite.simulator.system.services.terminal.terminal import Terminal -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import Port, PORT_LOOKUP @validate_call() @@ -106,7 +106,7 @@ class ACLRule(SimComponent): :ivar ACLAction action: Specifies whether to `PERMIT` or `DENY` the traffic that matches the rule conditions. The default action is `DENY`. - :ivar Optional[IPProtocol] protocol: The network protocol (e.g., TCP, UDP, ICMP) to match. If `None`, the rule + :ivar Optional[str] protocol: The network protocol (e.g., TCP, UDP, ICMP) to match. If `None`, the rule applies to all protocols. :ivar Optional[IPV4Address] src_ip_address: The source IP address to match. If combined with `src_wildcard_mask`, it specifies the start of an IP range. @@ -116,8 +116,8 @@ class ACLRule(SimComponent): `dst_wildcard_mask`, it specifies the start of an IP range. :ivar Optional[IPv4Address] dst_wildcard_mask: The wildcard mask for the destination IP address, defining the range of addresses to match. - :ivar Optional[Port] src_port: The source port number to match. Relevant for TCP/UDP protocols. - :ivar Optional[Port] dst_port: The destination port number to match. Relevant for TCP/UDP protocols. + :ivar Optional[int] src_port: The source port number to match. Relevant for TCP/UDP protocols. + :ivar Optional[int] dst_port: The destination port number to match. Relevant for TCP/UDP protocols. """ action: ACLAction = ACLAction.DENY @@ -149,13 +149,13 @@ class ACLRule(SimComponent): """ state = super().describe_state() state["action"] = self.action.value - state["protocol"] = self.protocol.name if self.protocol else None + state["protocol"] = self.protocol if self.protocol else None state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None state["src_wildcard_mask"] = str(self.src_wildcard_mask) if self.src_wildcard_mask else None - state["src_port"] = self.src_port.name if self.src_port else None + state["src_port"] = self.src_port if self.src_port else None state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None state["dst_wildcard_mask"] = str(self.dst_wildcard_mask) if self.dst_wildcard_mask else None - state["dst_port"] = self.dst_port.name if self.dst_port else None + state["dst_port"] = self.dst_port if self.dst_port else None state["match_count"] = self.match_count return state @@ -265,7 +265,7 @@ class AccessControlList(SimComponent): >>> acl = AccessControlList() >>> acl.add_rule( ... action=ACLAction.PERMIT, - ... protocol=IPProtocol.TCP, + ... protocol=IPProtocol["TCP"], ... src_ip_address="192.168.1.0", ... src_wildcard_mask="0.0.0.255", ... dst_ip_address="192.168.2.0", @@ -323,13 +323,13 @@ class AccessControlList(SimComponent): func=lambda request, context: RequestResponse.from_bool( self.add_rule( action=ACLAction[request[0]], - protocol=None if request[1] == "ALL" else IPProtocol[request[1]], + protocol=None if request[1] == "ALL" else request[1], src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]), src_wildcard_mask=None if request[3] == "NONE" else IPv4Address(request[3]), - src_port=None if request[4] == "ALL" else Port[request[4]], + src_port=None if request[4] == "ALL" else request[4], dst_ip_address=None if request[5] == "ALL" else IPv4Address(request[5]), dst_wildcard_mask=None if request[6] == "NONE" else IPv4Address(request[6]), - dst_port=None if request[7] == "ALL" else Port[request[7]], + dst_port=None if request[7] == "ALL" else request[7], position=int(request[8]), ) ) @@ -399,11 +399,11 @@ class AccessControlList(SimComponent): >>> router = Router("router") >>> router.add_rule( ... action=ACLAction.DENY, - ... protocol=IPProtocol.TCP, + ... protocol=IPProtocol["TCP"], ... src_ip_address="192.168.1.0", ... src_wildcard_mask="0.0.0.255", ... dst_ip_address="10.10.10.5", - ... dst_port=Port.SSH, + ... dst_port=Port["SSH"], ... position=5 ... ) >>> # This permits SSH traffic from the 192.168.1.0/24 subnet to the 10.10.10.5 server. @@ -411,10 +411,10 @@ class AccessControlList(SimComponent): >>> # Then if we want to allow a specific IP address from this subnet to SSH into the server >>> router.add_rule( ... action=ACLAction.PERMIT, - ... protocol=IPProtocol.TCP, + ... protocol=IPProtocol["TCP"], ... src_ip_address="192.168.1.25", ... dst_ip_address="10.10.10.5", - ... dst_port=Port.SSH, + ... dst_port=Port["SSH"], ... position=4 ... ) @@ -553,13 +553,13 @@ class AccessControlList(SimComponent): [ index, rule.action.name if rule.action else "ANY", - rule.protocol.name if rule.protocol else "ANY", + rule.protocol if rule.protocol else "ANY", rule.src_ip_address if rule.src_ip_address else "ANY", rule.src_wildcard_mask if rule.src_wildcard_mask else "ANY", - f"{rule.src_port.value} ({rule.src_port.name})" if rule.src_port else "ANY", + f"{rule.src_port}" if rule.src_port else "ANY", rule.dst_ip_address if rule.dst_ip_address else "ANY", rule.dst_wildcard_mask if rule.dst_wildcard_mask else "ANY", - f"{rule.dst_port.value} ({rule.dst_port.name})" if rule.dst_port else "ANY", + f"{rule.dst_port}" if rule.dst_port else "ANY", rule.match_count, ] ) @@ -1185,7 +1185,7 @@ class RouterSessionManager(SessionManager): return outbound_network_interface, dst_mac_address, dst_ip_address, src_port, dst_port, protocol, is_broadcast -class Router(NetworkNode): +class Router(NetworkNode, discriminator="router"): """ Represents a network router, managing routing and forwarding of IP packets across network interfaces. @@ -1202,13 +1202,25 @@ class Router(NetworkNode): RouteTable, RouterARP, and RouterICMP services. """ + class ConfigSchema(NetworkNode.ConfigSchema): + """Configuration Schema for Routers.""" + + type: Literal["router"] = "router" + hostname: str = "router" + num_ports: int = 5 + acl: Any = None # temporarily unset to appease extra="forbid" + routes: Any = None # temporarily unset to appease extra="forbid" + ports: Any = None # temporarily unset to appease extra="forbid" + default_route: Any = None # temporarily unset to appease extra="forbid" + + config: ConfigSchema = Field(default_factory=lambda: Router.ConfigSchema()) + SYSTEM_SOFTWARE: ClassVar[Dict] = { - "UserSessionManager": UserSessionManager, - "UserManager": UserManager, - "Terminal": Terminal, + "user-session-manager": UserSessionManager, + "user-manager": UserManager, + "terminal": Terminal, } - num_ports: int network_interfaces: Dict[str, RouterInterface] = {} "The Router Interfaces on the node." network_interface: Dict[int, RouterInterface] = {} @@ -1216,19 +1228,21 @@ class Router(NetworkNode): acl: AccessControlList route_table: RouteTable - def __init__(self, hostname: str, num_ports: int = 5, **kwargs): + def __init__(self, **kwargs): if not kwargs.get("sys_log"): - kwargs["sys_log"] = SysLog(hostname) + kwargs["sys_log"] = SysLog(kwargs["config"].hostname) if not kwargs.get("acl"): - kwargs["acl"] = AccessControlList(sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=hostname) + kwargs["acl"] = AccessControlList( + sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=kwargs["config"].hostname + ) if not kwargs.get("route_table"): kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"]) - super().__init__(hostname=hostname, num_ports=num_ports, **kwargs) + super().__init__(**kwargs) self.session_manager = RouterSessionManager(sys_log=self.sys_log) self.session_manager.node = self self.software_manager.session_manager = self.session_manager self.session_manager.software_manager = self.software_manager - for i in range(1, self.num_ports + 1): + for i in range(1, self.config.num_ports + 1): network_interface = RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0") self.connect_nic(network_interface) self.network_interface[i] = network_interface @@ -1258,7 +1272,10 @@ class Router(NetworkNode): Initializes the router's ACL (Access Control List) with default rules, permitting essential protocols like ARP and ICMP, which are necessary for basic network operations and diagnostics. """ - self.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + self.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + self.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) def setup_for_episode(self, episode: int): """ @@ -1335,7 +1352,6 @@ class Router(NetworkNode): :return: A dictionary representing the current state. """ state = super().describe_state() - state["num_ports"] = self.num_ports state["acl"] = self.acl.describe_state() return state @@ -1357,9 +1373,9 @@ class Router(NetworkNode): """ dst_ip_address = frame.ip.dst_ip_address dst_port = None - if frame.ip.protocol == IPProtocol.TCP: + if frame.ip.protocol == PROTOCOL_LOOKUP["TCP"]: dst_port = frame.tcp.dst_port - elif frame.ip.protocol == IPProtocol.UDP: + elif frame.ip.protocol == PROTOCOL_LOOKUP["UDP"]: dst_port = frame.udp.dst_port if self.ip_is_router_interface(dst_ip_address) and ( @@ -1371,7 +1387,7 @@ class Router(NetworkNode): def subject_to_acl(self, frame: Frame) -> bool: """Check that frame is subject to ACL rules.""" - if frame.ip.protocol == IPProtocol.UDP and frame.is_arp: + if frame.ip.protocol == "udp" and frame.is_arp: return False return True @@ -1553,7 +1569,7 @@ class Router(NetworkNode): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.hostname} Network Interfaces" + table.title = f"{self.config.hostname} Network Interfaces" for port, network_interface in self.network_interface.items(): table.add_row( [ @@ -1567,7 +1583,7 @@ class Router(NetworkNode): print(table) @classmethod - def from_config(cls, cfg: dict, **kwargs) -> "Router": + def from_config(cls, config: dict, **kwargs) -> "Router": """Create a router based on a config dict. Schema: @@ -1624,43 +1640,44 @@ class Router(NetworkNode): :return: Configured router. :rtype: Router """ - router = Router( - hostname=cfg["hostname"], - num_ports=int(cfg.get("num_ports", "5")), - operating_state=NodeOperatingState.ON - if not (p := cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - if "ports" in cfg: - for port_num, port_cfg in cfg["ports"].items(): + ports = config.pop("ports", None) + acl = config.pop("acl", None) + routes = config.pop("routes", None) + default_route = config.pop("default_route", None) + router = Router(config=Router.ConfigSchema(**config)) + if ports: + for port_num, port_cfg in ports.items(): router.configure_port( port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=IPv4Address(port_cfg.get("subnet_mask", "255.255.255.0")), ) - if "acl" in cfg: - for r_num, r_cfg in cfg["acl"].items(): + if acl: + for r_num, r_cfg in acl.items(): router.acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"), position=r_num, ) - if "routes" in cfg: - for route in cfg.get("routes"): + if routes: + for route in routes: router.route_table.add_route( address=IPv4Address(route.get("address")), subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")), next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")), metric=float(route.get("metric", 0)), ) - if "default_route" in cfg: - next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None) + if default_route: + next_hop_ip_address = default_route.get("next_hop_ip_address", None) if next_hop_ip_address: router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) + router.operating_state = ( + NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()] + ) return router diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index 4324ac94..6e5814d0 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -1,9 +1,10 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import Dict, Optional +from typing import Dict, Literal, Optional from prettytable import MARKDOWN, PrettyTable +from pydantic import Field from primaite import getLogger from primaite.exceptions import NetworkError @@ -87,15 +88,9 @@ class SwitchPort(WiredNetworkInterface): return False -class Switch(NetworkNode): - """ - A class representing a Layer 2 network switch. +class Switch(NetworkNode, discriminator="switch"): + """A class representing a Layer 2 network switch.""" - :ivar num_ports: The number of ports on the switch. Default is 24. - """ - - num_ports: int = 24 - "The number of ports on the switch." network_interfaces: Dict[str, SwitchPort] = {} "The SwitchPorts on the Switch." network_interface: Dict[int, SwitchPort] = {} @@ -103,9 +98,19 @@ class Switch(NetworkNode): mac_address_table: Dict[str, SwitchPort] = {} "A MAC address table mapping destination MAC addresses to corresponding SwitchPorts." + class ConfigSchema(NetworkNode.ConfigSchema): + """Configuration Schema for Switch nodes within PrimAITE.""" + + type: Literal["switch"] = "switch" + hostname: str = "Switch" + num_ports: int = 8 + "The number of ports on the switch." + + config: ConfigSchema = Field(default_factory=lambda: Switch.ConfigSchema()) + def __init__(self, **kwargs): super().__init__(**kwargs) - for i in range(1, self.num_ports + 1): + for i in range(1, self.config.num_ports + 1): self.connect_nic(SwitchPort()) def _install_system_software(self): @@ -121,7 +126,7 @@ class Switch(NetworkNode): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.hostname} Switch Ports" + table.title = f"{self.config.hostname} Switch Ports" for port_num, port in self.network_interface.items(): table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"]) print(table) @@ -134,7 +139,6 @@ class Switch(NetworkNode): """ state = super().describe_state() state["ports"] = {port_num: port.describe_state() for port_num, port in self.network_interface.items()} - state["num_ports"] = self.num_ports # redundant? state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()} return state diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 3cb4c515..eab75ff1 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -1,16 +1,16 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Literal, Optional, Union -from pydantic import validate_call +from pydantic import Field, validate_call -from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, IPWirelessNetworkInterface +from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, FREQ_WIFI_2_4, IPWirelessNetworkInterface from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterInterface from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import PORT_LOOKUP class WirelessAccessPoint(IPWirelessNetworkInterface): @@ -91,7 +91,7 @@ class WirelessAccessPoint(IPWirelessNetworkInterface): ) -class WirelessRouter(Router): +class WirelessRouter(Router, discriminator="wireless-router"): """ A WirelessRouter class that extends the functionality of a standard Router to include wireless capabilities. @@ -116,19 +116,32 @@ class WirelessRouter(Router): >>> wireless_router.configure_wireless_access_point( ... ip_address="10.10.10.1", ... subnet_mask="255.255.255.0" - ... frequency=AirSpaceFrequency.WIFI_2_4 + ... frequency="WIFI_2_4" ... ) """ network_interfaces: Dict[str, Union[RouterInterface, WirelessAccessPoint]] = {} network_interface: Dict[int, Union[RouterInterface, WirelessAccessPoint]] = {} - airspace: AirSpace - def __init__(self, hostname: str, airspace: AirSpace, **kwargs): - super().__init__(hostname=hostname, num_ports=0, airspace=airspace, **kwargs) + class ConfigSchema(Router.ConfigSchema): + """Configuration Schema for WirelessRouter nodes within PrimAITE.""" + + type: Literal["wireless-router"] = "wireless-router" + hostname: str = "WirelessRouter" + num_ports: int = 0 + router_interface: Any = None # temporarily unset to appease extra="forbid" + wireless_access_point: Any = None # temporarily unset to appease extra="forbid" + + airspace: AirSpace + config: ConfigSchema = Field(default_factory=lambda: WirelessRouter.ConfigSchema()) + + def __init__(self, **kwargs): + super().__init__(**kwargs) self.connect_nic( - WirelessAccessPoint(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", airspace=airspace) + WirelessAccessPoint( + ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", airspace=self.airspace + ) ) self.connect_nic(RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")) @@ -153,7 +166,7 @@ class WirelessRouter(Router): self, ip_address: IPV4Address, subnet_mask: IPV4Address, - frequency: Optional[AirSpaceFrequency] = AirSpaceFrequency.WIFI_2_4, + frequency: Optional[AirSpaceFrequency] = FREQ_WIFI_2_4, ): """ Configures a wireless access point (WAP). @@ -166,12 +179,12 @@ class WirelessRouter(Router): :param ip_address: The IP address to be assigned to the wireless access point. :param subnet_mask: The subnet mask associated with the IP address - :param frequency: The operating frequency of the wireless access point, defined by the AirSpaceFrequency + :param frequency: The operating frequency of the wireless access point, defined by the air space frequency enum. This determines the frequency band (e.g., 2.4 GHz or 5 GHz) the access point will use for wireless - communication. Default is AirSpaceFrequency.WIFI_2_4. + communication. Default is "WIFI_2_4". """ if not frequency: - frequency = AirSpaceFrequency.WIFI_2_4 + frequency = FREQ_WIFI_2_4 self.sys_log.info("Configuring wireless access point") self.wireless_access_point.disable() # Temporarily disable the WAP for reconfiguration @@ -226,7 +239,7 @@ class WirelessRouter(Router): ) @classmethod - def from_config(cls, cfg: Dict, **kwargs) -> "WirelessRouter": + def from_config(cls, config: Dict, airspace: AirSpace) -> "WirelessRouter": """Generate the wireless router from config. Schema: @@ -253,35 +266,35 @@ class WirelessRouter(Router): :return: WirelessRouter instance. :rtype: WirelessRouter """ - operating_state = ( - NodeOperatingState.ON if not (p := cfg.get("operating_state")) else NodeOperatingState[p.upper()] + router = cls(config=cls.ConfigSchema(**config), airspace=airspace) + router.operating_state = ( + NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()] ) - router = cls(hostname=cfg["hostname"], operating_state=operating_state, airspace=kwargs["airspace"]) - if "router_interface" in cfg: - ip_address = cfg["router_interface"]["ip_address"] - subnet_mask = cfg["router_interface"]["subnet_mask"] + if "router_interface" in config: + ip_address = config["router_interface"]["ip_address"] + subnet_mask = config["router_interface"]["subnet_mask"] router.configure_router_interface(ip_address=ip_address, subnet_mask=subnet_mask) - if "wireless_access_point" in cfg: - ip_address = cfg["wireless_access_point"]["ip_address"] - subnet_mask = cfg["wireless_access_point"]["subnet_mask"] - frequency = AirSpaceFrequency[cfg["wireless_access_point"]["frequency"]] + if "wireless_access_point" in config: + ip_address = config["wireless_access_point"]["ip_address"] + subnet_mask = config["wireless_access_point"]["subnet_mask"] + frequency = AirSpaceFrequency._registry[config["wireless_access_point"]["frequency"]] router.configure_wireless_access_point(ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency) - if "acl" in cfg: - for r_num, r_cfg in cfg["acl"].items(): + if "acl" in config: + for r_num, r_cfg in config["acl"].items(): router.acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), dst_ip_address=r_cfg.get("dst_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"), position=r_num, ) - if "routes" in cfg: - for route in cfg.get("routes"): + if "routes" in config: + for route in config.get("routes"): router.route_table.add_route( address=IPv4Address(route.get("address")), subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")), diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index ae6476c1..5d558e80 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address import yaml @@ -12,14 +12,14 @@ from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -40,46 +40,54 @@ def client_server_routed() -> Network: network = Network() # Router 1 - router_1 = Router(hostname="router_1", num_ports=3) + router_1 = Router(config=dict(hostname="router_1", num_ports=3)) router_1.power_on() router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0") router_1.configure_port(port=2, ip_address="192.168.2.1", subnet_mask="255.255.255.0") # Switch 1 - switch_1 = Switch(hostname="switch_1", num_ports=6) + switch_1 = Switch(config=dict(hostname="switch_1", num_ports=6)) switch_1.power_on() network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[6]) router_1.enable_port(1) # Switch 2 - switch_2 = Switch(hostname="switch_2", num_ports=6) + switch_2 = Switch(config=dict(hostname="switch_2", num_ports=6)) switch_2.power_on() network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[6]) router_1.enable_port(2) # Client 1 client_1 = Computer( - hostname="client_1", - ip_address="192.168.2.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.2.1", - start_up_duration=0, + config=dict( + hostname="client_1", + ip_address="192.168.2.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.2.1", + start_up_duration=0, + ) ) client_1.power_on() network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) # Server 1 server_1 = Server( - hostname="server_1", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + config=dict( + hostname="server_1", + ip_address="192.168.1.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) ) server_1.power_on() network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1]) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) return network @@ -124,42 +132,51 @@ def arcd_uc2_network() -> Network: network = Network() # Router 1 - router_1 = Router(hostname="router_1", num_ports=5, start_up_duration=0) + router_1 = Router.from_config( + config={"type": "router", "hostname": "router_1", "num_ports": 5, "start_up_duration": 0} + ) router_1.power_on() router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0") router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0") # Switch 1 - switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0) + switch_1 = Switch.from_config( + config={"type": "switch", "hostname": "switch_1", "num_ports": 8, "start_up_duration": 0} + ) switch_1.power_on() network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[8]) router_1.enable_port(1) # Switch 2 - switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0) + switch_2 = Switch.from_config( + config={"type": "switch", "hostname": "switch_2", "num_ports": 8, "start_up_duration": 0} + ) switch_2.power_on() network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8]) router_1.enable_port(2) # Client 1 - client_1 = Computer( - hostname="client_1", - ip_address="192.168.10.21", - subnet_mask="255.255.255.0", - default_gateway="192.168.10.1", - dns_server=IPv4Address("192.168.1.10"), - start_up_duration=0, - ) + client_1_cfg = { + "type": "computer", + "hostname": "client_1", + "ip_address": "192.168.10.21", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.10.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } + client_1: Computer = Computer.from_config(config=client_1_cfg) + client_1.power_on() network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) client_1.software_manager.install(DatabaseClient) - db_client_1: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") + db_client_1: DatabaseClient = client_1.software_manager.software.get("database-client") db_client_1.configure(server_ip_address=IPv4Address("192.168.1.14")) db_client_1.run() - web_browser_1 = client_1.software_manager.software.get("WebBrowser") + web_browser_1 = client_1.software_manager.software.get("web-browser") web_browser_1.target_url = "http://arcd.com/users/" client_1.software_manager.install(DataManipulationBot) - db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") + db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("data-manipulation-bot") db_manipulation_bot.configure( server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE", @@ -168,20 +185,24 @@ def arcd_uc2_network() -> Network: ) # Client 2 - client_2 = Computer( - hostname="client_2", - ip_address="192.168.10.22", - subnet_mask="255.255.255.0", - default_gateway="192.168.10.1", - dns_server=IPv4Address("192.168.1.10"), - start_up_duration=0, - ) + + client_2_cfg = { + "type": "computer", + "hostname": "client_2", + "ip_address": "192.168.10.22", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.10.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } + client_2: Computer = Computer.from_config(config=client_2_cfg) + client_2.power_on() client_2.software_manager.install(DatabaseClient) - db_client_2 = client_2.software_manager.software.get("DatabaseClient") + db_client_2 = client_2.software_manager.software.get("database-client") db_client_2.configure(server_ip_address=IPv4Address("192.168.1.14")) db_client_2.run() - web_browser_2 = client_2.software_manager.software.get("WebBrowser") + web_browser_2 = client_2.software_manager.software.get("web-browser") web_browser_2.target_url = "http://arcd.com/users/" network.connect( endpoint_b=client_2.network_interface[1], @@ -189,48 +210,61 @@ def arcd_uc2_network() -> Network: ) # Domain Controller - domain_controller = Server( - hostname="domain_controller", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + + domain_controller_cfg = { + "type": "server", + "hostname": "domain_controller", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + domain_controller = Server.from_config(config=domain_controller_cfg) domain_controller.power_on() domain_controller.software_manager.install(DNSServer) network.connect(endpoint_b=domain_controller.network_interface[1], endpoint_a=switch_1.network_interface[1]) # Database Server - database_server = Server( - hostname="database_server", - ip_address="192.168.1.14", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - dns_server=IPv4Address("192.168.1.10"), - start_up_duration=0, - ) + + database_server_cfg = { + "type": "server", + "hostname": "database_server", + "ip_address": "192.168.1.14", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } + + database_server = Server.from_config(config=database_server_cfg) + database_server.power_on() network.connect(endpoint_b=database_server.network_interface[1], endpoint_a=switch_1.network_interface[3]) database_server.software_manager.install(DatabaseService) - database_service: DatabaseService = database_server.software_manager.software.get("DatabaseService") # noqa + database_service: DatabaseService = database_server.software_manager.software.get("database-service") # noqa database_service.start() database_service.configure_backup(backup_server=IPv4Address("192.168.1.16")) # Web Server - web_server = Server( - hostname="web_server", - ip_address="192.168.1.12", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - dns_server=IPv4Address("192.168.1.10"), - start_up_duration=0, - ) + + web_server_cfg = { + "type": "server", + "hostname": "web_server", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } + web_server = Server.from_config(config=web_server_cfg) + web_server.power_on() web_server.software_manager.install(DatabaseClient) - database_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") + database_client: DatabaseClient = web_server.software_manager.software.get("database-client") database_client.configure(server_ip_address=IPv4Address("192.168.1.14")) network.connect(endpoint_b=web_server.network_interface[1], endpoint_a=switch_1.network_interface[2]) database_client.run() @@ -239,51 +273,65 @@ def arcd_uc2_network() -> Network: web_server.software_manager.install(WebServer) # register the web_server to a domain - dns_server_service: DNSServer = domain_controller.software_manager.software.get("DNSServer") # noqa + dns_server_service: DNSServer = domain_controller.software_manager.software.get("dns-server") # noqa dns_server_service.dns_register("arcd.com", web_server.network_interface[1].ip_address) # Backup Server - backup_server = Server( - hostname="backup_server", - ip_address="192.168.1.16", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - dns_server=IPv4Address("192.168.1.10"), - start_up_duration=0, - ) + backup_server_cfg = { + "type": "server", + "hostname": "backup_server", + "ip_address": "192.168.1.16", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } + backup_server: Server = Server.from_config(config=backup_server_cfg) + backup_server.power_on() backup_server.software_manager.install(FTPServer) network.connect(endpoint_b=backup_server.network_interface[1], endpoint_a=switch_1.network_interface[4]) # Security Suite - security_suite = Server( - hostname="security_suite", - ip_address="192.168.1.110", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - dns_server=IPv4Address("192.168.1.10"), - start_up_duration=0, - ) + security_suite_cfg = { + "type": "server", + "hostname": "security_suite", + "ip_address": "192.168.1.110", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } + security_suite: Server = Server.from_config(config=security_suite_cfg) security_suite.power_on() network.connect(endpoint_b=security_suite.network_interface[1], endpoint_a=switch_1.network_interface[7]) security_suite.connect_nic(NIC(ip_address="192.168.10.110", subnet_mask="255.255.255.0")) network.connect(endpoint_b=security_suite.network_interface[2], endpoint_a=switch_2.network_interface[7]) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Allow PostgreSQL requests router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) # Allow DNS requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["DNS"], dst_port=PORT_LOOKUP["DNS"], position=1) # Allow FTP requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=2) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=2) # Open port 80 for web server - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=3 + ) return network diff --git a/src/primaite/simulator/network/nmne.py b/src/primaite/simulator/network/nmne.py index c9cff5de..a2e5f1fe 100644 --- a/src/primaite/simulator/network/nmne.py +++ b/src/primaite/simulator/network/nmne.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import List from pydantic import BaseModel, ConfigDict diff --git a/src/primaite/simulator/network/protocols/__init__.py b/src/primaite/simulator/network/protocols/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/network/protocols/__init__.py +++ b/src/primaite/simulator/network/protocols/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/network/protocols/arp.py b/src/primaite/simulator/network/protocols/arp.py index 9e7f7ebe..86e461d0 100644 --- a/src/primaite/simulator/network/protocols/arp.py +++ b/src/primaite/simulator/network/protocols/arp.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations from ipaddress import IPv4Address diff --git a/src/primaite/simulator/network/protocols/dns.py b/src/primaite/simulator/network/protocols/dns.py index eb7b74ad..c0fed1aa 100644 --- a/src/primaite/simulator/network/protocols/dns.py +++ b/src/primaite/simulator/network/protocols/dns.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations from ipaddress import IPv4Address diff --git a/src/primaite/simulator/network/protocols/ftp.py b/src/primaite/simulator/network/protocols/ftp.py index c570a634..fd8fdd2b 100644 --- a/src/primaite/simulator/network/protocols/ftp.py +++ b/src/primaite/simulator/network/protocols/ftp.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from enum import Enum from typing import Any, Optional, Union diff --git a/src/primaite/simulator/network/protocols/http.py b/src/primaite/simulator/network/protocols/http.py index 5390cd26..54abdd98 100644 --- a/src/primaite/simulator/network/protocols/http.py +++ b/src/primaite/simulator/network/protocols/http.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from enum import Enum, IntEnum from primaite.simulator.network.protocols.packet import DataPacket diff --git a/src/primaite/simulator/network/protocols/icmp.py b/src/primaite/simulator/network/protocols/icmp.py index 9f0626f0..fcbe15da 100644 --- a/src/primaite/simulator/network/protocols/icmp.py +++ b/src/primaite/simulator/network/protocols/icmp.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import secrets from enum import Enum from typing import Union diff --git a/src/primaite/simulator/network/protocols/masquerade.py b/src/primaite/simulator/network/protocols/masquerade.py index e2a7b6a0..e0ed26b7 100644 --- a/src/primaite/simulator/network/protocols/masquerade.py +++ b/src/primaite/simulator/network/protocols/masquerade.py @@ -1,16 +1,18 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from enum import Enum from typing import Optional from primaite.simulator.network.protocols.packet import DataPacket +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.port import Port class MasqueradePacket(DataPacket): """Represents an generic malicious packet that is masquerading as another protocol.""" - masquerade_protocol: Enum # The 'Masquerade' protocol that is currently in use + masquerade_protocol: IPProtocol # The 'Masquerade' protocol that is currently in use - masquerade_port: Enum # The 'Masquerade' port that is currently in use + masquerade_port: Port # The 'Masquerade' port that is currently in use class C2Packet(MasqueradePacket): diff --git a/src/primaite/simulator/network/protocols/ntp.py b/src/primaite/simulator/network/protocols/ntp.py index 74e02dab..c9b6f877 100644 --- a/src/primaite/simulator/network/protocols/ntp.py +++ b/src/primaite/simulator/network/protocols/ntp.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations from datetime import datetime diff --git a/src/primaite/simulator/network/protocols/packet.py b/src/primaite/simulator/network/protocols/packet.py index 7eeec13b..6f28f716 100644 --- a/src/primaite/simulator/network/protocols/packet.py +++ b/src/primaite/simulator/network/protocols/packet.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Any from pydantic import BaseModel diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py index be7f842f..03411fb5 100644 --- a/src/primaite/simulator/network/protocols/ssh.py +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from enum import IntEnum from typing import Optional diff --git a/src/primaite/simulator/network/transmission/__init__.py b/src/primaite/simulator/network/transmission/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/network/transmission/__init__.py +++ b/src/primaite/simulator/network/transmission/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/network/transmission/data_link_layer.py b/src/primaite/simulator/network/transmission/data_link_layer.py index 86a6038b..a07194a4 100644 --- a/src/primaite/simulator/network/transmission/data_link_layer.py +++ b/src/primaite/simulator/network/transmission/data_link_layer.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from datetime import datetime from typing import Any, Optional @@ -7,10 +7,12 @@ from pydantic import BaseModel from primaite import getLogger from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.protocols.packet import DataPacket -from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol +from primaite.simulator.network.transmission.network_layer import IPPacket from primaite.simulator.network.transmission.primaite_layer import PrimaiteHeader -from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader, UDPHeader +from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader from primaite.simulator.network.utils import convert_bytes_to_megabits +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -70,15 +72,15 @@ class Frame(BaseModel): msg = "Network Frame cannot have both a TCP header and a UDP header" _LOGGER.error(msg) raise ValueError(msg) - if kwargs["ip"].protocol == IPProtocol.TCP and not kwargs.get("tcp"): + if kwargs["ip"].protocol == PROTOCOL_LOOKUP["TCP"] and not kwargs.get("tcp"): msg = "Cannot build a Frame using the TCP IP Protocol without a TCPHeader" _LOGGER.error(msg) raise ValueError(msg) - if kwargs["ip"].protocol == IPProtocol.UDP and not kwargs.get("udp"): + if kwargs["ip"].protocol == PROTOCOL_LOOKUP["UDP"] and not kwargs.get("udp"): msg = "Cannot build a Frame using the UDP IP Protocol without a UDPHeader" _LOGGER.error(msg) raise ValueError(msg) - if kwargs["ip"].protocol == IPProtocol.ICMP and not kwargs.get("icmp"): + if kwargs["ip"].protocol == PROTOCOL_LOOKUP["ICMP"] and not kwargs.get("icmp"): msg = "Cannot build a Frame using the ICMP IP Protocol without a ICMPPacket" _LOGGER.error(msg) raise ValueError(msg) @@ -165,7 +167,7 @@ class Frame(BaseModel): :return: True if the Frame is an ARP packet, otherwise False. """ - return self.udp.dst_port == Port.ARP and self.udp.src_port == Port.ARP + return self.udp.dst_port == PORT_LOOKUP["ARP"] @property def is_icmp(self) -> bool: diff --git a/src/primaite/simulator/network/transmission/network_layer.py b/src/primaite/simulator/network/transmission/network_layer.py index d493cbdf..7a6b34c9 100644 --- a/src/primaite/simulator/network/transmission/network_layer.py +++ b/src/primaite/simulator/network/transmission/network_layer.py @@ -1,35 +1,15 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from enum import Enum from pydantic import BaseModel from primaite import getLogger -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address _LOGGER = getLogger(__name__) -class IPProtocol(Enum): - """ - Enum representing transport layer protocols in IP header. - - .. _List of IPProtocols: - """ - - NONE = "none" - """Placeholder for a non-protocol.""" - TCP = "tcp" - """Transmission Control Protocol.""" - UDP = "udp" - """User Datagram Protocol.""" - ICMP = "icmp" - """Internet Control Message Protocol.""" - - def model_dump(self) -> str: - """Return as JSON-serialisable string.""" - return self.name - - class Precedence(Enum): """ Enum representing the Precedence levels in Quality of Service (QoS) for IP packets. @@ -81,7 +61,7 @@ class IPPacket(BaseModel): >>> ip_packet = IPPacket( ... src_ip_address=IPv4Address('192.168.0.1'), ... dst_ip_address=IPv4Address('10.0.0.1'), - ... protocol=IPProtocol.TCP, + ... protocol=IPProtocol["TCP"], ... ttl=64, ... precedence=Precedence.CRITICAL ... ) @@ -91,7 +71,7 @@ class IPPacket(BaseModel): "Source IP address." dst_ip_address: IPV4Address "Destination IP address." - protocol: IPProtocol = IPProtocol.TCP + protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"] "IPProtocol." ttl: int = 64 "Time to Live (TTL) for the packet." diff --git a/src/primaite/simulator/network/transmission/primaite_layer.py b/src/primaite/simulator/network/transmission/primaite_layer.py index 981b6fbc..8ff4ac02 100644 --- a/src/primaite/simulator/network/transmission/primaite_layer.py +++ b/src/primaite/simulator/network/transmission/primaite_layer.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from enum import Enum from pydantic import BaseModel diff --git a/src/primaite/simulator/network/transmission/transport_layer.py b/src/primaite/simulator/network/transmission/transport_layer.py index 7f0d2d7a..689eea2f 100644 --- a/src/primaite/simulator/network/transmission/transport_layer.py +++ b/src/primaite/simulator/network/transmission/transport_layer.py @@ -1,82 +1,10 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from enum import Enum -from typing import List, Union +from typing import List from pydantic import BaseModel -class Port(Enum): - """ - Enumeration of common known TCP/UDP ports used by protocols for operation of network applications. - - .. _List of Ports: - """ - - UNUSED = -1 - "An unused port stub." - - NONE = 0 - "Place holder for a non-port." - WOL = 9 - "Wake-on-Lan (WOL) - Used to turn or awaken a computer from sleep mode by a network message." - FTP_DATA = 20 - "File Transfer [Default Data]" - FTP = 21 - "File Transfer Protocol (FTP) - FTP control (command)" - SSH = 22 - "Secure Shell (SSH) - Used for secure remote access and command execution." - SMTP = 25 - "Simple Mail Transfer Protocol (SMTP) - Used for email delivery between servers." - DNS = 53 - "Domain Name System (DNS) - Used for translating domain names to IP addresses." - HTTP = 80 - "HyperText Transfer Protocol (HTTP) - Used for web traffic." - POP3 = 110 - "Post Office Protocol version 3 (POP3) - Used for retrieving emails from a mail server." - SFTP = 115 - "Secure File Transfer Protocol (SFTP) - Used for secure file transfer over SSH." - NTP = 123 - "Network Time Protocol (NTP) - Used for clock synchronization between computer systems." - IMAP = 143 - "Internet Message Access Protocol (IMAP) - Used for retrieving emails from a mail server." - SNMP = 161 - "Simple Network Management Protocol (SNMP) - Used for network device management." - SNMP_TRAP = 162 - "SNMP Trap - Used for sending SNMP notifications (traps) to a network management system." - ARP = 219 - "Address resolution Protocol - Used to connect a MAC address to an IP address." - LDAP = 389 - "Lightweight Directory Access Protocol (LDAP) - Used for accessing and modifying directory information." - HTTPS = 443 - "HyperText Transfer Protocol Secure (HTTPS) - Used for secure web traffic." - SMB = 445 - "Server Message Block (SMB) - Used for file sharing and printer sharing in Windows environments." - IPP = 631 - "Internet Printing Protocol (IPP) - Used for printing over the internet or an intranet." - SQL_SERVER = 1433 - "Microsoft SQL Server Database Engine - Used for communication with the SQL Server." - MYSQL = 3306 - "MySQL Database Server - Used for MySQL database communication." - RDP = 3389 - "Remote Desktop Protocol (RDP) - Used for remote desktop access to Windows machines." - RTP = 5004 - "Real-time Transport Protocol (RTP) - Used for transmitting real-time media, e.g., audio and video." - RTP_ALT = 5005 - "Alternative port for RTP (RTP_ALT) - Used in some configurations for transmitting real-time media." - DNS_ALT = 5353 - "Alternative port for DNS (DNS_ALT) - Used in some configurations for DNS service." - HTTP_ALT = 8080 - "Alternative port for HTTP (HTTP_ALT) - Often used as an alternative HTTP port for web applications." - HTTPS_ALT = 8443 - "Alternative port for HTTPS (HTTPS_ALT) - Used in some configurations for secure web traffic." - POSTGRES_SERVER = 5432 - "Postgres SQL Server." - - def model_dump(self) -> str: - """Return a json-serialisable string.""" - return self.name - - class UDPHeader(BaseModel): """ Represents a UDP header for the transport layer of a Network Frame. @@ -87,13 +15,13 @@ class UDPHeader(BaseModel): :Example: >>> udp_header = UDPHeader( - ... src_port=Port.HTTP_ALT, - ... dst_port=Port.HTTP, + ... src_port=Port["HTTP_ALT"], + ... dst_port=Port["HTTP"], ... ) """ - src_port: Union[Port, int] - dst_port: Union[Port, int] + src_port: int + dst_port: int class TCPFlags(Enum): @@ -126,12 +54,12 @@ class TCPHeader(BaseModel): :Example: >>> tcp_header = TCPHeader( - ... src_port=Port.HTTP_ALT, - ... dst_port=Port.HTTP, + ... src_port=Port["HTTP_ALT"], + ... dst_port=Port["HTTP"], ... flags=[TCPFlags.SYN, TCPFlags.ACK] ... ) """ - src_port: Port - dst_port: Port + src_port: int + dst_port: int flags: List[TCPFlags] = [TCPFlags.SYN] diff --git a/src/primaite/simulator/network/utils.py b/src/primaite/simulator/network/utils.py index 4fd1834a..b4d6c815 100644 --- a/src/primaite/simulator/network/utils.py +++ b/src/primaite/simulator/network/utils.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Union diff --git a/src/primaite/simulator/sim_container.py b/src/primaite/simulator/sim_container.py index 809b52db..abc83203 100644 --- a/src/primaite/simulator/sim_container.py +++ b/src/primaite/simulator/sim_container.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Dict from primaite.interface.request import RequestResponse @@ -38,8 +38,8 @@ class Simulation(SimComponent): rm.add_request("network", RequestType(func=self.network._request_manager)) # pass through domain requests to the domain object rm.add_request("domain", RequestType(func=self.domain._request_manager)) - # if 'do_nothing' is requested, just return a success - rm.add_request("do_nothing", RequestType(func=lambda request, context: RequestResponse(status="success"))) + # if 'do-nothing' is requested, just return a success + rm.add_request("do-nothing", RequestType(func=lambda request, context: RequestResponse(status="success"))) return rm def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/__init__.py b/src/primaite/simulator/system/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/system/__init__.py +++ b/src/primaite/simulator/system/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/system/applications/__init__.py b/src/primaite/simulator/system/applications/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/system/applications/__init__.py +++ b/src/primaite/simulator/system/applications/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 741f491d..1de29c33 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -1,10 +1,12 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations -from abc import abstractmethod +from abc import ABC, abstractmethod from enum import Enum from typing import Any, ClassVar, Dict, Optional, Set, Type +from pydantic import Field + from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType from primaite.simulator.system.software import IOSoftware, SoftwareHealthState @@ -21,13 +23,20 @@ class ApplicationOperatingState(Enum): "The application is being installed or updated." -class Application(IOSoftware): +class Application(IOSoftware, ABC): """ Represents an Application in the simulation environment. Applications are user-facing programs that may perform input/output operations. """ + class ConfigSchema(IOSoftware.ConfigSchema, ABC): + """Config Schema for Application class.""" + + type: str + + config: ConfigSchema = Field(default_factory=lambda: Application.ConfigSchema()) + operating_state: ApplicationOperatingState = ApplicationOperatingState.CLOSED "The current operating state of the Application." execution_control_status: str = "manual" @@ -41,21 +50,38 @@ class Application(IOSoftware): install_countdown: Optional[int] = None "The countdown to the end of the installation process. None if not currently installing" - _application_registry: ClassVar[Dict[str, Type["Application"]]] = {} + _registry: ClassVar[Dict[str, Type["Application"]]] = {} """Registry of application types. Automatically populated when subclasses are defined.""" - def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: + def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None: """ Register an application type. - :param identifier: Uniquely specifies an application class by name. Used for finding items by config. - :type identifier: str + :param discriminator: Uniquely specifies an application class by name. Used for finding items by config. + :type discriminator: Optional[str] :raises ValueError: When attempting to register an application with a name that is already allocated. """ super().__init_subclass__(**kwargs) - if identifier in cls._application_registry: - raise ValueError(f"Tried to define new application {identifier}, but this name is already reserved.") - cls._application_registry[identifier] = cls + if discriminator is None: + return + if discriminator in cls._registry: + raise ValueError(f"Tried to define new application {discriminator}, but this name is already reserved.") + cls._registry[discriminator] = cls + + @classmethod + def from_config(cls, config: Dict) -> "Application": + """Create an application from a config dictionary. + + :param config: dict of options for application components constructor + :type config: dict + :return: The application component. + :rtype: Application + """ + if config["type"] not in cls._registry: + raise ValueError(f"Invalid Application type {config['type']}") + application_class = cls._registry[config["type"]] + application_object = application_class(config=application_class.ConfigSchema(**config)) + return application_object def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 3f80c745..14f2db21 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations from ipaddress import IPv4Address @@ -6,16 +6,16 @@ from typing import Any, Dict, Optional, Union from uuid import uuid4 from prettytable import MARKDOWN, PrettyTable -from pydantic import BaseModel +from pydantic import BaseModel, Field from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.hardware.nodes.host.host_node import HostNode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.software_manager import SoftwareManager -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import PORT_LOOKUP class DatabaseClientConnection(BaseModel): @@ -37,7 +37,7 @@ class DatabaseClientConnection(BaseModel): @property def client(self) -> Optional[DatabaseClient]: """The DatabaseClient that holds this connection.""" - return self.parent_node.software_manager.software.get("DatabaseClient") + return self.parent_node.software_manager.software.get("database-client") def query(self, sql: str) -> bool: """ @@ -61,17 +61,25 @@ class DatabaseClientConnection(BaseModel): return str(self) -class DatabaseClient(Application, identifier="DatabaseClient"): +class DatabaseClient(Application, discriminator="database-client"): """ A DatabaseClient application. Extends the Application class to provide functionality for connecting, querying, and disconnecting from a Database Service. It mainly operates over TCP protocol. - - :ivar server_ip_address: The IPv4 address of the Database Service server, defaults to None. """ + class ConfigSchema(Application.ConfigSchema): + """ConfigSchema for DatabaseClient.""" + + type: str = "database-client" + db_server_ip: Optional[IPV4Address] = None + server_password: Optional[str] = None + + config: ConfigSchema = Field(default_factory=lambda: DatabaseClient.ConfigSchema()) + server_ip_address: Optional[IPv4Address] = None + """The IPv4 address of the Database Service server, defaults to None.""" server_password: Optional[str] = None _query_success_tracker: Dict[str, bool] = {} """Keep track of connections that were established or verified during this step. Used for rewards.""" @@ -89,10 +97,12 @@ class DatabaseClient(Application, identifier="DatabaseClient"): """Native Client Connection for using the client directly (similar to psql in a terminal).""" def __init__(self, **kwargs): - kwargs["name"] = "DatabaseClient" - kwargs["port"] = Port.POSTGRES_SERVER - kwargs["protocol"] = IPProtocol.TCP + kwargs["name"] = "database-client" + kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) + self.server_ip_address = self.config.db_server_ip + self.server_password = self.config.server_password def _init_request_manager(self) -> RequestManager: """ @@ -308,6 +318,9 @@ class DatabaseClient(Application, identifier="DatabaseClient"): """ if not self._can_perform_action(): return None + if self.server_ip_address is None: + self.sys_log.warning(f"{self.name}: Database server IP address not provided.") + return None connection_request_id = str(uuid4()) self._client_connection_requests[connection_request_id] = None diff --git a/src/primaite/simulator/system/applications/nmap.py b/src/primaite/simulator/system/applications/nmap.py index c87eaaf5..90debcd6 100644 --- a/src/primaite/simulator/system/applications/nmap.py +++ b/src/primaite/simulator/system/applications/nmap.py @@ -1,16 +1,16 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address, IPv4Network from typing import Any, Dict, Final, List, Optional, Set, Tuple, Union from prettytable import PrettyTable -from pydantic import validate_call +from pydantic import Field, validate_call from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ip_protocol import IPProtocol, is_valid_protocol, PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import is_valid_port, Port, PORT_LOOKUP class PortScanPayload(SimComponent): @@ -37,14 +37,14 @@ class PortScanPayload(SimComponent): """ state = super().describe_state() state["ip_address"] = str(self.ip_address) - state["port"] = self.port.value - state["protocol"] = self.protocol.value + state["port"] = self.port + state["protocol"] = self.protocol state["request"] = self.request return state -class NMAP(Application, identifier="NMAP"): +class NMAP(Application, discriminator="nmap"): """ A class representing the NMAP application for network scanning. @@ -52,6 +52,13 @@ class NMAP(Application, identifier="NMAP"): as ping scans to discover active hosts and port scans to detect open ports on those hosts. """ + class ConfigSchema(Application.ConfigSchema): + """ConfigSchema for NMAP.""" + + type: str = "nmap" + + config: "NMAP.ConfigSchema" = Field(default_factory=lambda: NMAP.ConfigSchema()) + _active_port_scans: Dict[str, PortScanPayload] = {} _port_scan_responses: Dict[str, PortScanPayload] = {} @@ -63,9 +70,9 @@ class NMAP(Application, identifier="NMAP"): } def __init__(self, **kwargs): - kwargs["name"] = "NMAP" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["name"] = "nmap" + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) def _can_perform_network_action(self) -> bool: @@ -201,7 +208,7 @@ class NMAP(Application, identifier="NMAP"): if show: table = PrettyTable(["IP Address", "Can Ping"]) table.align = "l" - table.title = f"{self.software_manager.node.hostname} NMAP Ping Scan" + table.title = f"{self.software_manager.node.config.hostname} NMAP Ping Scan" ip_addresses = self._explode_ip_address_network_array(target_ip_address) @@ -272,8 +279,8 @@ class NMAP(Application, identifier="NMAP"): payload = PortScanPayload(ip_address=ip_address, port=port, protocol=protocol) self._active_port_scans[payload.uuid] = payload self.sys_log.info( - f"{self.name}: Sending port scan request over {payload.protocol.name} on port {payload.port.value} " - f"({payload.port.name}) to {payload.ip_address}" + f"{self.name}: Sending port scan request over {payload.protocol} on port {payload.port} " + f"({payload.port}) to {payload.ip_address}" ) self.software_manager.send_payload_to_session_manager( payload=payload, dest_ip_address=ip_address, src_port=port, dest_port=port, ip_protocol=protocol @@ -295,8 +302,8 @@ class NMAP(Application, identifier="NMAP"): self._active_port_scans.pop(payload.uuid) self._port_scan_responses[payload.uuid] = payload self.sys_log.info( - f"{self.name}: Received port scan response from {payload.ip_address} on port {payload.port.value} " - f"({payload.port.name}) over {payload.protocol.name}" + f"{self.name}: Received port scan response from {payload.ip_address} on port {payload.port} " + f"({payload.port}) over {payload.protocol}" ) def _process_port_scan_request(self, payload: PortScanPayload, session_id: str) -> None: @@ -311,8 +318,8 @@ class NMAP(Application, identifier="NMAP"): if self.software_manager.check_port_is_open(port=payload.port, protocol=payload.protocol): payload.request = False self.sys_log.info( - f"{self.name}: Responding to port scan request for port {payload.port.value} " - f"({payload.port.name}) over {payload.protocol.name}", + f"{self.name}: Responding to port scan request for port {payload.port} " + f"({payload.port}) over {payload.protocol}", ) self.software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id) @@ -345,22 +352,22 @@ class NMAP(Application, identifier="NMAP"): """ ip_addresses = self._explode_ip_address_network_array(target_ip_address) - if isinstance(target_port, Port): + if is_valid_port(target_port): target_port = [target_port] elif target_port is None: - target_port = [port for port in Port if port not in {Port.NONE, Port.UNUSED}] + target_port = [PORT_LOOKUP[port] for port in PORT_LOOKUP if port not in {"NONE", "UNUSED"}] - if isinstance(target_protocol, IPProtocol): + if is_valid_protocol(target_protocol): target_protocol = [target_protocol] elif target_protocol is None: - target_protocol = [IPProtocol.TCP, IPProtocol.UDP] + target_protocol = [PROTOCOL_LOOKUP["TCP"], PROTOCOL_LOOKUP["UDP"]] scan_type = self._determine_port_scan_type(list(ip_addresses), target_port) active_ports = {} if show: - table = PrettyTable(["IP Address", "Port", "Name", "Protocol"]) + table = PrettyTable(["IP Address", "Port", "Protocol"]) table.align = "l" - table.title = f"{self.software_manager.node.hostname} NMAP Port Scan ({scan_type})" + table.title = f"{self.software_manager.node.config.hostname} NMAP Port Scan ({scan_type})" self.sys_log.info(f"{self.name}: Starting port scan") for ip_address in ip_addresses: # Prevent port scan on this node @@ -369,13 +376,12 @@ class NMAP(Application, identifier="NMAP"): for protocol in target_protocol: for port in set(target_port): port_open = self._check_port_open_on_ip_address(ip_address=ip_address, port=port, protocol=protocol) - if port_open: if show: - table.add_row([ip_address, port.value, port.name, protocol.name]) + table.add_row([ip_address, port, protocol]) _ip_address = ip_address if not json_serializable else str(ip_address) - _protocol = protocol if not json_serializable else protocol.value - _port = port if not json_serializable else port.value + _protocol = protocol + _port = port if _ip_address not in active_ports: active_ports[_ip_address] = dict() if _protocol not in active_ports[_ip_address]: diff --git a/src/primaite/simulator/system/applications/red_applications/__init__.py b/src/primaite/simulator/system/applications/red_applications/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/system/applications/red_applications/__init__.py +++ b/src/primaite/simulator/system/applications/red_applications/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/system/applications/red_applications/c2/__init__.py b/src/primaite/simulator/system/applications/red_applications/c2/__init__.py index 60e39743..33cc555f 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/__init__.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/__init__.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Optional, Union from pydantic import BaseModel, Field, field_validator, ValidationInfo diff --git a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py index 5d4cc8e0..1c2c1179 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py @@ -1,22 +1,22 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from abc import abstractmethod from enum import Enum from ipaddress import IPv4Address -from typing import Dict, Optional, Union +from typing import Dict, Optional, Set, Union -from pydantic import BaseModel, Field, validate_call +from pydantic import Field, validate_call from primaite.interface.request import RequestResponse from primaite.simulator.file_system.file_system import FileSystem, Folder from primaite.simulator.network.protocols.masquerade import C2Packet -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application, ApplicationOperatingState from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.ip_protocol import IPProtocol, is_valid_protocol, PROTOCOL_LOOKUP +from primaite.utils.validation.port import is_valid_port, Port, PORT_LOOKUP class C2Command(Enum): @@ -45,10 +45,10 @@ class C2Payload(Enum): """C2 Input Command payload. Used by the C2 Server to send a command to the c2 beacon.""" OUTPUT = "output_command" - """C2 Output Command. Used by the C2 Beacon to send the results of a Input command to the c2 server.""" + """C2 Output Command. Used by the C2 Beacon to send the results of an Input command to the c2 server.""" -class AbstractC2(Application, identifier="AbstractC2"): +class AbstractC2(Application): """ An abstract command and control (c2) application. @@ -60,9 +60,25 @@ class AbstractC2(Application, identifier="AbstractC2"): Defaults to masquerading as HTTP (Port 80) via TCP. - Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite. + Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite. """ + class ConfigSchema(Application.ConfigSchema): + """Configuration for AbstractC2.""" + + keep_alive_frequency: int = Field(default=5, ge=1) + """The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon.""" + + masquerade_protocol: IPProtocol = Field(default=PROTOCOL_LOOKUP["TCP"]) + """The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP.""" + + masquerade_port: Port = Field(default=PORT_LOOKUP["HTTP"]) + """The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP.""" + + listen_on_ports: Set[Port] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]} + + config: ConfigSchema = Field(default_factory=lambda: AbstractC2.ConfigSchema()) + c2_connection_active: bool = False """Indicates if the c2 server and c2 beacon are currently connected.""" @@ -75,19 +91,6 @@ class AbstractC2(Application, identifier="AbstractC2"): keep_alive_inactivity: int = 0 """Indicates how many timesteps since the last time the c2 application received a keep alive.""" - class _C2Opts(BaseModel): - """A Pydantic Schema for the different C2 configuration options.""" - - keep_alive_frequency: int = Field(default=5, ge=1) - """The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon.""" - - masquerade_protocol: IPProtocol = Field(default=IPProtocol.TCP) - """The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP.""" - - masquerade_port: Port = Field(default=Port.HTTP) - """The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP.""" - - c2_config: _C2Opts = _C2Opts() """ Holds the current configuration settings of the C2 Suite. @@ -100,6 +103,12 @@ class AbstractC2(Application, identifier="AbstractC2"): C2 beacon to reconfigure it's configuration settings. """ + def __init__(self, **kwargs): + """Initialise the C2 applications to by default listen for HTTP traffic.""" + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] + super().__init__(**kwargs) + def _craft_packet( self, c2_payload: C2Payload, c2_command: Optional[C2Command] = None, command_options: Optional[Dict] = {} ) -> C2Packet: @@ -118,13 +127,13 @@ class AbstractC2(Application, identifier="AbstractC2"): :type c2_command: C2Command. :param command_options: The relevant C2 Beacon parameters.F :type command_options: Dict - :return: Returns the construct C2Packet + :return: Returns the constructed C2Packet :rtype: C2Packet """ constructed_packet = C2Packet( - masquerade_protocol=self.c2_config.masquerade_protocol, - masquerade_port=self.c2_config.masquerade_port, - keep_alive_frequency=self.c2_config.keep_alive_frequency, + masquerade_protocol=self.config.masquerade_protocol, + masquerade_port=self.config.masquerade_port, + keep_alive_frequency=self.config.keep_alive_frequency, payload_type=c2_payload, command=c2_command, payload=command_options, @@ -140,13 +149,6 @@ class AbstractC2(Application, identifier="AbstractC2"): """ return super().describe_state() - def __init__(self, **kwargs): - """Initialise the C2 applications to by default listen for HTTP traffic.""" - kwargs["listen_on_ports"] = {Port.HTTP, Port.FTP, Port.DNS} - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.TCP - super().__init__(**kwargs) - @property def _host_ftp_client(self) -> Optional[FTPClient]: """Return the FTPClient that is installed C2 Application's host. @@ -160,11 +162,11 @@ class AbstractC2(Application, identifier="AbstractC2"): :return: An FTPClient object is successful, else None :rtype: union[FTPClient, None] """ - ftp_client: Union[FTPClient, None] = self.software_manager.software.get("FTPClient") + ftp_client: Union[FTPClient, None] = self.software_manager.software.get("ftp-client") if ftp_client is None: self.sys_log.warning(f"{self.__class__.__name__}: No FTPClient. Attempting to install.") self.software_manager.install(FTPClient) - ftp_client = self.software_manager.software.get("FTPClient") + ftp_client = self.software_manager.software.get("ftp-client") # Force start if the service is stopped. if ftp_client.operating_state == ServiceOperatingState.STOPPED: @@ -187,11 +189,11 @@ class AbstractC2(Application, identifier="AbstractC2"): :return: An FTPServer object is successful, else None :rtype: Optional[FTPServer] """ - ftp_server: Optional[FTPServer] = self.software_manager.software.get("FTPServer") + ftp_server: Optional[FTPServer] = self.software_manager.software.get("ftp-server") if ftp_server is None: self.sys_log.warning(f"{self.__class__.__name__}:No FTPServer installed. Attempting to install FTPServer.") self.software_manager.install(FTPServer) - ftp_server = self.software_manager.software.get("FTPServer") + ftp_server = self.software_manager.software.get("ftp-server") # Force start if the service is stopped. if ftp_server.operating_state == ServiceOperatingState.STOPPED: @@ -330,8 +332,8 @@ class AbstractC2(Application, identifier="AbstractC2"): if self.send( payload=keep_alive_packet, dest_ip_address=self.c2_remote_connection, - dest_port=self.c2_config.masquerade_port, - ip_protocol=self.c2_config.masquerade_protocol, + dest_port=self.config.masquerade_port, + ip_protocol=self.config.masquerade_protocol, session_id=session_id, ): # Setting the keep_alive_sent guard condition to True. This is used to prevent packet storms. @@ -340,8 +342,8 @@ class AbstractC2(Application, identifier="AbstractC2"): self.sys_log.info(f"{self.name}: Keep Alive sent to {self.c2_remote_connection}") self.sys_log.debug( f"{self.name}: Keep Alive sent to {self.c2_remote_connection} " - f"Masquerade Port: {self.c2_config.masquerade_port} " - f"Masquerade Protocol: {self.c2_config.masquerade_protocol} " + f"Masquerade Port: {self.config.masquerade_port} " + f"Masquerade Protocol: {self.config.masquerade_protocol} " ) return True else: @@ -366,8 +368,8 @@ class AbstractC2(Application, identifier="AbstractC2"): :return: True on successful configuration, false otherwise. :rtype: bool """ - # Validating that they are valid Enums. - if not isinstance(payload.masquerade_port, Port) or not isinstance(payload.masquerade_protocol, IPProtocol): + # Validating that they are valid Ports and Protocols. + if not is_valid_port(payload.masquerade_port) or not is_valid_protocol(payload.masquerade_protocol): self.sys_log.warning( f"{self.name}: Received invalid Masquerade Values within Keep Alive." f"Port: {payload.masquerade_port} Protocol: {payload.masquerade_protocol}." @@ -376,15 +378,15 @@ class AbstractC2(Application, identifier="AbstractC2"): # Updating the C2 Configuration attribute. - self.c2_config.masquerade_port = payload.masquerade_port - self.c2_config.masquerade_protocol = payload.masquerade_protocol - self.c2_config.keep_alive_frequency = payload.keep_alive_frequency + self.config.masquerade_port = payload.masquerade_port + self.config.masquerade_protocol = payload.masquerade_protocol + self.config.keep_alive_frequency = payload.keep_alive_frequency self.sys_log.debug( f"{self.name}: C2 Config Resolved Config from Keep Alive:" - f"Masquerade Port: {self.c2_config.masquerade_port}" - f"Masquerade Protocol: {self.c2_config.masquerade_protocol}" - f"Keep Alive Frequency: {self.c2_config.keep_alive_frequency}" + f"Masquerade Port: {self.config.masquerade_port}" + f"Masquerade Protocol: {self.config.masquerade_protocol}" + f"Keep Alive Frequency: {self.config.keep_alive_frequency}" ) # This statement is intended to catch on the C2 Application that is listening for connection. @@ -410,8 +412,8 @@ class AbstractC2(Application, identifier="AbstractC2"): self.keep_alive_inactivity = 0 self.keep_alive_frequency = 5 self.c2_remote_connection = None - self.c2_config.masquerade_port = Port.HTTP - self.c2_config.masquerade_protocol = IPProtocol.TCP + self.config.masquerade_port = PORT_LOOKUP["HTTP"] + self.config.masquerade_protocol = PROTOCOL_LOOKUP["TCP"] @abstractmethod def _confirm_remote_connection(self, timestep: int) -> bool: diff --git a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py index fa0271e5..486a0eaf 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py @@ -1,23 +1,23 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from enum import Enum +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address from typing import Dict, Optional from prettytable import MARKDOWN, PrettyTable -from pydantic import validate_call +from pydantic import Field, validate_call from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.protocols.masquerade import C2Packet -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.red_applications.c2 import ExfilOpts, RansomwareOpts, TerminalOpts from primaite.simulator.system.applications.red_applications.c2.abstract_c2 import AbstractC2, C2Command, C2Payload from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.terminal.terminal import Terminal, TerminalClientConnection +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import Port, PORT_LOOKUP -class C2Beacon(AbstractC2, identifier="C2Beacon"): +class C2Beacon(AbstractC2, discriminator="c2-beacon"): """ C2 Beacon Application. @@ -33,19 +33,34 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): 2. Leveraging the terminal application to execute requests (dependent on the command given) 3. Sending the RequestResponse back to the C2 Server (Command output) - Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite. + Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite. """ + class ConfigSchema(AbstractC2.ConfigSchema): + """ConfigSchema for C2Beacon.""" + + type: str = "c2-beacon" + c2_server_ip_address: Optional[IPV4Address] = None + keep_alive_frequency: int = 5 + masquerade_protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"] + masquerade_port: Port = PORT_LOOKUP["HTTP"] + + config: ConfigSchema = Field(default_factory=lambda: C2Beacon.ConfigSchema()) + keep_alive_attempted: bool = False """Indicates if a keep alive has been attempted to be sent this timestep. Used to prevent packet storms.""" terminal_session: TerminalClientConnection = None "The currently in use terminal session." + def __init__(self, **kwargs): + kwargs["name"] = "c2-beacon" + super().__init__(**kwargs) + @property def _host_terminal(self) -> Optional[Terminal]: - """Return the Terminal that is installed on the same machine as the C2 Beacon.""" - host_terminal: Terminal = self.software_manager.software.get("Terminal") + """Return the terminal that is installed on the same machine as the C2 Beacon.""" + host_terminal: Terminal = self.software_manager.software.get("terminal") if host_terminal is None: self.sys_log.warning(f"{self.__class__.__name__} cannot find a terminal on its host.") return host_terminal @@ -53,7 +68,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): @property def _host_ransomware_script(self) -> RansomwareScript: """Return the RansomwareScript that is installed on the same machine as the C2 Beacon.""" - ransomware_script: RansomwareScript = self.software_manager.software.get("RansomwareScript") + ransomware_script: RansomwareScript = self.software_manager.software.get("ransomware-script") if ransomware_script is None: self.sys_log.warning(f"{self.__class__.__name__} cannot find installed ransomware on its host.") return ransomware_script @@ -112,26 +127,22 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): self.configure( c2_server_ip_address=c2_remote_ip, keep_alive_frequency=frequency, - masquerade_protocol=IPProtocol[protocol], - masquerade_port=Port[port], + masquerade_protocol=protocol, + masquerade_port=port, ) ) rm.add_request("configure", request_type=RequestType(func=_configure)) return rm - def __init__(self, **kwargs): - kwargs["name"] = "C2Beacon" - super().__init__(**kwargs) - # Configure is practically setter method for the ``c2.config`` attribute that also ties into the request manager. @validate_call def configure( self, c2_server_ip_address: IPv4Address = None, keep_alive_frequency: int = 5, - masquerade_protocol: Enum = IPProtocol.TCP, - masquerade_port: Enum = Port.HTTP, + masquerade_protocol: str = PROTOCOL_LOOKUP["TCP"], + masquerade_port: int = PORT_LOOKUP["HTTP"], ) -> bool: """ Configures the C2 beacon to communicate with the C2 server. @@ -147,7 +158,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): masquerade_port | What port should the C2 traffic use? (TCP or UDP) These configuration options are used to reassign the fields in the inherited inner class - ``c2_config``. + ``config``. If a connection is already in progress then this method also sends a keep alive to the C2 Server in order for the C2 Server to sync with the new configuration settings. @@ -163,9 +174,9 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): :return: Returns True if the configuration was successful, False otherwise. """ self.c2_remote_connection = IPv4Address(c2_server_ip_address) - self.c2_config.keep_alive_frequency = keep_alive_frequency - self.c2_config.masquerade_port = masquerade_port - self.c2_config.masquerade_protocol = masquerade_protocol + self.config.keep_alive_frequency = keep_alive_frequency + self.config.masquerade_port = masquerade_port + self.config.masquerade_protocol = masquerade_protocol self.sys_log.info( f"{self.name}: Configured {self.name} with remote C2 server connection: {c2_server_ip_address=}." ) @@ -264,14 +275,12 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): if self.send( payload=output_packet, dest_ip_address=self.c2_remote_connection, - dest_port=self.c2_config.masquerade_port, - ip_protocol=self.c2_config.masquerade_protocol, + dest_port=self.config.masquerade_port, + ip_protocol=self.config.masquerade_protocol, session_id=session_id, ): self.sys_log.info(f"{self.name}: Command output sent to {self.c2_remote_connection}") - self.sys_log.debug( - f"{self.name}: on {self.c2_config.masquerade_port} via {self.c2_config.masquerade_protocol}" - ) + self.sys_log.debug(f"{self.name}: on {self.config.masquerade_port} via {self.config.masquerade_protocol}") return True else: self.sys_log.warning( @@ -291,7 +300,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): :payload C2Packet: The incoming INPUT command. :type Masquerade Packet: C2Packet. - :return: Returns the Request Response returned by the Terminal execute method. + :return: Returns the Request Response returned by the terminal execute method. :rtype: Request Response """ command_opts = RansomwareOpts.model_validate(payload.payload) @@ -315,7 +324,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): :payload C2Packet: The incoming INPUT command. :type Masquerade Packet: C2Packet. - :return: Returns the Request Response returned by the Terminal execute method. + :return: Returns the Request Response returned by the terminal execute method. :rtype: Request Response """ if self._host_ransomware_script is None: @@ -342,7 +351,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): :payload C2Packet: The incoming INPUT command. :type Masquerade Packet: C2Packet. - :return: Returns a tuple containing Request Response returned by the Terminal execute method. + :return: Returns a tuple containing Request Response returned by the terminal execute method. :rtype: Request Response """ if self._host_ftp_server is None: @@ -363,7 +372,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): ) # Using the terminal to start the FTP Client on the remote machine. - self.terminal_session.execute(command=["service", "start", "FTPClient"]) + self.terminal_session.execute(command=["service", "start", "ftp-client"]) # Need to supply to the FTP Client the C2 Beacon's host IP. host_network_interfaces = self.software_manager.node.network_interfaces @@ -421,7 +430,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): # Using the terminal to send the target data back to the C2 Beacon. exfil_response: RequestResponse = RequestResponse.from_bool( - self.terminal_session.execute(command=["service", "FTPClient", "send", ftp_opts]) + self.terminal_session.execute(command=["service", "ftp-client", "send", ftp_opts]) ) # Validating that we successfully received the target data. @@ -463,14 +472,14 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): def _command_terminal(self, payload: C2Packet) -> RequestResponse: """ - C2 Command: Terminal. + C2 Command: terminal. Creates a request that executes a terminal command. This request is then sent to the terminal service in order to be executed. :payload C2Packet: The incoming INPUT command. :type Masquerade Packet: C2Packet. - :return: Returns the Request Response returned by the Terminal execute method. + :return: Returns the Request Response returned by the terminal execute method. :rtype: Request Response """ command_opts = TerminalOpts.model_validate(payload.payload) @@ -563,7 +572,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): :rtype bool: """ self.keep_alive_attempted = False # Resetting keep alive sent. - if self.keep_alive_inactivity == self.c2_config.keep_alive_frequency: + if self.keep_alive_inactivity == self.config.keep_alive_frequency: self.sys_log.info( f"{self.name}: Attempting to Send Keep Alive to {self.c2_remote_connection} at timestep {timestep}." ) @@ -628,9 +637,9 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): self.c2_connection_active, self.c2_remote_connection, self.keep_alive_inactivity, - self.c2_config.keep_alive_frequency, - self.c2_config.masquerade_protocol, - self.c2_config.masquerade_port, + self.config.keep_alive_frequency, + self.config.masquerade_protocol, + self.config.masquerade_port, ] ) print(table) diff --git a/src/primaite/simulator/system/applications/red_applications/c2/c2_server.py b/src/primaite/simulator/system/applications/red_applications/c2/c2_server.py index f948d696..987029e4 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/c2_server.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/c2_server.py @@ -1,8 +1,8 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Dict, Optional from prettytable import MARKDOWN, PrettyTable -from pydantic import validate_call +from pydantic import Field, validate_call from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType @@ -16,7 +16,7 @@ from primaite.simulator.system.applications.red_applications.c2 import ( from primaite.simulator.system.applications.red_applications.c2.abstract_c2 import AbstractC2, C2Command, C2Payload -class C2Server(AbstractC2, identifier="C2Server"): +class C2Server(AbstractC2, discriminator="c2-server"): """ C2 Server Application. @@ -31,9 +31,16 @@ class C2Server(AbstractC2, identifier="C2Server"): 1. Sending commands to the C2 Beacon. (Command input) 2. Parsing terminal RequestResponses back to the Agent. - Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite. + Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite. """ + class ConfigSchema(AbstractC2.ConfigSchema): + """ConfigSchema for C2Server.""" + + type: str = "c2-server" + + config: ConfigSchema = Field(default_factory=lambda: C2Server.ConfigSchema()) + current_command_output: RequestResponse = None """The Request Response by the last command send. This attribute is updated by the method _handle_command_output.""" @@ -118,7 +125,7 @@ class C2Server(AbstractC2, identifier="C2Server"): return rm def __init__(self, **kwargs): - kwargs["name"] = "C2Server" + kwargs["name"] = "c2-server" super().__init__(**kwargs) self.run() @@ -251,8 +258,8 @@ class C2Server(AbstractC2, identifier="C2Server"): payload=command_packet, dest_ip_address=self.c2_remote_connection, session_id=self.c2_session.uuid, - dest_port=self.c2_config.masquerade_port, - ip_protocol=self.c2_config.masquerade_protocol, + dest_port=self.config.masquerade_port, + ip_protocol=self.config.masquerade_protocol, ): self.sys_log.info(f"{self.name}: Successfully sent {given_command}.") self.sys_log.info(f"{self.name}: Awaiting command response {given_command}.") @@ -334,11 +341,11 @@ class C2Server(AbstractC2, identifier="C2Server"): :return: Returns False if the C2 beacon is considered dead. Otherwise True. :rtype bool: """ - if self.keep_alive_inactivity > self.c2_config.keep_alive_frequency: + if self.keep_alive_inactivity > self.config.keep_alive_frequency: self.sys_log.info(f"{self.name}: C2 Beacon connection considered dead due to inactivity.") self.sys_log.debug( f"{self.name}: Did not receive expected keep alive connection from {self.c2_remote_connection}" - f"{self.name}: Expected at timestep: {timestep} due to frequency: {self.c2_config.keep_alive_frequency}" + f"{self.name}: Expected at timestep: {timestep} due to frequency: {self.config.keep_alive_frequency}" f"{self.name}: Last Keep Alive received at {(timestep - self.keep_alive_inactivity)}" ) self._reset_c2_connection() @@ -389,8 +396,8 @@ class C2Server(AbstractC2, identifier="C2Server"): [ self.c2_connection_active, self.c2_remote_connection, - self.c2_config.masquerade_protocol, - self.c2_config.masquerade_port, + self.config.masquerade_protocol, + self.config.masquerade_port, ] ) print(table) diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index fefb22c3..17862df4 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -1,16 +1,19 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from enum import IntEnum from ipaddress import IPv4Address from typing import Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.game.science import simulate_trial from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -37,9 +40,22 @@ class DataManipulationAttackStage(IntEnum): "Signifies that the attack has failed." -class DataManipulationBot(Application, identifier="DataManipulationBot"): +class DataManipulationBot(Application, discriminator="data-manipulation-bot"): """A bot that simulates a script which performs a SQL injection attack.""" + class ConfigSchema(Application.ConfigSchema): + """Configuration schema for DataManipulationBot.""" + + type: str = "data-manipulation-bot" + server_ip: Optional[IPV4Address] = None + server_password: Optional[str] = None + payload: str = "DELETE" + port_scan_p_of_success: float = 0.1 + data_manipulation_p_of_success: float = 0.1 + repeat: bool = True + + config: "DataManipulationBot.ConfigSchema" = Field(default_factory=lambda: DataManipulationBot.ConfigSchema()) + payload: Optional[str] = None port_scan_p_of_success: float = 0.1 data_manipulation_p_of_success: float = 0.1 @@ -49,13 +65,20 @@ class DataManipulationBot(Application, identifier="DataManipulationBot"): "Whether to repeat attacking once finished." def __init__(self, **kwargs): - kwargs["name"] = "DataManipulationBot" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["name"] = "data-manipulation-bot" + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) self._db_connection: Optional[DatabaseClientConnection] = None + self.server_ip_address = self.config.server_ip + self.server_password = self.config.server_password + self.payload = self.config.payload + self.port_scan_p_of_success = self.config.port_scan_p_of_success + self.data_manipulation_p_of_success = self.config.data_manipulation_p_of_success + self.repeat = self.config.repeat + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -71,7 +94,7 @@ class DataManipulationBot(Application, identifier="DataManipulationBot"): @property def _host_db_client(self) -> DatabaseClient: """Return the database client that is installed on the same machine as the DataManipulationBot.""" - db_client = self.software_manager.software.get("DatabaseClient") + db_client = self.software_manager.software.get("database-client") if db_client is None: self.sys_log.warning(f"{self.__class__.__name__} cannot find a database client on its host.") return db_client diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index fcad3b3e..1528de57 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -1,14 +1,17 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from enum import IntEnum from ipaddress import IPv4Address from typing import Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.game.science import simulate_trial from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.utils.validation.ipv4_address import ipv4_validator, IPV4Address +from primaite.utils.validation.port import Port, PORT_LOOKUP, port_validator _LOGGER = getLogger(__name__) @@ -29,9 +32,23 @@ class DoSAttackStage(IntEnum): "Attack is completed." -class DoSBot(DatabaseClient, identifier="DoSBot"): +class DoSBot(DatabaseClient, discriminator="dos-bot"): """A bot that simulates a Denial of Service attack.""" + class ConfigSchema(DatabaseClient.ConfigSchema): + """ConfigSchema for DoSBot.""" + + type: str = "dos-bot" + target_ip_address: Optional[IPV4Address] = None + target_port: Port = PORT_LOOKUP["POSTGRES_SERVER"] + payload: Optional[str] = None + repeat: bool = False + port_scan_p_of_success: float = 0.1 + dos_intensity: float = 1.0 + max_sessions: int = 1000 + + config: "DoSBot.ConfigSchema" = Field(default_factory=lambda: DoSBot.ConfigSchema()) + target_ip_address: Optional[IPv4Address] = None """IP address of the target service.""" @@ -55,8 +72,14 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): def __init__(self, **kwargs): super().__init__(**kwargs) - self.name = "DoSBot" - self.max_sessions = 1000 # override normal max sessions + self.name = "dos-bot" + self.target_ip_address = self.config.target_ip_address + self.target_port = self.config.target_port + self.payload = self.config.payload + self.repeat = self.config.repeat + self.port_scan_p_of_success = self.config.port_scan_p_of_success + self.dos_intensity = self.config.dos_intensity + self.max_sessions = self.config.max_sessions def _init_request_manager(self) -> RequestManager: """ @@ -83,9 +106,9 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): :rtype: RequestResponse """ if "target_ip_address" in request[-1]: - request[-1]["target_ip_address"] = IPv4Address(request[-1]["target_ip_address"]) + request[-1]["target_ip_address"] = ipv4_validator(request[-1]["target_ip_address"]) if "target_port" in request[-1]: - request[-1]["target_port"] = Port[request[-1]["target_port"]] + request[-1]["target_port"] = port_validator(request[-1]["target_port"]) return RequestResponse.from_bool(self.configure(**request[-1])) rm.add_request("configure", request_type=RequestType(func=_configure)) @@ -94,7 +117,7 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): def configure( self, target_ip_address: IPv4Address, - target_port: Optional[Port] = Port.POSTGRES_SERVER, + target_port: Optional[int] = PORT_LOOKUP["POSTGRES_SERVER"], payload: Optional[str] = None, repeat: bool = False, port_scan_p_of_success: float = 0.1, @@ -105,7 +128,7 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): Configure the Denial of Service bot. :param: target_ip_address: The IP address of the Node containing the target service. - :param: target_port: The port of the target service. Optional - Default is `Port.HTTP` + :param: target_port: The port of the target service. Optional - Default is `Port["HTTP"]` :param: payload: The payload the DoS Bot will throw at the target service. Optional - Default is `None` :param: repeat: If True, the bot will maintain the attack. Optional - Default is `True` :param: port_scan_p_of_success: The chance of the port scan being successful. Optional - Default is 0.1 (10%) diff --git a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py index 2046affc..450311ba 100644 --- a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py +++ b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py @@ -1,23 +1,35 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address from typing import Dict, Optional from prettytable import MARKDOWN, PrettyTable +from pydantic import Field from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import PORT_LOOKUP -class RansomwareScript(Application, identifier="RansomwareScript"): +class RansomwareScript(Application, discriminator="ransomware-script"): """Ransomware Kill Chain - Designed to be used by the TAP001 Agent on the example layout Network. :ivar payload: The attack stage query payload. (Default ENCRYPT) """ + class ConfigSchema(Application.ConfigSchema): + """ConfigSchema for RansomwareScript.""" + + type: str = "ransomware-script" + server_ip: Optional[IPV4Address] = None + server_password: Optional[str] = None + payload: str = "ENCRYPT" + + config: "RansomwareScript.ConfigSchema" = Field(default_factory=lambda: RansomwareScript.ConfigSchema()) + server_ip_address: Optional[IPv4Address] = None """IP address of node which hosts the database.""" server_password: Optional[str] = None @@ -26,12 +38,15 @@ class RansomwareScript(Application, identifier="RansomwareScript"): "Payload String for the payload stage" def __init__(self, **kwargs): - kwargs["name"] = "RansomwareScript" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["name"] = "ransomware-script" + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) self._db_connection: Optional[DatabaseClientConnection] = None + self.server_ip_address = self.config.server_ip + self.server_password = self.config.server_password + self.payload = self.config.payload def describe_state(self) -> Dict: """ @@ -48,7 +63,7 @@ class RansomwareScript(Application, identifier="RansomwareScript"): @property def _host_db_client(self) -> DatabaseClient: """Return the database client that is installed on the same machine as the Ransomware Script.""" - db_client: DatabaseClient = self.software_manager.software.get("DatabaseClient") + db_client: DatabaseClient = self.software_manager.software.get("database-client") if db_client is None: self.sys_log.warning(f"{self.__class__.__name__} cannot find a database client on its host.") return db_client diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 73791676..f4944652 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -1,10 +1,10 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from enum import Enum from ipaddress import IPv4Address from typing import Dict, List, Optional from urllib.parse import urlparse -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from primaite import getLogger from primaite.interface.request import RequestResponse @@ -15,22 +15,28 @@ from primaite.simulator.network.protocols.http import ( HttpResponsePacket, HttpStatusCode, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application from primaite.simulator.system.services.dns.dns_client import DNSClient +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) -class WebBrowser(Application, identifier="WebBrowser"): +class WebBrowser(Application, discriminator="web-browser"): """ Represents a web browser in the simulation environment. The application requests and loads web pages using its domain name and requesting IP addresses using DNS. """ - target_url: Optional[str] = None + class ConfigSchema(Application.ConfigSchema): + """ConfigSchema for WebBrowser.""" + + type: str = "web-browser" + target_url: Optional[str] = None + + config: "WebBrowser.ConfigSchema" = Field(default_factory=lambda: WebBrowser.ConfigSchema()) domain_name_ip_address: Optional[IPv4Address] = None "The IP address of the domain name for the webpage." @@ -42,11 +48,11 @@ class WebBrowser(Application, identifier="WebBrowser"): """Keep a log of visited websites and information about the visit, such as response code.""" def __init__(self, **kwargs): - kwargs["name"] = "WebBrowser" - kwargs["protocol"] = IPProtocol.TCP + kwargs["name"] = "web-browser" + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] # default for web is port 80 if kwargs.get("port") is None: - kwargs["port"] = Port.HTTP + kwargs["port"] = PORT_LOOKUP["HTTP"] super().__init__(**kwargs) self.run() @@ -86,7 +92,7 @@ class WebBrowser(Application, identifier="WebBrowser"): :param: url: The address of the web page the browser requests :type: url: str """ - url = url or self.target_url + url = url or self.config.target_url if not self._can_perform_action(): return False @@ -102,7 +108,7 @@ class WebBrowser(Application, identifier="WebBrowser"): return False # get the IP address of the domain name via DNS - dns_client: DNSClient = self.software_manager.software.get("DNSClient") + dns_client: DNSClient = self.software_manager.software.get("dns-client") domain_exists = dns_client.check_domain_exists(target_domain=parsed_url.hostname) # if domain does not exist, the request fails @@ -126,7 +132,7 @@ class WebBrowser(Application, identifier="WebBrowser"): if self.send( payload=payload, dest_ip_address=self.domain_name_ip_address, - dest_port=parsed_url.port if parsed_url.port else Port.HTTP, + dest_port=parsed_url.port if parsed_url.port else PORT_LOOKUP["HTTP"], ): self.sys_log.info( f"{self.name}: Received HTTP {payload.request_method.name} " @@ -154,7 +160,7 @@ class WebBrowser(Application, identifier="WebBrowser"): self, payload: HttpRequestPacket, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = Port.HTTP, + dest_port: Optional[Port] = PORT_LOOKUP["HTTP"], session_id: Optional[str] = None, **kwargs, ) -> bool: diff --git a/src/primaite/simulator/system/core/__init__.py b/src/primaite/simulator/system/core/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/system/core/__init__.py +++ b/src/primaite/simulator/system/core/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/system/core/packet_capture.py b/src/primaite/simulator/system/core/packet_capture.py index ea8b00a5..813c288e 100644 --- a/src/primaite/simulator/system/core/packet_capture.py +++ b/src/primaite/simulator/system/core/packet_capture.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import json import logging from pathlib import Path diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index 677ff477..26e3be79 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations from ipaddress import IPv4Address, IPv4Network @@ -10,11 +10,13 @@ from primaite.simulator.core import SimComponent from primaite.simulator.network.protocols.arp import ARPPacket from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame -from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader, UDPHeader +from primaite.simulator.network.transmission.network_layer import IPPacket +from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP if TYPE_CHECKING: - from primaite.simulator.network.hardware.base import NetworkInterface + from primaite.simulator.network.hardware.base import NetworkInterface, Node from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.core.sys_log import SysLog @@ -34,7 +36,7 @@ class Session(SimComponent): :param connected: A flag indicating whether the session is connected. """ - protocol: IPProtocol + protocol: str with_ip_address: IPv4Address src_port: Optional[Port] dst_port: Optional[Port] @@ -119,7 +121,7 @@ class SessionManager: """ protocol = frame.ip.protocol with_ip_address = frame.ip.src_ip_address - if protocol == IPProtocol.TCP: + if protocol == PROTOCOL_LOOKUP["TCP"]: if inbound_frame: src_port = frame.tcp.src_port dst_port = frame.tcp.dst_port @@ -127,7 +129,7 @@ class SessionManager: dst_port = frame.tcp.src_port src_port = frame.tcp.dst_port with_ip_address = frame.ip.dst_ip_address - elif protocol == IPProtocol.UDP: + elif protocol == PROTOCOL_LOOKUP["UDP"]: if inbound_frame: src_port = frame.udp.src_port dst_port = frame.udp.dst_port @@ -262,7 +264,7 @@ class SessionManager: src_port: Optional[Port] = None, dst_port: Optional[Port] = None, session_id: Optional[str] = None, - ip_protocol: IPProtocol = IPProtocol.TCP, + ip_protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"], icmp_packet: Optional[ICMPPacket] = None, ) -> Union[Any, None]: """ @@ -286,7 +288,7 @@ class SessionManager: dst_mac_address = payload.target_mac_addr outbound_network_interface = self.resolve_outbound_network_interface(payload.target_ip_address) is_broadcast = payload.request - ip_protocol = IPProtocol.UDP + ip_protocol = PROTOCOL_LOOKUP["UDP"] else: vals = self.resolve_outbound_transmission_details( dst_ip_address=dst_ip_address, @@ -311,26 +313,26 @@ class SessionManager: if not outbound_network_interface or not dst_mac_address: return False - if not (src_port or dst_port): + if src_port is None and dst_port is None: raise ValueError( "Failed to resolve src or dst port. Have you sent the port from the service or application?" ) tcp_header = None udp_header = None - if ip_protocol == IPProtocol.TCP: + if ip_protocol == PROTOCOL_LOOKUP["TCP"]: tcp_header = TCPHeader( src_port=dst_port, dst_port=dst_port, ) - elif ip_protocol == IPProtocol.UDP: + elif ip_protocol == PROTOCOL_LOOKUP["UDP"]: udp_header = UDPHeader( src_port=dst_port, dst_port=dst_port, ) # TODO: Only create IP packet if not ARP # ip_packet = None - # if dst_port != Port.ARP: + # if dst_port != Port["ARP"]: # IPPacket( # src_ip_address=outbound_network_interface.ip_address, # dst_ip_address=dst_ip_address, @@ -387,7 +389,7 @@ class SessionManager: elif frame.udp: dst_port = frame.udp.dst_port elif frame.icmp: - dst_port = Port.NONE + dst_port = PORT_LOOKUP["NONE"] self.software_manager.receive_payload_from_session_manager( payload=frame.payload, port=dst_port, @@ -413,5 +415,5 @@ class SessionManager: table.align = "l" table.title = f"{self.sys_log.hostname} Session Manager" for session in self.sessions_by_key.values(): - table.add_row([session.with_ip_address, session.dst_port.value, session.protocol.name]) + table.add_row([session.with_ip_address, session.dst_port, session.protocol]) print(table) diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index d45611ed..67e555ae 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from copy import deepcopy from ipaddress import IPv4Address, IPv4Network from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union @@ -8,12 +8,12 @@ from prettytable import MARKDOWN, PrettyTable from primaite.simulator.core import RequestType from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application, ApplicationOperatingState from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.service import Service, ServiceOperatingState from primaite.simulator.system.software import IOSoftware +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP if TYPE_CHECKING: from primaite.simulator.system.core.session_manager import SessionManager @@ -60,12 +60,12 @@ class SoftwareManager: @property def arp(self) -> "ARP": """Provides access to the ARP service instance, if installed.""" - return self.software.get("ARP") # noqa + return self.software.get("arp") # noqa @property def icmp(self) -> "ICMP": """Provides access to the ICMP service instance, if installed.""" - return self.software.get("ICMP") # noqa + return self.software.get("icmp") # noqa def get_open_ports(self) -> List[Port]: """ @@ -106,7 +106,7 @@ class SoftwareManager: return True return False - def install(self, software_class: Type[IOSoftware], **install_kwargs): + def install(self, software_class: Type[IOSoftware], software_config: Optional[IOSoftware.ConfigSchema] = None): """ Install an Application or Service. @@ -115,13 +115,22 @@ class SoftwareManager: if software_class in self._software_class_to_name_map: self.sys_log.warning(f"Cannot install {software_class} as it is already installed") return - software = software_class( - software_manager=self, - sys_log=self.sys_log, - file_system=self.file_system, - dns_server=self.dns_server, - **install_kwargs, - ) + if software_config is None: + software = software_class( + software_manager=self, + sys_log=self.sys_log, + file_system=self.file_system, + dns_server=self.dns_server, + ) + else: + software = software_class( + software_manager=self, + sys_log=self.sys_log, + file_system=self.file_system, + dns_server=self.dns_server, + config=software_config, + ) + software.parent = self.node if isinstance(software, Application): self.node.applications[software.uuid] = software @@ -131,6 +140,7 @@ class SoftwareManager: elif isinstance(software, Service): self.node.services[software.uuid] = software self.node._service_request_manager.add_request(software.name, RequestType(func=software._request_manager)) + software.start() software.install() software.software_manager = self self.software[software.name] = software @@ -191,7 +201,7 @@ class SoftwareManager: dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, src_port: Optional[Port] = None, dest_port: Optional[Port] = None, - ip_protocol: IPProtocol = IPProtocol.TCP, + ip_protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"], session_id: Optional[str] = None, ) -> bool: """ @@ -234,7 +244,7 @@ class SoftwareManager: :param session: The transport session the payload originates from. """ if payload.__class__.__name__ == "PortScanPayload": - self.software.get("NMAP").receive(payload=payload, session_id=session_id) + self.software.get("nmap").receive(payload=payload, session_id=session_id) return main_receiver = self.port_protocol_mapping.get((port, protocol), None) if main_receiver: @@ -267,7 +277,7 @@ class SoftwareManager: table.set_style(MARKDOWN) table.align = "l" table.title = f"{self.sys_log.hostname} Software Manager" - for software in self.port_protocol_mapping.values(): + for software in self.software.values(): software_type = "Service" if isinstance(software, Service) else "Application" table.add_row( [ @@ -275,8 +285,8 @@ class SoftwareManager: software_type, software.operating_state.name, software.health_state_actual.name, - software.port.value if software.port != Port.NONE else None, - software.protocol.value, + software.port if software.port != PORT_LOOKUP["NONE"] else None, + software.protocol, ] ) print(table) diff --git a/src/primaite/simulator/system/core/sys_log.py b/src/primaite/simulator/system/core/sys_log.py index 9e22696d..741e5d33 100644 --- a/src/primaite/simulator/system/core/sys_log.py +++ b/src/primaite/simulator/system/core/sys_log.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import logging from pathlib import Path diff --git a/src/primaite/simulator/system/processes/__init__.py b/src/primaite/simulator/system/processes/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/system/processes/__init__.py +++ b/src/primaite/simulator/system/processes/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/system/processes/process.py b/src/primaite/simulator/system/processes/process.py index 225505c8..ad2babc1 100644 --- a/src/primaite/simulator/system/processes/process.py +++ b/src/primaite/simulator/system/processes/process.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from abc import abstractmethod from enum import Enum from typing import Dict diff --git a/src/primaite/simulator/system/services/__init__.py b/src/primaite/simulator/system/services/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/system/services/__init__.py +++ b/src/primaite/simulator/system/services/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/system/services/access/__init__.py b/src/primaite/simulator/system/services/access/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/system/services/access/__init__.py +++ b/src/primaite/simulator/system/services/access/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/system/services/access/user_manager.py b/src/primaite/simulator/system/services/access/user_manager.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/system/services/access/user_manager.py +++ b/src/primaite/simulator/system/services/access/user_manager.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/system/services/access/user_session_manager.py b/src/primaite/simulator/system/services/access/user_session_manager.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/system/services/access/user_session_manager.py +++ b/src/primaite/simulator/system/services/access/user_session_manager.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/system/services/arp/__init__.py b/src/primaite/simulator/system/services/arp/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/system/services/arp/__init__.py +++ b/src/primaite/simulator/system/services/arp/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/system/services/arp/arp.py b/src/primaite/simulator/system/services/arp/arp.py index 9314bea7..3302041d 100644 --- a/src/primaite/simulator/system/services/arp/arp.py +++ b/src/primaite/simulator/system/services/arp/arp.py @@ -1,20 +1,21 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations from abc import abstractmethod from typing import Any, Dict, Optional, Union from prettytable import MARKDOWN, PrettyTable +from pydantic import Field from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.service import Service -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import PORT_LOOKUP -class ARP(Service): +class ARP(Service, discriminator="arp"): """ The ARP (Address Resolution Protocol) Service. @@ -22,12 +23,19 @@ class ARP(Service): sends ARP requests and replies, and processes incoming ARP packets. """ + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for ARP.""" + + type: str = "arp" + + config: "ARP.ConfigSchema" = Field(default_factory=lambda: ARP.ConfigSchema()) + arp: Dict[IPV4Address, ARPEntry] = {} def __init__(self, **kwargs): - kwargs["name"] = "ARP" - kwargs["port"] = Port.ARP - kwargs["protocol"] = IPProtocol.UDP + kwargs["name"] = "arp" + kwargs["port"] = PORT_LOOKUP["ARP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["UDP"] super().__init__(**kwargs) def describe_state(self) -> Dict: @@ -130,8 +138,8 @@ class ARP(Service): break if use_default_gateway: - if self.software_manager.node.default_gateway: - target_ip_address = self.software_manager.node.default_gateway + if self.software_manager.node.config.default_gateway: + target_ip_address = self.software_manager.node.config.default_gateway else: return diff --git a/src/primaite/simulator/system/services/database/__init__.py b/src/primaite/simulator/system/services/database/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/system/services/database/__init__.py +++ b/src/primaite/simulator/system/services/database/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index b38e87b4..edc3f6b4 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -1,31 +1,40 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address from typing import Any, Dict, List, Literal, Optional, Union from uuid import uuid4 +from pydantic import Field + from primaite import getLogger from primaite.simulator.file_system.file_system import File from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus from primaite.simulator.file_system.folder import Folder -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.service import Service, ServiceOperatingState from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) -class DatabaseService(Service): +class DatabaseService(Service, discriminator="database-service"): """ A class for simulating a generic SQL Server service. This class inherits from the `Service` class and provides methods to simulate a SQL database. """ - password: Optional[str] = None - """Password that needs to be provided by clients if they want to connect to the DatabaseService.""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for DatabaseService.""" + + type: str = "database-service" + backup_server_ip: Optional[IPv4Address] = None + db_password: Optional[str] = None + """Password that needs to be provided by clients if they want to connect to the DatabaseService.""" + + config: ConfigSchema = Field(default_factory=lambda: DatabaseService.ConfigSchema()) backup_server_ip: IPv4Address = None """IP address of the backup server.""" @@ -37,11 +46,21 @@ class DatabaseService(Service): """File name of latest backup.""" def __init__(self, **kwargs): - kwargs["name"] = "DatabaseService" - kwargs["port"] = Port.POSTGRES_SERVER - kwargs["protocol"] = IPProtocol.TCP + kwargs["name"] = "database-service" + kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self._create_db_file() + self.backup_server_ip = self.config.backup_server_ip + + @property + def password(self) -> Optional[str]: + """Convenience property for accessing the password.""" + return self.config.db_password + + @password.setter + def password(self, val: str) -> None: + self.config.db_password = val def install(self): """ @@ -51,7 +70,7 @@ class DatabaseService(Service): """ super().install() - if not self.parent.software_manager.software.get("FTPClient"): + if not self.parent.software_manager.software.get("ftp-client"): self.parent.sys_log.info(f"{self.name}: Installing FTPClient to enable database backups") self.parent.software_manager.install(FTPClient) @@ -75,7 +94,7 @@ class DatabaseService(Service): return False software_manager: SoftwareManager = self.software_manager - ftp_client_service: FTPClient = software_manager.software.get("FTPClient") + ftp_client_service: FTPClient = software_manager.software.get("ftp-client") if not ftp_client_service: self.sys_log.error( @@ -109,7 +128,7 @@ class DatabaseService(Service): return False software_manager: SoftwareManager = self.software_manager - ftp_client_service: FTPClient = software_manager.software.get("FTPClient") + ftp_client_service: FTPClient = software_manager.software.get("ftp-client") if not ftp_client_service: self.sys_log.error( @@ -206,7 +225,7 @@ class DatabaseService(Service): SoftwareHealthState.FIXING, SoftwareHealthState.COMPROMISED, ]: - if self.password == password: + if self.config.db_password == password: status_code = 200 # ok connection_id = self._generate_connection_id() # try to create connection diff --git a/src/primaite/simulator/system/services/dns/__init__.py b/src/primaite/simulator/system/services/dns/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/system/services/dns/__init__.py +++ b/src/primaite/simulator/system/services/dns/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index d7ba0cd4..8b16af69 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -1,32 +1,44 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address -from typing import Dict, Optional +from typing import Dict, Optional, TYPE_CHECKING + +from pydantic import Field from primaite import getLogger from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import Port, PORT_LOOKUP + +if TYPE_CHECKING: + from primaite.simulator.network.hardware.base import Node _LOGGER = getLogger(__name__) -class DNSClient(Service): +class DNSClient(Service, discriminator="dns-client"): """Represents a DNS Client as a Service.""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for DNSClient.""" + + type: str = "dns-client" + dns_server: Optional[IPV4Address] = None + "The DNS Server the client sends requests to." + + config: ConfigSchema = Field(default_factory=lambda: DNSClient.ConfigSchema()) dns_cache: Dict[str, IPv4Address] = {} "A dict of known mappings between domain/URLs names and IPv4 addresses." - dns_server: Optional[IPv4Address] = None - "The DNS Server the client sends requests to." def __init__(self, **kwargs): - kwargs["name"] = "DNSClient" - kwargs["port"] = Port.DNS + kwargs["name"] = "dns-client" + kwargs["port"] = PORT_LOOKUP["DNS"] # DNS uses UDP by default # it switches to TCP when the bytes exceed 512 (or 4096) bytes # TCP for now - kwargs["protocol"] = IPProtocol.TCP + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self.start() @@ -43,6 +55,15 @@ class DNSClient(Service): state = super().describe_state() return state + @property + def dns_server(self) -> Optional[IPV4Address]: + """Convenience property for accessing the dns server configuration.""" + return self.config.dns_server + + @dns_server.setter + def dns_server(self, val: IPV4Address) -> None: + self.config.dns_server = val + def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address) -> bool: """ Adds a domain name to the DNS Client cache. @@ -71,6 +92,14 @@ class DNSClient(Service): if not self._can_perform_action(): return False + # check if the domain is already in the DNS cache + if target_domain in self.dns_cache: + self.sys_log.info( + f"{self.name}: Domain lookup for {target_domain} successful," + f"resolves to {self.dns_cache[target_domain]}" + ) + return True + # check if DNS server is configured if self.dns_server is None: self.sys_log.warning(f"{self.name}: DNS Server is not configured") @@ -79,31 +108,23 @@ class DNSClient(Service): # check if the target domain is in the client's DNS cache payload = DNSPacket(dns_request=DNSRequest(domain_name_request=target_domain)) - # check if the domain is already in the DNS cache - if target_domain in self.dns_cache: - self.sys_log.info( - f"{self.name}: Domain lookup for {target_domain} successful," - f"resolves to {self.dns_cache[target_domain]}" - ) - return True + # return False if already reattempted + if is_reattempt: + self.sys_log.warning(f"{self.name}: Domain lookup for {target_domain} failed") + return False else: - # return False if already reattempted - if is_reattempt: - self.sys_log.warning(f"{self.name}: Domain lookup for {target_domain} failed") - return False - else: - # send a request to check if domain name exists in the DNS Server - software_manager: SoftwareManager = self.software_manager - software_manager.send_payload_to_session_manager( - payload=payload, dest_ip_address=self.dns_server, dest_port=Port.DNS - ) + # send a request to check if domain name exists in the DNS Server + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload=payload, dest_ip_address=self.dns_server, dest_port=PORT_LOOKUP["DNS"] + ) - # recursively re-call the function passing is_reattempt=True - return self.check_domain_exists( - target_domain=target_domain, - session_id=session_id, - is_reattempt=True, - ) + # recursively re-call the function passing is_reattempt=True + return self.check_domain_exists( + target_domain=target_domain, + session_id=session_id, + is_reattempt=True, + ) def send( self, @@ -160,3 +181,9 @@ class DNSClient(Service): self.sys_log.warning(f"Failed to resolve domain name {payload.dns_request.domain_name_request}") return False + + def install(self) -> None: + """Set the DNS server to be the node's DNS server unless a different one was already provided.""" + self.parent: Node + if self.parent and not self.dns_server: + self.config.dns_server = self.parent.dns_server diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index 8a4bbaed..696af993 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -1,32 +1,42 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address from typing import Any, Dict, Optional from prettytable import MARKDOWN, PrettyTable +from pydantic import Field from primaite import getLogger from primaite.simulator.network.protocols.dns import DNSPacket -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.service import Service +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) -class DNSServer(Service): +class DNSServer(Service, discriminator="dns-server"): """Represents a DNS Server as a Service.""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for DNSServer.""" + + type: str = "dns-server" + domain_mapping: dict = {} + + config: ConfigSchema = Field(default_factory=lambda: DNSServer.ConfigSchema()) + dns_table: Dict[str, IPv4Address] = {} "A dict of mappings between domain names and IPv4 addresses." def __init__(self, **kwargs): - kwargs["name"] = "DNSServer" - kwargs["port"] = Port.DNS + kwargs["name"] = "dns-server" + kwargs["port"] = PORT_LOOKUP["DNS"] # DNS uses UDP by default # it switches to TCP when the bytes exceed 512 (or 4096) bytes # TCP for now - kwargs["protocol"] = IPProtocol.TCP + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) + self.dns_table = self.config.domain_mapping self.start() def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/services/ftp/__init__.py b/src/primaite/simulator/system/services/ftp/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/system/services/ftp/__init__.py +++ b/src/primaite/simulator/system/services/ftp/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index f823e42c..5e97243a 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -1,32 +1,42 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address from typing import Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.file_system.file_system import File from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC +from primaite.simulator.system.services.service import Service +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) -class FTPClient(FTPServiceABC): +class FTPClient(FTPServiceABC, discriminator="ftp-client"): """ A class for simulating an FTP client service. - This class inherits from the `Service` class and provides methods to emulate FTP + This class inherits from the `FTPServiceABC` class and provides methods to emulate FTP RFC 959: https://datatracker.ietf.org/doc/html/rfc959 """ + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for FTPClient.""" + + type: str = "ftp-client" + + config: ConfigSchema = Field(default_factory=lambda: FTPClient.ConfigSchema()) + def __init__(self, **kwargs): - kwargs["name"] = "FTPClient" - kwargs["port"] = Port.FTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["name"] = "ftp-client" + kwargs["port"] = PORT_LOOKUP["FTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self.start() @@ -104,10 +114,11 @@ class FTPClient(FTPServiceABC): def _connect_to_server( self, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = Port.FTP, + dest_port: Optional[Port] = PORT_LOOKUP["FTP"], session_id: Optional[str] = None, is_reattempt: Optional[bool] = False, ) -> bool: + self._active = True """ Connects the client to a given FTP server. @@ -124,13 +135,13 @@ class FTPClient(FTPServiceABC): # normally FTP will choose a random port for the transfer, but using the FTP command port will do for now # create FTP packet - payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.PORT, ftp_command_args=Port.FTP) + payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.PORT, ftp_command_args=PORT_LOOKUP["FTP"]) if self.send(payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id): if payload.status_code == FTPStatusCode.OK: self.sys_log.info( f"{self.name}: Successfully connected to FTP Server " - f"{dest_ip_address} via port {payload.ftp_command_args.value}" + f"{dest_ip_address} via port {payload.ftp_command_args}" ) self.add_connection(connection_id="server_connection", session_id=session_id) return True @@ -139,7 +150,7 @@ class FTPClient(FTPServiceABC): # reattempt failed self.sys_log.warning( f"{self.name}: Unable to connect to FTP Server " - f"{dest_ip_address} via port {payload.ftp_command_args.value}" + f"{dest_ip_address} via port {payload.ftp_command_args}" ) return False else: @@ -152,7 +163,7 @@ class FTPClient(FTPServiceABC): return False def _disconnect_from_server( - self, dest_ip_address: Optional[IPv4Address] = None, dest_port: Optional[Port] = Port.FTP + self, dest_ip_address: Optional[IPv4Address] = None, dest_port: Optional[Port] = PORT_LOOKUP["FTP"] ) -> bool: """ Connects the client from a given FTP server. @@ -164,6 +175,7 @@ class FTPClient(FTPServiceABC): :param: is_reattempt: Set to True if attempt to disconnect from FTP Server has been attempted. Default False. :type: is_reattempt: Optional[bool] """ + self._active = True # send a disconnect request payload to FTP server payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.QUIT) software_manager: SoftwareManager = self.software_manager @@ -179,7 +191,7 @@ class FTPClient(FTPServiceABC): src_file_name: str, dest_folder_name: str, dest_file_name: str, - dest_port: Optional[Port] = Port.FTP, + dest_port: Optional[Port] = PORT_LOOKUP["FTP"], session_id: Optional[str] = None, ) -> bool: """ @@ -203,12 +215,13 @@ class FTPClient(FTPServiceABC): :param: dest_file_name: The name of the file to be saved on the FTP Server. :type: dest_file_name: str - :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port.FTP. + :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"]. :type: dest_port: Optional[Port] :param: session_id: The id of the session :type: session_id: Optional[str] """ + self._active = True # check if the file to transfer exists on the client file_to_transfer: File = self.file_system.get_file(folder_name=src_folder_name, file_name=src_file_name) if not file_to_transfer: @@ -241,7 +254,7 @@ class FTPClient(FTPServiceABC): src_file_name: str, dest_folder_name: str, dest_file_name: str, - dest_port: Optional[Port] = Port.FTP, + dest_port: Optional[Port] = PORT_LOOKUP["FTP"], ) -> bool: """ Request a file from a target IP address. @@ -263,9 +276,10 @@ class FTPClient(FTPServiceABC): :param: dest_file_name: The name of the file to be saved on the FTP Server. :type: dest_file_name: str - :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port.FTP. - :type: dest_port: Optional[Port] + :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"]. + :type: dest_port: Optional[int] """ + self._active = True # check if FTP is currently connected to IP self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port) @@ -317,6 +331,7 @@ class FTPClient(FTPServiceABC): This helps prevent an FTP request loop - FTP client and servers can exist on the same node. """ + self._active = True if not self._can_perform_action(): return False diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index f02d01f4..86e07c54 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -1,33 +1,46 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Any, Optional +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import is_valid_port, PORT_LOOKUP _LOGGER = getLogger(__name__) -class FTPServer(FTPServiceABC): +class FTPServer(FTPServiceABC, discriminator="ftp-server"): """ A class for simulating an FTP server service. - This class inherits from the `Service` class and provides methods to emulate FTP + This class inherits from the `FTPServiceABC` class and provides methods to emulate FTP RFC 959: https://datatracker.ietf.org/doc/html/rfc959 """ + class ConfigSchema(FTPServiceABC.ConfigSchema): + """ConfigSchema for FTPServer.""" + + type: str = "ftp-server" + server_password: Optional[str] = None + + config: ConfigSchema = Field(default_factory=lambda: FTPServer.ConfigSchema()) server_password: Optional[str] = None - """Password needed to connect to FTP server. Default is None.""" def __init__(self, **kwargs): - kwargs["name"] = "FTPServer" - kwargs["port"] = Port.FTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["name"] = "ftp-server" + kwargs["port"] = PORT_LOOKUP["FTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self.start() + @property + def server_password(self) -> Optional[str]: + """Convenience method for accessing FTP server password.""" + return self.config.server_password + def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket: """ Process the command in the FTP Packet. @@ -52,7 +65,7 @@ class FTPServer(FTPServiceABC): # process server specific commands, otherwise call super if payload.ftp_command == FTPCommand.PORT: # check that the port is valid - if isinstance(payload.ftp_command_args, Port) and payload.ftp_command_args.value in range(0, 65535): + if is_valid_port(payload.ftp_command_args): # return successful connection self.add_connection(connection_id=session_id, session_id=session_id) payload.status_code = FTPStatusCode.OK diff --git a/src/primaite/simulator/system/services/ftp/ftp_service.py b/src/primaite/simulator/system/services/ftp/ftp_service.py index 689a3da7..13acda70 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_service.py +++ b/src/primaite/simulator/system/services/ftp/ftp_service.py @@ -1,12 +1,14 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from abc import ABC from ipaddress import IPv4Address from typing import Dict, Optional +from pydantic import StrictBool + from primaite.simulator.file_system.file_system import File from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.transport_layer import Port -from primaite.simulator.system.services.service import Service +from primaite.simulator.system.services.service import Service, ServiceOperatingState +from primaite.utils.validation.port import Port class FTPServiceABC(Service, ABC): @@ -16,9 +18,22 @@ class FTPServiceABC(Service, ABC): Contains shared methods between both classes. """ + _active: StrictBool = False + """Flag that is True on timesteps where service transmits data and False when idle. Used for describe_state.""" + + def pre_timestep(self, timestep: int) -> None: + """When a new timestep begins, clear the _active attribute.""" + self._active = False + return super().pre_timestep(timestep) + def describe_state(self) -> Dict: """Returns a Dict of the FTPService state.""" - return super().describe_state() + state = super().describe_state() + + # override so that the service is shows as running only if actively transmitting data this timestep + if self.operating_state == ServiceOperatingState.RUNNING and not self._active: + state["operating_state"] = ServiceOperatingState.STOPPED.value + return state def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket: """ @@ -29,6 +44,7 @@ class FTPServiceABC(Service, ABC): :param: session_id: session ID linked to the FTP Packet. Optional. :type: session_id: Optional[str] """ + self._active = True if payload.ftp_command is not None: self.sys_log.info(f"Received FTP {payload.ftp_command.name} command.") @@ -51,6 +67,7 @@ class FTPServiceABC(Service, ABC): :param: payload: The FTP Packet that contains the file data :type: FTPPacket """ + self._active = True try: file_name = payload.ftp_command_args["dest_file_name"] folder_name = payload.ftp_command_args["dest_folder_name"] @@ -97,7 +114,7 @@ class FTPServiceABC(Service, ABC): :param: dest_ip_address: The IP address of the machine that hosts the FTP Server. :type: dest_ip_address: Optional[IPv4Address] - :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port.FTP. + :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"]. :type: dest_port: Optional[Port] :param: session_id: session ID linked to the FTP Packet. Optional. @@ -106,6 +123,7 @@ class FTPServiceABC(Service, ABC): :param: is_response: is true if the data being sent is in response to a request. Default False. :type: is_response: bool """ + self._active = True # send STOR request payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.STOR, @@ -135,6 +153,7 @@ class FTPServiceABC(Service, ABC): :param: payload: The FTP Packet that contains the file data :type: FTPPacket """ + self._active = True try: # find the file file_name = payload.ftp_command_args["src_file_name"] @@ -181,6 +200,7 @@ class FTPServiceABC(Service, ABC): :return: True if successful, False otherwise. """ + self._active = True self.sys_log.info(f"{self.name}: Sending FTP {payload.ftp_command.name} {payload.ftp_command_args}") return super().send( diff --git a/src/primaite/simulator/system/services/icmp/__init__.py b/src/primaite/simulator/system/services/icmp/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/system/services/icmp/__init__.py +++ b/src/primaite/simulator/system/services/icmp/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/system/services/icmp/icmp.py b/src/primaite/simulator/system/services/icmp/icmp.py index 6741d86a..207940cf 100644 --- a/src/primaite/simulator/system/services/icmp/icmp.py +++ b/src/primaite/simulator/system/services/icmp/icmp.py @@ -1,20 +1,22 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import secrets from ipaddress import IPv4Address from typing import Any, Dict, Optional, Tuple, Union +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.service import Service +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) -class ICMP(Service): +class ICMP(Service, discriminator="icmp"): """ The Internet Control Message Protocol (ICMP) service. @@ -22,12 +24,19 @@ class ICMP(Service): network diagnostics, notably the ping command. """ + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for ICMP.""" + + type: str = "icmp" + + config: "ICMP.ConfigSchema" = Field(default_factory=lambda: ICMP.ConfigSchema()) + request_replies: Dict = {} def __init__(self, **kwargs): - kwargs["name"] = "ICMP" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.ICMP + kwargs["name"] = "icmp" + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["ICMP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/services/icmp/router_icmp.py b/src/primaite/simulator/system/services/icmp/router_icmp.py index 4fdc6baa..4c69e381 100644 --- a/src/primaite/simulator/system/services/icmp/router_icmp.py +++ b/src/primaite/simulator/system/services/icmp/router_icmp.py @@ -1,13 +1,13 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -# class RouterICMP(ICMP): +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +# class RouterICMP(icmp): # """ -# A class to represent a router's Internet Control Message Protocol (ICMP) handler. +# A class to represent a router's Internet Control Message Protocol (icmp) handler. # # :param sys_log: System log for logging network events and errors. # :type sys_log: SysLog -# :param arp_cache: The ARP cache for resolving MAC addresses. +# :param arp_cache: The arp cache for resolving MAC addresses. # :type arp_cache: ARPCache -# :param router: The router to which this ICMP handler belongs. +# :param router: The router to which this icmp handler belongs. # :type router: Router # """ # @@ -19,7 +19,7 @@ # # def process_icmp(self, frame: Frame, from_network_interface: NIC, is_reattempt: bool = False): # """ -# Process incoming ICMP frames based on ICMP type. +# Process incoming icmp frames based on icmp type. # # :param frame: The incoming frame to process. # :param from_network_interface: The network interface where the frame is coming from. @@ -36,13 +36,13 @@ # self.sys_log.info(f"Received echo request from {frame.ip.src_ip_address}") # target_mac_address = self.arp.get_arp_cache_mac_address(frame.ip.src_ip_address) # src_nic = self.arp.get_arp_cache_network_interface(frame.ip.src_ip_address) -# tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP) +# tcp_header = TCPHeader(src_port=Port["arp"], dst_port=Port["arp"]) # # # Network Layer # ip_packet = IPPacket( # src_ip_address=network_interface.ip_address, # dst_ip_address=frame.ip.src_ip_address, -# protocol=IPProtocol.ICMP, +# protocol=IPProtocol["icmp"], # ) # # Data Link Layer # ethernet_header = EthernetHeader( @@ -54,7 +54,7 @@ # identifier=frame.icmp.identifier, # sequence=frame.icmp.sequence + 1, # ) -# payload = secrets.token_urlsafe(int(32 / 1.3)) # Standard ICMP 32 bytes size +# payload = secrets.token_urlsafe(int(32 / 1.3)) # Standard icmp 32 bytes size # frame = Frame( # ethernet=ethernet_header, # ip=ip_packet, diff --git a/src/primaite/simulator/system/services/ntp/__init__.py b/src/primaite/simulator/system/services/ntp/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/system/services/ntp/__init__.py +++ b/src/primaite/simulator/system/services/ntp/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index 8924a821..6bd1f4bb 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -1,29 +1,40 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from datetime import datetime from ipaddress import IPv4Address from typing import Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.service import Service, ServiceOperatingState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) -class NTPClient(Service): +class NTPClient(Service, discriminator="ntp-client"): """Represents a NTP client as a service.""" - ntp_server: Optional[IPv4Address] = None - "The NTP server the client sends requests to." + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for NTPClient.""" + + type: str = "ntp-client" + ntp_server_ip: Optional[IPV4Address] = None + "The NTP server the client sends requests to." + + config: ConfigSchema = Field(default_factory=lambda: NTPClient.ConfigSchema()) + time: Optional[datetime] = None def __init__(self, **kwargs): - kwargs["name"] = "NTPClient" - kwargs["port"] = Port.NTP - kwargs["protocol"] = IPProtocol.UDP + kwargs["name"] = "ntp-client" + kwargs["port"] = PORT_LOOKUP["NTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["UDP"] super().__init__(**kwargs) + self.ntp_server = self.config.ntp_server_ip self.start() def configure(self, ntp_server_ip_address: IPv4Address) -> None: @@ -33,8 +44,8 @@ class NTPClient(Service): :param ntp_server_ip_address: IPv4 address of NTP server. :param ntp_client_ip_Address: IPv4 address of NTP client. """ - self.ntp_server = ntp_server_ip_address - self.sys_log.info(f"{self.name}: ntp_server: {self.ntp_server}") + self.config.ntp_server_ip = ntp_server_ip_address + self.sys_log.info(f"{self.name}: ntp_server: {self.config.ntp_server_ip}") def describe_state(self) -> Dict: """ @@ -55,7 +66,7 @@ class NTPClient(Service): payload: NTPPacket, session_id: Optional[str] = None, dest_ip_address: IPv4Address = None, - dest_port: Port = Port.NTP, + dest_port: Port = PORT_LOOKUP["NTP"], **kwargs, ) -> bool: """Requests NTP data from NTP server. @@ -96,10 +107,10 @@ class NTPClient(Service): def request_time(self) -> None: """Send request to ntp_server.""" - if self.ntp_server: + if self.config.ntp_server_ip: self.software_manager.session_manager.receive_payload_from_software_manager( payload=NTPPacket(), - dst_ip_address=self.ntp_server, + dst_ip_address=self.config.ntp_server_ip, src_port=self.port, dst_port=self.port, ip_protocol=self.protocol, diff --git a/src/primaite/simulator/system/services/ntp/ntp_server.py b/src/primaite/simulator/system/services/ntp/ntp_server.py index 547bbc06..05696d9f 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_server.py +++ b/src/primaite/simulator/system/services/ntp/ntp_server.py @@ -1,23 +1,32 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from datetime import datetime from typing import Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.service import Service +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) -class NTPServer(Service): +class NTPServer(Service, discriminator="ntp-server"): """Represents a NTP server as a service.""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for NTPServer.""" + + type: str = "ntp-server" + + config: ConfigSchema = Field(default_factory=lambda: NTPServer.ConfigSchema()) + def __init__(self, **kwargs): - kwargs["name"] = "NTPServer" - kwargs["port"] = Port.NTP - kwargs["protocol"] = IPProtocol.UDP + kwargs["name"] = "ntp-server" + kwargs["port"] = PORT_LOOKUP["NTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["UDP"] super().__init__(**kwargs) self.start() diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 5adea6e7..a7b8fd09 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -1,9 +1,11 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations -from abc import abstractmethod +from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Dict, Optional +from typing import Any, ClassVar, Dict, Optional, Type + +from pydantic import Field from primaite import getLogger from primaite.interface.request import RequestFormat, RequestResponse @@ -37,6 +39,13 @@ class Service(IOSoftware): Services are programs that run in the background and may perform input/output operations. """ + class ConfigSchema(IOSoftware.ConfigSchema, ABC): + """Config Schema for Service class.""" + + type: str + + config: "Service.ConfigSchema" = Field(default_factory=lambda: Service.ConfigSchema()) + operating_state: ServiceOperatingState = ServiceOperatingState.STOPPED "The current operating state of the Service." @@ -46,9 +55,44 @@ class Service(IOSoftware): restart_countdown: Optional[int] = None "If currently restarting, how many timesteps remain until the restart is finished." + _registry: ClassVar[Dict[str, Type["Service"]]] = {} + """Registry of service types. Automatically populated when subclasses are defined.""" + def __init__(self, **kwargs): super().__init__(**kwargs) + def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None: + """ + Register a hostnode type. + + :param discriminator: Uniquely specifies an hostnode class by name. Used for finding items by config. + :type discriminator: str + :raises ValueError: When attempting to register an hostnode with a name that is already allocated. + """ + super().__init_subclass__(**kwargs) + if discriminator is None: + return + # Enforce lowercase registry entries because it makes comparisons everywhere else much easier. + discriminator = discriminator.lower() + if discriminator in cls._registry: + raise ValueError(f"Tried to define new hostnode {discriminator}, but this name is already reserved.") + cls._registry[discriminator] = cls + + @classmethod + def from_config(cls, config: Dict) -> "Service": + """Create a service from a config dictionary. + + :param config: dict of options for service components constructor + :type config: dict + :return: The service component. + :rtype: Service + """ + if config["type"] not in cls._registry: + raise ValueError(f"Invalid service type {config['type']}") + service_class = cls._registry[config["type"]] + service_object = service_class(config=service_class.ConfigSchema(**config)) + return service_object + def _can_perform_action(self) -> bool: """ Checks if the service can perform actions. @@ -212,14 +256,14 @@ class Service(IOSoftware): def disable(self) -> bool: """Disable the service.""" - self.sys_log.info(f"Disabling Application {self.name}") + self.sys_log.info(f"Disabling Service {self.name}") self.operating_state = ServiceOperatingState.DISABLED return True def enable(self) -> bool: """Enable the disabled service.""" if self.operating_state == ServiceOperatingState.DISABLED: - self.sys_log.info(f"Enabling Application {self.name}") + self.sys_log.info(f"Enabling Service {self.name}") self.operating_state = ServiceOperatingState.STOPPED return True return False diff --git a/src/primaite/simulator/system/services/terminal/__init__.py b/src/primaite/simulator/system/services/terminal/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/system/services/terminal/__init__.py +++ b/src/primaite/simulator/system/services/terminal/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index ed6854f4..112f6abc 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations from abc import abstractmethod @@ -7,7 +7,7 @@ from ipaddress import IPv4Address from typing import Any, Dict, List, Optional, Union from uuid import uuid4 -from pydantic import BaseModel +from pydantic import BaseModel, Field from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType @@ -17,10 +17,10 @@ from primaite.simulator.network.protocols.ssh import ( SSHTransportMessage, SSHUserCredentials, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service, ServiceOperatingState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP # TODO 2824: Since remote terminal connections and remote user sessions are the same thing, we could refactor @@ -129,9 +129,16 @@ class RemoteTerminalConnection(TerminalClientConnection): return self.parent_terminal.send(payload=payload, session_id=self.ssh_session_id) -class Terminal(Service): +class Terminal(Service, discriminator="terminal"): """Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH.""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for Terminal.""" + + type: str = "terminal" + + config: "Terminal.ConfigSchema" = Field(default_factory=lambda: Terminal.ConfigSchema()) + _client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {} """Dictionary of connect requests made to remote nodes.""" @@ -139,9 +146,9 @@ class Terminal(Service): """Last response received from RequestManager, for returning remote RequestResponse.""" def __init__(self, **kwargs): - kwargs["name"] = "Terminal" - kwargs["port"] = Port.SSH - kwargs["protocol"] = IPProtocol.TCP + kwargs["name"] = "terminal" + kwargs["port"] = PORT_LOOKUP["SSH"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) @property @@ -187,7 +194,7 @@ class Terminal(Service): return RequestResponse(status="failure", data={}) rm.add_request( - "ssh_to_remote", + "node_session_remote_login", request_type=RequestType(func=_remote_login), ) @@ -304,7 +311,6 @@ class Terminal(Service): :param password: Password for login. :return: boolean, True if successful, else False """ - # TODO: Un-comment this when UserSessionManager is merged. connection_uuid = self.parent.user_session_manager.local_login(username=username, password=password) if connection_uuid: self.sys_log.info(f"{self.name}: Login request authorised, connection uuid: {connection_uuid}") diff --git a/src/primaite/simulator/system/services/web_server/__init__.py b/src/primaite/simulator/system/services/web_server/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/simulator/system/services/web_server/__init__.py +++ b/src/primaite/simulator/system/services/web_server/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 0df47999..3f8760c4 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -1,8 +1,10 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address from typing import Any, Dict, List, Optional from urllib.parse import urlparse +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.protocols.http import ( HttpRequestMethod, @@ -10,18 +12,25 @@ from primaite.simulator.network.protocols.http import ( HttpResponsePacket, HttpStatusCode, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClientConnection from primaite.simulator.system.services.service import Service from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) -class WebServer(Service): +class WebServer(Service, discriminator="web-server"): """Class used to represent a Web Server Service in simulation.""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for WebServer.""" + + type: str = "web-server" + + config: ConfigSchema = Field(default_factory=lambda: WebServer.ConfigSchema()) + response_codes_this_timestep: List[HttpStatusCode] = [] def describe_state(self) -> Dict: @@ -48,11 +57,11 @@ class WebServer(Service): return super().pre_timestep(timestep) def __init__(self, **kwargs): - kwargs["name"] = "WebServer" - kwargs["protocol"] = IPProtocol.TCP + kwargs["name"] = "web-server" + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] # default for web is port 80 if kwargs.get("port") is None: - kwargs["port"] = Port.HTTP + kwargs["port"] = PORT_LOOKUP["HTTP"] super().__init__(**kwargs) self._install_web_files() @@ -139,7 +148,7 @@ class WebServer(Service): return True # otherwise, try to create db connection - db_client = self.software_manager.software.get("DatabaseClient") + db_client = self.software_manager.software.get("database-client") if db_client is None: return False # database client not installed diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index f1d1b9a1..86b57818 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -1,22 +1,22 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import copy -from abc import abstractmethod +from abc import ABC, abstractmethod from datetime import datetime from enum import Enum from ipaddress import IPv4Address, IPv4Network from typing import Any, Dict, Optional, Set, TYPE_CHECKING, Union from prettytable import MARKDOWN, PrettyTable -from pydantic import Field +from pydantic import BaseModel, ConfigDict, Field from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.file_system.file_system import FileSystem, Folder from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.core.sys_log import SysLog +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port if TYPE_CHECKING: from primaite.simulator.system.core.software_manager import SoftwareManager @@ -70,7 +70,7 @@ class SoftwareCriticality(Enum): "The highest level of criticality." -class Software(SimComponent): +class Software(SimComponent, ABC): """ A base class representing software in a simulator environment. @@ -78,14 +78,22 @@ class Software(SimComponent): It outlines the fundamental attributes and behaviors expected of any software in the simulation. """ + class ConfigSchema(BaseModel, ABC): + """Configurable options for all software.""" + + model_config = ConfigDict(extra="forbid") + starting_health_state: SoftwareHealthState = SoftwareHealthState.GOOD + criticality: SoftwareCriticality = SoftwareCriticality.LOWEST + fixing_duration: int = 2 + + config: ConfigSchema = Field(default_factory=lambda: Software.ConfigSchema()) + name: str "The name of the software." health_state_actual: SoftwareHealthState = SoftwareHealthState.UNUSED "The actual health state of the software." health_state_visible: SoftwareHealthState = SoftwareHealthState.UNUSED "The health state of the software visible to the red agent." - criticality: SoftwareCriticality = SoftwareCriticality.LOWEST - "The criticality level of the software." fixing_count: int = 0 "The count of patches applied to the software, defaults to 0." scanning_count: int = 0 @@ -100,11 +108,13 @@ class Software(SimComponent): "The FileSystem of the Node the Software is installed on." folder: Optional[Folder] = None "The folder on the file system the Software uses." - fixing_duration: int = 2 - "The number of ticks it takes to patch the software." _fixing_countdown: Optional[int] = None "Current number of ticks left to patch the software." + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.health_state_actual = self.config.starting_health_state # don't remove this + def _init_request_manager(self) -> RequestManager: """ Initialise the request manager. @@ -152,7 +162,7 @@ class Software(SimComponent): { "health_state_actual": self.health_state_actual.value, "health_state_visible": self.health_state_visible.value, - "criticality": self.criticality.value, + "criticality": self.config.criticality.value, "fixing_count": self.fixing_count, "scanning_count": self.scanning_count, "revealed_to_red": self.revealed_to_red, @@ -201,7 +211,7 @@ class Software(SimComponent): def fix(self) -> bool: """Perform a fix on the software.""" if self.health_state_actual in (SoftwareHealthState.COMPROMISED, SoftwareHealthState.GOOD): - self._fixing_countdown = self.fixing_duration + self._fixing_countdown = self.config.fixing_duration self.set_health_state(SoftwareHealthState.FIXING) return True return False @@ -233,7 +243,7 @@ class Software(SimComponent): super().pre_timestep(timestep) -class IOSoftware(Software): +class IOSoftware(Software, ABC): """ Represents software in a simulator environment that is capable of input/output operations. @@ -243,6 +253,13 @@ class IOSoftware(Software): required. """ + class ConfigSchema(Software.ConfigSchema, ABC): + """Configuration options for all IO Software.""" + + listen_on_ports: Set[Port] = Field(default_factory=set) + + config: ConfigSchema = Field(default_factory=lambda: IOSoftware.ConfigSchema()) + installing_count: int = 0 "The number of times the software has been installed. Default is 0." max_sessions: int = 100 @@ -260,6 +277,10 @@ class IOSoftware(Software): _connections: Dict[str, Dict] = {} "Active connections." + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.listen_on_ports = self.config.listen_on_ports + @abstractmethod def describe_state(self) -> Dict: """ @@ -277,7 +298,7 @@ class IOSoftware(Software): "max_sessions": self.max_sessions, "tcp": self.tcp, "udp": self.udp, - "port": self.port.value, + "port": self.port, } ) return state @@ -294,7 +315,7 @@ class IOSoftware(Software): """ if self.software_manager and self.software_manager.node.operating_state != NodeOperatingState.ON: self.software_manager.node.sys_log.error( - f"{self.name} Error: {self.software_manager.node.hostname} is not powered on." + f"{self.name} Error: {self.software_manager.node.config.hostname} is not powered on." ) return False return True @@ -386,8 +407,8 @@ class IOSoftware(Software): payload: Any, session_id: Optional[str] = None, dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, - dest_port: Optional[Port] = None, - ip_protocol: IPProtocol = IPProtocol.TCP, + dest_port: Optional[int] = None, + ip_protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"], **kwargs, ) -> bool: """ diff --git a/src/primaite/utils/__init__.py b/src/primaite/utils/__init__.py index 4d7c430e..1dced372 100644 --- a/src/primaite/utils/__init__.py +++ b/src/primaite/utils/__init__.py @@ -1,2 +1,2 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """Utilities for PrimAITE.""" diff --git a/src/primaite/utils/cli/__init__.py b/src/primaite/utils/cli/__init__.py index be6c00e7..836b79af 100644 --- a/src/primaite/utils/cli/__init__.py +++ b/src/primaite/utils/cli/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/utils/cli/dev_cli.py b/src/primaite/utils/cli/dev_cli.py index 8946a4ca..581cd0b1 100644 --- a/src/primaite/utils/cli/dev_cli.py +++ b/src/primaite/utils/cli/dev_cli.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import click import typer from rich import print diff --git a/src/primaite/utils/cli/primaite_config_utils.py b/src/primaite/utils/cli/primaite_config_utils.py index 635be5a7..1fefd0a4 100644 --- a/src/primaite/utils/cli/primaite_config_utils.py +++ b/src/primaite/utils/cli/primaite_config_utils.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Dict, Optional import yaml diff --git a/src/primaite/utils/converters.py b/src/primaite/utils/converters.py index f803851d..95956448 100644 --- a/src/primaite/utils/converters.py +++ b/src/primaite/utils/converters.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from enum import Enum from typing import Any, Dict diff --git a/src/primaite/utils/package_data.py b/src/primaite/utils/package_data.py index af0252f9..ed091dd0 100644 --- a/src/primaite/utils/package_data.py +++ b/src/primaite/utils/package_data.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import os from logging import Logger from pathlib import Path diff --git a/src/primaite/utils/session_metadata_parser.py b/src/primaite/utils/session_metadata_parser.py index f6594666..1a7345ea 100644 --- a/src/primaite/utils/session_metadata_parser.py +++ b/src/primaite/utils/session_metadata_parser.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK # flake8: noqa raise DeprecationWarning( "Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet." diff --git a/src/primaite/utils/session_output_reader.py b/src/primaite/utils/session_output_reader.py index b9ad68a1..f27f6143 100644 --- a/src/primaite/utils/session_output_reader.py +++ b/src/primaite/utils/session_output_reader.py @@ -1,9 +1,8 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK # flake8: noqa raise DeprecationWarning( "Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet." ) -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from pathlib import Path from typing import Any, Dict, Tuple, Union diff --git a/src/primaite/utils/session_output_writer.py b/src/primaite/utils/session_output_writer.py index 75a97f60..4049a8c1 100644 --- a/src/primaite/utils/session_output_writer.py +++ b/src/primaite/utils/session_output_writer.py @@ -1,9 +1,8 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK # flake8: noqa raise DeprecationWarning( "Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet." ) -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK import csv from logging import Logger from typing import Final, List, Tuple, TYPE_CHECKING, Union diff --git a/src/primaite/utils/validation/__init__.py b/src/primaite/utils/validation/__init__.py new file mode 100644 index 00000000..836b79af --- /dev/null +++ b/src/primaite/utils/validation/__init__.py @@ -0,0 +1 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/src/primaite/utils/validation/ip_protocol.py b/src/primaite/utils/validation/ip_protocol.py new file mode 100644 index 00000000..654a5156 --- /dev/null +++ b/src/primaite/utils/validation/ip_protocol.py @@ -0,0 +1,47 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +# Define a custom IP protocol validator +from typing import Any + +from pydantic import BeforeValidator, TypeAdapter, ValidationError +from typing_extensions import Annotated, Final + +PROTOCOL_LOOKUP: dict[str, str] = dict( + NONE="none", + TCP="tcp", + UDP="udp", + ICMP="icmp", +) +""" +Lookup table used for compatibility with PrimAITE <= 3.3. Configs with the capitalised protocol names are converted +to lowercase at runtime. +""" +VALID_PROTOCOLS = ["none", "tcp", "udp", "icmp"] +"""Supported protocols.""" + + +def protocol_validator(v: Any) -> str: + """ + Validate that IP Protocols are chosen from the list of supported IP Protocols. + + The protocol list is dynamic because plugins are able to extend it, therefore it is necessary to use this custom + validator instead of being able to specify a union of string literals. + """ + if isinstance(v, str) and v in PROTOCOL_LOOKUP: + return PROTOCOL_LOOKUP[v] + if v in VALID_PROTOCOLS: + return v + raise ValueError(f"{v} is not a valid IP Protocol. It must be one of the following: {VALID_PROTOCOLS}") + + +IPProtocol: Final[Annotated] = Annotated[str, BeforeValidator(protocol_validator)] +"""Validates that IP Protocols used in the simulation belong to the list of supported protocols.""" +_IPProtocolTypeAdapter = TypeAdapter(IPProtocol) + + +def is_valid_protocol(v: Any) -> bool: + """Convenience method to return true if the value matches the schema, and false otherwise.""" + try: + _IPProtocolTypeAdapter.validate_python(v) + return True + except ValidationError: + return False diff --git a/src/primaite/utils/validators.py b/src/primaite/utils/validation/ipv4_address.py similarity index 79% rename from src/primaite/utils/validators.py rename to src/primaite/utils/validation/ipv4_address.py index 139d303c..1dc6c74e 100644 --- a/src/primaite/utils/validators.py +++ b/src/primaite/utils/validation/ipv4_address.py @@ -1,4 +1,6 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + + from ipaddress import IPv4Address from typing import Any, Final @@ -6,6 +8,9 @@ from pydantic import BeforeValidator from typing_extensions import Annotated +# Define a custom type IPV4Address using the typing_extensions.Annotated. +# Annotated is used to attach metadata to type hints. In this case, it's used to associate the ipv4_validator +# with the IPv4Address type, ensuring that any usage of IPV4Address undergoes validation before assignment. def ipv4_validator(v: Any) -> IPv4Address: """ Validate the input and ensure it can be converted to an IPv4Address instance. @@ -24,12 +29,9 @@ def ipv4_validator(v: Any) -> IPv4Address: return IPv4Address(v) -# Define a custom type IPV4Address using the typing_extensions.Annotated. -# Annotated is used to attach metadata to type hints. In this case, it's used to associate the ipv4_validator -# with the IPv4Address type, ensuring that any usage of IPV4Address undergoes validation before assignment. IPV4Address: Final[Annotated] = Annotated[IPv4Address, BeforeValidator(ipv4_validator)] """ -IPv4Address with with IPv4Address with with pre-validation and auto-conversion from str using ipv4_validator.. +IPv4Address with pre-validation and auto-conversion from str using ipv4_validator.. This type is essentially an IPv4Address from the standard library's ipaddress module, but with added validation logic. If you use this custom type, the ipv4_validator function @@ -37,3 +39,12 @@ will automatically check and convert the input value to an instance of IPv4Addre any Pydantic model uses it. This ensures that any field marked with this type is not just an IPv4Address in form, but also valid according to the rules defined in ipv4_validator. """ + + +def str_ip(value: Any) -> str: + """Make sure it's a valid IP, but represent it as a string.""" + # TODO: this is a bit of a hack, we should change RequestResponse to be able to handle IPV4Address objects + return str(IPV4Address(value)) + + +StrIP: Final[Annotated] = Annotated[str, BeforeValidator(str_ip)] diff --git a/src/primaite/utils/validation/port.py b/src/primaite/utils/validation/port.py new file mode 100644 index 00000000..564e843c --- /dev/null +++ b/src/primaite/utils/validation/port.py @@ -0,0 +1,70 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +# Define a custom port validator +from typing import Any + +from pydantic import BeforeValidator, TypeAdapter, ValidationError +from typing_extensions import Annotated, Final + +PORT_LOOKUP: dict[str, int] = dict( + UNUSED=-1, + NONE=0, + WOL=9, + FTP_DATA=20, + FTP=21, + SSH=22, + SMTP=25, + DNS=53, + HTTP=80, + POP3=110, + SFTP=115, + NTP=123, + IMAP=143, + SNMP=161, + SNMP_TRAP=162, + ARP=219, + LDAP=389, + HTTPS=443, + SMB=445, + IPP=631, + SQL_SERVER=1433, + MYSQL=3306, + RDP=3389, + RTP=5004, + RTP_ALT=5005, + DNS_ALT=5353, + HTTP_ALT=8080, + HTTPS_ALT=8443, + POSTGRES_SERVER=5432, +) +""" +Lookup table used for compatibility with PrimAITE <= 3.3. Configs with named ports names are converted +to port integers at runtime. +""" + + +def port_validator(v: Any) -> int: + """ + Validate that Ports are chosen from the list of supported Ports. + + The protocol list is dynamic because plugins are able to extend it, therefore it is necessary to use this custom + validator instead of being able to specify a union of string literals. + """ + if isinstance(v, str) and v in PORT_LOOKUP: + v = PORT_LOOKUP[v] + if isinstance(v, int) and (0 <= v <= 65535): + return v + raise ValueError(f"{v} is not a valid Port. It must be an integer in the range [0,65535] or ") + + +Port: Final[Annotated] = Annotated[int, BeforeValidator(port_validator)] +"""Validates that network ports lie in the appropriate range of [0,65535].""" +_PortTypeAdapter = TypeAdapter(Port) + + +def is_valid_port(v: Any) -> bool: + """Convenience method to return true if the value matches the schema, and false otherwise.""" + try: + _PortTypeAdapter.validate_python(v) + return True + except ValidationError: + return False diff --git a/tests/__init__.py b/tests/__init__.py index 846ec808..900649b2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from pathlib import Path from typing import Final diff --git a/tests/assets/configs/action_penalty.yaml b/tests/assets/configs/action_penalty.yaml index 1771ba5f..6700172e 100644 --- a/tests/assets/configs/action_penalty.yaml +++ b/tests/assets/configs/action_penalty.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + io_settings: save_agent_actions: false save_step_metadata: false @@ -24,20 +27,20 @@ agents: - ref: defender team: BLUE - type: ProxyAgent + type: proxy-agent observation_space: - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: - hostname: domain_controller - hostname: web_server services: - - service_name: WebServer + - service_name: web-server - hostname: database_server folders: - folder_name: database @@ -69,15 +72,15 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: @@ -91,499 +94,432 @@ agents: - switch_2:eth-1<->client_1:eth-1 - switch_2:eth-2<->client_2:eth-1 - switch_2:eth-7<->security_suite:eth-2 - - type: "NONE" + - type: "none" label: ICS options: {} action_space: - action_list: - - type: DONOTHING - - type: NODE_SERVICE_SCAN - - type: NODE_SERVICE_STOP - - type: NODE_SERVICE_START - - type: NODE_SERVICE_PAUSE - - type: NODE_SERVICE_RESUME - - type: NODE_SERVICE_RESTART - - type: NODE_SERVICE_DISABLE - - type: NODE_SERVICE_ENABLE - - type: NODE_SERVICE_FIX - - type: NODE_FILE_SCAN - - type: NODE_FILE_CHECKHASH - - type: NODE_FILE_DELETE - - type: NODE_FILE_REPAIR - - type: NODE_FILE_RESTORE - - type: NODE_FOLDER_SCAN - - type: NODE_FOLDER_CHECKHASH - - type: NODE_FOLDER_REPAIR - - type: NODE_FOLDER_RESTORE - - type: NODE_OS_SCAN - - type: NODE_SHUTDOWN - - type: NODE_STARTUP - - type: NODE_RESET - - type: ROUTER_ACL_ADDRULE - - type: ROUTER_ACL_REMOVERULE - - type: HOST_NIC_ENABLE - - type: HOST_NIC_DISABLE - action_map: 0: - action: DONOTHING + action: do-nothing options: {} # scan webapp service 1: - action: NODE_SERVICE_SCAN + action: node-service-scan options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # stop webapp service 2: - action: NODE_SERVICE_STOP + action: node-service-stop options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # start webapp service 3: - action: "NODE_SERVICE_START" + action: "node-service-start" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 4: - action: "NODE_SERVICE_PAUSE" + action: "node-service-pause" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 5: - action: "NODE_SERVICE_RESUME" + action: "node-service-resume" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 6: - action: "NODE_SERVICE_RESTART" + action: "node-service-restart" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 7: - action: "NODE_SERVICE_DISABLE" + action: "node-service-disable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 8: - action: "NODE_SERVICE_ENABLE" + action: "node-service-enable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 9: # check database.db file - action: "NODE_FILE_SCAN" + action: "node-file-scan" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 10: - action: "NODE_FILE_CHECKHASH" + action: "node-file-scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 11: - action: "NODE_FILE_DELETE" + action: "node-file-delete" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 12: - action: "NODE_FILE_REPAIR" + action: "node-file-repair" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 13: - action: "NODE_SERVICE_FIX" + action: "node-service-fix" options: - node_id: 2 - service_id: 0 + node_name: database_server + service_name: database-service 14: - action: "NODE_FOLDER_SCAN" + action: "node-folder-scan" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 15: - action: "NODE_FOLDER_CHECKHASH" + action: "node-folder-scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 16: - action: "NODE_FOLDER_REPAIR" + action: "node-folder-repair" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 17: - action: "NODE_FOLDER_RESTORE" + action: "node-folder-restore" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 18: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 0 + node_name: domain_controller 19: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 0 + node_name: domain_controller 20: - action: NODE_STARTUP + action: node-startup options: - node_id: 0 + node_name: domain_controller 21: - action: NODE_RESET + action: node-reset options: - node_id: 0 + node_name: domain_controller 22: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 1 + node_name: web_server 23: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 1 + node_name: web_server 24: - action: NODE_STARTUP + action: node-startup options: - node_id: 1 + node_name: web_server 25: - action: NODE_RESET + action: node-reset options: - node_id: 1 + node_name: web_server 26: # old action num: 18 - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 2 + node_name: database_server 27: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 2 + node_name: database_server 28: - action: NODE_STARTUP + action: node-startup options: - node_id: 2 + node_name: database_server 29: - action: NODE_RESET + action: node-reset options: - node_id: 2 + node_name: database_server 30: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 3 + node_name: backup_server 31: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 3 + node_name: backup_server 32: - action: NODE_STARTUP + action: node-startup options: - node_id: 3 + node_name: backup_server 33: - action: NODE_RESET + action: node-reset options: - node_id: 3 + node_name: backup_server 34: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 4 + node_name: security_suite 35: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 4 + node_name: security_suite 36: - action: NODE_STARTUP + action: node-startup options: - node_id: 4 + node_name: security_suite 37: - action: NODE_RESET + action: node-reset options: - node_id: 4 + node_name: security_suite 38: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 5 + node_name: client_1 39: # old action num: 19 # shutdown client 1 - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 5 + node_name: client_1 40: # old action num: 20 - action: NODE_STARTUP + action: node-startup options: - node_id: 5 + node_name: client_1 41: # old action num: 21 - action: NODE_RESET + action: node-reset options: - node_id: 5 + node_name: client_1 42: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 6 + node_name: client_2 43: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 6 + node_name: client_2 44: - action: NODE_STARTUP + action: node-startup options: - node_id: 6 + node_name: client_2 45: - action: NODE_RESET + action: node-reset options: - node_id: 6 + node_name: client_2 - 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" - action: "ROUTER_ACL_ADDRULE" + 46: # old action num: 22 # "acl: ADDRULE - Block outgoing traffic from client 1" + action: "router-acl-add-rule" options: - target_router_nodename: router_1 + target_router: router_1 position: 1 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" - action: "ROUTER_ACL_ADDRULE" + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE + 47: # old action num: 23 # "acl: ADDRULE - Block outgoing traffic from client 2" + action: "router-acl-add-rule" options: - target_router_nodename: router_1 + target_router: router_1 position: 2 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE 48: # old action num: 24 # block tcp traffic from client 1 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: - target_router_nodename: router_1 + target_router: router_1 position: 3 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 49: # old action num: 25 # block tcp traffic from client 2 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: - target_router_nodename: router_1 + target_router: router_1 position: 4 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 50: # old action num: 26 - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: - target_router_nodename: router_1 + target_router: router_1 position: 5 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 51: # old action num: 27 - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: - target_router_nodename: router_1 + target_router: router_1 position: 6 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 52: # old action num: 28 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_nodename: router_1 + target_router: router_1 position: 0 53: # old action num: 29 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_nodename: router_1 + target_router: router_1 position: 1 54: # old action num: 30 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_nodename: router_1 + target_router: router_1 position: 2 55: # old action num: 31 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_nodename: router_1 + target_router: router_1 position: 3 56: # old action num: 32 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_nodename: router_1 + target_router: router_1 position: 4 57: # old action num: 33 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_nodename: router_1 + target_router: router_1 position: 5 58: # old action num: 34 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_nodename: router_1 + target_router: router_1 position: 6 59: # old action num: 35 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_nodename: router_1 + target_router: router_1 position: 7 60: # old action num: 36 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_nodename: router_1 + target_router: router_1 position: 8 61: # old action num: 37 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_nodename: router_1 + target_router: router_1 position: 9 62: # old action num: 38 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 63: # old action num: 39 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 64: # old action num: 40 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 65: # old action num: 41 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 66: # old action num: 42 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 67: # old action num: 43 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 68: # old action num: 44 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 69: # old action num: 45 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 70: # old action num: 46 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 71: # old action num: 47 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 72: # old action num: 48 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 73: # old action num: 49 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 74: # old action num: 50 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 75: # old action num: 51 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 76: # old action num: 52 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 6 - nic_id: 0 + node_name: client_2 + nic_num: 1 77: # old action num: 53 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 6 - nic_id: 0 - - - - options: - nodes: - - node_name: domain_controller - - node_name: web_server - applications: - - application_name: DatabaseClient - services: - - service_name: WebServer - - node_name: database_server - folders: - - folder_name: database - files: - - file_name: database.db - services: - - service_name: DatabaseService - - node_name: backup_server - - node_name: security_suite - - node_name: client_1 - - node_name: client_2 - - max_folders_per_node: 2 - max_files_per_folder: 2 - max_services_per_node: 2 - max_nics_per_node: 8 - max_acl_rules: 10 - ip_list: - - 192.168.1.10 - - 192.168.1.12 - - 192.168.1.14 - - 192.168.1.16 - - 192.168.1.110 - - 192.168.10.21 - - 192.168.10.22 - - 192.168.10.110 - + node_name: client_2 + nic_num: 1 reward_function: reward_components: - - type: ACTION_PENALTY + - type: action-penalty weight: 1.0 options: action_penalty: -0.75 @@ -652,7 +588,7 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 services: - - type: DNSServer + - type: dns-server options: domain_mapping: arcd.com: 192.168.1.12 # web server @@ -664,9 +600,9 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: WebServer + - type: web-server applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 @@ -678,10 +614,10 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: DatabaseService + - type: database-service options: backup_server_ip: 192.168.1.16 - - type: FTPClient + - type: ftp-client - hostname: backup_server type: server @@ -690,7 +626,7 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: FTPServer + - type: ftp-server - hostname: security_suite type: server @@ -710,20 +646,20 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: DataManipulationBot + - type: data-manipulation-bot options: port_scan_p_of_success: 0.8 data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 - - type: WebBrowser + - type: web-browser options: target_url: http://arcd.com/users/ - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 services: - - type: DNSClient + - type: dns-client - hostname: client_2 type: computer @@ -732,20 +668,20 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: WebBrowser + - type: web-browser options: target_url: http://arcd.com/users/ - - type: DataManipulationBot + - type: data-manipulation-bot options: port_scan_p_of_success: 0.8 data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 services: - - type: DNSClient + - type: dns-client diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index c83cadc8..b8551caf 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + game: ports: - ARP @@ -12,78 +15,37 @@ game: agents: - ref: client_2_green_user team: GREEN - type: ProbabilisticAgent - observation_space: null - action_space: - action_list: - - type: DONOTHING - options: - nodes: - - node_name: client_2 - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 - - reward_function: - reward_components: - - type: DUMMY - + type: probabilistic-agent agent_settings: # options specific to this particular agent type, basically args of __init__(self) - start_settings: - start_step: 25 - frequency: 20 - variance: 5 + action_probabilities: + 0: 1.0 - ref: data_manipulation_attacker team: RED - type: RedDatabaseCorruptingAgent - - observation_space: null - - action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - - type: NODE_FILE_DELETE - - type: NODE_FILE_CORRUPT - - type: NODE_OS_SCAN - options: - nodes: - - node_name: client_1 - applications: - - application_name: DataManipulationBot - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - - reward_function: - reward_components: - - type: DUMMY - + type: red-database-corrupting-agent agent_settings: # options specific to this particular agent type, basically args of __init__(self) - start_settings: - start_step: 25 - frequency: 20 - variance: 5 + possible_start_nodes: [client_1,] + target_application: data-manipulation-bot + start_step: 25 + frequency: 20 + variance: 5 - ref: defender team: BLUE - type: ProxyAgent + type: proxy-agent observation_space: - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: - hostname: domain_controller - hostname: web_server services: - - service_name: WebServer + - service_name: web-server - hostname: database_server folders: - folder_name: database @@ -115,15 +77,15 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: @@ -137,390 +99,336 @@ agents: - switch_2:eth-1<->client_1:eth-1 - switch_2:eth-2<->client_2:eth-1 - switch_2:eth-7<->security_suite:eth-2 - - type: "NONE" + - type: "none" label: ICS options: {} action_space: - action_list: - - type: DONOTHING - - type: NODE_SERVICE_SCAN - - type: NODE_SERVICE_STOP - - type: NODE_SERVICE_START - - type: NODE_SERVICE_PAUSE - - type: NODE_SERVICE_RESUME - - type: NODE_SERVICE_RESTART - - type: NODE_SERVICE_DISABLE - - type: NODE_SERVICE_ENABLE - - type: NODE_SERVICE_FIX - - type: NODE_FILE_SCAN - - type: NODE_FILE_CHECKHASH - - type: NODE_FILE_DELETE - - type: NODE_FILE_REPAIR - - type: NODE_FILE_RESTORE - - type: NODE_FOLDER_SCAN - - type: NODE_FOLDER_CHECKHASH - - type: NODE_FOLDER_REPAIR - - type: NODE_FOLDER_RESTORE - - type: NODE_OS_SCAN - - type: NODE_SHUTDOWN - - type: NODE_STARTUP - - type: NODE_RESET - - type: ROUTER_ACL_ADDRULE - - type: ROUTER_ACL_REMOVERULE - - type: HOST_NIC_ENABLE - - type: HOST_NIC_DISABLE action_map: 0: - action: DONOTHING + action: do-nothing options: {} # scan webapp service 1: - action: NODE_SERVICE_SCAN + action: node-service-scan options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # stop webapp service 2: - action: NODE_SERVICE_STOP + action: node-service-stop options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # start webapp service 3: - action: "NODE_SERVICE_START" + action: "node-service-start" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 4: - action: "NODE_SERVICE_PAUSE" + action: "node-service-pause" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 5: - action: "NODE_SERVICE_RESUME" + action: "node-service-resume" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 6: - action: "NODE_SERVICE_RESTART" + action: "node-service-restart" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 7: - action: "NODE_SERVICE_DISABLE" + action: "node-service-disable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 8: - action: "NODE_SERVICE_ENABLE" + action: "node-service-enable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 9: # check database.db file - action: "NODE_FILE_SCAN" + action: "node-file-scan" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 10: - action: "NODE_FILE_CHECKHASH" + action: "node-file-checkhash" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 11: - action: "NODE_FILE_DELETE" + action: "node-file-delete" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 12: - action: "NODE_FILE_REPAIR" + action: "node-file-repair" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 13: - action: "NODE_SERVICE_FIX" + action: "node-service-fix" options: - node_id: 2 - service_id: 0 + node_name: database_server + service_name: database-service 14: - action: "NODE_FOLDER_SCAN" + action: "node-folder-scan" options: - node_id: 2 - folder_id: 1 + node_name: database_server + folder_name: database 15: - action: "NODE_FOLDER_CHECKHASH" + action: "node-folder-checkhash" options: - node_id: 2 - folder_id: 1 + node_name: database_server + folder_name: database 16: - action: "NODE_FOLDER_REPAIR" + action: "node-folder-repair" options: - node_id: 2 - folder_id: 1 + node_name: database_server + folder_name: database 17: - action: "NODE_FOLDER_RESTORE" + action: "node-folder-restore" options: - node_id: 2 - folder_id: 1 + node_name: database_server + folder_name: database 18: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 2 + node_name: database_server 19: # shutdown client 1 - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 5 + node_name: client_1 20: - action: "NODE_STARTUP" + action: "node-startup" options: - node_id: 5 + node_name: client_1 21: - action: "NODE_RESET" + action: "node-reset" options: - node_id: 5 - 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite) - action: "ROUTER_ACL_ADDRULE" + node_name: client_1 + 22: # "acl: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite) + action: "router-acl-add-rule" options: target_router: router_1 position: 1 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite) - action: "ROUTER_ACL_ADDRULE" + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE + 23: # "acl: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite) + action: "router-acl-add-rule" options: target_router: router_1 position: 2 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE 24: # block tcp traffic from client 1 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 3 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 25: # block tcp traffic from client 2 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 4 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 26: - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 5 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 27: - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 6 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 28: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 0 29: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 1 30: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 2 31: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 3 32: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 4 33: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 5 34: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 6 35: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 7 36: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 8 37: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 9 38: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 39: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 40: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 41: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 42: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 43: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 44: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 45: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 46: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 47: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 48: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 49: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 50: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 51: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 52: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 6 - nic_id: 0 + node_name: client_2 + nic_num: 1 53: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 6 - nic_id: 0 - - - - options: - nodes: - - node_name: domain_controller - - node_name: web_server - - node_name: database_server - - node_name: backup_server - - node_name: security_suite - - node_name: client_1 - - node_name: client_2 - max_folders_per_node: 2 - max_files_per_folder: 2 - max_services_per_node: 2 - max_nics_per_node: 8 - max_acl_rules: 10 - ip_list: - - 192.168.1.10 - - 192.168.1.12 - - 192.168.1.14 - - 192.168.1.16 - - 192.168.1.110 - - 192.168.10.21 - - 192.168.10.22 - - 192.168.10.110 + node_name: client_2 + nic_num: 1 reward_function: reward_components: - - type: DATABASE_FILE_INTEGRITY + - type: database-file-integrity weight: 0.5 options: node_hostname: database_server @@ -528,7 +436,7 @@ agents: file_name: database.db - - type: WEB_SERVER_404_PENALTY + - type: web-server-404-penalty weight: 0.5 options: node_hostname: web_server @@ -587,7 +495,7 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 services: - - type: DNSServer + - type: dns-server options: domain_mapping: arcd.com: 192.168.1.12 # web server @@ -599,9 +507,9 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: WebServer + - type: web-server applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 @@ -613,7 +521,7 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: DatabaseService + - type: database-service - type: server hostname: backup_server @@ -622,7 +530,7 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: FTPServer + - type: ftp-server - type: server hostname: security_suite @@ -642,14 +550,14 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: DataManipulationBot + - type: data-manipulation-bot options: port_scan_p_of_success: 0.1 data_manipulation_p_of_success: 0.1 payload: "DELETE" server_ip: 192.168.1.14 services: - - type: DNSClient + - type: dns-client - type: computer hostname: client_2 @@ -658,9 +566,9 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: WebBrowser + - type: web-browser services: - - type: DNSClient + - type: dns-client links: - endpoint_a_hostname: router_1 diff --git a/tests/assets/configs/basic_c2_setup.yaml b/tests/assets/configs/basic_c2_setup.yaml index 0cae2ba0..9b569b44 100644 --- a/tests/assets/configs/basic_c2_setup.yaml +++ b/tests/assets/configs/basic_c2_setup.yaml @@ -4,6 +4,9 @@ # | node_a |------| switch_1 |------| node_b | # -------------- -------------- -------------- # +metadata: + version: 3.0 + io_settings: save_step_metadata: false save_pcap_logs: true @@ -40,7 +43,7 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 192.168.10.1 applications: - - type: C2Server + - type: c2-server options: listen_on_ports: - 80 @@ -52,7 +55,7 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 192.168.10.1 applications: - - type: C2Beacon + - type: c2-beacon options: c2_server_ip_address: 192.168.10.21 keep_alive_frequency: 5 diff --git a/tests/assets/configs/basic_firewall.yaml b/tests/assets/configs/basic_firewall.yaml index 0253a4d2..26038270 100644 --- a/tests/assets/configs/basic_firewall.yaml +++ b/tests/assets/configs/basic_firewall.yaml @@ -4,6 +4,8 @@ # | client_1 |------| switch_1 |------| client_2 | # -------------- -------------- -------------- # +metadata: + version: 3.0 io_settings: save_step_metadata: false @@ -26,40 +28,23 @@ game: agents: - ref: client_2_green_user team: GREEN - type: ProbabilisticAgent - observation_space: null + type: probabilistic-agent + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 - options: - nodes: - - node_name: client_2 - applications: - - application_name: WebBrowser - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 1 - - reward_function: - reward_components: - - type: DUMMY + node_name: client_2 + application_name: web-browser agent_settings: - start_settings: - start_step: 5 - frequency: 4 - variance: 3 + action_probabilities: + 0: 0.4 + 1: 0.6 simulation: network: diff --git a/tests/assets/configs/basic_node_with_software_listening_ports.yaml b/tests/assets/configs/basic_node_with_software_listening_ports.yaml index 53eee87f..6372de54 100644 --- a/tests/assets/configs/basic_node_with_software_listening_ports.yaml +++ b/tests/assets/configs/basic_node_with_software_listening_ports.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + io_settings: save_step_metadata: false save_pcap_logs: true @@ -26,13 +29,13 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 192.168.10.1 services: - - type: DatabaseService + - type: database-service options: backup_server_ip: 10.10.1.12 listen_on_ports: - 631 applications: - - type: WebBrowser + - type: web-browser options: target_url: http://sometech.ai listen_on_ports: diff --git a/tests/assets/configs/basic_node_with_users.yaml b/tests/assets/configs/basic_node_with_users.yaml index 064519dd..20331ff2 100644 --- a/tests/assets/configs/basic_node_with_users.yaml +++ b/tests/assets/configs/basic_node_with_users.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + io_settings: save_step_metadata: false save_pcap_logs: true diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index 03cf2207..76b4dfb9 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -4,6 +4,9 @@ # | client_1 |------| switch_1 |------| client_2 | # -------------- -------------- -------------- # +metadata: + version: 3.0 + io_settings: save_step_metadata: false save_pcap_logs: true @@ -41,52 +44,37 @@ game: agents: - ref: client_2_green_user team: GREEN - type: ProbabilisticAgent - observation_space: null + type: probabilistic-agent + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 - options: - nodes: - - node_name: client_2 - applications: - - application_name: WebBrowser - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 1 + node_name: client_2 + application_name: web-browser reward_function: reward_components: - - type: DUMMY + - type: dummy agent_settings: - start_settings: - start_step: 5 - frequency: 4 - variance: 3 - - + action_probabilities: + 0: 0.4 + 1: 0.6 - ref: defender team: BLUE - type: ProxyAgent + type: proxy-agent observation_space: - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: @@ -121,51 +109,33 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: - switch_1:eth-1<->client_1:eth-1 - switch_1:eth-2<->client_2:eth-1 - - type: "NONE" + - type: "none" label: ICS options: {} action_space: - action_list: - - type: DONOTHING - action_map: 0: - action: DONOTHING + action: do-nothing options: {} - options: - nodes: - - node_name: switch - - node_name: client_1 - - node_name: client_2 - - node_name: client_3 - max_folders_per_node: 2 - max_files_per_folder: 2 - max_services_per_node: 2 - max_nics_per_node: 8 - max_acl_rules: 10 - ip_list: - - 192.168.10.21 - - 192.168.10.22 - - 192.168.10.23 reward_function: reward_components: - - type: DATABASE_FILE_INTEGRITY + - type: database-file-integrity weight: 0.5 options: node_hostname: database_server @@ -173,7 +143,7 @@ agents: file_name: database.db - - type: WEB_SERVER_404_PENALTY + - type: web-server-404-penalty weight: 0.5 options: node_hostname: web_server @@ -198,48 +168,45 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: RansomwareScript - - type: WebBrowser + - type: ransomware-script + - type: web-browser options: target_url: http://arcd.com/users/ - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.10 server_password: arcd - - type: DataManipulationBot + - type: data-manipulation-bot options: port_scan_p_of_success: 0.8 data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.21 server_password: arcd - - type: DoSBot + - type: dos-bot options: target_ip_address: 192.168.10.21 payload: SPOOF DATA port_scan_p_of_success: 0.8 services: - - type: DNSClient - options: - dns_server: 192.168.1.10 - - type: DNSServer + - type: dns-client + - type: dns-server options: domain_mapping: arcd.com: 192.168.1.10 - - type: DatabaseService + - type: database-service options: backup_server_ip: 192.168.1.10 - - type: WebServer - - type: FTPServer - options: - server_password: arcd - - type: NTPClient + - type: web-server + - type: ftp-server + - type: ntp-client options: ntp_server_ip: 192.168.1.10 - - type: NTPServer - file_system: - - root: - - "test.txt" + - type: ntp-server + folders: + - folder_name: root + files: + - file_name: test.txt - hostname: client_2 type: computer ip_address: 192.168.10.22 diff --git a/tests/assets/configs/data_manipulation.yaml b/tests/assets/configs/data_manipulation.yaml index 97442903..59f97644 100644 --- a/tests/assets/configs/data_manipulation.yaml +++ b/tests/assets/configs/data_manipulation.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + io_settings: save_agent_actions: true save_step_metadata: false @@ -24,98 +27,72 @@ game: agents: - ref: client_2_green_user team: GREEN - type: ProbabilisticAgent + type: probabilistic-agent agent_settings: action_probabilities: 0: 0.3 1: 0.6 2: 0.1 - observation_space: null + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_2 - applications: - - application_name: WebBrowser - - application_name: DatabaseClient - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 2 action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 + node_name: client_2 + application_name: web-browser 2: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 1 + node_name: client_2 + application_name: database-client reward_function: reward_components: - - type: WEBPAGE_UNAVAILABLE_PENALTY + - type: webpage-unavailable-penalty weight: 0.25 options: node_hostname: client_2 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + - type: green-admin-database-unreachable-penalty weight: 0.05 options: node_hostname: client_2 - ref: client_1_green_user team: GREEN - type: ProbabilisticAgent + type: probabilistic-agent agent_settings: action_probabilities: 0: 0.3 1: 0.6 2: 0.1 - observation_space: null + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_1 - applications: - - application_name: WebBrowser - - application_name: DatabaseClient - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 2 action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 + node_name: client_1 + application_name: web-browser 2: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 1 + node_name: client_1 + application_name: web-browser reward_function: reward_components: - - type: WEBPAGE_UNAVAILABLE_PENALTY + - type: webpage-unavailable-penalty weight: 0.25 options: node_hostname: client_1 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + - type: green-admin-database-unreachable-penalty weight: 0.05 options: node_hostname: client_1 @@ -126,52 +103,30 @@ agents: - ref: data_manipulation_attacker team: RED - type: RedDatabaseCorruptingAgent - - observation_space: null - - action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_1 - applications: - - application_name: DataManipulationBot - - node_name: client_2 - applications: - - application_name: DataManipulationBot - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - - reward_function: - reward_components: - - type: DUMMY - + type: red-database-corrupting-agent agent_settings: # options specific to this particular agent type, basically args of __init__(self) - start_settings: - start_step: 25 - frequency: 20 - variance: 5 + possible_start_nodes: [client_1, client_2] + target_application: data-manipulation-bot + start_step: 25 + frequency: 20 + variance: 5 - ref: defender team: BLUE - type: ProxyAgent + type: proxy-agent observation_space: - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: - hostname: domain_controller - hostname: web_server services: - - service_name: WebServer + - service_name: web-server - hostname: database_server folders: - folder_name: database @@ -208,15 +163,15 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: @@ -230,511 +185,444 @@ agents: - switch_2:eth-1<->client_1:eth-1 - switch_2:eth-2<->client_2:eth-1 - switch_2:eth-7<->security_suite:eth-2 - - type: "NONE" + - type: "none" label: ICS options: {} action_space: - action_list: - - type: DONOTHING - - type: NODE_SERVICE_SCAN - - type: NODE_SERVICE_STOP - - type: NODE_SERVICE_START - - type: NODE_SERVICE_PAUSE - - type: NODE_SERVICE_RESUME - - type: NODE_SERVICE_RESTART - - type: NODE_SERVICE_DISABLE - - type: NODE_SERVICE_ENABLE - - type: NODE_SERVICE_FIX - - type: NODE_FILE_SCAN - - type: NODE_FILE_CHECKHASH - - type: NODE_FILE_DELETE - - type: NODE_FILE_REPAIR - - type: NODE_FILE_RESTORE - - type: NODE_FOLDER_SCAN - - type: NODE_FOLDER_CHECKHASH - - type: NODE_FOLDER_REPAIR - - type: NODE_FOLDER_RESTORE - - type: NODE_OS_SCAN - - type: NODE_SHUTDOWN - - type: NODE_STARTUP - - type: NODE_RESET - - type: ROUTER_ACL_ADDRULE - - type: ROUTER_ACL_REMOVERULE - - type: HOST_NIC_ENABLE - - type: HOST_NIC_DISABLE - action_map: 0: - action: DONOTHING + action: do-nothing options: {} # scan webapp service 1: - action: NODE_SERVICE_SCAN + action: node-service-scan options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # stop webapp service 2: - action: NODE_SERVICE_STOP + action: node-service-stop options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # start webapp service 3: - action: "NODE_SERVICE_START" + action: "node-service-start" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 4: - action: "NODE_SERVICE_PAUSE" + action: "node-service-pause" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 5: - action: "NODE_SERVICE_RESUME" + action: "node-service-resume" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 6: - action: "NODE_SERVICE_RESTART" + action: "node-service-restart" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 7: - action: "NODE_SERVICE_DISABLE" + action: "node-service-disable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 8: - action: "NODE_SERVICE_ENABLE" + action: "node-service-enable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 9: # check database.db file - action: "NODE_FILE_SCAN" + action: "node-file-scan" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 10: - action: "NODE_FILE_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + action: "node-file-scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 11: - action: "NODE_FILE_DELETE" + action: "node-file-delete" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 12: - action: "NODE_FILE_REPAIR" + action: "node-file-repair" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 13: - action: "NODE_SERVICE_FIX" + action: "node-service-fix" options: - node_id: 2 - service_id: 0 + node_name: database_server + service_name: database-service 14: - action: "NODE_FOLDER_SCAN" + action: "node-folder-scan" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 15: - action: "NODE_FOLDER_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + action: "node-folder-scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 16: - action: "NODE_FOLDER_REPAIR" + action: "node-folder-repair" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 17: - action: "NODE_FOLDER_RESTORE" + action: "node-folder-restore" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 18: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 0 + node_name: domain_controller 19: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 0 + node_name: domain_controller 20: - action: NODE_STARTUP + action: node-startup options: - node_id: 0 + node_name: domain_controller 21: - action: NODE_RESET + action: node-reset options: - node_id: 0 + node_name: domain_controller 22: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 1 + node_name: web_server 23: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 1 + node_name: web_server 24: - action: NODE_STARTUP + action: node-startup options: - node_id: 1 + node_name: web_server 25: - action: NODE_RESET + action: node-reset options: - node_id: 1 + node_name: web_server 26: # old action num: 18 - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 2 + node_name: database_server 27: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 2 + node_name: database_server 28: - action: NODE_STARTUP + action: node-startup options: - node_id: 2 + node_name: database_server 29: - action: NODE_RESET + action: node-reset options: - node_id: 2 + node_name: database_server 30: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 3 + node_name: backup_server 31: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 3 + node_name: backup_server 32: - action: NODE_STARTUP + action: node-startup options: - node_id: 3 + node_name: backup_server 33: - action: NODE_RESET + action: node-reset options: - node_id: 3 + node_name: backup_server 34: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 4 + node_name: security_suite 35: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 4 + node_name: security_suite 36: - action: NODE_STARTUP + action: node-startup options: - node_id: 4 + node_name: security_suite 37: - action: NODE_RESET + action: node-reset options: - node_id: 4 + node_name: security_suite 38: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 5 + node_name: client_1 39: # old action num: 19 # shutdown client 1 - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 5 + node_name: client_1 40: # old action num: 20 - action: NODE_STARTUP + action: node-startup options: - node_id: 5 + node_name: client_1 41: # old action num: 21 - action: NODE_RESET + action: node-reset options: - node_id: 5 + node_name: client_1 42: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 6 + node_name: client_2 43: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 6 + node_name: client_2 44: - action: NODE_STARTUP + action: node-startup options: - node_id: 6 + node_name: client_2 45: - action: NODE_RESET + action: node-reset options: - node_id: 6 + node_name: client_2 - 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" - action: "ROUTER_ACL_ADDRULE" + 46: # old action num: 22 # "acl: ADDRULE - Block outgoing traffic from client 1" + action: "router-acl-add-rule" options: target_router: router_1 position: 1 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" - action: "ROUTER_ACL_ADDRULE" + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE + 47: # old action num: 23 # "acl: ADDRULE - Block outgoing traffic from client 2" + action: "router-acl-add-rule" options: target_router: router_1 position: 2 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE 48: # old action num: 24 # block tcp traffic from client 1 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 3 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 49: # old action num: 25 # block tcp traffic from client 2 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 4 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 50: # old action num: 26 - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 5 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 51: # old action num: 27 - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 6 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 52: # old action num: 28 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 0 53: # old action num: 29 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 1 54: # old action num: 30 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 2 55: # old action num: 31 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 3 56: # old action num: 32 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 4 57: # old action num: 33 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 5 58: # old action num: 34 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 6 59: # old action num: 35 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 7 60: # old action num: 36 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 8 61: # old action num: 37 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 9 62: # old action num: 38 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 63: # old action num: 39 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 64: # old action num: 40 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 65: # old action num: 41 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 66: # old action num: 42 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 67: # old action num: 43 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 68: # old action num: 44 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 69: # old action num: 45 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 70: # old action num: 46 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 71: # old action num: 47 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 72: # old action num: 48 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 73: # old action num: 49 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 74: # old action num: 50 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 75: # old action num: 51 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 76: # old action num: 52 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 6 - nic_id: 0 + node_name: client_2 + nic_num: 1 77: # old action num: 53 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 6 - nic_id: 0 - - - - options: - nodes: - - node_name: domain_controller - - node_name: web_server - applications: - - application_name: DatabaseClient - services: - - service_name: WebServer - - node_name: database_server - folders: - - folder_name: database - files: - - file_name: database.db - services: - - service_name: DatabaseService - - node_name: backup_server - - node_name: security_suite - - node_name: client_1 - - node_name: client_2 - - max_folders_per_node: 2 - max_files_per_folder: 2 - max_services_per_node: 2 - max_nics_per_node: 8 - max_acl_rules: 10 - ip_list: - - 192.168.1.10 - - 192.168.1.12 - - 192.168.1.14 - - 192.168.1.16 - - 192.168.1.110 - - 192.168.10.21 - - 192.168.10.22 - - 192.168.10.110 - + node_name: client_2 + nic_num: 1 reward_function: reward_components: - - type: DATABASE_FILE_INTEGRITY + - type: database-file-integrity weight: 0.40 options: node_hostname: database_server folder_name: database file_name: database.db - - type: SHARED_REWARD + - type: shared-reward weight: 1.0 options: agent_name: client_1_green_user - - type: SHARED_REWARD + - type: shared-reward weight: 1.0 options: agent_name: client_2_green_user @@ -804,7 +692,7 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 services: - - type: DNSServer + - type: dns-server options: domain_mapping: arcd.com: 192.168.1.12 # web server @@ -816,9 +704,9 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: WebServer + - type: web-server applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 @@ -830,10 +718,10 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: DatabaseService + - type: database-service options: backup_server_ip: 192.168.1.16 - - type: FTPClient + - type: ftp-client - hostname: backup_server type: server @@ -842,7 +730,7 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: FTPServer + - type: ftp-server - hostname: security_suite type: server @@ -862,20 +750,20 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: DataManipulationBot + - type: data-manipulation-bot options: port_scan_p_of_success: 0.8 data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 - - type: WebBrowser + - type: web-browser options: target_url: http://arcd.com/users/ - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 services: - - type: DNSClient + - type: dns-client - hostname: client_2 type: computer @@ -884,20 +772,20 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: WebBrowser + - type: web-browser options: target_url: http://arcd.com/users/ - - type: DataManipulationBot + - type: data-manipulation-bot options: port_scan_p_of_success: 0.8 data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 services: - - type: DNSClient + - type: dns-client links: - endpoint_a_hostname: router_1 diff --git a/tests/assets/configs/dmz_network.yaml b/tests/assets/configs/dmz_network.yaml index 52316260..0accb3b2 100644 --- a/tests/assets/configs/dmz_network.yaml +++ b/tests/assets/configs/dmz_network.yaml @@ -30,6 +30,9 @@ # | external_computer |------| switch_3 |------| external_server | # ----------------------- -------------- --------------------- # +metadata: + version: 3.0 + io_settings: save_step_metadata: false save_pcap_logs: true @@ -51,40 +54,23 @@ game: agents: - ref: client_1_green_user team: GREEN - type: ProbabilisticAgent - observation_space: null + type: probabilistic-agent + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 - options: - nodes: - - node_name: client_1 - applications: - - application_name: WebBrowser - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 1 - - reward_function: - reward_components: - - type: DUMMY + node_name: client_1 + application_name: web-browser agent_settings: - start_settings: - start_step: 5 - frequency: 4 - variance: 3 + action_probabilities: + 0: 0.4 + 1: 0.6 simulation: @@ -240,7 +226,7 @@ simulation: start_up_duration: 0 shut_down_duration: 0 services: - - type: DNSServer + - type: dns-server links: - endpoint_a_hostname: client_1 endpoint_a_port: 1 diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index 3d60eb6e..6085b1e7 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + game: ports: - ARP @@ -12,90 +15,65 @@ game: agents: - ref: client_2_green_user team: GREEN - type: ProbabilisticAgent - observation_space: null + type: probabilistic-agent + action_space: - action_list: - - type: DONOTHING action_map: 0: - action: DONOTHING + action: do-nothing options: {} - options: - nodes: - - node_name: client_2 - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 reward_function: reward_components: - - type: DUMMY + - type: dummy agent_settings: # options specific to this particular agent type, basically args of __init__(self) - start_settings: - start_step: 25 - frequency: 20 - variance: 5 + action_probabilities: + 0: 1.0 - ref: data_manipulation_attacker team: RED - type: RedDatabaseCorruptingAgent + type: red-database-corrupting-agent + - observation_space: null action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - - type: NODE_FILE_DELETE - - type: NODE_FILE_CORRUPT - - type: NODE_OS_SCAN action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 - options: - nodes: - - node_name: client_1 - applications: - - application_name: DataManipulationBot - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 + node_name: client_1 + application_name: data-manipulation-bot reward_function: reward_components: - - type: DUMMY + - type: dummy agent_settings: # options specific to this particular agent type, basically args of __init__(self) - start_settings: - start_step: 25 - frequency: 20 - variance: 5 + possible_start_nodes: [client_1,] + target_application: data-manipulation-bot + start_step: 25 + frequency: 20 + variance: 5 - ref: defender team: BLUE - type: ProxyAgent + type: proxy-agent observation_space: - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: - hostname: domain_controller - hostname: web_server services: - - service_name: WebServer + - service_name: web-server - hostname: database_server folders: - folder_name: database @@ -127,15 +105,15 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: @@ -149,390 +127,336 @@ agents: - switch_2:eth-1<->client_1:eth-1 - switch_2:eth-2<->client_2:eth-1 - switch_2:eth-7<->security_suite:eth-2 - - type: "NONE" + - type: "none" label: ICS options: {} action_space: - action_list: - - type: DONOTHING - - type: NODE_SERVICE_SCAN - - type: NODE_SERVICE_STOP - - type: NODE_SERVICE_START - - type: NODE_SERVICE_PAUSE - - type: NODE_SERVICE_RESUME - - type: NODE_SERVICE_RESTART - - type: NODE_SERVICE_DISABLE - - type: NODE_SERVICE_ENABLE - - type: NODE_SERVICE_FIX - - type: NODE_FILE_SCAN - - type: NODE_FILE_CHECKHASH - - type: NODE_FILE_DELETE - - type: NODE_FILE_REPAIR - - type: NODE_FILE_RESTORE - - type: NODE_FOLDER_SCAN - - type: NODE_FOLDER_CHECKHASH - - type: NODE_FOLDER_REPAIR - - type: NODE_FOLDER_RESTORE - - type: NODE_OS_SCAN - - type: NODE_SHUTDOWN - - type: NODE_STARTUP - - type: NODE_RESET - - type: ROUTER_ACL_ADDRULE - - type: ROUTER_ACL_REMOVERULE - - type: HOST_NIC_ENABLE - - type: HOST_NIC_DISABLE action_map: 0: - action: DONOTHING + action: do-nothing options: {} # scan webapp service 1: - action: NODE_SERVICE_SCAN + action: node-service-scan options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # stop webapp service 2: - action: NODE_SERVICE_STOP + action: node-service-stop options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # start webapp service 3: - action: "NODE_SERVICE_START" + action: "node-service-start" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 4: - action: "NODE_SERVICE_PAUSE" + action: "node-service-pause" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 5: - action: "NODE_SERVICE_RESUME" + action: "node-service-resume" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 6: - action: "NODE_SERVICE_RESTART" + action: "node-service-restart" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 7: - action: "NODE_SERVICE_DISABLE" + action: "node-service-disable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 8: - action: "NODE_SERVICE_ENABLE" + action: "node-service-enable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 9: # check database.db file - action: "NODE_FILE_SCAN" + action: "node-file-scan" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 10: - action: "NODE_FILE_CHECKHASH" + action: "node-file-checkhash" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 11: - action: "NODE_FILE_DELETE" + action: "node-file-delete" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 12: - action: "NODE_FILE_REPAIR" + action: "node-file-repair" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 13: - action: "NODE_SERVICE_FIX" + action: "node-service-fix" options: - node_id: 2 - service_id: 0 + node_name: database_server + service_name: database-service 14: - action: "NODE_FOLDER_SCAN" + action: "node-folder-scan" options: - node_id: 2 - folder_id: 1 + node_name: database_server + folder_name: database 15: - action: "NODE_FOLDER_CHECKHASH" + action: "node-folder-checkhash" options: - node_id: 2 - folder_id: 1 + node_name: database_server + folder_name: database 16: - action: "NODE_FOLDER_REPAIR" + action: "node-folder-repair" options: - node_id: 2 - folder_id: 1 + node_name: database_server + folder_name: database 17: - action: "NODE_FOLDER_RESTORE" + action: "node-folder-restore" options: - node_id: 2 - folder_id: 1 + node_name: database_server + folder_name: database 18: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 2 + node_name: database_server 19: # shutdown client 1 - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 5 + node_name: client_1 20: - action: "NODE_STARTUP" + action: "node-startup" options: - node_id: 5 + node_name: client_1 21: - action: "NODE_RESET" + action: "node-reset" options: - node_id: 5 - 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite) - action: "ROUTER_ACL_ADDRULE" + node_name: client_1 + 22: # "acl: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite) + action: "router-acl-add-rule" options: target_router: router_1 position: 1 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite) - action: "ROUTER_ACL_ADDRULE" + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE + 23: # "acl: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite) + action: "router-acl-add-rule" options: target_router: router_1 position: 2 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE 24: # block tcp traffic from client 1 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 3 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 25: # block tcp traffic from client 2 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 4 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 26: - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 5 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 27: - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 6 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 28: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 0 29: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 1 30: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 2 31: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 3 32: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 4 33: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 5 34: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 6 35: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 7 36: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 8 37: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 9 38: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 39: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 40: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 41: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 42: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 43: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 44: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 45: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 46: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 47: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 48: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 49: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 50: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 51: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 52: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 6 - nic_id: 0 + node_name: client_2 + nic_num: 1 53: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 6 - nic_id: 0 - - - - options: - nodes: - - node_name: domain_controller - - node_name: web_server - - node_name: database_server - - node_name: backup_server - - node_name: security_suite - - node_name: client_1 - - node_name: client_2 - max_folders_per_node: 2 - max_files_per_folder: 2 - max_services_per_node: 2 - max_nics_per_node: 8 - max_acl_rules: 10 - ip_list: - - 192.168.1.10 - - 192.168.1.12 - - 192.168.1.14 - - 192.168.1.16 - - 192.168.1.110 - - 192.168.10.21 - - 192.168.10.22 - - 192.168.10.110 + node_name: client_2 + nic_num: 1 reward_function: reward_components: - - type: DATABASE_FILE_INTEGRITY + - type: database-file-integrity weight: 0.5 options: node_hostname: database_server @@ -540,7 +464,7 @@ agents: file_name: database.db - - type: WEB_SERVER_404_PENALTY + - type: web-server-404-penalty weight: 0.5 options: node_hostname: web_server @@ -599,7 +523,7 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 services: - - type: DNSServer + - type: dns-server options: domain_mapping: arcd.com: 192.168.1.12 # web server @@ -611,9 +535,9 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: WebServer + - type: web-server applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 @@ -625,7 +549,7 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: DatabaseService + - type: database-service - type: server hostname: backup_server @@ -634,7 +558,7 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: FTPServer + - type: ftp-server - type: server hostname: security_suite @@ -654,14 +578,14 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: DataManipulationBot + - type: data-manipulation-bot options: port_scan_p_of_success: 0.1 data_manipulation_p_of_success: 0.1 payload: "DELETE" server_ip: 192.168.1.14 services: - - type: DNSClient + - type: dns-client - type: computer hostname: client_2 @@ -670,9 +594,9 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: WebBrowser + - type: web-browser services: - - type: DNSClient + - type: dns-client links: - endpoint_a_hostname: router_1 diff --git a/tests/assets/configs/extended_config.yaml b/tests/assets/configs/extended_config.yaml new file mode 100644 index 00000000..a58a9d4a --- /dev/null +++ b/tests/assets/configs/extended_config.yaml @@ -0,0 +1,840 @@ +metadata: + version: 3.0 + +io_settings: + save_agent_actions: true + save_step_metadata: false + save_pcap_logs: false + save_sys_logs: false + sys_log_level: WARNING + + +game: + max_episode_length: 128 + ports: + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + thresholds: + nmne: + high: 10 + medium: 5 + low: 0 + +agents: + - ref: client_2_green_user + team: GREEN + type: probabilistic-agent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + + action_space: + action_map: + 0: + action: do-nothing + options: {} + 1: + action: node-application-execute + options: + node_name: client_2 + application_name: web-browser + 2: + action: node-application-execute + options: + node_name: client_2 + application_name: database-client + + reward_function: + reward_components: + - type: webpage-unavailable-penalty + weight: 0.25 + options: + node_hostname: client_2 + - type: green-admin-database-unreachable-penalty + weight: 0.05 + options: + node_hostname: client_2 + + - ref: client_1_green_user + team: GREEN + type: probabilistic-agent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + + action_space: + action_map: + 0: + action: do-nothing + options: {} + 1: + action: node-application-execute + options: + node_name: client_1 + application_name: web-browser + 2: + action: node-application-execute + options: + node_name: client_1 + application_name: database-client + + reward_function: + reward_components: + - type: webpage-unavailable-penalty + weight: 0.25 + options: + node_hostname: client_1 + - type: green-admin-database-unreachable-penalty + weight: 0.05 + options: + node_hostname: client_1 + + + + + + - ref: data_manipulation_attacker + team: RED + type: red-database-corrupting-agent + + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + possible_start_nodes: [client_1, client_2] + target_application: data-manipulation-bot + start_step: 25 + frequency: 20 + variance: 5 + + - ref: defender + team: BLUE + type: proxy-agent + + observation_space: + type: custom + options: + components: + - type: nodes + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: web-server + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - HTTP + - POSTGRES_SERVER + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: links + label: LINKS + options: + link_references: + - router_1:eth-1<->switch_1:eth-8 + - router_1:eth-2<->switch_2:eth-8 + - switch_1:eth-1<->domain_controller:eth-1 + - switch_1:eth-2<->web_server:eth-1 + - switch_1:eth-3<->database_server:eth-1 + - switch_1:eth-4<->backup_server:eth-1 + - switch_1:eth-7<->security_suite:eth-1 + - switch_2:eth-1<->client_1:eth-1 + - switch_2:eth-2<->client_2:eth-1 + - switch_2:eth-7<->security_suite:eth-2 + - type: "none" + label: ICS + options: {} + + action_space: + action_map: + 0: + action: do-nothing + options: {} + # scan webapp service + 1: + action: node-service-scan + options: + node_name: web_server + service_name: web-server + # stop webapp service + 2: + action: node-service-stop + options: + node_name: web_server + service_name: web-server + # start webapp service + 3: + action: "node-service-start" + options: + node_name: web_server + service_name: web-server + 4: + action: "node-service-pause" + options: + node_name: web_server + service_name: web-server + 5: + action: "node-service-resume" + options: + node_name: web_server + service_name: web-server + 6: + action: "node-service-restart" + options: + node_name: web_server + service_name: web-server + 7: + action: "node-service-disable" + options: + node_name: web_server + service_name: web-server + 8: + action: "node-service-enable" + options: + node_name: web_server + service_name: web-server + 9: # check database.db file + action: "node-file-scan" + options: + node_name: database_server + folder_name: database + file_name: database.db + 10: + action: "node-file-checkhash" + options: + node_name: database_server + folder_name: database + file_name: database.db + 11: + action: "node-file-delete" + options: + node_name: database_server + folder_name: database + file_name: database.db + 12: + action: "node-file-repair" + options: + node_name: database_server + folder_name: database + file_name: database.db + 13: + action: "node-service-fix" + options: + node_name: database_server + service_name: database-service + 14: + action: "node-folder-scan" + options: + node_name: database_server + folder_name: database + 15: + action: "node-folder-checkhash" + options: + node_name: database_server + folder_name: database + 16: + action: "node-folder-repair" + options: + node_name: database_server + folder_name: database + 17: + action: "node-folder-restore" + options: + node_name: database_server + folder_name: database + 18: + action: "node-os-scan" + options: + node_name: domain_controller + 19: + action: "node-shutdown" + options: + node_name: domain_controller + 20: + action: node-startup + options: + node_name: domain_controller + 21: + action: node-reset + options: + node_name: domain_controller + 22: + action: "node-os-scan" + options: + node_name: web_server + 23: + action: "node-shutdown" + options: + node_name: web_server + 24: + action: node-startup + options: + node_name: web_server + 25: + action: node-reset + options: + node_name: web_server + 26: # old action num: 18 + action: "node-os-scan" + options: + node_name: database_server + 27: + action: "node-shutdown" + options: + node_name: database_server + 28: + action: node-startup + options: + node_name: database_server + 29: + action: node-reset + options: + node_name: database_server + 30: + action: "node-os-scan" + options: + node_name: backup_server + 31: + action: "node-shutdown" + options: + node_name: backup_server + 32: + action: node-startup + options: + node_name: backup_server + 33: + action: node-reset + options: + node_name: backup_server + 34: + action: "node-os-scan" + options: + node_name: security_suite + 35: + action: "node-shutdown" + options: + node_name: security_suite + 36: + action: node-startup + options: + node_name: security_suite + 37: + action: node-reset + options: + node_name: security_suite + 38: + action: "node-os-scan" + options: + node_name: client_1 + 39: # old action num: 19 # shutdown client 1 + action: "node-shutdown" + options: + node_name: client_1 + 40: # old action num: 20 + action: node-startup + options: + node_name: client_1 + 41: # old action num: 21 + action: node-reset + options: + node_name: client_1 + 42: + action: "node-os-scan" + options: + node_name: client_2 + 43: + action: "node-shutdown" + options: + node_name: client_2 + 44: + action: node-startup + options: + node_name: client_2 + 45: + action: node-reset + options: + node_name: client_2 + + 46: # old action num: 22 # "acl: ADDRULE - Block outgoing traffic from client 1" + action: "router-acl-add-rule" + options: + target_router: router_1 + position: 1 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE + 47: # old action num: 23 # "acl: ADDRULE - Block outgoing traffic from client 2" + action: "router-acl-add-rule" + options: + target_router: router_1 + position: 2 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE + 48: # old action num: 24 # block tcp traffic from client 1 to web app + action: "router-acl-add-rule" + options: + target_router: router_1 + position: 3 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE + 49: # old action num: 25 # block tcp traffic from client 2 to web app + action: "router-acl-add-rule" + options: + target_router: router_1 + position: 4 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE + 50: # old action num: 26 + action: "router-acl-add-rule" + options: + target_router: router_1 + position: 5 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE + 51: # old action num: 27 + action: "router-acl-add-rule" + options: + target_router: router_1 + position: 6 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE + 52: # old action num: 28 + action: "router-acl-remove-rule" + options: + target_router: router_1 + position: 0 + 53: # old action num: 29 + action: "router-acl-remove-rule" + options: + target_router: router_1 + position: 1 + 54: # old action num: 30 + action: "router-acl-remove-rule" + options: + target_router: router_1 + position: 2 + 55: # old action num: 31 + action: "router-acl-remove-rule" + options: + target_router: router_1 + position: 3 + 56: # old action num: 32 + action: "router-acl-remove-rule" + options: + target_router: router_1 + position: 4 + 57: # old action num: 33 + action: "router-acl-remove-rule" + options: + target_router: router_1 + position: 5 + 58: # old action num: 34 + action: "router-acl-remove-rule" + options: + target_router: router_1 + position: 6 + 59: # old action num: 35 + action: "router-acl-remove-rule" + options: + target_router: router_1 + position: 7 + 60: # old action num: 36 + action: "router-acl-remove-rule" + options: + target_router: router_1 + position: 8 + 61: # old action num: 37 + action: "router-acl-remove-rule" + options: + target_router: router_1 + position: 9 + 62: # old action num: 38 + action: "host-nic-disable" + options: + node_name: domain_controller + nic_num: 1 + 63: # old action num: 39 + action: "host-nic-enable" + options: + node_name: domain_controller + nic_num: 1 + 64: # old action num: 40 + action: "host-nic-disable" + options: + node_name: web_server + nic_num: 1 + 65: # old action num: 41 + action: "host-nic-enable" + options: + node_name: web_server + nic_num: 1 + 66: # old action num: 42 + action: "host-nic-disable" + options: + node_name: database_server + nic_num: 1 + 67: # old action num: 43 + action: "host-nic-enable" + options: + node_name: database_server + nic_num: 1 + 68: # old action num: 44 + action: "host-nic-disable" + options: + node_name: backup_server + nic_num: 1 + 69: # old action num: 45 + action: "host-nic-enable" + options: + node_name: backup_server + nic_num: 1 + 70: # old action num: 46 + action: "host-nic-disable" + options: + node_name: security_suite + nic_num: 1 + 71: # old action num: 47 + action: "host-nic-enable" + options: + node_name: security_suite + nic_num: 1 + 72: # old action num: 48 + action: "host-nic-disable" + options: + node_name: security_suite + nic_num: 2 + 73: # old action num: 49 + action: "host-nic-enable" + options: + node_name: security_suite + nic_num: 2 + 74: # old action num: 50 + action: "host-nic-disable" + options: + node_name: client_1 + nic_num: 1 + 75: # old action num: 51 + action: "host-nic-enable" + options: + node_name: client_1 + nic_num: 1 + 76: # old action num: 52 + action: "host-nic-disable" + options: + node_name: client_2 + nic_num: 1 + 77: # old action num: 53 + action: "host-nic-enable" + options: + node_name: client_2 + nic_num: 1 + + + + reward_function: + reward_components: + - type: database-file-integrity + weight: 0.40 + options: + node_hostname: database_server + folder_name: database + file_name: database.db + + - type: shared-reward + weight: 1.0 + options: + agent_name: client_1_green_user + + - type: shared-reward + weight: 1.0 + options: + agent_name: client_2_green_user + + agent_settings: + flatten_obs: true + action_masking: true + + + + + +simulation: + network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE + nodes: + + - hostname: router_1 + type: router + num_ports: 5 + ports: + 1: + ip_address: 192.168.1.1 + subnet_mask: 255.255.255.0 + 2: + ip_address: 192.168.10.1 + subnet_mask: 255.255.255.0 + acl: + 18: + action: PERMIT + src_port: POSTGRES_SERVER + dst_port: POSTGRES_SERVER + 19: + action: PERMIT + src_port: DNS + dst_port: DNS + 20: + action: PERMIT + src_port: FTP + dst_port: FTP + 21: + action: PERMIT + src_port: HTTP + dst_port: HTTP + 22: + action: PERMIT + src_port: ARP + dst_port: ARP + 23: + action: PERMIT + protocol: ICMP + + - hostname: switch_1 + type: switch + num_ports: 8 + + - hostname: switch_2 + type: gigaswitch + num_ports: 8 + + - hostname: domain_controller + type: server + ip_address: 192.168.1.10 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + services: + - type: dns-server + options: + domain_mapping: + arcd.com: 192.168.1.12 # web server + + - hostname: web_server + type: server + ip_address: 192.168.1.12 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: web-server + applications: + - type: database-client + options: + db_server_ip: 192.168.1.14 + + + - hostname: database_server + type: server + ip_address: 192.168.1.14 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: database-service + options: + backup_server_ip: 192.168.1.16 + - type: ftp-client + + - hostname: backup_server + type: server + ip_address: 192.168.1.16 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: ftp-server + + - hostname: security_suite + type: server + ip_address: 192.168.1.110 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + network_interfaces: + 2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot + ip_address: 192.168.10.110 + subnet_mask: 255.255.255.0 + + - hostname: client_1 + type: supercomputer + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: data-manipulation-bot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - type: web-browser + options: + target_url: http://arcd.com/users/ + - type: extended-application + options: + target_url: http://arcd.com/users/ + - type: database-client + options: + db_server_ip: 192.168.1.14 + services: + - type: dns-client + - type: database-service + options: + backup_server_ip: 192.168.1.16 + - type: extended-service + + - hostname: client_2 + type: computer + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: web-browser + options: + target_url: http://arcd.com/users/ + - type: data-manipulation-bot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - type: database-client + options: + db_server_ip: 192.168.1.14 + services: + - type: dns-client + + links: + - endpoint_a_hostname: router_1 + endpoint_a_port: 1 + endpoint_b_hostname: switch_1 + endpoint_b_port: 8 + - endpoint_a_hostname: router_1 + endpoint_a_port: 2 + endpoint_b_hostname: switch_2 + endpoint_b_port: 8 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 1 + endpoint_b_hostname: domain_controller + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 2 + endpoint_b_hostname: web_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 3 + endpoint_b_hostname: database_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 4 + endpoint_b_hostname: backup_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 7 + endpoint_b_hostname: security_suite + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 1 + endpoint_b_hostname: client_1 + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 2 + endpoint_b_hostname: client_2 + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 7 + endpoint_b_hostname: security_suite + endpoint_b_port: 2 diff --git a/tests/assets/configs/firewall_actions_network.yaml b/tests/assets/configs/firewall_actions_network.yaml index 2292616d..66470f5a 100644 --- a/tests/assets/configs/firewall_actions_network.yaml +++ b/tests/assets/configs/firewall_actions_network.yaml @@ -30,6 +30,9 @@ # | external_computer |------| switch_3 |------| external_server | # ----------------------- -------------- --------------------- # +metadata: + version: 3.0 + io_settings: save_step_metadata: false save_pcap_logs: true @@ -51,13 +54,13 @@ game: agents: - ref: defender team: BLUE - type: ProxyAgent + type: proxy-agent observation_space: - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: @@ -77,199 +80,174 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: - client_1:eth-1<->switch_1:eth-1 - - type: "NONE" + - type: "none" label: ICS options: {} action_space: - action_list: - - type: DONOTHING - - type: FIREWALL_ACL_ADDRULE - - type: FIREWALL_ACL_REMOVERULE - - type: NETWORK_PORT_DISABLE - - type: NETWORK_PORT_ENABLE action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: FIREWALL_ACL_ADDRULE + action: firewall-acl-add-rule options: + type: firewall-acl-add-rule target_firewall_nodename: firewall firewall_port_name: internal firewall_port_direction: inbound position: 1 - permission: 1 - source_ip_id: 2 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: PERMIT + src_ip: 192.168.0.10 + dst_ip: ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE 2: - action: FIREWALL_ACL_REMOVERULE + action: firewall-acl-remove-rule options: target_firewall_nodename: firewall firewall_port_name: internal firewall_port_direction: inbound position: 1 3: - action: FIREWALL_ACL_ADDRULE + action: firewall-acl-add-rule options: target_firewall_nodename: firewall firewall_port_name: internal firewall_port_direction: outbound position: 1 - permission: 2 - source_ip_id: 2 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 2 - dest_port_id: 3 - protocol_id: 2 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.0.10 # client 1 + dst_ip: ALL + src_port: ARP + dst_port: DNS + protocol_name: icmp + src_wildcard: NONE + dst_wildcard: NONE 4: - action: FIREWALL_ACL_REMOVERULE + action: firewall-acl-remove-rule options: target_firewall_nodename: firewall firewall_port_name: internal firewall_port_direction: outbound position: 1 5: - action: FIREWALL_ACL_ADDRULE + action: firewall-acl-add-rule options: target_firewall_nodename: firewall firewall_port_name: dmz firewall_port_direction: inbound position: 1 - permission: 2 - source_ip_id: 3 # dmz_server - dest_ip_id: 2 # client_1 - source_port_id: 4 - dest_port_id: 4 - protocol_id: 4 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.10 # dmz_server + dst_ip: 192.168.0.10 # client_1 + src_port: HTTP + dst_port: HTTP + protocol_name: UDP + src_wildcard: NONE + dst_wildcard: NONE 6: - action: FIREWALL_ACL_REMOVERULE + action: firewall-acl-remove-rule options: target_firewall_nodename: firewall firewall_port_name: dmz firewall_port_direction: inbound position: 1 7: - action: FIREWALL_ACL_ADDRULE + action: firewall-acl-add-rule options: target_firewall_nodename: firewall firewall_port_name: dmz firewall_port_direction: outbound position: 2 - permission: 2 - source_ip_id: 3 # dmz_server - dest_ip_id: 2 # client_1 - source_port_id: 4 - dest_port_id: 4 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.10 # dmz_server + dst_ip: 192.168.0.10 # client_1 + src_port: HTTP + dst_port: HTTP + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 8: - action: FIREWALL_ACL_REMOVERULE + action: firewall-acl-remove-rule options: target_firewall_nodename: firewall firewall_port_name: dmz firewall_port_direction: outbound position: 2 9: - action: FIREWALL_ACL_ADDRULE + action: firewall-acl-add-rule options: target_firewall_nodename: firewall firewall_port_name: external firewall_port_direction: inbound position: 10 - permission: 2 - source_ip_id: 4 # external_computer - dest_ip_id: 3 # dmz - source_port_id: 5 - dest_port_id: 5 - protocol_id: 2 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.20.10 # external_computer + dst_ip: 192.168.10.10 # dmz + src_port: POSTGRES_SERVER + dst_port: POSTGRES_SERVER + protocol_name: icmp + src_wildcard: NONE + dst_wildcard: NONE 10: - action: FIREWALL_ACL_REMOVERULE + action: firewall-acl-remove-rule options: target_firewall_nodename: firewall firewall_port_name: external firewall_port_direction: inbound position: 10 11: - action: FIREWALL_ACL_ADDRULE + action: firewall-acl-add-rule options: target_firewall_nodename: firewall firewall_port_name: external firewall_port_direction: outbound position: 1 - permission: 2 - source_ip_id: 4 # external_computer - dest_ip_id: 2 # client_1 - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.20.10 # external_computer + dst_ip: 192.168.0.10 # client_1 + src_port: ALL + dst_port: ALL + protocol_name: NONE + src_wildcard: NONE + dst_wildcard: NONE 12: - action: FIREWALL_ACL_REMOVERULE + action: firewall-acl-remove-rule options: target_firewall_nodename: firewall firewall_port_name: external firewall_port_direction: outbound position: 1 13: - action: NETWORK_PORT_DISABLE + action: network-port-disable options: + type: network-port-disable target_nodename: firewall - port_id: 3 + port_num: 3 14: - action: NETWORK_PORT_ENABLE + action: network-port-enable options: + type: network-port-enable target_nodename: firewall - port_id: 3 - options: - nodes: - - node_name: client_1 - - node_name: dmz_server - - node_name: external_computer - ip_list: - - 192.168.0.10 - - 192.168.10.10 - - 192.168.20.10 - max_folders_per_node: 2 - max_files_per_folder: 2 - max_services_per_node: 2 - max_nics_per_node: 8 - max_acl_rules: 10 - reward_function: - reward_components: - - type: DUMMY + port_num: 3 - agent_settings: - start_settings: - start_step: 5 - frequency: 4 - variance: 3 @@ -426,7 +404,7 @@ simulation: start_up_duration: 0 shut_down_duration: 0 services: - - type: DNSServer + - type: dns-server links: - endpoint_a_hostname: client_1 endpoint_a_port: 1 diff --git a/tests/assets/configs/software_fix_duration.yaml b/tests/assets/configs/fixing_duration_one_item.yaml similarity index 62% rename from tests/assets/configs/software_fix_duration.yaml rename to tests/assets/configs/fixing_duration_one_item.yaml index 1a28258b..02b69e5c 100644 --- a/tests/assets/configs/software_fix_duration.yaml +++ b/tests/assets/configs/fixing_duration_one_item.yaml @@ -4,6 +4,9 @@ # | client_1 |------| switch_1 |------| client_2 | # -------------- -------------- -------------- # +metadata: + version: 3.0 + io_settings: save_step_metadata: false save_pcap_logs: true @@ -26,52 +29,33 @@ game: agents: - ref: client_2_green_user team: GREEN - type: ProbabilisticAgent - observation_space: null + type: probabilistic-agent + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 - options: - nodes: - - node_name: client_2 - applications: - - application_name: WebBrowser - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 1 - - reward_function: - reward_components: - - type: DUMMY - + node_name: client_1 + application_name: web-browser agent_settings: - start_settings: - start_step: 5 - frequency: 4 - variance: 3 - + action_probabilities: + 0: 0.4 + 1: 0.6 - ref: defender team: BLUE - type: ProxyAgent + type: proxy-agent observation_space: - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: @@ -100,51 +84,33 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: - switch_1:eth-1<->client_1:eth-1 - switch_1:eth-2<->client_2:eth-1 - - type: "NONE" + - type: "none" label: ICS options: {} action_space: - action_list: - - type: DONOTHING - action_map: 0: - action: DONOTHING + action: do-nothing options: {} - options: - nodes: - - node_name: switch - - node_name: client_1 - - node_name: client_2 - - node_name: client_3 - max_folders_per_node: 2 - max_files_per_folder: 2 - max_services_per_node: 2 - max_nics_per_node: 8 - max_acl_rules: 10 - ip_list: - - 192.168.10.21 - - 192.168.10.22 - - 192.168.10.23 reward_function: reward_components: - - type: DATABASE_FILE_INTEGRITY + - type: database-file-integrity weight: 0.5 options: node_hostname: database_server @@ -152,7 +118,7 @@ agents: file_name: database.db - - type: WEB_SERVER_404_PENALTY + - type: web-server-404-penalty weight: 0.5 options: node_hostname: web_server @@ -177,66 +143,46 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: NMAP - options: - fix_duration: 1 - - type: RansomwareScript - options: - fix_duration: 1 - - type: WebBrowser + - type: ransomware-script + - type: web-browser options: target_url: http://arcd.com/users/ - fix_duration: 1 - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.10 server_password: arcd - fix_duration: 1 - - type: DataManipulationBot + fixing_duration: 1 + - type: data-manipulation-bot options: port_scan_p_of_success: 0.8 data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.21 server_password: arcd - fix_duration: 1 - - type: DoSBot + - type: dos-bot options: target_ip_address: 192.168.10.21 payload: SPOOF DATA port_scan_p_of_success: 0.8 - fix_duration: 1 services: - - type: DNSClient + - type: dns-client + - type: dns-server options: - dns_server: 192.168.1.10 - fix_duration: 3 - - type: DNSServer - options: - fix_duration: 3 domain_mapping: arcd.com: 192.168.1.10 - - type: DatabaseService + - type: database-service options: + fixing_duration: 5 backup_server_ip: 192.168.1.10 - fix_duration: 3 - - type: WebServer - options: - fix_duration: 3 - - type: FTPClient - options: - fix_duration: 3 - - type: FTPServer + - type: web-server + - type: ftp-client + - type: ftp-server options: server_password: arcd - fix_duration: 3 - - type: NTPClient + - type: ntp-client options: ntp_server_ip: 192.168.1.10 - fix_duration: 3 - - type: NTPServer - options: - fix_duration: 3 + - type: ntp-server - hostname: client_2 type: computer ip_address: 192.168.10.22 @@ -244,14 +190,12 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.10 server_password: arcd services: - - type: DNSClient - options: - dns_server: 192.168.1.10 + - type: dns-client links: - endpoint_a_hostname: switch_1 diff --git a/tests/assets/configs/install_and_configure_apps.yaml b/tests/assets/configs/install_and_configure_apps.yaml index 6b548f7e..35546902 100644 --- a/tests/assets/configs/install_and_configure_apps.yaml +++ b/tests/assets/configs/install_and_configure_apps.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + io_settings: save_step_metadata: false save_pcap_logs: false @@ -16,82 +19,65 @@ game: agents: - ref: agent_1 team: BLUE - type: ProxyAgent + type: proxy-agent + - observation_space: null action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_INSTALL - - type: CONFIGURE_DATABASE_CLIENT - - type: CONFIGURE_DOSBOT - - type: CONFIGURE_RANSOMWARE_SCRIPT - - type: NODE_APPLICATION_REMOVE action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_INSTALL + action: node-application-install options: - node_id: 0 - application_name: DatabaseClient + node_name: client_1 + application_name: database-client 2: - action: NODE_APPLICATION_INSTALL + action: node-application-install options: - node_id: 1 - application_name: RansomwareScript + node_name: client_2 + application_name: ransomware-script 3: - action: NODE_APPLICATION_INSTALL + action: node-application-install options: - node_id: 2 - application_name: DoSBot + node_name: client_3 + application_name: dos-bot 4: - action: CONFIGURE_DATABASE_CLIENT + action: configure-database-client options: - node_id: 0 - config: - server_ip_address: 10.0.0.5 + node_name: client_1 + server_ip_address: 10.0.0.5 5: - action: CONFIGURE_DATABASE_CLIENT + action: configure-database-client options: - node_id: 0 - config: - server_password: correct_password + node_name: client_1 + server_password: correct_password 6: - action: CONFIGURE_RANSOMWARE_SCRIPT + action: configure-ransomware-script options: - node_id: 1 - config: - server_ip_address: 10.0.0.5 - server_password: correct_password - payload: ENCRYPT + node_name: client_2 + server_ip_address: 10.0.0.5 + server_password: correct_password + payload: ENCRYPT 7: - action: CONFIGURE_DOSBOT + action: configure-dos-bot options: - node_id: 2 - config: - target_ip_address: 10.0.0.5 - target_port: POSTGRES_SERVER - payload: DELETE - repeat: true - port_scan_p_of_success: 1.0 - dos_intensity: 1.0 - max_sessions: 1000 + node_name: client_3 + target_ip_address: 10.0.0.5 + target_port: POSTGRES_SERVER + payload: DELETE + repeat: true + port_scan_p_of_success: 1.0 + dos_intensity: 1.0 + max_sessions: 1000 8: - action: NODE_APPLICATION_INSTALL + action: node-application-install options: - node_id: 1 - application_name: DatabaseClient - options: - nodes: - - node_name: client_1 - - node_name: client_2 - - node_name: client_3 - ip_list: [] - reward_function: - reward_components: - - type: DUMMY + node_name: client_2 + application_name: database-client + agent_settings: + flatten_obs: True + action_masking: False simulation: network: @@ -120,7 +106,7 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 10.0.0.1 services: - - type: DatabaseService + - type: database-service options: db_password: correct_password links: diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml index a2d64605..de0cdad9 100644 --- a/tests/assets/configs/multi_agent_session.yaml +++ b/tests/assets/configs/multi_agent_session.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + io_settings: save_agent_actions: false save_step_metadata: false @@ -20,98 +23,72 @@ game: agents: - ref: client_2_green_user team: GREEN - type: ProbabilisticAgent + type: probabilistic-agent agent_settings: action_probabilities: 0: 0.3 1: 0.6 2: 0.1 - observation_space: null + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_2 - applications: - - application_name: WebBrowser - - application_name: DatabaseClient - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 2 action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 + node_name: client_2 + application_name: web-browser 2: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 1 + node_name: client_2 + application_name: database-client reward_function: reward_components: - - type: WEBPAGE_UNAVAILABLE_PENALTY + - type: webpage-unavailable-penalty weight: 0.25 options: node_hostname: client_2 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + - type: green-admin-database-unreachable-penalty weight: 0.05 options: node_hostname: client_2 - ref: client_1_green_user team: GREEN - type: ProbabilisticAgent + type: probabilistic-agent agent_settings: action_probabilities: 0: 0.3 1: 0.6 2: 0.1 - observation_space: null + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_1 - applications: - - application_name: WebBrowser - - application_name: DatabaseClient - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 2 action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 + node_name: client_1 + application_name: web-browser 2: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 1 + node_name: client_1 + application_name: web-browser reward_function: reward_components: - - type: WEBPAGE_UNAVAILABLE_PENALTY + - type: webpage-unavailable-penalty weight: 0.25 options: node_hostname: client_1 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + - type: green-admin-database-unreachable-penalty weight: 0.05 options: node_hostname: client_1 @@ -122,55 +99,31 @@ agents: - ref: data_manipulation_attacker team: RED - type: RedDatabaseCorruptingAgent - - observation_space: null - - action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - - type: NODE_FILE_DELETE - - type: NODE_FILE_CORRUPT - - type: NODE_OS_SCAN - options: - nodes: - - node_name: client_1 - applications: - - application_name: DataManipulationBot - - node_name: client_2 - applications: - - application_name: DataManipulationBot - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - - reward_function: - reward_components: - - type: DUMMY + type: red-database-corrupting-agent agent_settings: # options specific to this particular agent type, basically args of __init__(self) - start_settings: - start_step: 25 - frequency: 20 - variance: 5 + possible_start_nodes: [client_1, client_2] + target_application: data-manipulation-bot + start_step: 25 + frequency: 20 + variance: 5 - ref: defender_1 team: BLUE - type: ProxyAgent + type: proxy-agent observation_space: - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: - hostname: domain_controller - hostname: web_server services: - - service_name: WebServer + - service_name: web-server - hostname: database_server folders: - folder_name: database @@ -202,15 +155,15 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: @@ -224,508 +177,442 @@ agents: - switch_2:eth-1<->client_1:eth-1 - switch_2:eth-2<->client_2:eth-1 - switch_2:eth-7<->security_suite:eth-2 - - type: "NONE" + - type: "none" label: ICS options: {} action_space: - action_list: - - type: DONOTHING - - type: NODE_SERVICE_SCAN - - type: NODE_SERVICE_STOP - - type: NODE_SERVICE_START - - type: NODE_SERVICE_PAUSE - - type: NODE_SERVICE_RESUME - - type: NODE_SERVICE_RESTART - - type: NODE_SERVICE_DISABLE - - type: NODE_SERVICE_ENABLE - - type: NODE_SERVICE_FIX - - type: NODE_FILE_SCAN - - type: NODE_FILE_CHECKHASH - - type: NODE_FILE_DELETE - - type: NODE_FILE_REPAIR - - type: NODE_FILE_RESTORE - - type: NODE_FOLDER_SCAN - - type: NODE_FOLDER_CHECKHASH - - type: NODE_FOLDER_REPAIR - - type: NODE_FOLDER_RESTORE - - type: NODE_OS_SCAN - - type: NODE_SHUTDOWN - - type: NODE_STARTUP - - type: NODE_RESET - - type: ROUTER_ACL_ADDRULE - - type: ROUTER_ACL_REMOVERULE - - type: HOST_NIC_ENABLE - - type: HOST_NIC_DISABLE - action_map: 0: - action: DONOTHING + action: do-nothing options: {} # scan webapp service 1: - action: NODE_SERVICE_SCAN + action: node-service-scan options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # stop webapp service 2: - action: NODE_SERVICE_STOP + action: node-service-stop options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # start webapp service 3: - action: "NODE_SERVICE_START" + action: "node-service-start" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 4: - action: "NODE_SERVICE_PAUSE" + action: "node-service-pause" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 5: - action: "NODE_SERVICE_RESUME" + action: "node-service-resume" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 6: - action: "NODE_SERVICE_RESTART" + action: "node-service-restart" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 7: - action: "NODE_SERVICE_DISABLE" + action: "node-service-disable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 8: - action: "NODE_SERVICE_ENABLE" + action: "node-service-enable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 9: # check database.db file - action: "NODE_FILE_SCAN" + action: "node-file-scan" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 10: - action: "NODE_FILE_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + action: "node-file-scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 11: - action: "NODE_FILE_DELETE" + action: "node-file-delete" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 12: - action: "NODE_FILE_REPAIR" + action: "node-file-repair" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 13: - action: "NODE_SERVICE_FIX" + action: "node-service-fix" options: - node_id: 2 - service_id: 0 + node_name: database_server + service_name: database-service 14: - action: "NODE_FOLDER_SCAN" + action: "node-folder-scan" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 15: - action: "NODE_FOLDER_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + action: "node-folder-scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 16: - action: "NODE_FOLDER_REPAIR" + action: "node-folder-repair" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 17: - action: "NODE_FOLDER_RESTORE" + action: "node-folder-restore" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 18: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 0 + node_name: domain_controller 19: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 0 + node_name: domain_controller 20: - action: NODE_STARTUP + action: node-startup options: - node_id: 0 + node_name: domain_controller 21: - action: NODE_RESET + action: node-reset options: - node_id: 0 + node_name: domain_controller 22: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 1 + node_name: web_server 23: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 1 + node_name: web_server 24: - action: NODE_STARTUP + action: node-startup options: - node_id: 1 + node_name: web_server 25: - action: NODE_RESET + action: node-reset options: - node_id: 1 + node_name: web_server 26: # old action num: 18 - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 2 + node_name: database_server 27: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 2 + node_name: database_server 28: - action: NODE_STARTUP + action: node-startup options: - node_id: 2 + node_name: database_server 29: - action: NODE_RESET + action: node-reset options: - node_id: 2 + node_name: database_server 30: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 3 + node_name: backup_server 31: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 3 + node_name: backup_server 32: - action: NODE_STARTUP + action: node-startup options: - node_id: 3 + node_name: backup_server 33: - action: NODE_RESET + action: node-reset options: - node_id: 3 + node_name: backup_server 34: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 4 + node_name: security_suite 35: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 4 + node_name: security_suite 36: - action: NODE_STARTUP + action: node-startup options: - node_id: 4 + node_name: security_suite 37: - action: NODE_RESET + action: node-reset options: - node_id: 4 + node_name: security_suite 38: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 5 + node_name: client_1 39: # old action num: 19 # shutdown client 1 - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 5 + node_name: client_1 40: # old action num: 20 - action: NODE_STARTUP + action: node-startup options: - node_id: 5 + node_name: client_1 41: # old action num: 21 - action: NODE_RESET + action: node-reset options: - node_id: 5 + node_name: client_1 42: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 6 + node_name: client_2 43: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 6 + node_name: client_2 44: - action: NODE_STARTUP + action: node-startup options: - node_id: 6 + node_name: client_2 45: - action: NODE_RESET + action: node-reset options: - node_id: 6 + node_name: client_2 - 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" - action: "ROUTER_ACL_ADDRULE" + 46: # old action num: 22 # "acl: ADDRULE - Block outgoing traffic from client 1" + action: "router-acl-add-rule" options: target_router: router_1 position: 1 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" - action: "ROUTER_ACL_ADDRULE" + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE + 47: # old action num: 23 # "acl: ADDRULE - Block outgoing traffic from client 2" + action: "router-acl-add-rule" options: target_router: router_1 position: 2 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE 48: # old action num: 24 # block tcp traffic from client 1 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 3 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 49: # old action num: 25 # block tcp traffic from client 2 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 4 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 50: # old action num: 26 - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 5 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 51: # old action num: 27 - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 6 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 52: # old action num: 28 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 0 53: # old action num: 29 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 1 54: # old action num: 30 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 2 55: # old action num: 31 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 3 56: # old action num: 32 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 4 57: # old action num: 33 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 5 58: # old action num: 34 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 6 59: # old action num: 35 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 7 60: # old action num: 36 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 8 61: # old action num: 37 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 9 62: # old action num: 38 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 63: # old action num: 39 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 64: # old action num: 40 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 65: # old action num: 41 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 66: # old action num: 42 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 67: # old action num: 43 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 68: # old action num: 44 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 69: # old action num: 45 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 70: # old action num: 46 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 71: # old action num: 47 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 72: # old action num: 48 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 73: # old action num: 49 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 74: # old action num: 50 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 75: # old action num: 51 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 76: # old action num: 52 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 6 - nic_id: 0 + node_name: client_2 + nic_num: 1 77: # old action num: 53 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 6 - nic_id: 0 - - - options: - nodes: - - node_name: domain_controller - - node_name: web_server - applications: - - application_name: DatabaseClient - services: - - service_name: WebServer - - node_name: database_server - folders: - - folder_name: database - files: - - file_name: database.db - services: - - service_name: DatabaseService - - node_name: backup_server - - node_name: security_suite - - node_name: client_1 - - node_name: client_2 - - max_folders_per_node: 2 - max_files_per_folder: 2 - max_services_per_node: 2 - max_nics_per_node: 8 - max_acl_rules: 10 - ip_list: - - 192.168.1.10 - - 192.168.1.12 - - 192.168.1.14 - - 192.168.1.16 - - 192.168.1.110 - - 192.168.10.21 - - 192.168.10.22 - - 192.168.10.110 - + node_name: client_2 + nic_num: 1 reward_function: reward_components: - - type: DATABASE_FILE_INTEGRITY + - type: database-file-integrity weight: 0.40 options: node_hostname: database_server folder_name: database file_name: database.db - - type: SHARED_REWARD + - type: shared-reward weight: 1.0 options: agent_name: client_1_green_user - - type: SHARED_REWARD + - type: shared-reward weight: 1.0 options: agent_name: client_2_green_user @@ -737,20 +624,20 @@ agents: - ref: defender_2 team: BLUE - type: ProxyAgent + type: proxy-agent observation_space: - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: - hostname: domain_controller - hostname: web_server services: - - service_name: WebServer + - service_name: web-server - hostname: database_server folders: - folder_name: database @@ -782,15 +669,15 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: @@ -804,512 +691,443 @@ agents: - switch_2:eth-1<->client_1:eth-1 - switch_2:eth-2<->client_2:eth-1 - switch_2:eth-7<->security_suite:eth-2 - - type: "NONE" + - type: "none" label: ICS options: {} action_space: - action_list: - - type: DONOTHING - - type: NODE_SERVICE_SCAN - - type: NODE_SERVICE_STOP - - type: NODE_SERVICE_START - - type: NODE_SERVICE_PAUSE - - type: NODE_SERVICE_RESUME - - type: NODE_SERVICE_RESTART - - type: NODE_SERVICE_DISABLE - - type: NODE_SERVICE_ENABLE - - type: NODE_SERVICE_FIX - - type: NODE_FILE_SCAN - - type: NODE_FILE_CHECKHASH - - type: NODE_FILE_DELETE - - type: NODE_FILE_REPAIR - - type: NODE_FILE_RESTORE - - type: NODE_FOLDER_SCAN - - type: NODE_FOLDER_CHECKHASH - - type: NODE_FOLDER_REPAIR - - type: NODE_FOLDER_RESTORE - - type: NODE_OS_SCAN - - type: NODE_SHUTDOWN - - type: NODE_STARTUP - - type: NODE_RESET - - type: ROUTER_ACL_ADDRULE - options: - target_router: router_1 - - type: ROUTER_ACL_REMOVERULE - options: - target_router: router_1 - - type: HOST_NIC_ENABLE - - type: HOST_NIC_DISABLE - action_map: 0: - action: DONOTHING + action: do-nothing options: {} # scan webapp service 1: - action: NODE_SERVICE_SCAN + action: node-service-scan options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # stop webapp service 2: - action: NODE_SERVICE_STOP + action: node-service-stop options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # start webapp service 3: - action: "NODE_SERVICE_START" + action: "node-service-start" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 4: - action: "NODE_SERVICE_PAUSE" + action: "node-service-pause" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 5: - action: "NODE_SERVICE_RESUME" + action: "node-service-resume" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 6: - action: "NODE_SERVICE_RESTART" + action: "node-service-restart" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 7: - action: "NODE_SERVICE_DISABLE" + action: "node-service-disable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 8: - action: "NODE_SERVICE_ENABLE" + action: "node-service-enable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 9: # check database.db file - action: "NODE_FILE_SCAN" + action: "node-file-scan" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 10: - action: "NODE_FILE_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + action: "node-file-scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 11: - action: "NODE_FILE_DELETE" + action: "node-file-delete" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 12: - action: "NODE_FILE_REPAIR" + action: "node-file-repair" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 13: - action: "NODE_SERVICE_FIX" + action: "node-service-fix" options: - node_id: 2 - service_id: 0 + node_name: database_server + service_name: database-service 14: - action: "NODE_FOLDER_SCAN" + action: "node-folder-scan" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 15: - action: "NODE_FOLDER_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + action: "node-folder-scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 16: - action: "NODE_FOLDER_REPAIR" + action: "node-folder-repair" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 17: - action: "NODE_FOLDER_RESTORE" + action: "node-folder-restore" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 18: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 0 + node_name: domain_controller 19: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 0 + node_name: domain_controller 20: - action: NODE_STARTUP + action: node-startup options: - node_id: 0 + node_name: domain_controller 21: - action: NODE_RESET + action: node-reset options: - node_id: 0 + node_name: domain_controller 22: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 1 + node_name: web_server 23: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 1 + node_name: web_server 24: - action: NODE_STARTUP + action: node-startup options: - node_id: 1 + node_name: web_server 25: - action: NODE_RESET + action: node-reset options: - node_id: 1 + node_name: web_server 26: # old action num: 18 - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 2 + node_name: database_server 27: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 2 + node_name: database_server 28: - action: NODE_STARTUP + action: node-startup options: - node_id: 2 + node_name: database_server 29: - action: NODE_RESET + action: node-reset options: - node_id: 2 + node_name: database_server 30: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 3 + node_name: backup_server 31: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 3 + node_name: backup_server 32: - action: NODE_STARTUP + action: node-startup options: - node_id: 3 + node_name: backup_server 33: - action: NODE_RESET + action: node-reset options: - node_id: 3 + node_name: backup_server 34: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 4 + node_name: security_suite 35: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 4 + node_name: security_suite 36: - action: NODE_STARTUP + action: node-startup options: - node_id: 4 + node_name: security_suite 37: - action: NODE_RESET + action: node-reset options: - node_id: 4 + node_name: security_suite 38: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 5 + node_name: client_1 39: # old action num: 19 # shutdown client 1 - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 5 + node_name: client_1 40: # old action num: 20 - action: NODE_STARTUP + action: node-startup options: - node_id: 5 + node_name: client_1 41: # old action num: 21 - action: NODE_RESET + action: node-reset options: - node_id: 5 + node_name: client_1 42: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 6 + node_name: client_2 43: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 6 + node_name: client_2 44: - action: NODE_STARTUP + action: node-startup options: - node_id: 6 + node_name: client_2 45: - action: NODE_RESET + action: node-reset options: - node_id: 6 + node_name: client_2 - 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" - action: "ROUTER_ACL_ADDRULE" + 46: # old action num: 22 # "acl: ADDRULE - Block outgoing traffic from client 1" + action: "router-acl-add-rule" options: target_router: router_1 position: 1 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" - action: "ROUTER_ACL_ADDRULE" + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE + 47: # old action num: 23 # "acl: ADDRULE - Block outgoing traffic from client 2" + action: "router-acl-add-rule" options: target_router: router_1 position: 2 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE 48: # old action num: 24 # block tcp traffic from client 1 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 3 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 49: # old action num: 25 # block tcp traffic from client 2 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 4 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 50: # old action num: 26 - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 5 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 51: # old action num: 27 - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 6 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 52: # old action num: 28 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 0 53: # old action num: 29 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 1 54: # old action num: 30 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 2 55: # old action num: 31 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 3 56: # old action num: 32 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 4 57: # old action num: 33 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 5 58: # old action num: 34 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 6 59: # old action num: 35 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 7 60: # old action num: 36 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 8 61: # old action num: 37 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 9 62: # old action num: 38 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 63: # old action num: 39 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 64: # old action num: 40 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 65: # old action num: 41 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 66: # old action num: 42 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 67: # old action num: 43 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 68: # old action num: 44 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 69: # old action num: 45 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 70: # old action num: 46 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 71: # old action num: 47 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 72: # old action num: 48 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 73: # old action num: 49 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 74: # old action num: 50 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 75: # old action num: 51 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 76: # old action num: 52 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 6 - nic_id: 0 + node_name: client_2 + nic_num: 1 77: # old action num: 53 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 6 - nic_id: 0 + node_name: client_2 + nic_num: 1 - - options: - nodes: - - node_name: domain_controller - - node_name: web_server - applications: - - application_name: DatabaseClient - services: - - service_name: WebServer - - node_name: database_server - folders: - - folder_name: database - files: - - file_name: database.db - services: - - service_name: DatabaseService - - node_name: backup_server - - node_name: security_suite - - node_name: client_1 - - node_name: client_2 - - max_folders_per_node: 2 - max_files_per_folder: 2 - max_services_per_node: 2 - max_nics_per_node: 8 - max_acl_rules: 10 - ip_list: - - 192.168.1.10 - - 192.168.1.12 - - 192.168.1.14 - - 192.168.1.16 - - 192.168.1.110 - - 192.168.10.21 - - 192.168.10.22 - - 192.168.10.110 - reward_function: reward_components: - - type: DATABASE_FILE_INTEGRITY + - type: database-file-integrity weight: 0.40 options: node_hostname: database_server folder_name: database file_name: database.db - - type: SHARED_REWARD + - type: shared-reward weight: 1.0 options: agent_name: client_1_green_user - - type: SHARED_REWARD + - type: shared-reward weight: 1.0 options: agent_name: client_2_green_user @@ -1379,7 +1197,7 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 services: - - type: DNSServer + - type: dns-server options: domain_mapping: arcd.com: 192.168.1.12 # web server @@ -1391,9 +1209,9 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: WebServer + - type: web-server applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 @@ -1405,10 +1223,10 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: DatabaseService + - type: database-service options: backup_server_ip: 192.168.1.16 - - type: FTPClient + - type: ftp-client - hostname: backup_server type: server @@ -1417,7 +1235,7 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: FTPServer + - type: ftp-server - hostname: security_suite type: server @@ -1437,20 +1255,20 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: DataManipulationBot + - type: data-manipulation-bot options: port_scan_p_of_success: 0.8 data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 - - type: WebBrowser + - type: web-browser options: target_url: http://arcd.com/users/ - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 services: - - type: DNSClient + - type: dns-client - hostname: client_2 type: computer @@ -1459,20 +1277,20 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: WebBrowser + - type: web-browser options: target_url: http://arcd.com/users/ - - type: DataManipulationBot + - type: data-manipulation-bot options: port_scan_p_of_success: 0.8 data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 services: - - type: DNSClient + - type: dns-client diff --git a/tests/assets/configs/nmap_network_service_recon_red_agent_config.yaml b/tests/assets/configs/nmap_network_service_recon_red_agent_config.yaml index c5508f13..f7b8431e 100644 --- a/tests/assets/configs/nmap_network_service_recon_red_agent_config.yaml +++ b/tests/assets/configs/nmap_network_service_recon_red_agent_config.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + io_settings: save_step_metadata: false save_pcap_logs: true @@ -21,33 +24,18 @@ game: agents: - ref: client_1_red_nmap team: RED - type: ProbabilisticAgent - observation_space: null + type: probabilistic-agent + action_space: - options: - nodes: - - node_name: client_1 - applications: - - application_name: NMAP - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 1 - action_list: - - type: NODE_NMAP_NETWORK_SERVICE_RECON action_map: 0: - action: NODE_NMAP_NETWORK_SERVICE_RECON + action: node-network-service-recon options: source_node: client_1 target_ip_address: 192.168.10.0/24 target_port: 80 target_protocol: tcp - - reward_function: - reward_components: - - type: DUMMY - + show: false agent_settings: action_probabilities: 0: 1.0 diff --git a/tests/assets/configs/nmap_ping_scan_red_agent_config.yaml b/tests/assets/configs/nmap_ping_scan_red_agent_config.yaml index 33ba3d19..112d7266 100644 --- a/tests/assets/configs/nmap_ping_scan_red_agent_config.yaml +++ b/tests/assets/configs/nmap_ping_scan_red_agent_config.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + io_settings: save_step_metadata: false save_pcap_logs: true @@ -21,30 +24,16 @@ game: agents: - ref: client_1_red_nmap team: RED - type: ProbabilisticAgent - observation_space: null + type: probabilistic-agent + action_space: - options: - nodes: - - node_name: client_1 - applications: - - application_name: NMAP - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 1 - action_list: - - type: NODE_NMAP_PING_SCAN action_map: 0: - action: NODE_NMAP_PING_SCAN + action: node-nmap-ping-scan options: source_node: client_1 target_ip_address: 192.168.1.0/24 - - reward_function: - reward_components: - - type: DUMMY + show: False agent_settings: action_probabilities: diff --git a/tests/assets/configs/nmap_port_scan_red_agent_config.yaml b/tests/assets/configs/nmap_port_scan_red_agent_config.yaml index 8ed715c1..acd5319a 100644 --- a/tests/assets/configs/nmap_port_scan_red_agent_config.yaml +++ b/tests/assets/configs/nmap_port_scan_red_agent_config.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + io_settings: save_step_metadata: false save_pcap_logs: true @@ -21,23 +24,12 @@ game: agents: - ref: client_1_red_nmap team: RED - type: ProbabilisticAgent - observation_space: null + type: probabilistic-agent + action_space: - options: - nodes: - - node_name: client_1 - applications: - - application_name: NMAP - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 1 - action_list: - - type: NODE_NMAP_PORT_SCAN action_map: 0: - action: NODE_NMAP_PORT_SCAN + action: node-nmap-port-scan options: source_node: client_1 target_ip_address: 192.168.10.0/24 @@ -47,10 +39,7 @@ agents: - 80 - 123 - 219 - - reward_function: - reward_components: - - type: DUMMY + show: false agent_settings: action_probabilities: diff --git a/tests/assets/configs/no_nodes_links_agents_network.yaml b/tests/assets/configs/no_nodes_links_agents_network.yaml index b20835bc..ed279b51 100644 --- a/tests/assets/configs/no_nodes_links_agents_network.yaml +++ b/tests/assets/configs/no_nodes_links_agents_network.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + io_settings: save_step_metadata: false save_pcap_logs: true diff --git a/tests/assets/configs/nodes_with_initial_files.yaml b/tests/assets/configs/nodes_with_initial_files.yaml index fad6cffd..d4c6406b 100644 --- a/tests/assets/configs/nodes_with_initial_files.yaml +++ b/tests/assets/configs/nodes_with_initial_files.yaml @@ -29,52 +29,36 @@ game: agents: - ref: client_2_green_user team: GREEN - type: ProbabilisticAgent - observation_space: null + type: periodic-agent action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: node_id: 0 application_id: 0 - options: - nodes: - - node_name: client_2 - applications: - - application_name: WebBrowser - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 1 - - reward_function: - reward_components: - - type: DUMMY agent_settings: - start_settings: - start_step: 5 - frequency: 4 - variance: 3 + possible_start_nodes: [client_2,] + target_application: web-browser + start_step: 5 + frequency: 4 + variance: 3 - ref: defender team: BLUE - type: ProxyAgent + type: proxy-agent observation_space: - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: @@ -111,51 +95,32 @@ agents: - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: - switch_1:eth-1<->client_1:eth-1 - switch_1:eth-2<->client_2:eth-1 - - type: "NONE" + - type: none label: ICS options: {} action_space: - action_list: - - type: DONOTHING - action_map: 0: - action: DONOTHING + action: do-nothing options: {} - options: - nodes: - - node_name: switch - - node_name: client_1 - - node_name: client_2 - - node_name: client_3 - max_folders_per_node: 2 - max_files_per_folder: 2 - max_services_per_node: 2 - max_nics_per_node: 8 - max_acl_rules: 10 - ip_list: - - 192.168.10.21 - - 192.168.10.22 - - 192.168.10.23 reward_function: reward_components: - - type: DATABASE_FILE_INTEGRITY + - type: database-file-integrity weight: 0.5 options: node_hostname: database_server folder_name: database file_name: database.db - - - type: WEB_SERVER_404_PENALTY + - type: web-server-404-penalty weight: 0.5 options: node_hostname: web_server @@ -180,60 +145,62 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: RansomwareScript - - type: WebBrowser + - type: ransomware-script + - type: web-browser options: target_url: http://arcd.com/users/ - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.10 server_password: arcd - - type: DataManipulationBot + - type: data-manipulation-bot options: port_scan_p_of_success: 0.8 data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.21 server_password: arcd - - type: DoSBot + - type: dos-bot options: target_ip_address: 192.168.10.21 payload: SPOOF DATA port_scan_p_of_success: 0.8 services: - - type: DNSClient + - type: dns-client options: dns_server: 192.168.1.10 - - type: DNSServer + - type: dns-server options: domain_mapping: arcd.com: 192.168.1.10 - - type: DatabaseService + - type: database-service options: backup_server_ip: 192.168.1.10 - - type: WebServer - - type: FTPServer + - type: web-server + - type: ftp-server options: server_password: arcd - - type: NTPClient + - type: ntp-client options: ntp_server_ip: 192.168.1.10 - - type: NTPServer + - type: ntp-server - hostname: client_2 type: computer ip_address: 192.168.10.22 subnet_mask: 255.255.255.0 default_gateway: 192.168.10.1 dns_server: 192.168.1.10 - file_system: - - empty_folder - - downloads: - - "test.txt" - - "suh_con.dn" - - root: - - passwords: - size: 69 - type: TXT + folders: + - folder_name: empty_folder + - folder_name: downloads + files: + - file_name: "test.txt" + - file_name: "another_file.pwtwoti" + - folder_name: root + files: + - file_name: passwords + size: 663 + type: TXT # pre installed services and applications - hostname: client_3 type: computer diff --git a/tests/assets/configs/scenario_with_placeholders/greens_1.yaml b/tests/assets/configs/scenario_with_placeholders/greens_1.yaml index 98d2392a..3f9b65f4 100644 --- a/tests/assets/configs/scenario_with_placeholders/greens_1.yaml +++ b/tests/assets/configs/scenario_with_placeholders/greens_1.yaml @@ -1,34 +1,26 @@ agents: &greens - ref: green_A team: GREEN - type: ProbabilisticAgent + type: probabilistic-agent agent_settings: action_probabilities: 0: 0.2 1: 0.8 - observation_space: null + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client - applications: - - application_name: DatabaseClient action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 + node_name: client + application_name: database-client reward_function: reward_components: - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + - type: green-admin-database-unreachable-penalty weight: 1.0 options: node_hostname: client diff --git a/tests/assets/configs/scenario_with_placeholders/greens_2.yaml b/tests/assets/configs/scenario_with_placeholders/greens_2.yaml index 17a5977b..77a689e7 100644 --- a/tests/assets/configs/scenario_with_placeholders/greens_2.yaml +++ b/tests/assets/configs/scenario_with_placeholders/greens_2.yaml @@ -1,34 +1,26 @@ agents: &greens - ref: green_B team: GREEN - type: ProbabilisticAgent + type: probabilistic-agent agent_settings: action_probabilities: 0: 0.95 1: 0.05 - observation_space: null + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client - applications: - - application_name: DatabaseClient action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 + node_name: client + application_name: database-client reward_function: reward_components: - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + - type: green-admin-database-unreachable-penalty weight: 1.0 options: node_hostname: client diff --git a/tests/assets/configs/scenario_with_placeholders/reds_1.yaml b/tests/assets/configs/scenario_with_placeholders/reds_1.yaml index 31675a0b..b95955b4 100644 --- a/tests/assets/configs/scenario_with_placeholders/reds_1.yaml +++ b/tests/assets/configs/scenario_with_placeholders/reds_1.yaml @@ -1,26 +1,11 @@ reds: &reds - ref: red_A team: RED - type: RedDatabaseCorruptingAgent - - observation_space: null - - action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client - applications: - - application_name: DataManipulationBot - - reward_function: - reward_components: - - type: DUMMY + type: red-database-corrupting-agent agent_settings: - start_settings: - start_step: 10 - frequency: 10 - variance: 0 + possible_start_nodes: [client,] + target_application: data-manipulation-bot + start_step: 10 + frequency: 10 + variance: 0 diff --git a/tests/assets/configs/scenario_with_placeholders/reds_2.yaml b/tests/assets/configs/scenario_with_placeholders/reds_2.yaml index c5572b89..a4a7550a 100644 --- a/tests/assets/configs/scenario_with_placeholders/reds_2.yaml +++ b/tests/assets/configs/scenario_with_placeholders/reds_2.yaml @@ -1,26 +1,10 @@ reds: &reds - ref: red_B team: RED - type: RedDatabaseCorruptingAgent - - observation_space: null - - action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client - applications: - - application_name: DataManipulationBot - - reward_function: - reward_components: - - type: DUMMY - + type: red-database-corrupting-agent agent_settings: - start_settings: - start_step: 3 - frequency: 2 - variance: 1 + possible_start_nodes: [client_1,] + target_application: data-manipulation-bot + start_step: 3 + frequency: 2 + variance: 1 diff --git a/tests/assets/configs/scenario_with_placeholders/scenario.yaml b/tests/assets/configs/scenario_with_placeholders/scenario.yaml index ef930a1a..57fe59ab 100644 --- a/tests/assets/configs/scenario_with_placeholders/scenario.yaml +++ b/tests/assets/configs/scenario_with_placeholders/scenario.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + io_settings: save_agent_actions: true save_step_metadata: false @@ -26,12 +29,12 @@ agents: - ref: defender team: BLUE - type: ProxyAgent + type: proxy-agent observation_space: - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: routers: [] @@ -46,7 +49,7 @@ agents: include_num_access: false include_nmne: false - - type: LINKS + - type: links label: LINKS options: link_references: @@ -54,69 +57,50 @@ agents: - server:eth-1<->switch_1:eth-2 action_space: - action_list: - - type: DONOTHING - - type: NODE_SHUTDOWN - - type: NODE_STARTUP - - type: HOST_NIC_ENABLE - - type: HOST_NIC_DISABLE action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_SHUTDOWN + action: node-shutdown options: - node_id: 0 + node_name: client 2: - action: NODE_SHUTDOWN + action: node-shutdown options: - node_id: 1 + node_name: server 3: - action: NODE_STARTUP + action: node-startup options: - node_id: 0 + node_name: client 4: - action: NODE_STARTUP + action: node-startup options: - node_id: 1 + node_name: server 5: - action: HOST_NIC_DISABLE + action: host-nic-disable options: - node_id: 0 - nic_id: 0 + node_name: client + nic_num: 1 6: - action: HOST_NIC_DISABLE + action: host-nic-disable options: - node_id: 1 - nic_id: 0 + node_name: server + nic_num: 1 7: - action: HOST_NIC_ENABLE + action: host-nic-enable options: - node_id: 0 - nic_id: 0 + node_name: client + nic_num: 1 8: - action: HOST_NIC_ENABLE + action: host-nic-enable options: - node_id: 1 - nic_id: 0 - options: - nodes: - - node_name: client - - node_name: server - - max_folders_per_node: 0 - max_files_per_folder: 0 - max_services_per_node: 0 - max_nics_per_node: 1 - max_acl_rules: 0 - ip_list: - - 192.168.1.2 - - 192.168.1.3 + node_name: server + nic_num: 1 reward_function: reward_components: - - type: DATABASE_FILE_INTEGRITY + - type: database-file-integrity weight: 0.40 options: node_hostname: database_server @@ -136,10 +120,10 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.3 - - type: DataManipulationBot + - type: data-manipulation-bot options: server_ip: 192.168.1.3 payload: "DELETE" @@ -154,7 +138,7 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 services: - - type: DatabaseService + - type: database-service links: - endpoint_a_hostname: client diff --git a/tests/assets/configs/shared_rewards.yaml b/tests/assets/configs/shared_rewards.yaml index 81cb85f7..5aeb99fa 100644 --- a/tests/assets/configs/shared_rewards.yaml +++ b/tests/assets/configs/shared_rewards.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + io_settings: save_agent_actions: false save_step_metadata: false @@ -23,150 +26,103 @@ game: agents: - ref: client_2_green_user team: GREEN - type: ProbabilisticAgent + type: probabilistic-agent agent_settings: action_probabilities: 0: 0.3 1: 0.6 2: 0.1 - observation_space: null + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_2 - applications: - - application_name: WebBrowser - - application_name: DatabaseClient - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 2 action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 + node_name: client_2 + application_name: web-browser 2: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 1 + node_name: client_2 + application_name: database-client reward_function: reward_components: - - type: WEBPAGE_UNAVAILABLE_PENALTY + - type: webpage-unavailable-penalty weight: 0.25 options: node_hostname: client_2 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + - type: green-admin-database-unreachable-penalty weight: 0.05 options: node_hostname: client_2 - ref: client_1_green_user team: GREEN - type: ProbabilisticAgent + type: probabilistic-agent agent_settings: action_probabilities: 0: 0.3 1: 0.6 2: 0.1 - observation_space: null + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_1 - applications: - - application_name: WebBrowser - - application_name: DatabaseClient - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 2 action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 + node_name: client_1 + application_name: web-browser 2: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 1 + node_name: client_1 + application_name: database-client reward_function: reward_components: - - type: WEBPAGE_UNAVAILABLE_PENALTY + - type: webpage-unavailable-penalty weight: 0.25 options: node_hostname: client_1 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + - type: green-admin-database-unreachable-penalty weight: 0.05 options: node_hostname: client_1 - ref: data_manipulation_attacker team: RED - type: RedDatabaseCorruptingAgent - - observation_space: null - - action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_1 - applications: - - application_name: DataManipulationBot - - node_name: client_2 - applications: - - application_name: DataManipulationBot - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - - reward_function: - reward_components: - - type: DUMMY + type: red-database-corrupting-agent agent_settings: # options specific to this particular agent type, basically args of __init__(self) - start_settings: - start_step: 25 - frequency: 20 - variance: 5 + possible_start_nodes: [client_1, client_2] + target_application: data-manipulation-bot + start_step: 25 + frequency: 20 + variance: 5 - ref: defender team: BLUE - type: ProxyAgent + type: proxy-agent observation_space: - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: - hostname: domain_controller - hostname: web_server services: - - service_name: WebServer + - service_name: web-server - hostname: database_server folders: - folder_name: database @@ -198,15 +154,15 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: @@ -220,503 +176,436 @@ agents: - switch_2:eth-1<->client_1:eth-1 - switch_2:eth-2<->client_2:eth-1 - switch_2:eth-7<->security_suite:eth-2 - - type: "NONE" + - type: "none" label: ICS options: {} action_space: - action_list: - - type: DONOTHING - - type: NODE_SERVICE_SCAN - - type: NODE_SERVICE_STOP - - type: NODE_SERVICE_START - - type: NODE_SERVICE_PAUSE - - type: NODE_SERVICE_RESUME - - type: NODE_SERVICE_RESTART - - type: NODE_SERVICE_DISABLE - - type: NODE_SERVICE_ENABLE - - type: NODE_SERVICE_FIX - - type: NODE_FILE_SCAN - - type: NODE_FILE_CHECKHASH - - type: NODE_FILE_DELETE - - type: NODE_FILE_REPAIR - - type: NODE_FILE_RESTORE - - type: NODE_FOLDER_SCAN - - type: NODE_FOLDER_CHECKHASH - - type: NODE_FOLDER_REPAIR - - type: NODE_FOLDER_RESTORE - - type: NODE_OS_SCAN - - type: NODE_SHUTDOWN - - type: NODE_STARTUP - - type: NODE_RESET - - type: ROUTER_ACL_ADDRULE - - type: ROUTER_ACL_REMOVERULE - - type: HOST_NIC_ENABLE - - type: HOST_NIC_DISABLE - action_map: 0: - action: DONOTHING + action: do-nothing options: {} # scan webapp service 1: - action: NODE_SERVICE_SCAN + action: node-service-scan options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # stop webapp service 2: - action: NODE_SERVICE_STOP + action: node-service-stop options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # start webapp service 3: - action: "NODE_SERVICE_START" + action: "node-service-start" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 4: - action: "NODE_SERVICE_PAUSE" + action: "node-service-pause" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 5: - action: "NODE_SERVICE_RESUME" + action: "node-service-resume" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 6: - action: "NODE_SERVICE_RESTART" + action: "node-service-restart" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 7: - action: "NODE_SERVICE_DISABLE" + action: "node-service-disable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 8: - action: "NODE_SERVICE_ENABLE" + action: "node-service-enable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 9: # check database.db file - action: "NODE_FILE_SCAN" + action: "node-file-scan" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 10: - action: "NODE_FILE_CHECKHASH" + action: "node-file-checkhash" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 11: - action: "NODE_FILE_DELETE" + action: "node-file-delete" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 12: - action: "NODE_FILE_REPAIR" + action: "node-file-repair" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 13: - action: "NODE_SERVICE_FIX" + action: "node-service-fix" options: - node_id: 2 - service_id: 0 + node_name: database_server + service_name: database-service 14: - action: "NODE_FOLDER_SCAN" + action: "node-folder-scan" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 15: - action: "NODE_FOLDER_CHECKHASH" + action: "node-folder-checkhash" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 16: - action: "NODE_FOLDER_REPAIR" + action: "node-folder-repair" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 17: - action: "NODE_FOLDER_RESTORE" + action: "node-folder-restore" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 18: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 0 + node_name: domain_controller 19: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 0 + node_name: domain_controller 20: - action: NODE_STARTUP + action: node-startup options: - node_id: 0 + node_name: domain_controller 21: - action: NODE_RESET + action: node-reset options: - node_id: 0 + node_name: domain_controller 22: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 1 + node_name: web_server 23: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 1 + node_name: web_server 24: - action: NODE_STARTUP + action: node-startup options: - node_id: 1 + node_name: web_server 25: - action: NODE_RESET + action: node-reset options: - node_id: 1 + node_name: web_server 26: # old action num: 18 - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 2 + node_name: database_server 27: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 2 + node_name: database_server 28: - action: NODE_STARTUP + action: node-startup options: - node_id: 2 + node_name: database_server 29: - action: NODE_RESET + action: node-reset options: - node_id: 2 + node_name: database_server 30: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 3 + node_name: backup_server 31: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 3 + node_name: backup_server 32: - action: NODE_STARTUP + action: node-startup options: - node_id: 3 + node_name: backup_server 33: - action: NODE_RESET + action: node-reset options: - node_id: 3 + node_name: backup_server 34: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 4 + node_name: security_suite 35: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 4 + node_name: security_suite 36: - action: NODE_STARTUP + action: node-startup options: - node_id: 4 + node_name: security_suite 37: - action: NODE_RESET + action: node-reset options: - node_id: 4 + node_name: security_suite 38: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 5 + node_name: client_1 39: # old action num: 19 # shutdown client 1 - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 5 + node_name: client_1 40: # old action num: 20 - action: NODE_STARTUP + action: node-startup options: - node_id: 5 + node_name: client_1 41: # old action num: 21 - action: NODE_RESET + action: node-reset options: - node_id: 5 + node_name: client_1 42: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 6 + node_name: client_2 43: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 6 + node_name: client_2 44: - action: NODE_STARTUP + action: node-startup options: - node_id: 6 + node_name: client_2 45: - action: NODE_RESET + action: node-reset options: - node_id: 6 + node_name: client_2 - 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" - action: "ROUTER_ACL_ADDRULE" + 46: # old action num: 22 # "acl: ADDRULE - Block outgoing traffic from client 1" + action: "router-acl-add-rule" options: target_router: router_1 position: 1 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" - action: "ROUTER_ACL_ADDRULE" + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE + 47: # old action num: 23 # "acl: ADDRULE - Block outgoing traffic from client 2" + action: "router-acl-add-rule" options: target_router: router_1 position: 2 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE 48: # old action num: 24 # block tcp traffic from client 1 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 3 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 49: # old action num: 25 # block tcp traffic from client 2 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 4 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 50: # old action num: 26 - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 5 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 51: # old action num: 27 - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 6 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 52: # old action num: 28 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 0 53: # old action num: 29 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 1 54: # old action num: 30 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 2 55: # old action num: 31 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 3 56: # old action num: 32 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 4 57: # old action num: 33 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 5 58: # old action num: 34 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 6 59: # old action num: 35 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 7 60: # old action num: 36 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 8 61: # old action num: 37 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 9 62: # old action num: 38 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 63: # old action num: 39 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 64: # old action num: 40 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 65: # old action num: 41 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 66: # old action num: 42 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 67: # old action num: 43 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 68: # old action num: 44 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 69: # old action num: 45 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 70: # old action num: 46 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 71: # old action num: 47 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 72: # old action num: 48 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 73: # old action num: 49 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 74: # old action num: 50 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 75: # old action num: 51 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 76: # old action num: 52 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 6 - nic_id: 0 + node_name: client_2 + nic_num: 1 77: # old action num: 53 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 6 - nic_id: 0 - - - - options: - nodes: - - node_name: domain_controller - - node_name: web_server - applications: - - application_name: DatabaseClient - services: - - service_name: WebServer - - node_name: database_server - folders: - - folder_name: database - files: - - file_name: database.db - services: - - service_name: DatabaseService - - node_name: backup_server - - node_name: security_suite - - node_name: client_1 - - node_name: client_2 - - max_folders_per_node: 2 - max_files_per_folder: 2 - max_services_per_node: 2 - max_nics_per_node: 8 - max_acl_rules: 10 - ip_list: - - 192.168.1.10 - - 192.168.1.12 - - 192.168.1.14 - - 192.168.1.16 - - 192.168.1.110 - - 192.168.10.21 - - 192.168.10.22 - - 192.168.10.110 - + node_name: client_2 + nic_num: 1 reward_function: reward_components: - - type: SHARED_REWARD + - type: shared-reward weight: 1.0 options: agent_name: client_1_green_user - - type: SHARED_REWARD + - type: shared-reward weight: 1.0 options: agent_name: client_2_green_user @@ -787,7 +676,7 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 services: - - type: DNSServer + - type: dns-server options: domain_mapping: arcd.com: 192.168.1.12 # web server @@ -799,9 +688,9 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: WebServer + - type: web-server applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 @@ -813,10 +702,10 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: DatabaseService + - type: database-service options: backup_server_ip: 192.168.1.16 - - type: FTPClient + - type: ftp-client - hostname: backup_server type: server @@ -825,7 +714,7 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: FTPServer + - type: ftp-server - hostname: security_suite type: server @@ -845,20 +734,20 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: DataManipulationBot + - type: data-manipulation-bot options: port_scan_p_of_success: 0.8 data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 - - type: WebBrowser + - type: web-browser options: target_url: http://arcd.com/users/ - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 services: - - type: DNSClient + - type: dns-client - hostname: client_2 type: computer @@ -867,20 +756,20 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: WebBrowser + - type: web-browser options: target_url: http://arcd.com/users/ - - type: DataManipulationBot + - type: data-manipulation-bot options: port_scan_p_of_success: 0.8 data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 services: - - type: DNSClient + - type: dns-client diff --git a/tests/assets/configs/fix_duration_one_item.yaml b/tests/assets/configs/software_fixing_duration.yaml similarity index 67% rename from tests/assets/configs/fix_duration_one_item.yaml rename to tests/assets/configs/software_fixing_duration.yaml index bd0fb61f..66ba6f18 100644 --- a/tests/assets/configs/fix_duration_one_item.yaml +++ b/tests/assets/configs/software_fixing_duration.yaml @@ -4,6 +4,9 @@ # | client_1 |------| switch_1 |------| client_2 | # -------------- -------------- -------------- # +metadata: + version: 3.0 + io_settings: save_step_metadata: false save_pcap_logs: true @@ -26,52 +29,33 @@ game: agents: - ref: client_2_green_user team: GREEN - type: ProbabilisticAgent - observation_space: null + type: probabilistic-agent + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 - options: - nodes: - - node_name: client_2 - applications: - - application_name: WebBrowser - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 1 - - reward_function: - reward_components: - - type: DUMMY - + node_name: client_2 + application_name: web-browser agent_settings: - start_settings: - start_step: 5 - frequency: 4 - variance: 3 - + action_probabilities: + 0: 0.4 + 1: 0.6 - ref: defender team: BLUE - type: ProxyAgent + type: proxy-agent observation_space: - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: @@ -100,51 +84,33 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: - switch_1:eth-1<->client_1:eth-1 - switch_1:eth-2<->client_2:eth-1 - - type: "NONE" + - type: "none" label: ICS options: {} action_space: - action_list: - - type: DONOTHING - action_map: 0: - action: DONOTHING + action: do-nothing options: {} - options: - nodes: - - node_name: switch - - node_name: client_1 - - node_name: client_2 - - node_name: client_3 - max_folders_per_node: 2 - max_files_per_folder: 2 - max_services_per_node: 2 - max_nics_per_node: 8 - max_acl_rules: 10 - ip_list: - - 192.168.10.21 - - 192.168.10.22 - - 192.168.10.23 reward_function: reward_components: - - type: DATABASE_FILE_INTEGRITY + - type: database-file-integrity weight: 0.5 options: node_hostname: database_server @@ -152,7 +118,7 @@ agents: file_name: database.db - - type: WEB_SERVER_404_PENALTY + - type: web-server-404-penalty weight: 0.5 options: node_hostname: web_server @@ -177,48 +143,66 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: RansomwareScript - - type: WebBrowser + - type: nmap + options: + fixing_duration: 1 + - type: ransomware-script + options: + fixing_duration: 1 + - type: web-browser options: target_url: http://arcd.com/users/ - - type: DatabaseClient + fixing_duration: 1 + - type: database-client options: db_server_ip: 192.168.1.10 server_password: arcd - fix_duration: 1 - - type: DataManipulationBot + fixing_duration: 1 + - type: data-manipulation-bot options: port_scan_p_of_success: 0.8 data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.21 server_password: arcd - - type: DoSBot + fixing_duration: 1 + - type: dos-bot options: target_ip_address: 192.168.10.21 payload: SPOOF DATA port_scan_p_of_success: 0.8 + fixing_duration: 1 services: - - type: DNSClient + - type: dns-client options: dns_server: 192.168.1.10 - - type: DNSServer + fixing_duration: 3 + - type: dns-server options: + fixing_duration: 3 domain_mapping: arcd.com: 192.168.1.10 - - type: DatabaseService + - type: database-service options: - fix_duration: 5 backup_server_ip: 192.168.1.10 - - type: WebServer - - type: FTPClient - - type: FTPServer + fixing_duration: 3 + - type: web-server + options: + fixing_duration: 3 + - type: ftp-client + options: + fixing_duration: 3 + - type: ftp-server options: server_password: arcd - - type: NTPClient + fixing_duration: 3 + - type: ntp-client options: ntp_server_ip: 192.168.1.10 - - type: NTPServer + fixing_duration: 3 + - type: ntp-server + options: + fixing_duration: 3 - hostname: client_2 type: computer ip_address: 192.168.10.22 @@ -226,14 +210,12 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.10 server_password: arcd services: - - type: DNSClient - options: - dns_server: 192.168.1.10 + - type: dns-client links: - endpoint_a_hostname: switch_1 diff --git a/tests/assets/configs/test_application_install.yaml b/tests/assets/configs/test_application_install.yaml index 3a3a6890..8a292f83 100644 --- a/tests/assets/configs/test_application_install.yaml +++ b/tests/assets/configs/test_application_install.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + io_settings: save_agent_actions: true save_step_metadata: false @@ -23,98 +26,72 @@ game: agents: - ref: client_2_green_user team: GREEN - type: ProbabilisticAgent + type: probabilistic-agent agent_settings: action_probabilities: 0: 0.3 1: 0.6 2: 0.1 - observation_space: null + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_2 - applications: - - application_name: WebBrowser - - application_name: DatabaseClient - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 2 action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 + node_name: client_2 + application_name: web-browser 2: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 1 + node_name: client_2 + application_name: database-client reward_function: reward_components: - - type: WEBPAGE_UNAVAILABLE_PENALTY + - type: webpage-unavailable-penalty weight: 0.25 options: node_hostname: client_2 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + - type: green-admin-database-unreachable-penalty weight: 0.05 options: node_hostname: client_2 - ref: client_1_green_user team: GREEN - type: ProbabilisticAgent + type: probabilistic-agent agent_settings: action_probabilities: 0: 0.3 1: 0.6 2: 0.1 - observation_space: null + action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_1 - applications: - - application_name: WebBrowser - - application_name: DatabaseClient - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 2 action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 + node_name: client_1 + application_name: web-browser 2: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 1 + node_name: client_1 + application_name: web-browser reward_function: reward_components: - - type: WEBPAGE_UNAVAILABLE_PENALTY + - type: webpage-unavailable-penalty weight: 0.25 options: node_hostname: client_1 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + - type: green-admin-database-unreachable-penalty weight: 0.05 options: node_hostname: client_1 @@ -125,52 +102,31 @@ agents: - ref: data_manipulation_attacker team: RED - type: RedDatabaseCorruptingAgent - - observation_space: null - - action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_1 - applications: - - application_name: DataManipulationBot - - node_name: client_2 - applications: - - application_name: DataManipulationBot - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - - reward_function: - reward_components: - - type: DUMMY + type: red-database-corrupting-agent agent_settings: # options specific to this particular agent type, basically args of __init__(self) - start_settings: - start_step: 25 - frequency: 20 - variance: 5 + possible_start_nodes: [client_1, client_2] + target_application: data-manipulation-bot + start_step: 25 + frequency: 20 + variance: 5 - ref: defender team: BLUE - type: ProxyAgent + type: proxy-agent observation_space: - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: - hostname: domain_controller - hostname: web_server services: - - service_name: WebServer + - service_name: web-server - hostname: database_server folders: - folder_name: database @@ -202,15 +158,15 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: @@ -224,543 +180,468 @@ agents: - switch_2:eth-1<->client_1:eth-1 - switch_2:eth-2<->client_2:eth-1 - switch_2:eth-7<->security_suite:eth-2 - - type: "NONE" + - type: "none" label: ICS options: {} action_space: - action_list: - - type: DONOTHING - - type: NODE_SERVICE_SCAN - - type: NODE_SERVICE_STOP - - type: NODE_SERVICE_START - - type: NODE_SERVICE_PAUSE - - type: NODE_SERVICE_RESUME - - type: NODE_SERVICE_RESTART - - type: NODE_SERVICE_DISABLE - - type: NODE_SERVICE_ENABLE - - type: NODE_SERVICE_FIX - - type: NODE_FILE_SCAN - - type: NODE_FILE_CHECKHASH - - type: NODE_FILE_DELETE - - type: NODE_FILE_REPAIR - - type: NODE_FILE_RESTORE - - type: NODE_FOLDER_SCAN - - type: NODE_FOLDER_CHECKHASH - - type: NODE_FOLDER_REPAIR - - type: NODE_FOLDER_RESTORE - - type: NODE_OS_SCAN - - type: NODE_SHUTDOWN - - type: NODE_STARTUP - - type: NODE_RESET - - type: ROUTER_ACL_ADDRULE - - type: ROUTER_ACL_REMOVERULE - - type: HOST_NIC_ENABLE - - type: HOST_NIC_DISABLE - - type: NODE_APPLICATION_INSTALL - - type: NODE_APPLICATION_REMOVE - - type: NODE_APPLICATION_EXECUTE - - type: CONFIGURE_DOSBOT - action_map: 0: - action: DONOTHING + action: do-nothing options: {} # scan webapp service 1: - action: NODE_SERVICE_SCAN + action: node-service-scan options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # stop webapp service 2: - action: NODE_SERVICE_STOP + action: node-service-stop options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # start webapp service 3: - action: "NODE_SERVICE_START" + action: "node-service-start" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 4: - action: "NODE_SERVICE_PAUSE" + action: "node-service-pause" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 5: - action: "NODE_SERVICE_RESUME" + action: "node-service-resume" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 6: - action: "NODE_SERVICE_RESTART" + action: "node-service-restart" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 7: - action: "NODE_SERVICE_DISABLE" + action: "node-service-disable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 8: - action: "NODE_SERVICE_ENABLE" + action: "node-service-enable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 9: # check database.db file - action: "NODE_FILE_SCAN" + action: "node-file-scan" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 10: - action: "NODE_FILE_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + action: "node-file-scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 11: - action: "NODE_FILE_DELETE" + action: "node-file-delete" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 12: - action: "NODE_FILE_REPAIR" + action: "node-file-repair" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 13: - action: "NODE_SERVICE_FIX" + action: "node-service-fix" options: - node_id: 2 - service_id: 0 + node_name: database_server + service_name: database-service 14: - action: "NODE_FOLDER_SCAN" + action: "node-folder-scan" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 15: - action: "NODE_FOLDER_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + action: "node-folder-scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 16: - action: "NODE_FOLDER_REPAIR" + action: "node-folder-repair" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 17: - action: "NODE_FOLDER_RESTORE" + action: "node-folder-restore" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 18: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 0 + node_name: domain_controller 19: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 0 + node_name: domain_controller 20: - action: NODE_STARTUP + action: node-startup options: - node_id: 0 + node_name: domain_controller 21: - action: NODE_RESET + action: node-reset options: - node_id: 0 + node_name: domain_controller 22: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 1 + node_name: web_server 23: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 1 + node_name: web_server 24: - action: NODE_STARTUP + action: node-startup options: - node_id: 1 + node_name: web_server 25: - action: NODE_RESET + action: node-reset options: - node_id: 1 + node_name: web_server 26: # old action num: 18 - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 2 + node_name: database_server 27: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 2 + node_name: database_server 28: - action: NODE_STARTUP + action: node-startup options: - node_id: 2 + node_name: database_server 29: - action: NODE_RESET + action: node-reset options: - node_id: 2 + node_name: database_server 30: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 3 + node_name: backup_server 31: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 3 + node_name: backup_server 32: - action: NODE_STARTUP + action: node-startup options: - node_id: 3 + node_name: backup_server 33: - action: NODE_RESET + action: node-reset options: - node_id: 3 + node_name: backup_server 34: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 4 + node_name: security_suite 35: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 4 + node_name: security_suite 36: - action: NODE_STARTUP + action: node-startup options: - node_id: 4 + node_name: security_suite 37: - action: NODE_RESET + action: node-reset options: - node_id: 4 + node_name: security_suite 38: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 5 + node_name: client_1 39: # old action num: 19 # shutdown client 1 - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 5 + node_name: client_1 40: # old action num: 20 - action: NODE_STARTUP + action: node-startup options: - node_id: 5 + node_name: client_1 41: # old action num: 21 - action: NODE_RESET + action: node-reset options: - node_id: 5 + node_name: client_1 42: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 6 + node_name: client_2 43: - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 6 + node_name: client_2 44: - action: NODE_STARTUP + action: node-startup options: - node_id: 6 + node_name: client_2 45: - action: NODE_RESET + action: node-reset options: - node_id: 6 + node_name: client_2 - 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" - action: "ROUTER_ACL_ADDRULE" + 46: # old action num: 22 # "acl: ADDRULE - Block outgoing traffic from client 1" + action: "router-acl-add-rule" options: - target_router_hostname: router_1 + target_router: router_1 position: 1 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" - action: "ROUTER_ACL_ADDRULE" + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE + 47: # old action num: 23 # "acl: ADDRULE - Block outgoing traffic from client 2" + action: "router-acl-add-rule" options: - target_router_hostname: router_1 + target_router: router_1 position: 2 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE 48: # old action num: 24 # block tcp traffic from client 1 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: - target_router_hostname: router_1 + target_router: router_1 position: 3 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 49: # old action num: 25 # block tcp traffic from client 2 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: - target_router_hostname: router_1 + target_router: router_1 position: 4 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 50: # old action num: 26 - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: - target_router_hostname: router_1 + target_router: router_1 position: 5 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 51: # old action num: 27 - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: - target_router_hostname: router_1 + target_router: router_1 position: 6 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 52: # old action num: 28 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_hostname: router_1 + target_router: router_1 position: 0 53: # old action num: 29 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_hostname: router_1 + target_router: router_1 position: 1 54: # old action num: 30 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_hostname: router_1 + target_router: router_1 position: 2 55: # old action num: 31 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_hostname: router_1 + target_router: router_1 position: 3 56: # old action num: 32 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_hostname: router_1 + target_router: router_1 position: 4 57: # old action num: 33 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_hostname: router_1 + target_router: router_1 position: 5 58: # old action num: 34 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_hostname: router_1 + target_router: router_1 position: 6 59: # old action num: 35 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_hostname: router_1 + target_router: router_1 position: 7 60: # old action num: 36 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_hostname: router_1 + target_router: router_1 position: 8 61: # old action num: 37 - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: - target_router_hostname: router_1 + target_router: router_1 position: 9 62: # old action num: 38 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 63: # old action num: 39 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 64: # old action num: 40 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 65: # old action num: 41 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 66: # old action num: 42 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 67: # old action num: 43 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 68: # old action num: 44 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 69: # old action num: 45 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 70: # old action num: 46 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 71: # old action num: 47 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 72: # old action num: 48 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 73: # old action num: 49 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 74: # old action num: 50 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 75: # old action num: 51 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 76: # old action num: 52 - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 6 - nic_id: 0 + node_name: client_2 + nic_num: 1 77: # old action num: 53 - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 6 - nic_id: 0 + node_name: client_2 + nic_num: 1 78: - action: NODE_APPLICATION_INSTALL + action: node-application-install options: - node_id: 0 - application_name: DoSBot + node_name: domain_controller + application_name: dos-bot 79: - action: NODE_APPLICATION_REMOVE + action: node-application-remove options: - node_id: 0 - application_name: DoSBot + node_name: domain_controller + application_name: dos-bot 80: - action: NODE_APPLICATION_REMOVE + action: node-application-remove options: - node_id: 0 - application_name: WebBrowser + node_name: domain_controller + application_name: web-browser 81: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 + node_name: domain_controller + application_name: dos-bot 82: - action: CONFIGURE_DOSBOT + action: configure-dos-bot options: - node_id: 0 - config: - target_ip_address: 192.168.1.14 - target_port: POSTGRES_SERVER - - - - - options: - nodes: - - node_name: domain_controller - applications: - - application_name: DoSBot - - node_name: web_server - applications: - - application_name: DatabaseClient - services: - - service_name: WebServer - - node_name: database_server - folders: - - folder_name: database - files: - - file_name: database.db - services: - - service_name: DatabaseService - - node_name: backup_server - - node_name: security_suite - - node_name: client_1 - - node_name: client_2 - - max_folders_per_node: 2 - max_files_per_folder: 2 - max_services_per_node: 2 - max_nics_per_node: 8 - max_acl_rules: 10 - ip_list: - - 192.168.1.10 - - 192.168.1.12 - - 192.168.1.14 - - 192.168.1.16 - - 192.168.1.110 - - 192.168.10.21 - - 192.168.10.22 - - 192.168.10.110 - + node_name: domain_controller + target_ip_address: 192.168.1.14 + target_port: POSTGRES_SERVER reward_function: reward_components: - - type: DATABASE_FILE_INTEGRITY + - type: database-file-integrity weight: 0.40 options: node_hostname: database_server folder_name: database file_name: database.db - - type: SHARED_REWARD + - type: shared-reward weight: 1.0 options: agent_name: client_1_green_user - - type: SHARED_REWARD + - type: shared-reward weight: 1.0 options: agent_name: client_2_green_user @@ -831,7 +712,7 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 services: - - type: DNSServer + - type: dns-server options: domain_mapping: arcd.com: 192.168.1.12 # web server @@ -843,9 +724,9 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: WebServer + - type: web-server applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 @@ -857,10 +738,10 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: DatabaseService + - type: database-service options: backup_server_ip: 192.168.1.16 - - type: FTPClient + - type: ftp-client - hostname: backup_server type: server @@ -869,7 +750,7 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: FTPServer + - type: ftp-server - hostname: security_suite type: server @@ -889,20 +770,20 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: DataManipulationBot + - type: data-manipulation-bot options: port_scan_p_of_success: 0.8 data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 - - type: WebBrowser + - type: web-browser options: target_url: http://arcd.com/users/ - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 services: - - type: DNSClient + - type: dns-client - hostname: client_2 type: computer @@ -911,20 +792,20 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: WebBrowser + - type: web-browser options: target_url: http://arcd.com/users/ - - type: DataManipulationBot + - type: data-manipulation-bot options: port_scan_p_of_success: 0.8 data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 services: - - type: DNSClient + - type: dns-client diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index 27cfa240..ad43732f 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + io_settings: save_agent_actions: true save_step_metadata: true @@ -20,91 +23,58 @@ game: agents: - ref: client_2_green_user team: GREEN - type: ProbabilisticAgent - observation_space: null + type: probabilistic-agent + action_space: - action_list: - - type: DONOTHING action_map: 0: - action: DONOTHING + action: do-nothing options: {} - options: - nodes: - - node_name: client_2 - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 - - reward_function: - reward_components: - - type: DUMMY agent_settings: # options specific to this particular agent type, basically args of __init__(self) - start_settings: - start_step: 25 - frequency: 20 - variance: 5 + action_probabilities: + 0: 1.0 - ref: data_manipulation_attacker team: RED - type: RedDatabaseCorruptingAgent + type: red-database-corrupting-agent + - observation_space: null action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - - type: NODE_FILE_DELETE - - type: NODE_FILE_CORRUPT - - type: NODE_OS_SCAN action_map: 0: - action: DONOTHING + action: do-nothing options: {} 1: - action: NODE_APPLICATION_EXECUTE + action: node-application-execute options: - node_id: 0 - application_id: 0 - options: - nodes: - - node_name: client_1 - applications: - - application_name: DataManipulationBot - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - - reward_function: - reward_components: - - type: DUMMY + node_name: client_1 + application_name: data-manipulation-bot agent_settings: # options specific to this particular agent type, basically args of __init__(self) - start_settings: - start_step: 25 - frequency: 20 - variance: 5 + possible_start_nodes: [client_1,] + target_application: data-manipulation-bot + start_step: 25 + frequency: 20 + variance: 5 - ref: defender team: BLUE - type: ProxyAgent + type: proxy-agent observation_space: - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: - hostname: domain_controller - hostname: web_server services: - - service_name: WebServer + - service_name: web-server - hostname: database_server folders: - folder_name: database @@ -136,15 +106,15 @@ agents: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: @@ -158,400 +128,335 @@ agents: - switch_2:eth-1<->client_1:eth-1 - switch_2:eth-2<->client_2:eth-1 - switch_2:eth-7<->security_suite:eth-2 - - type: "NONE" + - type: "none" label: ICS options: {} action_space: - action_list: - - type: DONOTHING - - type: NODE_SERVICE_SCAN - - type: NODE_SERVICE_STOP - - type: NODE_SERVICE_START - - type: NODE_SERVICE_PAUSE - - type: NODE_SERVICE_RESUME - - type: NODE_SERVICE_RESTART - - type: NODE_SERVICE_DISABLE - - type: NODE_SERVICE_ENABLE - - type: NODE_SERVICE_FIX - - type: NODE_FILE_SCAN - - type: NODE_FILE_CHECKHASH - - type: NODE_FILE_DELETE - - type: NODE_FILE_REPAIR - - type: NODE_FILE_RESTORE - - type: NODE_FOLDER_SCAN - - type: NODE_FOLDER_CHECKHASH - - type: NODE_FOLDER_REPAIR - - type: NODE_FOLDER_RESTORE - - type: NODE_OS_SCAN - - type: NODE_SHUTDOWN - - type: NODE_STARTUP - - type: NODE_RESET - - type: ROUTER_ACL_ADDRULE - - type: ROUTER_ACL_REMOVERULE - - type: HOST_NIC_ENABLE - - type: HOST_NIC_DISABLE - action_map: 0: - action: DONOTHING + action: do-nothing options: {} # scan webapp service 1: - action: NODE_SERVICE_SCAN + action: node-service-scan options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # stop webapp service 2: - action: NODE_SERVICE_STOP + action: node-service-stop options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server # start webapp service 3: - action: "NODE_SERVICE_START" + action: "node-service-start" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 4: - action: "NODE_SERVICE_PAUSE" + action: "node-service-pause" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 5: - action: "NODE_SERVICE_RESUME" + action: "node-service-resume" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 6: - action: "NODE_SERVICE_RESTART" + action: "node-service-restart" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 7: - action: "NODE_SERVICE_DISABLE" + action: "node-service-disable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 8: - action: "NODE_SERVICE_ENABLE" + action: "node-service-enable" options: - node_id: 1 - service_id: 0 + node_name: web_server + service_name: web-server 9: # check database.db file - action: "NODE_FILE_SCAN" + action: "node-file-scan" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 10: - action: "NODE_FILE_CHECKHASH" + action: "node-file-checkhash" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 11: - action: "NODE_FILE_DELETE" + action: "node-file-delete" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 12: - action: "NODE_FILE_REPAIR" + action: "node-file-repair" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_name: database_server + folder_name: database + file_name: database.db 13: - action: "NODE_SERVICE_FIX" + action: "node-service-fix" options: - node_id: 2 - service_id: 0 + node_name: database_server + service_name: database-service 14: - action: "NODE_FOLDER_SCAN" + action: "node-folder-scan" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 15: - action: "NODE_FOLDER_CHECKHASH" + action: "node-folder-checkhash" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 16: - action: "NODE_FOLDER_REPAIR" + action: "node-folder-repair" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 17: - action: "NODE_FOLDER_RESTORE" + action: "node-folder-restore" options: - node_id: 2 - folder_id: 0 + node_name: database_server + folder_name: database 18: - action: "NODE_OS_SCAN" + action: "node-os-scan" options: - node_id: 2 + node_name: database_server 19: # shutdown client 1 - action: "NODE_SHUTDOWN" + action: "node-shutdown" options: - node_id: 5 + node_name: client_1 20: - action: "NODE_STARTUP" + action: "node-startup" options: - node_id: 5 + node_name: client_1 21: - action: "NODE_RESET" + action: "node-reset" options: - node_id: 5 - 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite) - action: "ROUTER_ACL_ADDRULE" + node_name: client_1 + 22: # "acl: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite) + action: "router-acl-add-rule" options: target_router: router_1 position: 1 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite) - action: "ROUTER_ACL_ADDRULE" + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE + 23: # "acl: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite) + action: "router-acl-add-rule" options: target_router: router_1 position: 2 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: ALL # ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: NONE + dst_wildcard: NONE 24: # block tcp traffic from client 1 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 3 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 25: # block tcp traffic from client 2 to web app - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 4 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.12 # web server + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 26: - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 5 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.21 # client 1 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 27: - action: "ROUTER_ACL_ADDRULE" + action: "router-acl-add-rule" options: target_router: router_1 position: 6 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: DENY + src_ip: 192.168.10.22 # client 2 + dst_ip: 192.168.1.14 # database + src_port: ALL + dst_port: ALL + protocol_name: TCP + src_wildcard: NONE + dst_wildcard: NONE 28: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 0 29: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 1 30: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 2 31: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 3 32: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 4 33: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 5 34: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 6 35: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 7 36: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 8 37: - action: "ROUTER_ACL_REMOVERULE" + action: "router-acl-remove-rule" options: target_router: router_1 position: 9 38: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 39: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 0 - nic_id: 0 + node_name: domain_controller + nic_num: 1 40: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 41: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 1 - nic_id: 0 + node_name: web_server + nic_num: 1 42: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 43: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 2 - nic_id: 0 + node_name: database_server + nic_num: 1 44: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 45: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 3 - nic_id: 0 + node_name: backup_server + nic_num: 1 46: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 47: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 0 + node_name: security_suite + nic_num: 1 48: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 49: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 4 - nic_id: 1 + node_name: security_suite + nic_num: 2 50: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 51: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 5 - nic_id: 0 + node_name: client_1 + nic_num: 1 52: - action: "HOST_NIC_DISABLE" + action: "host-nic-disable" options: - node_id: 6 - nic_id: 0 + node_name: client_2 + nic_num: 1 53: - action: "HOST_NIC_ENABLE" + action: "host-nic-enable" options: - node_id: 6 - nic_id: 0 - - - options: - nodes: - - node_name: domain_controller - - node_name: web_server - applications: - - application_name: DatabaseClient - services: - - service_name: WebServer - - node_name: database_server - folders: - - folder_name: database - files: - - file_name: database.db - services: - - service_name: DatabaseService - - node_name: backup_server - - node_name: security_suite - - node_name: client_1 - - node_name: client_2 - - max_folders_per_node: 2 - max_files_per_folder: 2 - max_services_per_node: 2 - max_nics_per_node: 8 - max_acl_rules: 10 - ip_list: - - 192.168.1.10 - - 192.168.1.12 - - 192.168.1.14 - - 192.168.1.16 - - 192.168.1.110 - - 192.168.10.21 - - 192.168.10.22 - - 192.168.10.110 + node_name: client_2 + nic_num: 1 reward_function: reward_components: - - type: DATABASE_FILE_INTEGRITY + - type: database-file-integrity weight: 0.5 options: node_hostname: database_server @@ -559,7 +464,7 @@ agents: file_name: database.db - - type: WEB_SERVER_404_PENALTY + - type: web-server-404-penalty weight: 0.5 options: node_hostname: web_server @@ -619,7 +524,7 @@ simulation: subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 services: - - type: DNSServer + - type: dns-server options: domain_mapping: arcd.com: 192.168.1.12 # web server @@ -631,9 +536,9 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: WebServer + - type: web-server applications: - - type: DatabaseClient + - type: database-client options: db_server_ip: 192.168.1.14 @@ -645,7 +550,7 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: DatabaseService + - type: database-service options: backup_server_ip: 192.168.1.16 @@ -656,7 +561,7 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - - type: FTPServer + - type: ftp-server - type: server hostname: security_suite @@ -676,14 +581,14 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: DataManipulationBot + - type: data-manipulation-bot options: port_scan_p_of_success: 0.1 data_manipulation_p_of_success: 0.1 payload: "DELETE" server_ip: 192.168.1.14 services: - - type: DNSClient + - type: dns-client - type: computer hostname: client_2 @@ -692,16 +597,16 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - - type: WebBrowser + - type: web-browser services: - - type: DNSClient + - type: dns-client - type: printer hostname: HP_LaserJet_Pro_4102fdn_printer ip_address: 192.168.10.99 subnet_mask: 255.255.255.0 - - type: wireless_router + - type: wireless-router hostname: router_2 router_interface: ip_address: 192.169.1.1 diff --git a/tests/assets/configs/wireless_wan_network_config.yaml b/tests/assets/configs/wireless_wan_network_config.yaml index c8f61bad..45721d8a 100644 --- a/tests/assets/configs/wireless_wan_network_config.yaml +++ b/tests/assets/configs/wireless_wan_network_config.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + game: max_episode_length: 256 ports: @@ -24,7 +27,7 @@ simulation: default_gateway: 192.168.2.1 start_up_duration: 0 - - type: wireless_router + - type: wireless-router hostname: router_1 start_up_duration: 0 @@ -45,7 +48,7 @@ simulation: next_hop_ip_address: 192.168.1.2 metric: 0 - - type: wireless_router + - type: wireless-router hostname: router_2 start_up_duration: 0 diff --git a/tests/assets/configs/wireless_wan_network_config_freq_max_override.yaml b/tests/assets/configs/wireless_wan_network_config_freq_max_override.yaml index a327b0f5..20e48a89 100644 --- a/tests/assets/configs/wireless_wan_network_config_freq_max_override.yaml +++ b/tests/assets/configs/wireless_wan_network_config_freq_max_override.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + game: max_episode_length: 256 ports: @@ -28,7 +31,7 @@ simulation: default_gateway: 192.168.2.1 start_up_duration: 0 - - type: wireless_router + - type: wireless-router hostname: router_1 start_up_duration: 0 @@ -49,7 +52,7 @@ simulation: next_hop_ip_address: 192.168.1.2 metric: 0 - - type: wireless_router + - type: wireless-router hostname: router_2 start_up_duration: 0 diff --git a/tests/assets/configs/wireless_wan_network_config_freq_max_override_blocked.yaml b/tests/assets/configs/wireless_wan_network_config_freq_max_override_blocked.yaml index ff048c92..6342d1b1 100644 --- a/tests/assets/configs/wireless_wan_network_config_freq_max_override_blocked.yaml +++ b/tests/assets/configs/wireless_wan_network_config_freq_max_override_blocked.yaml @@ -1,3 +1,6 @@ +metadata: + version: 3.0 + game: max_episode_length: 256 ports: @@ -28,7 +31,7 @@ simulation: default_gateway: 192.168.2.1 start_up_duration: 0 - - type: wireless_router + - type: wireless-router hostname: router_1 start_up_duration: 0 @@ -49,7 +52,7 @@ simulation: next_hop_ip_address: 192.168.1.2 metric: 0 - - type: wireless_router + - type: wireless-router hostname: router_2 start_up_duration: 0 diff --git a/tests/conftest.py b/tests/conftest.py index fcce0cae..becb9a06 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,9 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from typing import Any, Dict, Tuple +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from typing import Any, Dict, Optional, Tuple import pytest import yaml +from pydantic import Field from ray import init as rayinit from primaite import getLogger, PRIMAITE_PATHS @@ -10,6 +11,7 @@ from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager from primaite.game.agent.rewards import RewardFunction +from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent from primaite.game.game import PrimaiteGame from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.container import Network @@ -18,8 +20,6 @@ from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.networks import arcd_uc2_network -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.web_browser import WebBrowser @@ -28,6 +28,8 @@ from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT rayinit() @@ -37,29 +39,43 @@ ACTION_SPACE_NODE_ACTION_VALUES = 1 _LOGGER = getLogger(__name__) -class DummyService(Service): +class DummyService(Service, discriminator="dummy-service"): """Test Service class""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for DummyService.""" + + type: str = "dummy-service" + + config: ConfigSchema = Field(default_factory=lambda: DummyService.ConfigSchema()) + def describe_state(self) -> Dict: return super().describe_state() def __init__(self, **kwargs): - kwargs["name"] = "DummyService" - kwargs["port"] = Port.HTTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["name"] = "dummy-service" + kwargs["port"] = PORT_LOOKUP["HTTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def receive(self, payload: Any, session_id: str, **kwargs) -> bool: pass -class DummyApplication(Application, identifier="DummyApplication"): +class DummyApplication(Application, discriminator="dummy-application"): """Test Application class""" + class ConfigSchema(Application.ConfigSchema): + """ConfigSchema for DummyApplication.""" + + type: str = "dummy-application" + + config: ConfigSchema = Field(default_factory=lambda: DummyApplication.ConfigSchema()) + def __init__(self, **kwargs): - kwargs["name"] = "DummyApplication" - kwargs["port"] = Port.HTTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["name"] = "dummy-application" + kwargs["port"] = PORT_LOOKUP["HTTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: @@ -77,7 +93,7 @@ def uc2_network() -> Network: @pytest.fixture(scope="function") def service(file_system) -> DummyService: return DummyService( - name="DummyService", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_service") + name="dummy-service", port=PORT_LOOKUP["ARP"], file_system=file_system, sys_log=SysLog(hostname="dummy_service") ) @@ -89,8 +105,8 @@ def service_class(): @pytest.fixture(scope="function") def application(file_system) -> DummyApplication: return DummyApplication( - name="DummyApplication", - port=Port.ARP, + name="dummy-application", + port=PORT_LOOKUP["ARP"], file_system=file_system, sys_log=SysLog(hostname="dummy_application"), ) @@ -103,7 +119,14 @@ def application_class(): @pytest.fixture(scope="function") def file_system() -> FileSystem: - computer = Computer(hostname="fs_node", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + computer_cfg = { + "type": "computer", + "hostname": "fs_node", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + computer = Computer.from_config(config=computer_cfg) computer.power_on() return computer.file_system @@ -113,23 +136,29 @@ def client_server() -> Tuple[Computer, Server]: network = Network() # Create Computer - computer = Computer( - hostname="computer", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + computer_cfg = { + "type": "computer", + "hostname": "computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + computer: Computer = Computer.from_config(config=computer_cfg) computer.power_on() # Create Server - server = Server( - hostname="server", - ip_address="192.168.1.3", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + server_cfg = { + "type": "server", + "hostname": "server", + "ip_address": "192.168.1.3", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + server: Server = Server.from_config(config=server_cfg) server.power_on() # Connect Computer and Server @@ -146,26 +175,33 @@ def client_switch_server() -> Tuple[Computer, Switch, Server]: network = Network() # Create Computer - computer = Computer( - hostname="computer", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + computer_cfg = { + "type": "computer", + "hostname": "computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + computer: Computer = Computer.from_config(config=computer_cfg) computer.power_on() # Create Server - server = Server( - hostname="server", - ip_address="192.168.1.3", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + server_cfg = { + "type": "server", + "hostname": "server", + "ip_address": "192.168.1.3", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + server: Server = Server.from_config(config=server_cfg) server.power_on() - switch = Switch(hostname="switch", start_up_duration=0) + # Create Switch + switch: Switch = Switch.from_config(config={"type": "switch", "hostname": "switch", "start_up_duration": 0}) switch.power_on() network.connect(endpoint_a=computer.network_interface[1], endpoint_b=switch.network_interface[1]) @@ -195,65 +231,94 @@ def example_network() -> Network: network = Network() # Router 1 - router_1 = Router(hostname="router_1", start_up_duration=0) + + router_1_cfg = {"hostname": "router_1", "type": "router", "start_up_duration": 0} + + # router_1 = Router(hostname="router_1", start_up_duration=0) + router_1 = Router.from_config(config=router_1_cfg) router_1.power_on() router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0") router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0") # Switch 1 - switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0) + + switch_1_cfg = {"hostname": "switch_1", "type": "switch", "start_up_duration": 0} + + switch_1 = Switch.from_config(config=switch_1_cfg) + switch_1.power_on() network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[8]) router_1.enable_port(1) # Switch 2 - switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0) + switch_2_config = {"hostname": "switch_2", "type": "switch", "num_ports": 8, "start_up_duration": 0} + switch_2 = Switch.from_config(config=switch_2_config) switch_2.power_on() network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8]) router_1.enable_port(2) - # Client 1 - client_1 = Computer( - hostname="client_1", - ip_address="192.168.10.21", - subnet_mask="255.255.255.0", - default_gateway="192.168.10.1", - start_up_duration=0, - ) + # # Client 1 + + client_1_cfg = { + "type": "computer", + "hostname": "client_1", + "ip_address": "192.168.10.21", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.10.1", + "start_up_duration": 0, + } + + client_1 = Computer.from_config(config=client_1_cfg) + client_1.power_on() network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) - # Client 2 - client_2 = Computer( - hostname="client_2", - ip_address="192.168.10.22", - subnet_mask="255.255.255.0", - default_gateway="192.168.10.1", - start_up_duration=0, - ) + # # Client 2 + + client_2_cfg = { + "type": "computer", + "hostname": "client_2", + "ip_address": "192.168.10.22", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.10.1", + "start_up_duration": 0, + } + + client_2 = Computer.from_config(config=client_2_cfg) + client_2.power_on() network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.network_interface[2]) - # Server 1 - server_1 = Server( - hostname="server_1", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + # # Server 1 + + server_1_cfg = { + "type": "server", + "hostname": "server_1", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + server_1 = Server.from_config(config=server_1_cfg) + server_1.power_on() network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1]) - # DServer 2 - server_2 = Server( - hostname="server_2", - ip_address="192.168.1.14", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + # # DServer 2 + + server_2_cfg = { + "type": "server", + "hostname": "server_2", + "ip_address": "192.168.1.14", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + server_2 = Server.from_config(config=server_2_cfg) + server_2.power_on() network.connect(endpoint_b=server_2.network_interface[1], endpoint_a=switch_1.network_interface[2]) @@ -261,26 +326,22 @@ def example_network() -> Network: assert all(link.is_up for link in network.links.values()) + client_1.software_manager.show() + return network -class ControlledAgent(AbstractAgent): +class ControlledAgent(AbstractAgent, discriminator="controlled-agent"): """Agent that can be controlled by the tests.""" - def __init__( - self, - agent_name: str, - action_space: ActionManager, - observation_space: ObservationManager, - reward_function: RewardFunction, - ) -> None: - super().__init__( - agent_name=agent_name, - action_space=action_space, - observation_space=observation_space, - reward_function=reward_function, - ) - self.most_recent_action: Tuple[str, Dict] + most_recent_action: Optional[Tuple[str, Dict]] = None + + class ConfigSchema(AbstractAgent.ConfigSchema): + """Configuration Schema for Abstract Agent used in tests.""" + + type: str = "controlled-agent" + + config: ConfigSchema = Field(default_factory=lambda: ControlledAgent.ConfigSchema()) def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]: """Return the agent's most recent action, formatted in CAOS format.""" @@ -299,29 +360,35 @@ def install_stuff_to_sim(sim: Simulation): # 1: Set up network hardware # 1.1: Configure the router - router = Router(hostname="router", num_ports=3, start_up_duration=0) + router = Router.from_config(config={"type": "router", "hostname": "router", "num_ports": 3, "start_up_duration": 0}) router.power_on() router.configure_port(port=1, ip_address="10.0.1.1", subnet_mask="255.255.255.0") router.configure_port(port=2, ip_address="10.0.2.1", subnet_mask="255.255.255.0") # 1.2: Create and connect switches - switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0) + switch_1 = Switch.from_config( + config={"type": "switch", "hostname": "switch_1", "num_ports": 6, "start_up_duration": 0} + ) switch_1.power_on() network.connect(endpoint_a=router.network_interface[1], endpoint_b=switch_1.network_interface[6]) router.enable_port(1) - switch_2 = Switch(hostname="switch_2", num_ports=6, start_up_duration=0) + switch_2 = Switch.from_config( + config={"type": "switch", "hostname": "switch_2", "num_ports": 6, "start_up_duration": 0} + ) switch_2.power_on() network.connect(endpoint_a=router.network_interface[2], endpoint_b=switch_2.network_interface[6]) router.enable_port(2) # 1.3: Create and connect computer - client_1 = Computer( - hostname="client_1", - ip_address="10.0.1.2", - subnet_mask="255.255.255.0", - default_gateway="10.0.1.1", - start_up_duration=0, - ) + client_1_cfg = { + "type": "computer", + "hostname": "client_1", + "ip_address": "10.0.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "10.0.1.1", + "start_up_duration": 0, + } + client_1: Computer = Computer.from_config(config=client_1_cfg) client_1.power_on() network.connect( endpoint_a=client_1.network_interface[1], @@ -329,44 +396,50 @@ def install_stuff_to_sim(sim: Simulation): ) # 1.4: Create and connect servers - server_1 = Server( - hostname="server_1", - ip_address="10.0.2.2", - subnet_mask="255.255.255.0", - default_gateway="10.0.2.1", - start_up_duration=0, - ) + server_1_cfg = { + "type": "server", + "hostname": "server_1", + "ip_address": "10.0.2.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "10.0.2.1", + "start_up_duration": 0, + } + + server_1: Server = Server.from_config(config=server_1_cfg) server_1.power_on() network.connect(endpoint_a=server_1.network_interface[1], endpoint_b=switch_2.network_interface[1]) + server_2_cfg = { + "type": "server", + "hostname": "server_2", + "ip_address": "10.0.2.3", + "subnet_mask": "255.255.255.0", + "default_gateway": "10.0.2.1", + "start_up_duration": 0, + } - server_2 = Server( - hostname="server_2", - ip_address="10.0.2.3", - subnet_mask="255.255.255.0", - default_gateway="10.0.2.1", - start_up_duration=0, - ) + server_2: Server = Server.from_config(config=server_2_cfg) server_2.power_on() network.connect(endpoint_a=server_2.network_interface[1], endpoint_b=switch_2.network_interface[2]) - # 2: Configure base ACL - router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + # 2: Configure base acl + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22) + router.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["DNS"], dst_port=PORT_LOOKUP["DNS"], position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=3) # 3: Install server software server_1.software_manager.install(DNSServer) - dns_service: DNSServer = server_1.software_manager.software.get("DNSServer") # noqa + dns_service: DNSServer = server_1.software_manager.software.get("dns-server") # noqa dns_service.dns_register("www.example.com", server_2.network_interface[1].ip_address) server_2.software_manager.install(WebServer) # 3.1: Ensure that the dns clients are configured correctly - client_1.software_manager.software.get("DNSClient").dns_server = server_1.network_interface[1].ip_address - server_2.software_manager.software.get("DNSClient").dns_server = server_1.network_interface[1].ip_address + client_1.software_manager.software.get("dns-client").dns_server = server_1.network_interface[1].ip_address + server_2.software_manager.software.get("dns-client").dns_server = server_1.network_interface[1].ip_address # 4: Check that client came pre-installed with web browser and dns client - assert isinstance(client_1.software_manager.software.get("WebBrowser"), WebBrowser) - assert isinstance(client_1.software_manager.software.get("DNSClient"), DNSClient) + assert isinstance(client_1.software_manager.software.get("web-browser"), WebBrowser) + assert isinstance(client_1.software_manager.software.get("dns-client"), DNSClient) # 4.1: Create a file on the computer client_1.file_system.create_file("cat.png", 300, folder_name="downloads") @@ -374,35 +447,38 @@ def install_stuff_to_sim(sim: Simulation): # 5: Assert that the simulation starts off in the state that we expect assert len(sim.network.nodes) == 6 assert len(sim.network.links) == 5 + # 5.1: Assert the router is correctly configured r = sim.network.router_nodes[0] for i, acl_rule in enumerate(r.acl.acl): if i == 1: - assert acl_rule.src_port == acl_rule.dst_port == Port.DNS + assert acl_rule.src_port == acl_rule.dst_port == PORT_LOOKUP["DNS"] elif i == 3: - assert acl_rule.src_port == acl_rule.dst_port == Port.HTTP + assert acl_rule.src_port == acl_rule.dst_port == PORT_LOOKUP["HTTP"] + elif i == 22: + assert acl_rule.src_port == acl_rule.dst_port == PORT_LOOKUP["ARP"] elif i == 23: - assert acl_rule.protocol == IPProtocol.ICMP + assert acl_rule.protocol == PROTOCOL_LOOKUP["ICMP"] elif i == 24: ... else: assert acl_rule is None # 5.2: Assert the client is correctly configured - c: Computer = [node for node in sim.network.nodes.values() if node.hostname == "client_1"][0] - assert c.software_manager.software.get("WebBrowser") is not None - assert c.software_manager.software.get("DNSClient") is not None + c: Computer = [node for node in sim.network.nodes.values() if node.config.hostname == "client_1"][0] + assert c.software_manager.software.get("web-browser") is not None + assert c.software_manager.software.get("dns-client") is not None assert str(c.network_interface[1].ip_address) == "10.0.1.2" # 5.3: Assert that server_1 is correctly configured - s1: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_1"][0] + s1: Server = [node for node in sim.network.nodes.values() if node.config.hostname == "server_1"][0] assert str(s1.network_interface[1].ip_address) == "10.0.2.2" - assert s1.software_manager.software.get("DNSServer") is not None + assert s1.software_manager.software.get("dns-server") is not None # 5.4: Assert that server_2 is correctly configured - s2: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_2"][0] + s2: Server = [node for node in sim.network.nodes.values() if node.config.hostname == "server_2"][0] assert str(s2.network_interface[1].ip_address) == "10.0.2.3" - assert s2.software_manager.software.get("WebServer") is not None + assert s2.software_manager.software.get("web-server") is not None # 6: Return the simulation return sim @@ -415,100 +491,13 @@ def game_and_agent(): sim = game.simulation install_stuff_to_sim(sim) - actions = [ - {"type": "DONOTHING"}, - {"type": "NODE_SERVICE_SCAN"}, - {"type": "NODE_SERVICE_STOP"}, - {"type": "NODE_SERVICE_START"}, - {"type": "NODE_SERVICE_PAUSE"}, - {"type": "NODE_SERVICE_RESUME"}, - {"type": "NODE_SERVICE_RESTART"}, - {"type": "NODE_SERVICE_DISABLE"}, - {"type": "NODE_SERVICE_ENABLE"}, - {"type": "NODE_SERVICE_FIX"}, - {"type": "NODE_APPLICATION_EXECUTE"}, - {"type": "NODE_APPLICATION_SCAN"}, - {"type": "NODE_APPLICATION_CLOSE"}, - {"type": "NODE_APPLICATION_FIX"}, - {"type": "NODE_APPLICATION_INSTALL"}, - {"type": "NODE_APPLICATION_REMOVE"}, - {"type": "NODE_FILE_CREATE"}, - {"type": "NODE_FILE_SCAN"}, - {"type": "NODE_FILE_CHECKHASH"}, - {"type": "NODE_FILE_DELETE"}, - {"type": "NODE_FILE_REPAIR"}, - {"type": "NODE_FILE_RESTORE"}, - {"type": "NODE_FILE_CORRUPT"}, - {"type": "NODE_FILE_ACCESS"}, - {"type": "NODE_FOLDER_CREATE"}, - {"type": "NODE_FOLDER_SCAN"}, - {"type": "NODE_FOLDER_CHECKHASH"}, - {"type": "NODE_FOLDER_REPAIR"}, - {"type": "NODE_FOLDER_RESTORE"}, - {"type": "NODE_OS_SCAN"}, - {"type": "NODE_SHUTDOWN"}, - {"type": "NODE_STARTUP"}, - {"type": "NODE_RESET"}, - {"type": "ROUTER_ACL_ADDRULE"}, - {"type": "ROUTER_ACL_REMOVERULE"}, - {"type": "HOST_NIC_ENABLE"}, - {"type": "HOST_NIC_DISABLE"}, - {"type": "NETWORK_PORT_ENABLE"}, - {"type": "NETWORK_PORT_DISABLE"}, - {"type": "CONFIGURE_C2_BEACON"}, - {"type": "C2_SERVER_RANSOMWARE_LAUNCH"}, - {"type": "C2_SERVER_RANSOMWARE_CONFIGURE"}, - {"type": "C2_SERVER_TERMINAL_COMMAND"}, - {"type": "C2_SERVER_DATA_EXFILTRATE"}, - {"type": "NODE_ACCOUNTS_ADD_USER"}, - {"type": "NODE_ACCOUNTS_DISABLE_USER"}, - {"type": "NODE_ACCOUNTS_CHANGE_PASSWORD"}, - {"type": "SSH_TO_REMOTE"}, - {"type": "SESSIONS_REMOTE_LOGOFF"}, - {"type": "NODE_SEND_REMOTE_COMMAND"}, - {"type": "NODE_SEND_LOCAL_COMMAND"}, - ] + config = { + "type": "controlled-agent", + "ref": "test_agent", + "team": "BLUE", + } - action_space = ActionManager( - actions=actions, # ALL POSSIBLE ACTIONS - nodes=[ - { - "node_name": "client_1", - "applications": [ - {"application_name": "WebBrowser"}, - {"application_name": "DoSBot"}, - {"application_name": "C2Server"}, - ], - "folders": [{"folder_name": "downloads", "files": [{"file_name": "cat.png"}]}], - }, - { - "node_name": "server_1", - "services": [{"service_name": "DNSServer"}], - "applications": [{"application_name": "C2Beacon"}], - }, - {"node_name": "server_2", "services": [{"service_name": "WebServer"}]}, - {"node_name": "router"}, - ], - max_folders_per_node=2, - max_files_per_folder=2, - max_services_per_node=2, - max_applications_per_node=3, - max_nics_per_node=2, - max_acl_rules=10, - protocols=["TCP", "UDP", "ICMP"], - ports=["HTTP", "DNS", "ARP"], - ip_list=["10.0.1.1", "10.0.1.2", "10.0.2.1", "10.0.2.2", "10.0.2.3"], - act_map={}, - ) - observation_space = ObservationManager(NestedObservation(components={})) - reward_function = RewardFunction() - - test_agent = ControlledAgent( - agent_name="test_agent", - action_space=action_space, - observation_space=observation_space, - reward_function=reward_function, - ) + test_agent = ControlledAgent(config=config) game.agents["test_agent"] = test_agent diff --git a/tests/e2e_integration_tests/__init__.py b/tests/e2e_integration_tests/__init__.py index be6c00e7..836b79af 100644 --- a/tests/e2e_integration_tests/__init__.py +++ b/tests/e2e_integration_tests/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/e2e_integration_tests/action_masking/__init__.py b/tests/e2e_integration_tests/action_masking/__init__.py index be6c00e7..836b79af 100644 --- a/tests/e2e_integration_tests/action_masking/__init__.py +++ b/tests/e2e_integration_tests/action_masking/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py index addf6dca..6da801d4 100644 --- a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py +++ b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Dict import pytest @@ -12,6 +12,7 @@ from sb3_contrib import MaskablePPO from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from tests import TEST_ASSETS_ROOT CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml" diff --git a/tests/e2e_integration_tests/environments/__init__.py b/tests/e2e_integration_tests/environments/__init__.py index be6c00e7..836b79af 100644 --- a/tests/e2e_integration_tests/environments/__init__.py +++ b/tests/e2e_integration_tests/environments/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py index 26e690d0..06b080d8 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import yaml from ray.rllib.algorithms.ppo import PPOConfig diff --git a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py index 265257e4..da0ca458 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import tempfile from pathlib import Path diff --git a/tests/e2e_integration_tests/environments/test_sb3_environment.py b/tests/e2e_integration_tests/environments/test_sb3_environment.py index a07d5d2e..9ca3525a 100644 --- a/tests/e2e_integration_tests/environments/test_sb3_environment.py +++ b/tests/e2e_integration_tests/environments/test_sb3_environment.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """Test that we can create a primaite environment and train sb3 agent with no crash.""" import tempfile from pathlib import Path diff --git a/tests/e2e_integration_tests/test_environment.py b/tests/e2e_integration_tests/test_environment.py index dcd51193..881681aa 100644 --- a/tests/e2e_integration_tests/test_environment.py +++ b/tests/e2e_integration_tests/test_environment.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pydantic import pytest import yaml diff --git a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py index 7ec38d72..79d0db1b 100644 --- a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py +++ b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import yaml from primaite.game.game import PrimaiteGame @@ -14,13 +14,13 @@ from tests import TEST_ASSETS_ROOT def test_data_manipulation(uc2_network): """Tests the UC2 data manipulation scenario end-to-end. Is a work in progress.""" client_1: Computer = uc2_network.get_node_by_hostname("client_1") - db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") + db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("data-manipulation-bot") database_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = database_server.software_manager.software.get("DatabaseService") + db_service: DatabaseService = database_server.software_manager.software.get("database-service") web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") + db_client: DatabaseClient = web_server.software_manager.software.get("database-client") db_connection: DatabaseClientConnection = db_client.get_new_connection() db_service.backup_database() @@ -49,7 +49,7 @@ def test_application_install_uninstall_on_uc2(): cfg = yaml.safe_load(f) env = PrimaiteGymEnv(env_config=cfg) - env.agent.flatten_obs = False + env.agent.config.agent_settings.flatten_obs = False env.reset() _, _, _, _, _ = env.step(0) @@ -61,7 +61,7 @@ def test_application_install_uninstall_on_uc2(): # Test we can Install the DoSBot app _, _, _, _, info = env.step(78) - assert "DoSBot" in domcon.software_manager.software + assert "dos-bot" in domcon.software_manager.software # installing takes 3 steps so let's wait for 3 steps env.step(0) @@ -75,13 +75,13 @@ def test_application_install_uninstall_on_uc2(): # Test we can Uninstall the DoSBot app _, _, _, _, info = env.step(79) - assert "DoSBot" not in domcon.software_manager.software + assert "dos-bot" not in domcon.software_manager.software # Test we cannot execute the DoSBot app as it was uninstalled _, _, _, _, info = env.step(81) assert info["agent_actions"]["defender"].response.status == "unreachable" # Test we can uninstall one of the default apps (WebBrowser) - assert "WebBrowser" in domcon.software_manager.software + assert "web-browser" in domcon.software_manager.software _, _, _, _, info = env.step(80) - assert "WebBrowser" not in domcon.software_manager.software + assert "web-browser" not in domcon.software_manager.software diff --git a/tests/integration_tests/__init__.py b/tests/integration_tests/__init__.py index be6c00e7..836b79af 100644 --- a/tests/integration_tests/__init__.py +++ b/tests/integration_tests/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/integration_tests/cli/__init__.py b/tests/integration_tests/cli/__init__.py index cfce7ae6..603d228f 100644 --- a/tests/integration_tests/cli/__init__.py +++ b/tests/integration_tests/cli/__init__.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import List from typer.testing import CliRunner, Result diff --git a/tests/integration_tests/cli/test_dev_cli.py b/tests/integration_tests/cli/test_dev_cli.py index cd390555..16c3de9f 100644 --- a/tests/integration_tests/cli/test_dev_cli.py +++ b/tests/integration_tests/cli/test_dev_cli.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import os import shutil import tempfile diff --git a/tests/integration_tests/component_creation/__init__.py b/tests/integration_tests/component_creation/__init__.py index be6c00e7..836b79af 100644 --- a/tests/integration_tests/component_creation/__init__.py +++ b/tests/integration_tests/component_creation/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/integration_tests/component_creation/test_action_integration.py b/tests/integration_tests/component_creation/test_action_integration.py index 7bdc80fc..0fd0aa19 100644 --- a/tests/integration_tests/component_creation/test_action_integration.py +++ b/tests/integration_tests/component_creation/test_action_integration.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from primaite.simulator.core import RequestType from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server @@ -12,12 +12,18 @@ def test_passing_actions_down(monkeypatch) -> None: sim = Simulation() - pc1 = Computer(hostname="PC-1", ip_address="10.10.1.1", subnet_mask="255.255.255.0") + pc1 = Computer.from_config( + config={"type": "computer", "hostname": "PC-1", "ip_address": "10.10.1.1", "subnet_mask": "255.255.255.0"} + ) pc1.start_up_duration = 0 pc1.power_on() - pc2 = Computer(hostname="PC-2", ip_address="10.10.1.2", subnet_mask="255.255.255.0") - srv = Server(hostname="WEBSERVER", ip_address="10.10.1.100", subnet_mask="255.255.255.0") - s1 = Switch(hostname="switch1") + pc2 = Computer.from_config( + config={"type": "computer", "hostname": "PC-2", "ip_address": "10.10.1.2", "subnet_mask": "255.255.255.0"} + ) + srv = Server.from_config( + config={"type": "server", "hostname": "WEBSERVER", "ip_address": "10.10.1.100", "subnet_mask": "255.255.255.0"} + ) + s1 = Switch.from_config(config={"type": "switch", "hostname": "switch1"}) for n in [pc1, pc2, srv, s1]: sim.network.add_node(n) @@ -48,6 +54,6 @@ def test_passing_actions_down(monkeypatch) -> None: assert not action_invoked # call the patched method - sim.apply_request(["network", "node", pc1.hostname, "file_system", "folder", "downloads", "repair"]) + sim.apply_request(["network", "node", pc1.config.hostname, "file_system", "folder", "downloads", "repair"]) assert action_invoked diff --git a/tests/integration_tests/component_creation/test_permission_system.py b/tests/integration_tests/component_creation/test_permission_system.py index baf75523..c7faa81b 100644 --- a/tests/integration_tests/component_creation/test_permission_system.py +++ b/tests/integration_tests/component_creation/test_permission_system.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from enum import Enum from typing import Dict, List, Literal diff --git a/tests/integration_tests/configuration_file_parsing/__init__.py b/tests/integration_tests/configuration_file_parsing/__init__.py index 7e23a4c2..09861acb 100644 --- a/tests/integration_tests/configuration_file_parsing/__init__.py +++ b/tests/integration_tests/configuration_file_parsing/__init__.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from pathlib import Path from typing import Union diff --git a/tests/integration_tests/configuration_file_parsing/nodes/__init__.py b/tests/integration_tests/configuration_file_parsing/nodes/__init__.py index be6c00e7..836b79af 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/__init__.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/__init__.py b/tests/integration_tests/configuration_file_parsing/nodes/network/__init__.py index be6c00e7..836b79af 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/__init__.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py index 457fdb42..234e7342 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address import pytest @@ -9,8 +9,8 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests.integration_tests.configuration_file_parsing import BASIC_FIREWALL, DMZ_NETWORK, load_config @@ -68,44 +68,44 @@ def test_firewall_acl_rules_correctly_added(dmz_config): # ICMP and ARP should be allowed internal_inbound assert firewall.internal_inbound_acl.num_rules == 2 assert firewall.internal_inbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.internal_inbound_acl.acl[22].src_port == Port.ARP - assert firewall.internal_inbound_acl.acl[22].dst_port == Port.ARP + assert firewall.internal_inbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.internal_inbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert firewall.internal_inbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.internal_inbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.internal_inbound_acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert firewall.internal_inbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed internal_outbound assert firewall.internal_outbound_acl.num_rules == 2 assert firewall.internal_outbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.internal_outbound_acl.acl[22].src_port == Port.ARP - assert firewall.internal_outbound_acl.acl[22].dst_port == Port.ARP + assert firewall.internal_outbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.internal_outbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert firewall.internal_outbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.internal_outbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.internal_outbound_acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert firewall.internal_outbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed dmz_inbound assert firewall.dmz_inbound_acl.num_rules == 2 assert firewall.dmz_inbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.dmz_inbound_acl.acl[22].src_port == Port.ARP - assert firewall.dmz_inbound_acl.acl[22].dst_port == Port.ARP + assert firewall.dmz_inbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.dmz_inbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert firewall.dmz_inbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.dmz_inbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.dmz_inbound_acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert firewall.dmz_inbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed dmz_outbound assert firewall.dmz_outbound_acl.num_rules == 2 assert firewall.dmz_outbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.dmz_outbound_acl.acl[22].src_port == Port.ARP - assert firewall.dmz_outbound_acl.acl[22].dst_port == Port.ARP + assert firewall.dmz_outbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.dmz_outbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert firewall.dmz_outbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.dmz_outbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.dmz_outbound_acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert firewall.dmz_outbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed external_inbound assert firewall.external_inbound_acl.num_rules == 1 assert firewall.external_inbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.external_inbound_acl.acl[22].src_port == Port.ARP - assert firewall.external_inbound_acl.acl[22].dst_port == Port.ARP + assert firewall.external_inbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.external_inbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] # external_inbound should have implicit action PERMIT # ICMP does not have a provided ACL Rule but implicit action should allow anything assert firewall.external_inbound_acl.implicit_action == ACLAction.PERMIT @@ -113,8 +113,8 @@ def test_firewall_acl_rules_correctly_added(dmz_config): # ICMP and ARP should be allowed external_outbound assert firewall.external_outbound_acl.num_rules == 1 assert firewall.external_outbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.external_outbound_acl.acl[22].src_port == Port.ARP - assert firewall.external_outbound_acl.acl[22].dst_port == Port.ARP + assert firewall.external_outbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.external_outbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] # external_outbound should have implicit action PERMIT # ICMP does not have a provided ACL Rule but implicit action should allow anything assert firewall.external_outbound_acl.implicit_action == ACLAction.PERMIT diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py index ccde3a02..7ca3a6aa 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py @@ -1,13 +1,14 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests.integration_tests.configuration_file_parsing import DMZ_NETWORK, load_config @@ -63,8 +64,8 @@ def test_router_acl_rules_correctly_added(dmz_config): # ICMP and ARP should be allowed assert router_1.acl.num_rules == 2 assert router_1.acl.acl[22].action == ACLAction.PERMIT - assert router_1.acl.acl[22].src_port == Port.ARP - assert router_1.acl.acl[22].dst_port == Port.ARP + assert router_1.acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert router_1.acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert router_1.acl.acl[23].action == ACLAction.PERMIT - assert router_1.acl.acl[23].protocol == IPProtocol.ICMP + assert router_1.acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert router_1.acl.implicit_action == ACLAction.DENY diff --git a/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py b/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py index 8526ab78..f3911691 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py @@ -1,8 +1,10 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from primaite.config.load import data_manipulation_config_path from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.network.firewall import Firewall +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ_NETWORK, load_config diff --git a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py index 3e06d371..fb34f43a 100644 --- a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py +++ b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address from pathlib import Path from typing import Union @@ -86,7 +86,7 @@ def test_node_software_install(): assert client_2.software_manager.software.get(software.__name__) is not None # check that applications have been installed on client 1 - for applications in Application._application_registry: + for applications in Application._registry: assert client_1.software_manager.software.get(applications) is not None # check that services have been installed on client 1 @@ -99,7 +99,7 @@ def test_web_browser_install(): game = load_config(BASIC_CONFIG) client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - web_browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") + web_browser: WebBrowser = client_1.software_manager.software.get("web-browser") assert web_browser.target_url == "http://arcd.com/users/" @@ -109,7 +109,7 @@ def test_database_client_install(): game = load_config(BASIC_CONFIG) client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - database_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") + database_client: DatabaseClient = client_1.software_manager.software.get("database-client") assert database_client.server_ip_address == IPv4Address("192.168.1.10") assert database_client.server_password == "arcd" @@ -120,7 +120,7 @@ def test_data_manipulation_bot_install(): game = load_config(BASIC_CONFIG) client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") + data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("data-manipulation-bot") assert data_manipulation_bot.server_ip_address == IPv4Address("192.168.1.21") assert data_manipulation_bot.payload == "DELETE" @@ -134,7 +134,7 @@ def test_dos_bot_install(): game = load_config(BASIC_CONFIG) client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - dos_bot: DoSBot = client_1.software_manager.software.get("DoSBot") + dos_bot: DoSBot = client_1.software_manager.software.get("dos-bot") assert dos_bot.target_ip_address == IPv4Address("192.168.10.21") assert dos_bot.payload == "SPOOF DATA" @@ -149,7 +149,7 @@ def test_dns_client_install(): game = load_config(BASIC_CONFIG) client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - dns_client: DNSClient = client_1.software_manager.software.get("DNSClient") + dns_client: DNSClient = client_1.software_manager.software.get("dns-client") assert dns_client.dns_server == IPv4Address("192.168.1.10") @@ -159,7 +159,7 @@ def test_dns_server_install(): game = load_config(BASIC_CONFIG) client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - dns_server: DNSServer = client_1.software_manager.software.get("DNSServer") + dns_server: DNSServer = client_1.software_manager.software.get("dns-server") assert dns_server.dns_lookup("arcd.com") == IPv4Address("192.168.1.10") @@ -169,7 +169,7 @@ def test_database_service_install(): game = load_config(BASIC_CONFIG) client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - database_service: DatabaseService = client_1.software_manager.software.get("DatabaseService") + database_service: DatabaseService = client_1.software_manager.software.get("database-service") assert database_service.backup_server_ip == IPv4Address("192.168.1.10") @@ -179,10 +179,10 @@ def test_web_server_install(): game = load_config(BASIC_CONFIG) client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - web_server_service: WebServer = client_1.software_manager.software.get("WebServer") + web_server_service: WebServer = client_1.software_manager.software.get("web-server") # config should have also installed database client - web server service should be able to retrieve this - assert web_server_service.software_manager.software.get("DatabaseClient") is not None + assert web_server_service.software_manager.software.get("database-client") is not None def test_ftp_client_install(): @@ -190,7 +190,7 @@ def test_ftp_client_install(): game = load_config(BASIC_CONFIG) client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - ftp_client_service: FTPClient = client_1.software_manager.software.get("FTPClient") + ftp_client_service: FTPClient = client_1.software_manager.software.get("ftp-client") assert ftp_client_service is not None @@ -199,9 +199,8 @@ def test_ftp_server_install(): game = load_config(BASIC_CONFIG) client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - ftp_server_service: FTPServer = client_1.software_manager.software.get("FTPServer") + ftp_server_service: FTPServer = client_1.software_manager.software.get("ftp-server") assert ftp_server_service is not None - assert ftp_server_service.server_password == "arcd" def test_ntp_client_install(): @@ -209,7 +208,7 @@ def test_ntp_client_install(): game = load_config(BASIC_CONFIG) client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - ntp_client_service: NTPClient = client_1.software_manager.software.get("NTPClient") + ntp_client_service: NTPClient = client_1.software_manager.software.get("ntp-client") assert ntp_client_service is not None assert ntp_client_service.ntp_server == IPv4Address("192.168.1.10") @@ -219,5 +218,5 @@ def test_ntp_server_install(): game = load_config(BASIC_CONFIG) client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - ntp_server_service: NTPServer = client_1.software_manager.software.get("NTPServer") + ntp_server_service: NTPServer = client_1.software_manager.software.get("ntp-server") assert ntp_server_service is not None diff --git a/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py b/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py index 13be830b..1352f894 100644 --- a/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py +++ b/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py @@ -1,9 +1,10 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest import yaml from primaite.session.environment import PrimaiteGymEnv from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from tests.conftest import TEST_ASSETS_ROOT folder_path = TEST_ASSETS_ROOT / "configs" / "scenario_with_placeholders" diff --git a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py index 4098db7f..627fc53b 100644 --- a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py +++ b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from pathlib import Path from typing import Union diff --git a/tests/integration_tests/configuration_file_parsing/test_io_settings.py b/tests/integration_tests/configuration_file_parsing/test_io_settings.py index 82977b82..79812d80 100644 --- a/tests/integration_tests/configuration_file_parsing/test_io_settings.py +++ b/tests/integration_tests/configuration_file_parsing/test_io_settings.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from pathlib import Path from typing import Union diff --git a/tests/integration_tests/configuration_file_parsing/test_no_nodes_links_agents_config.py b/tests/integration_tests/configuration_file_parsing/test_no_nodes_links_agents_config.py index 26fc562d..016d264f 100644 --- a/tests/integration_tests/configuration_file_parsing/test_no_nodes_links_agents_config.py +++ b/tests/integration_tests/configuration_file_parsing/test_no_nodes_links_agents_config.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import yaml from primaite.game.game import PrimaiteGame diff --git a/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py b/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py index 49e90b54..4c99a39f 100644 --- a/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py +++ b/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py @@ -25,21 +25,21 @@ def test_node_file_system_from_config(): client_1 = game.simulation.network.get_node_by_hostname("client_1") - assert client_1.software_manager.software.get("DatabaseService") # database service should be installed + assert client_1.software_manager.software.get("database-service") # database service should be installed assert client_1.file_system.get_file(folder_name="database", file_name="database.db") # database files should exist - assert client_1.software_manager.software.get("WebServer") # web server should be installed + assert client_1.software_manager.software.get("web-server") # web server should be installed assert client_1.file_system.get_file(folder_name="primaite", file_name="index.html") # web files should exist client_2 = game.simulation.network.get_node_by_hostname("client_2") # database service should not be installed - assert client_2.software_manager.software.get("DatabaseService") is None + assert client_2.software_manager.software.get("database-service") is None # database files should not exist assert client_2.file_system.get_file(folder_name="database", file_name="database.db") is None # web server should not be installed - assert client_2.software_manager.software.get("WebServer") is None + assert client_2.software_manager.software.get("web-server") is None # web files should not exist assert client_2.file_system.get_file(folder_name="primaite", file_name="index.html") is None @@ -50,7 +50,7 @@ def test_node_file_system_from_config(): password_file = client_2.file_system.get_file(folder_name="root", file_name="passwords.txt") assert password_file # should exist assert password_file.file_type is FileType.TXT - assert password_file.size is 69 + assert password_file.size == 663 downloads_folder = client_2.file_system.get_folder(folder_name="downloads") assert downloads_folder # downloads folder should exist @@ -59,6 +59,6 @@ def test_node_file_system_from_config(): assert test_txt # test.txt should exist assert test_txt.file_type is FileType.TXT - unknown_file_type = downloads_folder.get_file(file_name="suh_con.dn") + unknown_file_type = downloads_folder.get_file(file_name="another_file.pwtwoti") assert unknown_file_type # unknown_file_type should exist assert unknown_file_type.file_type is FileType.UNKNOWN diff --git a/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py b/tests/integration_tests/configuration_file_parsing/test_software_fixing_duration.py similarity index 54% rename from tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py rename to tests/integration_tests/configuration_file_parsing/test_software_fixing_duration.py index dd38fafd..118b8c1f 100644 --- a/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py +++ b/tests/integration_tests/configuration_file_parsing/test_software_fixing_duration.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import copy from pathlib import Path from typing import Union @@ -13,10 +13,10 @@ from primaite.simulator.system.services.database.database_service import Databas from primaite.simulator.system.services.dns.dns_client import DNSClient from tests import TEST_ASSETS_ROOT -TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fix_duration.yaml" -ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fix_duration_one_item.yaml" +TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fixing_duration.yaml" +ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fixing_duration_one_item.yaml" -TestApplications = ["DummyApplication", "BroadcastTestClient"] +TestApplications = ["dummy-application", "broadcast-test-client"] def load_config(config_path: Union[str, Path]) -> PrimaiteGame: @@ -27,55 +27,63 @@ def load_config(config_path: Union[str, Path]) -> PrimaiteGame: return PrimaiteGame.from_config(cfg) -def test_default_fix_duration(): - """Test that software with no defined fix duration in config uses the default fix duration of 2.""" +def test_default_fixing_duration(): + """Test that software with no defined fixing duration in config uses the default fixing duration of 2.""" game = load_config(TEST_CONFIG) client_2: Computer = game.simulation.network.get_node_by_hostname("client_2") - database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient") - assert database_client.fixing_duration == 2 + database_client: DatabaseClient = client_2.software_manager.software.get("database-client") + assert database_client.config.fixing_duration == 2 - dns_client: DNSClient = client_2.software_manager.software.get("DNSClient") - assert dns_client.fixing_duration == 2 + dns_client: DNSClient = client_2.software_manager.software.get("dns-client") + assert dns_client.config.fixing_duration == 2 -def test_fix_duration_set_from_config(): - """Test to check that the fix duration set for applications and services works as intended.""" +def test_fixing_duration_set_from_config(): + """Test to check that the fixing duration set for applications and services works as intended.""" game = load_config(TEST_CONFIG) client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") # in config - services take 3 timesteps to fix - for service in ["DNSClient", "DNSServer", "DatabaseService", "WebServer", "FTPClient", "FTPServer", "NTPServer"]: + for service in [ + "dns-client", + "dns-server", + "database-service", + "web-server", + "ftp-client", + "ftp-server", + "ntp-server", + ]: assert client_1.software_manager.software.get(service) is not None - assert client_1.software_manager.software.get(service).fixing_duration == 3 + assert client_1.software_manager.software.get(service).config.fixing_duration == 3 # in config - applications take 1 timestep to fix # remove test applications from list - applications = set(Application._application_registry) - set(TestApplications) + applications = set(Application._registry) - set(TestApplications) - for application in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot", "DatabaseClient"]: + for application in ["ransomware-script", "web-browser", "data-manipulation-bot", "dos-bot", "database-client"]: assert client_1.software_manager.software.get(application) is not None - assert client_1.software_manager.software.get(application).fixing_duration == 1 + assert client_1.software_manager.software.get(application).config.fixing_duration == 1 -def test_fix_duration_for_one_item(): - """Test that setting fix duration for one application does not affect other components.""" +def test_fixing_duration_for_one_item(): + """Test that setting fixing duration for one application does not affect other components.""" game = load_config(ONE_ITEM_CONFIG) client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") # in config - services take 3 timesteps to fix - for service in ["DNSClient", "DNSServer", "WebServer", "FTPClient", "FTPServer", "NTPServer"]: + for service in ["dns-client", "dns-server", "web-server", "ftp-client", "ftp-server", "ntp-server"]: assert client_1.software_manager.software.get(service) is not None - assert client_1.software_manager.software.get(service).fixing_duration == 2 + assert client_1.software_manager.software.get(service).config.fixing_duration == 2 # in config - applications take 1 timestep to fix # remove test applications from list - for applications in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot"]: + for applications in ["ransomware-script", "web-browser", "data-manipulation-bot", "dos-bot"]: assert client_1.software_manager.software.get(applications) is not None - assert client_1.software_manager.software.get(applications).fixing_duration == 2 + assert client_1.software_manager.software.get(applications).config.fixing_duration == 2 - database_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") - assert database_client.fixing_duration == 1 + database_client: DatabaseClient = client_1.software_manager.software.get("database-client") + assert database_client.config.fixing_duration == 1 - database_service: DatabaseService = client_1.software_manager.software.get("DatabaseService") - assert database_service.fixing_duration == 5 + database_service: DatabaseService = client_1.software_manager.software.get("database-service") + assert database_service.config.fixing_duration == 5 diff --git a/tests/integration_tests/extensions/applications/extended_application.py b/tests/integration_tests/extensions/applications/extended_application.py new file mode 100644 index 00000000..5ea85c57 --- /dev/null +++ b/tests/integration_tests/extensions/applications/extended_application.py @@ -0,0 +1,229 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from enum import Enum +from ipaddress import IPv4Address +from typing import Dict, List, Optional +from urllib.parse import urlparse + +from pydantic import BaseModel, ConfigDict, Field + +from primaite import getLogger +from primaite.interface.request import RequestResponse +from primaite.simulator.core import RequestManager, RequestType +from primaite.simulator.network.protocols.http import ( + HttpRequestMethod, + HttpRequestPacket, + HttpResponsePacket, + HttpStatusCode, +) +from primaite.simulator.system.applications.application import Application +from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.simulator.system.services.dns.dns_client import DNSClient +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP + +_LOGGER = getLogger(__name__) + + +class ExtendedApplication(Application, discriminator="extended-application"): + """ + Clone of web browser that uses the extension framework instead of being part of PrimAITE directly. + + The application requests and loads web pages using its domain name and requesting IP addresses using DNS. + """ + + class ConfigSchema(Application.ConfigSchema): + """ConfigSchema for ExtendedApplication.""" + + type: str = "extended-application" + target_url: Optional[str] = None + + config: "ExtendedApplication.ConfigSchema" = Field(default_factory=lambda: ExtendedApplication.ConfigSchema()) + + target_url: Optional[str] = None + + domain_name_ip_address: Optional[IPv4Address] = None + "The IP address of the domain name for the webpage." + + latest_response: Optional[HttpResponsePacket] = None + """Keeps track of the latest HTTP response.""" + + history: List["BrowserHistoryItem"] = [] + """Keep a log of visited websites and information about the visit, such as response code.""" + + def __init__(self, **kwargs): + kwargs["name"] = "extended-application" + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] + # default for web is port 80 + if kwargs.get("port") is None: + kwargs["port"] = PORT_LOOKUP["HTTP"] + + super().__init__(**kwargs) + self.target_url = self.config.target_url + self.run() + + def _init_request_manager(self) -> RequestManager: + """ + Initialise the request manager. + + More information in user guide and docstring for SimComponent._init_request_manager. + """ + rm = super()._init_request_manager() + rm.add_request( + name="execute", + request_type=RequestType( + func=lambda request, context: RequestResponse.from_bool(self.get_webpage()) + ), # noqa + ) + + return rm + + def describe_state(self) -> Dict: + """ + Produce a dictionary describing the current state of the WebBrowser. + + :return: A dictionary capturing the current state of the WebBrowser and its child objects. + """ + state = super().describe_state() + state["history"] = [hist_item.state() for hist_item in self.history] + return state + + def get_webpage(self, url: Optional[str] = None) -> bool: + """ + Retrieve the webpage. + + This should send a request to the web server which also requests for a list of users + + :param: url: The address of the web page the browser requests + :type: url: str + """ + url = url or self.target_url + if not self._can_perform_action(): + return False + + self.num_executions += 1 # trying to connect counts as an execution + + # reset latest response + self.latest_response = HttpResponsePacket(status_code=HttpStatusCode.NOT_FOUND) + + try: + parsed_url = urlparse(url) + except Exception: + self.sys_log.warning(f"{url} is not a valid URL") + return False + + # get the IP address of the domain name via DNS + dns_client: DNSClient = self.software_manager.software.get("dns-client") + domain_exists = dns_client.check_domain_exists(target_domain=parsed_url.hostname) + + # if domain does not exist, the request fails + if domain_exists: + # set current domain name IP address + self.domain_name_ip_address = dns_client.dns_cache[parsed_url.hostname] + else: + # check if url is an ip address + try: + self.domain_name_ip_address = IPv4Address(parsed_url.hostname) + except Exception: + # unable to deal with this request + self.sys_log.warning(f"{self.name}: Unable to resolve URL {url}") + return False + + # create HTTPRequest payload + payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url=url) + + # send request - As part of the self.send call, a response will be received and stored in the + # self.latest_response variable + if self.send( + payload=payload, + dest_ip_address=self.domain_name_ip_address, + dest_port=parsed_url.port if parsed_url.port else PORT_LOOKUP["HTTP"], + ): + self.sys_log.info( + f"{self.name}: Received HTTP {payload.request_method.name} " + f"Response {payload.request_url} - {self.latest_response.status_code.value}" + ) + self.history.append( + WebBrowser.BrowserHistoryItem( + url=url, + status=self.BrowserHistoryItem._HistoryItemStatus.LOADED, + response_code=self.latest_response.status_code, + ) + ) + return self.latest_response.status_code is HttpStatusCode.OK + else: + self.sys_log.warning(f"{self.name}: Error sending Http Packet") + self.sys_log.debug(f"{self.name}: {payload=}") + self.history.append( + WebBrowser.BrowserHistoryItem( + url=url, status=self.BrowserHistoryItem._HistoryItemStatus.SERVER_UNREACHABLE + ) + ) + return False + + def send( + self, + payload: HttpRequestPacket, + dest_ip_address: Optional[IPv4Address] = None, + dest_port: Optional[int] = PORT_LOOKUP["HTTP"], + session_id: Optional[str] = None, + **kwargs, + ) -> bool: + """ + Sends a payload to the SessionManager. + + :param payload: The payload to be sent. + :param dest_ip_address: The ip address of the payload destination. + :param dest_port: The port of the payload destination. + :param session_id: The Session ID the payload is to originate from. Optional. + + :return: True if successful, False otherwise. + """ + self.sys_log.info(f"{self.name}: Sending HTTP {payload.request_method.name} {payload.request_url}") + + return super().send( + payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id, **kwargs + ) + + def receive(self, payload: HttpResponsePacket, session_id: Optional[str] = None, **kwargs) -> bool: + """ + Receives a payload from the SessionManager. + + :param payload: The payload to be sent. + :param session_id: The Session ID the payload is to originate from. Optional. + :return: True if successful, False otherwise. + """ + if not isinstance(payload, HttpResponsePacket): + self.sys_log.warning(f"{self.name} received a packet that is not an HttpResponsePacket") + self.sys_log.debug(f"{self.name}: {payload=}") + return False + self.sys_log.info(f"{self.name}: Received HTTP {payload.status_code.value}") + self.latest_response = payload + return True + + class BrowserHistoryItem(BaseModel): + """Simple representation of browser history, used for tracking success of web requests to calculate rewards.""" + + model_config = ConfigDict(extra="forbid") + """Error if incorrect specification.""" + + url: str + """The URL that was attempted to be fetched by the browser""" + + class _HistoryItemStatus(Enum): + NOT_SENT = "NOT_SENT" + PENDING = "PENDING" + SERVER_UNREACHABLE = "SERVER_UNREACHABLE" + LOADED = "LOADED" + + status: _HistoryItemStatus = _HistoryItemStatus.PENDING + + response_code: Optional[HttpStatusCode] = None + """HTTP response code that was received, or PENDING if a response was not yet received.""" + + def state(self) -> Dict: + """Return the contents of this dataclass as a dict for use with describe_state method.""" + if self.status == self._HistoryItemStatus.LOADED: + outcome = self.response_code.value + else: + outcome = self.status.value + return {"url": self.url, "outcome": outcome} diff --git a/tests/integration_tests/extensions/nodes/giga_switch.py b/tests/integration_tests/extensions/nodes/giga_switch.py new file mode 100644 index 00000000..5c202ed2 --- /dev/null +++ b/tests/integration_tests/extensions/nodes/giga_switch.py @@ -0,0 +1,125 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from typing import Dict, Literal + +from prettytable import MARKDOWN, PrettyTable + +from primaite import _LOGGER +from primaite.exceptions import NetworkError +from primaite.simulator.network.hardware.base import Link +from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode +from primaite.simulator.network.hardware.nodes.network.switch import SwitchPort +from primaite.simulator.network.transmission.data_link_layer import Frame + + +class GigaSwitch(NetworkNode, discriminator="gigaswitch"): + """ + A class representing a Layer 2 network switch. + + :ivar num_ports: The number of ports on the switch. Default is 24. + """ + + class ConfigSchema(NetworkNode.ConfigSchema): + type: Literal["gigaswitch"] = "gigaswitch" + + num_ports: int = 24 + "The number of ports on the switch." + network_interfaces: Dict[str, SwitchPort] = {} + "The SwitchPorts on the Switch." + network_interface: Dict[int, SwitchPort] = {} + "The SwitchPorts on the Switch by port id." + mac_address_table: Dict[str, SwitchPort] = {} + "A MAC address table mapping destination MAC addresses to corresponding SwitchPorts." + + def __init__(self, **kwargs): + print("--- Extended Component: GigaSwitch ---") + super().__init__(**kwargs) + for i in range(1, self.num_ports + 1): + self.connect_nic(SwitchPort()) + + def _install_system_software(self): + pass + + def show(self, markdown: bool = False): + """ + Prints a table of the SwitchPorts on the Switch. + + :param markdown: If True, outputs the table in markdown format. Default is False. + """ + table = PrettyTable(["Port", "MAC Address", "Speed", "Status"]) + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = f"{self.config.hostname} Switch Ports" + for port_num, port in self.network_interface.items(): + table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"]) + print(table) + + def describe_state(self) -> Dict: + """ + Produce a dictionary describing the current state of this object. + + :return: Current state of this object and child objects. + """ + state = super().describe_state() + state["ports"] = {port_num: port.describe_state() for port_num, port in self.network_interface.items()} + state["num_ports"] = self.num_ports # redundant? + state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()} + return state + + def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort): + """ + Private method to add an entry to the MAC address table. + + :param mac_address: MAC address to be added. + :param switch_port: Corresponding SwitchPort object. + """ + mac_table_port = self.mac_address_table.get(mac_address) + if not mac_table_port: + self.mac_address_table[mac_address] = switch_port + self.sys_log.info(f"Added MAC table entry: Port {switch_port.port_num} -> {mac_address}") + else: + if mac_table_port != switch_port: + self.mac_address_table.pop(mac_address) + self.sys_log.info(f"Removed MAC table entry: Port {mac_table_port.port_num} -> {mac_address}") + self._add_mac_table_entry(mac_address, switch_port) + + def receive_frame(self, frame: Frame, from_network_interface: SwitchPort): + """ + Forward a frame to the appropriate port based on the destination MAC address. + + :param frame: The Frame being received. + :param from_network_interface: The SwitchPort that received the frame. + """ + src_mac = frame.ethernet.src_mac_addr + dst_mac = frame.ethernet.dst_mac_addr + self._add_mac_table_entry(src_mac, from_network_interface) + + outgoing_port = self.mac_address_table.get(dst_mac) + if outgoing_port and dst_mac.lower() != "ff:ff:ff:ff:ff:ff": + outgoing_port.send_frame(frame) + else: + # If the destination MAC is not in the table, flood to all ports except incoming + for port in self.network_interface.values(): + if port.enabled and port != from_network_interface: + port.send_frame(frame) + + def disconnect_link_from_port(self, link: Link, port_number: int): + """ + Disconnect a given link from the specified port number on the switch. + + :param link: The Link object to be disconnected. + :param port_number: The port number on the switch from where the link should be disconnected. + :raise NetworkError: When an invalid port number is provided or the link does not match the connection. + """ + port = self.network_interface.get(port_number) + if port is None: + msg = f"Invalid port number {port_number} on the switch" + _LOGGER.error(msg) + raise NetworkError(msg) + + if port._connected_link != link: + msg = f"The link does not match the connection at port number {port_number}" + _LOGGER.error(msg) + raise NetworkError(msg) + + port.disconnect_link() diff --git a/tests/integration_tests/extensions/nodes/super_computer.py b/tests/integration_tests/extensions/nodes/super_computer.py new file mode 100644 index 00000000..4418e352 --- /dev/null +++ b/tests/integration_tests/extensions/nodes/super_computer.py @@ -0,0 +1,46 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from typing import ClassVar, Dict, Literal + +from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC +from primaite.simulator.system.services.ftp.ftp_client import FTPClient +from primaite.utils.validation.ipv4_address import IPV4Address + + +class SuperComputer(HostNode, discriminator="supercomputer"): + """ + A basic Computer class. + + Example: + >>> pc_a = Computer( + hostname="pc_a", + ip_address="192.168.1.10", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1" + ) + >>> pc_a.power_on() + + Instances of computer come 'pre-packaged' with the following: + + * Core Functionality: + * Packet Capture + * Sys Log + * Services: + * ARP Service + * ICMP Service + * DNS Client + * FTP Client + * NTP Client + * Applications: + * Web Browser + """ + + class ConfigSchema(HostNode.ConfigSchema): + type: Literal["supercomputer"] = "supercomputer" + + SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "ftp-client": FTPClient} + + def __init__(self, **kwargs): + print("--- Extended Component: SuperComputer ---") + super().__init__(**kwargs) + + pass diff --git a/tests/integration_tests/extensions/services/extended_service.py b/tests/integration_tests/extensions/services/extended_service.py new file mode 100644 index 00000000..b1cf7ed5 --- /dev/null +++ b/tests/integration_tests/extensions/services/extended_service.py @@ -0,0 +1,439 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from ipaddress import IPv4Address +from typing import Any, Dict, List, Literal, Optional, Union +from uuid import uuid4 + +from pydantic import Field + +from primaite import getLogger +from primaite.simulator.file_system.file_system import File +from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus +from primaite.simulator.file_system.folder import Folder +from primaite.simulator.system.core.software_manager import SoftwareManager +from primaite.simulator.system.services.ftp.ftp_client import FTPClient +from primaite.simulator.system.services.service import Service, ServiceOperatingState +from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP + +_LOGGER = getLogger(__name__) + + +class ExtendedService(Service, discriminator="extended-service"): + """ + A copy of DatabaseService that uses the extension framework instead of being part of PrimAITE. + + This class inherits from the `Service` class and provides methods to simulate a SQL database. + """ + + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for ExtendedService.""" + + type: str = "extended-service" + + backup_server_ip: IPv4Address = None + """IP address of the backup server.""" + + config: "ExtendedService.ConfigSchema" = Field(default_factory=lambda: ExtendedService.ConfigSchema()) + + password: Optional[str] = None + """Password that needs to be provided by clients if they want to connect to the DatabaseService.""" + + latest_backup_directory: str = None + """Directory of latest backup.""" + + latest_backup_file_name: str = None + """File name of latest backup.""" + + def __init__(self, **kwargs): + kwargs["name"] = "extended-service" + kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] + super().__init__(**kwargs) + self._create_db_file() + if kwargs.get("options"): + opt = kwargs["options"] + self.password = opt.get("db_password", None) + if "backup_server_ip" in opt: + self.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip"))) + + def install(self): + """ + Perform first-time setup of the ExtendedService. + + Installs an instance of FTPClient on the Node to enable database backup if it isn't installed already. + """ + super().install() + + if not self.parent.software_manager.software.get("ftp-client"): + self.parent.sys_log.info(f"{self.name}: Installing FTPClient to enable database backups") + self.parent.software_manager.install(FTPClient) + + def configure_backup(self, backup_server: IPv4Address): + """ + Set up the database backup. + + :param: backup_server_ip: The IP address of the backup server + """ + self.backup_server_ip = backup_server + + def backup_database(self) -> bool: + """Create a backup of the database to the configured backup server.""" + # check if this action can be performed + if not self._can_perform_action(): + return False + + # check if the backup server was configured + if self.backup_server_ip is None: + self.sys_log.warning(f"{self.name} - {self.sys_log.hostname}: not configured.") + return False + + software_manager: SoftwareManager = self.software_manager + ftp_client_service: FTPClient = software_manager.software.get("ftp-client") + + if not ftp_client_service: + self.sys_log.error( + f"{self.name}: Failed to perform database backup as the FTPClient software is not installed" + ) + return False + + # send backup copy of database file to FTP server + if not self.db_file: + self.sys_log.error(f"{self.name}: Attempted to backup database file but it doesn't exist.") + return False + + response = ftp_client_service.send_file( + dest_ip_address=self.backup_server_ip, + src_file_name=self.db_file.name, + src_folder_name="database", + dest_folder_name=str(self.uuid), + # Prevent's a filename clash with the real DatabaseService service implementation + dest_file_name="extended_service_database.db", + ) + + if response: + return True + + self.sys_log.error("Unable to create database backup.") + return False + + def restore_backup(self) -> bool: + """Restore a backup from backup server.""" + # check if this action can be performed + if not self._can_perform_action(): + return False + + software_manager: SoftwareManager = self.software_manager + ftp_client_service: FTPClient = software_manager.software.get("ftp-client") + + if not ftp_client_service: + self.sys_log.error( + f"{self.name}: Failed to restore database backup as the FTPClient software is not installed" + ) + return False + + # retrieve backup file from backup server + response = ftp_client_service.request_file( + src_folder_name=str(self.uuid), + src_file_name="extended_service_database.db", + dest_folder_name="downloads", + dest_file_name="extended_service_database.db", + dest_ip_address=self.backup_server_ip, + ) + + if not response: + self.sys_log.error("Unable to restore database backup.") + return False + + old_visible_state = SoftwareHealthState.GOOD + + # get db file regardless of whether or not it was deleted + db_file = self.file_system.get_file( + folder_name="database", file_name="extended_service_database.db", include_deleted=True + ) + + if db_file is None: + self.sys_log.warning("Database file not initialised.") + return False + + # if the file was deleted, get the old visible health state + if db_file.deleted: + old_visible_state = db_file.visible_health_status + else: + old_visible_state = self.db_file.visible_health_status + self.file_system.delete_file(folder_name="database", file_name="extended_service_database.db") + + # replace db file + self.file_system.copy_file( + src_folder_name="downloads", src_file_name="extended_service_database.db", dst_folder_name="database" + ) + + if self.db_file is None: + self.sys_log.error("Copying database backup failed.") + return False + + self.db_file.visible_health_status = old_visible_state + self.set_health_state(SoftwareHealthState.GOOD) + + return True + + def _create_db_file(self): + """Creates the Simulation File and sqlite file in the file system.""" + self.file_system.create_file(folder_name="database", file_name="extended_service_database.db") + + @property + def db_file(self) -> File: + """Returns the database file.""" + return self.file_system.get_file(folder_name="database", file_name="extended_service_database.db") + + def _return_database_folder(self) -> Folder: + """Returns the database folder.""" + return self.file_system.get_folder_by_id(self.db_file.folder_id) + + def _generate_connection_id(self) -> str: + """Generate a unique connection ID.""" + return str(uuid4()) + + def _process_connect( + self, + src_ip: IPv4Address, + connection_request_id: str, + password: Optional[str] = None, + session_id: Optional[str] = None, + ) -> Dict[str, Union[int, Dict[str, bool]]]: + """Process an incoming connection request. + + :param connection_id: A unique identifier for the connection + :type connection_id: str + :param password: Supplied password. It must match self.password for connection success, defaults to None + :type password: Optional[str], optional + :return: Response to connection request containing success info. + :rtype: Dict[str, Union[int, Dict[str, bool]]] + """ + self.sys_log.info(f"{self.name}: Processing new connection request ({connection_request_id}) from {src_ip}") + status_code = 500 # Default internal server error + connection_id = None + if self.operating_state == ServiceOperatingState.RUNNING: + status_code = 503 # service unavailable + if self.health_state_actual == SoftwareHealthState.OVERWHELMED: + self.sys_log.info( + f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, service is at " + f"capacity." + ) + if self.health_state_actual in [ + SoftwareHealthState.GOOD, + SoftwareHealthState.FIXING, + SoftwareHealthState.COMPROMISED, + ]: + if self.password == password: + status_code = 200 # ok + connection_id = self._generate_connection_id() + # try to create connection + if not self.add_connection(connection_id=connection_id, session_id=session_id): + status_code = 500 + self.sys_log.info( + f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, " + f"returning status code 500" + ) + else: + status_code = 401 # Unauthorised + self.sys_log.info( + f"{self.name}: Connection request ({connection_request_id}) from {src_ip} unauthorised " + f"(incorrect password), returning status code 401" + ) + else: + status_code = 404 # service not found + return { + "status_code": status_code, + "type": "connect_response", + "response": status_code == 200, + "connection_id": connection_id, + "connection_request_id": connection_request_id, + } + + def _process_sql( + self, + query: Literal["SELECT", "DELETE", "INSERT", "ENCRYPT"], + query_id: str, + connection_id: Optional[str] = None, + ) -> Dict[str, Union[int, List[Any]]]: + """ + Executes the given SQL query and returns the result. + + Possible queries: + - SELECT : returns the data + - DELETE : deletes the data + - INSERT : inserts the data + - ENCRYPT : corrupts the data + + :param query: The SQL query to be executed. + :return: Dictionary containing status code and data fetched. + """ + self.sys_log.info(f"{self.name}: Running {query}") + + if not self.db_file: + self.sys_log.error(f"{self.name}: Failed to run {query} because the database file is missing.") + return {"status_code": 404, "type": "sql", "data": False} + + if self.health_state_actual is not SoftwareHealthState.GOOD: + self.sys_log.error(f"{self.name}: Failed to run {query} because the database service is unavailable.") + return {"status_code": 500, "type": "sql", "data": False} + + if query == "SELECT": + if self.db_file.health_status == FileSystemItemHealthStatus.CORRUPT: + return { + "status_code": 200, + "type": "sql", + "data": False, + "uuid": query_id, + "connection_id": connection_id, + } + elif self.db_file.health_status == FileSystemItemHealthStatus.GOOD: + return { + "status_code": 200, + "type": "sql", + "data": True, + "uuid": query_id, + "connection_id": connection_id, + } + else: + return {"status_code": 404, "type": "sql", "data": False} + elif query == "DELETE": + self.db_file.health_status = FileSystemItemHealthStatus.COMPROMISED + return { + "status_code": 200, + "type": "sql", + "data": False, + "uuid": query_id, + "connection_id": connection_id, + } + elif query == "ENCRYPT": + self.file_system.num_file_creations += 1 + self.db_file.health_status = FileSystemItemHealthStatus.CORRUPT + self.db_file.num_access += 1 + database_folder = self._return_database_folder() + database_folder.health_status = FileSystemItemHealthStatus.CORRUPT + self.file_system.num_file_deletions += 1 + return { + "status_code": 200, + "type": "sql", + "data": False, + "uuid": query_id, + "connection_id": connection_id, + } + elif query == "INSERT": + if self.health_state_actual == SoftwareHealthState.GOOD: + return { + "status_code": 200, + "type": "sql", + "data": False, + "uuid": query_id, + "connection_id": connection_id, + } + else: + return {"status_code": 404, "type": "sql", "data": False} + elif query == "SELECT * FROM pg_stat_activity": + # Check if the connection is active. + if self.health_state_actual == SoftwareHealthState.GOOD: + return { + "status_code": 200, + "type": "sql", + "data": False, + "uuid": query_id, + "connection_id": connection_id, + } + else: + return {"status_code": 401, "data": False} + else: + # Invalid query + self.sys_log.warning(f"{self.name}: Invalid {query}") + return {"status_code": 500, "data": False} + + def describe_state(self) -> Dict: + """ + Produce a dictionary describing the current state of this object. + + Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation. + + :return: Current state of this object and child objects. + :rtype: Dict + """ + return super().describe_state() + + def receive(self, payload: Any, session_id: str, **kwargs) -> bool: + """ + Processes the incoming SQL payload and sends the result back. + + :param payload: The SQL query to be executed. + :param session_id: The session identifier. + :return: True if the Status Code is 200, otherwise False. + """ + result = {"status_code": 500, "data": []} + # if server service is down, return error + if not self._can_perform_action(): + return False + + if isinstance(payload, dict) and payload.get("type"): + if payload["type"] == "connect_request": + src_ip = kwargs.get("frame").ip.src_ip_address + result = self._process_connect( + src_ip=src_ip, + password=payload.get("password"), + connection_request_id=payload.get("connection_request_id"), + session_id=session_id, + ) + elif payload["type"] == "disconnect": + if payload["connection_id"] in self.connections: + connection_id = payload["connection_id"] + connected_ip_address = self.connections[connection_id]["ip_address"] + frame = kwargs.get("frame") + if connected_ip_address == frame.ip.src_ip_address: + self.sys_log.info( + f"{self.name}: Received disconnect command for {connection_id=} from {connected_ip_address}" + ) + self.terminate_connection(connection_id=payload["connection_id"], send_disconnect=False) + else: + self.sys_log.warning( + f"{self.name}: Ignoring disconnect command for {connection_id=} as the command source " + f"({frame.ip.src_ip_address}) doesn't match the connection source ({connected_ip_address})" + ) + elif payload["type"] == "sql": + if payload.get("connection_id") in self.connections: + result = self._process_sql( + query=payload["sql"], query_id=payload["uuid"], connection_id=payload["connection_id"] + ) + else: + result = {"status_code": 401, "type": "sql"} + else: + self.sys_log.info(f"{self.name}: Ignoring payload as it is not a Database payload") + self.send(payload=result, session_id=session_id) + return True + + def send(self, payload: Any, session_id: str, **kwargs) -> bool: + """ + Send a SQL response back down to the SessionManager. + + :param payload: The SQL query results. + :param session_id: The session identifier. + :return: True if the Status Code is 200, otherwise False. + """ + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id) + + return payload["status_code"] == 200 + + def apply_timestep(self, timestep: int) -> None: + """ + Apply a single timestep of simulation dynamics to this service. + + Here at the first step, the database backup is created, in addition to normal service update logic. + """ + if timestep == 1: + self.backup_database() + return super().apply_timestep(timestep) + + def _update_fix_status(self) -> None: + """Perform a database restore when the FIXING countdown is finished.""" + super()._update_fix_status() + if self._fixing_countdown is None: + self.restore_backup() diff --git a/tests/integration_tests/extensions/test_extendable_config.py b/tests/integration_tests/extensions/test_extendable_config.py new file mode 100644 index 00000000..34d1e418 --- /dev/null +++ b/tests/integration_tests/extensions/test_extendable_config.py @@ -0,0 +1,34 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +import os + +from primaite.config.load import get_extended_config_path +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from tests import TEST_ASSETS_ROOT +from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ_NETWORK, load_config +from tests.integration_tests.extensions.applications.extended_application import ExtendedApplication +from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch + +# Import the extended components so that PrimAITE registers them +from tests.integration_tests.extensions.nodes.super_computer import SuperComputer +from tests.integration_tests.extensions.services.extended_service import ExtendedService + +CONFIG_PATH = TEST_ASSETS_ROOT / "configs/extended_config.yaml" + + +def test_extended_example_config(): + """Test that the example config can be parsed properly.""" + game = load_config(CONFIG_PATH) + network: Network = game.simulation.network + + assert len(network.nodes) == 10 # 10 nodes in example network + assert len(network.computer_nodes) == 1 + assert len(network.router_nodes) == 1 # 1 router in network + assert len(network.switch_nodes) == 1 # 1 switches in network + assert len(network.server_nodes) == 5 # 5 servers in network + + extended_host = network.get_node_by_hostname("client_1") + + assert "extended-application" in extended_host.software_manager.software + assert "extended-service" in extended_host.software_manager.software diff --git a/tests/integration_tests/game_layer/actions/__init__.py b/tests/integration_tests/game_layer/actions/__init__.py index be6c00e7..836b79af 100644 --- a/tests/integration_tests/game_layer/actions/__init__.py +++ b/tests/integration_tests/game_layer/actions/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/integration_tests/game_layer/actions/test_application_request_permission.py b/tests/integration_tests/game_layer/actions/test_application_request_permission.py index 36a7ae57..f1fc4b34 100644 --- a/tests/integration_tests/game_layer/actions/test_application_request_permission.py +++ b/tests/integration_tests/game_layer/actions/test_application_request_permission.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Tuple import pytest @@ -28,27 +28,27 @@ def test_application_cannot_perform_actions_unless_running(game_and_agent_fixtur game, agent = game_and_agent_fixture client_1 = game.simulation.network.get_node_by_hostname("client_1") - browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") + browser: WebBrowser = client_1.software_manager.software.get("web-browser") browser.close() assert browser.operating_state == ApplicationOperatingState.CLOSED - action = ("NODE_APPLICATION_SCAN", {"node_id": 0, "application_id": 0}) + action = ("node-application-scan", {"node_name": "client_1", "application_name": "web-browser"}) agent.store_action(action) game.step() assert browser.operating_state == ApplicationOperatingState.CLOSED - action = ("NODE_APPLICATION_CLOSE", {"node_id": 0, "application_id": 0}) + action = ("node-application-close", {"node_name": "client_1", "application_name": "web-browser"}) agent.store_action(action) game.step() assert browser.operating_state == ApplicationOperatingState.CLOSED - action = ("NODE_APPLICATION_FIX", {"node_id": 0, "application_id": 0}) + action = ("node-application-fix", {"node_name": "client_1", "application_name": "web-browser"}) agent.store_action(action) game.step() assert browser.operating_state == ApplicationOperatingState.CLOSED - action = ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}) + action = ("node-application-execute", {"node_name": "client_1", "application_name": "web-browser"}) agent.store_action(action) game.step() assert browser.operating_state == ApplicationOperatingState.CLOSED diff --git a/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py b/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py index 806ce063..7cab59ed 100644 --- a/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py +++ b/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address from typing import Tuple @@ -11,13 +11,13 @@ from primaite.simulator.network.hardware.base import UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.red_applications.c2.c2_beacon import C2Beacon from primaite.simulator.system.applications.red_applications.c2.c2_server import C2Command, C2Server from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture @@ -26,13 +26,13 @@ def game_and_agent_fixture(game_and_agent): game, agent = game_and_agent router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=4) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=5) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=6) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=4) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["DNS"], dst_port=PORT_LOOKUP["DNS"], position=5) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=6) c2_server_host = game.simulation.network.get_node_by_hostname("client_1") c2_server_host.software_manager.install(software_class=C2Server) - c2_server: C2Server = c2_server_host.software_manager.software["C2Server"] + c2_server: C2Server = c2_server_host.software_manager.software["c2-server"] c2_server.run() return (game, agent) @@ -46,23 +46,21 @@ def test_c2_beacon_default(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgen server_1: Server = game.simulation.network.get_node_by_hostname("server_1") action = ( - "NODE_APPLICATION_INSTALL", - {"node_id": 1, "application_name": "C2Beacon"}, + "node-application-install", + {"node_name": "server_1", "application_name": "c2-beacon"}, ) agent.store_action(action) game.step() assert agent.history[-1].response.status == "success" action = ( - "CONFIGURE_C2_BEACON", + "configure-c2-beacon", { - "node_id": 1, - "config": { - "c2_server_ip_address": "10.0.1.2", - "keep_alive_frequency": 5, - "masquerade_protocol": "TCP", - "masquerade_port": "HTTP", - }, + "node_name": "server_1", + "c2_server_ip_address": "10.0.1.2", + "keep_alive_frequency": 5, + "masquerade_protocol": "TCP", + "masquerade_port": "HTTP", }, ) agent.store_action(action) @@ -70,15 +68,15 @@ def test_c2_beacon_default(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgen assert agent.history[-1].response.status == "success" action = ( - "NODE_APPLICATION_EXECUTE", - {"node_id": 1, "application_id": 0}, + "node-application-execute", + {"node_name": "server_1", "application_name": "c2-beacon"}, ) agent.store_action(action) game.step() assert agent.history[-1].response.status == "success" # Asserting that we've confirmed our connection - c2_beacon: C2Beacon = server_1.software_manager.software["C2Beacon"] + c2_beacon: C2Beacon = server_1.software_manager.software["c2-beacon"] assert c2_beacon.c2_connection_active == True @@ -93,9 +91,9 @@ def test_c2_server_ransomware(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyA # Installing a database on Server_2 for the ransomware to attack server_2: Server = game.simulation.network.get_node_by_hostname("server_2") server_2.software_manager.install(DatabaseService) - server_2.software_manager.software["DatabaseService"].start() + server_2.software_manager.software["database-service"].start() # Configuring the C2 to connect to client 1 (C2 Server) - c2_beacon: C2Beacon = server_1.software_manager.software["C2Beacon"] + c2_beacon: C2Beacon = server_1.software_manager.software["c2-beacon"] c2_beacon.configure(c2_server_ip_address=IPv4Address("10.0.1.2")) c2_beacon.establish() assert c2_beacon.c2_connection_active == True @@ -103,17 +101,15 @@ def test_c2_server_ransomware(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyA # C2 Action 1: Installing the RansomwareScript & Database client via Terminal action = ( - "C2_SERVER_TERMINAL_COMMAND", + "c2-server-terminal-command", { - "node_id": 0, + "node_name": "client_1", "ip_address": None, - "account": { - "username": "admin", - "password": "admin", - }, + "username": "admin", + "password": "admin", "commands": [ - ["software_manager", "application", "install", "RansomwareScript"], - ["software_manager", "application", "install", "DatabaseClient"], + ["software_manager", "application", "install", "ransomware-script"], + ["software_manager", "application", "install", "database-client"], ], }, ) @@ -122,10 +118,11 @@ def test_c2_server_ransomware(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyA assert agent.history[-1].response.status == "success" action = ( - "C2_SERVER_RANSOMWARE_CONFIGURE", + "c2-server-ransomware-configure", { - "node_id": 0, - "config": {"server_ip_address": "10.0.2.3", "payload": "ENCRYPT"}, + "node_name": "client_1", + "server_ip_address": "10.0.2.3", + "payload": "ENCRYPT", }, ) agent.store_action(action) @@ -134,16 +131,16 @@ def test_c2_server_ransomware(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyA # Stepping a few timesteps to allow for the RansowmareScript to finish installing. - action = ("DONOTHING", {}) + action = ("do-nothing", {}) agent.store_action(action) game.step() game.step() game.step() action = ( - "C2_SERVER_RANSOMWARE_LAUNCH", + "c2-server-ransomware-launch", { - "node_id": 0, + "node_name": "client_1", }, ) agent.store_action(action) @@ -165,10 +162,10 @@ def test_c2_server_data_exfiltration(game_and_agent_fixture: Tuple[PrimaiteGame, # Installing a database on Server_2 (creates a database.db file.) server_2: Server = game.simulation.network.get_node_by_hostname("server_2") server_2.software_manager.install(DatabaseService) - server_2.software_manager.software["DatabaseService"].start() + server_2.software_manager.software["database-service"].start() # Configuring the C2 to connect to client 1 (C2 Server) - c2_beacon: C2Beacon = server_1.software_manager.software["C2Beacon"] + c2_beacon: C2Beacon = server_1.software_manager.software["c2-beacon"] c2_beacon.configure(c2_server_ip_address=IPv4Address("10.0.1.2")) c2_beacon.establish() assert c2_beacon.c2_connection_active == True @@ -181,17 +178,15 @@ def test_c2_server_data_exfiltration(game_and_agent_fixture: Tuple[PrimaiteGame, # C2 Action: Data exfiltrate. action = ( - "C2_SERVER_DATA_EXFILTRATE", + "c2-server-data-exfiltrate", { - "node_id": 0, + "node_name": "client_1", "target_file_name": "database.db", "target_folder_name": "database", "exfiltration_folder_name": "spoils", "target_ip_address": "10.0.2.3", - "account": { - "username": "admin", - "password": "admin", - }, + "username": "admin", + "password": "admin", }, ) agent.store_action(action) diff --git a/tests/integration_tests/game_layer/actions/test_configure_actions.py b/tests/integration_tests/game_layer/actions/test_configure_actions.py index 0c9ec6f0..35d65a5a 100644 --- a/tests/integration_tests/game_layer/actions/test_configure_actions.py +++ b/tests/integration_tests/game_layer/actions/test_configure_actions.py @@ -1,22 +1,22 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address import pytest from pydantic import ValidationError -from primaite.game.agent.actions import ( +from primaite.game.agent.actions.software import ( ConfigureDatabaseClientAction, ConfigureDoSBotAction, ConfigureRansomwareScriptAction, ) from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.database.database_service import DatabaseService +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT from tests.conftest import ControlledAgent @@ -27,21 +27,18 @@ class TestConfigureDatabaseAction: def test_configure_ip_password(self, game_and_agent): game, agent = game_and_agent agent: ControlledAgent - agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager) # make sure there is a database client on this node client_1 = game.simulation.network.get_node_by_hostname("client_1") client_1.software_manager.install(DatabaseClient) - db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"] + db_client: DatabaseClient = client_1.software_manager.software["database-client"] action = ( - "CONFIGURE_DATABASE_CLIENT", + "configure-database-client", { - "node_id": 0, - "config": { - "server_ip_address": "192.168.1.99", - "server_password": "admin123", - }, + "node_name": "client_1", + "server_ip_address": "192.168.1.99", + "server_password": "admin123", }, ) agent.store_action(action) @@ -53,20 +50,17 @@ class TestConfigureDatabaseAction: def test_configure_ip(self, game_and_agent): game, agent = game_and_agent agent: ControlledAgent - agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager) # make sure there is a database client on this node client_1 = game.simulation.network.get_node_by_hostname("client_1") client_1.software_manager.install(DatabaseClient) - db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"] + db_client: DatabaseClient = client_1.software_manager.software["database-client"] action = ( - "CONFIGURE_DATABASE_CLIENT", + "configure-database-client", { - "node_id": 0, - "config": { - "server_ip_address": "192.168.1.99", - }, + "node_name": "client_1", + "server_ip_address": "192.168.1.99", }, ) agent.store_action(action) @@ -78,21 +72,18 @@ class TestConfigureDatabaseAction: def test_configure_password(self, game_and_agent): game, agent = game_and_agent agent: ControlledAgent - agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager) # make sure there is a database client on this node client_1 = game.simulation.network.get_node_by_hostname("client_1") client_1.software_manager.install(DatabaseClient) - db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"] + db_client: DatabaseClient = client_1.software_manager.software["database-client"] old_ip = db_client.server_ip_address action = ( - "CONFIGURE_DATABASE_CLIENT", + "configure-database-client", { - "node_id": 0, - "config": { - "server_password": "admin123", - }, + "node_name": "client_1", + "server_password": "admin123", }, ) agent.store_action(action) @@ -120,22 +111,19 @@ class TestConfigureRansomwareScriptAction: def test_configure_ip_password(self, game_and_agent, config): game, agent = game_and_agent agent: ControlledAgent - agent.action_manager.actions["CONFIGURE_RANSOMWARE_SCRIPT"] = ConfigureRansomwareScriptAction( - agent.action_manager - ) # make sure there is a database client on this node client_1 = game.simulation.network.get_node_by_hostname("client_1") client_1.software_manager.install(RansomwareScript) - ransomware_script: RansomwareScript = client_1.software_manager.software["RansomwareScript"] + ransomware_script: RansomwareScript = client_1.software_manager.software["ransomware-script"] old_ip = ransomware_script.server_ip_address old_pw = ransomware_script.server_password old_payload = ransomware_script.payload action = ( - "CONFIGURE_RANSOMWARE_SCRIPT", - {"node_id": 0, "config": config}, + "configure-ransomware-script", + {"node_name": "client_1", **config}, ) agent.store_action(action) game.step() @@ -151,18 +139,15 @@ class TestConfigureRansomwareScriptAction: def test_invalid_config(self, game_and_agent): game, agent = game_and_agent agent: ControlledAgent - agent.action_manager.actions["CONFIGURE_RANSOMWARE_SCRIPT"] = ConfigureRansomwareScriptAction( - agent.action_manager - ) # make sure there is a database client on this node client_1 = game.simulation.network.get_node_by_hostname("client_1") client_1.software_manager.install(RansomwareScript) - ransomware_script: RansomwareScript = client_1.software_manager.software["RansomwareScript"] + ransomware_script: RansomwareScript = client_1.software_manager.software["ransomware-script"] action = ( - "CONFIGURE_RANSOMWARE_SCRIPT", + "configure-ransomware-script", { - "node_id": 0, + "node_name": "client_1", "config": {"server_password": "admin123", "bad_option": 70}, }, ) @@ -172,35 +157,32 @@ class TestConfigureRansomwareScriptAction: class TestConfigureDoSBot: - def test_configure_DoSBot(self, game_and_agent): + def test_configure_dos_bot(self, game_and_agent): game, agent = game_and_agent agent: ControlledAgent - agent.action_manager.actions["CONFIGURE_DOSBOT"] = ConfigureDoSBotAction(agent.action_manager) client_1 = game.simulation.network.get_node_by_hostname("client_1") client_1.software_manager.install(DoSBot) - dos_bot: DoSBot = client_1.software_manager.software["DoSBot"] + dos_bot: DoSBot = client_1.software_manager.software["dos-bot"] action = ( - "CONFIGURE_DOSBOT", + "configure-dos-bot", { - "node_id": 0, - "config": { - "target_ip_address": "192.168.1.99", - "target_port": "POSTGRES_SERVER", - "payload": "HACC", - "repeat": False, - "port_scan_p_of_success": 0.875, - "dos_intensity": 0.75, - "max_sessions": 50, - }, + "node_name": "client_1", + "target_ip_address": "192.168.1.99", + "target_port": "POSTGRES_SERVER", + "payload": "HACC", + "repeat": False, + "port_scan_p_of_success": 0.875, + "dos_intensity": 0.75, + "max_sessions": 50, }, ) agent.store_action(action) game.step() assert dos_bot.target_ip_address == IPv4Address("192.168.1.99") - assert dos_bot.target_port == Port.POSTGRES_SERVER + assert dos_bot.target_port == PORT_LOOKUP["POSTGRES_SERVER"] assert dos_bot.payload == "HACC" assert not dos_bot.repeat assert dos_bot.port_scan_p_of_success == 0.875 @@ -214,11 +196,11 @@ class TestConfigureYAML: # make sure there's no db client on the node yet client_1 = env.game.simulation.network.get_node_by_hostname("client_1") - assert client_1.software_manager.software.get("DatabaseClient") is None + assert client_1.software_manager.software.get("database-client") is None # take the install action, check that the db gets installed, step to get it to finish installing env.step(1) - db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") + db_client: DatabaseClient = client_1.software_manager.software.get("database-client") assert isinstance(db_client, DatabaseClient) assert db_client.operating_state == ApplicationOperatingState.INSTALLING env.step(0) @@ -239,14 +221,14 @@ class TestConfigureYAML: assert db_client.server_password == "correct_password" assert db_client.connect() - def test_configure_ransomware_script(self): + def test_c2_server_ransomware_configure(self): env = PrimaiteGymEnv(env_config=APP_CONFIG_YAML) client_2 = env.game.simulation.network.get_node_by_hostname("client_2") - assert client_2.software_manager.software.get("RansomwareScript") is None + assert client_2.software_manager.software.get("ransomware-script") is None # install ransomware script env.step(2) - ransom = client_2.software_manager.software.get("RansomwareScript") + ransom = client_2.software_manager.software.get("ransomware-script") assert isinstance(ransom, RansomwareScript) assert ransom.operating_state == ApplicationOperatingState.INSTALLING env.step(0) @@ -268,17 +250,17 @@ class TestConfigureYAML: assert ransom.attack() db_server = env.game.simulation.network.get_node_by_hostname("server_1") - db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") + db_service: DatabaseService = db_server.software_manager.software.get("database-service") assert db_service.db_file.health_status == FileSystemItemHealthStatus.CORRUPT def test_configure_dos_bot(self): env = PrimaiteGymEnv(env_config=APP_CONFIG_YAML) client_3 = env.game.simulation.network.get_node_by_hostname("client_3") - assert client_3.software_manager.software.get("DoSBot") is None + assert client_3.software_manager.software.get("dos-bot") is None # install DoSBot env.step(3) - bot = client_3.software_manager.software.get("DoSBot") + bot = client_3.software_manager.software.get("dos-bot") assert isinstance(bot, DoSBot) assert bot.operating_state == ApplicationOperatingState.INSTALLING env.step(0) diff --git a/tests/integration_tests/game_layer/actions/test_file_request_permission.py b/tests/integration_tests/game_layer/actions/test_file_request_permission.py index 1c143aed..905dbfc9 100644 --- a/tests/integration_tests/game_layer/actions/test_file_request_permission.py +++ b/tests/integration_tests/game_layer/actions/test_file_request_permission.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import uuid from typing import Tuple @@ -33,8 +33,8 @@ def test_create_file(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): assert client_1.file_system.get_file(folder_name=random_folder, file_name=random_file) is None action = ( - "NODE_FILE_CREATE", - {"node_id": 0, "folder_name": random_folder, "file_name": random_file}, + "node-file-create", + {"node_name": "client_1", "folder_name": random_folder, "file_name": random_file}, ) agent.store_action(action) game.step() @@ -51,8 +51,8 @@ def test_file_delete_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAge assert file.deleted is False action = ( - "NODE_FILE_DELETE", - {"node_id": 0, "folder_id": 0, "file_id": 0}, + "node-file-delete", + {"node_name": "client_1", "folder_name": "downloads", "file_name": "cat.png"}, ) agent.store_action(action) game.step() @@ -69,11 +69,11 @@ def test_file_scan_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent file.corrupt() assert file.health_status == FileSystemItemHealthStatus.CORRUPT - assert file.visible_health_status == FileSystemItemHealthStatus.GOOD + assert file.visible_health_status == FileSystemItemHealthStatus.NONE action = ( - "NODE_FILE_SCAN", - {"node_id": 0, "folder_id": 0, "file_id": 0}, + "node-file-scan", + {"node_name": "client_1", "folder_name": "downloads", "file_name": "cat.png"}, ) agent.store_action(action) game.step() @@ -93,8 +93,8 @@ def test_file_repair_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAge assert file.health_status == FileSystemItemHealthStatus.CORRUPT action = ( - "NODE_FILE_REPAIR", - {"node_id": 0, "folder_id": 0, "file_id": 0}, + "node-file-repair", + {"node_name": "client_1", "folder_name": "downloads", "file_name": "cat.png"}, ) agent.store_action(action) game.step() @@ -113,8 +113,8 @@ def test_file_restore_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAg assert file.health_status == FileSystemItemHealthStatus.CORRUPT action = ( - "NODE_FILE_RESTORE", - {"node_id": 0, "folder_id": 0, "file_id": 0}, + "node-file-restore", + {"node_name": "client_1", "folder_name": "downloads", "file_name": "cat.png"}, ) agent.store_action(action) game.step() @@ -132,8 +132,8 @@ def test_file_corrupt_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAg assert file.health_status == FileSystemItemHealthStatus.GOOD action = ( - "NODE_FILE_CORRUPT", - {"node_id": 0, "folder_id": 0, "file_id": 0}, + "node-file-corrupt", + {"node_name": "client_1", "folder_name": "downloads", "file_name": "cat.png"}, ) agent.store_action(action) game.step() @@ -150,8 +150,8 @@ def test_file_access_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAge assert file.num_access == 0 action = ( - "NODE_FILE_ACCESS", - {"node_id": 0, "folder_name": file.folder_name, "file_name": file.name}, + "node-file-access", + {"node_name": "client_1", "folder_name": file.folder_name, "file_name": file.name}, ) agent.store_action(action) game.step() diff --git a/tests/integration_tests/game_layer/actions/test_folder_request_permission.py b/tests/integration_tests/game_layer/actions/test_folder_request_permission.py index e5e0806a..1bd1add3 100644 --- a/tests/integration_tests/game_layer/actions/test_folder_request_permission.py +++ b/tests/integration_tests/game_layer/actions/test_folder_request_permission.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import uuid from typing import Tuple @@ -32,9 +32,9 @@ def test_create_folder(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): assert client_1.file_system.get_folder(folder_name=random_folder) is None action = ( - "NODE_FOLDER_CREATE", + "node-folder-create", { - "node_id": 0, + "node_name": "client_1", "folder_name": random_folder, }, ) @@ -52,18 +52,18 @@ def test_folder_scan_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAge folder = client_1.file_system.get_folder(folder_name="downloads") assert folder.health_status == FileSystemItemHealthStatus.GOOD - assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD + assert folder.visible_health_status == FileSystemItemHealthStatus.NONE folder.corrupt() assert folder.health_status == FileSystemItemHealthStatus.CORRUPT - assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD + assert folder.visible_health_status == FileSystemItemHealthStatus.NONE action = ( - "NODE_FOLDER_SCAN", + "node-folder-scan", { - "node_id": 0, # client_1, - "folder_id": 0, # downloads + "node_name": "client_1", # client_1, + "folder_name": "downloads", # downloads }, ) agent.store_action(action) @@ -87,10 +87,10 @@ def test_folder_repair_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyA assert folder.health_status == FileSystemItemHealthStatus.CORRUPT action = ( - "NODE_FOLDER_REPAIR", + "node-folder-repair", { - "node_id": 0, # client_1, - "folder_id": 0, # downloads + "node_name": "client_1", # client_1, + "folder_name": "downloads", # downloads }, ) agent.store_action(action) @@ -111,10 +111,10 @@ def test_folder_restore_action(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy assert folder.health_status == FileSystemItemHealthStatus.CORRUPT action = ( - "NODE_FOLDER_RESTORE", + "node-folder-restore", { - "node_id": 0, # client_1, - "folder_id": 0, # downloads + "node_name": "client_1", # client_1, + "folder_name": "downloads", # downloads }, ) agent.store_action(action) diff --git a/tests/integration_tests/game_layer/actions/test_nic_request_permission.py b/tests/integration_tests/game_layer/actions/test_nic_request_permission.py index d796b75e..a68e4b23 100644 --- a/tests/integration_tests/game_layer/actions/test_nic_request_permission.py +++ b/tests/integration_tests/game_layer/actions/test_nic_request_permission.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Tuple import pytest @@ -29,10 +29,10 @@ def test_nic_cannot_be_turned_off_if_not_on(game_and_agent_fixture: Tuple[Primai assert nic.enabled is False action = ( - "HOST_NIC_DISABLE", + "host-nic-disable", { - "node_id": 0, # client_1 - "nic_id": 0, # the only nic (eth-1) + "node_name": "client_1", # client_1 + "nic_num": 1, # the only nic (eth-1) }, ) agent.store_action(action) @@ -50,10 +50,10 @@ def test_nic_cannot_be_turned_on_if_already_on(game_and_agent_fixture: Tuple[Pri assert nic.enabled action = ( - "HOST_NIC_ENABLE", + "host-nic-enable", { - "node_id": 0, # client_1 - "nic_id": 0, # the only nic (eth-1) + "node_name": "client_1", # client_1 + "nic_num": 1, # the only nic (eth-1) }, ) agent.store_action(action) @@ -71,10 +71,10 @@ def test_that_a_nic_can_be_enabled_and_disabled(game_and_agent_fixture: Tuple[Pr assert nic.enabled action = ( - "HOST_NIC_DISABLE", + "host-nic-disable", { - "node_id": 0, # client_1 - "nic_id": 0, # the only nic (eth-1) + "node_name": "client_1", # client_1 + "nic_num": 1, # the only nic (eth-1) }, ) agent.store_action(action) @@ -83,10 +83,10 @@ def test_that_a_nic_can_be_enabled_and_disabled(game_and_agent_fixture: Tuple[Pr assert nic.enabled is False action = ( - "HOST_NIC_ENABLE", + "host-nic-enable", { - "node_id": 0, # client_1 - "nic_id": 0, # the only nic (eth-1) + "node_name": "client_1", # client_1 + "nic_num": 1, # the only nic (eth-1) }, ) agent.store_action(action) diff --git a/tests/integration_tests/game_layer/actions/test_node_request_permission.py b/tests/integration_tests/game_layer/actions/test_node_request_permission.py index fdf04ad5..baddee46 100644 --- a/tests/integration_tests/game_layer/actions/test_node_request_permission.py +++ b/tests/integration_tests/game_layer/actions/test_node_request_permission.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Tuple import pytest @@ -25,32 +25,35 @@ def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy game, agent = game_and_agent_fixture client_1 = game.simulation.network.get_node_by_hostname("client_1") + client_1.config.shut_down_duration = 3 assert client_1.operating_state == NodeOperatingState.ON # turn it off - action = ("NODE_SHUTDOWN", {"node_id": 0}) + action = ("node-shutdown", {"node_name": "client_1"}) agent.store_action(action) game.step() assert client_1.operating_state == NodeOperatingState.SHUTTING_DOWN - for i in range(client_1.shut_down_duration + 1): - action = ("DONOTHING", {"node_id": 0}) + for i in range(client_1.config.shut_down_duration + 1): + action = ("do-nothing", {}) agent.store_action(action) game.step() assert client_1.operating_state == NodeOperatingState.OFF + client_1.config.start_up_duration = 3 + # turn it on - action = ("NODE_STARTUP", {"node_id": 0}) + action = ("node-startup", {"node_name": "client_1"}) agent.store_action(action) game.step() assert client_1.operating_state == NodeOperatingState.BOOTING - for i in range(client_1.start_up_duration + 1): - action = ("DONOTHING", {"node_id": 0}) + for i in range(client_1.config.start_up_duration + 1): + action = ("do-nothing", {}) agent.store_action(action) game.step() @@ -65,7 +68,7 @@ def test_node_cannot_be_started_up_if_node_is_already_on(game_and_agent_fixture: assert client_1.operating_state == NodeOperatingState.ON # turn it on - action = ("NODE_STARTUP", {"node_id": 0}) + action = ("node-startup", {"node_name": "client_1"}) agent.store_action(action) game.step() @@ -79,15 +82,15 @@ def test_node_cannot_be_shut_down_if_node_is_already_off(game_and_agent_fixture: client_1 = game.simulation.network.get_node_by_hostname("client_1") client_1.power_off() - for i in range(client_1.shut_down_duration + 1): - action = ("DONOTHING", {"node_id": 0}) + for i in range(client_1.config.shut_down_duration + 1): + action = ("do-nothing", {}) agent.store_action(action) game.step() assert client_1.operating_state == NodeOperatingState.OFF # turn it ff - action = ("NODE_SHUTDOWN", {"node_id": 0}) + action = ("node-shutdown", {"node_name": "client_1"}) agent.store_action(action) game.step() diff --git a/tests/integration_tests/game_layer/actions/test_service_request_permission.py b/tests/integration_tests/game_layer/actions/test_service_request_permission.py index 3054c73b..9bf7a38c 100644 --- a/tests/integration_tests/game_layer/actions/test_service_request_permission.py +++ b/tests/integration_tests/game_layer/actions/test_service_request_permission.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Tuple import pytest @@ -26,12 +26,12 @@ def test_service_start(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): game, agent = game_and_agent_fixture server_1: Server = game.simulation.network.get_node_by_hostname("server_1") - dns_server = server_1.software_manager.software.get("DNSServer") + dns_server = server_1.software_manager.software.get("dns-server") dns_server.pause() assert dns_server.operating_state == ServiceOperatingState.PAUSED - action = ("NODE_SERVICE_START", {"node_id": 1, "service_id": 0}) + action = ("node-service-start", {"node_name": "server_1", "service_name": "dns-server"}) agent.store_action(action) game.step() assert dns_server.operating_state == ServiceOperatingState.PAUSED @@ -40,7 +40,7 @@ def test_service_start(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): assert dns_server.operating_state == ServiceOperatingState.STOPPED - action = ("NODE_SERVICE_START", {"node_id": 1, "service_id": 0}) + action = ("node-service-start", {"node_name": "server_1", "service_name": "dns-server"}) agent.store_action(action) game.step() @@ -52,9 +52,9 @@ def test_service_resume(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]) game, agent = game_and_agent_fixture server_1: Server = game.simulation.network.get_node_by_hostname("server_1") - dns_server = server_1.software_manager.software.get("DNSServer") + dns_server = server_1.software_manager.software.get("dns-server") - action = ("NODE_SERVICE_RESUME", {"node_id": 1, "service_id": 0}) + action = ("node-service-resume", {"node_name": "server_1", "service_name": "dns-server"}) agent.store_action(action) game.step() assert dns_server.operating_state == ServiceOperatingState.RUNNING @@ -63,7 +63,7 @@ def test_service_resume(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]) assert dns_server.operating_state == ServiceOperatingState.PAUSED - action = ("NODE_SERVICE_RESUME", {"node_id": 1, "service_id": 0}) + action = ("node-service-resume", {"node_name": "server_1", "service_name": "dns-server"}) agent.store_action(action) game.step() @@ -75,32 +75,32 @@ def test_service_cannot_perform_actions_unless_running(game_and_agent_fixture: T game, agent = game_and_agent_fixture server_1: Server = game.simulation.network.get_node_by_hostname("server_1") - dns_server = server_1.software_manager.software.get("DNSServer") + dns_server = server_1.software_manager.software.get("dns-server") dns_server.stop() assert dns_server.operating_state == ServiceOperatingState.STOPPED - action = ("NODE_SERVICE_SCAN", {"node_id": 1, "service_id": 0}) + action = ("node-service-scan", {"node_name": "server_1", "service_name": "dns-server"}) agent.store_action(action) game.step() assert dns_server.operating_state == ServiceOperatingState.STOPPED - action = ("NODE_SERVICE_PAUSE", {"node_id": 1, "service_id": 0}) + action = ("node-service-pause", {"node_name": "server_1", "service_name": "dns-server"}) agent.store_action(action) game.step() assert dns_server.operating_state == ServiceOperatingState.STOPPED - action = ("NODE_SERVICE_RESUME", {"node_id": 1, "service_id": 0}) + action = ("node-service-resume", {"node_name": "server_1", "service_name": "dns-server"}) agent.store_action(action) game.step() assert dns_server.operating_state == ServiceOperatingState.STOPPED - action = ("NODE_SERVICE_RESTART", {"node_id": 1, "service_id": 0}) + action = ("node-service-restart", {"node_name": "server_1", "service_name": "dns-server"}) agent.store_action(action) game.step() assert dns_server.operating_state == ServiceOperatingState.STOPPED - action = ("NODE_SERVICE_FIX", {"node_id": 1, "service_id": 0}) + action = ("node-service-fix", {"node_name": "server_1", "service_name": "dns-server"}) agent.store_action(action) game.step() assert dns_server.operating_state == ServiceOperatingState.STOPPED diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py index c4247d6e..3ee97fb7 100644 --- a/tests/integration_tests/game_layer/actions/test_terminal_actions.py +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Tuple import pytest @@ -9,9 +9,9 @@ from primaite.simulator.network.hardware.base import UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture @@ -20,7 +20,7 @@ def game_and_agent_fixture(game_and_agent): game, agent = game_and_agent router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=4) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=4) return (game, agent) @@ -32,13 +32,13 @@ def test_remote_login(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): client_1 = game.simulation.network.get_node_by_hostname("client_1") # create a new user account on server_1 that will be logged into remotely - server_1_usm: UserManager = server_1.software_manager.software["UserManager"] + server_1_usm: UserManager = server_1.software_manager.software["user-manager"] server_1_usm.add_user("user123", "password", is_admin=True) action = ( - "SSH_TO_REMOTE", + "node-session-remote-login", { - "node_id": 0, + "node_name": "client_1", "username": "user123", "password": "password", "remote_ip": str(server_1.network_interface[1].ip_address), @@ -64,13 +64,13 @@ def test_remote_login_wrong_password(game_and_agent_fixture: Tuple[PrimaiteGame, client_1 = game.simulation.network.get_node_by_hostname("client_1") # create a new user account on server_1 that will be logged into remotely - server_1_usm: UserManager = server_1.software_manager.software["UserManager"] + server_1_usm: UserManager = server_1.software_manager.software["user-manager"] server_1_usm.add_user("user123", "password", is_admin=True) action = ( - "SSH_TO_REMOTE", + "node-session-remote-login", { - "node_id": 0, + "node_name": "client_1", "username": "user123", "password": "wrong_password", "remote_ip": str(server_1.network_interface[1].ip_address), @@ -96,13 +96,13 @@ def test_remote_login_change_password(game_and_agent_fixture: Tuple[PrimaiteGame client_1 = game.simulation.network.get_node_by_hostname("client_1") # create a new user account on server_1 that will be logged into remotely - server_1_um: UserManager = server_1.software_manager.software["UserManager"] + server_1_um: UserManager = server_1.software_manager.software["user-manager"] server_1_um.add_user("user123", "password", is_admin=True) action = ( - "NODE_ACCOUNTS_CHANGE_PASSWORD", + "node-account-change-password", { - "node_id": 1, # server_1 + "node_name": "server_1", # server_1 "username": "user123", "current_password": "password", "new_password": "different_password", @@ -121,14 +121,14 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam client_1 = game.simulation.network.get_node_by_hostname("client_1") # create a new user account on server_1 that will be logged into remotely - server_1_usm: UserManager = server_1.software_manager.software["UserManager"] + server_1_usm: UserManager = server_1.software_manager.software["user-manager"] server_1_usm.add_user("user123", "password", is_admin=True) # Log in remotely action = ( - "SSH_TO_REMOTE", + "node-session-remote-login", { - "node_id": 0, + "node_name": "client_1", "username": "user123", "password": "password", "remote_ip": str(server_1.network_interface[1].ip_address), @@ -139,9 +139,9 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam # Change password action = ( - "NODE_ACCOUNTS_CHANGE_PASSWORD", + "node-account-change-password", { - "node_id": 1, # server_1 + "node_name": "server_1", # server_1 "username": "user123", "current_password": "password", "new_password": "different_password", @@ -152,9 +152,9 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam # Assert that the user cannot execute an action action = ( - "NODE_SEND_REMOTE_COMMAND", + "node-send-remote-command", { - "node_id": 0, + "node_name": "client_1", "remote_ip": str(server_1.network_interface[1].ip_address), "command": ["file_system", "create", "file", "folder123", "doggo.pdf", False], }, @@ -171,13 +171,13 @@ def test_local_terminal(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]) client_1 = game.simulation.network.get_node_by_hostname("client_1") # create a new user account on server_1 that will be logged into remotely - client_1_usm: UserManager = client_1.software_manager.software["UserManager"] + client_1_usm: UserManager = client_1.software_manager.software["user-manager"] client_1_usm.add_user("user123", "password", is_admin=True) action = ( - "NODE_SEND_LOCAL_COMMAND", + "node-send-local-command", { - "node_id": 0, + "node_name": "client_1", "username": "user123", "password": "password", "command": ["file_system", "create", "file", "folder123", "doggo.pdf", False], @@ -191,9 +191,9 @@ def test_local_terminal(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]) # Change password action = ( - "NODE_ACCOUNTS_CHANGE_PASSWORD", + "node-account-change-password", { - "node_id": 0, # server_1 + "node_name": "client_1", "username": "user123", "current_password": "password", "new_password": "different_password", @@ -203,9 +203,9 @@ def test_local_terminal(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]) game.step() action = ( - "NODE_SEND_LOCAL_COMMAND", + "node-send-local-command", { - "node_id": 0, + "node_name": "client_1", "username": "user123", "password": "password", "command": ["file_system", "create", "file", "folder123", "cat.pdf", False], diff --git a/tests/integration_tests/game_layer/actions/test_user_account_actions.py b/tests/integration_tests/game_layer/actions/test_user_account_actions.py index f97716c6..26b871db 100644 --- a/tests/integration_tests/game_layer/actions/test_user_account_actions.py +++ b/tests/integration_tests/game_layer/actions/test_user_account_actions.py @@ -3,7 +3,7 @@ import pytest from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validation.port import Port, PORT_LOOKUP @pytest.fixture @@ -27,8 +27,8 @@ def test_user_account_add_user_action(game_and_agent_fixture): # add admin account action = ( - "NODE_ACCOUNTS_ADD_USER", - {"node_id": 0, "username": "admin_2", "password": "e-tronic-boogaloo", "is_admin": True}, + "node-account-add-user", + {"node_name": "client_1", "username": "admin_2", "password": "e-tronic-boogaloo", "is_admin": True}, ) agent.store_action(action) game.step() @@ -38,8 +38,8 @@ def test_user_account_add_user_action(game_and_agent_fixture): # add non admin account action = ( - "NODE_ACCOUNTS_ADD_USER", - {"node_id": 0, "username": "leeroy.jenkins", "password": "no_plan_needed", "is_admin": False}, + "node-account-add-user", + {"node_name": "client_1", "username": "leeroy.jenkins", "password": "no_plan_needed", "is_admin": False}, ) agent.store_action(action) game.step() @@ -63,9 +63,9 @@ def test_user_account_disable_user_action(game_and_agent_fixture): # disable test account action = ( - "NODE_ACCOUNTS_DISABLE_USER", + "node-account-disable-user", { - "node_id": 0, + "node_name": "client_1", "username": "test", }, ) @@ -86,8 +86,8 @@ def test_user_account_change_password_action(game_and_agent_fixture): # change account password action = ( - "NODE_ACCOUNTS_CHANGE_PASSWORD", - {"node_id": 0, "username": "test", "current_password": "password", "new_password": "2Hard_2_Hack"}, + "node-account-change-password", + {"node_name": "client_1", "username": "test", "current_password": "password", "new_password": "2Hard_2_Hack"}, ) agent.store_action(action) game.step() @@ -100,16 +100,16 @@ def test_user_account_create_terminal_action(game_and_agent_fixture): game, agent = game_and_agent_fixture router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=4) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=4) server_1 = game.simulation.network.get_node_by_hostname("server_1") - server_1_usm = server_1.software_manager.software["UserManager"] + server_1_usm = server_1.software_manager.software["user-manager"] server_1_usm.add_user("user123", "password", is_admin=True) action = ( - "SSH_TO_REMOTE", + "node-session-remote-login", { - "node_id": 0, + "node_name": "client_1", "username": "user123", "password": "password", "remote_ip": str(server_1.network_interface[1].ip_address), @@ -121,11 +121,11 @@ def test_user_account_create_terminal_action(game_and_agent_fixture): # Create a new user account via terminal. action = ( - "NODE_SEND_REMOTE_COMMAND", + "node-send-remote-command", { - "node_id": 0, + "node_name": "client_1", "remote_ip": str(server_1.network_interface[1].ip_address), - "command": ["service", "UserManager", "add_user", "new_user", "new_pass", True], + "command": ["service", "user-manager", "add_user", "new_user", "new_pass", True], }, ) agent.store_action(action) @@ -140,16 +140,16 @@ def test_user_account_disable_terminal_action(game_and_agent_fixture): """Tests that agents can use the terminal to disable users.""" game, agent = game_and_agent_fixture router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=4) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=4) server_1 = game.simulation.network.get_node_by_hostname("server_1") - server_1_usm = server_1.software_manager.software["UserManager"] + server_1_usm = server_1.software_manager.software["user-manager"] server_1_usm.add_user("user123", "password", is_admin=True) action = ( - "SSH_TO_REMOTE", + "node-session-remote-login", { - "node_id": 0, + "node_name": "client_1", "username": "user123", "password": "password", "remote_ip": str(server_1.network_interface[1].ip_address), @@ -161,11 +161,11 @@ def test_user_account_disable_terminal_action(game_and_agent_fixture): # Disable a user via terminal action = ( - "NODE_SEND_REMOTE_COMMAND", + "node-send-remote-command", { - "node_id": 0, + "node_name": "client_1", "remote_ip": str(server_1.network_interface[1].ip_address), - "command": ["service", "UserManager", "disable_user", "user123"], + "command": ["service", "user-manager", "disable_user", "user123"], }, ) agent.store_action(action) diff --git a/tests/integration_tests/game_layer/observations/__init__.py b/tests/integration_tests/game_layer/observations/__init__.py index be6c00e7..836b79af 100644 --- a/tests/integration_tests/game_layer/observations/__init__.py +++ b/tests/integration_tests/game_layer/observations/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/integration_tests/game_layer/observations/test_acl_observations.py b/tests/integration_tests/game_layer/observations/test_acl_observations.py index f1d9d416..0a633b2d 100644 --- a/tests/integration_tests/game_layer/observations/test_acl_observations.py +++ b/tests/integration_tests/game_layer/observations/test_acl_observations.py @@ -1,13 +1,13 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from primaite.game.agent.observations.acl_observation import ACLObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.ntp.ntp_server import NTPServer +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -28,18 +28,18 @@ def test_acl_observations(simulation): # quick set up of ntp client_1.software_manager.install(NTPClient) - ntp_client: NTPClient = client_1.software_manager.software.get("NTPClient") + ntp_client: NTPClient = client_1.software_manager.software.get("ntp-client") ntp_client.configure(server.network_interface.get(1).ip_address) server.software_manager.install(NTPServer) # add router acl rule - router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port.NTP, src_port=Port.NTP, position=1) + router.acl.add_rule(action=ACLAction.PERMIT, dst_port=PORT_LOOKUP["NTP"], src_port=PORT_LOOKUP["NTP"], position=1) acl_obs = ACLObservation( - where=["network", "nodes", router.hostname, "acl", "acl"], + where=["network", "nodes", router.config.hostname, "acl", "acl"], ip_list=[], - port_list=["NTP", "HTTP", "POSTGRES_SERVER"], - protocol_list=["TCP", "UDP", "ICMP"], + port_list=[123, 80, 5432], + protocol_list=["tcp", "udp", "icmp"], num_rules=10, wildcard_list=[], ) diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py index 6356c297..722fd294 100644 --- a/tests/integration_tests/game_layer/observations/test_file_system_observations.py +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from gymnasium import spaces @@ -24,7 +24,7 @@ def test_file_observation(simulation): file = pc.file_system.create_file(file_name="dog.png") dog_file_obs = FileObservation( - where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"], include_num_access=False, file_system_requires_scan=True, ) @@ -32,11 +32,11 @@ def test_file_observation(simulation): assert dog_file_obs.space["health_status"] == spaces.Discrete(6) observation_state = dog_file_obs.observe(simulation.describe_state()) - assert observation_state.get("health_status") == 1 # good initial + assert observation_state.get("health_status") == 0 # initially unset file.corrupt() observation_state = dog_file_obs.observe(simulation.describe_state()) - assert observation_state.get("health_status") == 1 # scan file so this changes + assert observation_state.get("health_status") == 0 # still default unset value because no scan happened file.scan() file.apply_timestep(0) # apply time step @@ -47,7 +47,7 @@ def test_file_observation(simulation): def test_config_file_access_categories(simulation): pc: Computer = simulation.network.get_node_by_hostname("client_1") file_obs = FileObservation( - where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"], include_num_access=False, file_system_requires_scan=True, thresholds={"file_access": {"low": 3, "medium": 6, "high": 9}}, @@ -60,7 +60,7 @@ def test_config_file_access_categories(simulation): with pytest.raises(Exception): # should throw an error FileObservation( - where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"], include_num_access=False, file_system_requires_scan=True, thresholds={"file_access": {"low": 9, "medium": 6, "high": 9}}, @@ -69,7 +69,7 @@ def test_config_file_access_categories(simulation): with pytest.raises(Exception): # should throw an error FileObservation( - where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"], include_num_access=False, file_system_requires_scan=True, thresholds={"file_access": {"low": 3, "medium": 9, "high": 9}}, @@ -84,7 +84,7 @@ def test_folder_observation(simulation): file = pc.file_system.create_file(file_name="dog.png", folder_name="test_folder") root_folder_obs = FolderObservation( - where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"], + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "test_folder"], include_num_access=False, file_system_requires_scan=True, num_files=1, @@ -95,11 +95,11 @@ def test_folder_observation(simulation): observation_state = root_folder_obs.observe(simulation.describe_state()) assert observation_state.get("FILES") is not None - assert observation_state.get("health_status") == 1 + assert observation_state.get("health_status") == 0 # initially unset file.corrupt() # corrupt just the file observation_state = root_folder_obs.observe(simulation.describe_state()) - assert observation_state.get("health_status") == 1 # scan folder to change this + assert observation_state.get("health_status") == 0 # still unset as no scan occurred yet folder.scan() for i in range(folder.scan_duration + 1): diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 34a37f5e..874fa49e 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -1,12 +1,12 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from primaite.game.agent.observations.firewall_observation import FirewallObservation from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import ACLAction from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP def check_default_rules(acl_obs): @@ -25,14 +25,15 @@ def check_default_rules(acl_obs): def test_firewall_observation(): """Test adding/removing acl rules and enabling/disabling ports.""" net = Network() - firewall = Firewall(hostname="firewall", operating_state=NodeOperatingState.ON) + firewall_cfg = {"type": "firewall", "hostname": "firewall"} + firewall = Firewall.from_config(config=firewall_cfg) firewall_observation = FirewallObservation( where=[], num_rules=7, ip_list=["10.0.0.1", "10.0.0.2"], wildcard_list=["0.0.0.255", "0.0.0.1"], - port_list=["HTTP", "DNS"], - protocol_list=["TCP"], + port_list=[80, 53], + protocol_list=["tcp"], include_users=False, ) @@ -62,13 +63,13 @@ def test_firewall_observation(): # add a rule to the internal inbound and check that the observation is correct firewall.internal_inbound_acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="10.0.0.1", src_wildcard_mask="0.0.0.1", dst_ip_address="10.0.0.2", dst_wildcard_mask="0.0.0.1", - src_port=Port.HTTP, - dst_port=Port.HTTP, + src_port=PORT_LOOKUP["HTTP"], + dst_port=PORT_LOOKUP["HTTP"], position=5, ) @@ -116,7 +117,9 @@ def test_firewall_observation(): assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(1, 4)) # connect a switch to the firewall and check that only the correct port is updated - switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON) + switch: Switch = Switch.from_config( + config={"type": "switch", "hostname": "switch", "num_ports": 1, "operating_state": "ON"} + ) link = net.connect(firewall.network_interface[1], switch.network_interface[1]) assert firewall.network_interface[1].enabled observation = firewall_observation.observe(firewall.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_link_observations.py b/tests/integration_tests/game_layer/observations/test_link_observations.py index 7d1c1939..1ab50a68 100644 --- a/tests/integration_tests/game_layer/observations/test_link_observations.py +++ b/tests/integration_tests/game_layer/observations/test_link_observations.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from gymnasium import spaces @@ -56,12 +56,26 @@ def test_link_observation(): """Check the shape and contents of the link observation.""" net = Network() sim = Simulation(network=net) - switch = Switch(hostname="switch", num_ports=5, operating_state=NodeOperatingState.ON) - computer_1 = Computer( - hostname="computer_1", ip_address="10.0.0.1", subnet_mask="255.255.255.0", start_up_duration=0 + switch: Switch = Switch.from_config( + config={"type": "switch", "hostname": "switch", "num_ports": 5, "operating_state": "ON"} ) - computer_2 = Computer( - hostname="computer_2", ip_address="10.0.0.2", subnet_mask="255.255.255.0", start_up_duration=0 + computer_1: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer_1", + "ip_address": "10.0.0.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) + computer_2: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer_2", + "ip_address": "10.0.0.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } ) computer_1.power_on() computer_2.power_on() diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index d01d0c8e..046db4eb 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from pathlib import Path from typing import Union @@ -43,23 +43,23 @@ def simulation(example_network) -> Simulation: computer: Computer = example_network.get_node_by_hostname("client_1") server: Server = example_network.get_node_by_hostname("server_1") - web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") + web_browser: WebBrowser = computer.software_manager.software.get("web-browser") web_browser.run() # Install DNS Client service on computer computer.software_manager.install(DNSClient) - dns_client: DNSClient = computer.software_manager.software.get("DNSClient") + dns_client: DNSClient = computer.software_manager.software.get("dns-client") # set dns server dns_client.dns_server = server.network_interface[1].ip_address # Install Web Server service on server server.software_manager.install(WebServer) - web_server_service: WebServer = server.software_manager.software.get("WebServer") + web_server_service: WebServer = server.software_manager.software.get("web-server") web_server_service.start() # Install DNS Server service on server server.software_manager.install(DNSServer) - dns_server: DNSServer = server.software_manager.software.get("DNSServer") + dns_server: DNSServer = server.software_manager.software.get("dns-server") # register arcd.com to DNS dns_server.dns_register( domain_name="arcd.com", @@ -75,7 +75,15 @@ def test_nic(simulation): nic: NIC = pc.network_interface[1] - nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True) + nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True) + + # The Simulation object created by the fixture also creates the + # NICObservation class with the NICObservation.capture_nmnme class variable + # set to False. Under normal (non-test) circumstances this class variable + # is set from a config file such as data_manipulation.yaml. So although + # capture_nmne is set to True in the NetworkInterface class it's still False + # in the NICObservation class so we set it now. + nic_obs.capture_nmne = True # The Simulation object created by the fixture also creates the # NICObservation class with the NICObservation.capture_nmnme class variable @@ -116,7 +124,7 @@ def test_nic_categories(simulation): """Test the NIC observation nmne count categories.""" pc: Computer = simulation.network.get_node_by_hostname("client_1") - nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True) + nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True) assert nic_obs.high_nmne_threshold == 10 # default assert nic_obs.med_nmne_threshold == 5 # default @@ -126,7 +134,7 @@ def test_nic_categories(simulation): def test_config_nic_categories(simulation): pc: Computer = simulation.network.get_node_by_hostname("client_1") nic_obs = NICObservation( - where=["network", "nodes", pc.hostname, "NICs", 1], + where=["network", "nodes", pc.config.hostname, "NICs", 1], thresholds={"nmne": {"low": 3, "medium": 6, "high": 9}}, include_nmne=True, ) @@ -138,7 +146,7 @@ def test_config_nic_categories(simulation): with pytest.raises(Exception): # should throw an error NICObservation( - where=["network", "nodes", pc.hostname, "NICs", 1], + where=["network", "nodes", pc.config.hostname, "NICs", 1], thresholds={"nmne": {"low": 9, "medium": 6, "high": 9}}, include_nmne=True, ) @@ -146,20 +154,27 @@ def test_config_nic_categories(simulation): with pytest.raises(Exception): # should throw an error NICObservation( - where=["network", "nodes", pc.hostname, "NICs", 1], + where=["network", "nodes", pc.config.hostname, "NICs", 1], thresholds={"nmne": {"low": 3, "medium": 9, "high": 9}}, include_nmne=True, ) def test_nic_monitored_traffic(simulation): - monitored_traffic = {"icmp": ["NONE"], "tcp": ["DNS"]} + monitored_traffic = { + "icmp": ["NONE"], + "tcp": [ + 53, + ], + } pc: Computer = simulation.network.get_node_by_hostname("client_1") pc2: Computer = simulation.network.get_node_by_hostname("client_2") nic_obs = NICObservation( - where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=False, monitored_traffic=monitored_traffic + where=["network", "nodes", pc.config.hostname, "NICs", 1], + include_nmne=False, + monitored_traffic=monitored_traffic, ) simulation.pre_timestep(0) # apply timestep to whole sim @@ -186,8 +201,8 @@ def test_nic_monitored_traffic(simulation): assert traffic_obs["tcp"][53]["outbound"] == 0 # send a database query - browser: WebBrowser = pc.software_manager.software.get("WebBrowser") - browser.target_url = f"http://arcd.com/" + browser: WebBrowser = pc.software_manager.software.get("web-browser") + browser.config.target_url = f"http://arcd.com/" browser.get_webpage() traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC") diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index 9d60823b..aef60bc2 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import copy from uuid import uuid4 @@ -25,7 +25,7 @@ def test_host_observation(simulation): pc: Computer = simulation.network.get_node_by_hostname("client_1") host_obs = HostObservation( - where=["network", "nodes", pc.hostname], + where=["network", "nodes", pc.config.hostname], num_applications=0, num_files=1, num_folders=1, @@ -58,7 +58,7 @@ def test_host_observation(simulation): observation_state = host_obs.observe(simulation.describe_state()) assert observation_state.get("operating_status") == 4 # shutting down - for i in range(pc.shut_down_duration + 1): + for i in range(pc.config.shut_down_duration + 1): pc.apply_timestep(i) observation_state = host_obs.observe(simulation.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index 48d29cfb..495e102d 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from pprint import pprint from primaite.game.agent.observations.acl_observation import ACLObservation @@ -8,15 +8,17 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP def test_router_observation(): """Test adding/removing acl rules and enabling/disabling ports.""" net = Network() - router = Router(hostname="router", num_ports=5, operating_state=NodeOperatingState.ON) + router = Router.from_config( + config={"type": "router", "hostname": "router", "num_ports": 5, "operating_state": "ON"} + ) ports = [PortObservation(where=["NICs", i]) for i in range(1, 6)] acl = ACLObservation( @@ -24,8 +26,8 @@ def test_router_observation(): num_rules=7, ip_list=["10.0.0.1", "10.0.0.2"], wildcard_list=["0.0.0.255", "0.0.0.1"], - port_list=["HTTP", "DNS"], - protocol_list=["TCP"], + port_list=[80, 53], + protocol_list=["tcp"], ) router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl, include_users=False) @@ -39,13 +41,13 @@ def test_router_observation(): # Add an ACL rule to the router router.acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="10.0.0.1", src_wildcard_mask="0.0.0.1", dst_ip_address="10.0.0.2", dst_wildcard_mask="0.0.0.1", - src_port=Port.HTTP, - dst_port=Port.HTTP, + src_port=PORT_LOOKUP["HTTP"], + dst_port=PORT_LOOKUP["HTTP"], position=5, ) # Observe the state using the RouterObservation instance @@ -89,7 +91,9 @@ def test_router_observation(): assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(1, 6)) # connect a switch to the router and check that only the correct port is updated - switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON) + switch: Switch = Switch.from_config( + config={"type": "switch", "hostname": "switch", "num_ports": 1, "operating_state": "ON"} + ) link = net.connect(router.network_interface[1], switch.network_interface[1]) assert router.network_interface[1].enabled observed_output = router_observation.observe(router.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_software_observations.py b/tests/integration_tests/game_layer/observations/test_software_observations.py index a0637969..1ebff10c 100644 --- a/tests/integration_tests/game_layer/observations/test_software_observations.py +++ b/tests/integration_tests/game_layer/observations/test_software_observations.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from gymnasium import spaces @@ -26,11 +26,11 @@ def test_service_observation(simulation): # install software on the computer pc.software_manager.install(NTPServer) - ntp_server = pc.software_manager.software.get("NTPServer") + ntp_server = pc.software_manager.software.get("ntp-server") assert ntp_server service_obs = ServiceObservation( - where=["network", "nodes", pc.hostname, "services", "NTPServer"], services_requires_scan=True + where=["network", "nodes", pc.config.hostname, "services", "ntp-server"], services_requires_scan=True ) assert service_obs.space["operating_status"] == spaces.Discrete(7) @@ -53,11 +53,11 @@ def test_application_observation(simulation): # install software on the computer pc.software_manager.install(DatabaseClient) - web_browser: WebBrowser = pc.software_manager.software.get("WebBrowser") + web_browser: WebBrowser = pc.software_manager.software.get("web-browser") assert web_browser app_obs = ApplicationObservation( - where=["network", "nodes", pc.hostname, "applications", "WebBrowser"], applications_requires_scan=True + where=["network", "nodes", pc.config.hostname, "applications", "web-browser"], applications_requires_scan=True ) web_browser.close() @@ -79,7 +79,7 @@ def test_application_executions_categories(simulation): pc: Computer = simulation.network.get_node_by_hostname("client_1") app_obs = ApplicationObservation( - where=["network", "nodes", pc.hostname, "applications", "WebBrowser"], + where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"], applications_requires_scan=False, thresholds={"app_executions": {"low": 3, "medium": 6, "high": 9}}, ) @@ -91,7 +91,7 @@ def test_application_executions_categories(simulation): with pytest.raises(Exception): # should throw an error ApplicationObservation( - where=["network", "nodes", pc.hostname, "applications", "WebBrowser"], + where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"], applications_requires_scan=False, thresholds={"app_executions": {"low": 9, "medium": 6, "high": 9}}, ) @@ -99,7 +99,7 @@ def test_application_executions_categories(simulation): with pytest.raises(Exception): # should throw an error ApplicationObservation( - where=["network", "nodes", pc.hostname, "applications", "WebBrowser"], + where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"], applications_requires_scan=False, thresholds={"app_executions": {"low": 3, "medium": 9, "high": 9}}, ) diff --git a/tests/integration_tests/game_layer/observations/test_user_observations.py b/tests/integration_tests/game_layer/observations/test_user_observations.py index ca5e2543..de5c3c98 100644 --- a/tests/integration_tests/game_layer/observations/test_user_observations.py +++ b/tests/integration_tests/game_layer/observations/test_user_observations.py @@ -1,9 +1,9 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml" @@ -13,9 +13,9 @@ DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yam def env_with_ssh() -> PrimaiteGymEnv: """Build data manipulation environment with SSH port open on router.""" env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG) - env.agent.flatten_obs = False + env.agent.config.agent_settings.flatten_obs = False router: Router = env.game.simulation.network.get_node_by_hostname("router_1") - router.acl.add_rule(ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=3) + router.acl.add_rule(ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=3) return env diff --git a/tests/integration_tests/game_layer/test_RNG_seed.py b/tests/integration_tests/game_layer/test_RNG_seed.py index 508f35e6..2b80e153 100644 --- a/tests/integration_tests/game_layer/test_RNG_seed.py +++ b/tests/integration_tests/game_layer/test_RNG_seed.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from pprint import pprint import pytest @@ -25,12 +25,12 @@ def test_rng_seed_set(create_env): env.reset(seed=3) for i in range(100): env.step(0) - a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] + a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do-nothing"] env.reset(seed=3) for i in range(100): env.step(0) - b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] + b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do-nothing"] assert a == b @@ -46,12 +46,12 @@ def test_rng_seed_unset(create_env): env.reset() for i in range(100): env.step(0) - a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] + a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do-nothing"] env.reset() for i in range(100): env.step(0) - b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] + b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do-nothing"] assert a != b diff --git a/tests/integration_tests/game_layer/test_action_mask.py b/tests/integration_tests/game_layer/test_action_mask.py index 64464724..e0337929 100644 --- a/tests/integration_tests/game_layer/test_action_mask.py +++ b/tests/integration_tests/game_layer/test_action_mask.py @@ -1,7 +1,8 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.host_node import HostNode +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from primaite.simulator.system.services.service import ServiceOperatingState from tests.conftest import TEST_ASSETS_ROOT @@ -15,7 +16,6 @@ def test_mask_contents_correct(): net = sim.network mask = game.action_mask("defender") agent = env.agent - node_list = agent.action_manager.node_names action_map = agent.action_manager.action_map # CHECK NIC ENABLE/DISABLE ACTIONS @@ -23,10 +23,10 @@ def test_mask_contents_correct(): mask = game.action_mask("defender") act_type, act_params = action - if act_type == "NODE_NIC_ENABLE": - node_name = node_list[act_params["node_id"]] + if act_type == "node_nic_enable": + node_name = act_params["node_name"] node_obj = net.get_node_by_hostname(node_name) - nic_obj = node_obj.network_interface[act_params["nic_id"] + 1] + nic_obj = node_obj.network_interface[act_params["nic_num"]] assert nic_obj.enabled assert not mask[action_num] nic_obj.disable() @@ -34,10 +34,10 @@ def test_mask_contents_correct(): assert mask[action_num] nic_obj.enable() - if act_type == "NODE_NIC_DISABLE": - node_name = node_list[act_params["node_id"]] + if act_type == "node_nic_disable": + node_name = act_params["node_name"] node_obj = net.get_node_by_hostname(node_name) - nic_obj = node_obj.network_interface[act_params["nic_id"] + 1] + nic_obj = node_obj.network_interface[act_params["nic_num"]] assert nic_obj.enabled assert mask[action_num] nic_obj.disable() @@ -45,14 +45,14 @@ def test_mask_contents_correct(): assert not mask[action_num] nic_obj.enable() - if act_type == "ROUTER_ACL_ADDRULE": + if act_type == "router-acl-add-rule": assert mask[action_num] - if act_type == "ROUTER_ACL_REMOVERULE": + if act_type == "router-acl-remove-rule": assert mask[action_num] - if act_type == "NODE_RESET": - node_name = node_list[act_params["node_id"]] + if act_type == "node-reset": + node_name = act_params["node_name"] node_obj = net.get_node_by_hostname(node_name) assert node_obj.operating_state is NodeOperatingState.ON assert mask[action_num] @@ -61,8 +61,8 @@ def test_mask_contents_correct(): assert not mask[action_num] node_obj.operating_state = NodeOperatingState.ON - if act_type == "NODE_SHUTDOWN": - node_name = node_list[act_params["node_id"]] + if act_type == "node-shutdown": + node_name = act_params["node_name"] node_obj = net.get_node_by_hostname(node_name) assert node_obj.operating_state is NodeOperatingState.ON assert mask[action_num] @@ -71,8 +71,8 @@ def test_mask_contents_correct(): assert not mask[action_num] node_obj.operating_state = NodeOperatingState.ON - if act_type == "NODE_OS_SCAN": - node_name = node_list[act_params["node_id"]] + if act_type == "node-os-scan": + node_name = act_params["node_name"] node_obj = net.get_node_by_hostname(node_name) assert node_obj.operating_state is NodeOperatingState.ON assert mask[action_num] @@ -81,8 +81,8 @@ def test_mask_contents_correct(): assert not mask[action_num] node_obj.operating_state = NodeOperatingState.ON - if act_type == "NODE_STARTUP": - node_name = node_list[act_params["node_id"]] + if act_type == "node-startup": + node_name = act_params["node_name"] node_obj = net.get_node_by_hostname(node_name) assert node_obj.operating_state is NodeOperatingState.ON assert not mask[action_num] @@ -91,15 +91,15 @@ def test_mask_contents_correct(): assert mask[action_num] node_obj.operating_state = NodeOperatingState.ON - if act_type == "DONOTHING": + if act_type == "do-nothing": assert mask[action_num] - if act_type == "NODE_SERVICE_DISABLE": + if act_type == "node-service-disable": assert mask[action_num] - if act_type in ["NODE_SERVICE_SCAN", "NODE_SERVICE_STOP", "NODE_SERVICE_PAUSE"]: - node_name = node_list[act_params["node_id"]] - service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]] + if act_type in ["node-service-scan", "node-service-stop", "node-service-pause"]: + node_name = act_params["node_name"] + service_name = act_params["service_name"] node_obj = net.get_node_by_hostname(node_name) service_obj = node_obj.software_manager.software.get(service_name) assert service_obj.operating_state is ServiceOperatingState.RUNNING @@ -109,9 +109,9 @@ def test_mask_contents_correct(): assert not mask[action_num] service_obj.operating_state = ServiceOperatingState.RUNNING - if act_type == "NODE_SERVICE_RESUME": - node_name = node_list[act_params["node_id"]] - service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]] + if act_type == "node-service-resume": + node_name = act_params["node_name"] + service_name = act_params["service_name"] node_obj = net.get_node_by_hostname(node_name) service_obj = node_obj.software_manager.software.get(service_name) assert service_obj.operating_state is ServiceOperatingState.RUNNING @@ -121,9 +121,9 @@ def test_mask_contents_correct(): assert mask[action_num] service_obj.operating_state = ServiceOperatingState.RUNNING - if act_type == "NODE_SERVICE_START": - node_name = node_list[act_params["node_id"]] - service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]] + if act_type == "node-service-start": + node_name = act_params["node_name"] + service_name = act_params["service_name"] node_obj = net.get_node_by_hostname(node_name) service_obj = node_obj.software_manager.software.get(service_name) assert service_obj.operating_state is ServiceOperatingState.RUNNING @@ -133,9 +133,9 @@ def test_mask_contents_correct(): assert mask[action_num] service_obj.operating_state = ServiceOperatingState.RUNNING - if act_type == "NODE_SERVICE_ENABLE": - node_name = node_list[act_params["node_id"]] - service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]] + if act_type == "node-service-enable": + node_name = act_params["node_name"] + service_name = act_params["service_name"] node_obj = net.get_node_by_hostname(node_name) service_obj = node_obj.software_manager.software.get(service_name) assert service_obj.operating_state is ServiceOperatingState.RUNNING @@ -145,12 +145,10 @@ def test_mask_contents_correct(): assert mask[action_num] service_obj.operating_state = ServiceOperatingState.RUNNING - if act_type in ["NODE_FILE_SCAN", "NODE_FILE_CHECKHASH", "NODE_FILE_DELETE"]: - node_name = node_list[act_params["node_id"]] - folder_name = agent.action_manager.get_folder_name_by_idx(act_params["node_id"], act_params["folder_id"]) - file_name = agent.action_manager.get_file_name_by_idx( - act_params["node_id"], act_params["folder_id"], act_params["file_id"] - ) + if act_type in ["node-file-scan", "node-file-checkhash", "node-file-delete"]: + node_name = act_params["node_name"] + folder_name = act_params["folder_name"] + file_name = act_params["file_name"] node_obj = net.get_node_by_hostname(node_name) file_obj = node_obj.file_system.get_file(folder_name, file_name, include_deleted=True) assert not file_obj.deleted diff --git a/tests/integration_tests/game_layer/test_action_shapes.py b/tests/integration_tests/game_layer/test_action_shapes.py deleted file mode 100644 index 48500d8f..00000000 --- a/tests/integration_tests/game_layer/test_action_shapes.py +++ /dev/null @@ -1,21 +0,0 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from typing import Tuple - -from primaite.game.agent.interface import ProxyAgent -from primaite.game.game import PrimaiteGame -from tests import TEST_ASSETS_ROOT - -FIREWALL_ACTIONS_NETWORK = TEST_ASSETS_ROOT / "configs/firewall_actions_network.yaml" - - -def test_router_acl_add_rule_action_shape(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]): - """Test to check ROUTER_ADD_ACL_RULE has the expected action shape.""" - game, agent = game_and_agent - - # assert that the shape of the actions is correct - router_acl_add_rule_action = agent.action_manager.actions.get("ROUTER_ACL_ADDRULE") - assert router_acl_add_rule_action.shape.get("source_ip_id") == len(agent.action_manager.ip_address_list) - assert router_acl_add_rule_action.shape.get("dest_ip_id") == len(agent.action_manager.ip_address_list) - assert router_acl_add_rule_action.shape.get("source_port_id") == len(agent.action_manager.ports) - assert router_acl_add_rule_action.shape.get("dest_port_id") == len(agent.action_manager.ports) - assert router_acl_add_rule_action.shape.get("protocol_id") == len(agent.action_manager.protocols) diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index a9231632..764c207e 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK # Plan for creating integration tests for the actions: # I need to test that the requests coming out of the actions have the intended effect on the simulation. # I can do this by creating a simulation, and then running the action on the simulation, and then checking @@ -21,21 +21,23 @@ from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.hardware.nodes.network.firewall import Firewall +from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT FIREWALL_ACTIONS_NETWORK = TEST_ASSETS_ROOT / "configs/firewall_actions_network.yaml" def test_do_nothing_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]): - """Test that the DoNothingAction can form a request and that it is accepted by the simulation.""" + """Test that the do_nothingAction can form a request and that it is accepted by the simulation.""" game, agent = game_and_agent - action = ("DONOTHING", {}) + action = ("do-nothing", {}) agent.store_action(action) game.step() @@ -51,12 +53,12 @@ def test_node_service_scan_integration(game_and_agent: Tuple[PrimaiteGame, Proxy game, agent = game_and_agent # 1: Check that the service starts off in a good state, and that visible state is hidden until first scan - svc = game.simulation.network.get_node_by_hostname("server_1").software_manager.software.get("DNSServer") + svc = game.simulation.network.get_node_by_hostname("server_1").software_manager.software.get("dns-server") assert svc.health_state_actual == SoftwareHealthState.GOOD assert svc.health_state_visible == SoftwareHealthState.UNUSED # 2: Scan and check that the visible state is now correct - action = ("NODE_SERVICE_SCAN", {"node_id": 1, "service_id": 0}) + action = ("node-service-scan", {"node_name": "server_1", "service_name": "dns-server"}) agent.store_action(action) game.step() assert svc.health_state_actual == SoftwareHealthState.GOOD @@ -67,7 +69,7 @@ def test_node_service_scan_integration(game_and_agent: Tuple[PrimaiteGame, Proxy assert svc.health_state_visible == SoftwareHealthState.GOOD # 4: Scan and check that the visible state is now correct - action = ("NODE_SERVICE_SCAN", {"node_id": 1, "service_id": 0}) + action = ("node-service-scan", {"node_name": "server_1", "service_name": "dns-server"}) agent.store_action(action) game.step() assert svc.health_state_actual == SoftwareHealthState.COMPROMISED @@ -84,11 +86,11 @@ def test_node_service_fix_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA game, agent = game_and_agent # 1: Corrupt the service - svc = game.simulation.network.get_node_by_hostname("server_1").software_manager.software.get("DNSServer") + svc = game.simulation.network.get_node_by_hostname("server_1").software_manager.software.get("dns-server") svc.health_state_actual = SoftwareHealthState.COMPROMISED # 2: Apply a patch action - action = ("NODE_SERVICE_FIX", {"node_id": 1, "service_id": 0}) + action = ("node-service-fix", {"node_name": "server_1", "service_name": "dns-server"}) agent.store_action(action) game.step() @@ -96,7 +98,7 @@ def test_node_service_fix_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA assert svc.health_state_actual == SoftwareHealthState.FIXING # 4: perform a few do-nothing steps and check that the service is now in the good state - action = ("DONOTHING", {}) + action = ("do-nothing", {}) agent.store_action(action) game.step() assert svc.health_state_actual == SoftwareHealthState.GOOD @@ -115,31 +117,31 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox server_1 = game.simulation.network.get_node_by_hostname("server_1") server_2 = game.simulation.network.get_node_by_hostname("server_2") router = game.simulation.network.get_node_by_hostname("router") - assert router.acl.num_rules == 3 + assert router.acl.num_rules == 4 assert client_1.ping("10.0.2.3") # client_1 can ping server_2 assert server_2.ping("10.0.1.2") # server_2 can ping client_1 # 2: Add a rule to block client 1 from reaching server 2 on router action = ( - "ROUTER_ACL_ADDRULE", + "router-acl-add-rule", { "target_router": "router", - "position": 4, # 4th rule - "permission": 2, # DENY - "source_ip_id": 3, # 10.0.1.2 (client_1) - "dest_ip_id": 6, # 10.0.2.3 (server_2) - "dest_port_id": 1, # ALL - "source_port_id": 1, # ALL - "protocol_id": 1, # ALL - "source_wildcard_id": 0, - "dest_wildcard_id": 0, + "position": 4, + "permission": "DENY", + "src_ip": "10.0.1.2", + "src_wildcard": "NONE", + "src_port": "ALL", + "dst_ip": "10.0.2.3", + "dst_wildcard": "NONE", + "dst_port": "ALL", + "protocol_name": "icmp", }, ) agent.store_action(action) game.step() - # 3: Check that the ACL now has 4 rules, and that client 1 cannot ping server 2 - assert router.acl.num_rules == 4 + # 3: Check that the acl now has 6 rules, and that client 1 cannot ping server 2 + assert router.acl.num_rules == 5 assert not client_1.ping("10.0.2.3") # Cannot ping server_2 assert client_1.ping("10.0.2.2") # Can ping server_1 assert not server_2.ping( @@ -148,25 +150,25 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox # 4: Add a rule to block server_1 from reaching server_2 on router (this should not affect comms as they are on same subnet) action = ( - "ROUTER_ACL_ADDRULE", + "router-acl-add-rule", { "target_router": "router", "position": 5, # 5th rule - "permission": 2, # DENY - "source_ip_id": 5, # 10.0.2.2 (server_1) - "dest_ip_id": 6, # 10.0.2.3 (server_2) - "dest_port_id": 1, # ALL - "source_port_id": 1, # ALL - "protocol_id": 1, # ALL - "source_wildcard_id": 0, - "dest_wildcard_id": 0, + "permission": "DENY", # DENY + "src_ip": "10.0.2.2", # 10.0.2.2 (server_1) + "src_wildcard": 0, + "src_port": "ALL", # ALL + "dst_ip": "10.0.2.3", # 10.0.2.3 (server_2) + "dst_wildcard": 0, + "dst_port": "ALL", # ALL + "protocol_name": "ALL", # ALL }, ) agent.store_action(action) game.step() - # 5: Check that the ACL now has 5 rules, but that server_1 can still ping server_2 - assert router.acl.num_rules == 5 + # 5: Check that the ACL now has 6 rules, but that server_1 can still ping server_2 + assert router.acl.num_rules == 6 assert server_1.ping("10.0.2.3") # Can ping server_2 @@ -177,16 +179,17 @@ def test_router_acl_removerule_integration(game_and_agent: Tuple[PrimaiteGame, P # 1: Check that http traffic is going across the network nicely. client_1 = game.simulation.network.get_node_by_hostname("client_1") server_1 = game.simulation.network.get_node_by_hostname("server_1") - router = game.simulation.network.get_node_by_hostname("router") + router: Router = game.simulation.network.get_node_by_hostname("router") + assert router.acl.num_rules == 4 - browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") + browser: WebBrowser = client_1.software_manager.software.get("web-browser") browser.run() - browser.target_url = "http://www.example.com" + browser.config.target_url = "http://www.example.com" assert browser.get_webpage() # check that the browser can access example.com before we block it # 2: Remove rule that allows HTTP traffic across the network action = ( - "ROUTER_ACL_REMOVERULE", + "router-acl-remove-rule", { "target_router": "router", "position": 3, # 4th rule @@ -196,9 +199,9 @@ def test_router_acl_removerule_integration(game_and_agent: Tuple[PrimaiteGame, P game.step() # 3: Check that the ACL now has 2 rules, and that client 1 cannot access example.com - assert router.acl.num_rules == 2 + assert router.acl.num_rules == 3 assert not browser.get_webpage() - client_1.software_manager.software.get("DNSClient").dns_cache.clear() + client_1.software_manager.software.get("dns-client").dns_cache.clear() assert client_1.ping("10.0.2.2") # pinging still works because ICMP is allowed assert client_1.ping("10.0.2.3") @@ -212,17 +215,17 @@ def test_host_nic_disable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA server_1 = game.simulation.network.get_node_by_hostname("server_1") server_2 = game.simulation.network.get_node_by_hostname("server_2") - browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") + browser: WebBrowser = client_1.software_manager.software.get("web-browser") browser.run() - browser.target_url = "http://www.example.com" + browser.config.target_url = "http://www.example.com" assert browser.get_webpage() # check that the browser can access example.com before we block it # 2: Disable the NIC on client_1 action = ( - "HOST_NIC_DISABLE", + "host-nic-disable", { - "node_id": 0, # client_1 - "nic_id": 0, # the only nic (eth-1) + "node_name": "client_1", # client_1 + "nic_num": 1, # the only nic (eth-1) }, ) agent.store_action(action) @@ -250,10 +253,10 @@ def test_host_nic_enable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAg # 2: Use action to enable nic action = ( - "HOST_NIC_ENABLE", + "host-nic-enable", { - "node_id": 0, # client_1 - "nic_id": 0, # the only nic (eth-1) + "node_name": "client_1", # client_1 + "nic_num": 1, # the only nic (eth-1) }, ) agent.store_action(action) @@ -273,15 +276,15 @@ def test_node_file_scan_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAge client_1 = game.simulation.network.get_node_by_hostname("client_1") file = client_1.file_system.get_file("downloads", "cat.png") assert file.health_status == FileSystemItemHealthStatus.GOOD - assert file.visible_health_status == FileSystemItemHealthStatus.GOOD + assert file.visible_health_status == FileSystemItemHealthStatus.NONE # 2: perform a scan and make sure nothing has changed action = ( - "NODE_FILE_SCAN", + "node-file-scan", { - "node_id": 0, # client_1, - "folder_id": 0, # downloads, - "file_id": 0, # cat.png + "node_name": "client_1", # client_1, + "folder_name": "downloads", # downloads, + "file_name": "cat.png", # cat.png }, ) agent.store_action(action) @@ -314,11 +317,11 @@ def test_node_file_delete_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA # 2: delete the file action = ( - "NODE_FILE_DELETE", + "node-file-delete", { - "node_id": 0, # client_1 - "folder_id": 0, # downloads - "file_id": 0, # cat.png + "node_name": "client_1", # client_1 + "folder_name": "downloads", # downloads + "file_name": "cat.png", # cat.png }, ) agent.store_action(action) @@ -334,14 +337,15 @@ def test_node_file_create(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]): """Test that a file is created.""" game, agent = game_and_agent - client_1 = game.simulation.network.get_node_by_hostname("client_1") # + client_1 = game.simulation.network.get_node_by_hostname("client_1") action = ( - "NODE_FILE_CREATE", + "node-file-create", { - "node_id": 0, + "node_name": "client_1", "folder_name": "test", "file_name": "file.txt", + "force": "False", }, ) agent.store_action(action) @@ -357,9 +361,9 @@ def test_node_file_access(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]): client_1 = game.simulation.network.get_node_by_hostname("client_1") # action = ( - "NODE_FILE_CREATE", + "node-file-create", { - "node_id": 0, + "node_name": "client_1", "folder_name": "test", "file_name": "file.txt", }, @@ -370,9 +374,9 @@ def test_node_file_access(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]): assert client_1.file_system.get_file(folder_name="test", file_name="file.txt").num_access == 0 action = ( - "NODE_FILE_ACCESS", + "node-file-access", { - "node_id": 0, + "node_name": "client_1", "folder_name": "test", "file_name": "file.txt", }, @@ -390,9 +394,9 @@ def test_node_folder_create(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]): client_1 = game.simulation.network.get_node_by_hostname("client_1") # action = ( - "NODE_FOLDER_CREATE", + "node-folder-create", { - "node_id": 0, + "node_name": "client_1", "folder_name": "test", }, ) @@ -411,17 +415,17 @@ def test_network_router_port_disable_integration(game_and_agent: Tuple[PrimaiteG server_1 = game.simulation.network.get_node_by_hostname("server_1") router = game.simulation.network.get_node_by_hostname("router") - browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") + browser: WebBrowser = client_1.software_manager.software.get("web-browser") browser.run() - browser.target_url = "http://www.example.com" + browser.config.target_url = "http://www.example.com" assert browser.get_webpage() # check that the browser can access example.com before we block it # 2: Disable the NIC on client_1 action = ( - "NETWORK_PORT_DISABLE", + "network-port-disable", { "target_nodename": "router", # router - "port_id": 1, # port 1 + "port_num": 1, # port 1 }, ) agent.store_action(action) @@ -450,10 +454,10 @@ def test_network_router_port_enable_integration(game_and_agent: Tuple[PrimaiteGa # 2: Use action to enable port action = ( - "NETWORK_PORT_ENABLE", + "network-port-enable", { "target_nodename": "router", # router - "port_id": 1, # port 1 + "port_num": 1, # port 1 }, ) agent.store_action(action) @@ -471,16 +475,19 @@ def test_node_application_scan_integration(game_and_agent: Tuple[PrimaiteGame, P # 1: Check that http traffic is going across the network nicely. client_1 = game.simulation.network.get_node_by_hostname("client_1") - browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") + browser: WebBrowser = client_1.software_manager.software.get("web-browser") browser.run() - browser.target_url = "http://www.example.com" + browser.config.target_url = "http://www.example.com" assert browser.get_webpage() # check that the browser can access example.com assert browser.health_state_actual == SoftwareHealthState.GOOD assert browser.health_state_visible == SoftwareHealthState.UNUSED # 2: Scan and check that the visible state is now correct - action = ("NODE_APPLICATION_SCAN", {"node_id": 0, "application_id": 0}) + action = ( + "node-application-scan", + {"node_name": "client_1", "application_name": "web-browser"}, + ) agent.store_action(action) game.step() assert browser.health_state_actual == SoftwareHealthState.GOOD @@ -491,7 +498,10 @@ def test_node_application_scan_integration(game_and_agent: Tuple[PrimaiteGame, P assert browser.health_state_visible == SoftwareHealthState.GOOD # 4: Scan and check that the visible state is now correct - action = ("NODE_APPLICATION_SCAN", {"node_id": 0, "application_id": 0}) + action = ( + "node-application-scan", + {"node_name": "client_1", "application_name": "web-browser"}, + ) agent.store_action(action) game.step() assert browser.health_state_actual == SoftwareHealthState.COMPROMISED @@ -508,11 +518,14 @@ def test_node_application_fix_integration(game_and_agent: Tuple[PrimaiteGame, Pr # 1: Check that http traffic is going across the network nicely. client_1 = game.simulation.network.get_node_by_hostname("client_1") - browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") + browser: WebBrowser = client_1.software_manager.software.get("web-browser") browser.health_state_actual = SoftwareHealthState.COMPROMISED # 2: Apply a fix action - action = ("NODE_APPLICATION_FIX", {"node_id": 0, "application_id": 0}) + action = ( + "node-application-fix", + {"node_name": "client_1", "application_name": "web-browser"}, + ) agent.store_action(action) game.step() @@ -520,7 +533,7 @@ def test_node_application_fix_integration(game_and_agent: Tuple[PrimaiteGame, Pr assert browser.health_state_actual == SoftwareHealthState.FIXING # 4: perform a few do-nothing steps and check that the application is now in the good state - action = ("DONOTHING", {}) + action = ("do-nothing", {}) agent.store_action(action) game.step() assert browser.health_state_actual == SoftwareHealthState.GOOD @@ -533,12 +546,15 @@ def test_node_application_close_integration(game_and_agent: Tuple[PrimaiteGame, game, agent = game_and_agent client_1 = game.simulation.network.get_node_by_hostname("client_1") - browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") + browser: WebBrowser = client_1.software_manager.software.get("web-browser") browser.run() assert browser.operating_state == ApplicationOperatingState.RUNNING # 2: Apply a close action - action = ("NODE_APPLICATION_CLOSE", {"node_id": 0, "application_id": 0}) + action = ( + "node-application-close", + {"node_name": "client_1", "application_name": "web-browser"}, + ) agent.store_action(action) game.step() @@ -549,25 +565,31 @@ def test_node_application_install_and_uninstall_integration(game_and_agent: Tupl """Test that the NodeApplicationInstallAction and NodeApplicationRemoveAction can form a request and that it is accepted by the simulation. - When you initiate a install action, the Application will be installed and configured on the node. + When you initiate an install action, the Application will be installed and configured on the node. The remove action will uninstall the application from the node.""" game, agent = game_and_agent client_1 = game.simulation.network.get_node_by_hostname("client_1") - assert client_1.software_manager.software.get("DoSBot") is None + assert client_1.software_manager.software.get("dos-bot") is None - action = ("NODE_APPLICATION_INSTALL", {"node_id": 0, "application_name": "DoSBot"}) + action = ( + "node-application-install", + {"node_name": "client_1", "application_name": "dos-bot"}, + ) agent.store_action(action) game.step() - assert client_1.software_manager.software.get("DoSBot") is not None + assert client_1.software_manager.software.get("dos-bot") is not None - action = ("NODE_APPLICATION_REMOVE", {"node_id": 0, "application_name": "DoSBot"}) + action = ( + "node-application-remove", + {"node_name": "client_1", "application_name": "dos-bot"}, + ) agent.store_action(action) game.step() - assert client_1.software_manager.software.get("DoSBot") is None + assert client_1.software_manager.software.get("dos-bot") is None def test_firewall_acl_add_remove_rule_integration(): @@ -608,9 +630,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.internal_outbound_acl.acl[1].action.name == "DENY" assert firewall.internal_outbound_acl.acl[1].src_ip_address == IPv4Address("192.168.0.10") assert firewall.internal_outbound_acl.acl[1].dst_ip_address is None - assert firewall.internal_outbound_acl.acl[1].dst_port == Port.DNS - assert firewall.internal_outbound_acl.acl[1].src_port == Port.ARP - assert firewall.internal_outbound_acl.acl[1].protocol == IPProtocol.ICMP + assert firewall.internal_outbound_acl.acl[1].dst_port == PORT_LOOKUP["DNS"] + assert firewall.internal_outbound_acl.acl[1].src_port == PORT_LOOKUP["ARP"] + assert firewall.internal_outbound_acl.acl[1].protocol == PROTOCOL_LOOKUP["ICMP"] env.step(4) # Remove ACL rule from Internal Outbound assert firewall.internal_outbound_acl.num_rules == 2 @@ -620,9 +642,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.dmz_inbound_acl.acl[1].action.name == "DENY" assert firewall.dmz_inbound_acl.acl[1].src_ip_address == IPv4Address("192.168.10.10") assert firewall.dmz_inbound_acl.acl[1].dst_ip_address == IPv4Address("192.168.0.10") - assert firewall.dmz_inbound_acl.acl[1].dst_port == Port.HTTP - assert firewall.dmz_inbound_acl.acl[1].src_port == Port.HTTP - assert firewall.dmz_inbound_acl.acl[1].protocol == IPProtocol.UDP + assert firewall.dmz_inbound_acl.acl[1].dst_port == PORT_LOOKUP["HTTP"] + assert firewall.dmz_inbound_acl.acl[1].src_port == PORT_LOOKUP["HTTP"] + assert firewall.dmz_inbound_acl.acl[1].protocol == PROTOCOL_LOOKUP["UDP"] env.step(6) # Remove ACL rule from DMZ Inbound assert firewall.dmz_inbound_acl.num_rules == 2 @@ -632,9 +654,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.dmz_outbound_acl.acl[2].action.name == "DENY" assert firewall.dmz_outbound_acl.acl[2].src_ip_address == IPv4Address("192.168.10.10") assert firewall.dmz_outbound_acl.acl[2].dst_ip_address == IPv4Address("192.168.0.10") - assert firewall.dmz_outbound_acl.acl[2].dst_port == Port.HTTP - assert firewall.dmz_outbound_acl.acl[2].src_port == Port.HTTP - assert firewall.dmz_outbound_acl.acl[2].protocol == IPProtocol.TCP + assert firewall.dmz_outbound_acl.acl[2].dst_port == PORT_LOOKUP["HTTP"] + assert firewall.dmz_outbound_acl.acl[2].src_port == PORT_LOOKUP["HTTP"] + assert firewall.dmz_outbound_acl.acl[2].protocol == PROTOCOL_LOOKUP["TCP"] env.step(8) # Remove ACL rule from DMZ Outbound assert firewall.dmz_outbound_acl.num_rules == 2 @@ -644,9 +666,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.external_inbound_acl.acl[10].action.name == "DENY" assert firewall.external_inbound_acl.acl[10].src_ip_address == IPv4Address("192.168.20.10") assert firewall.external_inbound_acl.acl[10].dst_ip_address == IPv4Address("192.168.10.10") - assert firewall.external_inbound_acl.acl[10].dst_port == Port.POSTGRES_SERVER - assert firewall.external_inbound_acl.acl[10].src_port == Port.POSTGRES_SERVER - assert firewall.external_inbound_acl.acl[10].protocol == IPProtocol.ICMP + assert firewall.external_inbound_acl.acl[10].dst_port == PORT_LOOKUP["POSTGRES_SERVER"] + assert firewall.external_inbound_acl.acl[10].src_port == PORT_LOOKUP["POSTGRES_SERVER"] + assert firewall.external_inbound_acl.acl[10].protocol == PROTOCOL_LOOKUP["ICMP"] env.step(10) # Remove ACL rule from External Inbound assert firewall.external_inbound_acl.num_rules == 1 @@ -658,7 +680,7 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.external_outbound_acl.acl[1].dst_ip_address == IPv4Address("192.168.0.10") assert firewall.external_outbound_acl.acl[1].dst_port is None assert firewall.external_outbound_acl.acl[1].src_port is None - assert firewall.external_outbound_acl.acl[1].protocol is None + assert firewall.external_outbound_acl.acl[1].protocol == PROTOCOL_LOOKUP["NONE"] env.step(12) # Remove ACL rule from External Outbound assert firewall.external_outbound_acl.num_rules == 1 diff --git a/tests/integration_tests/game_layer/test_observations.py b/tests/integration_tests/game_layer/test_observations.py index d5679007..17b9b71e 100644 --- a/tests/integration_tests/game_layer/test_observations.py +++ b/tests/integration_tests/game_layer/test_observations.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from gymnasium import spaces from primaite.game.agent.observations.file_system_observations import FileObservation @@ -8,21 +8,18 @@ from primaite.simulator.sim_container import Simulation def test_file_observation(): sim = Simulation() - pc = Computer(hostname="beep", ip_address="123.123.123.123", subnet_mask="255.255.255.0") + pc: Computer = Computer.from_config( + config={"type": "computer", "hostname": "beep", "ip_address": "123.123.123.123", "subnet_mask": "255.255.255.0"} + ) sim.network.add_node(pc) f = pc.file_system.create_file(file_name="dog.png") state = sim.describe_state() dog_file_obs = FileObservation( - where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"], include_num_access=False, - file_system_requires_scan=True, + file_system_requires_scan=False, ) assert dog_file_obs.observe(state) == {"health_status": 1} assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)}) - - -# TODO: -# def test_file_num_access(): -# ... diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 58783d70..8aae311c 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest import yaml @@ -9,71 +9,81 @@ from primaite.interface.request import RequestResponse from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT from tests.conftest import ControlledAgent -def test_WebpageUnavailablePenalty(game_and_agent): +def test_WebpageUnavailablePenalty(game_and_agent: tuple[PrimaiteGame, ControlledAgent]): """Test that we get the right reward for failing to fetch a website.""" # set up the scenario, configure the web browser to the correct url game, agent = game_and_agent agent: ControlledAgent - comp = WebpageUnavailablePenalty(node_hostname="client_1") + schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="client_1", sticky=True) + comp = WebpageUnavailablePenalty(config=schema) + client_1 = game.simulation.network.get_node_by_hostname("client_1") - browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") + browser: WebBrowser = client_1.software_manager.software.get("web-browser") browser.run() - browser.target_url = "http://www.example.com" + browser.config.target_url = "http://www.example.com" agent.reward_function.register_component(comp, 0.7) # Check that before trying to fetch the webpage, the reward is 0.0 - agent.store_action(("DONOTHING", {})) + agent.store_action(("do-nothing", {})) game.step() assert agent.reward_function.current_reward == 0.0 # Check that successfully fetching the webpage yields a reward of 0.7 - agent.store_action(("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0})) + agent.store_action(("node-application-execute", {"node_name": "client_1", "application_name": "web-browser"})) game.step() assert agent.reward_function.current_reward == 0.7 # Block the web traffic, check that failing to fetch the webpage yields a reward of -0.7 router: Router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.DENY, protocol=IPProtocol.TCP, src_port=Port.HTTP, dst_port=Port.HTTP) - agent.store_action(("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0})) + router.acl.add_rule( + action=ACLAction.DENY, + protocol=PROTOCOL_LOOKUP["TCP"], + src_port=PORT_LOOKUP["HTTP"], + dst_port=PORT_LOOKUP["HTTP"], + ) + agent.store_action(("node-application-execute", {"node_name": "client_1", "application_name": "web-browser"})) game.step() assert agent.reward_function.current_reward == -0.7 -def test_uc2_rewards(game_and_agent): +def test_uc2_rewards(game_and_agent: tuple[PrimaiteGame, ControlledAgent]): """Test that the reward component correctly applies a penalty when the selected client cannot reach the database.""" game, agent = game_and_agent agent: ControlledAgent server_1: Server = game.simulation.network.get_node_by_hostname("server_1") server_1.software_manager.install(DatabaseService) - db_service = server_1.software_manager.software.get("DatabaseService") + db_service = server_1.software_manager.software.get("database-service") db_service.start() client_1 = game.simulation.network.get_node_by_hostname("client_1") client_1.software_manager.install(DatabaseClient) - db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") + db_client: DatabaseClient = client_1.software_manager.software.get("database-client") db_client.configure(server_ip_address=server_1.network_interface[1].ip_address) db_client.run() router: Router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=2) + router.acl.add_rule( + ACLAction.PERMIT, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"], position=2 + ) - comp = GreenAdminDatabaseUnreachablePenalty("client_1") + schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema(node_hostname="client_1", sticky=True) + comp = GreenAdminDatabaseUnreachablePenalty(config=schema) - request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"] + request = ["network", "node", "client_1", "application", "database-client", "execute"] response = game.simulation.apply_request(request) state = game.get_sim_state() ahi = AgentHistoryItem( - timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response + timestep=0, action="node-application-execute", parameters={}, request=request, response=response ) reward_value = comp.calculate(state, last_action_response=ahi) assert reward_value == 1.0 @@ -84,7 +94,7 @@ def test_uc2_rewards(game_and_agent): response = game.simulation.apply_request(request) state = game.get_sim_state() ahi = AgentHistoryItem( - timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response + timestep=0, action="node-application-execute", parameters={}, request=request, response=response ) reward_value = comp.calculate( state, @@ -132,23 +142,25 @@ def test_action_penalty_loads_from_config(): act_penalty_obj = comp[0] if act_penalty_obj is None: pytest.fail("Action penalty reward component was not added to the agent from config.") - assert act_penalty_obj.action_penalty == -0.75 - assert act_penalty_obj.do_nothing_penalty == 0.125 + assert act_penalty_obj.config.action_penalty == -0.75 + assert act_penalty_obj.config.do_nothing_penalty == 0.125 def test_action_penalty(): """Test that the action penalty is correctly applied when agent performs any action""" # Create an ActionPenalty Reward - Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125) + schema = ActionPenalty.ConfigSchema(action_penalty=-0.75, do_nothing_penalty=0.125) + # Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125) + Penalty = ActionPenalty(config=schema) - # Assert that penalty is applied if action isn't DONOTHING + # Assert that penalty is applied if action isn't do-nothing reward_value = Penalty.calculate( state={}, last_action_response=AgentHistoryItem( timestep=0, - action="NODE_APPLICATION_EXECUTE", - parameters={"node_id": 0, "application_id": 1}, + action="node-application-execute", + parameters={"node_name": "client", "application_name": "web-browser"}, request=["execute"], response=RequestResponse.from_bool(True), ), @@ -156,14 +168,14 @@ def test_action_penalty(): assert reward_value == -0.75 - # Assert that no penalty applied for a DONOTHING action + # Assert that no penalty applied for a do-nothing action reward_value = Penalty.calculate( state={}, last_action_response=AgentHistoryItem( timestep=0, - action="DONOTHING", + action="do-nothing", parameters={}, - request=["do_nothing"], + request=["do-nothing"], response=RequestResponse.from_bool(True), ), ) @@ -171,20 +183,21 @@ def test_action_penalty(): assert reward_value == 0.125 -def test_action_penalty_e2e(game_and_agent): +def test_action_penalty_e2e(game_and_agent: tuple[PrimaiteGame, ControlledAgent]): """Test that we get the right reward for doing actions to fetch a website.""" game, agent = game_and_agent agent: ControlledAgent - comp = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125) + schema = ActionPenalty.ConfigSchema(action_penalty=-0.75, do_nothing_penalty=0.125) + comp = ActionPenalty(config=schema) agent.reward_function.register_component(comp, 1.0) - action = ("DONOTHING", {}) + action = ("do-nothing", {}) agent.store_action(action) game.step() assert agent.reward_function.current_reward == 0.125 - action = ("NODE_FILE_SCAN", {"node_id": 0, "folder_id": 0, "file_id": 0}) + action = ("node-file-scan", {"node_name": "client", "folder_name": "downloads", "file_name": "document.pdf"}) agent.store_action(action) game.step() assert agent.reward_function.current_reward == -0.75 diff --git a/tests/integration_tests/network/__init__.py b/tests/integration_tests/network/__init__.py index be6c00e7..836b79af 100644 --- a/tests/integration_tests/network/__init__.py +++ b/tests/integration_tests/network/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/integration_tests/network/test_airspace_config.py b/tests/integration_tests/network/test_airspace_config.py index 78d00b47..fd3f6f28 100644 --- a/tests/integration_tests/network/test_airspace_config.py +++ b/tests/integration_tests/network/test_airspace_config.py @@ -1,8 +1,8 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import yaml from primaite.game.game import PrimaiteGame -from primaite.simulator.network.airspace import AirSpaceFrequency +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from tests import TEST_ASSETS_ROOT @@ -13,8 +13,8 @@ def test_override_freq_max_capacity_mbps(): config_dict = yaml.safe_load(f) network = PrimaiteGame.from_config(cfg=config_dict).simulation.network - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_2_4) == 123.45 - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_5) == 0.0 + assert network.airspace.get_frequency_max_capacity_mbps("WIFI_2_4") == 123.45 + assert network.airspace.get_frequency_max_capacity_mbps("WIFI_5") == 0.0 pc_a = network.get_node_by_hostname("pc_a") pc_b = network.get_node_by_hostname("pc_b") @@ -32,8 +32,8 @@ def test_override_freq_max_capacity_mbps_blocked(): config_dict = yaml.safe_load(f) network = PrimaiteGame.from_config(cfg=config_dict).simulation.network - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_2_4) == 0.0 - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_5) == 0.0 + assert network.airspace.get_frequency_max_capacity_mbps("WIFI_2_4") == 0.0 + assert network.airspace.get_frequency_max_capacity_mbps("WIFI_5") == 0.0 pc_a = network.get_node_by_hostname("pc_a") pc_b = network.get_node_by_hostname("pc_b") diff --git a/tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py b/tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py index b7317c3d..479473d1 100644 --- a/tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py +++ b/tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from primaite.simulator.file_system.file_type import FileType from primaite.simulator.network.hardware.nodes.network.router import ACLAction from primaite.simulator.system.services.ftp.ftp_client import FTPClient @@ -19,11 +19,11 @@ def test_wireless_link_loading(wireless_wan_network): airspace = router_1.airspace client.software_manager.install(FTPClient) - ftp_client: FTPClient = client.software_manager.software.get("FTPClient") + ftp_client: FTPClient = client.software_manager.software.get("ftp-client") ftp_client.start() server.software_manager.install(FTPServer) - ftp_server: FTPServer = server.software_manager.software.get("FTPServer") + ftp_server: FTPServer = server.software_manager.software.get("ftp-server") ftp_server.start() client.file_system.create_file(file_name="mixtape", size=10 * 10**6, file_type=FileType.MP3, folder_name="music") diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index 80007c46..50a1173c 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -1,27 +1,35 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address, IPv4Network from typing import Any, Dict, List, Tuple import pytest +from pydantic import Field from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application from primaite.simulator.system.services.service import Service +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP -class BroadcastTestService(Service): +class BroadcastTestService(Service, discriminator="broadcast-test-service"): """A service for sending broadcast and unicast messages over a network.""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for BroadcastTestService.""" + + type: str = "broadcast-test-service" + + config: "BroadcastTestService.ConfigSchema" = Field(default_factory=lambda: BroadcastTestService.ConfigSchema()) + def __init__(self, **kwargs): # Set default service properties for broadcasting - kwargs["name"] = "BroadcastService" - kwargs["port"] = Port.HTTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["name"] = "broadcast-test-service" + kwargs["port"] = PORT_LOOKUP["HTTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: @@ -33,24 +41,33 @@ class BroadcastTestService(Service): super().send( payload="unicast", dest_ip_address=ip_address, - dest_port=Port.HTTP, + dest_port=PORT_LOOKUP["HTTP"], ) def broadcast(self, ip_network: IPv4Network): # Send a broadcast payload to an entire IP network - super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port.HTTP, ip_protocol=self.protocol) + super().send( + payload="broadcast", dest_ip_address=ip_network, dest_port=PORT_LOOKUP["HTTP"], ip_protocol=self.protocol + ) -class BroadcastTestClient(Application, identifier="BroadcastTestClient"): +class BroadcastTestClient(Application, discriminator="broadcast-test-client"): """A client application to receive broadcast and unicast messages.""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for BroadcastTestClient.""" + + type: str = "broadcast-test-client" + + config: ConfigSchema = Field(default_factory=lambda: BroadcastTestClient.ConfigSchema()) + payloads_received: List = [] def __init__(self, **kwargs): # Set default client properties - kwargs["name"] = "BroadcastTestClient" - kwargs["port"] = Port.HTTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["name"] = "broadcast-test-client" + kwargs["port"] = PORT_LOOKUP["HTTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: @@ -67,44 +84,55 @@ class BroadcastTestClient(Application, identifier="BroadcastTestClient"): def broadcast_network() -> Network: network = Network() - client_1 = Computer( - hostname="client_1", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + client_1_cfg = { + "type": "computer", + "hostname": "client_1", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + client_1: Computer = Computer.from_config(config=client_1_cfg) client_1.power_on() client_1.software_manager.install(BroadcastTestClient) - application_1 = client_1.software_manager.software["BroadcastTestClient"] + application_1 = client_1.software_manager.software["broadcast-test-client"] application_1.run() + client_2_cfg = { + "type": "computer", + "hostname": "client_2", + "ip_address": "192.168.1.3", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } - client_2 = Computer( - hostname="client_2", - ip_address="192.168.1.3", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + client_2: Computer = Computer.from_config(config=client_2_cfg) client_2.power_on() client_2.software_manager.install(BroadcastTestClient) - application_2 = client_2.software_manager.software["BroadcastTestClient"] + application_2 = client_2.software_manager.software["broadcast-test-client"] application_2.run() - server_1 = Server( - hostname="server_1", - ip_address="192.168.1.1", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + server_1_cfg = { + "type": "server", + "hostname": "server_1", + "ip_address": "192.168.1.1", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + server_1: Server = Server.from_config(config=server_1_cfg) + server_1.power_on() server_1.software_manager.install(BroadcastTestService) - service: BroadcastTestService = server_1.software_manager.software["BroadcastService"] + service: BroadcastTestService = server_1.software_manager.software["broadcast-test-service"] service.start() - switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0) + switch_1: Switch = Switch.from_config( + config={"type": "switch", "hostname": "switch_1", "num_ports": 6, "start_up_duration": 0} + ) switch_1.power_on() network.connect(endpoint_a=client_1.network_interface[1], endpoint_b=switch_1.network_interface[1]) @@ -119,13 +147,13 @@ def broadcast_service_and_clients( broadcast_network, ) -> Tuple[BroadcastTestService, BroadcastTestClient, BroadcastTestClient]: client_1: BroadcastTestClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[ - "BroadcastTestClient" + "broadcast-test-client" ] client_2: BroadcastTestClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[ - "BroadcastTestClient" + "broadcast-test-client" ] service: BroadcastTestService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[ - "BroadcastService" + "broadcast-test-service" ] return service, client_1, client_2 diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index 1499df9a..80e7c3b3 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from itertools import product import yaml @@ -23,7 +23,7 @@ def test_capture_nmne(uc2_network: Network): of the "DELETE" SQL command as a malicious network event. """ web_server: Server = uc2_network.get_node_by_hostname("web_server") # noqa - db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] # noqa + db_client: DatabaseClient = web_server.software_manager.software["database-client"] # noqa db_client_connection: DatabaseClientConnection = db_client.get_new_connection() db_server: Server = uc2_network.get_node_by_hostname("database_server") # noqa @@ -100,7 +100,7 @@ def test_describe_state_nmne(uc2_network: Network): only shows MNEs since the last time describe_state was called. """ web_server: Server = uc2_network.get_node_by_hostname("web_server") # noqa - db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] # noqa + db_client: DatabaseClient = web_server.software_manager.software["database-client"] # noqa db_client_connection: DatabaseClientConnection = db_client.get_new_connection() db_server: Server = uc2_network.get_node_by_hostname("database_server") # noqa @@ -214,7 +214,7 @@ def test_capture_nmne_observations(uc2_network: Network): sim.network = uc2_network web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + db_client: DatabaseClient = web_server.software_manager.software["database-client"] db_client_connection: DatabaseClientConnection = db_client.get_new_connection() # Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs diff --git a/tests/integration_tests/network/test_firewall.py b/tests/integration_tests/network/test_firewall.py index b15ee51a..7c2c36c0 100644 --- a/tests/integration_tests/network/test_firewall.py +++ b/tests/integration_tests/network/test_firewall.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address import pytest @@ -7,10 +7,10 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.ntp.ntp_server import NTPServer +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -41,7 +41,9 @@ def dmz_external_internal_network() -> Network: """ network = Network() - firewall_node: Firewall = Firewall(hostname="firewall_1", start_up_duration=0) + firewall_node: Firewall = Firewall.from_config( + config={"type": "firewall", "hostname": "firewall_1", "start_up_duration": 0} + ) firewall_node.power_on() # configure firewall ports firewall_node.configure_external_port( @@ -53,70 +55,83 @@ def dmz_external_internal_network() -> Network: ) # Allow ICMP - firewall_node.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.external_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.external_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + firewall_node.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.external_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.external_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Allow ARP firewall_node.internal_inbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) firewall_node.internal_outbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) firewall_node.external_inbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) firewall_node.external_outbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + firewall_node.dmz_inbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + firewall_node.dmz_outbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) - firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) # external node - external_node = Computer( - hostname="external_node", - ip_address="192.168.10.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.10.1", - start_up_duration=0, + external_node: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "external_node", + "ip_address": "192.168.10.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.10.1", + "start_up_duration": 0, + } ) external_node.power_on() external_node.software_manager.install(NTPServer) - ntp_service: NTPServer = external_node.software_manager.software["NTPServer"] + ntp_service: NTPServer = external_node.software_manager.software["ntp-server"] ntp_service.start() # connect external node to firewall node network.connect(endpoint_b=external_node.network_interface[1], endpoint_a=firewall_node.external_port) # internal node - internal_node = Computer( - hostname="internal_node", - ip_address="192.168.0.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.0.1", - start_up_duration=0, + internal_node: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "internal_node", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.0.1", + "start_up_duration": 0, + } ) internal_node.power_on() internal_node.software_manager.install(NTPClient) - internal_ntp_client: NTPClient = internal_node.software_manager.software["NTPClient"] + internal_ntp_client: NTPClient = internal_node.software_manager.software["ntp-client"] internal_ntp_client.configure(external_node.network_interface[1].ip_address) internal_ntp_client.start() # connect external node to firewall node network.connect(endpoint_b=internal_node.network_interface[1], endpoint_a=firewall_node.internal_port) # dmz node - dmz_node = Computer( - hostname="dmz_node", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + dmz_node: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "dmz_node", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } ) dmz_node.power_on() - dmz_ntp_client: NTPClient = dmz_node.software_manager.software["NTPClient"] + dmz_ntp_client: NTPClient = dmz_node.software_manager.software["ntp-client"] dmz_ntp_client.configure(external_node.network_interface[1].ip_address) dmz_ntp_client.start() # connect external node to firewall node @@ -151,9 +166,9 @@ def test_nodes_can_ping_default_gateway(dmz_external_internal_network): internal_node = dmz_external_internal_network.get_node_by_hostname("internal_node") dmz_node = dmz_external_internal_network.get_node_by_hostname("dmz_node") - assert internal_node.ping(internal_node.default_gateway) # default gateway internal - assert dmz_node.ping(dmz_node.default_gateway) # default gateway dmz - assert external_node.ping(external_node.default_gateway) # default gateway external + assert internal_node.ping(internal_node.config.default_gateway) # default gateway internal + assert dmz_node.ping(dmz_node.config.default_gateway) # default gateway dmz + assert external_node.ping(external_node.config.default_gateway) # default gateway external def test_nodes_can_ping_default_gateway_on_another_subnet(dmz_external_internal_network): @@ -167,14 +182,14 @@ def test_nodes_can_ping_default_gateway_on_another_subnet(dmz_external_internal_ internal_node = dmz_external_internal_network.get_node_by_hostname("internal_node") dmz_node = dmz_external_internal_network.get_node_by_hostname("dmz_node") - assert internal_node.ping(external_node.default_gateway) # internal node to external default gateway - assert internal_node.ping(dmz_node.default_gateway) # internal node to dmz default gateway + assert internal_node.ping(external_node.config.default_gateway) # internal node to external default gateway + assert internal_node.ping(dmz_node.config.default_gateway) # internal node to dmz default gateway - assert dmz_node.ping(internal_node.default_gateway) # dmz node to internal default gateway - assert dmz_node.ping(external_node.default_gateway) # dmz node to external default gateway + assert dmz_node.ping(internal_node.config.default_gateway) # dmz node to internal default gateway + assert dmz_node.ping(external_node.config.default_gateway) # dmz node to external default gateway - assert external_node.ping(external_node.default_gateway) # external node to internal default gateway - assert external_node.ping(dmz_node.default_gateway) # external node to dmz default gateway + assert external_node.ping(external_node.config.default_gateway) # external node to internal default gateway + assert external_node.ping(dmz_node.config.default_gateway) # external node to dmz default gateway def test_nodes_can_ping_each_other(dmz_external_internal_network): @@ -210,8 +225,8 @@ def test_service_blocked(dmz_external_internal_network): firewall = dmz_external_internal_network.get_node_by_hostname("firewall_1") internal_node = dmz_external_internal_network.get_node_by_hostname("internal_node") dmz_node = dmz_external_internal_network.get_node_by_hostname("dmz_node") - internal_ntp_client: NTPClient = internal_node.software_manager.software["NTPClient"] - dmz_ntp_client: NTPClient = dmz_node.software_manager.software["NTPClient"] + internal_ntp_client: NTPClient = internal_node.software_manager.software["ntp-client"] + dmz_ntp_client: NTPClient = dmz_node.software_manager.software["ntp-client"] assert not internal_ntp_client.time @@ -257,13 +272,17 @@ def test_service_allowed_with_rule(dmz_external_internal_network): firewall = dmz_external_internal_network.get_node_by_hostname("firewall_1") internal_node = dmz_external_internal_network.get_node_by_hostname("internal_node") dmz_node = dmz_external_internal_network.get_node_by_hostname("dmz_node") - internal_ntp_client: NTPClient = internal_node.software_manager.software["NTPClient"] - dmz_ntp_client: NTPClient = dmz_node.software_manager.software["NTPClient"] + internal_ntp_client: NTPClient = internal_node.software_manager.software["ntp-client"] + dmz_ntp_client: NTPClient = dmz_node.software_manager.software["ntp-client"] assert not internal_ntp_client.time - firewall.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=1) - firewall.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=1) + firewall.internal_outbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=1 + ) + firewall.internal_inbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=1 + ) internal_ntp_client.request_time() @@ -271,8 +290,12 @@ def test_service_allowed_with_rule(dmz_external_internal_network): assert not dmz_ntp_client.time - firewall.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=1) - firewall.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=1) + firewall.dmz_outbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=1 + ) + firewall.dmz_inbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=1 + ) dmz_ntp_client.request_time() diff --git a/tests/integration_tests/network/test_frame_transmission.py b/tests/integration_tests/network/test_frame_transmission.py index fc2d146e..6a514bdc 100644 --- a/tests/integration_tests/network/test_frame_transmission.py +++ b/tests/integration_tests/network/test_frame_transmission.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC @@ -10,25 +10,31 @@ def test_node_to_node_ping(): """Tests two Computers are able to ping each other.""" network = Network() - client_1 = Computer( - hostname="client_1", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + client_1: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "client_1", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } ) client_1.power_on() - server_1 = Server( - hostname="server_1", - ip_address="192.168.1.11", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + server_1: Server = Server.from_config( + config={ + "type": "server", + "hostname": "server_1", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } ) server_1.power_on() - switch_1 = Switch(hostname="switch_1", start_up_duration=0) + switch_1: Switch = Switch.from_config(config={"type": "switch", "hostname": "switch_1", "start_up_duration": 0}) switch_1.power_on() network.connect(endpoint_a=client_1.network_interface[1], endpoint_b=switch_1.network_interface[1]) @@ -41,14 +47,38 @@ def test_multi_nic(): """Tests that Computers with multiple NICs can ping each other and the data go across the correct links.""" network = Network() - node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + node_a: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_a", + "ip_address": "192.168.0.10", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node_a.power_on() - node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) + node_b: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_b", + "ip_address": "192.168.0.11", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node_b.power_on() node_b.connect_nic(NIC(ip_address="10.0.0.12", subnet_mask="255.0.0.0")) - node_c = Computer(hostname="node_c", ip_address="10.0.0.13", subnet_mask="255.0.0.0", start_up_duration=0) + node_c: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_c", + "ip_address": "10.0.0.13", + "subnet_mask": "255.0.0.0", + "start_up_duration": 0, + } + ) node_c.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) diff --git a/tests/integration_tests/network/test_multi_lan_internet_example_network.py b/tests/integration_tests/network/test_multi_lan_internet_example_network.py index bcc9ad94..381fea62 100644 --- a/tests/integration_tests/network/test_multi_lan_internet_example_network.py +++ b/tests/integration_tests/network/test_multi_lan_internet_example_network.py @@ -1,6 +1,7 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.networks import multi_lan_internet_network_example from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser @@ -12,7 +13,7 @@ def test_all_with_configured_dns_server_ip_can_resolve_url(): network = multi_lan_internet_network_example() for node in network.nodes.values(): - dns_client: DNSClient = node.software_manager.software.get("DNSClient") + dns_client: DNSClient = node.software_manager.software.get("dns-client") if not dns_client: continue @@ -24,8 +25,8 @@ def test_all_with_configured_dns_server_ip_can_resolve_url(): def test_external_pcs_can_access_sometech_website(): network = multi_lan_internet_network_example() - pc_1_browser: WebBrowser = network.get_node_by_hostname("pc_1").software_manager.software["WebBrowser"] - pc_2_browser: WebBrowser = network.get_node_by_hostname("pc_2").software_manager.software["WebBrowser"] + pc_1_browser: WebBrowser = network.get_node_by_hostname("pc_1").software_manager.software["web-browser"] + pc_2_browser: WebBrowser = network.get_node_by_hostname("pc_2").software_manager.software["web-browser"] assert pc_1_browser.get_webpage() assert pc_2_browser.get_webpage() @@ -34,8 +35,8 @@ def test_external_pcs_can_access_sometech_website(): def test_external_pcs_cannot_access_sometech_db(): network = multi_lan_internet_network_example() - pc_1_db_client: DatabaseClient = network.get_node_by_hostname("pc_1").software_manager.software["DatabaseClient"] - pc_2_db_client: DatabaseClient = network.get_node_by_hostname("pc_2").software_manager.software["DatabaseClient"] + pc_1_db_client: DatabaseClient = network.get_node_by_hostname("pc_1").software_manager.software["database-client"] + pc_2_db_client: DatabaseClient = network.get_node_by_hostname("pc_2").software_manager.software["database-client"] assert not pc_1_db_client.get_new_connection() assert not pc_2_db_client.get_new_connection() @@ -47,8 +48,8 @@ def test_external_pcs_cannot_access_ftp_on_sometech_storage_server(): some_tech_storage_srv = network.get_node_by_hostname("some_tech_storage_srv") some_tech_storage_srv.file_system.create_file(file_name="test.png") - pc_1_ftp_client: FTPClient = network.get_node_by_hostname("pc_1").software_manager.software["FTPClient"] - pc_2_ftp_client: FTPClient = network.get_node_by_hostname("pc_2").software_manager.software["FTPClient"] + pc_1_ftp_client: FTPClient = network.get_node_by_hostname("pc_1").software_manager.software["ftp-client"] + pc_2_ftp_client: FTPClient = network.get_node_by_hostname("pc_2").software_manager.software["ftp-client"] assert not pc_1_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -71,7 +72,7 @@ def test_sometech_webserver_can_access_sometech_db_server(): network = multi_lan_internet_network_example() web_db_client: DatabaseClient = network.get_node_by_hostname("some_tech_web_srv").software_manager.software[ - "DatabaseClient" + "database-client" ] assert web_db_client.get_new_connection() @@ -85,7 +86,7 @@ def test_sometech_webserver_cannot_access_ftp_on_sometech_storage_server(): web_server: Server = network.get_node_by_hostname("some_tech_web_srv") web_server.software_manager.install(FTPClient) - web_ftp_client: FTPClient = web_server.software_manager.software["FTPClient"] + web_ftp_client: FTPClient = web_server.software_manager.software["ftp-client"] assert not web_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -101,13 +102,13 @@ def test_sometech_dev_pcs_can_access_sometech_website(): some_tech_snr_dev_pc: Computer = network.get_node_by_hostname("some_tech_snr_dev_pc") - snr_dev_browser: WebBrowser = some_tech_snr_dev_pc.software_manager.software["WebBrowser"] + snr_dev_browser: WebBrowser = some_tech_snr_dev_pc.software_manager.software["web-browser"] assert snr_dev_browser.get_webpage() some_tech_jnr_dev_pc: Computer = network.get_node_by_hostname("some_tech_jnr_dev_pc") - jnr_dev_browser: WebBrowser = some_tech_jnr_dev_pc.software_manager.software["WebBrowser"] + jnr_dev_browser: WebBrowser = some_tech_jnr_dev_pc.software_manager.software["web-browser"] assert jnr_dev_browser.get_webpage() @@ -116,12 +117,12 @@ def test_sometech_dev_pcs_can_connect_to_sometech_db_server(): network = multi_lan_internet_network_example() some_tech_snr_dev_pc: Computer = network.get_node_by_hostname("some_tech_snr_dev_pc") - snr_dev_db_client: DatabaseClient = some_tech_snr_dev_pc.software_manager.software["DatabaseClient"] + snr_dev_db_client: DatabaseClient = some_tech_snr_dev_pc.software_manager.software["database-client"] assert snr_dev_db_client.get_new_connection() some_tech_jnr_dev_pc: Computer = network.get_node_by_hostname("some_tech_jnr_dev_pc") - jnr_dev_db_client: DatabaseClient = some_tech_jnr_dev_pc.software_manager.software["DatabaseClient"] + jnr_dev_db_client: DatabaseClient = some_tech_jnr_dev_pc.software_manager.software["database-client"] assert jnr_dev_db_client.get_new_connection() @@ -133,7 +134,7 @@ def test_sometech_snr_dev_can_access_ftp_on_sometech_storage_server(): some_tech_storage_srv.file_system.create_file(file_name="test.png") some_tech_snr_dev_pc: Computer = network.get_node_by_hostname("some_tech_snr_dev_pc") - snr_dev_ftp_client: FTPClient = some_tech_snr_dev_pc.software_manager.software["FTPClient"] + snr_dev_ftp_client: FTPClient = some_tech_snr_dev_pc.software_manager.software["ftp-client"] assert snr_dev_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -151,7 +152,7 @@ def test_sometech_jnr_dev_cannot_access_ftp_on_sometech_storage_server(): some_tech_storage_srv.file_system.create_file(file_name="test.png") some_tech_jnr_dev_pc: Computer = network.get_node_by_hostname("some_tech_jnr_dev_pc") - jnr_dev_ftp_client: FTPClient = some_tech_jnr_dev_pc.software_manager.software["FTPClient"] + jnr_dev_ftp_client: FTPClient = some_tech_jnr_dev_pc.software_manager.software["ftp-client"] assert not jnr_dev_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -167,7 +168,7 @@ def test_sometech_hr_pc_can_access_sometech_website(): some_tech_hr_pc: Computer = network.get_node_by_hostname("some_tech_hr_1") - hr_browser: WebBrowser = some_tech_hr_pc.software_manager.software["WebBrowser"] + hr_browser: WebBrowser = some_tech_hr_pc.software_manager.software["web-browser"] assert hr_browser.get_webpage() @@ -177,7 +178,7 @@ def test_sometech_hr_pc_cannot_access_sometech_db(): some_tech_hr_pc: Computer = network.get_node_by_hostname("some_tech_hr_1") - hr_db_client: DatabaseClient = some_tech_hr_pc.software_manager.software["DatabaseClient"] + hr_db_client: DatabaseClient = some_tech_hr_pc.software_manager.software["database-client"] assert not hr_db_client.get_new_connection() @@ -189,7 +190,7 @@ def test_sometech_hr_pc_cannot_access_ftp_on_sometech_storage_server(): some_tech_storage_srv.file_system.create_file(file_name="test.png") some_tech_hr_pc: Computer = network.get_node_by_hostname("some_tech_hr_1") - hr_ftp_client: FTPClient = some_tech_hr_pc.software_manager.software["FTPClient"] + hr_ftp_client: FTPClient = some_tech_hr_pc.software_manager.software["ftp-client"] assert not hr_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, diff --git a/tests/integration_tests/network/test_network_creation.py b/tests/integration_tests/network/test_network_creation.py index 794ddde5..4d88eac3 100644 --- a/tests/integration_tests/network/test_network_creation.py +++ b/tests/integration_tests/network/test_network_creation.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC @@ -27,7 +27,15 @@ def test_network(example_network): def test_adding_removing_nodes(): """Check that we can create and add a node to a network.""" net = Network() - n1 = Computer(hostname="computer", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + n1 = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) net.add_node(n1) assert n1.parent is net assert n1 in net @@ -37,10 +45,18 @@ def test_adding_removing_nodes(): assert n1 not in net -def test_readding_node(): - """Check that warning is raised when readding a node.""" +def test_reading_node(): + """Check that warning is raised when reading a node.""" net = Network() - n1 = Computer(hostname="computer", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + n1 = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) net.add_node(n1) net.add_node(n1) assert n1.parent is net @@ -50,7 +66,15 @@ def test_readding_node(): def test_removing_nonexistent_node(): """Check that warning is raised when trying to remove a node that is not in the network.""" net = Network() - n1 = Computer(hostname="computer1", ip_address="192.168.1.1", subnet_mask="255.255.255.0", start_up_duration=0) + n1 = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer1", + "ip_address": "192.168.1.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) net.remove_node(n1) assert n1.parent is None assert n1 not in net @@ -59,8 +83,24 @@ def test_removing_nonexistent_node(): def test_connecting_nodes(): """Check that two nodes on the network can be connected.""" net = Network() - n1 = Computer(hostname="computer1", ip_address="192.168.1.1", subnet_mask="255.255.255.0", start_up_duration=0) - n2 = Computer(hostname="computer2", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + n1: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer1", + "ip_address": "192.168.1.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) + n2: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer2", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) net.add_node(n1) net.add_node(n2) @@ -75,7 +115,15 @@ def test_connecting_nodes(): def test_connecting_node_to_itself_fails(): net = Network() - node = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) + node = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_b", + "ip_address": "192.168.0.11", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node.power_on() node.connect_nic(NIC(ip_address="10.0.0.12", subnet_mask="255.0.0.0")) @@ -92,8 +140,24 @@ def test_connecting_node_to_itself_fails(): def test_disconnecting_nodes(): net = Network() - n1 = Computer(hostname="computer1", ip_address="192.168.1.1", subnet_mask="255.255.255.0", start_up_duration=0) - n2 = Computer(hostname="computer2", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + n1 = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer1", + "ip_address": "192.168.1.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) + n2 = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer2", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) net.connect(n1.network_interface[1], n2.network_interface[1]) assert len(net.links) == 1 diff --git a/tests/integration_tests/network/test_nic_link_connection.py b/tests/integration_tests/network/test_nic_link_connection.py index ab9160c8..8c45f511 100644 --- a/tests/integration_tests/network/test_nic_link_connection.py +++ b/tests/integration_tests/network/test_nic_link_connection.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from primaite.simulator.network.hardware.base import Link diff --git a/tests/integration_tests/network/test_routing.py b/tests/integration_tests/network/test_routing.py index e234b4e5..ccf7c8ff 100644 --- a/tests/integration_tests/network/test_routing.py +++ b/tests/integration_tests/network/test_routing.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Tuple import pytest @@ -6,34 +6,40 @@ import pytest from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.ntp.ntp_server import NTPServer +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def pc_a_pc_b_router_1() -> Tuple[Computer, Computer, Router]: network = Network() - pc_a = Computer( - hostname="pc_a", - ip_address="192.168.0.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.0.1", - start_up_duration=0, + pc_a = Computer.from_config( + config={ + "type": "computer", + "hostname": "pc_a", + "ip_address": "192.168.0.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.0.1", + "start_up_duration": 0, + } ) pc_a.power_on() - pc_b = Computer( - hostname="pc_b", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + pc_b = Computer.from_config( + config={ + "type": "computer", + "hostname": "pc_b", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } ) pc_b.power_on() - router_1 = Router(hostname="router_1", start_up_duration=0) + router_1 = Router.from_config(config={"type": "router", "hostname": "router_1", "start_up_duration": 0}) router_1.power_on() router_1.configure_port(1, "192.168.0.1", "255.255.255.0") @@ -52,18 +58,21 @@ def multi_hop_network() -> Network: network = Network() # Configure PC A - pc_a = Computer( - hostname="pc_a", - ip_address="192.168.0.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.0.1", - start_up_duration=0, + pc_a: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "pc_a", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.0.1", + "start_up_duration": 0, + } ) pc_a.power_on() network.add_node(pc_a) # Configure Router 1 - router_1 = Router(hostname="router_1", start_up_duration=0) + router_1: Router = Router.from_config(config={"type": "router", "hostname": "router_1", "start_up_duration": 0}) router_1.power_on() network.add_node(router_1) @@ -73,21 +82,27 @@ def multi_hop_network() -> Network: router_1.enable_port(2) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Configure PC B - pc_b = Computer( - hostname="pc_b", - ip_address="192.168.2.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.2.1", - start_up_duration=0, + pc_b: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "pc_b", + "ip_address": "192.168.2.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.2.1", + "start_up_duration": 0, + } ) pc_b.power_on() network.add_node(pc_b) # Configure Router 2 - router_2 = Router(hostname="router_2", start_up_duration=0) + router_2: Router = Router.from_config(config={"type": "router", "hostname": "router_2", "start_up_duration": 0}) router_2.power_on() network.add_node(router_2) @@ -110,13 +125,13 @@ def multi_hop_network() -> Network: def test_ping_default_gateway(pc_a_pc_b_router_1): pc_a, pc_b, router_1 = pc_a_pc_b_router_1 - assert pc_a.ping(pc_a.default_gateway) + assert pc_a.ping(pc_a.config.default_gateway) def test_ping_other_router_port(pc_a_pc_b_router_1): pc_a, pc_b, router_1 = pc_a_pc_b_router_1 - assert pc_a.ping(pc_b.default_gateway) + assert pc_a.ping(pc_b.config.default_gateway) def test_host_on_other_subnet(pc_a_pc_b_router_1): @@ -185,19 +200,23 @@ def test_routing_services(multi_hop_network): pc_b = multi_hop_network.get_node_by_hostname("pc_b") pc_a.software_manager.install(NTPClient) - ntp_client = pc_a.software_manager.software["NTPClient"] + ntp_client = pc_a.software_manager.software["ntp-client"] ntp_client.start() pc_b.software_manager.install(NTPServer) - pc_b.software_manager.software["NTPServer"].start() + pc_b.software_manager.software["ntp-server"].start() ntp_client.configure(ntp_server_ip_address=pc_b.network_interface[1].ip_address) router_1: Router = multi_hop_network.get_node_by_hostname("router_1") # noqa router_2: Router = multi_hop_network.get_node_by_hostname("router_2") # noqa - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=21) - router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=21) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=21 + ) + router_2.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=21 + ) assert ntp_client.time is None ntp_client.request_time() diff --git a/tests/integration_tests/network/test_switched_network.py b/tests/integration_tests/network/test_switched_network.py index ae0aa8a7..67392da3 100644 --- a/tests/integration_tests/network/test_switched_network.py +++ b/tests/integration_tests/network/test_switched_network.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK def test_switched_network(client_switch_server): """Tests a node can ping another node via the switch.""" computer, switch, server = client_switch_server diff --git a/tests/integration_tests/network/test_users_creation_from_config.py b/tests/integration_tests/network/test_users_creation_from_config.py index 8cd3b037..5340c369 100644 --- a/tests/integration_tests/network/test_users_creation_from_config.py +++ b/tests/integration_tests/network/test_users_creation_from_config.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import yaml from primaite.game.game import PrimaiteGame @@ -15,7 +15,7 @@ def test_users_from_config(): client_1 = network.get_node_by_hostname("client_1") - user_manager: UserManager = client_1.software_manager.software["UserManager"] + user_manager: UserManager = client_1.software_manager.software["user-manager"] assert len(user_manager.users) == 3 diff --git a/tests/integration_tests/network/test_wireless_router.py b/tests/integration_tests/network/test_wireless_router.py index 9a22208b..74b97c2f 100644 --- a/tests/integration_tests/network/test_wireless_router.py +++ b/tests/integration_tests/network/test_wireless_router.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest import yaml @@ -7,8 +7,8 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT @@ -17,18 +17,23 @@ def wireless_wan_network(): network = Network() # Configure PC A - pc_a = Computer( - hostname="pc_a", - ip_address="192.168.0.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.0.1", - start_up_duration=0, + pc_a = Computer.from_config( + config={ + "type": "computer", + "hostname": "pc_a", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.0.1", + "start_up_duration": 0, + } ) pc_a.power_on() network.add_node(pc_a) # Configure Router 1 - router_1 = WirelessRouter(hostname="router_1", start_up_duration=0, airspace=network.airspace) + router_1 = WirelessRouter.from_config( + config={"type": "wireless-router", "hostname": "router_1", "start_up_duration": 0}, airspace=network.airspace + ) router_1.power_on() network.add_node(router_1) @@ -37,21 +42,29 @@ def wireless_wan_network(): network.connect(pc_a.network_interface[1], router_1.network_interface[2]) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Configure PC B - pc_b = Computer( - hostname="pc_b", - ip_address="192.168.2.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.2.1", - start_up_duration=0, + pc_b: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "pc_b", + "ip_address": "192.168.2.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.2.1", + "start_up_duration": 0, + } ) pc_b.power_on() network.add_node(pc_b) # Configure Router 2 - router_2 = WirelessRouter(hostname="router_2", start_up_duration=0, airspace=network.airspace) + router_2: WirelessRouter = WirelessRouter.from_config( + config={"type": "wireless-router", "hostname": "router_2", "start_up_duration": 0}, airspace=network.airspace + ) router_2.power_on() network.add_node(router_2) @@ -95,8 +108,8 @@ def wireless_wan_network_from_config_yaml(): def test_cross_wireless_wan_connectivity(wireless_wan_network): pc_a, pc_b, router_1, router_2 = wireless_wan_network # Ensure that PCs can ping across routers before any frequency change - assert pc_a.ping(pc_a.default_gateway), "PC A should ping its default gateway successfully." - assert pc_b.ping(pc_b.default_gateway), "PC B should ping its default gateway successfully." + assert pc_a.ping(pc_a.config.default_gateway), "PC A should ping its default gateway successfully." + assert pc_b.ping(pc_b.config.default_gateway), "PC B should ping its default gateway successfully." assert pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should ping PC B across routers successfully." assert pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should ping PC A across routers successfully." @@ -106,8 +119,8 @@ def test_cross_wireless_wan_connectivity_from_yaml(wireless_wan_network_from_con pc_a = wireless_wan_network_from_config_yaml.get_node_by_hostname("pc_a") pc_b = wireless_wan_network_from_config_yaml.get_node_by_hostname("pc_b") - assert pc_a.ping(pc_a.default_gateway), "PC A should ping its default gateway successfully." - assert pc_b.ping(pc_b.default_gateway), "PC B should ping its default gateway successfully." + assert pc_a.ping(pc_a.config.default_gateway), "PC A should ping its default gateway successfully." + assert pc_b.ping(pc_b.config.default_gateway), "PC B should ping its default gateway successfully." assert pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should ping PC B across routers successfully." assert pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should ping PC A across routers successfully." diff --git a/tests/integration_tests/system/__init__.py b/tests/integration_tests/system/__init__.py index be6c00e7..836b79af 100644 --- a/tests/integration_tests/system/__init__.py +++ b/tests/integration_tests/system/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py index 9d12f2cf..62eda37d 100644 --- a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py +++ b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address from typing import Tuple @@ -13,8 +13,6 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import AccessControlList, ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.c2.c2_beacon import C2Beacon @@ -25,6 +23,8 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT @@ -34,52 +34,64 @@ def basic_network() -> Network: # Creating two generic nodes for the C2 Server and the C2 Beacon. - node_a = Computer( - hostname="node_a", - ip_address="192.168.0.2", - subnet_mask="255.255.255.252", - default_gateway="192.168.0.1", - start_up_duration=0, - ) + node_a_cfg = { + "type": "computer", + "hostname": "node_a", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.252", + "default_gateway": "192.168.0.1", + "start_up_duration": 0, + } + + node_a: Computer = Computer.from_config(config=node_a_cfg) node_a.power_on() node_a.software_manager.get_open_ports() node_a.software_manager.install(software_class=C2Server) - node_b = Computer( - hostname="node_b", - ip_address="192.168.255.2", - subnet_mask="255.255.255.248", - default_gateway="192.168.255.1", - start_up_duration=0, - ) + node_b_cfg = { + "type": "computer", + "hostname": "node_b", + "ip_address": "192.168.255.2", + "subnet_mask": "255.255.255.248", + "default_gateway": "192.168.255.1", + "start_up_duration": 0, + } + node_b: Computer = Computer.from_config(config=node_b_cfg) node_b.power_on() node_b.software_manager.install(software_class=C2Beacon) # Creating a generic computer for testing remote terminal connections. - node_c = Computer( - hostname="node_c", - ip_address="192.168.255.3", - subnet_mask="255.255.255.248", - default_gateway="192.168.255.1", - start_up_duration=0, - ) + node_c_cfg = { + "type": "computer", + "hostname": "node_c", + "ip_address": "192.168.255.3", + "subnet_mask": "255.255.255.248", + "default_gateway": "192.168.255.1", + "start_up_duration": 0, + } + + node_c: Computer = Computer.from_config(config=node_c_cfg) node_c.power_on() # Creating a router to sit between node 1 and node 2. - router = Router(hostname="router", num_ports=3, start_up_duration=0) + router = Router.from_config(config={"type": "router", "hostname": "router", "num_ports": 3, "start_up_duration": 0}) # Default allow all. router.acl.add_rule(action=ACLAction.PERMIT) router.power_on() # Creating switches for each client. - switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0) + switch_1 = Switch.from_config( + config={"type": "switch", "hostname": "switch_1", "num_ports": 6, "start_up_duration": 0} + ) switch_1.power_on() # Connecting the switches to the router. router.configure_port(port=1, ip_address="192.168.0.1", subnet_mask="255.255.255.252") network.connect(endpoint_a=router.network_interface[1], endpoint_b=switch_1.network_interface[6]) - switch_2 = Switch(hostname="switch_2", num_ports=6, start_up_duration=0) + switch_2 = Switch.from_config( + config={"type": "switch", "hostname": "switch_2", "num_ports": 6, "start_up_duration": 0} + ) switch_2.power_on() network.connect(endpoint_a=router.network_interface[2], endpoint_b=switch_2.network_interface[6]) @@ -99,15 +111,15 @@ def basic_network() -> Network: def setup_c2(given_network: Network): """Installs the C2 Beacon & Server, configures and then returns.""" computer_a: Computer = given_network.get_node_by_hostname("node_a") - c2_server: C2Server = computer_a.software_manager.software.get("C2Server") + c2_server: C2Server = computer_a.software_manager.software.get("c2-server") computer_a.software_manager.install(DatabaseService) - computer_a.software_manager.software["DatabaseService"].start() + computer_a.software_manager.software["database-service"].start() computer_b: Computer = given_network.get_node_by_hostname("node_b") - c2_beacon: C2Beacon = computer_b.software_manager.software.get("C2Beacon") + c2_beacon: C2Beacon = computer_b.software_manager.software.get("c2-beacon") computer_b.software_manager.install(DatabaseClient) - computer_b.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.2")) - computer_b.software_manager.software["DatabaseClient"].run() + computer_b.software_manager.software["database-client"].configure(server_ip_address=IPv4Address("192.168.0.2")) + computer_b.software_manager.software["database-client"].run() c2_beacon.configure(c2_server_ip_address="192.168.0.2", keep_alive_frequency=2) c2_server.run() @@ -173,13 +185,13 @@ def test_c2_suite_configure_request(basic_network): c2_beacon_config = { "c2_server_ip_address": "192.168.0.2", "keep_alive_frequency": 5, - "masquerade_protocol": "TCP", - "masquerade_port": "HTTP", + "masquerade_protocol": "tcp", + "masquerade_port": 80, } - network.apply_request(["node", "node_b", "application", "C2Beacon", "configure", c2_beacon_config]) + network.apply_request(["node", "node_b", "application", "c2-beacon", "configure", c2_beacon_config]) network.apply_timestep(0) - network.apply_request(["node", "node_b", "application", "C2Beacon", "execute"]) + network.apply_request(["node", "node_b", "application", "c2-beacon", "execute"]) assert c2_beacon.c2_connection_active is True assert c2_server.c2_connection_active is True @@ -195,13 +207,13 @@ def test_c2_suite_ransomware_commands(basic_network): # Testing Via Requests: computer_b.software_manager.install(software_class=RansomwareScript) ransomware_config = {"server_ip_address": "192.168.0.2"} - network.apply_request(["node", "node_a", "application", "C2Server", "ransomware_configure", ransomware_config]) + network.apply_request(["node", "node_a", "application", "c2-server", "ransomware_configure", ransomware_config]) - ransomware_script: RansomwareScript = computer_b.software_manager.software["RansomwareScript"] + ransomware_script: RansomwareScript = computer_b.software_manager.software["ransomware-script"] assert ransomware_script.server_ip_address == "192.168.0.2" - network.apply_request(["node", "node_a", "application", "C2Server", "ransomware_launch"]) + network.apply_request(["node", "node_a", "application", "c2-server", "ransomware_launch"]) database_file = computer_a.software_manager.file_system.get_file("database", "database.db") @@ -227,7 +239,7 @@ def test_c2_suite_acl_block(basic_network): assert c2_beacon.c2_connection_active == True # Now we add a HTTP blocking acl (Thus preventing a keep alive) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0) c2_beacon.apply_timestep(2) c2_beacon.apply_timestep(3) @@ -322,8 +334,8 @@ def test_c2_suite_acl_bypass(basic_network): ################ Confirm Default Setup ######################### # Permitting all HTTP & FTP traffic - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=1) c2_beacon.apply_timestep(0) assert c2_beacon.keep_alive_inactivity == 1 @@ -337,7 +349,7 @@ def test_c2_suite_acl_bypass(basic_network): ################ Denying HTTP Traffic ######################### # Now we add a HTTP blocking acl (Thus preventing a keep alive) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0) blocking_acl: AccessControlList = router.acl.acl[0] # Asserts to show the C2 Suite is unable to maintain connection: @@ -359,8 +371,8 @@ def test_c2_suite_acl_bypass(basic_network): c2_beacon.configure( c2_server_ip_address="192.168.0.2", keep_alive_frequency=2, - masquerade_port=Port.FTP, - masquerade_protocol=IPProtocol.TCP, + masquerade_port=PORT_LOOKUP["FTP"], + masquerade_protocol=PROTOCOL_LOOKUP["TCP"], ) c2_beacon.establish() @@ -407,8 +419,8 @@ def test_c2_suite_acl_bypass(basic_network): ################ Denying FTP Traffic & Enable HTTP ######################### # Blocking FTP and re-permitting HTTP: - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.FTP, dst_port=Port.FTP, position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=1) blocking_acl: AccessControlList = router.acl.acl[1] # Asserts to show the C2 Suite is unable to maintain connection: @@ -430,8 +442,8 @@ def test_c2_suite_acl_bypass(basic_network): c2_beacon.configure( c2_server_ip_address="192.168.0.2", keep_alive_frequency=2, - masquerade_port=Port.HTTP, - masquerade_protocol=IPProtocol.TCP, + masquerade_port=PORT_LOOKUP["HTTP"], + masquerade_protocol=PROTOCOL_LOOKUP["TCP"], ) c2_beacon.establish() @@ -491,10 +503,16 @@ def test_c2_suite_yaml(): yaml_network = game.simulation.network computer_a: Computer = yaml_network.get_node_by_hostname("node_a") - c2_server: C2Server = computer_a.software_manager.software.get("C2Server") + c2_server: C2Server = computer_a.software_manager.software.get("c2-server") computer_b: Computer = yaml_network.get_node_by_hostname("node_b") - c2_beacon: C2Beacon = computer_b.software_manager.software.get("C2Beacon") + c2_beacon: C2Beacon = computer_b.software_manager.software.get("c2-beacon") + c2_beacon.configure( + c2_server_ip_address=c2_beacon.config.c2_server_ip_address, + keep_alive_frequency=c2_beacon.config.keep_alive_frequency, + masquerade_port=c2_beacon.config.masquerade_port, + masquerade_protocol=c2_beacon.config.masquerade_protocol, + ) assert c2_server.operating_state == ApplicationOperatingState.RUNNING diff --git a/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py b/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py index 2e87578d..1a85b20d 100644 --- a/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py +++ b/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address from typing import Tuple @@ -9,7 +9,6 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection from primaite.simulator.system.applications.red_applications.data_manipulation_bot import ( @@ -19,6 +18,7 @@ from primaite.simulator.system.applications.red_applications.data_manipulation_b from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -27,20 +27,20 @@ def data_manipulation_bot_and_db_server(client_server) -> Tuple[DataManipulation # install db client on computer computer.software_manager.install(DatabaseClient) - db_client: DatabaseClient = computer.software_manager.software.get("DatabaseClient") + db_client: DatabaseClient = computer.software_manager.software.get("database-client") db_client.run() # Install DoSBot on computer computer.software_manager.install(DataManipulationBot) - data_manipulation_bot: DataManipulationBot = computer.software_manager.software.get("DataManipulationBot") + data_manipulation_bot: DataManipulationBot = computer.software_manager.software.get("data-manipulation-bot") data_manipulation_bot.configure( server_ip_address=IPv4Address(server.network_interface[1].ip_address), payload="DELETE" ) # Install DB Server service on server server.software_manager.install(DatabaseService) - db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService") + db_server_service: DatabaseService = server.software_manager.software.get("database-service") db_server_service.start() return data_manipulation_bot, computer, db_server_service, server @@ -52,7 +52,10 @@ def data_manipulation_db_server_green_client(example_network) -> Network: router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) client_1: Computer = network.get_node_by_hostname("client_1") @@ -61,26 +64,26 @@ def data_manipulation_db_server_green_client(example_network) -> Network: # install db client on client 1 client_1.software_manager.install(DatabaseClient) - db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") + db_client: DatabaseClient = client_1.software_manager.software.get("database-client") db_client.run() # install Data Manipulation bot on client 1 client_1.software_manager.install(DataManipulationBot) - data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") + data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("data-manipulation-bot") data_manipulation_bot.configure( server_ip_address=IPv4Address(server.network_interface[1].ip_address), payload="DELETE" ) # install db server service on server server.software_manager.install(DatabaseService) - db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService") + db_server_service: DatabaseService = server.software_manager.software.get("database-service") db_server_service.start() # Install DB client (green) on client 2 client_2.software_manager.install(DatabaseClient) - database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient") + database_client: DatabaseClient = client_2.software_manager.software.get("database-client") database_client.configure(server_ip_address=IPv4Address(server.network_interface[1].ip_address)) database_client.run() @@ -134,13 +137,13 @@ def test_data_manipulation_disrupts_green_agent_connection(data_manipulation_db_ network: Network = data_manipulation_db_server_green_client client_1: Computer = network.get_node_by_hostname("client_1") - data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") + data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("data-manipulation-bot") client_2: Computer = network.get_node_by_hostname("client_2") - green_db_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient") + green_db_client: DatabaseClient = client_2.software_manager.software.get("database-client") server: Server = network.get_node_by_hostname("server_1") - db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService") + db_server_service: DatabaseService = server.software_manager.software.get("database-service") green_db_connection: DatabaseClientConnection = green_db_client.get_new_connection() diff --git a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py index 68c1fbfe..47ddb504 100644 --- a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py +++ b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address from typing import Tuple @@ -8,12 +8,12 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -23,15 +23,15 @@ def dos_bot_and_db_server(client_server) -> Tuple[DoSBot, Computer, DatabaseServ # Install DoSBot on computer computer.software_manager.install(DoSBot) - dos_bot: DoSBot = computer.software_manager.software.get("DoSBot") + dos_bot: DoSBot = computer.software_manager.software.get("dos-bot") dos_bot.configure( target_ip_address=IPv4Address(server.network_interface[1].ip_address), - target_port=Port.POSTGRES_SERVER, + target_port=PORT_LOOKUP["POSTGRES_SERVER"], ) # Install DB Server service on server server.software_manager.install(DatabaseService) - db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService") + db_server_service: DatabaseService = server.software_manager.software.get("database-service") db_server_service.start() return dos_bot, computer, db_server_service, server @@ -43,7 +43,10 @@ def dos_bot_db_server_green_client(example_network) -> Network: router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) client_1: Computer = network.get_node_by_hostname("client_1") @@ -53,21 +56,21 @@ def dos_bot_db_server_green_client(example_network) -> Network: # install DoS bot on client 1 client_1.software_manager.install(DoSBot) - dos_bot: DoSBot = client_1.software_manager.software.get("DoSBot") + dos_bot: DoSBot = client_1.software_manager.software.get("dos-bot") dos_bot.configure( target_ip_address=IPv4Address(server.network_interface[1].ip_address), - target_port=Port.POSTGRES_SERVER, + target_port=PORT_LOOKUP["POSTGRES_SERVER"], ) # install db server service on server server.software_manager.install(DatabaseService) - db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService") + db_server_service: DatabaseService = server.software_manager.software.get("database-service") db_server_service.start() # Install DB client (green) on client 2 client_2.software_manager.install(DatabaseClient) - database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient") + database_client: DatabaseClient = client_2.software_manager.software.get("database-client") database_client.configure(server_ip_address=IPv4Address("192.168.0.1")) database_client.run() @@ -156,13 +159,13 @@ def test_dos_blocks_green_agent_connection(dos_bot_db_server_green_client): network: Network = dos_bot_db_server_green_client client_1: Computer = network.get_node_by_hostname("client_1") - dos_bot: DoSBot = client_1.software_manager.software.get("DoSBot") + dos_bot: DoSBot = client_1.software_manager.software.get("dos-bot") client_2: Computer = network.get_node_by_hostname("client_2") - green_db_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient") + green_db_client: DatabaseClient = client_2.software_manager.software.get("database-client") server: Server = network.get_node_by_hostname("server_1") - db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService") + db_server_service: DatabaseService = server.software_manager.software.get("database-service") assert db_server_service.health_state_actual is SoftwareHealthState.GOOD diff --git a/tests/integration_tests/system/red_applications/test_ransomware_script.py b/tests/integration_tests/system/red_applications/test_ransomware_script.py index 97abafb5..a62b2fb6 100644 --- a/tests/integration_tests/system/red_applications/test_ransomware_script.py +++ b/tests/integration_tests/system/red_applications/test_ransomware_script.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address from typing import Tuple @@ -9,11 +9,11 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -22,20 +22,20 @@ def ransomware_script_and_db_server(client_server) -> Tuple[RansomwareScript, Co # install db client on computer computer.software_manager.install(DatabaseClient) - db_client: DatabaseClient = computer.software_manager.software.get("DatabaseClient") + db_client: DatabaseClient = computer.software_manager.software.get("database-client") db_client.run() # Install DoSBot on computer computer.software_manager.install(RansomwareScript) - ransomware_script_application: RansomwareScript = computer.software_manager.software.get("RansomwareScript") + ransomware_script_application: RansomwareScript = computer.software_manager.software.get("ransomware-script") ransomware_script_application.configure( server_ip_address=IPv4Address(server.network_interface[1].ip_address), payload="ENCRYPT" ) # Install DB Server service on server server.software_manager.install(DatabaseService) - db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService") + db_server_service: DatabaseService = server.software_manager.software.get("database-service") db_server_service.start() return ransomware_script_application, computer, db_server_service, server @@ -47,7 +47,10 @@ def ransomware_script_db_server_green_client(example_network) -> Network: router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) client_1: Computer = network.get_node_by_hostname("client_1") @@ -56,26 +59,26 @@ def ransomware_script_db_server_green_client(example_network) -> Network: # install db client on client 1 client_1.software_manager.install(DatabaseClient) - db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") + db_client: DatabaseClient = client_1.software_manager.software.get("database-client") db_client.run() # install Ransomware Script bot on client 1 client_1.software_manager.install(RansomwareScript) - ransomware_script_application: RansomwareScript = client_1.software_manager.software.get("RansomwareScript") + ransomware_script_application: RansomwareScript = client_1.software_manager.software.get("ransomware-script") ransomware_script_application.configure( server_ip_address=IPv4Address(server.network_interface[1].ip_address), payload="ENCRYPT" ) # install db server service on server server.software_manager.install(DatabaseService) - db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService") + db_server_service: DatabaseService = server.software_manager.software.get("database-service") db_server_service.start() # Install DB client (green) on client 2 client_2.software_manager.install(DatabaseClient) - database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient") + database_client: DatabaseClient = client_2.software_manager.software.get("database-client") database_client.configure(server_ip_address=IPv4Address(server.network_interface[1].ip_address)) database_client.run() @@ -107,15 +110,15 @@ def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_ network: Network = ransomware_script_db_server_green_client client_1: Computer = network.get_node_by_hostname("client_1") - ransomware_script_application: RansomwareScript = client_1.software_manager.software.get("RansomwareScript") + ransomware_script_application: RansomwareScript = client_1.software_manager.software.get("ransomware-script") client_2: Computer = network.get_node_by_hostname("client_2") - green_db_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient") + green_db_client: DatabaseClient = client_2.software_manager.software.get("database-client") green_db_client.connect() green_db_client_connection: DatabaseClientConnection = green_db_client.get_new_connection() server: Server = network.get_node_by_hostname("server_1") - db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService") + db_server_service: DatabaseService = server.software_manager.software.get("database-service") assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.GOOD assert green_db_client.query("SELECT") is True diff --git a/tests/integration_tests/system/test_application_on_node.py b/tests/integration_tests/system/test_application_on_node.py index ffb5cc7f..38a7ca03 100644 --- a/tests/integration_tests/system/test_application_on_node.py +++ b/tests/integration_tests/system/test_application_on_node.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Tuple import pytest @@ -10,18 +10,21 @@ from primaite.simulator.system.applications.application import Application, Appl @pytest.fixture(scope="function") def populated_node(application_class) -> Tuple[Application, Computer]: - computer: Computer = Computer( - hostname="test_computer", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - shut_down_duration=0, + computer: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "test_computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + "shut_down_duration": 0, + } ) computer.power_on() computer.software_manager.install(application_class) - app = computer.software_manager.software.get("DummyApplication") + app = computer.software_manager.software.get("dummy-application") app.run() return app, computer @@ -29,17 +32,20 @@ def populated_node(application_class) -> Tuple[Application, Computer]: def test_application_on_offline_node(application_class): """Test to check that the application cannot be interacted with when node it is on is off.""" - computer: Computer = Computer( - hostname="test_computer", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - shut_down_duration=0, + computer: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "test_computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + "shut_down_duration": 0, + } ) computer.software_manager.install(application_class) - app: Application = computer.software_manager.software.get("DummyApplication") + app: Application = computer.software_manager.software.get("dummy-application") computer.power_off() diff --git a/tests/integration_tests/system/test_arp.py b/tests/integration_tests/system/test_arp.py index 6c7e853a..b9a92255 100644 --- a/tests/integration_tests/system/test_arp.py +++ b/tests/integration_tests/system/test_arp.py @@ -1,8 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterARP -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.arp.arp import ARP +from primaite.utils.validation.port import PORT_LOOKUP from tests.integration_tests.network.test_routing import multi_hop_network @@ -58,7 +57,7 @@ def test_arp_not_affected_by_acl(multi_hop_network): # Add explicit rule to block ARP traffic. This shouldn't actually stop ARP traffic # as it operates a different layer within the network. - router_1.acl.add_rule(action=ACLAction.DENY, src_port=Port.ARP, dst_port=Port.ARP, position=23) + router_1.acl.add_rule(action=ACLAction.DENY, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=23) pc_a_arp: ARP = pc_a.software_manager.arp diff --git a/tests/integration_tests/system/test_database_on_node.py b/tests/integration_tests/system/test_database_on_node.py index 965b4ae8..6627b7a1 100644 --- a/tests/integration_tests/system/test_database_on_node.py +++ b/tests/integration_tests/system/test_database_on_node.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address from typing import Tuple @@ -20,22 +20,38 @@ from primaite.simulator.system.software import SoftwareHealthState @pytest.fixture(scope="function") def peer_to_peer() -> Tuple[Computer, Computer]: network = Network() - node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + node_a: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_a", + "ip_address": "192.168.0.10", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node_a.power_on() node_a.software_manager.get_open_ports() - node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) + node_b: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_b", + "ip_address": "192.168.0.11", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) assert node_a.ping("192.168.0.11") node_a.software_manager.install(DatabaseClient) - node_a.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.11")) - node_a.software_manager.software["DatabaseClient"].run() + node_a.software_manager.software["database-client"].configure(server_ip_address=IPv4Address("192.168.0.11")) + node_a.software_manager.software["database-client"].run() node_b.software_manager.install(DatabaseService) - database_service: DatabaseService = node_b.software_manager.software["DatabaseService"] # noqa + database_service: DatabaseService = node_b.software_manager.software["database-service"] # noqa database_service.start() return node_a, node_b @@ -44,9 +60,9 @@ def peer_to_peer() -> Tuple[Computer, Computer]: def peer_to_peer_secure_db(peer_to_peer) -> Tuple[Computer, Computer]: node_a, node_b = peer_to_peer - database_service: DatabaseService = node_b.software_manager.software["DatabaseService"] # noqa + database_service: DatabaseService = node_b.software_manager.software["database-service"] # noqa database_service.stop() - database_service.password = "12345" + database_service.config.db_password = "12345" database_service.start() return node_a, node_b @@ -54,9 +70,9 @@ def peer_to_peer_secure_db(peer_to_peer) -> Tuple[Computer, Computer]: def test_database_client_server_connection(peer_to_peer): node_a, node_b = peer_to_peer - db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"] + db_client: DatabaseClient = node_a.software_manager.software["database-client"] - db_service: DatabaseService = node_b.software_manager.software["DatabaseService"] + db_service: DatabaseService = node_b.software_manager.software["database-service"] db_client.connect() @@ -71,9 +87,9 @@ def test_database_client_server_connection(peer_to_peer): def test_database_client_server_correct_password(peer_to_peer_secure_db): node_a, node_b = peer_to_peer_secure_db - db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"] + db_client: DatabaseClient = node_a.software_manager.software["database-client"] - db_service: DatabaseService = node_b.software_manager.software["DatabaseService"] + db_service: DatabaseService = node_b.software_manager.software["database-service"] db_client.configure(server_ip_address=IPv4Address("192.168.0.11"), server_password="12345") db_client.connect() @@ -84,9 +100,9 @@ def test_database_client_server_correct_password(peer_to_peer_secure_db): def test_database_client_server_incorrect_password(peer_to_peer_secure_db): node_a, node_b = peer_to_peer_secure_db - db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"] + db_client: DatabaseClient = node_a.software_manager.software["database-client"] - db_service: DatabaseService = node_b.software_manager.software["DatabaseService"] + db_service: DatabaseService = node_b.software_manager.software["database-service"] # should fail db_client.connect() @@ -102,7 +118,7 @@ def test_database_client_server_incorrect_password(peer_to_peer_secure_db): def test_database_client_native_connection_query(uc2_network): """Tests DB query across the network returns HTTP status 200 and date.""" web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + db_client: DatabaseClient = web_server.software_manager.software["database-client"] db_client.connect() assert db_client.query(sql="SELECT") assert db_client.query(sql="INSERT") @@ -111,7 +127,7 @@ def test_database_client_native_connection_query(uc2_network): def test_database_client_connection_query(uc2_network): """Tests DB query across the network returns HTTP status 200 and date.""" web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + db_client: DatabaseClient = web_server.software_manager.software["database-client"] db_connection: DatabaseClientConnection = db_client.get_new_connection() @@ -122,13 +138,13 @@ def test_database_client_connection_query(uc2_network): def test_create_database_backup(uc2_network): """Run the backup_database method and check if the FTP server has the relevant file.""" db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + db_service: DatabaseService = db_server.software_manager.software["database-service"] # back up should be created assert db_service.backup_database() is True backup_server: Server = uc2_network.get_node_by_hostname("backup_server") - ftp_server: FTPServer = backup_server.software_manager.software["FTPServer"] + ftp_server: FTPServer = backup_server.software_manager.software["ftp-server"] # backup file should exist in the backup server assert ftp_server.file_system.get_file(folder_name=db_service.uuid, file_name="database.db") is not None @@ -137,7 +153,7 @@ def test_create_database_backup(uc2_network): def test_restore_backup(uc2_network): """Run the restore_backup method and check if the backup is properly restored.""" db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + db_service: DatabaseService = db_server.software_manager.software["database-service"] # create a back up assert db_service.backup_database() is True @@ -156,14 +172,14 @@ def test_restore_backup(uc2_network): def test_restore_backup_without_updating_scan(uc2_network): """Same test as restore backup but the file is previously seen as corrupted.""" db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + db_service: DatabaseService = db_server.software_manager.software["database-service"] # create a back up assert db_service.backup_database() is True db_service.db_file.corrupt() # corrupt the db assert db_service.db_file.health_status == FileSystemItemHealthStatus.CORRUPT # db file is actually corrupt - assert db_service.db_file.visible_health_status == FileSystemItemHealthStatus.GOOD # not scanned yet + assert db_service.db_file.visible_health_status == FileSystemItemHealthStatus.NONE # not scanned yet db_service.db_file.scan() # scan the db file @@ -184,13 +200,13 @@ def test_restore_backup_without_updating_scan(uc2_network): def test_restore_backup_after_deleting_file_without_updating_scan(uc2_network): """Same test as restore backup but the file is previously seen as corrupted.""" db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + db_service: DatabaseService = db_server.software_manager.software["database-service"] assert db_service.backup_database() is True db_service.db_file.corrupt() # corrupt the db assert db_service.db_file.health_status == FileSystemItemHealthStatus.CORRUPT # db file is actually corrupt - assert db_service.db_file.visible_health_status == FileSystemItemHealthStatus.GOOD # not scanned yet + assert db_service.db_file.visible_health_status == FileSystemItemHealthStatus.NONE # not scanned yet db_service.db_file.scan() # scan the db file @@ -217,7 +233,7 @@ def test_restore_backup_after_deleting_file_without_updating_scan(uc2_network): def test_database_service_fix(uc2_network): """Test that the software fix applies to database service.""" db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + db_service: DatabaseService = db_server.software_manager.software["database-service"] assert db_service.backup_database() is True @@ -232,7 +248,7 @@ def test_database_service_fix(uc2_network): assert db_service.health_state_actual == SoftwareHealthState.FIXING # apply timestep until the fix is applied - for i in range(db_service.fixing_duration + 1): + for i in range(db_service.config.fixing_duration + 1): uc2_network.apply_timestep(i) assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD @@ -242,10 +258,10 @@ def test_database_service_fix(uc2_network): def test_database_cannot_be_queried_while_fixing(uc2_network): """Tests that the database service cannot be queried if the service is being fixed.""" db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + db_service: DatabaseService = db_server.software_manager.software["database-service"] web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + db_client: DatabaseClient = web_server.software_manager.software["database-client"] db_connection: DatabaseClientConnection = db_client.get_new_connection() @@ -266,7 +282,7 @@ def test_database_cannot_be_queried_while_fixing(uc2_network): assert db_connection.query(sql="SELECT") is False # apply timestep until the fix is applied - for i in range(db_service.fixing_duration + 1): + for i in range(db_service.config.fixing_duration + 1): uc2_network.apply_timestep(i) assert db_service.health_state_actual == SoftwareHealthState.GOOD @@ -279,10 +295,10 @@ def test_database_cannot_be_queried_while_fixing(uc2_network): def test_database_can_create_connection_while_fixing(uc2_network): """Tests that connections cannot be created while the database is being fixed.""" db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + db_service: DatabaseService = db_server.software_manager.software["database-service"] client_2: Server = uc2_network.get_node_by_hostname("client_2") - db_client: DatabaseClient = client_2.software_manager.software["DatabaseClient"] + db_client: DatabaseClient = client_2.software_manager.software["database-client"] db_connection: DatabaseClientConnection = db_client.get_new_connection() @@ -308,7 +324,7 @@ def test_database_can_create_connection_while_fixing(uc2_network): assert new_db_connection.query(sql="SELECT") is False # still should fail to query because FIXING # apply timestep until the fix is applied - for i in range(db_service.fixing_duration + 1): + for i in range(db_service.config.fixing_duration + 1): uc2_network.apply_timestep(i) assert db_service.health_state_actual == SoftwareHealthState.GOOD @@ -321,13 +337,13 @@ def test_database_can_create_connection_while_fixing(uc2_network): def test_database_client_cannot_query_offline_database_server(uc2_network): """Tests DB query across the network returns HTTP status 404 when db server is offline.""" db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") + db_service: DatabaseService = db_server.software_manager.software.get("database-service") assert db_server.operating_state is NodeOperatingState.ON assert db_service.operating_state is ServiceOperatingState.RUNNING web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") + db_client: DatabaseClient = web_server.software_manager.software.get("database-client") db_client.connect() assert len(db_client.client_connections) @@ -338,7 +354,7 @@ def test_database_client_cannot_query_offline_database_server(uc2_network): assert db_connection.query("INSERT") is True db_server.power_off() - for i in range(db_server.shut_down_duration + 1): + for i in range(db_server.config.shut_down_duration + 1): uc2_network.apply_timestep(timestep=i) assert db_server.operating_state is NodeOperatingState.OFF @@ -351,8 +367,8 @@ def test_database_client_cannot_query_offline_database_server(uc2_network): def test_database_client_uninstall_terminates_connections(peer_to_peer): node_a, node_b = peer_to_peer - db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"] - db_service: DatabaseService = node_b.software_manager.software["DatabaseService"] # noqa + db_client: DatabaseClient = node_a.software_manager.software["database-client"] + db_service: DatabaseService = node_b.software_manager.software["database-service"] # noqa db_connection: DatabaseClientConnection = db_client.get_new_connection() @@ -366,7 +382,7 @@ def test_database_client_uninstall_terminates_connections(peer_to_peer): assert db_connection.query("SELECT") # Perform the DatabaseClient uninstall - node_a.software_manager.uninstall("DatabaseClient") + node_a.software_manager.uninstall("database-client") # Check that all connection counters are updated accordingly and client connection can no longer query the database assert len(db_service.connections) == 0 @@ -381,8 +397,8 @@ def test_database_client_uninstall_terminates_connections(peer_to_peer): def test_database_service_can_terminate_connection(peer_to_peer): node_a, node_b = peer_to_peer - db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"] - db_service: DatabaseService = node_b.software_manager.software["DatabaseService"] # noqa + db_client: DatabaseClient = node_a.software_manager.software["database-client"] + db_service: DatabaseService = node_b.software_manager.software["database-service"] # noqa db_connection: DatabaseClientConnection = db_client.get_new_connection() @@ -412,47 +428,65 @@ def test_database_service_can_terminate_connection(peer_to_peer): def test_client_connection_terminate_does_not_terminate_another_clients_connection(): network = Network() - db_server = Server( - hostname="db_client", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0 + db_server: Server = Server.from_config( + config={ + "type": "server", + "hostname": "db_client", + "ip_address": "192.168.0.11", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } ) db_server.power_on() db_server.software_manager.install(DatabaseService) - db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] # noqa + db_service: DatabaseService = db_server.software_manager.software["database-service"] # noqa db_service.start() - client_a = Computer( - hostname="client_a", ip_address="192.168.0.12", subnet_mask="255.255.255.0", start_up_duration=0 + client_a = Computer.from_config( + config={ + "type": "computer", + "hostname": "client_a", + "ip_address": "192.168.0.12", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } ) client_a.power_on() client_a.software_manager.install(DatabaseClient) - client_a.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.11")) - client_a.software_manager.software["DatabaseClient"].run() + client_a.software_manager.software["database-client"].configure(server_ip_address=IPv4Address("192.168.0.11")) + client_a.software_manager.software["database-client"].run() - client_b = Computer( - hostname="client_b", ip_address="192.168.0.13", subnet_mask="255.255.255.0", start_up_duration=0 + client_b = Computer.from_config( + config={ + "type": "computer", + "hostname": "client_b", + "ip_address": "192.168.0.13", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } ) client_b.power_on() client_b.software_manager.install(DatabaseClient) - client_b.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.11")) - client_b.software_manager.software["DatabaseClient"].run() + client_b.software_manager.software["database-client"].configure(server_ip_address=IPv4Address("192.168.0.11")) + client_b.software_manager.software["database-client"].run() - switch = Switch(hostname="switch", start_up_duration=0, num_ports=3) + switch = Switch.from_config(config={"type": "switch", "hostname": "switch", "start_up_duration": 0, "num_ports": 3}) switch.power_on() network.connect(endpoint_a=switch.network_interface[1], endpoint_b=db_server.network_interface[1]) network.connect(endpoint_a=switch.network_interface[2], endpoint_b=client_a.network_interface[1]) network.connect(endpoint_a=switch.network_interface[3], endpoint_b=client_b.network_interface[1]) - db_client_a: DatabaseClient = client_a.software_manager.software["DatabaseClient"] # noqa + db_client_a: DatabaseClient = client_a.software_manager.software["database-client"] # noqa db_connection_a = db_client_a.get_new_connection() assert db_connection_a.query("SELECT") assert len(db_service.connections) == 1 - db_client_b: DatabaseClient = client_b.software_manager.software["DatabaseClient"] # noqa + db_client_b: DatabaseClient = client_b.software_manager.software["database-client"] # noqa db_connection_b = db_client_b.get_new_connection() assert db_connection_b.query("SELECT") @@ -465,6 +499,14 @@ def test_client_connection_terminate_does_not_terminate_another_clients_connecti def test_database_server_install_ftp_client(): - server = Server(hostname="db_server", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + server: Server = Server.from_config( + config={ + "type": "server", + "hostname": "db_server", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) server.software_manager.install(DatabaseService) - assert server.software_manager.software.get("FTPClient") + assert server.software_manager.software.get("ftp-client") diff --git a/tests/integration_tests/system/test_dns_client_server.py b/tests/integration_tests/system/test_dns_client_server.py index 480a90bc..068b94d2 100644 --- a/tests/integration_tests/system/test_dns_client_server.py +++ b/tests/integration_tests/system/test_dns_client_server.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address from typing import Tuple @@ -18,14 +18,14 @@ def dns_client_and_dns_server(client_server) -> Tuple[DNSClient, Computer, DNSSe # Install DNS Client on computer computer.software_manager.install(DNSClient) - dns_client: DNSClient = computer.software_manager.software.get("DNSClient") + dns_client: DNSClient = computer.software_manager.software.get("dns-client") dns_client.start() # set server as DNS Server dns_client.dns_server = IPv4Address(server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address) # Install DNS Server on server server.software_manager.install(DNSServer) - dns_server: DNSServer = server.software_manager.software.get("DNSServer") + dns_server: DNSServer = server.software_manager.software.get("dns-server") dns_server.start() # register arcd.com as a domain dns_server.dns_register( @@ -72,7 +72,7 @@ def test_dns_client_requests_offline_dns_server(dns_client_and_dns_server): server.power_off() - for i in range(server.shut_down_duration + 1): + for i in range(server.config.shut_down_duration + 1): server.apply_timestep(timestep=i) assert server.operating_state == NodeOperatingState.OFF diff --git a/tests/integration_tests/system/test_ftp_client_server.py b/tests/integration_tests/system/test_ftp_client_server.py index 22c5d484..bb3aa8f2 100644 --- a/tests/integration_tests/system/test_ftp_client_server.py +++ b/tests/integration_tests/system/test_ftp_client_server.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Tuple import pytest @@ -17,12 +17,12 @@ def ftp_client_and_ftp_server(client_server) -> Tuple[FTPClient, Computer, FTPSe # Install FTP Client service on computer computer.software_manager.install(FTPClient) - ftp_client: FTPClient = computer.software_manager.software.get("FTPClient") + ftp_client: FTPClient = computer.software_manager.software.get("ftp-client") ftp_client.start() # Install FTP Server service on server server.software_manager.install(FTPServer) - ftp_server: FTPServer = server.software_manager.software.get("FTPServer") + ftp_server: FTPServer = server.software_manager.software.get("ftp-server") ftp_server.start() return ftp_client, computer, ftp_server, server @@ -87,7 +87,7 @@ def test_ftp_client_tries_to_connect_to_offline_server(ftp_client_and_ftp_server server.power_off() - for i in range(server.shut_down_duration + 1): + for i in range(server.config.shut_down_duration + 1): server.apply_timestep(timestep=i) assert ftp_client.operating_state == ServiceOperatingState.RUNNING diff --git a/tests/integration_tests/system/test_nmap.py b/tests/integration_tests/system/test_nmap.py index 2b8691cc..e5f08a94 100644 --- a/tests/integration_tests/system/test_nmap.py +++ b/tests/integration_tests/system/test_nmap.py @@ -1,13 +1,13 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from enum import Enum from ipaddress import IPv4Address, IPv4Network import yaml from primaite.game.game import PrimaiteGame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.nmap import NMAP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT @@ -15,7 +15,7 @@ def test_ping_scan_all_on(example_network): network = example_network client_1 = network.get_node_by_hostname("client_1") - client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa + client_1_nmap: NMAP = client_1.software_manager.software["nmap"] # noqa expected_result = [IPv4Address("192.168.1.10"), IPv4Address("192.168.1.14")] actual_result = client_1_nmap.ping_scan(target_ip_address=["192.168.1.10", "192.168.1.14"]) @@ -27,7 +27,7 @@ def test_ping_scan_all_on_full_network(example_network): network = example_network client_1 = network.get_node_by_hostname("client_1") - client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa + client_1_nmap: NMAP = client_1.software_manager.software["nmap"] # noqa expected_result = [IPv4Address("192.168.1.1"), IPv4Address("192.168.1.10"), IPv4Address("192.168.1.14")] actual_result = client_1_nmap.ping_scan(target_ip_address=IPv4Network("192.168.1.0/24")) @@ -39,7 +39,7 @@ def test_ping_scan_some_on(example_network): network = example_network client_1 = network.get_node_by_hostname("client_1") - client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa + client_1_nmap: NMAP = client_1.software_manager.software["nmap"] # noqa network.get_node_by_hostname("server_2").power_off() @@ -53,7 +53,7 @@ def test_ping_scan_all_off(example_network): network = example_network client_1 = network.get_node_by_hostname("client_1") - client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa + client_1_nmap: NMAP = client_1.software_manager.software["nmap"] # noqa network.get_node_by_hostname("server_1").power_off() network.get_node_by_hostname("server_2").power_off() @@ -68,15 +68,17 @@ def test_port_scan_one_node_one_port(example_network): network = example_network client_1 = network.get_node_by_hostname("client_1") - client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa + client_1_nmap: NMAP = client_1.software_manager.software["nmap"] # noqa client_2 = network.get_node_by_hostname("client_2") actual_result = client_1_nmap.port_scan( - target_ip_address=client_2.network_interface[1].ip_address, target_port=Port.DNS, target_protocol=IPProtocol.TCP + target_ip_address=client_2.network_interface[1].ip_address, + target_port=PORT_LOOKUP["DNS"], + target_protocol=PROTOCOL_LOOKUP["TCP"], ) - expected_result = {IPv4Address("192.168.10.22"): {IPProtocol.TCP: [Port.DNS]}} + expected_result = {IPv4Address("192.168.10.22"): {PROTOCOL_LOOKUP["TCP"]: [PORT_LOOKUP["DNS"]]}} assert actual_result == expected_result @@ -97,18 +99,24 @@ def test_port_scan_full_subnet_all_ports_and_protocols(example_network): network = example_network client_1 = network.get_node_by_hostname("client_1") - client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa + client_1_nmap: NMAP = client_1.software_manager.software["nmap"] # noqa actual_result = client_1_nmap.port_scan( target_ip_address=IPv4Network("192.168.10.0/24"), - target_port=[Port.ARP, Port.HTTP, Port.FTP, Port.DNS, Port.NTP], + target_port=[ + PORT_LOOKUP["ARP"], + PORT_LOOKUP["HTTP"], + PORT_LOOKUP["FTP"], + PORT_LOOKUP["DNS"], + PORT_LOOKUP["NTP"], + ], ) expected_result = { - IPv4Address("192.168.10.1"): {IPProtocol.UDP: [Port.ARP]}, + IPv4Address("192.168.10.1"): {PROTOCOL_LOOKUP["UDP"]: [PORT_LOOKUP["ARP"]]}, IPv4Address("192.168.10.22"): { - IPProtocol.TCP: [Port.HTTP, Port.FTP, Port.DNS], - IPProtocol.UDP: [Port.ARP, Port.NTP], + PROTOCOL_LOOKUP["TCP"]: [PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]], + PROTOCOL_LOOKUP["UDP"]: [PORT_LOOKUP["ARP"], PORT_LOOKUP["NTP"]], }, } @@ -119,13 +127,15 @@ def test_network_service_recon_all_ports_and_protocols(example_network): network = example_network client_1 = network.get_node_by_hostname("client_1") - client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa + client_1_nmap: NMAP = client_1.software_manager.software["nmap"] # noqa actual_result = client_1_nmap.network_service_recon( - target_ip_address=IPv4Network("192.168.10.0/24"), target_port=Port.HTTP, target_protocol=IPProtocol.TCP + target_ip_address=IPv4Network("192.168.10.0/24"), + target_port=PORT_LOOKUP["HTTP"], + target_protocol=PROTOCOL_LOOKUP["TCP"], ) - expected_result = {IPv4Address("192.168.10.22"): {IPProtocol.TCP: [Port.HTTP]}} + expected_result = {IPv4Address("192.168.10.22"): {PROTOCOL_LOOKUP["TCP"]: [PORT_LOOKUP["HTTP"]]}} assert sort_dict(actual_result) == sort_dict(expected_result) diff --git a/tests/integration_tests/system/test_ntp_client_server.py b/tests/integration_tests/system/test_ntp_client_server.py index 957c1aeb..6c700fb8 100644 --- a/tests/integration_tests/system/test_ntp_client_server.py +++ b/tests/integration_tests/system/test_ntp_client_server.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address from time import sleep from typing import Tuple @@ -29,12 +29,12 @@ def create_ntp_network(client_server) -> Tuple[NTPClient, Computer, NTPServer, S server.power_on() server.software_manager.install(NTPServer) - ntp_server: NTPServer = server.software_manager.software.get("NTPServer") + ntp_server: NTPServer = server.software_manager.software.get("ntp-server") ntp_server.start() client.power_on() client.software_manager.install(NTPClient) - ntp_client: NTPClient = client.software_manager.software.get("NTPClient") + ntp_client: NTPClient = client.software_manager.software.get("ntp-client") ntp_client.start() return ntp_client, client, ntp_server, server @@ -43,8 +43,8 @@ def create_ntp_network(client_server) -> Tuple[NTPClient, Computer, NTPServer, S def test_ntp_client_server(create_ntp_network): ntp_client, client, ntp_server, server = create_ntp_network - ntp_server: NTPServer = server.software_manager.software["NTPServer"] - ntp_client: NTPClient = client.software_manager.software["NTPClient"] + ntp_server: NTPServer = server.software_manager.software["ntp-server"] + ntp_client: NTPClient = client.software_manager.software["ntp-client"] assert ntp_server.operating_state == ServiceOperatingState.RUNNING assert ntp_client.operating_state == ServiceOperatingState.RUNNING @@ -64,8 +64,8 @@ def test_ntp_client_server(create_ntp_network): def test_ntp_server_failure(create_ntp_network): ntp_client, client, ntp_server, server = create_ntp_network - ntp_server: NTPServer = server.software_manager.software["NTPServer"] - ntp_client: NTPClient = client.software_manager.software["NTPClient"] + ntp_server: NTPServer = server.software_manager.software["ntp-server"] + ntp_client: NTPClient = client.software_manager.software["ntp-client"] assert ntp_client.operating_state == ServiceOperatingState.RUNNING assert ntp_client.operating_state == ServiceOperatingState.RUNNING diff --git a/tests/integration_tests/system/test_service_listening_on_ports.py b/tests/integration_tests/system/test_service_listening_on_ports.py index fd502a70..9c25d4f9 100644 --- a/tests/integration_tests/system/test_service_listening_on_ports.py +++ b/tests/integration_tests/system/test_service_listening_on_ports.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Any, Dict, List, Set import yaml @@ -6,19 +6,26 @@ from pydantic import Field from primaite.game.game import PrimaiteGame from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.service import Service +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT -class _DatabaseListener(Service): - name: str = "DatabaseListener" - protocol: IPProtocol = IPProtocol.TCP - port: Port = Port.NONE - listen_on_ports: Set[Port] = {Port.POSTGRES_SERVER} +class _DatabaseListener(Service, discriminator="database-listener"): + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for _DatabaseListener.""" + + type: str = "database-listener" + listen_on_ports: Set[int] = {PORT_LOOKUP["POSTGRES_SERVER"]} + + config: "_DatabaseListener.ConfigSchema" = Field(default_factory=lambda: _DatabaseListener.ConfigSchema()) + name: str = "database-listener" + protocol: str = PROTOCOL_LOOKUP["TCP"] + port: int = PORT_LOOKUP["NONE"] + listen_on_ports: Set[int] = {PORT_LOOKUP["POSTGRES_SERVER"]} payloads_received: List[Any] = Field(default_factory=list) def receive(self, payload: Any, session_id: str, **kwargs) -> bool: @@ -34,15 +41,15 @@ def test_http_listener(client_server): computer, server = client_server server.software_manager.install(DatabaseService) - server_db = server.software_manager.software["DatabaseService"] + server_db = server.software_manager.software["database-service"] server_db.start() server.software_manager.install(_DatabaseListener) - server_db_listener: _DatabaseListener = server.software_manager.software["DatabaseListener"] + server_db_listener: _DatabaseListener = server.software_manager.software["database-listener"] server_db_listener.start() computer.software_manager.install(DatabaseClient) - computer_db_client: DatabaseClient = computer.software_manager.software["DatabaseClient"] + computer_db_client: DatabaseClient = computer.software_manager.software["database-client"] computer_db_client.run() computer_db_client.server_ip_address = server.network_interface[1].ip_address @@ -51,8 +58,8 @@ def test_http_listener(client_server): computer.session_manager.receive_payload_from_software_manager( payload="masquerade as Database traffic", dst_ip_address=server.network_interface[1].ip_address, - dst_port=Port.POSTGRES_SERVER, - ip_protocol=IPProtocol.TCP, + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + ip_protocol=PROTOCOL_LOOKUP["TCP"], ) assert len(server_db_listener.payloads_received) == 1 @@ -76,9 +83,9 @@ def test_set_listen_on_ports_from_config(): network = PrimaiteGame.from_config(cfg=config_dict).simulation.network client: Computer = network.get_node_by_hostname("client") - assert Port.SMB in client.software_manager.get_open_ports() - assert Port.IPP in client.software_manager.get_open_ports() + assert PORT_LOOKUP["SMB"] in client.software_manager.get_open_ports() + assert PORT_LOOKUP["IPP"] in client.software_manager.get_open_ports() - web_browser = client.software_manager.software["WebBrowser"] + web_browser = client.software_manager.software["web-browser"] - assert not web_browser.listen_on_ports.difference({Port.SMB, Port.IPP}) + assert not web_browser.listen_on_ports.difference({PORT_LOOKUP["SMB"], PORT_LOOKUP["IPP"]}) diff --git a/tests/integration_tests/system/test_service_on_node.py b/tests/integration_tests/system/test_service_on_node.py index cf9728ce..5afb71dc 100644 --- a/tests/integration_tests/system/test_service_on_node.py +++ b/tests/integration_tests/system/test_service_on_node.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Tuple import pytest @@ -13,17 +13,19 @@ from primaite.simulator.system.services.service import Service, ServiceOperating def populated_node( service_class, ) -> Tuple[Server, Service]: - server = Server( - hostname="server", - ip_address="192.168.0.1", - subnet_mask="255.255.255.0", - start_up_duration=0, - shut_down_duration=0, - ) + server_cfg = { + "type": "server", + "hostname": "server", + "ip_address": "192.168.0.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + "shut_down_duration": 0, + } + server: Server = Server.from_config(config=server_cfg) server.power_on() server.software_manager.install(service_class) - service = server.software_manager.software.get("DummyService") + service = server.software_manager.software.get("dummy-service") service.start() return server, service @@ -31,18 +33,20 @@ def populated_node( def test_service_on_offline_node(service_class): """Test to check that the service cannot be interacted with when node it is on is off.""" - computer: Computer = Computer( - hostname="test_computer", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - shut_down_duration=0, - ) + computer_cfg = { + "type": "computer", + "hostname": "test_computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + "shut_down_duration": 0, + } + computer: Computer = Computer.from_config(config=computer_cfg) computer.power_on() computer.software_manager.install(service_class) - service: Service = computer.software_manager.software.get("DummyService") + service: Service = computer.software_manager.software.get("dummy-service") computer.power_off() diff --git a/tests/integration_tests/system/test_user_session_manager_logins.py b/tests/integration_tests/system/test_user_session_manager_logins.py index 4318530c..9736232b 100644 --- a/tests/integration_tests/system/test_user_session_manager_logins.py +++ b/tests/integration_tests/system/test_user_session_manager_logins.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Tuple from uuid import uuid4 @@ -14,21 +14,27 @@ from primaite.simulator.network.hardware.nodes.host.server import Server def client_server_network() -> Tuple[Computer, Server, Network]: network = Network() - client = Computer( - hostname="client", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + client = Computer.from_config( + config={ + "type": "computer", + "hostname": "client", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } ) client.power_on() - server = Server( - hostname="server", - ip_address="192.168.1.3", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + server = Server.from_config( + config={ + "type": "server", + "hostname": "server", + "ip_address": "192.168.1.3", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } ) server.power_on() diff --git a/tests/integration_tests/system/test_web_client_server.py b/tests/integration_tests/system/test_web_client_server.py index 05cbae4f..fef483e9 100644 --- a/tests/integration_tests/system/test_web_client_server.py +++ b/tests/integration_tests/system/test_web_client_server.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Tuple import pytest @@ -20,23 +20,23 @@ def web_client_and_web_server(client_server) -> Tuple[WebBrowser, Computer, WebS # Install Web Browser on computer computer.software_manager.install(WebBrowser) - web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") + web_browser: WebBrowser = computer.software_manager.software.get("web-browser") web_browser.run() # Install DNS Client service on computer computer.software_manager.install(DNSClient) - dns_client: DNSClient = computer.software_manager.software.get("DNSClient") + dns_client: DNSClient = computer.software_manager.software.get("dns-client") # set dns server dns_client.dns_server = server.network_interfaces[next(iter(server.network_interfaces))].ip_address # Install Web Server service on server server.software_manager.install(WebServer) - web_server_service: WebServer = server.software_manager.software.get("WebServer") + web_server_service: WebServer = server.software_manager.software.get("web-server") web_server_service.start() # Install DNS Server service on server server.software_manager.install(DNSServer) - dns_server: DNSServer = server.software_manager.software.get("DNSServer") + dns_server: DNSServer = server.software_manager.software.get("dns-server") # register arcd.com to DNS dns_server.dns_register( domain_name="arcd.com", @@ -51,7 +51,7 @@ def test_web_page_get_users_page_request_with_domain_name(web_client_and_web_ser web_browser_app, computer, web_server_service, server = web_client_and_web_server web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address - web_browser_app.target_url = f"http://arcd.com/" + web_browser_app.config.target_url = f"http://arcd.com/" assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING assert web_browser_app.get_webpage() is True @@ -66,7 +66,7 @@ def test_web_page_get_users_page_request_with_ip_address(web_client_and_web_serv web_browser_app, computer, web_server_service, server = web_client_and_web_server web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address - web_browser_app.target_url = f"http://{web_server_ip}/" + web_browser_app.config.target_url = f"http://{web_server_ip}/" assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING assert web_browser_app.get_webpage() is True @@ -81,7 +81,7 @@ def test_web_page_request_from_shut_down_server(web_client_and_web_server): web_browser_app, computer, web_server_service, server = web_client_and_web_server web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address - web_browser_app.target_url = f"http://arcd.com/" + web_browser_app.config.target_url = f"http://arcd.com/" assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING assert web_browser_app.get_webpage() is True @@ -94,7 +94,7 @@ def test_web_page_request_from_shut_down_server(web_client_and_web_server): server.power_off() - for i in range(server.shut_down_duration + 1): + for i in range(server.config.shut_down_duration + 1): server.apply_timestep(timestep=i) # node should be off @@ -108,7 +108,7 @@ def test_web_page_request_from_closed_web_browser(web_client_and_web_server): web_browser_app, computer, web_server_service, server = web_client_and_web_server assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING - web_browser_app.target_url = f"http://arcd.com/" + web_browser_app.config.target_url = f"http://arcd.com/" assert web_browser_app.get_webpage() is True # latest response should have status code 200 diff --git a/tests/integration_tests/system/test_web_client_server_and_database.py b/tests/integration_tests/system/test_web_client_server_and_database.py index 5a765763..e8045ed9 100644 --- a/tests/integration_tests/system/test_web_client_server_and_database.py +++ b/tests/integration_tests/system/test_web_client_server_and_database.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address from typing import Tuple @@ -9,7 +9,6 @@ from primaite.simulator.network.hardware.base import Link from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService @@ -17,6 +16,7 @@ from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.web_server.web_server import WebServer from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -24,17 +24,22 @@ def web_client_web_server_database(example_network) -> Tuple[Network, Computer, # add rules to network router router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) # Allow DNS requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["DNS"], dst_port=PORT_LOOKUP["DNS"], position=1) # Allow FTP requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=2) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=2) # Open port 80 for web server - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=3 + ) # Create Computer computer: Computer = example_network.get_node_by_hostname("client_1") @@ -63,29 +68,29 @@ def web_client_web_server_database(example_network) -> Tuple[Network, Computer, # Install DatabaseService on db server db_server.software_manager.install(DatabaseService) - db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") + db_service: DatabaseService = db_server.software_manager.software.get("database-service") db_service.start() # Install Web Browser on computer computer.software_manager.install(WebBrowser) - web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") - web_browser.target_url = "http://arcd.com/users/" + web_browser: WebBrowser = computer.software_manager.software.get("web-browser") + web_browser.config.target_url = "http://arcd.com/users/" web_browser.run() # Install DNS Client service on computer computer.software_manager.install(DNSClient) - dns_client: DNSClient = computer.software_manager.software.get("DNSClient") + dns_client: DNSClient = computer.software_manager.software.get("dns-client") # set dns server dns_client.dns_server = web_server.network_interfaces[next(iter(web_server.network_interfaces))].ip_address # Install Web Server service on web server web_server.software_manager.install(WebServer) - web_server_service: WebServer = web_server.software_manager.software.get("WebServer") + web_server_service: WebServer = web_server.software_manager.software.get("web-server") web_server_service.start() # Install DNS Server service on web server web_server.software_manager.install(DNSServer) - dns_server: DNSServer = web_server.software_manager.software.get("DNSServer") + dns_server: DNSServer = web_server.software_manager.software.get("dns-server") # register arcd.com to DNS dns_server.dns_register( domain_name="arcd.com", @@ -94,7 +99,7 @@ def web_client_web_server_database(example_network) -> Tuple[Network, Computer, # Install DatabaseClient service on web server web_server.software_manager.install(DatabaseClient) - db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") + db_client: DatabaseClient = web_server.software_manager.software.get("database-client") db_client.server_ip_address = IPv4Address(db_server_nic.ip_address) # set IP address of Database Server db_client.run() assert dns_client.check_domain_exists("arcd.com") @@ -106,7 +111,7 @@ def web_client_web_server_database(example_network) -> Tuple[Network, Computer, def test_web_client_requests_users(web_client_web_server_database): _, computer, _, _ = web_client_web_server_database - web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") + web_browser: WebBrowser = computer.software_manager.software.get("web-browser") assert web_browser.get_webpage() @@ -116,8 +121,8 @@ def test_database_fix_disrupts_web_client(uc2_network): computer: Computer = uc2_network.get_node_by_hostname("client_1") db_server: Server = uc2_network.get_node_by_hostname("database_server") - web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") - database_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") + web_browser: WebBrowser = computer.software_manager.software.get("web-browser") + database_service: DatabaseService = db_server.software_manager.software.get("database-service") # fix the database service database_service.fix() @@ -126,7 +131,7 @@ def test_database_fix_disrupts_web_client(uc2_network): assert web_browser.get_webpage() is False - for i in range(database_service.fixing_duration + 1): + for i in range(database_service.config.fixing_duration + 1): uc2_network.apply_timestep(i) assert database_service.health_state_actual == SoftwareHealthState.GOOD @@ -138,7 +143,7 @@ class TestWebBrowserHistory: def test_populating_history(self, web_client_web_server_database): network, computer, _, _ = web_client_web_server_database - web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") + web_browser: WebBrowser = computer.software_manager.software.get("web-browser") assert web_browser.history == [] web_browser.get_webpage() assert len(web_browser.history) == 1 @@ -148,7 +153,9 @@ class TestWebBrowserHistory: assert web_browser.history[-1].response_code == 200 router = network.get_node_by_hostname("router_1") - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) + router.acl.add_rule( + action=ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0 + ) assert not web_browser.get_webpage() assert len(web_browser.history) == 3 # with current NIC behaviour, even if you block communication, you won't get SERVER_UNREACHABLE because @@ -158,17 +165,19 @@ class TestWebBrowserHistory: def test_history_in_state(self, web_client_web_server_database): network, computer, _, _ = web_client_web_server_database - web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") + web_browser: WebBrowser = computer.software_manager.software.get("web-browser") state = computer.describe_state() - assert "history" in state["applications"]["WebBrowser"] - assert len(state["applications"]["WebBrowser"]["history"]) == 0 + assert "history" in state["applications"]["web-browser"] + assert len(state["applications"]["web-browser"]["history"]) == 0 web_browser.get_webpage() router = network.get_node_by_hostname("router_1") - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) + router.acl.add_rule( + action=ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0 + ) web_browser.get_webpage() state = computer.describe_state() - assert state["applications"]["WebBrowser"]["history"][0]["outcome"] == 200 - assert state["applications"]["WebBrowser"]["history"][1]["outcome"] == 404 + assert state["applications"]["web-browser"]["history"][0]["outcome"] == 200 + assert state["applications"]["web-browser"]["history"][1]["outcome"] == 404 diff --git a/tests/integration_tests/test_simulation/__init__.py b/tests/integration_tests/test_simulation/__init__.py index be6c00e7..836b79af 100644 --- a/tests/integration_tests/test_simulation/__init__.py +++ b/tests/integration_tests/test_simulation/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py index 95634cf1..ebff2893 100644 --- a/tests/integration_tests/test_simulation/test_request_response.py +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK # some test cases: # 0. test that sending a request to a valid target results in a success # 1. test that sending a request to a component that doesn't exist results in a failure @@ -12,7 +12,7 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validation.port import PORT_LOOKUP from tests.conftest import DummyApplication, DummyService @@ -48,13 +48,13 @@ def test_successful_application_requests(example_network): client_1 = net.get_node_by_hostname("client_1") client_1.software_manager.install(DummyApplication) - client_1.software_manager.software.get("DummyApplication").run() + client_1.software_manager.software.get("dummy-application").run() - resp_1 = net.apply_request(["node", "client_1", "application", "DummyApplication", "scan"]) + resp_1 = net.apply_request(["node", "client_1", "application", "dummy-application", "scan"]) assert resp_1 == RequestResponse(status="success", data={}) - resp_2 = net.apply_request(["node", "client_1", "application", "DummyApplication", "fix"]) + resp_2 = net.apply_request(["node", "client_1", "application", "dummy-application", "fix"]) assert resp_2 == RequestResponse(status="success", data={}) - resp_3 = net.apply_request(["node", "client_1", "application", "DummyApplication", "compromise"]) + resp_3 = net.apply_request(["node", "client_1", "application", "dummy-application", "compromise"]) assert resp_3 == RequestResponse(status="success", data={}) @@ -77,7 +77,7 @@ def test_successful_service_requests(example_network): "scan", "fix", ]: - resp_1 = net.apply_request(["node", "server_1", "service", "DummyService", verb]) + resp_1 = net.apply_request(["node", "server_1", "service", "dummy-service", verb]) assert resp_1 == RequestResponse(status="success", data={}) server_1.apply_timestep(timestep=1) server_1.apply_timestep(timestep=1) @@ -93,7 +93,7 @@ def test_non_existent_requests(example_network): net = example_network resp_1 = net.apply_request(["fake"]) assert resp_1.status == "unreachable" - resp_2 = net.apply_request(["network", "node", "client_39", "application", "WebBrowser", "execute"]) + resp_2 = net.apply_request(["network", "node", "client_39", "application", "web-browser", "execute"]) assert resp_2.status == "unreachable" @@ -102,8 +102,8 @@ def test_non_existent_requests(example_network): [ ["node", "client_1", "file_system", "folder", "root", "scan"], ["node", "client_1", "os", "scan"], - ["node", "client_1", "service", "DNSClient", "stop"], - ["node", "client_1", "application", "WebBrowser", "scan"], + ["node", "client_1", "service", "dns-client", "stop"], + ["node", "client_1", "application", "web-browser", "scan"], ["node", "client_1", "network_interface", 1, "disable"], ], ) @@ -111,7 +111,7 @@ def test_request_fails_if_node_off(example_network, node_request): """Test that requests succeed when the node is on, and fail if the node is off.""" net = example_network client_1: HostNode = net.get_node_by_hostname("client_1") - client_1.shut_down_duration = 0 + client_1.config.shut_down_duration = 0 assert client_1.operating_state == NodeOperatingState.ON resp_1 = net.apply_request(node_request) @@ -128,10 +128,14 @@ class TestDataManipulationGreenRequests: """Test that green requests succeed when the node is on and fail if the node is off.""" net: Network = uc2_network - client_1_browser_execute = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"]) - client_1_db_client_execute = net.apply_request(["node", "client_1", "application", "DatabaseClient", "execute"]) - client_2_browser_execute = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"]) - client_2_db_client_execute = net.apply_request(["node", "client_2", "application", "DatabaseClient", "execute"]) + client_1_browser_execute = net.apply_request(["node", "client_1", "application", "web-browser", "execute"]) + client_1_db_client_execute = net.apply_request( + ["node", "client_1", "application", "database-client", "execute"] + ) + client_2_browser_execute = net.apply_request(["node", "client_2", "application", "web-browser", "execute"]) + client_2_db_client_execute = net.apply_request( + ["node", "client_2", "application", "database-client", "execute"] + ) assert client_1_browser_execute.status == "success" assert client_1_db_client_execute.status == "success" assert client_2_browser_execute.status == "success" @@ -140,18 +144,18 @@ class TestDataManipulationGreenRequests: client_1 = net.get_node_by_hostname("client_1") client_2 = net.get_node_by_hostname("client_2") - client_1.shut_down_duration = 0 + client_1.config.shut_down_duration = 0 client_1.power_off() - client_2.shut_down_duration = 0 + client_2.config.shut_down_duration = 0 client_2.power_off() - client_1_browser_execute_off = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"]) + client_1_browser_execute_off = net.apply_request(["node", "client_1", "application", "web-browser", "execute"]) client_1_db_client_execute_off = net.apply_request( - ["node", "client_1", "application", "DatabaseClient", "execute"] + ["node", "client_1", "application", "database-client", "execute"] ) - client_2_browser_execute_off = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"]) + client_2_browser_execute_off = net.apply_request(["node", "client_2", "application", "web-browser", "execute"]) client_2_db_client_execute_off = net.apply_request( - ["node", "client_2", "application", "DatabaseClient", "execute"] + ["node", "client_2", "application", "database-client", "execute"] ) assert client_1_browser_execute_off.status == "failure" assert client_1_db_client_execute_off.status == "failure" @@ -166,24 +170,34 @@ class TestDataManipulationGreenRequests: client_1: HostNode = net.get_node_by_hostname("client_1") client_2: HostNode = net.get_node_by_hostname("client_2") - client_1_browser_execute = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"]) - client_2_browser_execute = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"]) + client_1_browser_execute = net.apply_request(["node", "client_1", "application", "web-browser", "execute"]) + client_2_browser_execute = net.apply_request(["node", "client_2", "application", "web-browser", "execute"]) assert client_1_browser_execute.status == "success" assert client_2_browser_execute.status == "success" - router.acl.add_rule(ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) - client_1_browser_execute = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"]) - client_2_browser_execute = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"]) + router.acl.add_rule(ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=3) + client_1_browser_execute = net.apply_request(["node", "client_1", "application", "web-browser", "execute"]) + client_2_browser_execute = net.apply_request(["node", "client_2", "application", "web-browser", "execute"]) assert client_1_browser_execute.status == "failure" assert client_2_browser_execute.status == "failure" - client_1_db_client_execute = net.apply_request(["node", "client_1", "application", "DatabaseClient", "execute"]) - client_2_db_client_execute = net.apply_request(["node", "client_2", "application", "DatabaseClient", "execute"]) + client_1_db_client_execute = net.apply_request( + ["node", "client_1", "application", "database-client", "execute"] + ) + client_2_db_client_execute = net.apply_request( + ["node", "client_2", "application", "database-client", "execute"] + ) assert client_1_db_client_execute.status == "success" assert client_2_db_client_execute.status == "success" - router.acl.add_rule(ACLAction.DENY, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER) - client_1_db_client_execute = net.apply_request(["node", "client_1", "application", "DatabaseClient", "execute"]) - client_2_db_client_execute = net.apply_request(["node", "client_2", "application", "DatabaseClient", "execute"]) + router.acl.add_rule( + ACLAction.DENY, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"] + ) + client_1_db_client_execute = net.apply_request( + ["node", "client_1", "application", "database-client", "execute"] + ) + client_2_db_client_execute = net.apply_request( + ["node", "client_2", "application", "database-client", "execute"] + ) assert client_1_db_client_execute.status == "failure" assert client_2_db_client_execute.status == "failure" diff --git a/tests/mock_and_patch/__init__.py b/tests/mock_and_patch/__init__.py index be6c00e7..836b79af 100644 --- a/tests/mock_and_patch/__init__.py +++ b/tests/mock_and_patch/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/mock_and_patch/get_session_path_mock.py b/tests/mock_and_patch/get_session_path_mock.py index f315fca4..073028a7 100644 --- a/tests/mock_and_patch/get_session_path_mock.py +++ b/tests/mock_and_patch/get_session_path_mock.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import tempfile from datetime import datetime from pathlib import Path diff --git a/tests/unit_tests/__init__.py b/tests/unit_tests/__init__.py index be6c00e7..836b79af 100644 --- a/tests/unit_tests/__init__.py +++ b/tests/unit_tests/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/__init__.py b/tests/unit_tests/_primaite/__init__.py index be6c00e7..836b79af 100644 --- a/tests/unit_tests/_primaite/__init__.py +++ b/tests/unit_tests/_primaite/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_game/__init__.py b/tests/unit_tests/_primaite/_game/__init__.py index be6c00e7..836b79af 100644 --- a/tests/unit_tests/_primaite/_game/__init__.py +++ b/tests/unit_tests/_primaite/_game/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_game/_agent/__init__.py b/tests/unit_tests/_primaite/_game/_agent/__init__.py index be6c00e7..836b79af 100644 --- a/tests/unit_tests/_primaite/_game/_agent/__init__.py +++ b/tests/unit_tests/_primaite/_game/_agent/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_game/_agent/test_actions.py b/tests/unit_tests/_primaite/_game/_agent/test_actions.py index c2d31ee1..cef24bb1 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_actions.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_actions.py @@ -1,11 +1,12 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from unittest.mock import Mock import pytest +from pydantic import ValidationError -from primaite.game.agent.actions import ( - ActionManager, - DoNothingAction, +from primaite.game.agent.actions import ActionManager +from primaite.game.agent.actions.manager import DoNothingAction +from primaite.game.agent.actions.service import ( NodeServiceDisableAction, NodeServiceEnableAction, NodeServicePauseAction, @@ -18,14 +19,9 @@ from primaite.game.agent.actions import ( def test_do_nothing_action_form_request(): - """Test that the DoNothingAction can form a request and that it is correct.""" - manager = Mock() - - action = DoNothingAction(manager=manager) - - request = action.form_request() - - assert request == ["do_nothing"] + """Test that the do_nothingAction can form a request and that it is correct.""" + request = DoNothingAction.form_request(DoNothingAction.ConfigSchema()) + assert request == ["do-nothing"] @pytest.mark.parametrize( @@ -42,7 +38,7 @@ def test_do_nothing_action_form_request(): ], ) # flake8: noqa @pytest.mark.parametrize( - "node_name, service_name, expect_to_do_nothing", + "node_name, service_name, expect_failure", [ ("pc_1", "chrome", False), (None, "chrome", True), @@ -50,42 +46,15 @@ def test_do_nothing_action_form_request(): (None, None, True), ], ) # flake8: noqa -def test_service_action_form_request(node_name, service_name, expect_to_do_nothing, action_class, action_verb): +def test_service_action_form_request(node_name, service_name, expect_failure, action_class, action_verb): """Test that the ServiceScanAction can form a request and that it is correct.""" - manager: ActionManager = Mock() - manager.get_node_name_by_idx.return_value = node_name - manager.get_service_name_by_idx.return_value = service_name - - action = action_class(manager=manager, num_nodes=1, num_services=1) - - request = action.form_request(node_id=0, service_id=0) - - if expect_to_do_nothing: - assert request == ["do_nothing"] + if expect_failure: + with pytest.raises(ValidationError): + request = action_class.form_request( + config=action_class.ConfigSchema(node_name=node_name, service_name=service_name) + ) else: + request = action_class.form_request( + config=action_class.ConfigSchema(node_name=node_name, service_name=service_name) + ) assert request == ["network", "node", node_name, "service", service_name, action_verb] - - -@pytest.mark.parametrize( - "node_name, service_name, expect_to_do_nothing", - [ - ("pc_1", "chrome", False), - (None, "chrome", True), - ("pc_1", None, True), - (None, None, True), - ], -) # flake8: noqa -def test_service_scan_form_request(node_name, service_name, expect_to_do_nothing): - """Test that the ServiceScanAction can form a request and that it is correct.""" - manager: ActionManager = Mock() - manager.get_node_name_by_idx.return_value = node_name - manager.get_service_name_by_idx.return_value = service_name - - action = NodeServiceScanAction(manager=manager, num_nodes=1, num_services=1) - - request = action.form_request(node_id=0, service_id=0) - - if expect_to_do_nothing: - assert request == ["do_nothing"] - else: - assert request == ["network", "node", node_name, "service", service_name, "scan"] diff --git a/tests/unit_tests/_primaite/_game/_agent/test_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_agent.py new file mode 100644 index 00000000..b555f1b2 --- /dev/null +++ b/tests/unit_tests/_primaite/_game/_agent/test_agent.py @@ -0,0 +1,52 @@ +from primaite.game.agent.observations.file_system_observations import FileObservation +from primaite.game.agent.observations.observation_manager import NullObservation +from primaite.game.agent.scripted_agents.random_agent import RandomAgent + + +def test_creating_empty_agent(): + agent = RandomAgent(config={"ref": "Empty Agent"}) + assert len(agent.action_manager.action_map) == 0 + assert isinstance(agent.observation_manager.obs, NullObservation) + assert len(agent.reward_function.reward_components) == 0 + + +def test_creating_agent_from_dict(): + action_config = { + "action_map": { + 0: {"action": "do-nothing", "options": {}}, + 1: { + "action": "node-application-execute", + "options": {"node_name": "client", "application_name": "database"}, + }, + } + } + observation_config = { + "type": "file", + "options": { + "file_name": "dog.pdf", + "include_num_access": False, + "file_system_requires_scan": False, + }, + } + reward_config = { + "reward_components": [ + { + "type": "database-file-integrity", + "weight": 0.3, + "options": {"node_hostname": "server", "folder_name": "database", "file_name": "database.db"}, + } + ] + } + agent = RandomAgent( + config={ + "ref": "random_agent", + "team": "BLUE", + "action_space": action_config, + "observation_space": observation_config, + "reward_function": reward_config, + } + ) + + assert len(agent.action_manager.action_map) == 2 + assert isinstance(agent.observation_manager.obs, FileObservation) + assert len(agent.reward_function.reward_components) == 1 diff --git a/tests/unit_tests/_primaite/_game/_agent/test_agent_log.py b/tests/unit_tests/_primaite/_game/_agent/test_agent_log.py index d61e1a23..a7713437 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_agent_log.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_agent_log.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from uuid import uuid4 import pytest diff --git a/tests/unit_tests/_primaite/_game/_agent/test_observations.py b/tests/unit_tests/_primaite/_game/_agent/test_observations.py index 935bbdcf..3df6ca0a 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_observations.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_observations.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import json from typing import List @@ -22,17 +22,17 @@ class TestFileSystemRequiresScan: def test_obs_config(self, yaml_option_string, expected_val): """Check that the default behaviour is to set FileSystemRequiresScan to True.""" obs_cfg_yaml = f""" - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: - hostname: domain_controller - hostname: web_server services: - - service_name: WebServer + - service_name: web-server - hostname: database_server folders: - folder_name: database @@ -70,15 +70,15 @@ class TestFileSystemRequiresScan: wildcard_list: - 0.0.0.1 port_list: - - 80 - - 5432 + - HTTP + - POSTGRES_SERVER protocol_list: - ICMP - TCP - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: @@ -92,14 +92,14 @@ class TestFileSystemRequiresScan: - switch_2:eth-1<->client_1:eth-1 - switch_2:eth-2<->client_2:eth-1 - switch_2:eth-7<->security_suite:eth-2 - - type: "NONE" + - type: "none" label: ICS options: {{}} """ cfg = yaml.safe_load(obs_cfg_yaml) - manager = ObservationManager.from_config(cfg) + manager = ObservationManager(config=cfg) hosts: List[HostObservation] = manager.obs.components["NODES"].hosts for i, host in enumerate(hosts): @@ -120,18 +120,24 @@ class TestFileSystemRequiresScan: assert obs_not_requiring_scan.observe(file_state)["health_status"] == 3 def test_folder_require_scan(self): - folder_state = {"health_status": 3, "visible_status": 1} + folder_state = {"health_status": 3, "visible_status": 1, "scanned_this_step": False} obs_requiring_scan = FolderObservation( [], files=[], num_files=0, include_num_access=False, file_system_requires_scan=True ) - assert obs_requiring_scan.observe(folder_state)["health_status"] == 1 + assert obs_requiring_scan.observe(folder_state)["health_status"] == 0 obs_not_requiring_scan = FolderObservation( [], files=[], num_files=0, include_num_access=False, file_system_requires_scan=False ) assert obs_not_requiring_scan.observe(folder_state)["health_status"] == 3 + folder_state = {"health_status": 3, "visible_status": 1, "scanned_this_step": True} + obs_requiring_scan = FolderObservation( + [], files=[], num_files=0, include_num_access=False, file_system_requires_scan=True + ) + assert obs_requiring_scan.observe(folder_state)["health_status"] == 1 + class TestServicesRequiresScan: @pytest.mark.parametrize( @@ -145,18 +151,18 @@ class TestServicesRequiresScan: def test_obs_config(self, yaml_option_string, expected_val): """Check that the default behaviour is to set service_requires_scan to True.""" obs_cfg_yaml = f""" - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: - hostname: domain_controller - hostname: web_server services: - - service_name: WebServer - - service_name: DNSClient + - service_name: web-server + - service_name: dns-client - hostname: database_server folders: - folder_name: database @@ -164,7 +170,7 @@ class TestServicesRequiresScan: - file_name: database.db - hostname: backup_server services: - - service_name: FTPServer + - service_name: ftp-server - hostname: security_suite - hostname: client_1 - hostname: client_2 @@ -204,7 +210,7 @@ class TestServicesRequiresScan: - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: @@ -218,7 +224,7 @@ class TestServicesRequiresScan: - switch_2:eth-1<->client_1:eth-1 - switch_2:eth-2<->client_2:eth-1 - switch_2:eth-7<->security_suite:eth-2 - - type: "NONE" + - type: none label: ICS options: {{}} @@ -257,10 +263,10 @@ class TestApplicationsRequiresScan: def test_obs_config(self, yaml_option_string, expected_val): """Check that the default behaviour is to set applications_requires_scan to True.""" obs_cfg_yaml = f""" - type: CUSTOM + type: custom options: components: - - type: NODES + - type: nodes label: NODES options: hosts: @@ -275,11 +281,11 @@ class TestApplicationsRequiresScan: - hostname: security_suite - hostname: client_1 applications: - - application_name: WebBrowser + - application_name: web-browser - hostname: client_2 applications: - - application_name: WebBrowser - - application_name: DatabaseClient + - application_name: web-browser + - application_name: database-client num_services: 0 num_applications: 3 num_folders: 1 @@ -316,7 +322,7 @@ class TestApplicationsRequiresScan: - UDP num_rules: 10 - - type: LINKS + - type: links label: LINKS options: link_references: @@ -330,7 +336,7 @@ class TestApplicationsRequiresScan: - switch_2:eth-1<->client_1:eth-1 - switch_2:eth-2<->client_2:eth-1 - switch_2:eth-7<->security_suite:eth-2 - - type: "NONE" + - type: none label: ICS options: {{}} diff --git a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py index ec18f1fb..305375f9 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py @@ -1,8 +1,9 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from primaite.game.agent.actions import ActionManager from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager from primaite.game.agent.rewards import RewardFunction from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent +from primaite.game.game import PrimaiteGame, PrimaiteGameOptions def test_probabilistic_agent(): @@ -16,69 +17,58 @@ def test_probabilistic_agent(): """ N_TRIALS = 10_000 P_DO_NOTHING = 0.1 - P_NODE_APPLICATION_EXECUTE = 0.3 - P_NODE_FILE_DELETE = 0.6 + P_node_application_execute = 0.3 + P_node_file_delete = 0.6 MIN_DO_NOTHING = 850 MAX_DO_NOTHING = 1150 - MIN_NODE_APPLICATION_EXECUTE = 2800 - MAX_NODE_APPLICATION_EXECUTE = 3200 - MIN_NODE_FILE_DELETE = 5750 - MAX_NODE_FILE_DELETE = 6250 + MIN_node_application_execute = 2800 + MAX_node_application_execute = 3200 + MIN_node_file_delete = 5750 + MAX_node_file_delete = 6250 - action_space = ActionManager( - actions=[ - {"type": "DONOTHING"}, - {"type": "NODE_APPLICATION_EXECUTE"}, - {"type": "NODE_FILE_DELETE"}, - ], - nodes=[ - { - "node_name": "client_1", - "applications": [{"application_name": "WebBrowser"}], - "folders": [{"folder_name": "downloads", "files": [{"file_name": "cat.png"}]}], + action_space_cfg = { + "action_map": { + 0: {"action": "do-nothing", "options": {}}, + 1: { + "action": "node-application-execute", + "options": {"node_name": "client_1", "application_name": "web-browser"}, + }, + 2: { + "action": "node-file-delete", + "options": {"node_name": "client_1", "folder_name": "downloads", "file_name": "cat.png"}, }, - ], - max_folders_per_node=2, - max_files_per_folder=2, - max_services_per_node=2, - max_applications_per_node=2, - max_nics_per_node=2, - max_acl_rules=10, - protocols=["TCP", "UDP", "ICMP"], - ports=["HTTP", "DNS", "ARP"], - act_map={ - 0: {"action": "DONOTHING", "options": {}}, - 1: {"action": "NODE_APPLICATION_EXECUTE", "options": {"node_id": 0, "application_id": 0}}, - 2: {"action": "NODE_FILE_DELETE", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}}, }, - ) - observation_space = ObservationManager(NestedObservation(components={})) - reward_function = RewardFunction() + } - pa = ProbabilisticAgent( - agent_name="test_agent", - action_space=action_space, - observation_space=observation_space, - reward_function=reward_function, - settings={ - "action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE}, + game = PrimaiteGame() + game.options = PrimaiteGameOptions(ports=[], protocols=[]) + + pa_config = { + "type": "probabilistic-agent", + "ref": "probabilistic-agent", + "team": "BLUE", + "action_space": action_space_cfg, + "agent_settings": { + "action_probabilities": {0: P_DO_NOTHING, 1: P_node_application_execute, 2: P_node_file_delete}, }, - ) + } + + pa = ProbabilisticAgent.from_config(config=pa_config) do_nothing_count = 0 node_application_execute_count = 0 node_file_delete_count = 0 for _ in range(N_TRIALS): a = pa.get_action(0) - if a == ("DONOTHING", {}): + if a == ("do-nothing", {}): do_nothing_count += 1 - elif a == ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}): + elif a == ("node-application-execute", {"node_name": "client_1", "application_name": "web-browser"}): node_application_execute_count += 1 - elif a == ("NODE_FILE_DELETE", {"node_id": 0, "folder_id": 0, "file_id": 0}): + elif a == ("node-file-delete", {"node_name": "client_1", "folder_name": "downloads", "file_name": "cat.png"}): node_file_delete_count += 1 else: raise AssertionError("Probabilistic agent produced an unexpected action.") assert MIN_DO_NOTHING < do_nothing_count < MAX_DO_NOTHING - assert MIN_NODE_APPLICATION_EXECUTE < node_application_execute_count < MAX_NODE_APPLICATION_EXECUTE - assert MIN_NODE_FILE_DELETE < node_file_delete_count < MAX_NODE_FILE_DELETE + assert MIN_node_application_execute < node_application_execute_count < MAX_node_application_execute + assert MIN_node_file_delete < node_file_delete_count < MAX_node_file_delete diff --git a/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py b/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py index 58f0fcc1..935349d0 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from primaite.game.agent.interface import AgentHistoryItem from primaite.game.agent.rewards import ( @@ -11,7 +11,12 @@ from primaite.interface.request import RequestResponse class TestWebServer404PenaltySticky: def test_non_sticky(self): - reward = WebServer404Penalty("computer", "WebService", sticky=False) + schema = WebServer404Penalty.ConfigSchema( + node_hostname="computer", + service_name="WebService", + sticky=False, + ) + reward = WebServer404Penalty(config=schema) # no response codes yet, reward is 0 codes = [] @@ -38,7 +43,12 @@ class TestWebServer404PenaltySticky: assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = WebServer404Penalty("computer", "WebService", sticky=True) + schema = WebServer404Penalty.ConfigSchema( + node_hostname="computer", + service_name="WebService", + sticky=True, + ) + reward = WebServer404Penalty(config=schema) # no response codes yet, reward is 0 codes = [] @@ -67,25 +77,26 @@ class TestWebServer404PenaltySticky: class TestWebpageUnavailabilitySticky: def test_non_sticky(self): - reward = WebpageUnavailablePenalty("computer", sticky=False) + schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="computer", sticky=False) + reward = WebpageUnavailablePenalty(config=schema) # no response codes yet, reward is 0 - action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + action, params, request = "do-nothing", {}, ["do-nothing"] response = RequestResponse(status="success", data={}) browser_history = [] - state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"web-browser": {"history": browser_history}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) assert reward.calculate(state, last_action_response) == 0 # agent did a successful fetch - action = "NODE_APPLICATION_EXECUTE" - params = {"node_id": 0, "application_id": 0} - request = ["network", "node", "computer", "application", "WebBrowser", "execute"] + action = "node-application-execute" + params = {"node_name": "computer", "application_name": "web-browser"} + request = ["network", "node", "computer", "application", "web-browser", "execute"] response = RequestResponse(status="success", data={}) browser_history.append({"outcome": 200}) - state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"web-browser": {"history": browser_history}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) @@ -93,59 +104,60 @@ class TestWebpageUnavailabilitySticky: # THE IMPORTANT BIT # agent did nothing, because reward is not sticky, it goes back to 0 - action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + action, params, request = "do-nothing", {}, ["do-nothing"] response = RequestResponse(status="success", data={}) browser_history = [] - state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"web-browser": {"history": browser_history}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) assert reward.calculate(state, last_action_response) == 0.0 # agent fails to fetch, get a -1.0 reward - action = "NODE_APPLICATION_EXECUTE" - params = {"node_id": 0, "application_id": 0} - request = ["network", "node", "computer", "application", "WebBrowser", "execute"] + action = "node-application-execute" + params = {"node_name": "computer", "application_name": "web-browser"} + request = ["network", "node", "computer", "application", "web-browser", "execute"] response = RequestResponse(status="failure", data={}) browser_history.append({"outcome": 404}) - state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"web-browser": {"history": browser_history}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) assert reward.calculate(state, last_action_response) == -1.0 # agent fails again to fetch, get a -1.0 reward again - action = "NODE_APPLICATION_EXECUTE" - params = {"node_id": 0, "application_id": 0} - request = ["network", "node", "computer", "application", "WebBrowser", "execute"] + action = "node-application-execute" + params = {"node_name": "computer", "application_name": "web-browser"} + request = ["network", "node", "computer", "application", "web-browser", "execute"] response = RequestResponse(status="failure", data={}) browser_history.append({"outcome": 404}) - state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"web-browser": {"history": browser_history}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = WebpageUnavailablePenalty("computer", sticky=True) + schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="computer", sticky=True) + reward = WebpageUnavailablePenalty(config=schema) # no response codes yet, reward is 0 - action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + action, params, request = "do-nothing", {}, ["do-nothing"] response = RequestResponse(status="success", data={}) browser_history = [] - state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"web-browser": {"history": browser_history}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) assert reward.calculate(state, last_action_response) == 0 # agent did a successful fetch - action = "NODE_APPLICATION_EXECUTE" - params = {"node_id": 0, "application_id": 0} - request = ["network", "node", "computer", "application", "WebBrowser", "execute"] + action = "node-application-execute" + params = {"node_name": "computer", "application_name": "web-browser"} + request = ["network", "node", "computer", "application", "web-browser", "execute"] response = RequestResponse(status="success", data={}) browser_history.append({"outcome": 200}) - state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"web-browser": {"history": browser_history}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) @@ -153,33 +165,33 @@ class TestWebpageUnavailabilitySticky: # THE IMPORTANT BIT # agent did nothing, because reward is sticky, it stays at 1.0 - action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + action, params, request = "do-nothing", {}, ["do-nothing"] response = RequestResponse(status="success", data={}) - state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"web-browser": {"history": browser_history}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) assert reward.calculate(state, last_action_response) == 1.0 # agent fails to fetch, get a -1.0 reward - action = "NODE_APPLICATION_EXECUTE" - params = {"node_id": 0, "application_id": 0} - request = ["network", "node", "computer", "application", "WebBrowser", "execute"] + action = "node-application-execute" + params = {"node_name": "computer", "application_name": "web-browser"} + request = ["network", "node", "computer", "application", "web-browser", "execute"] response = RequestResponse(status="failure", data={}) browser_history.append({"outcome": 404}) - state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"web-browser": {"history": browser_history}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) assert reward.calculate(state, last_action_response) == -1.0 # agent fails again to fetch, get a -1.0 reward again - action = "NODE_APPLICATION_EXECUTE" - params = {"node_id": 0, "application_id": 0} - request = ["network", "node", "computer", "application", "WebBrowser", "execute"] + action = "node-application-execute" + params = {"node_name": "computer", "application_name": "web-browser"} + request = ["network", "node", "computer", "application", "web-browser", "execute"] response = RequestResponse(status="failure", data={}) browser_history.append({"outcome": 404}) - state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"web-browser": {"history": browser_history}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) @@ -188,23 +200,27 @@ class TestWebpageUnavailabilitySticky: class TestGreenAdminDatabaseUnreachableSticky: def test_non_sticky(self): - reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=False) + schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema( + node_hostname="computer", + sticky=False, + ) + reward = GreenAdminDatabaseUnreachablePenalty(config=schema) # no response codes yet, reward is 0 - action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + action, params, request = "do-nothing", {}, ["do-nothing"] response = RequestResponse(status="success", data={}) - state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"database-client": {}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) assert reward.calculate(state, last_action_response) == 0 # agent did a successful fetch - action = "NODE_APPLICATION_EXECUTE" - params = {"node_id": 0, "application_id": 0} - request = ["network", "node", "computer", "application", "DatabaseClient", "execute"] + action = "node-application-execute" + params = {"node_name": "computer", "application_name": "database-client"} + request = ["network", "node", "computer", "application", "database-client", "execute"] response = RequestResponse(status="success", data={}) - state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"database-client": {}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) @@ -212,55 +228,58 @@ class TestGreenAdminDatabaseUnreachableSticky: # THE IMPORTANT BIT # agent did nothing, because reward is not sticky, it goes back to 0 - action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + action, params, request = "do-nothing", {}, ["do-nothing"] response = RequestResponse(status="success", data={}) - browser_history = [] - state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"database-client": {}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) assert reward.calculate(state, last_action_response) == 0.0 # agent fails to fetch, get a -1.0 reward - action = "NODE_APPLICATION_EXECUTE" - params = {"node_id": 0, "application_id": 0} - request = ["network", "node", "computer", "application", "DatabaseClient", "execute"] + action = "node-application-execute" + params = {"node_name": "computer", "application_name": "database-client"} + request = ["network", "node", "computer", "application", "database-client", "execute"] response = RequestResponse(status="failure", data={}) - state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"database-client": {}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) assert reward.calculate(state, last_action_response) == -1.0 # agent fails again to fetch, get a -1.0 reward again - action = "NODE_APPLICATION_EXECUTE" - params = {"node_id": 0, "application_id": 0} - request = ["network", "node", "computer", "application", "DatabaseClient", "execute"] + action = "node-application-execute" + params = {"node_name": "computer", "application_name": "database-client"} + request = ["network", "node", "computer", "application", "database-client", "execute"] response = RequestResponse(status="failure", data={}) - state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"database-client": {}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=True) + schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema( + node_hostname="computer", + sticky=True, + ) + reward = GreenAdminDatabaseUnreachablePenalty(config=schema) # no response codes yet, reward is 0 - action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + action, params, request = "do-nothing", {}, ["do-nothing"] response = RequestResponse(status="success", data={}) - state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"database-client": {}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) assert reward.calculate(state, last_action_response) == 0 # agent did a successful fetch - action = "NODE_APPLICATION_EXECUTE" - params = {"node_id": 0, "application_id": 0} - request = ["network", "node", "computer", "application", "DatabaseClient", "execute"] + action = "node-application-execute" + params = {"node_name": "computer", "application_name": "database-client"} + request = ["network", "node", "computer", "application", "database-client", "execute"] response = RequestResponse(status="success", data={}) - state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"database-client": {}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) @@ -268,31 +287,31 @@ class TestGreenAdminDatabaseUnreachableSticky: # THE IMPORTANT BIT # agent did nothing, because reward is not sticky, it goes back to 0 - action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + action, params, request = "do-nothing", {}, ["do-nothing"] response = RequestResponse(status="success", data={}) - state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"database-client": {}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) assert reward.calculate(state, last_action_response) == 1.0 # agent fails to fetch, get a -1.0 reward - action = "NODE_APPLICATION_EXECUTE" - params = {"node_id": 0, "application_id": 0} - request = ["network", "node", "computer", "application", "DatabaseClient", "execute"] + action = "node-application-execute" + params = {"node_name": "computer", "application_name": "database-client"} + request = ["network", "node", "computer", "application", "database-client", "execute"] response = RequestResponse(status="failure", data={}) - state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"database-client": {}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) assert reward.calculate(state, last_action_response) == -1.0 # agent fails again to fetch, get a -1.0 reward again - action = "NODE_APPLICATION_EXECUTE" - params = {"node_id": 0, "application_id": 0} - request = ["network", "node", "computer", "application", "DatabaseClient", "execute"] + action = "node-application-execute" + params = {"node_name": "computer", "application_name": "database-client"} + request = ["network", "node", "computer", "application", "database-client", "execute"] response = RequestResponse(status="failure", data={}) - state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + state = {"network": {"nodes": {"computer": {"applications": {"database-client": {}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response ) diff --git a/tests/unit_tests/_primaite/_interface/__init__.py b/tests/unit_tests/_primaite/_interface/__init__.py index be6c00e7..836b79af 100644 --- a/tests/unit_tests/_primaite/_interface/__init__.py +++ b/tests/unit_tests/_primaite/_interface/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_interface/test_request.py b/tests/unit_tests/_primaite/_interface/test_request.py index 6067f9e4..d9fae083 100644 --- a/tests/unit_tests/_primaite/_interface/test_request.py +++ b/tests/unit_tests/_primaite/_interface/test_request.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from pydantic import ValidationError diff --git a/tests/unit_tests/_primaite/_session/__init__.py b/tests/unit_tests/_primaite/_session/__init__.py index be6c00e7..836b79af 100644 --- a/tests/unit_tests/_primaite/_session/__init__.py +++ b/tests/unit_tests/_primaite/_session/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_session/test_episode_schedule.py b/tests/unit_tests/_primaite/_session/test_episode_schedule.py index 21448339..ff26bb02 100644 --- a/tests/unit_tests/_primaite/_session/test_episode_schedule.py +++ b/tests/unit_tests/_primaite/_session/test_episode_schedule.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest import yaml diff --git a/tests/unit_tests/_primaite/_simulator/__init__.py b/tests/unit_tests/_primaite/_simulator/__init__.py index be6c00e7..836b79af 100644 --- a/tests/unit_tests/_primaite/_simulator/__init__.py +++ b/tests/unit_tests/_primaite/_simulator/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_simulator/_domain/__init__.py b/tests/unit_tests/_primaite/_simulator/_domain/__init__.py index be6c00e7..836b79af 100644 --- a/tests/unit_tests/_primaite/_simulator/_domain/__init__.py +++ b/tests/unit_tests/_primaite/_simulator/_domain/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py index 8db68565..f5294844 100644 --- a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py +++ b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """Test the account module of the simulator.""" import pytest diff --git a/tests/unit_tests/_primaite/_simulator/_domain/test_controller.py b/tests/unit_tests/_primaite/_simulator/_domain/test_controller.py index be6c00e7..836b79af 100644 --- a/tests/unit_tests/_primaite/_simulator/_domain/test_controller.py +++ b/tests/unit_tests/_primaite/_simulator/_domain/test_controller.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/__init__.py b/tests/unit_tests/_primaite/_simulator/_file_system/__init__.py index be6c00e7..836b79af 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/__init__.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file.py index 0b9bdc8e..e5e79e5f 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import warnings import pytest @@ -22,12 +22,12 @@ def test_file_scan(file_system): file: File = file_system.create_file(file_name="test_file.txt", folder_name="test_folder") assert file.health_status == FileSystemItemHealthStatus.GOOD - assert file.visible_health_status == FileSystemItemHealthStatus.GOOD + assert file.visible_health_status == FileSystemItemHealthStatus.NONE file.corrupt() assert file.health_status == FileSystemItemHealthStatus.CORRUPT - assert file.visible_health_status == FileSystemItemHealthStatus.GOOD + assert file.visible_health_status == FileSystemItemHealthStatus.NONE file.scan() @@ -46,7 +46,7 @@ def test_file_reveal_to_red_scan(file_system): assert file.revealed_to_red is True -@pytest.mark.skip(reason="NODE_FILE_CHECKHASH not implemented") +@pytest.mark.skip(reason="node-file-checkhash not implemented") def test_simulated_file_check_hash(file_system): file: File = file_system.create_file(file_name="test_file.txt", folder_name="test_folder") diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_actions.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_actions.py index 594c7afe..13e3cbe2 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_actions.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Tuple import pytest @@ -24,7 +24,7 @@ def test_file_scan_request(populated_file_system): file.corrupt() assert file.health_status == FileSystemItemHealthStatus.CORRUPT - assert file.visible_health_status == FileSystemItemHealthStatus.GOOD + assert file.visible_health_status == FileSystemItemHealthStatus.NONE fs.apply_request(request=["folder", folder.name, "file", file.name, "scan"]) @@ -32,7 +32,7 @@ def test_file_scan_request(populated_file_system): assert file.visible_health_status == FileSystemItemHealthStatus.CORRUPT -@pytest.mark.skip(reason="NODE_FILE_CHECKHASH not implemented") +@pytest.mark.skip(reason="node-file-checkhash not implemented") def test_file_checkhash_request(populated_file_system): """Test that an agent can request a file hash check.""" fs, folder, file = populated_file_system @@ -94,7 +94,7 @@ def test_deleted_file_cannot_be_interacted_with(populated_file_system): assert fs.get_file(folder_name=folder.name, file_name=file.name).health_status == FileSystemItemHealthStatus.CORRUPT assert ( fs.get_file(folder_name=folder.name, file_name=file.name).visible_health_status - == FileSystemItemHealthStatus.GOOD + == FileSystemItemHealthStatus.NONE ) fs.apply_request(request=["delete", "file", folder.name, file.name]) diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py index 4eb0dd10..5554b9ef 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from primaite.simulator.file_system.file import File diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_actions.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_actions.py index 7d022ea4..44a4e22a 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_actions.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Tuple import pytest diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_folder.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_folder.py index 724d7903..fd581ea6 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_folder.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_folder.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from primaite.simulator.file_system.file import File @@ -44,25 +44,25 @@ def test_folder_scan(file_system): file2: File = folder.get_file_by_id(file_uuid=list(folder.files)[0]) assert folder.health_status == FileSystemItemHealthStatus.GOOD - assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file1.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD + assert folder.visible_health_status == FileSystemItemHealthStatus.NONE + assert file1.visible_health_status == FileSystemItemHealthStatus.NONE + assert file2.visible_health_status == FileSystemItemHealthStatus.NONE folder.corrupt() assert folder.health_status == FileSystemItemHealthStatus.CORRUPT - assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file1.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD + assert folder.visible_health_status == FileSystemItemHealthStatus.NONE + assert file1.visible_health_status == FileSystemItemHealthStatus.NONE + assert file2.visible_health_status == FileSystemItemHealthStatus.NONE folder.scan() folder.apply_timestep(timestep=0) assert folder.health_status == FileSystemItemHealthStatus.CORRUPT - assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file1.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD + assert folder.visible_health_status == FileSystemItemHealthStatus.NONE + assert file1.visible_health_status == FileSystemItemHealthStatus.NONE + assert file2.visible_health_status == FileSystemItemHealthStatus.NONE folder.apply_timestep(timestep=1) folder.apply_timestep(timestep=2) @@ -120,7 +120,7 @@ def test_folder_corrupt_repair(file_system): assert file.health_status == FileSystemItemHealthStatus.GOOD -@pytest.mark.skip(reason="NODE_FILE_CHECKHASH not implemented") +@pytest.mark.skip(reason="node-file-checkhash not implemented") def test_simulated_folder_check_hash(file_system): folder: Folder = file_system.create_folder(folder_name="test_folder") file_system.create_file(file_name="test_file.txt", folder_name="test_folder") diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_folder_actions.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_folder_actions.py index 4a561b97..1eba3e55 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_folder_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_folder_actions.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import warnings from typing import Tuple @@ -29,18 +29,18 @@ def test_folder_scan_request(populated_file_system): folder.corrupt() assert folder.health_status == FileSystemItemHealthStatus.CORRUPT - assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file1.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD + assert folder.visible_health_status == FileSystemItemHealthStatus.NONE + assert file1.visible_health_status == FileSystemItemHealthStatus.NONE + assert file2.visible_health_status == FileSystemItemHealthStatus.NONE fs.apply_request(request=["folder", folder.name, "scan"]) folder.apply_timestep(timestep=0) assert folder.health_status == FileSystemItemHealthStatus.CORRUPT - assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file1.visible_health_status == FileSystemItemHealthStatus.GOOD - assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD + assert folder.visible_health_status == FileSystemItemHealthStatus.NONE + assert file1.visible_health_status == FileSystemItemHealthStatus.NONE + assert file2.visible_health_status == FileSystemItemHealthStatus.NONE folder.apply_timestep(timestep=1) folder.apply_timestep(timestep=2) @@ -51,7 +51,7 @@ def test_folder_scan_request(populated_file_system): assert file2.visible_health_status == FileSystemItemHealthStatus.CORRUPT -@pytest.mark.skip(reason="NODE_FOLDER_CHECKHASH not implemented") +@pytest.mark.skip(reason="node-folder-checkhash not implemented") def test_folder_checkhash_request(populated_file_system): """Test that an agent can request a folder hash check.""" fs, folder, file = populated_file_system diff --git a/tests/unit_tests/_primaite/_simulator/_network/__init__.py b/tests/unit_tests/_primaite/_simulator/_network/__init__.py index be6c00e7..836b79af 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/__init__.py +++ b/tests/unit_tests/_primaite/_simulator/_network/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/__init__.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/__init__.py index be6c00e7..836b79af 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/__init__.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/__init__.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/__init__.py index be6c00e7..836b79af 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/__init__.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py index 9bc1abfd..ee7eb08f 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address import pytest @@ -7,8 +7,10 @@ from primaite.simulator.network.hardware.base import generate_mac_address from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame -from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader, UDPHeader +from primaite.simulator.network.transmission.network_layer import IPPacket +from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -23,25 +25,26 @@ def router_with_acl_rules(): :return: A configured Router object with ACL rules. """ - router = Router("Router") + router_cfg = {"hostname": "router_1", "type": "router"} + router = Router.from_config(config=router_cfg) acl = router.acl # Add rules here as needed acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="192.168.1.1", - src_port=Port.HTTPS, + src_port=PORT_LOOKUP["HTTPS"], dst_ip_address="192.168.1.2", - dst_port=Port.HTTP, + dst_port=PORT_LOOKUP["HTTP"], position=1, ) acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="192.168.1.3", - src_port=Port(8080), + src_port=8080, dst_ip_address="192.168.1.4", - dst_port=Port(80), + dst_port=80, position=2, ) return router @@ -60,26 +63,27 @@ def router_with_wildcard_acl(): :return: A Router object with configured ACL rules, including rules with wildcard masking. """ - router = Router("Router") + router_cfg = {"hostname": "router_1", "type": "router"} + router = Router.from_config(config=router_cfg) acl = router.acl # Rule to permit traffic from a specific source IP and port to a specific destination IP and port acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="192.168.1.1", - src_port=Port(8080), + src_port=8080, dst_ip_address="10.1.1.2", - dst_port=Port(80), + dst_port=80, position=1, ) # Rule to deny traffic from an IP range to a specific destination IP and port acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="192.168.1.0", src_wildcard_mask="0.0.0.255", dst_ip_address="10.1.1.3", - dst_port=Port(443), + dst_port=443, position=2, ) # Rule to permit any traffic to a range of destination IPs @@ -109,11 +113,11 @@ def test_add_rule(router_with_acl_rules): acl = router_with_acl_rules.acl assert acl.acl[1].action == ACLAction.PERMIT - assert acl.acl[1].protocol == IPProtocol.TCP + assert acl.acl[1].protocol == PROTOCOL_LOOKUP["TCP"] assert acl.acl[1].src_ip_address == IPv4Address("192.168.1.1") - assert acl.acl[1].src_port == Port.HTTPS + assert acl.acl[1].src_port == PORT_LOOKUP["HTTPS"] assert acl.acl[1].dst_ip_address == IPv4Address("192.168.1.2") - assert acl.acl[1].dst_port == Port.HTTP + assert acl.acl[1].dst_port == PORT_LOOKUP["HTTP"] def test_remove_rule(router_with_acl_rules): @@ -136,8 +140,8 @@ def test_traffic_permitted_by_specific_rule(router_with_acl_rules): acl = router_with_acl_rules.acl permitted_frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="192.168.1.2", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port.HTTPS, dst_port=Port.HTTP), + ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="192.168.1.2", protocol=PROTOCOL_LOOKUP["TCP"]), + tcp=TCPHeader(src_port=PORT_LOOKUP["HTTPS"], dst_port=PORT_LOOKUP["HTTP"]), ) is_permitted, _ = acl.is_permitted(permitted_frame) assert is_permitted @@ -153,8 +157,8 @@ def test_traffic_denied_by_specific_rule(router_with_acl_rules): acl = router_with_acl_rules.acl not_permitted_frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.3", dst_ip_address="192.168.1.4", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port(8080), dst_port=Port(80)), + ip=IPPacket(src_ip_address="192.168.1.3", dst_ip_address="192.168.1.4", protocol=PROTOCOL_LOOKUP["TCP"]), + tcp=TCPHeader(src_port=8080, dst_port=80), ) is_permitted, _ = acl.is_permitted(not_permitted_frame) assert not is_permitted @@ -173,8 +177,8 @@ def test_default_rule(router_with_acl_rules): acl = router_with_acl_rules.acl not_permitted_frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.5", dst_ip_address="192.168.1.12", protocol=IPProtocol.UDP), - udp=UDPHeader(src_port=Port.HTTPS, dst_port=Port.HTTP), + ip=IPPacket(src_ip_address="192.168.1.5", dst_ip_address="192.168.1.12", protocol=PROTOCOL_LOOKUP["UDP"]), + udp=UDPHeader(src_port=PORT_LOOKUP["HTTPS"], dst_port=PORT_LOOKUP["HTTP"]), ) is_permitted, rule = acl.is_permitted(not_permitted_frame) assert not is_permitted @@ -189,8 +193,8 @@ def test_direct_ip_match_with_acl(router_with_wildcard_acl): acl = router_with_wildcard_acl.acl frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="10.1.1.2", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port(8080), dst_port=Port(80)), + ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="10.1.1.2", protocol=PROTOCOL_LOOKUP["TCP"]), + tcp=TCPHeader(src_port=8080, dst_port=80), ) assert acl.is_permitted(frame)[0], "Direct IP match should be permitted." @@ -204,8 +208,8 @@ def test_ip_range_match_denied_with_acl(router_with_wildcard_acl): acl = router_with_wildcard_acl.acl frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.100", dst_ip_address="10.1.1.3", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port(8080), dst_port=Port(443)), + ip=IPPacket(src_ip_address="192.168.1.100", dst_ip_address="10.1.1.3", protocol=PROTOCOL_LOOKUP["TCP"]), + tcp=TCPHeader(src_port=8080, dst_port=443), ) assert not acl.is_permitted(frame)[0], "IP range match with wildcard mask should be denied." @@ -219,8 +223,8 @@ def test_traffic_permitted_to_destination_range_with_acl(router_with_wildcard_ac acl = router_with_wildcard_acl.acl frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=IPProtocol.UDP), - udp=UDPHeader(src_port=Port(1433), dst_port=Port(1433)), + ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=PROTOCOL_LOOKUP["UDP"]), + udp=UDPHeader(src_port=1433, dst_port=1433), ) assert acl.is_permitted(frame)[0], "Traffic to destination IP range should be permitted." @@ -241,7 +245,8 @@ def test_ip_traffic_from_specific_subnet(): - Traffic from outside the 192.168.1.0/24 subnet is denied. """ - router = Router("Router") + router_cfg = {"hostname": "router_1", "type": "router"} + router = Router.from_config(config=router_cfg) acl = router.acl # Add rules here as needed acl.add_rule( @@ -253,23 +258,23 @@ def test_ip_traffic_from_specific_subnet(): permitted_frame_1 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER), + ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=PROTOCOL_LOOKUP["TCP"]), + tcp=TCPHeader(src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"]), ) assert acl.is_permitted(permitted_frame_1)[0] permitted_frame_2 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.10", dst_ip_address="85.199.214.101", protocol=IPProtocol.UDP), - udp=UDPHeader(src_port=Port.NTP, dst_port=Port.NTP), + ip=IPPacket(src_ip_address="192.168.1.10", dst_ip_address="85.199.214.101", protocol=PROTOCOL_LOOKUP["UDP"]), + udp=UDPHeader(src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"]), ) assert acl.is_permitted(permitted_frame_2)[0] permitted_frame_3 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.200", dst_ip_address="192.168.1.1", protocol=IPProtocol.ICMP), + ip=IPPacket(src_ip_address="192.168.1.200", dst_ip_address="192.168.1.1", protocol=PROTOCOL_LOOKUP["ICMP"]), icmp=ICMPPacket(identifier=1), ) @@ -277,16 +282,16 @@ def test_ip_traffic_from_specific_subnet(): not_permitted_frame_1 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.0.50", dst_ip_address="10.2.200.200", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER), + ip=IPPacket(src_ip_address="192.168.0.50", dst_ip_address="10.2.200.200", protocol=PROTOCOL_LOOKUP["TCP"]), + tcp=TCPHeader(src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"]), ) assert not acl.is_permitted(not_permitted_frame_1)[0] not_permitted_frame_2 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.2.10", dst_ip_address="85.199.214.101", protocol=IPProtocol.UDP), - udp=UDPHeader(src_port=Port.NTP, dst_port=Port.NTP), + ip=IPPacket(src_ip_address="192.168.2.10", dst_ip_address="85.199.214.101", protocol=PROTOCOL_LOOKUP["UDP"]), + udp=UDPHeader(src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"]), ) assert not acl.is_permitted(not_permitted_frame_2)[0] diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py index d4e38ded..e1a910b8 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py @@ -1,14 +1,13 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP def test_wireless_router_from_config(): cfg = { - "ref": "router_1", "type": "router", "hostname": "router_1", "num_ports": 6, @@ -50,9 +49,9 @@ def test_wireless_router_from_config(): }, } - rt = Router.from_config(cfg=cfg) + rt = Router.from_config(config=cfg) - assert rt.num_ports == 6 + assert rt.config.num_ports == 6 assert rt.network_interface[1].ip_address == IPv4Address("192.168.1.1") assert rt.network_interface[1].subnet_mask == IPv4Address("255.255.255.0") @@ -67,12 +66,12 @@ def test_wireless_router_from_config(): r0 = rt.acl.acl[0] assert r0.action == ACLAction.PERMIT - assert r0.src_port == r0.dst_port == Port.POSTGRES_SERVER + assert r0.src_port == r0.dst_port == PORT_LOOKUP["POSTGRES_SERVER"] assert r0.src_ip_address == r0.dst_ip_address == r0.dst_wildcard_mask == r0.src_wildcard_mask == r0.protocol == None r1 = rt.acl.acl[1] assert r1.action == ACLAction.PERMIT - assert r1.protocol == IPProtocol.ICMP + assert r1.protocol == PROTOCOL_LOOKUP["ICMP"] assert ( r1.src_ip_address == r1.dst_ip_address diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_switch.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_switch.py index 2613d536..94b1764d 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_switch.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_switch.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState @@ -7,7 +7,8 @@ from primaite.simulator.network.hardware.nodes.network.switch import Switch @pytest.fixture(scope="function") def switch() -> Switch: - switch: Switch = Switch(hostname="switch_1", num_ports=8, start_up_duration=0) + switch_cfg = {"type": "switch", "hostname": "switch_1", "num_ports": 8, "start_up_duration": 0} + switch: Switch = Switch.from_config(config=switch_cfg) switch.power_on() switch.show() return switch @@ -16,4 +17,3 @@ def switch() -> Switch: def test_describe_state(switch): state = switch.describe_state() assert len(state.get("ports")) is 8 - assert state.get("num_ports") is 8 diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py index f35cf171..cb2d3935 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from primaite.simulator.network.hardware.base import NetworkInterface, Node @@ -7,7 +7,10 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer @pytest.fixture def node() -> Node: - return Computer(hostname="test", ip_address="192.168.1.2", subnet_mask="255.255.255.0") + computer_cfg = {"type": "computer", "hostname": "test", "ip_address": "192.168.1.2", "subnet_mask": "255.255.255.0"} + computer = Computer.from_config(config=computer_cfg) + + return computer def test_nic_enabled_validator(node): diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_nic.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_nic.py index 29d5ec67..f9ff0328 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_nic.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_nic.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import re from ipaddress import IPv4Address diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py index 44c5c781..425c0887 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from primaite.simulator.file_system.file import File @@ -12,7 +12,16 @@ from tests.conftest import DummyApplication, DummyService @pytest.fixture def node() -> Node: - return Computer(hostname="test", ip_address="192.168.1.2", subnet_mask="255.255.255.0") + computer_cfg = { + "type": "computer", + "hostname": "test", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "operating_state": "OFF", + } + computer = Computer.from_config(config=computer_cfg) + + return computer def test_node_startup(node): @@ -57,26 +66,26 @@ def test_node_os_scan(node): # add services to node node.software_manager.install(DummyService) - service = node.software_manager.software.get("DummyService") + service = node.software_manager.software.get("dummy-service") service.set_health_state(SoftwareHealthState.COMPROMISED) assert service.health_state_visible == SoftwareHealthState.UNUSED # add application to node node.software_manager.install(DummyApplication) - application = node.software_manager.software.get("DummyApplication") + application = node.software_manager.software.get("dummy-application") application.set_health_state(SoftwareHealthState.COMPROMISED) assert application.health_state_visible == SoftwareHealthState.UNUSED # add folder and file to node folder: Folder = node.file_system.create_folder(folder_name="test_folder") folder.corrupt() - assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD + assert folder.visible_health_status == FileSystemItemHealthStatus.NONE file: File = node.file_system.create_file(folder_name="test_folder", file_name="file.txt") file2: File = node.file_system.create_file(folder_name="test_folder", file_name="file2.txt") file.corrupt() file2.corrupt() - assert file.visible_health_status == FileSystemItemHealthStatus.GOOD + assert file.visible_health_status == FileSystemItemHealthStatus.NONE # run os scan node.apply_request(["os", "scan"]) @@ -103,12 +112,12 @@ def test_node_red_scan(node): # add services to node node.software_manager.install(DummyService) - service = node.software_manager.software.get("DummyService") + service = node.software_manager.software.get("dummy-service") assert service.revealed_to_red is False # add application to node node.software_manager.install(DummyApplication) - application = node.software_manager.software.get("DummyApplication") + application = node.software_manager.software.get("dummy-application") application.set_health_state(SoftwareHealthState.COMPROMISED) assert application.revealed_to_red is False @@ -166,7 +175,7 @@ def test_node_is_on_validator(node): """Test that the node is on validator.""" node.power_on() - for i in range(node.start_up_duration + 1): + for i in range(node.config.start_up_duration + 1): node.apply_timestep(i) validator = Node._NodeIsOnValidator(node=node) @@ -174,7 +183,7 @@ def test_node_is_on_validator(node): assert validator(request=[], context={}) node.power_off() - for i in range(node.shut_down_duration + 1): + for i in range(node.config.shut_down_duration + 1): node.apply_timestep(i) assert validator(request=[], context={}) is False @@ -184,7 +193,7 @@ def test_node_is_off_validator(node): """Test that the node is on validator.""" node.power_on() - for i in range(node.start_up_duration + 1): + for i in range(node.config.start_up_duration + 1): node.apply_timestep(i) validator = Node._NodeIsOffValidator(node=node) @@ -192,7 +201,7 @@ def test_node_is_off_validator(node): assert validator(request=[], context={}) is False node.power_off() - for i in range(node.shut_down_duration + 1): + for i in range(node.config.shut_down_duration + 1): node.apply_timestep(i) assert validator(request=[], context={}) diff --git a/tests/unit_tests/_primaite/_simulator/_network/_transmission/__init__.py b/tests/unit_tests/_primaite/_simulator/_network/_transmission/__init__.py index be6c00e7..836b79af 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_transmission/__init__.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_transmission/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py index 92618baa..161d9cb4 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py @@ -1,11 +1,13 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame -from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol, Precedence +from primaite.simulator.network.transmission.network_layer import IPPacket, Precedence from primaite.simulator.network.transmission.primaite_layer import AgentSource, DataStatus -from primaite.simulator.network.transmission.transport_layer import Port, TCPFlags, TCPHeader, UDPHeader +from primaite.simulator.network.transmission.transport_layer import TCPFlags, TCPHeader, UDPHeader +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP def test_frame_minimal_instantiation(): @@ -20,7 +22,7 @@ def test_frame_minimal_instantiation(): ) # Check network layer default values - assert frame.ip.protocol == IPProtocol.TCP + assert frame.ip.protocol == PROTOCOL_LOOKUP["TCP"] assert frame.ip.ttl == 64 assert frame.ip.precedence == Precedence.ROUTINE @@ -40,7 +42,7 @@ def test_frame_creation_fails_tcp_without_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.TCP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["TCP"]), ) @@ -49,7 +51,7 @@ def test_frame_creation_fails_udp_without_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.UDP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["UDP"]), ) @@ -58,7 +60,7 @@ def test_frame_creation_fails_tcp_with_udp_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.TCP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["TCP"]), udp=UDPHeader(src_port=8080, dst_port=80), ) @@ -68,7 +70,7 @@ def test_frame_creation_fails_udp_with_tcp_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.UDP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["UDP"]), udp=TCPHeader(src_port=8080, dst_port=80), ) @@ -77,7 +79,7 @@ def test_icmp_frame_creation(): """Tests Frame creation for ICMP.""" frame = Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.ICMP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["ICMP"]), icmp=ICMPPacket(), ) assert frame @@ -88,5 +90,5 @@ def test_icmp_frame_creation_fails_without_icmp_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.ICMP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["ICMP"]), ) diff --git a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_network_layer.py b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_network_layer.py index 658726b5..990a0bbf 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_network_layer.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_network_layer.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_container.py b/tests/unit_tests/_primaite/_simulator/_network/test_container.py index f764f9b5..9a54f7b2 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_container.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_container.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import json import pytest @@ -61,12 +61,12 @@ def test_apply_timestep_to_nodes(network): client_1.power_off() assert client_1.operating_state is NodeOperatingState.SHUTTING_DOWN - for i in range(client_1.shut_down_duration + 1): + for i in range(client_1.config.shut_down_duration + 1): network.apply_timestep(timestep=i) assert client_1.operating_state is NodeOperatingState.OFF - network.apply_timestep(client_1.shut_down_duration + 2) + network.apply_timestep(client_1.config.shut_down_duration + 2) assert client_1.operating_state is NodeOperatingState.OFF @@ -74,7 +74,16 @@ def test_removing_node_that_does_not_exist(network): """Node that does not exist on network should not affect existing nodes.""" assert len(network.nodes) is 7 - network.remove_node(Computer(hostname="new_node", ip_address="192.168.1.2", subnet_mask="255.255.255.0")) + network.remove_node( + Computer.from_config( + config={ + "type": "computer", + "hostname": "new_node", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + } + ) + ) assert len(network.nodes) is 7 diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_creation.py b/tests/unit_tests/_primaite/_simulator/_network/test_creation.py new file mode 100644 index 00000000..700780d0 --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_network/test_creation.py @@ -0,0 +1,69 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +import pytest + +from primaite.simulator.network.container import Network +from primaite.simulator.network.creation import NetworkNodeAdder, OfficeLANAdder + +param_names = ("lan_name", "subnet_base", "pcs_ip_block_start", "num_pcs", "include_router", "bandwidth") +param_vals = ( + ("CORP-NETWORK", 3, 10, 6, True, 45), + ("OTHER-NETWORK", 10, 25, 26, True, 100), + ("OTHER-NETWORK", 10, 25, 55, False, 100), +) +param_dicts = [dict(zip(param_names, vals)) for vals in param_vals] + + +def _assert_valid_creation(net: Network, lan_name, subnet_base, pcs_ip_block_start, num_pcs, include_router, bandwidth): + """Assert that the network contains the correct nodes as described by config items""" + num_switches = 1 if num_pcs <= 23 else num_pcs // 23 + 2 + num_routers = 1 if include_router else 0 + total_nodes = num_pcs + num_switches + num_routers + + assert all((n.config.hostname.endswith(lan_name) for n in net.nodes.values())) + assert len(net.computer_nodes) == num_pcs + assert len(net.switch_nodes) == num_switches + assert len(net.router_nodes) == num_routers + assert len(net.nodes) == total_nodes + assert all( + [str(n.network_interface[1].ip_address).startswith(f"192.168.{subnet_base}") for n in net.computer_nodes] + ) + # check that computers occupy address range 192.168.3.10 - 192.168.3.16 + computer_ip_last_octets = {str(n.network_interface[1].ip_address).split(".")[-1] for n in net.computer_nodes} + assert computer_ip_last_octets == {str(i) for i in range(pcs_ip_block_start, pcs_ip_block_start + num_pcs)} + + +@pytest.mark.parametrize("kwargs", param_dicts) +def test_office_lan_adder(kwargs): + """Assert that adding an office lan via the python API works correctly.""" + net = Network() + + office_lan_config = OfficeLANAdder.ConfigSchema( + lan_name=kwargs["lan_name"], + subnet_base=kwargs["subnet_base"], + pcs_ip_block_start=kwargs["pcs_ip_block_start"], + num_pcs=kwargs["num_pcs"], + include_router=kwargs["include_router"], + bandwidth=kwargs["bandwidth"], + ) + OfficeLANAdder.add_nodes_to_net(config=office_lan_config, network=net) + + _assert_valid_creation(net=net, **kwargs) + + +@pytest.mark.parametrize("kwargs", param_dicts) +def test_office_lan_from_config(kwargs): + """Assert that the base class can add an office lan given a config dict.""" + net = Network() + + config = dict( + type="office-lan", + lan_name=kwargs["lan_name"], + subnet_base=kwargs["subnet_base"], + pcs_ip_block_start=kwargs["pcs_ip_block_start"], + num_pcs=kwargs["num_pcs"], + include_router=kwargs["include_router"], + bandwidth=kwargs["bandwidth"], + ) + + NetworkNodeAdder.from_config(config=config, network=net) + _assert_valid_creation(net=net, **kwargs) diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_utils.py b/tests/unit_tests/_primaite/_simulator/_network/test_utils.py index c80189c1..d86aa876 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_utils.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_utils.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from primaite.simulator.network.utils import convert_bytes_to_megabits, convert_megabits_to_bytes diff --git a/tests/unit_tests/_primaite/_simulator/_system/__init__.py b/tests/unit_tests/_primaite/_simulator/_system/__init__.py index be6c00e7..836b79af 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/__init__.py +++ b/tests/unit_tests/_primaite/_simulator/_system/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/__init__.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/__init__.py index be6c00e7..836b79af 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/__init__.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/__init__.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/__init__.py index be6c00e7..836b79af 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/__init__.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py index 885a3cb6..64dbdd52 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py @@ -1,14 +1,14 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.red_applications.c2.c2_beacon import C2Beacon from primaite.simulator.system.applications.red_applications.c2.c2_server import C2Command, C2Server +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -16,19 +16,27 @@ def basic_c2_network() -> Network: network = Network() # Creating two generic nodes for the C2 Server and the C2 Beacon. + computer_a_cfg = { + "type": "computer", + "hostname": "computer_a", + "ip_address": "192.168.0.1", + "subnet_mask": "255.255.255.252", + "start_up_duration": 0, + } + computer_a = Computer.from_config(config=computer_a_cfg) - computer_a = Computer( - hostname="computer_a", - ip_address="192.168.0.1", - subnet_mask="255.255.255.252", - start_up_duration=0, - ) computer_a.power_on() computer_a.software_manager.install(software_class=C2Server) - computer_b = Computer( - hostname="computer_b", ip_address="192.168.0.2", subnet_mask="255.255.255.252", start_up_duration=0 - ) + computer_b_cfg = { + "type": "computer", + "hostname": "computer_b", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.252", + "start_up_duration": 0, + } + + computer_b = Computer.from_config(config=computer_b_cfg) computer_b.power_on() computer_b.software_manager.install(software_class=C2Beacon) @@ -44,8 +52,8 @@ def setup_c2(given_network: Network): computer_a: Computer = network.get_node_by_hostname("computer_a") computer_b: Computer = network.get_node_by_hostname("computer_b") - c2_beacon: C2Beacon = computer_b.software_manager.software.get("C2Beacon") - c2_server: C2Server = computer_a.software_manager.software.get("C2Server") + c2_beacon: C2Beacon = computer_b.software_manager.software.get("c2-beacon") + c2_server: C2Server = computer_a.software_manager.software.get("c2-server") c2_beacon.configure(c2_server_ip_address="192.168.0.1", keep_alive_frequency=2) c2_server.run() @@ -128,20 +136,20 @@ def test_c2_handle_switching_port(basic_c2_network): assert c2_server.c2_connection_active is True # Assert to confirm that both the C2 server and the C2 beacon are configured correctly. - assert c2_beacon.c2_config.keep_alive_frequency is 2 - assert c2_beacon.c2_config.masquerade_port is Port.HTTP - assert c2_beacon.c2_config.masquerade_protocol is IPProtocol.TCP + assert c2_beacon.config.keep_alive_frequency is 2 + assert c2_beacon.config.masquerade_port is PORT_LOOKUP["HTTP"] + assert c2_beacon.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] - assert c2_server.c2_config.keep_alive_frequency is 2 - assert c2_server.c2_config.masquerade_port is Port.HTTP - assert c2_server.c2_config.masquerade_protocol is IPProtocol.TCP + assert c2_server.config.keep_alive_frequency is 2 + assert c2_server.config.masquerade_port is PORT_LOOKUP["HTTP"] + assert c2_server.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] # Configuring the C2 Beacon. c2_beacon.configure( c2_server_ip_address="192.168.0.1", keep_alive_frequency=2, - masquerade_port=Port.FTP, - masquerade_protocol=IPProtocol.TCP, + masquerade_port=PORT_LOOKUP["FTP"], + masquerade_protocol=PROTOCOL_LOOKUP["TCP"], ) # Asserting that the c2 applications have established a c2 connection @@ -150,11 +158,11 @@ def test_c2_handle_switching_port(basic_c2_network): # Assert to confirm that both the C2 server and the C2 beacon # Have reconfigured their C2 settings. - assert c2_beacon.c2_config.masquerade_port is Port.FTP - assert c2_beacon.c2_config.masquerade_protocol is IPProtocol.TCP + assert c2_beacon.config.masquerade_port is PORT_LOOKUP["FTP"] + assert c2_beacon.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] - assert c2_server.c2_config.masquerade_port is Port.FTP - assert c2_server.c2_config.masquerade_protocol is IPProtocol.TCP + assert c2_server.config.masquerade_port is PORT_LOOKUP["FTP"] + assert c2_server.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] def test_c2_handle_switching_frequency(basic_c2_network): @@ -174,8 +182,8 @@ def test_c2_handle_switching_frequency(basic_c2_network): assert c2_server.c2_connection_active is True # Assert to confirm that both the C2 server and the C2 beacon are configured correctly. - assert c2_beacon.c2_config.keep_alive_frequency is 2 - assert c2_server.c2_config.keep_alive_frequency is 2 + assert c2_beacon.config.keep_alive_frequency is 2 + assert c2_server.config.keep_alive_frequency is 2 # Configuring the C2 Beacon. c2_beacon.configure(c2_server_ip_address="192.168.0.1", keep_alive_frequency=10) @@ -186,8 +194,8 @@ def test_c2_handle_switching_frequency(basic_c2_network): # Assert to confirm that both the C2 server and the C2 beacon # Have reconfigured their C2 settings. - assert c2_beacon.c2_config.keep_alive_frequency is 10 - assert c2_server.c2_config.keep_alive_frequency is 10 + assert c2_beacon.config.keep_alive_frequency is 10 + assert c2_server.config.keep_alive_frequency is 10 # Now skipping 9 time steps to confirm keep alive inactivity for i in range(9): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py index 0811d2a0..35703e14 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py @@ -1,15 +1,15 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.networks import arcd_uc2_network -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.red_applications.data_manipulation_bot import ( DataManipulationAttackStage, DataManipulationBot, ) +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -20,15 +20,15 @@ def dm_client() -> Node: @pytest.fixture def dm_bot(dm_client) -> DataManipulationBot: - return dm_client.software_manager.software.get("DataManipulationBot") + return dm_client.software_manager.software.get("data-manipulation-bot") def test_create_dm_bot(dm_client): - data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software.get("DataManipulationBot") + data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software.get("data-manipulation-bot") - assert data_manipulation_bot.name == "DataManipulationBot" - assert data_manipulation_bot.port == Port.NONE - assert data_manipulation_bot.protocol == IPProtocol.NONE + assert data_manipulation_bot.name == "data-manipulation-bot" + assert data_manipulation_bot.port == PORT_LOOKUP["NONE"] + assert data_manipulation_bot.protocol == PROTOCOL_LOOKUP["NONE"] assert data_manipulation_bot.payload == "DELETE" @@ -75,8 +75,8 @@ def test_dm_bot_perform_data_manipulation_success(dm_bot): def test_dm_bot_fails_without_db_client(dm_client): - dm_client.software_manager.uninstall("DatabaseClient") - dm_bot = dm_client.software_manager.software.get("DataManipulationBot") + dm_client.software_manager.uninstall("database-client") + dm_bot = dm_client.software_manager.software.get("data-manipulation-bot") assert dm_bot._host_db_client is None dm_bot.attack_stage = DataManipulationAttackStage.PORT_SCAN dm_bot._perform_data_manipulation(p_of_success=1.0) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py index 2acd991a..fffb7c84 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py @@ -1,24 +1,30 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address import pytest from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def dos_bot() -> DoSBot: - computer = Computer( - hostname="compromised_pc", ip_address="192.168.0.1", subnet_mask="255.255.255.0", start_up_duration=0 - ) + computer_cfg = { + "type": "computer", + "hostname": "compromised_pc", + "ip_address": "192.168.0.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + computer: Computer = Computer.from_config(config=computer_cfg) + computer.power_on() computer.software_manager.install(DoSBot) - dos_bot: DoSBot = computer.software_manager.software.get("DoSBot") + dos_bot: DoSBot = computer.software_manager.software.get("dos-bot") dos_bot.configure(target_ip_address=IPv4Address("192.168.0.1")) return dos_bot @@ -34,7 +40,7 @@ def test_dos_bot_cannot_run_when_node_offline(dos_bot): dos_bot_node.power_off() - for i in range(dos_bot_node.shut_down_duration + 1): + for i in range(dos_bot_node.config.shut_down_duration + 1): dos_bot_node.apply_timestep(timestep=i) assert dos_bot_node.operating_state is NodeOperatingState.OFF diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_actions.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_actions.py index 0e9c536c..a69dc844 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_actions.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from primaite.simulator.system.applications.application import Application, ApplicationOperatingState diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_registry.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_registry.py index d8d7dfab..2a5d34ab 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_registry.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_registry.py @@ -1,22 +1,22 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from primaite.simulator.system.applications.application import Application def test_adding_to_app_registry(): - class temp_application(Application, identifier="temp_app"): + class temp_application(Application, discriminator="temp-app"): pass - assert Application._application_registry["temp_app"] is temp_application + assert Application._registry["temp-app"] is temp_application with pytest.raises(ValueError): - class another_application(Application, identifier="temp_app"): + class another_application(Application, discriminator="temp-app"): pass # This is kinda evil... # Because pytest doesn't reimport classes from modules, registering this temporary test application will change the # state of the Application registry for all subsequently run tests. So, we have to delete and unregister the class. del temp_application - Application._application_registry.pop("temp_app") + Application._registry.pop("temp-app") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_applications.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_applications.py index aef5d6d1..6cccad91 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_applications.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_applications.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.software import SoftwareHealthState @@ -18,7 +18,7 @@ def test_scan(application): def test_run_application(application): assert application.operating_state == ApplicationOperatingState.CLOSED - assert application.health_state_actual == SoftwareHealthState.UNUSED + assert application.health_state_actual == SoftwareHealthState.GOOD application.run() assert application.operating_state == ApplicationOperatingState.RUNNING @@ -37,9 +37,9 @@ def test_close_application(application): def test_application_describe_states(application): assert application.operating_state == ApplicationOperatingState.CLOSED - assert application.health_state_actual == SoftwareHealthState.UNUSED + assert application.health_state_actual == SoftwareHealthState.GOOD - assert SoftwareHealthState.UNUSED.value == application.describe_state().get("health_state_actual") + assert SoftwareHealthState.GOOD.value == application.describe_state().get("health_state_actual") application.run() assert SoftwareHealthState.GOOD.value == application.describe_state().get("health_state_actual") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py index e456ed78..177f31b0 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address from typing import Tuple from uuid import uuid4 @@ -17,18 +17,32 @@ from primaite.simulator.system.services.database.database_service import Databas def database_client_on_computer() -> Tuple[DatabaseClient, Computer]: network = Network() - db_server = Server(hostname="db_server", ip_address="192.168.0.1", subnet_mask="255.255.255.0", start_up_duration=0) + db_server: Server = Server.from_config( + config={ + "type": "server", + "hostname": "db_server", + "ip_address": "192.168.0.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) db_server.power_on() db_server.software_manager.install(DatabaseService) - db_server.software_manager.software["DatabaseService"].start() + db_server.software_manager.software["database-service"].start() - db_client = Computer( - hostname="db_client", ip_address="192.168.0.2", subnet_mask="255.255.255.0", start_up_duration=0 + db_client: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "db_client", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } ) db_client.power_on() db_client.software_manager.install(DatabaseClient) - database_client: DatabaseClient = db_client.software_manager.software.get("DatabaseClient") + database_client: DatabaseClient = db_client.software_manager.software.get("database-client") database_client.configure(server_ip_address=IPv4Address("192.168.0.1")) database_client.run() diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py index ce98d164..8a901f2b 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py @@ -1,46 +1,54 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.protocols.http import HttpResponsePacket, HttpStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def web_browser() -> WebBrowser: - computer = Computer( - hostname="web_client", - ip_address="192.168.1.11", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + computer_cfg = { + "type": "computer", + "hostname": "web_client", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + computer: Computer = Computer.from_config(config=computer_cfg) + computer.power_on() # Web Browser should be pre-installed in computer - web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") + web_browser: WebBrowser = computer.software_manager.software.get("web-browser") web_browser.run() assert web_browser.operating_state is ApplicationOperatingState.RUNNING return web_browser def test_create_web_client(): - computer = Computer( - hostname="web_client", - ip_address="192.168.1.11", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + computer_cfg = { + "type": "computer", + "hostname": "web_client", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + computer: Computer = Computer.from_config(config=computer_cfg) + computer.power_on() # Web Browser should be pre-installed in computer - web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") - assert web_browser.name is "WebBrowser" - assert web_browser.port is Port.HTTP - assert web_browser.protocol is IPProtocol.TCP + web_browser: WebBrowser = computer.software_manager.software.get("web-browser") + assert web_browser.name == "web-browser" + assert web_browser.port is PORT_LOOKUP["HTTP"] + assert web_browser.protocol is PROTOCOL_LOOKUP["TCP"] def test_receive_invalid_payload(web_browser): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/__init__.py b/tests/unit_tests/_primaite/_simulator/_system/_services/__init__.py index be6c00e7..836b79af 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/__init__.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py index 9e7ab1d2..2154ebf9 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from primaite.simulator.network.hardware.base import Node @@ -8,10 +8,18 @@ from primaite.simulator.system.services.database.database_service import Databas @pytest.fixture(scope="function") def database_server() -> Node: - node = Computer(hostname="db_node", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + node_cfg = { + "type": "computer", + "hostname": "db_node", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + + node = Computer.from_config(config=node_cfg) node.power_on() node.software_manager.install(DatabaseService) - node.software_manager.software.get("DatabaseService").start() + node.software_manager.software.get("database-service").start() return node diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py index e9ce4884..195632ee 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address import pytest @@ -6,34 +6,46 @@ import pytest from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.protocols.dns import DNSPacket, DNSReply, DNSRequest -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def dns_client() -> Computer: - node = Computer( - hostname="dns_client", - ip_address="192.168.1.11", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - dns_server=IPv4Address("192.168.1.10"), - ) + node_cfg = { + "type": "computer", + "hostname": "dns_client", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10"), + } + node = Computer.from_config(config=node_cfg) return node def test_create_dns_client(dns_client): assert dns_client is not None - dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient") - assert dns_client_service.name is "DNSClient" - assert dns_client_service.port is Port.DNS - assert dns_client_service.protocol is IPProtocol.TCP + dns_client_service: DNSClient = dns_client.software_manager.software.get("dns-client") + assert dns_client_service.name == "dns-client" + assert dns_client_service.port is PORT_LOOKUP["DNS"] + assert dns_client_service.protocol is PROTOCOL_LOOKUP["TCP"] def test_dns_client_add_domain_to_cache_when_not_running(dns_client): - dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient") + dns_client_service: DNSClient = dns_client.software_manager.software.get("dns-client") + + # shutdown the dns_client + dns_client.power_off() + + # wait for dns_client to turn off + idx = 0 + while dns_client.operating_state == NodeOperatingState.SHUTTING_DOWN: + dns_client.apply_timestep(idx) + idx += 1 + assert dns_client.operating_state is NodeOperatingState.OFF assert dns_client_service.operating_state is ServiceOperatingState.STOPPED @@ -46,7 +58,7 @@ def test_dns_client_add_domain_to_cache_when_not_running(dns_client): def test_dns_client_check_domain_exists_when_not_running(dns_client): dns_client.operating_state = NodeOperatingState.ON - dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient") + dns_client_service: DNSClient = dns_client.software_manager.software.get("dns-client") dns_client_service.start() assert dns_client.operating_state is NodeOperatingState.ON @@ -61,7 +73,7 @@ def test_dns_client_check_domain_exists_when_not_running(dns_client): dns_client.power_off() - for i in range(dns_client.shut_down_duration + 1): + for i in range(dns_client.config.shut_down_duration + 1): dns_client.apply_timestep(timestep=i) assert dns_client.operating_state is NodeOperatingState.OFF @@ -73,7 +85,7 @@ def test_dns_client_check_domain_exists_when_not_running(dns_client): def test_dns_client_check_domain_in_cache(dns_client): """Test to make sure that the check_domain_in_cache returns the correct values.""" dns_client.operating_state = NodeOperatingState.ON - dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient") + dns_client_service: DNSClient = dns_client.software_manager.software.get("dns-client") dns_client_service.start() # add a domain to the dns client cache @@ -85,7 +97,7 @@ def test_dns_client_check_domain_in_cache(dns_client): def test_dns_client_receive(dns_client): """Test to make sure the DNS Client knows how to deal with request responses.""" - dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient") + dns_client_service: DNSClient = dns_client.software_manager.software.get("dns-client") dns_client_service.receive( payload=DNSPacket( @@ -99,6 +111,6 @@ def test_dns_client_receive(dns_client): def test_dns_client_receive_non_dns_payload(dns_client): - dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient") + dns_client_service: DNSClient = dns_client.software_manager.software.get("dns-client") assert dns_client_service.receive(payload=None) is False diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py index 4658fe76..006a7f2e 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address import pytest @@ -8,21 +8,23 @@ from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def dns_server() -> Node: - node = Server( - hostname="dns_server", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + node_cfg = { + "type": "server", + "hostname": "dns_server", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + node = Server.from_config(config=node_cfg) node.power_on() node.software_manager.install(software_class=DNSServer) return node @@ -30,15 +32,15 @@ def dns_server() -> Node: def test_create_dns_server(dns_server): assert dns_server is not None - dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer") - assert dns_server_service.name is "DNSServer" - assert dns_server_service.port is Port.DNS - assert dns_server_service.protocol is IPProtocol.TCP + dns_server_service: DNSServer = dns_server.software_manager.software.get("dns-server") + assert dns_server_service.name == "dns-server" + assert dns_server_service.port is PORT_LOOKUP["DNS"] + assert dns_server_service.protocol is PROTOCOL_LOOKUP["TCP"] def test_dns_server_domain_name_registration(dns_server): """Test to check if the domain name registration works.""" - dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer") + dns_server_service: DNSServer = dns_server.software_manager.software.get("dns-server") # register the web server in the domain controller dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12")) @@ -50,17 +52,24 @@ def test_dns_server_domain_name_registration(dns_server): def test_dns_server_receive(dns_server): """Test to make sure that the DNS Server correctly responds to a DNS Client request.""" - dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer") + dns_server_service: DNSServer = dns_server.software_manager.software.get("dns-server") # register the web server in the domain controller dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12")) - client = Computer(hostname="client", ip_address="192.168.1.11", subnet_mask="255.255.255.0", start_up_duration=0) + client_cfg = { + "type": "computer", + "hostname": "client", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + client = Computer.from_config(config=client_cfg) client.power_on() - client.dns_server = IPv4Address("192.168.1.10") + client.config.dns_server = IPv4Address("192.168.1.10") network = Network() network.connect(dns_server.network_interface[1], client.network_interface[1]) - dns_client: DNSClient = client.software_manager.software["DNSClient"] # noqa + dns_client: DNSClient = client.software_manager.software["dns-client"] # noqa dns_client.check_domain_exists("fake-domain.com") assert dns_client.check_domain_exists("fake-domain.com") is False diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py index 99bb42ed..81e05467 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address import pytest @@ -8,31 +8,33 @@ from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def ftp_client() -> Node: - node = Computer( - hostname="ftp_client", - ip_address="192.168.1.11", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + node_cfg = { + "type": "computer", + "hostname": "ftp_client", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + node = Computer.from_config(config=node_cfg) node.power_on() return node def test_create_ftp_client(ftp_client): assert ftp_client is not None - ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") - assert ftp_client_service.name is "FTPClient" - assert ftp_client_service.port is Port.FTP - assert ftp_client_service.protocol is IPProtocol.TCP + ftp_client_service: FTPClient = ftp_client.software_manager.software.get("ftp-client") + assert ftp_client_service.name == "ftp-client" + assert ftp_client_service.port is PORT_LOOKUP["FTP"] + assert ftp_client_service.protocol is PROTOCOL_LOOKUP["TCP"] def test_ftp_client_store_file(ftp_client): @@ -51,7 +53,7 @@ def test_ftp_client_store_file(ftp_client): status_code=FTPStatusCode.OK, ) - ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") + ftp_client_service: FTPClient = ftp_client.software_manager.software.get("ftp-client") ftp_client_service.receive(response) assert ftp_client.file_system.get_file(folder_name="downloads", file_name="file.txt") @@ -61,11 +63,11 @@ def test_ftp_should_not_process_commands_if_service_not_running(ftp_client): """Method _process_ftp_command should return false if service is not running.""" payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.PORT, - ftp_command_args=Port.FTP, + ftp_command_args=PORT_LOOKUP["FTP"], status_code=FTPStatusCode.OK, ) - ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") + ftp_client_service: FTPClient = ftp_client.software_manager.software.get("ftp-client") ftp_client_service.stop() assert ftp_client_service.operating_state is ServiceOperatingState.STOPPED assert ftp_client_service._process_ftp_command(payload=payload).status_code is FTPStatusCode.ERROR @@ -75,7 +77,7 @@ def test_ftp_tries_to_send_file__that_does_not_exist(ftp_client): """Method send_file should return false if no file to send.""" assert ftp_client.file_system.get_file(folder_name="root", file_name="test.txt") is None - ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") + ftp_client_service: FTPClient = ftp_client.software_manager.software.get("ftp-client") assert ftp_client_service.operating_state is ServiceOperatingState.RUNNING assert ( ftp_client_service.send_file( @@ -91,10 +93,10 @@ def test_ftp_tries_to_send_file__that_does_not_exist(ftp_client): def test_offline_ftp_client_receives_request(ftp_client): """Receive should return false if the node the ftp client is installed on is offline.""" - ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") + ftp_client_service: FTPClient = ftp_client.software_manager.software.get("ftp-client") ftp_client.power_off() - for i in range(ftp_client.shut_down_duration + 1): + for i in range(ftp_client.config.shut_down_duration + 1): ftp_client.apply_timestep(timestep=i) assert ftp_client.operating_state is NodeOperatingState.OFF @@ -102,7 +104,7 @@ def test_offline_ftp_client_receives_request(ftp_client): payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.PORT, - ftp_command_args=Port.FTP, + ftp_command_args=PORT_LOOKUP["FTP"], status_code=FTPStatusCode.OK, ) @@ -111,7 +113,7 @@ def test_offline_ftp_client_receives_request(ftp_client): def test_receive_should_fail_if_payload_is_not_ftp(ftp_client): """Receive should return false if the node the ftp client is installed on is not an FTPPacket.""" - ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") + ftp_client_service: FTPClient = ftp_client.software_manager.software.get("ftp-client") assert ftp_client_service.receive(payload=None) is False @@ -119,8 +121,8 @@ def test_receive_should_ignore_payload_with_none_status_code(ftp_client): """Receive should ignore payload with no set status code to prevent infinite send/receive loops.""" payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.PORT, - ftp_command_args=Port.FTP, + ftp_command_args=PORT_LOOKUP["FTP"], status_code=None, ) - ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") + ftp_client_service: FTPClient = ftp_client.software_manager.software.get("ftp-client") assert ftp_client_service.receive(payload=payload) is False diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py index a1c2ba59..77ba5cd4 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus @@ -6,21 +6,23 @@ from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def ftp_server() -> Node: - node = Server( - hostname="ftp_server", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + node_cfg = { + "type": "server", + "hostname": "ftp_server", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + node = Server.from_config(config=node_cfg) node.power_on() node.software_manager.install(software_class=FTPServer) return node @@ -28,10 +30,10 @@ def ftp_server() -> Node: def test_create_ftp_server(ftp_server): assert ftp_server is not None - ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer") - assert ftp_server_service.name is "FTPServer" - assert ftp_server_service.port is Port.FTP - assert ftp_server_service.protocol is IPProtocol.TCP + ftp_server_service: FTPServer = ftp_server.software_manager.software.get("ftp-server") + assert ftp_server_service.name == "ftp-server" + assert ftp_server_service.port is PORT_LOOKUP["FTP"] + assert ftp_server_service.protocol is PROTOCOL_LOOKUP["TCP"] def test_ftp_server_store_file(ftp_server): @@ -49,7 +51,7 @@ def test_ftp_server_store_file(ftp_server): packet_payload_size=24, ) - ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer") + ftp_server_service: FTPServer = ftp_server.software_manager.software.get("ftp-server") ftp_server_service.receive(response) assert ftp_server.file_system.get_file(folder_name="downloads", file_name="file.txt") @@ -63,7 +65,7 @@ def test_ftp_server_should_send_error_if_port_arg_is_invalid(ftp_server): packet_payload_size=24, ) - ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer") + ftp_server_service: FTPServer = ftp_server.software_manager.software.get("ftp-server") assert ftp_server_service._process_ftp_command(payload=payload).status_code is FTPStatusCode.ERROR @@ -71,7 +73,7 @@ def test_ftp_server_receives_non_ftp_packet(ftp_server): """Receive should return false if the service receives a non ftp packet.""" response: FTPPacket = None - ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer") + ftp_server_service: FTPServer = ftp_server.software_manager.software.get("ftp-server") assert ftp_server_service.receive(response) is False @@ -87,7 +89,7 @@ def test_offline_ftp_server_receives_request(ftp_server): packet_payload_size=24, ) - ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer") + ftp_server_service: FTPServer = ftp_server.software_manager.software.get("ftp-server") ftp_server_service.stop() assert ftp_server_service.operating_state is ServiceOperatingState.STOPPED assert ftp_server_service.receive(response) is False diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_service_actions.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_service_actions.py index 537beb8b..60cd2422 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_service_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_service_actions.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from primaite.simulator.system.services.service import Service, ServiceOperatingState from primaite.simulator.system.software import SoftwareHealthState diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py index 8c12adaa..fe78aa65 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from uuid import uuid4 import pytest @@ -22,7 +22,7 @@ def test_scan(service): def test_start_service(service): assert service.operating_state == ServiceOperatingState.STOPPED - assert service.health_state_actual == SoftwareHealthState.UNUSED + assert service.health_state_actual == SoftwareHealthState.GOOD service.start() assert service.operating_state == ServiceOperatingState.RUNNING @@ -43,7 +43,7 @@ def test_pause_and_resume_service(service): assert service.operating_state == ServiceOperatingState.STOPPED service.resume() assert service.operating_state == ServiceOperatingState.STOPPED - assert service.health_state_actual == SoftwareHealthState.UNUSED + assert service.health_state_actual == SoftwareHealthState.GOOD service.start() assert service.health_state_actual == SoftwareHealthState.GOOD @@ -58,11 +58,11 @@ def test_pause_and_resume_service(service): def test_restart(service): assert service.operating_state == ServiceOperatingState.STOPPED - assert service.health_state_actual == SoftwareHealthState.UNUSED + assert service.health_state_actual == SoftwareHealthState.GOOD service.restart() # Service is STOPPED. Restart will only work if the service was PAUSED or RUNNING assert service.operating_state == ServiceOperatingState.STOPPED - assert service.health_state_actual == SoftwareHealthState.UNUSED + assert service.health_state_actual == SoftwareHealthState.GOOD service.start() assert service.operating_state == ServiceOperatingState.RUNNING @@ -148,7 +148,7 @@ def test_service_fixing(service): service.fix() assert service.health_state_actual == SoftwareHealthState.FIXING - for i in range(service.fixing_duration + 1): + for i in range(service.config.fixing_duration + 1): service.apply_timestep(i) assert service.health_state_actual == SoftwareHealthState.GOOD @@ -157,11 +157,11 @@ def test_service_fixing(service): def test_enable_disable(service): service.disable() assert service.operating_state == ServiceOperatingState.DISABLED - assert service.health_state_actual == SoftwareHealthState.UNUSED + assert service.health_state_actual == SoftwareHealthState.GOOD service.enable() assert service.operating_state == ServiceOperatingState.STOPPED - assert service.health_state_actual == SoftwareHealthState.UNUSED + assert service.health_state_actual == SoftwareHealthState.GOOD def test_overwhelm_service(service): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 55f89c04..3b2377e9 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Tuple from uuid import uuid4 @@ -13,28 +13,35 @@ from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter +from primaite.simulator.network.networks import arcd_uc2_network from primaite.simulator.network.protocols.ssh import ( SSHConnectionMessage, SSHPacket, SSHTransportMessage, SSHUserCredentials, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection, Terminal from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def terminal_on_computer() -> Tuple[Terminal, Computer]: - computer: Computer = Computer( - hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0 + computer: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_a", + "ip_address": "192.168.0.10", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } ) computer.power_on() - terminal: Terminal = computer.software_manager.software.get("Terminal") + terminal: Terminal = computer.software_manager.software.get("terminal") return terminal, computer @@ -42,11 +49,27 @@ def terminal_on_computer() -> Tuple[Terminal, Computer]: @pytest.fixture(scope="function") def basic_network() -> Network: network = Network() - node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + node_a = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_a", + "ip_address": "192.168.0.10", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node_a.power_on() node_a.software_manager.get_open_ports() - node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) + node_b = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_b", + "ip_address": "192.168.0.11", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) @@ -58,18 +81,23 @@ def wireless_wan_network(): network = Network() # Configure PC A - pc_a = Computer( - hostname="pc_a", - ip_address="192.168.0.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.0.1", - start_up_duration=0, - ) + pc_a_cfg = { + "type": "computer", + "hostname": "pc_a", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.0.1", + "start_up_duration": 0, + } + + pc_a = Computer.from_config(config=pc_a_cfg) pc_a.power_on() network.add_node(pc_a) # Configure Router 1 - router_1 = WirelessRouter(hostname="router_1", start_up_duration=0, airspace=network.airspace) + router_1 = WirelessRouter.from_config( + config={"type": "wireless-router", "hostname": "router_1", "start_up_duration": 0}, airspace=network.airspace + ) router_1.power_on() network.add_node(router_1) @@ -78,47 +106,41 @@ def wireless_wan_network(): network.connect(pc_a.network_interface[1], router_1.network_interface[2]) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) - # add ACL rule to allow SSH traffic - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=21) + # add acl rule to allow SSH traffic + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=21 + ) # Configure PC B - pc_b = Computer( - hostname="pc_b", - ip_address="192.168.2.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.2.1", - start_up_duration=0, - ) + + pc_b_cfg = { + "type": "computer", + "hostname": "pc_b", + "ip_address": "192.168.2.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.2.1", + "start_up_duration": 0, + } + + pc_b = Computer.from_config(config=pc_b_cfg) pc_b.power_on() network.add_node(pc_b) - # Configure Router 2 - router_2 = WirelessRouter(hostname="router_2", start_up_duration=0, airspace=network.airspace) - router_2.power_on() - network.add_node(router_2) - - # Configure the connection between PC B and Router 2 port 2 - router_2.configure_router_interface("192.168.2.1", "255.255.255.0") - network.connect(pc_b.network_interface[1], router_2.network_interface[2]) - # Configure Router 2 ACLs # Configure the wireless connection between Router 1 port 1 and Router 2 port 1 router_1.configure_wireless_access_point("192.168.1.1", "255.255.255.0") - router_2.configure_wireless_access_point("192.168.1.2", "255.255.255.0") router_1.route_table.add_route( address="192.168.2.0", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.2" ) - # Configure Route from Router 2 to PC A subnet - router_2.route_table.add_route( - address="192.168.0.2", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.1" - ) - - return pc_a, pc_b, router_1, router_2 + return network @pytest.fixture @@ -127,7 +149,7 @@ def game_and_agent_fixture(game_and_agent): game, agent = game_and_agent client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - client_1.start_up_duration = 3 + client_1.config.start_up_duration = 3 return game, agent @@ -139,24 +161,32 @@ def test_terminal_creation(terminal_on_computer): def test_terminal_install_default(): """Terminal should be auto installed onto Nodes""" - computer = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + computer: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_a", + "ip_address": "192.168.0.10", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) computer.power_on() - assert computer.software_manager.software.get("Terminal") + assert computer.software_manager.software.get("terminal") def test_terminal_not_on_switch(): """Ensure terminal does not auto-install to switch""" - test_switch = Switch(hostname="Test") + test_switch = Switch.from_config(config={"type": "switch", "hostname": "Test"}) - assert not test_switch.software_manager.software.get("Terminal") + assert not test_switch.software_manager.software.get("terminal") def test_terminal_send(basic_network): - """Test that Terminal can send valid commands.""" + """Test that terminal can send valid commands.""" network: Network = basic_network computer_a: Computer = network.get_node_by_hostname("node_a") - terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + terminal_a: Terminal = computer_a.software_manager.software.get("terminal") computer_b: Computer = network.get_node_by_hostname("node_b") payload: SSHPacket = SSHPacket( @@ -174,7 +204,7 @@ def test_terminal_receive(basic_network): """Test that terminal can receive and process commands""" network: Network = basic_network computer_a: Computer = network.get_node_by_hostname("node_a") - terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + terminal_a: Terminal = computer_a.software_manager.software.get("terminal") computer_b: Computer = network.get_node_by_hostname("node_b") folder_name = "Downloads" @@ -195,14 +225,14 @@ def test_terminal_receive(basic_network): def test_terminal_install(basic_network): - """Test that Terminal can successfully process an INSTALL request""" + """Test that terminal can successfully process an INSTALL request""" network: Network = basic_network computer_a: Computer = network.get_node_by_hostname("node_a") - terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + terminal_a: Terminal = computer_a.software_manager.software.get("terminal") computer_b: Computer = network.get_node_by_hostname("node_b") payload: SSHPacket = SSHPacket( - payload=["software_manager", "application", "install", "RansomwareScript"], + payload=["software_manager", "application", "install", "ransomware-script"], transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, ) @@ -211,16 +241,16 @@ def test_terminal_install(basic_network): username="admin", password="admin", ip_address="192.168.0.11" ) - term_a_on_node_b.execute(["software_manager", "application", "install", "RansomwareScript"]) + term_a_on_node_b.execute(["software_manager", "application", "install", "ransomware-script"]) - assert computer_b.software_manager.software.get("RansomwareScript") + assert computer_b.software_manager.software.get("ransomware-script") def test_terminal_fail_when_closed(basic_network): - """Ensure Terminal won't attempt to send/receive when off""" + """Ensure terminal won't attempt to send/receive when off""" network: Network = basic_network computer: Computer = network.get_node_by_hostname("node_a") - terminal: Terminal = computer.software_manager.software.get("Terminal") + terminal: Terminal = computer.software_manager.software.get("terminal") computer_b: Computer = network.get_node_by_hostname("node_b") terminal.operating_state = ServiceOperatingState.STOPPED @@ -229,12 +259,12 @@ def test_terminal_fail_when_closed(basic_network): def test_terminal_disconnect(basic_network): - """Test Terminal disconnects""" + """Test terminal disconnects""" network: Network = basic_network computer_a: Computer = network.get_node_by_hostname("node_a") - terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + terminal_a: Terminal = computer_a.software_manager.software.get("terminal") computer_b: Computer = network.get_node_by_hostname("node_b") - terminal_b: Terminal = computer_b.software_manager.software.get("Terminal") + terminal_b: Terminal = computer_b.software_manager.software.get("terminal") assert len(terminal_b._connections) == 0 @@ -252,10 +282,10 @@ def test_terminal_disconnect(basic_network): def test_terminal_ignores_when_off(basic_network): - """Terminal should ignore commands when not running""" + """terminal should ignore commands when not running""" network: Network = basic_network computer_a: Computer = network.get_node_by_hostname("node_a") - terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + terminal_a: Terminal = computer_a.software_manager.software.get("terminal") computer_b: Computer = network.get_node_by_hostname("node_b") @@ -265,14 +295,17 @@ def test_terminal_ignores_when_off(basic_network): terminal_a.operating_state = ServiceOperatingState.STOPPED - assert not term_a_on_term_b.execute(["software_manager", "application", "install", "RansomwareScript"]) + assert not term_a_on_term_b.execute(["software_manager", "application", "install", "ransomware-script"]) def test_computer_remote_login_to_router(wireless_wan_network): """Test to confirm that a computer can SSH into a router.""" - pc_a, _, router_1, _ = wireless_wan_network - pc_a_terminal: Terminal = pc_a.software_manager.software.get("Terminal") + pc_a = wireless_wan_network.get_node_by_hostname("pc_a") + + router_1 = wireless_wan_network.get_node_by_hostname("router_1") + + pc_a_terminal: Terminal = pc_a.software_manager.software.get("terminal") assert len(pc_a_terminal._connections) == 0 @@ -280,18 +313,20 @@ def test_computer_remote_login_to_router(wireless_wan_network): assert len(pc_a_terminal._connections) == 1 - payload = ["software_manager", "application", "install", "RansomwareScript"] + payload = ["software_manager", "application", "install", "ransomware-script"] pc_a_on_router_1.execute(payload) - assert router_1.software_manager.software.get("RansomwareScript") + assert router_1.software_manager.software.get("ransomware-script") def test_router_remote_login_to_computer(wireless_wan_network): """Test to confirm that a router can ssh into a computer.""" - pc_a, _, router_1, _ = wireless_wan_network + pc_a = wireless_wan_network.get_node_by_hostname("pc_a") - router_1_terminal: Terminal = router_1.software_manager.software.get("Terminal") + router_1 = wireless_wan_network.get_node_by_hostname("router_1") + + router_1_terminal: Terminal = router_1.software_manager.software.get("terminal") assert len(router_1_terminal._connections) == 0 @@ -299,21 +334,23 @@ def test_router_remote_login_to_computer(wireless_wan_network): assert len(router_1_terminal._connections) == 1 - payload = ["software_manager", "application", "install", "RansomwareScript"] + payload = ["software_manager", "application", "install", "ransomware-script"] router_1_on_pc_a.execute(payload) - assert pc_a.software_manager.software.get("RansomwareScript") + assert pc_a.software_manager.software.get("ransomware-script") def test_router_blocks_SSH_traffic(wireless_wan_network): """Test to check that router will block SSH traffic if no ACL rule.""" - pc_a, _, router_1, _ = wireless_wan_network + pc_a = wireless_wan_network.get_node_by_hostname("pc_a") + + router_1 = wireless_wan_network.get_node_by_hostname("router_1") # Remove rule that allows SSH traffic. router_1.acl.remove_rule(position=21) - pc_a_terminal: Terminal = pc_a.software_manager.software.get("Terminal") + pc_a_terminal: Terminal = pc_a.software_manager.software.get("terminal") assert len(pc_a_terminal._connections) == 0 @@ -322,18 +359,22 @@ def test_router_blocks_SSH_traffic(wireless_wan_network): assert len(pc_a_terminal._connections) == 0 -def test_SSH_across_network(wireless_wan_network): +def test_SSH_across_network(): """Test to show ability to SSH across a network.""" - pc_a, pc_b, router_1, router_2 = wireless_wan_network + network: Network = arcd_uc2_network() + pc_a = network.get_node_by_hostname("client_1") + router_1 = network.get_node_by_hostname("router_1") - terminal_a: Terminal = pc_a.software_manager.software.get("Terminal") - terminal_b: Terminal = pc_b.software_manager.software.get("Terminal") + terminal_a: Terminal = pc_a.software_manager.software.get("terminal") - router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=21) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=21 + ) assert len(terminal_a._connections) == 0 - terminal_b_on_terminal_a = terminal_b.login(username="admin", password="admin", ip_address="192.168.0.2") + # Login to the Domain Controller + terminal_a.login(username="admin", password="admin", ip_address="192.168.1.10") assert len(terminal_a._connections) == 1 @@ -342,7 +383,7 @@ def test_multiple_remote_terminals_same_node(basic_network): """Test to check that multiple remote terminals can be spawned by one node.""" network: Network = basic_network computer_a: Computer = network.get_node_by_hostname("node_a") - terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + terminal_a: Terminal = computer_a.software_manager.software.get("terminal") computer_b: Computer = network.get_node_by_hostname("node_b") assert len(terminal_a._connections) == 0 @@ -351,8 +392,6 @@ def test_multiple_remote_terminals_same_node(basic_network): for attempt in range(3): remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11") - terminal_a.show() - assert len(terminal_a._connections) == 3 @@ -360,10 +399,10 @@ def test_terminal_rejects_commands_if_disconnect(basic_network): """Test to check terminal will ignore commands from disconnected connections""" network: Network = basic_network computer_a: Computer = network.get_node_by_hostname("node_a") - terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + terminal_a: Terminal = computer_a.software_manager.software.get("terminal") computer_b: Computer = network.get_node_by_hostname("node_b") - terminal_b: Terminal = computer_b.software_manager.software.get("Terminal") + terminal_b: Terminal = computer_b.software_manager.software.get("terminal") remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11") @@ -375,9 +414,9 @@ def test_terminal_rejects_commands_if_disconnect(basic_network): assert len(terminal_a._connections) == 0 assert len(terminal_b._connections) == 0 - assert remote_connection.execute(["software_manager", "application", "install", "RansomwareScript"]) is False + assert remote_connection.execute(["software_manager", "application", "install", "ransomware-script"]) is False - assert not computer_b.software_manager.software.get("RansomwareScript") + assert not computer_b.software_manager.software.get("ransomware-script") assert remote_connection.is_active is False @@ -386,9 +425,9 @@ def test_terminal_connection_timeout(basic_network): """Test that terminal_connections are affected by UserSession timeout.""" network: Network = basic_network computer_a: Computer = network.get_node_by_hostname("node_a") - terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + terminal_a: Terminal = computer_a.software_manager.software.get("terminal") computer_b: Computer = network.get_node_by_hostname("node_b") - terminal_b: Terminal = computer_b.software_manager.software.get("Terminal") + terminal_b: Terminal = computer_b.software_manager.software.get("terminal") remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11") @@ -410,7 +449,7 @@ def test_terminal_last_response_updates(basic_network): """Test that the _last_response within Terminal correctly updates.""" network: Network = basic_network computer_a: Computer = network.get_node_by_hostname("node_a") - terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + terminal_a: Terminal = computer_a.software_manager.software.get("terminal") computer_b: Computer = network.get_node_by_hostname("node_b") assert terminal_a.last_response is None @@ -420,12 +459,12 @@ def test_terminal_last_response_updates(basic_network): # Last response should be a successful logon assert terminal_a.last_response == RequestResponse(status="success", data={"reason": "Login Successful"}) - remote_connection.execute(command=["software_manager", "application", "install", "RansomwareScript"]) + remote_connection.execute(command=["software_manager", "application", "install", "ransomware-script"]) # Last response should now update following successful install assert terminal_a.last_response == RequestResponse(status="success", data={}) - remote_connection.execute(command=["software_manager", "application", "install", "RansomwareScript"]) + remote_connection.execute(command=["software_manager", "application", "install", "ransomware-script"]) # Last response should now update to success, but with supplied reason. assert terminal_a.last_response == RequestResponse(status="success", data={"reason": "already installed"}) @@ -444,7 +483,7 @@ def test_terminal_last_response_updates(basic_network): remote_connection.execute( command=[ "service", - "FTPClient", + "ftp-client", "send", { "dest_ip_address": "192.168.0.2", diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py index 9af176be..30e0f9eb 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import pytest from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState @@ -9,38 +9,40 @@ from primaite.simulator.network.protocols.http import ( HttpResponsePacket, HttpStatusCode, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def web_server() -> Server: - node = Server( - hostname="web_server", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + node_cfg = { + "type": "server", + "hostname": "web_server", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + node = Server.from_config(config=node_cfg) node.power_on() node.software_manager.install(WebServer) - node.software_manager.software.get("WebServer").start() + node.software_manager.software.get("web-server").start() return node def test_create_web_server(web_server): assert web_server is not None - web_server_service: WebServer = web_server.software_manager.software.get("WebServer") - assert web_server_service.name is "WebServer" - assert web_server_service.port is Port.HTTP - assert web_server_service.protocol is IPProtocol.TCP + web_server_service: WebServer = web_server.software_manager.software.get("web-server") + assert web_server_service.name == "web-server" + assert web_server_service.port == PORT_LOOKUP["HTTP"] + assert web_server_service.protocol == PROTOCOL_LOOKUP["TCP"] def test_handling_get_request_not_found_path(web_server): payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/fake-path") - web_server_service: WebServer = web_server.software_manager.software.get("WebServer") + web_server_service: WebServer = web_server.software_manager.software.get("web-server") response: HttpResponsePacket = web_server_service._handle_get_request(payload=payload) assert response.status_code == HttpStatusCode.NOT_FOUND @@ -49,7 +51,7 @@ def test_handling_get_request_not_found_path(web_server): def test_handling_get_request_home_page(web_server): payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/") - web_server_service: WebServer = web_server.software_manager.software.get("WebServer") + web_server_service: WebServer = web_server.software_manager.software.get("web-server") response: HttpResponsePacket = web_server_service._handle_get_request(payload=payload) assert response.status_code == HttpStatusCode.OK diff --git a/tests/unit_tests/_primaite/_simulator/_system/core/test_sys_log.py b/tests/unit_tests/_primaite/_simulator/_system/core/test_sys_log.py index 053211cd..5a734b6e 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/core/test_sys_log.py +++ b/tests/unit_tests/_primaite/_simulator/_system/core/test_sys_log.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from uuid import uuid4 import pytest diff --git a/tests/unit_tests/_primaite/_simulator/_system/test_software.py b/tests/unit_tests/_primaite/_simulator/_system/test_software.py index 4cf83370..9ad0dbcb 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/test_software.py +++ b/tests/unit_tests/_primaite/_simulator/_system/test_software.py @@ -1,16 +1,24 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Dict import pytest +from pydantic import Field -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.service import Service from primaite.simulator.system.software import IOSoftware, SoftwareHealthState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP -class TestSoftware(Service): +class TestSoftware(Service, discriminator="TestSoftware"): + class ConfigSchema(Service.ConfigSchema): + """ConfigSChema for TestSoftware.""" + + type: str = "test-software" + + config: "TestSoftware.ConfigSchema" = Field(default_factory=lambda: TestSoftware.ConfigSchema()) + def describe_state(self) -> Dict: pass @@ -18,11 +26,11 @@ class TestSoftware(Service): @pytest.fixture(scope="function") def software(file_system): return TestSoftware( - name="TestSoftware", - port=Port.ARP, + name="test-software", + port=PORT_LOOKUP["ARP"], file_system=file_system, sys_log=SysLog(hostname="test_service"), - protocol=IPProtocol.TCP, + protocol=PROTOCOL_LOOKUP["TCP"], ) @@ -31,6 +39,6 @@ def test_software_creation(software): def test_software_set_health_state(software): - assert software.health_state_actual == SoftwareHealthState.UNUSED - software.set_health_state(SoftwareHealthState.GOOD) assert software.health_state_actual == SoftwareHealthState.GOOD + software.set_health_state(SoftwareHealthState.COMPROMISED) + assert software.health_state_actual == SoftwareHealthState.COMPROMISED diff --git a/tests/unit_tests/_primaite/_simulator/test_core.py b/tests/unit_tests/_primaite/_simulator/test_core.py index 02960978..271375eb 100644 --- a/tests/unit_tests/_primaite/_simulator/test_core.py +++ b/tests/unit_tests/_primaite/_simulator/test_core.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Callable, Dict, List, Literal, Tuple import pytest diff --git a/tests/unit_tests/_primaite/_simulator/test_sim_container.py b/tests/unit_tests/_primaite/_simulator/test_sim_container.py index fe702307..f482d7e6 100644 --- a/tests/unit_tests/_primaite/_simulator/test_sim_container.py +++ b/tests/unit_tests/_primaite/_simulator/test_sim_container.py @@ -1,4 +1,4 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from primaite.simulator.sim_container import Simulation diff --git a/tests/unit_tests/_primaite/_utils/__init__.py b/tests/unit_tests/_primaite/_utils/__init__.py index be6c00e7..836b79af 100644 --- a/tests/unit_tests/_primaite/_utils/__init__.py +++ b/tests/unit_tests/_primaite/_utils/__init__.py @@ -1 +1 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_utils/_validation/__init__.py b/tests/unit_tests/_primaite/_utils/_validation/__init__.py new file mode 100644 index 00000000..836b79af --- /dev/null +++ b/tests/unit_tests/_primaite/_utils/_validation/__init__.py @@ -0,0 +1 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_utils/_validation/test_ip_protocol.py b/tests/unit_tests/_primaite/_utils/_validation/test_ip_protocol.py new file mode 100644 index 00000000..7acbe4a7 --- /dev/null +++ b/tests/unit_tests/_primaite/_utils/_validation/test_ip_protocol.py @@ -0,0 +1,23 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +import pytest + +from primaite.utils.validation.ip_protocol import IPProtocol, is_valid_protocol, PROTOCOL_LOOKUP, protocol_validator + + +def test_port_conversion(): + for proto_name, proto_val in PROTOCOL_LOOKUP.items(): + assert protocol_validator(proto_name) == proto_val + assert is_valid_protocol(proto_name) + + +def test_port_passthrough(): + for proto_val in PROTOCOL_LOOKUP.values(): + assert protocol_validator(proto_val) == proto_val + assert is_valid_protocol(proto_val) + + +def test_invalid_ports(): + for port in (123, "abcdefg", "NONEXISTENT_PROTO"): + with pytest.raises(ValueError): + protocol_validator(port) + assert not is_valid_protocol(port) diff --git a/tests/unit_tests/_primaite/_utils/_validation/test_port.py b/tests/unit_tests/_primaite/_utils/_validation/test_port.py new file mode 100644 index 00000000..2e30ab76 --- /dev/null +++ b/tests/unit_tests/_primaite/_utils/_validation/test_port.py @@ -0,0 +1,25 @@ +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +import pytest + +from primaite.utils.validation.port import is_valid_port, Port, PORT_LOOKUP, port_validator + + +def test_port_conversion(): + valid_port_lookup = {k: v for k, v in PORT_LOOKUP.items() if k != "UNUSED"} + for port_name, port_val in valid_port_lookup.items(): + assert port_validator(port_name) == port_val + assert is_valid_port(port_name) + + +def test_port_passthrough(): + valid_port_lookup = {k: v for k, v in PORT_LOOKUP.items() if k != "UNUSED"} + for port_val in valid_port_lookup.values(): + assert port_validator(port_val) == port_val + assert is_valid_port(port_val) + + +def test_invalid_ports(): + for port in (999999, -20, 3.214, "NONEXISTENT_PORT"): + with pytest.raises(ValueError): + port_validator(port) + assert not is_valid_port(port) diff --git a/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py b/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py index a8fb0a3a..d0a64ece 100644 --- a/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py +++ b/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py @@ -1,7 +1,7 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from primaite.utils.converters import convert_dict_enum_keys_to_enum_values +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP def test_simple_conversion(): @@ -11,7 +11,7 @@ def test_simple_conversion(): The original dictionary contains one level of nested dictionary with enums as keys. The expected output should have string values of enums as keys. """ - original_dict = {IPProtocol.UDP: {Port.ARP: {"inbound": 0, "outbound": 1016.0}}} + original_dict = {PROTOCOL_LOOKUP["UDP"]: {PORT_LOOKUP["ARP"]: {"inbound": 0, "outbound": 1016.0}}} expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0}}} assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict @@ -36,8 +36,8 @@ def test_mixed_keys(): The expected output should have string values of enums and original string keys. """ original_dict = { - IPProtocol.TCP: {"port": {"inbound": 0, "outbound": 1016.0}}, - "protocol": {Port.HTTP: {"inbound": 10, "outbound": 2020.0}}, + PROTOCOL_LOOKUP["TCP"]: {"port": {"inbound": 0, "outbound": 1016.0}}, + "protocol": {PORT_LOOKUP["HTTP"]: {"inbound": 10, "outbound": 2020.0}}, } expected_dict = { "tcp": {"port": {"inbound": 0, "outbound": 1016.0}}, @@ -66,7 +66,13 @@ def test_nested_dicts(): The expected output should have string values of enums as keys at all levels. """ original_dict = { - IPProtocol.UDP: {Port.ARP: {"inbound": 0, "outbound": 1016.0, "details": {IPProtocol.TCP: {"latency": "low"}}}} + PROTOCOL_LOOKUP["UDP"]: { + PORT_LOOKUP["ARP"]: { + "inbound": 0, + "outbound": 1016.0, + "details": {PROTOCOL_LOOKUP["TCP"]: {"latency": "low"}}, + } + } } expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0, "details": {"tcp": {"latency": "low"}}}}} assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict @@ -79,6 +85,12 @@ def test_non_dict_values(): The original dictionary contains lists and tuples as values. The expected output should preserve these non-dictionary values while converting enum keys to string values. """ - original_dict = {IPProtocol.UDP: [Port.ARP, Port.HTTP], "protocols": (IPProtocol.TCP, IPProtocol.UDP)} - expected_dict = {"udp": [Port.ARP, Port.HTTP], "protocols": (IPProtocol.TCP, IPProtocol.UDP)} + original_dict = { + PROTOCOL_LOOKUP["UDP"]: [PORT_LOOKUP["ARP"], PORT_LOOKUP["HTTP"]], + "protocols": (PROTOCOL_LOOKUP["TCP"], PROTOCOL_LOOKUP["UDP"]), + } + expected_dict = { + "udp": [PORT_LOOKUP["ARP"], PORT_LOOKUP["HTTP"]], + "protocols": (PROTOCOL_LOOKUP["TCP"], PROTOCOL_LOOKUP["UDP"]), + } assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict