diff --git a/CHANGELOG.md b/CHANGELOG.md index e4fb0172..6f1c2324 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [4.0.0] = TBC ### Added +- Log observation space data by episode and step. +- Added `show_history` method to Agents, allowing you to view actions taken by an agent per step. By default, `do-nothing` actions are omitted. +- New ``node-send-local-command`` action implemented which grants agents the ability to execute commands locally. (Previously limited to remote only) +- Added ability to set the observation threshold for NMNE, file access and application executions ### Changed - Agents now follow a common configuration format, simplifying the configuration of agents and their extensibilty. @@ -25,6 +29,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Nodes now use a config schema and are extensible, allowing for plugin support. - Node tests have been updated to use the new node config schemas when not using YAML files. - Documentation has been updated to include details of extensability with PrimAITE. +- ACLs are no longer applied to layer-2 traffic. +- Random number seed values are recorded in simulation/seed.log if the seed is set in the config file + or `generate_seed_value` is set to `true`. +- ARP .show() method will now include the port number associated with each entry. +- Added `services_requires_scan` and `applications_requires_scan` to agent observation space config to allow the agents to be able to see actual health states of services and applications without requiring scans (Default `True`, set to `False` to allow agents to see actual health state without scanning). +- Updated the `Terminal` class to provide response information when sending remote command execution. ### Fixed - DNS client no longer fails to check its cache if a DNS server address is missing. @@ -33,6 +43,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [3.3.0] - 2024-09-04 +## [3.4.0] + +### Added +- Log observation space data by episode and step. +- Added `show_history` method to Agents, allowing you to view actions taken by an agent per step. By default, `DONOTHING` actions are omitted. +- New ``NODE_SEND_LOCAL_COMMAND`` action implemented which grants agents the ability to execute commands locally. (Previously limited to remote only) +- Added ability to set the observation threshold for NMNE, file access and application executions + +### Changed +- ACL's are no longer applied to layer-2 traffic. +- Random number seed values are recorded in simulation/seed.log if the seed is set in the config file + or `generate_seed_value` is set to `true`. +- ARP .show() method will now include the port number associated with each entry. +- Added `services_requires_scan` and `applications_requires_scan` to agent observation space config to allow the agents to be able to see actual health states of services and applications without requiring scans (Default `True`, set to `False` to allow agents to see actual health state without scanning). +- Updated the `Terminal` class to provide response information when sending remote command execution. + +## [3.3.0] - 2024-09-04 ### Added - Random Number Generator Seeding by specifying a random number seed in the config file. - Implemented Terminal service class, providing a generic terminal simulation. diff --git a/docs/source/configuration/agents.rst b/docs/source/configuration/agents.rst index dce8da3a..c2674e31 100644 --- a/docs/source/configuration/agents.rst +++ b/docs/source/configuration/agents.rst @@ -20,12 +20,6 @@ Agents can be scripted (deterministic and stochastic), or controlled by a reinfo - ref: green_agent_example team: GREEN type: probabilistic-agent - observation_space: - type: UC2GreenObservation - action_space: - reward_function: - reward_components: - - type: dummy agent_settings: start_settings: @@ -160,3 +154,4 @@ If ``True``, gymnasium flattening will be performed on the observation space bef ----------------- Agents will record their action log for each step. This is a summary of what the agent did, along with response information from requests within the simulation. +A summary of the actions taken by the agent can be viewed using the `show_history()` function. By default, this will display all actions taken apart from ``DONOTHING``. diff --git a/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst b/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst index 542b817b..e6d5da67 100644 --- a/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst +++ b/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst @@ -54,6 +54,39 @@ Optional. Default value is ``3``. The number of time steps required to occur in order for the node to cycle from ``ON`` to ``SHUTTING_DOWN`` and then finally ``OFF``. +``file_system`` +--------------- + +Optional. + +The file system of the node. This configuration allows nodes to be initialised with files and/or folders. + +The file system takes a list of folders and files. + +Example: + +.. code-block:: yaml + + simulation: + network: + nodes: + - hostname: client_1 + type: computer + ip_address: 192.168.10.11 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + file_system: + - empty_folder # example of an empty folder + - downloads: + - "test_1.txt" # files in the downloads folder + - "test_2.txt" + - root: + - passwords: # example of file with size and type + size: 69 # size in bytes + type: TXT # See FileType for list of available file types + +List of file types: :py:mod:`primaite.simulator.file_system.file_type.FileType` + ``users`` --------- diff --git a/docs/source/configuration/simulation/nodes/network_examples.rst b/docs/source/configuration/simulation/nodes/network_examples.rst index 4616139e..84ee4c60 100644 --- a/docs/source/configuration/simulation/nodes/network_examples.rst +++ b/docs/source/configuration/simulation/nodes/network_examples.rst @@ -1177,8 +1177,8 @@ ACLs permitting or denying traffic as per our configured ACL rules. some_tech_storage_srv = network.get_node_by_hostname("some_tech_storage_srv") some_tech_storage_srv.file_system.create_file(file_name="test.png") - pc_1_ftp_client: FTPClient = network.get_node_by_hostname("pc_1").software_manager.software["FTPClient"] - pc_2_ftp_client: FTPClient = network.get_node_by_hostname("pc_2").software_manager.software["FTPClient"] + pc_1_ftp_client: FTPClient = network.get_node_by_hostname("pc_1").software_manager.software["ftp-client"] + pc_2_ftp_client: FTPClient = network.get_node_by_hostname("pc_2").software_manager.software["ftp-client"] assert not pc_1_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -1224,7 +1224,7 @@ ACLs permitting or denying traffic as per our configured ACL rules. web_server: Server = network.get_node_by_hostname("some_tech_web_srv") - web_ftp_client: FTPClient = web_server.software_manager.software["FTPClient"] + web_ftp_client: FTPClient = web_server.software_manager.software["ftp-client"] assert not web_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -1269,7 +1269,7 @@ ACLs permitting or denying traffic as per our configured ACL rules. some_tech_storage_srv.file_system.create_file(file_name="test.png") some_tech_snr_dev_pc: Computer = network.get_node_by_hostname("some_tech_snr_dev_pc") - snr_dev_ftp_client: FTPClient = some_tech_snr_dev_pc.software_manager.software["FTPClient"] + snr_dev_ftp_client: FTPClient = some_tech_snr_dev_pc.software_manager.software["ftp-client"] assert snr_dev_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -1294,7 +1294,7 @@ ACLs permitting or denying traffic as per our configured ACL rules. some_tech_storage_srv.file_system.create_file(file_name="test.png") some_tech_jnr_dev_pc: Computer = network.get_node_by_hostname("some_tech_jnr_dev_pc") - jnr_dev_ftp_client: FTPClient = some_tech_jnr_dev_pc.software_manager.software["FTPClient"] + jnr_dev_ftp_client: FTPClient = some_tech_jnr_dev_pc.software_manager.software["ftp-client"] assert not jnr_dev_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -1337,7 +1337,7 @@ ACLs permitting or denying traffic as per our configured ACL rules. some_tech_storage_srv.file_system.create_file(file_name="test.png") some_tech_hr_pc: Computer = network.get_node_by_hostname("some_tech_hr_1") - hr_ftp_client: FTPClient = some_tech_hr_pc.software_manager.software["FTPClient"] + hr_ftp_client: FTPClient = some_tech_hr_pc.software_manager.software["ftp-client"] assert not hr_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, diff --git a/docs/source/configuration/simulation/nodes/router.rst b/docs/source/configuration/simulation/nodes/router.rst index 4b41784c..ee278a98 100644 --- a/docs/source/configuration/simulation/nodes/router.rst +++ b/docs/source/configuration/simulation/nodes/router.rst @@ -74,7 +74,7 @@ The subnet mask setting for the port. ``acl`` ------- -Sets up the ACL rules for the router. +Sets up the ACL rules for the router to apply to layer-3 traffic. These are not applied to layer-2 traffic such as ARP. e.g. @@ -85,10 +85,6 @@ e.g. ... acl: 1: - action: PERMIT - src_port: ARP - dst_port: ARP - 2: action: PERMIT protocol: ICMP diff --git a/docs/source/how_to_guides/extensible_agents.rst b/docs/source/how_to_guides/extensible_agents.rst index 5225e63c..57cc39a0 100644 --- a/docs/source/how_to_guides/extensible_agents.rst +++ b/docs/source/how_to_guides/extensible_agents.rst @@ -46,17 +46,13 @@ The core features that should be implemented in any new agent are detailed below - ref: example_green_agent team: GREEN - type: ExampleAgent + type: example-agent action_space: action_map: 0: action: do-nothing options: {} - reward_function: - reward_components: - - type: dummy - agent_settings: start_step: 25 frequency: 20 diff --git a/docs/source/how_to_guides/extensible_nodes.rst b/docs/source/how_to_guides/extensible_nodes.rst index 9219ca31..6cc4b845 100644 --- a/docs/source/how_to_guides/extensible_nodes.rst +++ b/docs/source/how_to_guides/extensible_nodes.rst @@ -26,9 +26,9 @@ class Router(NetworkNode, identifier="router"): """ Represents a network router within the simulation, managing routing and forwarding of IP packets across network interfaces.""" SYSTEM_SOFTWARE: ClassVar[Dict] = { - "UserSessionManager": UserSessionManager, - "UserManager": UserManager, - "Terminal": Terminal, + "user-session-manager": UserSessionManager, + "user-manager": UserManager, + "terminal": Terminal, } network_interfaces: Dict[str, RouterInterface] = {} diff --git a/docs/source/request_system.rst b/docs/source/request_system.rst index f0437705..93fc2a9f 100644 --- a/docs/source/request_system.rst +++ b/docs/source/request_system.rst @@ -2,6 +2,8 @@ © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +.. _request_system: + Request System ************** diff --git a/docs/source/simulation_components/network/network.rst b/docs/source/simulation_components/network/network.rst index 95bfb9fe..6da610b2 100644 --- a/docs/source/simulation_components/network/network.rst +++ b/docs/source/simulation_components/network/network.rst @@ -99,19 +99,19 @@ we'll use the following Network that has a client, server, two switches, and a r network.connect(endpoint_a=switch_2.network_interface[1], endpoint_b=client_1.network_interface[1]) network.connect(endpoint_a=switch_1.network_interface[1], endpoint_b=server_1.network_interface[1]) -8. Add ACL rules on the Router to allow ARP and ICMP traffic. +8. Add an ACL rule on the Router to allow ICMP traffic. .. code-block:: python router_1.acl.add_rule( action=ACLAction.PERMIT, - src_port=Port["ARP"], - dst_port=Port["ARP"], + src_port=PORT_LOOKUP["ARP"], + dst_port=PORT_LOOKUP["ARP"], position=22 ) router_1.acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol["ICMP"], + protocol=PROTOCOL_LOOKUP["ICMP"], position=23 ) diff --git a/docs/source/simulation_components/network/nodes/wireless_router.rst b/docs/source/simulation_components/network/nodes/wireless_router.rst index d7207846..4078ffda 100644 --- a/docs/source/simulation_components/network/nodes/wireless_router.rst +++ b/docs/source/simulation_components/network/nodes/wireless_router.rst @@ -102,8 +102,8 @@ ICMP traffic, ensuring basic network connectivity and ping functionality. network.connect(pc_a.network_interface[1], router_1.router_interface) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Configure PC B pc_b = Computer( diff --git a/docs/source/simulation_components/system/applications/c2_suite.rst b/docs/source/simulation_components/system/applications/c2_suite.rst index ce083406..c2654cea 100644 --- a/docs/source/simulation_components/system/applications/c2_suite.rst +++ b/docs/source/simulation_components/system/applications/c2_suite.rst @@ -183,7 +183,7 @@ Python # Example command: Installing and configuring Ransomware: ransomware_installation_command = { "commands": [ - ["software_manager","application","install","RansomwareScript"], + ["software_manager","application","install","ransomware-script"], ], "username": "admin", "password": "admin", diff --git a/docs/source/simulation_components/system/applications/data_manipulation_bot.rst b/docs/source/simulation_components/system/applications/data_manipulation_bot.rst index 053ad8d7..b2229840 100644 --- a/docs/source/simulation_components/system/applications/data_manipulation_bot.rst +++ b/docs/source/simulation_components/system/applications/data_manipulation_bot.rst @@ -78,7 +78,7 @@ Python network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) client_1.software_manager.install(DatabaseClient) client_1.software_manager.install(DataManipulationBot) - data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") + data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("data-manipulation-bot") data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE") data_manipulation_bot.run() @@ -98,26 +98,6 @@ If not using the data manipulation bot manually, it needs to be used with a data team: RED type: red-database-corrupting-agent - observation_space: - type: UC2RedObservation - options: - nodes: - - node_name: client_1 - observations: - - logon_status - - operating_status - applications: - - application_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} - - action_space: - reward_function: - reward_components: - - type: dummy - agent_settings: start_settings: start_step: 25 diff --git a/docs/source/simulation_components/system/applications/database_client.rst b/docs/source/simulation_components/system/applications/database_client.rst index 341876fe..324db17e 100644 --- a/docs/source/simulation_components/system/applications/database_client.rst +++ b/docs/source/simulation_components/system/applications/database_client.rst @@ -60,7 +60,7 @@ Python # install DatabaseClient client.software_manager.install(DatabaseClient) - database_client: DatabaseClient = client.software_manager.software.get("DatabaseClient") + database_client: DatabaseClient = client.software_manager.software.get("database-client") # Configure the DatabaseClient database_client.configure(server_ip_address=IPv4Address("192.168.0.1")) # address of the DatabaseService diff --git a/docs/source/simulation_components/system/applications/ransomware_script.rst b/docs/source/simulation_components/system/applications/ransomware_script.rst index 517167fd..e1a385a7 100644 --- a/docs/source/simulation_components/system/applications/ransomware_script.rst +++ b/docs/source/simulation_components/system/applications/ransomware_script.rst @@ -63,7 +63,7 @@ Python network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) client_1.software_manager.install(DatabaseClient) client_1.software_manager.install(RansomwareScript) - RansomwareScript: RansomwareScript = client_1.software_manager.software.get("RansomwareScript") + RansomwareScript: RansomwareScript = client_1.software_manager.software.get("ransomware-script") RansomwareScript.configure(server_ip_address=IPv4Address("192.168.1.14")) RansomwareScript.execute() diff --git a/docs/source/simulation_components/system/applications/web_browser.rst b/docs/source/simulation_components/system/applications/web_browser.rst index 36290235..3ff12b58 100644 --- a/docs/source/simulation_components/system/applications/web_browser.rst +++ b/docs/source/simulation_components/system/applications/web_browser.rst @@ -62,7 +62,7 @@ The :ref:`DNSClient` must be configured to use the :ref:`DNSServer`. The :ref:`D # Install WebBrowser on computer computer.software_manager.install(WebBrowser) - web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") + web_browser: WebBrowser = computer.software_manager.software.get("web-browser") web_browser.run() # configure the WebBrowser diff --git a/docs/source/simulation_components/system/services/database_service.rst b/docs/source/simulation_components/system/services/database_service.rst index d071a5c4..58e8c667 100644 --- a/docs/source/simulation_components/system/services/database_service.rst +++ b/docs/source/simulation_components/system/services/database_service.rst @@ -67,7 +67,7 @@ Python # Install DatabaseService on server server.software_manager.install(DatabaseService) - db_service: DatabaseService = server.software_manager.software.get("DatabaseService") + db_service: DatabaseService = server.software_manager.software.get("database-service") db_service.start() # configure DatabaseService diff --git a/docs/source/simulation_components/system/services/dns_client.rst b/docs/source/simulation_components/system/services/dns_client.rst index aa99bc0e..9d131288 100644 --- a/docs/source/simulation_components/system/services/dns_client.rst +++ b/docs/source/simulation_components/system/services/dns_client.rst @@ -57,7 +57,7 @@ Python # Install DNSClient on server server.software_manager.install(DNSClient) - dns_client: DNSClient = server.software_manager.software.get("DNSClient") + dns_client: DNSClient = server.software_manager.software.get("dns-client") dns_client.start() # configure DatabaseService diff --git a/docs/source/simulation_components/system/services/dns_server.rst b/docs/source/simulation_components/system/services/dns_server.rst index 787b2a70..179d0837 100644 --- a/docs/source/simulation_components/system/services/dns_server.rst +++ b/docs/source/simulation_components/system/services/dns_server.rst @@ -54,7 +54,7 @@ Python # Install DNSServer on server server.software_manager.install(DNSServer) - dns_server: DNSServer = server.software_manager.software.get("DNSServer") + dns_server: DNSServer = server.software_manager.software.get("dns-server") dns_server.start() # configure DatabaseService diff --git a/docs/source/simulation_components/system/services/ftp_client.rst b/docs/source/simulation_components/system/services/ftp_client.rst index b8b493bd..db425b92 100644 --- a/docs/source/simulation_components/system/services/ftp_client.rst +++ b/docs/source/simulation_components/system/services/ftp_client.rst @@ -61,7 +61,7 @@ Python # Install FTPClient on server server.software_manager.install(FTPClient) - ftp_client: FTPClient = server.software_manager.software.get("FTPClient") + ftp_client: FTPClient = server.software_manager.software.get("ftp-client") ftp_client.start() diff --git a/docs/source/simulation_components/system/services/ftp_server.rst b/docs/source/simulation_components/system/services/ftp_server.rst index df080125..74d38945 100644 --- a/docs/source/simulation_components/system/services/ftp_server.rst +++ b/docs/source/simulation_components/system/services/ftp_server.rst @@ -56,7 +56,7 @@ Python # Install FTPServer on server server.software_manager.install(FTPServer) - ftp_server: FTPServer = server.software_manager.software.get("FTPServer") + ftp_server: FTPServer = server.software_manager.software.get("ftp-server") ftp_server.start() ftp_server.server_password = "test" diff --git a/docs/source/simulation_components/system/services/ntp_client.rst b/docs/source/simulation_components/system/services/ntp_client.rst index 22ae4516..29215a2d 100644 --- a/docs/source/simulation_components/system/services/ntp_client.rst +++ b/docs/source/simulation_components/system/services/ntp_client.rst @@ -54,7 +54,7 @@ Python # Install NTPClient on server server.software_manager.install(NTPClient) - ntp_client: NTPClient = server.software_manager.software.get("NTPClient") + ntp_client: NTPClient = server.software_manager.software.get("ntp-client") ntp_client.start() ntp_client.configure(ntp_server_ip_address=IPv4Address("192.168.0.10")) diff --git a/docs/source/simulation_components/system/services/ntp_server.rst b/docs/source/simulation_components/system/services/ntp_server.rst index 784aac0c..57efab4d 100644 --- a/docs/source/simulation_components/system/services/ntp_server.rst +++ b/docs/source/simulation_components/system/services/ntp_server.rst @@ -56,7 +56,7 @@ Python # Install NTPServer on server server.software_manager.install(NTPServer) - ntp_server: NTPServer = server.software_manager.software.get("NTPServer") + ntp_server: NTPServer = server.software_manager.software.get("ntp-server") ntp_server.start() diff --git a/docs/source/simulation_components/system/services/terminal.rst b/docs/source/simulation_components/system/services/terminal.rst index 1fec3b53..75f50e98 100644 --- a/docs/source/simulation_components/system/services/terminal.rst +++ b/docs/source/simulation_components/system/services/terminal.rst @@ -23,6 +23,14 @@ Key capabilities - Simulates common Terminal processes/commands. - Leverages the Service base class for install/uninstall, status tracking etc. +Usage +""""" + + - Pre-Installs on any `Node` component (with the exception of `Switches`). + - Terminal Clients connect, execute commands and disconnect from remote nodes. + - Ensures that users are logged in to the component before executing any commands. + - Service runs on SSH port 22 by default. + - Enables Agents to send commands both remotely and locally. Implementation """""""""""""" @@ -30,19 +38,112 @@ Implementation - Manages remote connections in a dictionary by session ID. - Processes commands, forwarding to the ``RequestManager`` or ``SessionManager`` where appropriate. - Extends Service class. - - A detailed guide on the implementation and functionality of the Terminal class can be found in the "Terminal-Processing" jupyter notebook. + +A detailed guide on the implementation and functionality of the Terminal class can be found in the "Terminal-Processing" jupyter notebook. + +Command Format +^^^^^^^^^^^^^^ + +Terminals implement their commands through leveraging the pre-existing :ref:`request_system`. + +Due to this Terminals will only accept commands passed within the ``RequestFormat``. + +:py:class:`primaite.game.interface.RequestFormat` + +For example, ``terminal`` command actions when used in ``yaml`` format are formatted as follows: + +.. code-block:: yaml + + command: + - "file_system" + - "create" + - "file" + - "downloads" + - "cat.png" + - "False + +This is then loaded from yaml into a dictionary containing the terminal command: + +.. code-block:: python + + {"command":["file_system", "create", "file", "downloads", "cat.png", "False"]} + +Which is then passed to the ``Terminals`` Request Manager to be executed. + +Game Layer Usage (Agents) +======================== + +The below code examples demonstrate how to use terminal related actions in yaml files. + +yaml +"""" + +``node-send-local-command`` +""""""""""""""""""""""""""" + +Agents can execute local commands without needing to perform a separate remote login action (``node-session-remote-login``). + +.. code-block:: yaml + + ... + ... + action: node-send-local-command + options: + node_id: 0 + username: admin + password: admin + command: # Example command - Creates a file called 'cat.png' in the downloads folder. + - "file_system" + - "create" + - "file" + - "downloads" + - "cat.png" + - "False" -Usage -""""" +``node-session-remote-login`` +""""""""""""""""" - - Pre-Installs on all ``Nodes`` (with the exception of ``Switches``). - - Terminal Clients connect, execute commands and disconnect from remote nodes. - - Ensures that users are logged in to the component before executing any commands. - - Service runs on SSH port 22 by default. +Agents are able to use the terminal to login into remote nodes via ``SSH`` which allows for agents to execute commands on remote hosts. + +.. code-block:: yaml + + ... + ... + action: node-session-remote-login + options: + node_id: 0 + username: admin + password: admin + remote_ip: 192.168.0.10 # Example Ip Address. (The remote host's IP that will be used by ssh) + + +``node-send-remote-command`` +"""""""""""""""""""""""""""" + +After remotely logging into another host, an agent can use the ``node-send-remote-command`` to execute commands across the network remotely. + +.. code-block:: yaml + + ... + ... + action: node-send-remote-command + options: + node_id: 0 + remote_ip: 192.168.0.10 + command: + - "file_system" + - "create" + - "file" + - "downloads" + - "cat.png" + - "False" + + + +Simulation Layer Usage +====================== -Usage -===== The below code examples demonstrate how to create a terminal, a remote terminal, and how to send a basic application install command to a remote node. @@ -66,7 +167,7 @@ Python } ) - terminal: Terminal = client.software_manager.software.get("Terminal") + terminal: Terminal = client.software_manager.software.get("terminal") Creating Remote Terminal Connection """"""""""""""""""""""""""""""""""" @@ -87,7 +188,7 @@ Creating Remote Terminal Connection node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) - terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + terminal_a: Terminal = node_a.software_manager.software.get("terminal") term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") @@ -113,12 +214,12 @@ Executing a basic application install command node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) - terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + terminal_a: Terminal = node_a.software_manager.software.get("terminal") term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") - term_a_term_b_remote_connection.execute(["software_manager", "application", "install", "RansomwareScript"]) + term_a_term_b_remote_connection.execute(["software_manager", "application", "install", "ransomware-script"]) @@ -141,7 +242,7 @@ Creating a folder on a remote node node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) - terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + terminal_a: Terminal = node_a.software_manager.software.get("terminal") term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") @@ -168,7 +269,7 @@ Disconnect from Remote Node node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) - terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + terminal_a: Terminal = node_a.software_manager.software.get("terminal") term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") diff --git a/docs/source/simulation_components/system/services/web_server.rst b/docs/source/simulation_components/system/services/web_server.rst index dc7ebd4b..1794d967 100644 --- a/docs/source/simulation_components/system/services/web_server.rst +++ b/docs/source/simulation_components/system/services/web_server.rst @@ -57,7 +57,7 @@ Python # Install WebServer on server server.software_manager.install(WebServer) - web_server: WebServer = server.software_manager.software.get("WebServer") + web_server: WebServer = server.software_manager.software.get("web-server") web_server.start() Via Configuration diff --git a/docs/source/simulation_components/system/software.rst b/docs/source/simulation_components/system/software.rst index be284af1..66408109 100644 --- a/docs/source/simulation_components/system/software.rst +++ b/docs/source/simulation_components/system/software.rst @@ -30,7 +30,7 @@ See :ref:`Node Start up and Shut down` node.software_manager.install(WebServer) - web_server: WebServer = node.software_manager.software.get("WebServer") + web_server: WebServer = node.software_manager.software.get("web-server") assert web_server.operating_state is ServiceOperatingState.RUNNING # service is immediately ran after install node.power_off() diff --git a/src/primaite/VERSION b/src/primaite/VERSION index 21a8fb4d..d9b058f1 100644 --- a/src/primaite/VERSION +++ b/src/primaite/VERSION @@ -1 +1 @@ -4.0.0a1-dev +4.0.0-dev diff --git a/src/primaite/game/agent/actions/node.py b/src/primaite/game/agent/actions/node.py index 19639c21..b1b6ec12 100644 --- a/src/primaite/game/agent/actions/node.py +++ b/src/primaite/game/agent/actions/node.py @@ -1,6 +1,6 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from abc import ABC, abstractmethod -from typing import ClassVar, List, Optional, Union +from typing import ClassVar, List, Literal, Optional, Union from primaite.game.agent.actions.manager import AbstractAction from primaite.interface.request import RequestFormat @@ -153,8 +153,6 @@ class NodeNMAPPortScanAction(NodeNMAPAbstractAction, discriminator="node-nmap-po class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, discriminator="node-network-service-recon"): """Action which performs an nmap network service recon (ping scan followed by port scan).""" - config: "NodeNetworkServiceReconAction.ConfigSchema" - class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema): """Configuration schema for NodeNetworkServiceReconAction.""" @@ -179,3 +177,70 @@ class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, discriminator="node- "show": config.show, }, ] + + +class NodeAccountsAddUserAction(AbstractAction, discriminator="node-account-add-user"): + class ConfigSchema(AbstractAction.ConfigSchema): + type: Literal["node-account-add-user"] = "node-account-add-user" + node_name: str + username: str + password: str + is_admin: bool + + @classmethod + @staticmethod + def form_request(config: ConfigSchema) -> RequestFormat: + return [ + "network", + "node", + config.node_name, + "service", + "user-manager", + "add_user", + config.username, + config.password, + config.is_admin, + ] + + +class NodeAccountsDisableUserAction(AbstractAction, discriminator="node-account-disable-user"): + class ConfigSchema(AbstractAction.ConfigSchema): + type: Literal["node-account-disable-user"] = "node-account-disable-user" + node_name: str + username: str + + @classmethod + @staticmethod + def form_request(config: ConfigSchema) -> RequestFormat: + return [ + "network", + "node", + config.node_name, + "service", + "user-manager", + "disable_user", + config.username, + ] + + +class NodeSendLocalCommandAction(AbstractAction, discriminator="node-send-local-command"): + class ConfigSchema(AbstractAction.ConfigSchema): + type: Literal["node-send-local-command"] = "node-send-local-command" + node_name: str + username: str + password: str + command: RequestFormat + + @staticmethod + def form_request(config: ConfigSchema) -> RequestFormat: + return [ + "network", + "node", + config.node_name, + "service", + "terminal", + "send_local_command", + config.username, + config.password, + {"command": config.command}, + ] diff --git a/src/primaite/game/agent/actions/session.py b/src/primaite/game/agent/actions/session.py index b7b56d17..63a45c5e 100644 --- a/src/primaite/game/agent/actions/session.py +++ b/src/primaite/game/agent/actions/session.py @@ -34,8 +34,6 @@ class NodeSessionAbstractAction(AbstractAction, ABC): class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="node-session-remote-login"): """Action which performs a remote session login.""" - config: "NodeSessionsRemoteLoginAction.ConfigSchema" - class ConfigSchema(NodeSessionAbstractAction.ConfigSchema): """Configuration schema for NodeSessionsRemoteLoginAction.""" @@ -53,7 +51,7 @@ class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="no config.node_name, "service", "terminal", - "node-session-remote-login", + "node_session_remote_login", config.username, config.password, config.remote_ip, diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 07363012..a6e9739f 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Type, TYPE_CHECKING from gymnasium.core import ActType, ObsType +from prettytable import PrettyTable from pydantic import BaseModel, ConfigDict, Field from primaite.game.agent.actions import ActionManager @@ -42,6 +43,9 @@ class AgentHistoryItem(BaseModel): reward_info: Dict[str, Any] = {} + observation: Optional[ObsType] = None + """The observation space data for this step.""" + class AbstractAgent(BaseModel, ABC): """Base class for scripted and RL agents.""" @@ -67,6 +71,9 @@ class AbstractAgent(BaseModel, ABC): default_factory=lambda: ObservationManager.ConfigSchema() ) reward_function: RewardFunction.ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema()) + thresholds: Optional[Dict] = {} + # TODO: this is only relevant to some observations, need to refactor the way thresholds are dealt with (#3085) + """A dict containing the observation thresholds.""" config: ConfigSchema = Field(default_factory=lambda: AbstractAgent.ConfigSchema()) @@ -95,10 +102,34 @@ class AbstractAgent(BaseModel, ABC): def model_post_init(self, __context: Any) -> None: """Overwrite the default empty action, observation, and rewards with ones defined through the config.""" self.action_manager = ActionManager(config=self.config.action_space) + self.config.observation_space.options.thresholds = self.config.thresholds self.observation_manager = ObservationManager(config=self.config.observation_space) self.reward_function = RewardFunction(config=self.config.reward_function) return super().model_post_init(__context) + def show_history(self, ignored_actions: Optional[list] = None): + """ + Print an agent action provided it's not the do-nothing action. + + :param ignored_actions: OPTIONAL: List of actions to be ignored when displaying the history. + If not provided, defaults to ignore do-nothing actions. + """ + if not ignored_actions: + ignored_actions = ["do-nothing"] + table = PrettyTable() + table.field_names = ["Step", "Action", "Params", "Response", "Response Data"] + print(f"Actions for '{self.config.ref}':") + for item in self.history: + if item.action in ignored_actions: + pass + else: + # format dict by putting each key-value entry on a separate line and putting a blank line on the end. + param_string = "\n".join([*[f"{k}: {v:.30}" for k, v in item.parameters.items()], ""]) + data_string = "\n".join([*[f"{k}: {v:.30}" for k, v in item.response.data], ""]) + + table.add_row([item.timestep, item.action, param_string, item.response.status, data_string]) + print(table) + def update_observation(self, state: Dict) -> ObsType: """ Convert a state from the simulator into an observation for the agent using the observation space. @@ -145,12 +176,23 @@ class AbstractAgent(BaseModel, ABC): return request def process_action_response( - self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse + self, + timestep: int, + action: str, + parameters: Dict[str, Any], + request: RequestFormat, + response: RequestResponse, + observation: ObsType, ) -> None: """Process the response from the most recent action.""" self.history.append( AgentHistoryItem( - timestep=timestep, action=action, parameters=parameters, request=request, response=response + timestep=timestep, + action=action, + parameters=parameters, + request=request, + response=response, + observation=observation, ) ) diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index ed9dcd8f..a9e3a9aa 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -26,7 +26,13 @@ class FileObservation(AbstractObservation, discriminator="file"): file_system_requires_scan: Optional[bool] = None """If True, the file must be scanned to update the health state. Tf False, the true state is always shown.""" - def __init__(self, where: WhereType, include_num_access: bool, file_system_requires_scan: bool) -> None: + def __init__( + self, + where: WhereType, + include_num_access: bool, + file_system_requires_scan: bool, + thresholds: Optional[Dict] = {}, + ) -> None: """ Initialise a file observation instance. @@ -48,10 +54,36 @@ class FileObservation(AbstractObservation, discriminator="file"): if self.include_num_access: self.default_observation["num_access"] = 0 - # TODO: allow these to be configured in yaml - self.high_threshold = 10 - self.med_threshold = 5 - self.low_threshold = 0 + if thresholds.get("file_access") is None: + self.low_file_access_threshold = 0 + self.med_file_access_threshold = 5 + self.high_file_access_threshold = 10 + else: + self._set_file_access_threshold( + thresholds=[ + thresholds.get("file_access")["low"], + thresholds.get("file_access")["medium"], + thresholds.get("file_access")["high"], + ] + ) + + def _set_file_access_threshold(self, thresholds: List[int]): + """ + Method that validates and then sets the file access threshold. + + :param: thresholds: The file access threshold to validate and set. + """ + if self._validate_thresholds( + thresholds=[ + thresholds[0], + thresholds[1], + thresholds[2], + ], + threshold_identifier="file_access", + ): + self.low_file_access_threshold = thresholds[0] + self.med_file_access_threshold = thresholds[1] + self.high_file_access_threshold = thresholds[2] def _categorise_num_access(self, num_access: int) -> int: """ @@ -60,11 +92,11 @@ class FileObservation(AbstractObservation, discriminator="file"): :param num_access: Number of file accesses. :return: Bin number corresponding to the number of accesses. """ - if num_access > self.high_threshold: + if num_access > self.high_file_access_threshold: return 3 - elif num_access > self.med_threshold: + elif num_access > self.med_file_access_threshold: return 2 - elif num_access > self.low_threshold: + elif num_access > self.low_file_access_threshold: return 1 return 0 @@ -122,6 +154,7 @@ class FileObservation(AbstractObservation, discriminator="file"): where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access, file_system_requires_scan=config.file_system_requires_scan, + thresholds=config.thresholds, ) @@ -149,6 +182,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"): num_files: int, include_num_access: bool, file_system_requires_scan: bool, + thresholds: Optional[Dict] = {}, ) -> None: """ Initialise a folder observation instance. @@ -177,6 +211,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"): where=None, include_num_access=include_num_access, file_system_requires_scan=self.file_system_requires_scan, + thresholds=thresholds, ) ) while len(self.files) > num_files: @@ -253,6 +288,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"): for file_config in config.files: file_config.include_num_access = config.include_num_access file_config.file_system_requires_scan = config.file_system_requires_scan + file_config.thresholds = config.thresholds files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files] return cls( @@ -261,4 +297,5 @@ class FolderObservation(AbstractObservation, discriminator="folder"): num_files=config.num_files, include_num_access=config.include_num_access, file_system_requires_scan=config.file_system_requires_scan, + thresholds=config.thresholds, ) diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 17bcb983..9b979063 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -54,7 +54,15 @@ class HostObservation(AbstractObservation, discriminator="host"): """ If True, files and folders must be scanned to update the health state. If False, true state is always shown. """ - include_users: Optional[bool] = None + services_requires_scan: Optional[bool] = None + """ + If True, services must be scanned to update the health state. If False, true state is always shown. + """ + applications_requires_scan: Optional[bool] = None + """ + If True, applications must be scanned to update the health state. If False, true state is always shown. + """ + include_users: Optional[bool] = True """If True, report user session information.""" def __init__( @@ -73,6 +81,8 @@ class HostObservation(AbstractObservation, discriminator="host"): monitored_traffic: Optional[Dict], include_num_access: bool, file_system_requires_scan: bool, + services_requires_scan: bool, + applications_requires_scan: bool, include_users: bool, ) -> None: """ @@ -108,6 +118,12 @@ class HostObservation(AbstractObservation, discriminator="host"): :param file_system_requires_scan: If True, the files and folders must be scanned to update the health state. If False, the true state is always shown. :type file_system_requires_scan: bool + :param services_requires_scan: If True, services must be scanned to update the health state. + If False, the true state is always shown. + :type services_requires_scan: bool + :param applications_requires_scan: If True, applications must be scanned to update the health state. + If False, the true state is always shown. + :type applications_requires_scan: bool :param include_users: If True, report user session information. :type include_users: bool """ @@ -121,7 +137,7 @@ class HostObservation(AbstractObservation, discriminator="host"): # Ensure lists have lengths equal to specified counts by truncating or padding self.services: List[ServiceObservation] = services while len(self.services) < num_services: - self.services.append(ServiceObservation(where=None)) + self.services.append(ServiceObservation(where=None, services_requires_scan=services_requires_scan)) while len(self.services) > num_services: truncated_service = self.services.pop() msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" @@ -129,7 +145,9 @@ class HostObservation(AbstractObservation, discriminator="host"): self.applications: List[ApplicationObservation] = applications while len(self.applications) < num_applications: - self.applications.append(ApplicationObservation(where=None)) + self.applications.append( + ApplicationObservation(where=None, applications_requires_scan=applications_requires_scan) + ) while len(self.applications) > num_applications: truncated_application = self.applications.pop() msg = f"Too many applications in Node observation space for node. Truncating {truncated_application.where}" @@ -153,7 +171,13 @@ class HostObservation(AbstractObservation, discriminator="host"): self.nics: List[NICObservation] = network_interfaces while len(self.nics) < num_nics: - self.nics.append(NICObservation(where=None, include_nmne=include_nmne, monitored_traffic=monitored_traffic)) + self.nics.append( + NICObservation( + where=None, + include_nmne=include_nmne, + monitored_traffic=monitored_traffic, + ) + ) while len(self.nics) > num_nics: truncated_nic = self.nics.pop() msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}" @@ -269,8 +293,15 @@ class HostObservation(AbstractObservation, discriminator="host"): folder_config.include_num_access = config.include_num_access folder_config.num_files = config.num_files folder_config.file_system_requires_scan = config.file_system_requires_scan + folder_config.thresholds = config.thresholds for nic_config in config.network_interfaces: nic_config.include_nmne = config.include_nmne + nic_config.thresholds = config.thresholds + for service_config in config.services: + service_config.services_requires_scan = config.services_requires_scan + for application_config in config.applications: + application_config.applications_requires_scan = config.applications_requires_scan + application_config.thresholds = config.thresholds services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services] applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications] @@ -281,7 +312,10 @@ class HostObservation(AbstractObservation, discriminator="host"): count = 1 while len(nics) < config.num_nics: nic_config = NICObservation.ConfigSchema( - nic_num=count, include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic + nic_num=count, + include_nmne=config.include_nmne, + monitored_traffic=config.monitored_traffic, + thresholds=config.thresholds, ) nics.append(NICObservation.from_config(config=nic_config, parent_where=where)) count += 1 @@ -301,5 +335,7 @@ class HostObservation(AbstractObservation, discriminator="host"): monitored_traffic=config.monitored_traffic, include_num_access=config.include_num_access, file_system_requires_scan=config.file_system_requires_scan, + services_requires_scan=config.services_requires_scan, + applications_requires_scan=config.applications_requires_scan, include_users=config.include_users, ) diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 1aa6470d..8faeb906 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -1,13 +1,14 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import Dict, List, Optional +from typing import ClassVar, Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.simulator.network.nmne import NMNEConfig from primaite.utils.validation.ip_protocol import IPProtocol from primaite.utils.validation.port import Port @@ -15,6 +16,9 @@ from primaite.utils.validation.port import Port class NICObservation(AbstractObservation, discriminator="network-interface"): """Status information about a network interface within the simulation environment.""" + capture_nmne: ClassVar[bool] = NMNEConfig().capture_nmne + "A Boolean specifying whether malicious network events should be captured." + class ConfigSchema(AbstractObservation.ConfigSchema): """Configuration schema for NICObservation.""" @@ -25,7 +29,13 @@ class NICObservation(AbstractObservation, discriminator="network-interface"): monitored_traffic: Optional[Dict[IPProtocol, List[Port]]] = None """A dict containing which traffic types are to be included in the observation.""" - def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None: + def __init__( + self, + where: WhereType, + include_nmne: bool, + monitored_traffic: Optional[Dict] = None, + thresholds: Dict = {}, + ) -> None: """ Initialise a network interface observation instance. @@ -45,10 +55,18 @@ class NICObservation(AbstractObservation, discriminator="network-interface"): self.nmne_inbound_last_step: int = 0 self.nmne_outbound_last_step: int = 0 - # TODO: allow these to be configured in yaml - self.high_nmne_threshold = 10 - self.med_nmne_threshold = 5 - self.low_nmne_threshold = 0 + if thresholds.get("nmne") is None: + self.low_nmne_threshold = 0 + self.med_nmne_threshold = 5 + self.high_nmne_threshold = 10 + else: + self._set_nmne_threshold( + thresholds=[ + thresholds.get("nmne")["low"], + thresholds.get("nmne")["medium"], + thresholds.get("nmne")["high"], + ] + ) self.monitored_traffic = monitored_traffic if self.monitored_traffic: @@ -105,6 +123,20 @@ class NICObservation(AbstractObservation, discriminator="network-interface"): bandwidth_utilisation = traffic_value / nic_max_bandwidth return int(bandwidth_utilisation * 9) + 1 + def _set_nmne_threshold(self, thresholds: List[int]): + """ + Method that validates and then sets the NMNE threshold. + + :param: thresholds: The NMNE threshold to validate and set. + """ + if self._validate_thresholds( + thresholds=thresholds, + threshold_identifier="nmne", + ): + self.low_nmne_threshold = thresholds[0] + self.med_nmne_threshold = thresholds[1] + self.high_nmne_threshold = thresholds[2] + def observe(self, state: Dict) -> ObsType: """ Generate observation based on the current state of the simulation. @@ -116,7 +148,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"): """ nic_state = access_from_nested_dict(state, self.where) - if nic_state is NOT_PRESENT_IN_STATE: + if nic_state is NOT_PRESENT_IN_STATE or self.where is None: return self.default_observation obs = {"nic_status": 1 if nic_state["enabled"] else 2} @@ -164,7 +196,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"): for port in self.monitored_traffic[protocol]: obs["TRAFFIC"][protocol][port] = {"inbound": 0, "outbound": 0} - if self.include_nmne: + if self.capture_nmne and self.include_nmne: obs.update({"NMNE": {}}) direction_dict = nic_state["nmne"].get("direction", {}) inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) @@ -224,6 +256,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"): where=parent_where + ["NICs", config.nic_num], include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic, + thresholds=config.thresholds, ) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 3a3283a2..260fac68 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -48,7 +48,13 @@ class NodesObservation(AbstractObservation, discriminator="nodes"): include_num_access: Optional[bool] = None """Flag to include the number of accesses.""" file_system_requires_scan: bool = True - """If True, the folder must be scanned to update the health state. Tf False, the true state is always shown.""" + """If True, the folder must be scanned to update the health state. If False, the true state is always shown.""" + services_requires_scan: bool = True + """If True, the services must be scanned to update the health state. + If False, the true state is always shown.""" + applications_requires_scan: bool = True + """If True, the applications must be scanned to update the health state. + If False, the true state is always shown.""" include_users: Optional[bool] = True """If True, report user session information.""" num_ports: Optional[int] = None @@ -196,8 +202,14 @@ class NodesObservation(AbstractObservation, discriminator="nodes"): host_config.include_num_access = config.include_num_access if host_config.file_system_requires_scan is None: host_config.file_system_requires_scan = config.file_system_requires_scan + if host_config.services_requires_scan is None: + host_config.services_requires_scan = config.services_requires_scan + if host_config.applications_requires_scan is None: + host_config.applications_requires_scan = config.applications_requires_scan if host_config.include_users is None: host_config.include_users = config.include_users + if not host_config.thresholds: + host_config.thresholds = config.thresholds for router_config in config.routers: if router_config.num_ports is None: @@ -214,6 +226,8 @@ class NodesObservation(AbstractObservation, discriminator="nodes"): router_config.num_rules = config.num_rules if router_config.include_users is None: router_config.include_users = config.include_users + if not router_config.thresholds: + router_config.thresholds = config.thresholds for firewall_config in config.firewalls: if firewall_config.ip_list is None: @@ -228,6 +242,8 @@ class NodesObservation(AbstractObservation, discriminator="nodes"): firewall_config.num_rules = config.num_rules if firewall_config.include_users is None: firewall_config.include_users = config.include_users + if not firewall_config.thresholds: + firewall_config.thresholds = config.thresholds hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts] routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers] diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py index 032435b8..c979c132 100644 --- a/src/primaite/game/agent/observations/observation_manager.py +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -114,7 +114,9 @@ class NestedObservation(AbstractObservation, discriminator="custom"): instances = dict() for component in config.components: obs_class = AbstractObservation._registry[component.type] - obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options)) + obs_instance = obs_class.from_config( + config=obs_class.ConfigSchema(**component.options, thresholds=config.thresholds) + ) instances[component.label] = obs_instance return cls(components=instances) @@ -228,7 +230,7 @@ class ObservationManager(BaseModel): return self.obs.space @classmethod - def from_config(cls, config: Optional[Dict]) -> "ObservationManager": + def from_config(cls, config: Optional[Dict], thresholds: Optional[Dict] = {}) -> "ObservationManager": """ Create observation space from a config. @@ -239,11 +241,10 @@ class ObservationManager(BaseModel): AbstractObservation options: this must adhere to the chosen observation type's ConfigSchema nested class. :type config: Dict + :param thresholds: Dictionary containing the observation thresholds. + :type thresholds: Optional[Dict] """ if config is None: return cls(NullObservation()) - obs_type = config["type"] - obs_class = AbstractObservation._registry[obs_type] - observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"])) - obs_manager = cls(observation) + obs_manager = cls(config=config) return obs_manager diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index da81d2ad..8558b75c 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """Manages the observation space for the agent.""" from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable, Optional, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Type, Union from gymnasium import spaces from gymnasium.core import ObsType @@ -19,6 +19,9 @@ class AbstractObservation(ABC): class ConfigSchema(ABC, BaseModel): """Config schema for observations.""" + thresholds: Optional[Dict] = {} + """A dict containing the observation thresholds.""" + model_config = ConfigDict(extra="forbid") _registry: Dict[str, Type["AbstractObservation"]] = {} @@ -69,3 +72,34 @@ class AbstractObservation(ABC): def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> "AbstractObservation": """Create this observation space component form a serialised format.""" return cls() + + def _validate_thresholds(self, thresholds: List[int] = None, threshold_identifier: Optional[str] = "") -> bool: + """ + Method that checks if the thresholds are non overlapping and in the correct (ascending) order. + + Pass in the thresholds from low to high e.g. + thresholds=[low_threshold, med_threshold, ..._threshold, high_threshold] + + Throws an error if the threshold is not valid + + :param: thresholds: List of thresholds in ascending order. + :type: List[int] + :param: threshold_identifier: The name of the threshold option. + :type: Optional[str] + + :returns: bool + """ + if thresholds is None or len(thresholds) < 2: + raise Exception(f"{threshold_identifier} thresholds are invalid {thresholds}") + for idx in range(1, len(thresholds)): + if not isinstance(thresholds[idx], int): + raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.") + if not isinstance(thresholds[idx - 1], int): + raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.") + + if thresholds[idx] <= thresholds[idx - 1]: + raise Exception( + f"{threshold_identifier} threshold ({thresholds[idx - 1]}) " + f"is greater than or equal to ({thresholds[idx]}.)" + ) + return True diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index 07ec1abf..dac6b362 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import Dict +from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType @@ -19,7 +19,10 @@ class ServiceObservation(AbstractObservation, discriminator="service"): service_name: str """Name of the service, used for querying simulation state dictionary""" - def __init__(self, where: WhereType) -> None: + services_requires_scan: Optional[bool] = None + """If True, services must be scanned to update the health state. If False, true state is always shown.""" + + def __init__(self, where: WhereType, services_requires_scan: bool) -> None: """ Initialise a service observation instance. @@ -28,6 +31,7 @@ class ServiceObservation(AbstractObservation, discriminator="service"): :type where: WhereType """ self.where = where + self.services_requires_scan = services_requires_scan self.default_observation = {"operating_status": 0, "health_status": 0} def observe(self, state: Dict) -> ObsType: @@ -44,7 +48,9 @@ class ServiceObservation(AbstractObservation, discriminator="service"): return self.default_observation return { "operating_status": service_state["operating_state"], - "health_status": service_state["health_state_visible"], + "health_status": service_state["health_state_visible"] + if self.services_requires_scan + else service_state["health_state_actual"], } @property @@ -70,7 +76,9 @@ class ServiceObservation(AbstractObservation, discriminator="service"): :return: Constructed service observation instance. :rtype: ServiceObservation """ - return cls(where=parent_where + ["services", config.service_name]) + return cls( + where=parent_where + ["services", config.service_name], services_requires_scan=config.services_requires_scan + ) class ApplicationObservation(AbstractObservation, discriminator="application"): @@ -82,7 +90,12 @@ class ApplicationObservation(AbstractObservation, discriminator="application"): application_name: str """Name of the application, used for querying simulation state dictionary""" - def __init__(self, where: WhereType) -> None: + applications_requires_scan: Optional[bool] = None + """ + If True, applications must be scanned to update the health state. If False, true state is always shown. + """ + + def __init__(self, where: WhereType, applications_requires_scan: bool, thresholds: Optional[Dict] = {}) -> None: """ Initialise an application observation instance. @@ -92,25 +105,52 @@ class ApplicationObservation(AbstractObservation, discriminator="application"): :type where: WhereType """ self.where = where + self.applications_requires_scan = applications_requires_scan self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0} - # TODO: allow these to be configured in yaml - self.high_threshold = 10 - self.med_threshold = 5 - self.low_threshold = 0 + if thresholds.get("app_executions") is None: + self.low_app_execution_threshold = 0 + self.med_app_execution_threshold = 5 + self.high_app_execution_threshold = 10 + else: + self._set_application_execution_thresholds( + thresholds=[ + thresholds.get("app_executions")["low"], + thresholds.get("app_executions")["medium"], + thresholds.get("app_executions")["high"], + ] + ) + + def _set_application_execution_thresholds(self, thresholds: List[int]): + """ + Method that validates and then sets the application execution threshold. + + :param: thresholds: The application execution threshold to validate and set. + """ + if self._validate_thresholds( + thresholds=[ + thresholds[0], + thresholds[1], + thresholds[2], + ], + threshold_identifier="app_executions", + ): + self.low_app_execution_threshold = thresholds[0] + self.med_app_execution_threshold = thresholds[1] + self.high_app_execution_threshold = thresholds[2] def _categorise_num_executions(self, num_executions: int) -> int: """ - Represent number of file accesses as a categorical variable. + Represent number of application executions as a categorical variable. - :param num_access: Number of file accesses. + :param num_access: Number of application executions. :return: Bin number corresponding to the number of accesses. """ - if num_executions > self.high_threshold: + if num_executions > self.high_app_execution_threshold: return 3 - elif num_executions > self.med_threshold: + elif num_executions > self.med_app_execution_threshold: return 2 - elif num_executions > self.low_threshold: + elif num_executions > self.low_app_execution_threshold: return 1 return 0 @@ -128,7 +168,9 @@ class ApplicationObservation(AbstractObservation, discriminator="application"): return self.default_observation return { "operating_status": application_state["operating_state"], - "health_status": application_state["health_state_visible"], + "health_status": application_state["health_state_visible"] + if self.applications_requires_scan + else application_state["health_state_actual"], "num_executions": self._categorise_num_executions(application_state["num_executions"]), } @@ -161,4 +203,8 @@ class ApplicationObservation(AbstractObservation, discriminator="application"): :return: Constructed application observation instance. :rtype: ApplicationObservation """ - return cls(where=parent_where + ["applications", config.application_name]) + return cls( + where=parent_where + ["applications", config.application_name], + applications_requires_scan=config.applications_requires_scan, + thresholds=config.thresholds, + ) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 1427776e..a3b77ec3 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, ConfigDict from primaite import DEFAULT_BANDWIDTH, getLogger from primaite.game.agent.interface import AbstractAgent, ProxyAgent +from primaite.game.agent.observations import NICObservation from primaite.game.agent.rewards import SharedReward from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT @@ -44,15 +45,15 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) SERVICE_TYPES_MAPPING = { - "DNSClient": DNSClient, - "DNSServer": DNSServer, - "DatabaseService": DatabaseService, - "WebServer": WebServer, - "FTPClient": FTPClient, - "FTPServer": FTPServer, - "NTPClient": NTPClient, - "NTPServer": NTPServer, - "Terminal": Terminal, + "dns-client": DNSClient, + "dns-server": DNSServer, + "database-service": DatabaseService, + "web-server": WebServer, + "ftp-client": FTPClient, + "ftp-server": FTPServer, + "ntp-client": NTPClient, + "ntp-server": NTPServer, + "terminal": Terminal, } """List of available services that can be installed on nodes in the PrimAITE Simulation.""" @@ -68,6 +69,8 @@ class PrimaiteGameOptions(BaseModel): seed: int = None """Random number seed for RNGs.""" + generate_seed_value: bool = False + """Internally generated seed value.""" max_episode_length: int = 256 """Maximum number of episodes for the PrimAITE game.""" ports: List[Port] @@ -175,6 +178,7 @@ class PrimaiteGame: parameters=parameters, request=request, response=response, + observation=obs, ) def pre_timestep(self) -> None: @@ -263,6 +267,7 @@ class PrimaiteGame: node_sets_cfg = network_config.get("node_sets", []) # Set the NMNE capture config NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {})) + NICObservation.capture_nmne = NMNEConfig(**network_config.get("nmne_config", {})).capture_nmne for node_cfg in nodes_cfg: n_type = node_cfg["type"] @@ -293,6 +298,7 @@ class PrimaiteGame: if "users" in node_cfg and new_node.software_manager.software.get("user-manager"): user_manager: UserManager = new_node.software_manager.software["user-manager"] # noqa + for user_cfg in node_cfg["users"]: user_manager.add_user(**user_cfg, bypass_can_perform_action=True) @@ -407,6 +413,7 @@ class PrimaiteGame: agents_cfg = cfg.get("agents", []) for agent_cfg in agents_cfg: + agent_cfg = {**agent_cfg, "thresholds": game.options.thresholds} new_agent = AbstractAgent.from_config(agent_cfg) game.agents[agent_cfg["ref"]] = new_agent if isinstance(new_agent, ProxyAgent): diff --git a/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb b/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb index 5982407b..9ac1da9b 100644 --- a/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb @@ -47,7 +47,7 @@ "source": [ "def make_cfg_have_flat_obs(cfg):\n", " for agent in cfg['agents']:\n", - " if agent['type'] == \"ProxyAgent\":\n", + " if agent['type'] == \"proxy-agent\":\n", " agent['agent_settings']['flatten_obs'] = False" ] }, diff --git a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb index 9d9fb654..d9ca6411 100644 --- a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb @@ -684,6 +684,15 @@ " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.game.agents[\"data_manipulation_attacker\"].show_history()" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb b/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb index c7fdafad..40250d0c 100644 --- a/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb +++ b/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb @@ -153,6 +153,49 @@ "PRIMAITE_CONFIG[\"developer_mode\"][\"enabled\"] = was_enabled\n", "PRIMAITE_CONFIG[\"developer_mode\"][\"output_sys_logs\"] = was_syslogs_enabled" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Viewing Agent history\n", + "\n", + "It's possible to view the actions carried out by an agent for a given training session using the `show_history()` method. By default, this will be all actions apart from DONOTHING actions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(data_manipulation_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + "\n", + "env = PrimaiteGymEnv(env_config=cfg)\n", + "\n", + "# Run the training session to generate some resultant data.\n", + "for i in range(100):\n", + " env.step(0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Calling `.show_history()` should show us when the Data Manipulation used the `NODE_APPLICATION_EXECUTE` action." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "attacker = env.game.agents[\"data_manipulation_attacker\"]\n", + "\n", + "attacker.show_history()" + ] } ], "metadata": { diff --git a/src/primaite/notebooks/How-To-Use-Primaite-Dev-Mode.ipynb b/src/primaite/notebooks/How-To-Use-Primaite-Dev-Mode.ipynb index 3bb242d9..85356c64 100644 --- a/src/primaite/notebooks/How-To-Use-Primaite-Dev-Mode.ipynb +++ b/src/primaite/notebooks/How-To-Use-Primaite-Dev-Mode.ipynb @@ -15,15 +15,6 @@ "To display the available dev-mode options, run the command below:" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!primaite setup" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/src/primaite/notebooks/Terminal-Processing.ipynb b/src/primaite/notebooks/Terminal-Processing.ipynb index 5077732b..755b0184 100644 --- a/src/primaite/notebooks/Terminal-Processing.ipynb +++ b/src/primaite/notebooks/Terminal-Processing.ipynb @@ -9,6 +9,13 @@ "© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Simulation Layer Implementation." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -186,6 +193,22 @@ "computer_b.file_system.show()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Information about the latest response when executing a remote command can be seen by calling the `last_response` attribute within `Terminal`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(terminal_a.last_response)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -224,11 +247,254 @@ "source": [ "computer_b.user_session_manager.show(include_historic=True, include_session_id=True)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Game Layer Implementation\n", + "\n", + "This notebook section will detail the implementation of how the game layer utilises the terminal to support different agent actions.\n", + "\n", + "The ``Terminal`` is used in a variety of different ways in the game layer. Specifically, the terminal is leveraged to implement the following actions:\n", + "\n", + "\n", + "| Game Layer Action | Simulation Layer |\n", + "|-----------------------------------|--------------------------|\n", + "| ``node-send-local-command`` | Uses the given user credentials, creates a ``LocalTerminalSession`` and executes the given command and returns the ``RequestResponse``.\n", + "| ``node-session-remote-login`` | Uses the given user credentials and remote IP to create a ``RemoteTerminalSession``.\n", + "| ``node-send-remote-command`` | Uses the given remote IP to locate the correct ``RemoteTerminalSession``, executes the given command and returns the ``RequestsResponse``." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Game Layer Setup\n", + "\n", + "Similar to other notebooks, the next code cells create a custom proxy agent to demonstrate how these commands can be leveraged by agents in the ``UC2`` network environment.\n", + "\n", + "If you're unfamiliar with ``UC2`` then please refer to the [UC2-E2E-Demo notebook for further reference](./Data-Manipulation-E2E-Demonstration.ipynb)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import yaml\n", + "from primaite.config.load import data_manipulation_config_path\n", + "from primaite.session.environment import PrimaiteGymEnv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "custom_terminal_agent = \"\"\"\n", + " - ref: CustomC2Agent\n", + " team: RED\n", + " type: proxy-agent\n", + " action_space:\n", + " action_map:\n", + " 0:\n", + " action: do-nothing\n", + " options: {}\n", + " 1:\n", + " action: node-send-local-command\n", + " options:\n", + " node_name: client_1\n", + " username: admin\n", + " password: admin\n", + " command:\n", + " - file_system\n", + " - create\n", + " - file\n", + " - downloads\n", + " - dog.png\n", + " - False\n", + " 2:\n", + " action: node-session-remote-login\n", + " options:\n", + " node_name: client_1\n", + " username: admin\n", + " password: admin\n", + " remote_ip: 192.168.10.22\n", + " 3:\n", + " action: node-send-remote-command\n", + " options:\n", + " node_name: client_1\n", + " remote_ip: 192.168.10.22\n", + " command:\n", + " - file_system\n", + " - create\n", + " - file\n", + " - downloads\n", + " - cat.png\n", + " - False\n", + "\"\"\"\n", + "custom_terminal_agent_yaml = yaml.safe_load(custom_terminal_agent)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(data_manipulation_config_path()) as f:\n", + " cfg = yaml.safe_load(f)\n", + " # removing all agents & adding the custom agent.\n", + " cfg['agents'] = {}\n", + " cfg['agents'] = custom_terminal_agent_yaml\n", + "\n", + "env = PrimaiteGymEnv(env_config=cfg)\n", + "\n", + "client_1: Computer = env.game.simulation.network.get_node_by_hostname(\"client_1\")\n", + "client_2: Computer = env.game.simulation.network.get_node_by_hostname(\"client_2\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Terminal Action | ``node-send-local-command`` \n", + "\n", + "The yaml snippet below shows all the relevant agent options for this action:\n", + "\n", + "```yaml\n", + "\n", + " action_space:\n", + " action_list:\n", + " ...\n", + " - type: node-send-local-command\n", + " ...\n", + " options:\n", + " nodes: # Node List\n", + " - node_name: client_1\n", + " ...\n", + " ...\n", + " action_map:\n", + " 1:\n", + " action: node-send-local-command\n", + " options:\n", + " node_id: 0 # Index 0 at the node list.\n", + " username: admin\n", + " password: admin\n", + " command:\n", + " - file_system\n", + " - create\n", + " - file\n", + " - downloads\n", + " - dog.png\n", + " - False\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.step(1)\n", + "client_1.file_system.show(full=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Terminal Action | ``node-session-remote-login`` \n", + "\n", + "The yaml snippet below shows all the relevant agent options for this action:\n", + "\n", + "```yaml\n", + "\n", + " action_space:\n", + " action_list:\n", + " ...\n", + " - type: node-session-remote-login\n", + " ...\n", + " options:\n", + " nodes: # Node List\n", + " - node_name: client_1\n", + " ...\n", + " ...\n", + " action_map:\n", + " 2:\n", + " action: node-session-remote-login\n", + " options:\n", + " node_id: 0 # Index 0 at the node list.\n", + " username: admin\n", + " password: admin\n", + " remote_ip: 192.168.10.22 # client_2's ip address.\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.step(2)\n", + "client_2.session_manager.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Terminal Action | ``node-send-remote-command``\n", + "\n", + "The yaml snippet below shows all the relevant agent options for this action:\n", + "\n", + "```yaml\n", + "\n", + " action_space:\n", + " action_list:\n", + " ...\n", + " - type: node-send-remote-command\n", + " ...\n", + " options:\n", + " nodes: # Node List\n", + " - node_name: client_1\n", + " ...\n", + " ...\n", + " action_map:\n", + " 1:\n", + " action: node-send-remote-command\n", + " options:\n", + " node_id: 0 # Index 0 at the node list.\n", + " remote_ip: 192.168.10.22\n", + " commands:\n", + " - file_system\n", + " - create\n", + " - file\n", + " - downloads\n", + " - cat.png\n", + " - False\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.step(3)\n", + "client_2.file_system.show(full=True)" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "venv", "language": "python", "name": "python3" }, diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index b7a9a042..fa545dbc 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -26,14 +26,26 @@ except ModuleNotFoundError: _LOGGER.debug("Torch not available for importing") -def set_random_seed(seed: int) -> Union[None, int]: +def set_random_seed(seed: int, generate_seed_value: bool) -> Union[None, int]: """ Set random number generators. + If seed is None or -1 and generate_seed_value is True randomly generate a + seed value. + If seed is > -1 and generate_seed_value is True ignore the latter and use + the provide seed value. + :param seed: int + :param generate_seed_value: bool + :return: None or the int representing the seed used. """ if seed is None or seed == -1: - return None + if generate_seed_value: + rng = np.random.default_rng() + # 2**32-1 is highest value for python RNG seed. + seed = int(rng.integers(low=0, high=2**32 - 1)) + else: + return None elif seed < -1: raise ValueError("Invalid random number seed") # Seed python RNG @@ -50,6 +62,13 @@ def set_random_seed(seed: int) -> Union[None, int]: return seed +def log_seed_value(seed: int): + """Log the selected seed value to file.""" + path = SIM_OUTPUT.path / "seed.log" + with open(path, "w") as file: + file.write(f"Seed value = {seed}") + + class PrimaiteGymEnv(gymnasium.Env): """ Thin wrapper env to provide agents with a gymnasium API. @@ -65,7 +84,8 @@ class PrimaiteGymEnv(gymnasium.Env): """Object that returns a config corresponding to the current episode.""" self.seed = self.episode_scheduler(0).get("game", {}).get("seed") """Get RNG seed from config file. NB: Must be before game instantiation.""" - self.seed = set_random_seed(self.seed) + self.generate_seed_value = self.episode_scheduler(0).get("game", {}).get("generate_seed_value") + self.seed = set_random_seed(self.seed, self.generate_seed_value) self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {})) """Handles IO for the environment. This produces sys logs, agent logs, etc.""" self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0)) @@ -79,6 +99,8 @@ class PrimaiteGymEnv(gymnasium.Env): _LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}") + log_seed_value(self.seed) + def action_masks(self) -> np.ndarray: """ Return the action mask for the agent. @@ -146,7 +168,7 @@ class PrimaiteGymEnv(gymnasium.Env): f"avg. reward: {self.agent.reward_function.total_reward}" ) if seed is not None: - set_random_seed(seed) + set_random_seed(seed, self.generate_seed_value) self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward if self.io.settings.save_agent_actions: diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index dfc282ca..9cc39848 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -862,7 +862,21 @@ class UserManager(Service, discriminator="user-manager"): """ rm = super()._init_request_manager() - # todo add doc about requeest schemas + # todo add doc about request schemas + rm.add_request( + "add_user", + RequestType( + func=lambda request, context: RequestResponse.from_bool( + self.add_user(username=request[0], password=request[1], is_admin=request[2]) + ) + ), + ) + rm.add_request( + "disable_user", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.disable_user(username=request[0])) + ), + ) rm.add_request( "change_password", RequestType( @@ -1570,7 +1584,7 @@ class Node(SimComponent, ABC): operating_state: Any = None - users: Any = None # Temporary to appease "extra=forbid" + users: List[Dict] = [] # Temporary to appease "extra=forbid" config: ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema()) """Configuration items within Node""" @@ -1636,6 +1650,8 @@ class Node(SimComponent, ABC): self._install_system_software() self.session_manager.node = self self.session_manager.software_manager = self.software_manager + for user in self.config.users: + self.user_manager.add_user(**user, bypass_can_perform_action=True) @property def user_manager(self) -> Optional[UserManager]: @@ -1767,7 +1783,7 @@ class Node(SimComponent, ABC): """ application_name = request[0] if self.software_manager.software.get(application_name): - self.sys_log.warning(f"Can't install {application_name}. It's already installed.") + self.sys_log.info(f"Can't install {application_name}. It's already installed.") return RequestResponse(status="success", data={"reason": "already installed"}) application_class = Application._registry[application_name] self.software_manager.install(application_class) diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 31f88c80..fcd0eb80 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -2,11 +2,12 @@ from __future__ import annotations from ipaddress import IPv4Address -from typing import Any, ClassVar, Dict, Literal, Optional +from typing import Any, ClassVar, Dict, List, Literal, Optional from pydantic import Field from primaite import getLogger +from primaite.simulator.file_system.file_type import FileType from primaite.simulator.network.hardware.base import ( IPWiredNetworkInterface, Link, @@ -339,7 +340,7 @@ class HostNode(Node, discriminator="host-node"): ip_address: IPV4Address services: Any = None # temporarily unset to appease extra="forbid" applications: Any = None # temporarily unset to appease extra="forbid" - folders: Any = None # temporarily unset to appease extra="forbid" + folders: List[Dict] = {} # temporarily unset to appease extra="forbid" network_interfaces: Any = None # temporarily unset to appease extra="forbid" config: ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema()) @@ -348,6 +349,18 @@ class HostNode(Node, discriminator="host-node"): super().__init__(**kwargs) self.connect_nic(NIC(ip_address=kwargs["config"].ip_address, subnet_mask=kwargs["config"].subnet_mask)) + for folder in self.config.folders: + # handle empty foler defined by just a string + self.file_system.create_folder(folder["folder_name"]) + + for file in folder.get("files", []): + self.file_system.create_file( + folder_name=folder["folder_name"], + file_name=file["file_name"], + size=file.get("size", 0), + file_type=FileType[file.get("type", "UNKNOWN").upper()], + ) + @property def nmap(self) -> Optional[NMAP]: """ diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 2cbc23d2..c872b8b3 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -49,7 +49,7 @@ class Firewall(Router, discriminator="firewall"): Example: >>> from primaite.simulator.network.transmission.network_layer import IPProtocol - >>> from primaite.simulator.network.transmission.transport_layer import Port + >>> from primaite.utils.validation.port import Port >>> firewall = Firewall(hostname="Firewall1") >>> firewall.configure_internal_port(ip_address="192.168.1.1", subnet_mask="255.255.255.0") >>> firewall.configure_external_port(ip_address="10.0.0.1", subnet_mask="255.255.255.0") diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index f2a4652c..33b55b9d 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -467,6 +467,7 @@ class AccessControlList(SimComponent): """Check if a packet with the given properties is permitted through the ACL.""" permitted = False rule: ACLRule = None + for _rule in self._acl: if not _rule: continue @@ -1215,9 +1216,9 @@ class Router(NetworkNode, discriminator="router"): config: ConfigSchema = Field(default_factory=lambda: Router.ConfigSchema()) SYSTEM_SOFTWARE: ClassVar[Dict] = { - "UserSessionManager": UserSessionManager, - "UserManager": UserManager, - "Terminal": Terminal, + "user-session-manager": UserSessionManager, + "user-manager": UserManager, + "terminal": Terminal, } network_interfaces: Dict[str, RouterInterface] = {} @@ -1384,6 +1385,12 @@ class Router(NetworkNode, discriminator="router"): return False + def subject_to_acl(self, frame: Frame) -> bool: + """Check that frame is subject to ACL rules.""" + if frame.ip.protocol == "udp" and frame.is_arp: + return False + return True + def receive_frame(self, frame: Frame, from_network_interface: RouterInterface): """ Processes an incoming frame received on one of the router's interfaces. @@ -1397,8 +1404,12 @@ class Router(NetworkNode, discriminator="router"): if self.operating_state != NodeOperatingState.ON: return - # Check if it's permitted - permitted, rule = self.acl.is_permitted(frame) + if self.subject_to_acl(frame=frame): + # Check if it's permitted + permitted, rule = self.acl.is_permitted(frame) + else: + permitted = True + rule = None if not permitted: at_port = self._get_port_of_nic(from_network_interface) diff --git a/src/primaite/simulator/network/transmission/data_link_layer.py b/src/primaite/simulator/network/transmission/data_link_layer.py index e7c2a124..a07194a4 100644 --- a/src/primaite/simulator/network/transmission/data_link_layer.py +++ b/src/primaite/simulator/network/transmission/data_link_layer.py @@ -163,7 +163,7 @@ class Frame(BaseModel): """ Checks if the Frame is an ARP (Address Resolution Protocol) packet. - This is determined by checking if the destination port of the TCP header is equal to the ARP port. + This is determined by checking if the destination and source port of the UDP header is equal to the ARP port. :return: True if the Frame is an ARP packet, otherwise False. """ diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index d6617efa..26e3be79 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -415,5 +415,5 @@ class SessionManager: table.align = "l" table.title = f"{self.sys_log.hostname} Session Manager" for session in self.sessions_by_key.values(): - table.add_row([session.dst_ip_address, session.dst_port, session.protocol]) + table.add_row([session.with_ip_address, session.dst_port, session.protocol]) print(table) diff --git a/src/primaite/simulator/system/services/arp/arp.py b/src/primaite/simulator/system/services/arp/arp.py index 348cc03e..3302041d 100644 --- a/src/primaite/simulator/system/services/arp/arp.py +++ b/src/primaite/simulator/system/services/arp/arp.py @@ -55,7 +55,7 @@ class ARP(Service, discriminator="arp"): :param markdown: If True, format the output as Markdown. Otherwise, use plain text. """ - table = PrettyTable(["IP Address", "MAC Address", "Via"]) + table = PrettyTable(["IP Address", "MAC Address", "Via", "Port"]) if markdown: table.set_style(MARKDOWN) table.align = "l" @@ -66,6 +66,7 @@ class ARP(Service, discriminator="arp"): str(ip), arp.mac_address, self.software_manager.node.network_interfaces[arp.network_interface_uuid].mac_address, + self.software_manager.node.network_interfaces[arp.network_interface_uuid].port_num, ] ) print(table) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 2ce7d176..112f6abc 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -142,12 +142,20 @@ class Terminal(Service, discriminator="terminal"): _client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {} """Dictionary of connect requests made to remote nodes.""" + _last_response: Optional[RequestResponse] = None + """Last response received from RequestManager, for returning remote RequestResponse.""" + def __init__(self, **kwargs): kwargs["name"] = "terminal" kwargs["port"] = PORT_LOOKUP["SSH"] kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) + @property + def last_response(self) -> Optional[RequestResponse]: + """Public version of _last_response attribute.""" + return self._last_response + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -186,7 +194,7 @@ class Terminal(Service, discriminator="terminal"): return RequestResponse(status="failure", data={}) rm.add_request( - "node-session-remote-login", + "node_session_remote_login", request_type=RequestType(func=_remote_login), ) @@ -209,28 +217,45 @@ class Terminal(Service, discriminator="terminal"): command: str = request[1]["command"] remote_connection = self._get_connection_from_ip(ip_address=ip_address) if remote_connection: - outcome = remote_connection.execute(command) - if outcome: - return RequestResponse( - status="success", - data={}, - ) - else: - return RequestResponse( - status="failure", - data={}, - ) + remote_connection.execute(command) + return self.last_response if not None else RequestResponse(status="failure", data={}) + return RequestResponse( + status="failure", + data={"reason": "Failed to execute command."}, + ) rm.add_request( "send_remote_command", request_type=RequestType(func=remote_execute_request), ) + def local_execute_request(request: RequestFormat, context: Dict) -> RequestResponse: + """Executes a command using a local terminal session.""" + command: str = request[2]["command"] + local_connection = self._process_local_login(username=request[0], password=request[1]) + if local_connection: + outcome = local_connection.execute(command) + if outcome: + return RequestResponse( + status="success", + data={"reason": outcome}, + ) + return RequestResponse( + status="success", + data={"reason": "Local Terminal failed to resolve command. Potentially invalid credentials?"}, + ) + + rm.add_request( + "send_local_command", + request_type=RequestType(func=local_execute_request), + ) + return rm def execute(self, command: List[Any]) -> Optional[RequestResponse]: """Execute a passed ssh command via the request manager.""" - return self.parent.apply_request(command) + self._last_response = self.parent.apply_request(command) + return self._last_response def _get_connection_from_ip(self, ip_address: IPv4Address) -> Optional[RemoteTerminalConnection]: """Find Remote Terminal Connection from a given IP.""" @@ -409,6 +434,8 @@ class Terminal(Service, discriminator="terminal"): """ source_ip = kwargs["frame"].ip.src_ip_address self.sys_log.info(f"{self.name}: Received payload: {payload}. Source: {source_ip}") + self._last_response = None # Clear last response + if isinstance(payload, SSHPacket): if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: # validate & add connection @@ -457,6 +484,9 @@ class Terminal(Service, discriminator="terminal"): session_id=session_id, source_ip=source_ip, ) + self._last_response: RequestResponse = RequestResponse( + status="success", data={"reason": "Login Successful"} + ) elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST: # Requesting a command to be executed @@ -468,12 +498,32 @@ class Terminal(Service, discriminator="terminal"): payload.connection_uuid ) remote_session.last_active_step = self.software_manager.node.user_session_manager.current_timestep - self.execute(command) + self._last_response: RequestResponse = self.execute(command) + + if self._last_response.status == "success": + transport_message = SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS + else: + transport_message = SSHTransportMessage.SSH_MSG_SERVICE_FAILED + + payload: SSHPacket = SSHPacket( + payload=self._last_response, + transport_message=transport_message, + connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_DATA, + ) + self.software_manager.send_payload_to_session_manager( + payload=payload, dest_port=self.port, session_id=session_id + ) return True else: self.sys_log.error( f"{self.name}: Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command." ) + elif ( + payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS + or SSHTransportMessage.SSH_MSG_SERVICE_FAILED + ): + # Likely receiving command ack from remote. + self._last_response = payload.payload if isinstance(payload, dict) and payload.get("type"): if payload["type"] == "disconnect": diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 2eddefc1..3f8760c4 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -117,37 +117,44 @@ class WebServer(Service, discriminator="web-server"): :type: payload: HttpRequestPacket """ response = HttpResponsePacket(status_code=HttpStatusCode.NOT_FOUND, payload=payload) - try: - parsed_url = urlparse(payload.request_url) - path = parsed_url.path.strip("/") - if len(path) < 1: + parsed_url = urlparse(payload.request_url) + path = parsed_url.path.strip("/") if parsed_url and parsed_url.path else "" + + if len(path) < 1: + # query succeeded + response.status_code = HttpStatusCode.OK + + if path.startswith("users"): + # get data from DatabaseServer + # get all users + if not self._establish_db_connection(): + # unable to create a db connection + response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR + return response + + if self.db_connection.query("SELECT"): # query succeeded + self.set_health_state(SoftwareHealthState.GOOD) response.status_code = HttpStatusCode.OK + else: + self.set_health_state(SoftwareHealthState.COMPROMISED) + return response - if path.startswith("users"): - # get data from DatabaseServer - # get all users - if not self.db_connection: - self._establish_db_connection() - - if self.db_connection.query("SELECT"): - # query succeeded - self.set_health_state(SoftwareHealthState.GOOD) - response.status_code = HttpStatusCode.OK - else: - self.set_health_state(SoftwareHealthState.COMPROMISED) - - return response - except Exception: # TODO: refactor this. Likely to cause silent bugs. (ADO ticket #2345 ) - # something went wrong on the server - response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR - return response - - def _establish_db_connection(self) -> None: + def _establish_db_connection(self) -> bool: """Establish a connection to db.""" + # if active db connection, return true + if self.db_connection: + return True + + # otherwise, try to create db connection db_client = self.software_manager.software.get("database-client") + + if db_client is None: + return False # database client not installed + self.db_connection: DatabaseClientConnection = db_client.get_new_connection() + return self.db_connection is not None def send( self, diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index b57ed3e0..76b4dfb9 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -28,7 +28,19 @@ game: - ICMP - TCP - UDP - + thresholds: + nmne: + high: 100 + medium: 25 + low: 5 + file_access: + high: 10 + medium: 5 + low: 2 + app_executions: + high: 5 + medium: 3 + low: 2 agents: - ref: client_2_green_user team: GREEN @@ -67,10 +79,16 @@ agents: options: hosts: - hostname: client_1 + applications: + - application_name: WebBrowser + folders: + - folder_name: root + files: + - file_name: "test.txt" - hostname: client_2 - hostname: client_3 num_services: 1 - num_applications: 0 + num_applications: 1 num_folders: 1 num_files: 1 num_nics: 2 @@ -185,6 +203,10 @@ simulation: options: ntp_server_ip: 192.168.1.10 - type: ntp-server + folders: + - folder_name: root + files: + - file_name: test.txt - hostname: client_2 type: computer ip_address: 192.168.10.22 diff --git a/tests/assets/configs/nodes_with_initial_files.yaml b/tests/assets/configs/nodes_with_initial_files.yaml new file mode 100644 index 00000000..d4c6406b --- /dev/null +++ b/tests/assets/configs/nodes_with_initial_files.yaml @@ -0,0 +1,226 @@ +# Basic Switched network +# +# -------------- -------------- -------------- +# | client_1 |------| switch_1 |------| client_2 | +# -------------- -------------- -------------- +# +io_settings: + save_step_metadata: false + save_pcap_logs: true + save_sys_logs: true + sys_log_level: WARNING + agent_log_level: INFO + save_agent_logs: true + write_agent_log_to_terminal: True + + +game: + max_episode_length: 256 + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + +agents: + - ref: client_2_green_user + team: GREEN + type: periodic-agent + action_space: + action_map: + 0: + action: do-nothing + options: {} + 1: + action: node-application-execute + options: + node_id: 0 + application_id: 0 + + agent_settings: + possible_start_nodes: [client_2,] + target_application: web-browser + start_step: 5 + frequency: 4 + variance: 3 + + + + - ref: defender + team: BLUE + type: proxy-agent + + observation_space: + type: custom + options: + components: + - type: nodes + label: NODES + options: + hosts: + - hostname: client_1 + - hostname: client_2 + - hostname: client_3 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + include_nmne: false + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.23 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: links + label: LINKS + options: + link_references: + - switch_1:eth-1<->client_1:eth-1 + - switch_1:eth-2<->client_2:eth-1 + - type: none + label: ICS + options: {} + + action_space: + action_map: + 0: + action: do-nothing + options: {} + + reward_function: + reward_components: + - type: database-file-integrity + weight: 0.5 + options: + node_hostname: database_server + folder_name: database + file_name: database.db + + - type: web-server-404-penalty + weight: 0.5 + options: + node_hostname: web_server + service_name: web_server_web_service + + + agent_settings: + flatten_obs: true + +simulation: + network: + nodes: + + - type: switch + hostname: switch_1 + num_ports: 8 + + - hostname: client_1 + type: computer + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: ransomware-script + - type: web-browser + options: + target_url: http://arcd.com/users/ + - type: database-client + options: + db_server_ip: 192.168.1.10 + server_password: arcd + - type: data-manipulation-bot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.21 + server_password: arcd + - type: dos-bot + options: + target_ip_address: 192.168.10.21 + payload: SPOOF DATA + port_scan_p_of_success: 0.8 + services: + - type: dns-client + options: + dns_server: 192.168.1.10 + - type: dns-server + options: + domain_mapping: + arcd.com: 192.168.1.10 + - type: database-service + options: + backup_server_ip: 192.168.1.10 + - type: web-server + - type: ftp-server + options: + server_password: arcd + - type: ntp-client + options: + ntp_server_ip: 192.168.1.10 + - type: ntp-server + - hostname: client_2 + type: computer + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + folders: + - folder_name: empty_folder + - folder_name: downloads + files: + - file_name: "test.txt" + - file_name: "another_file.pwtwoti" + - folder_name: root + files: + - file_name: passwords + size: 663 + type: TXT + # pre installed services and applications + - hostname: client_3 + type: computer + ip_address: 192.168.10.23 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + start_up_duration: 0 + shut_down_duration: 0 + operating_state: "OFF" + # pre installed services and applications + + links: + - endpoint_a_hostname: switch_1 + endpoint_a_port: 1 + endpoint_b_hostname: client_1 + endpoint_b_port: 1 + bandwidth: 200 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 2 + endpoint_b_hostname: client_2 + endpoint_b_port: 1 + bandwidth: 200 diff --git a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py index 4153adc0..627fc53b 100644 --- a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py +++ b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py @@ -8,7 +8,7 @@ from primaite.config.load import data_manipulation_config_path from primaite.game.game import PrimaiteGame from tests import TEST_ASSETS_ROOT -BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" +BASIC_SWITCHED_NETWORK_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" def load_config(config_path: Union[str, Path]) -> PrimaiteGame: @@ -24,3 +24,42 @@ def test_thresholds(): game = load_config(data_manipulation_config_path()) assert game.options.thresholds is not None + + +def test_nmne_threshold(): + """Test that the NMNE thresholds are properly loaded in by observation.""" + game = load_config(BASIC_SWITCHED_NETWORK_CONFIG) + + assert game.options.thresholds["nmne"] is not None + + # get NIC observation + nic_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].nics[0] + assert nic_obs.low_nmne_threshold == 5 + assert nic_obs.med_nmne_threshold == 25 + assert nic_obs.high_nmne_threshold == 100 + + +def test_file_access_threshold(): + """Test that the NMNE thresholds are properly loaded in by observation.""" + game = load_config(BASIC_SWITCHED_NETWORK_CONFIG) + + assert game.options.thresholds["file_access"] is not None + + # get file observation + file_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].folders[0].files[0] + assert file_obs.low_file_access_threshold == 2 + assert file_obs.med_file_access_threshold == 5 + assert file_obs.high_file_access_threshold == 10 + + +def test_app_executions_threshold(): + """Test that the NMNE thresholds are properly loaded in by observation.""" + game = load_config(BASIC_SWITCHED_NETWORK_CONFIG) + + assert game.options.thresholds["app_executions"] is not None + + # get application observation + app_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].applications[0] + assert app_obs.low_app_execution_threshold == 2 + assert app_obs.med_app_execution_threshold == 3 + assert app_obs.high_app_execution_threshold == 5 diff --git a/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py b/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py new file mode 100644 index 00000000..4c99a39f --- /dev/null +++ b/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py @@ -0,0 +1,64 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from pathlib import Path +from typing import Union + +import yaml + +from primaite.game.game import PrimaiteGame +from primaite.simulator.file_system.file_type import FileType +from tests import TEST_ASSETS_ROOT + +BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/nodes_with_initial_files.yaml" + + +def load_config(config_path: Union[str, Path]) -> PrimaiteGame: + """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) + + return PrimaiteGame.from_config(cfg) + + +def test_node_file_system_from_config(): + """Test that the appropriate files are instantiated in nodes when loaded from config.""" + game = load_config(BASIC_CONFIG) + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + assert client_1.software_manager.software.get("database-service") # database service should be installed + assert client_1.file_system.get_file(folder_name="database", file_name="database.db") # database files should exist + + assert client_1.software_manager.software.get("web-server") # web server should be installed + assert client_1.file_system.get_file(folder_name="primaite", file_name="index.html") # web files should exist + + client_2 = game.simulation.network.get_node_by_hostname("client_2") + + # database service should not be installed + assert client_2.software_manager.software.get("database-service") is None + # database files should not exist + assert client_2.file_system.get_file(folder_name="database", file_name="database.db") is None + + # web server should not be installed + assert client_2.software_manager.software.get("web-server") is None + # web files should not exist + assert client_2.file_system.get_file(folder_name="primaite", file_name="index.html") is None + + empty_folder = client_2.file_system.get_folder(folder_name="empty_folder") + assert empty_folder + assert len(empty_folder.files) == 0 # should have no files + + password_file = client_2.file_system.get_file(folder_name="root", file_name="passwords.txt") + assert password_file # should exist + assert password_file.file_type is FileType.TXT + assert password_file.size == 663 + + downloads_folder = client_2.file_system.get_folder(folder_name="downloads") + assert downloads_folder # downloads folder should exist + + test_txt = downloads_folder.get_file(file_name="test.txt") + assert test_txt # test.txt should exist + assert test_txt.file_type is FileType.TXT + + unknown_file_type = downloads_folder.get_file(file_name="another_file.pwtwoti") + assert unknown_file_type # unknown_file_type should exist + assert unknown_file_type.file_type is FileType.UNKNOWN diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py index 6d30644c..3ee97fb7 100644 --- a/tests/integration_tests/game_layer/actions/test_terminal_actions.py +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -164,3 +164,55 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam assert server_1.file_system.get_folder("folder123") is None assert server_1.file_system.get_file("folder123", "doggo.pdf") is None + + +def test_local_terminal(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + # create a new user account on server_1 that will be logged into remotely + client_1_usm: UserManager = client_1.software_manager.software["user-manager"] + client_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "node-send-local-command", + { + "node_name": "client_1", + "username": "user123", + "password": "password", + "command": ["file_system", "create", "file", "folder123", "doggo.pdf", False], + }, + ) + agent.store_action(action) + game.step() + + assert client_1.file_system.get_folder("folder123") + assert client_1.file_system.get_file("folder123", "doggo.pdf") + + # Change password + action = ( + "node-account-change-password", + { + "node_name": "client_1", + "username": "user123", + "current_password": "password", + "new_password": "different_password", + }, + ) + agent.store_action(action) + game.step() + + action = ( + "node-send-local-command", + { + "node_name": "client_1", + "username": "user123", + "password": "password", + "command": ["file_system", "create", "file", "folder123", "cat.pdf", False], + }, + ) + agent.store_action(action) + game.step() + + assert client_1.file_system.get_file("folder123", "cat.pdf") is None + client_1.session_manager.show() diff --git a/tests/integration_tests/game_layer/actions/test_user_account_actions.py b/tests/integration_tests/game_layer/actions/test_user_account_actions.py new file mode 100644 index 00000000..26b871db --- /dev/null +++ b/tests/integration_tests/game_layer/actions/test_user_account_actions.py @@ -0,0 +1,176 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest + +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.network.router import ACLAction +from primaite.utils.validation.port import Port, PORT_LOOKUP + + +@pytest.fixture +def game_and_agent_fixture(game_and_agent): + """Create a game with a simple agent that can be controlled by the tests.""" + game, agent = game_and_agent + + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + client_1.start_up_duration = 3 + + return (game, agent) + + +def test_user_account_add_user_action(game_and_agent_fixture): + """Tests the add user account action.""" + game, agent = game_and_agent_fixture + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + assert len(client_1.user_manager.users) == 1 # admin is created by default + assert len(client_1.user_manager.admins) == 1 + + # add admin account + action = ( + "node-account-add-user", + {"node_name": "client_1", "username": "admin_2", "password": "e-tronic-boogaloo", "is_admin": True}, + ) + agent.store_action(action) + game.step() + + assert len(client_1.user_manager.users) == 2 # new user added + assert len(client_1.user_manager.admins) == 2 + + # add non admin account + action = ( + "node-account-add-user", + {"node_name": "client_1", "username": "leeroy.jenkins", "password": "no_plan_needed", "is_admin": False}, + ) + agent.store_action(action) + game.step() + + assert len(client_1.user_manager.users) == 3 # new user added + assert len(client_1.user_manager.admins) == 2 + + +def test_user_account_disable_user_action(game_and_agent_fixture): + """Tests the disable user account action.""" + game, agent = game_and_agent_fixture + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + client_1.user_manager.add_user(username="test", password="password", is_admin=True) + assert len(client_1.user_manager.users) == 2 # new user added + assert len(client_1.user_manager.admins) == 2 + + test_user = client_1.user_manager.users.get("test") + assert test_user + assert test_user.disabled is not True + + # disable test account + action = ( + "node-account-disable-user", + { + "node_name": "client_1", + "username": "test", + }, + ) + agent.store_action(action) + game.step() + assert test_user.disabled + + +def test_user_account_change_password_action(game_and_agent_fixture): + """Tests the change password user account action.""" + game, agent = game_and_agent_fixture + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + client_1.user_manager.add_user(username="test", password="password", is_admin=True) + + test_user = client_1.user_manager.users.get("test") + assert test_user.password == "password" + + # change account password + action = ( + "node-account-change-password", + {"node_name": "client_1", "username": "test", "current_password": "password", "new_password": "2Hard_2_Hack"}, + ) + agent.store_action(action) + game.step() + + assert test_user.password == "2Hard_2_Hack" + + +def test_user_account_create_terminal_action(game_and_agent_fixture): + """Tests that agents can use the terminal to create new users.""" + game, agent = game_and_agent_fixture + + router = game.simulation.network.get_node_by_hostname("router") + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=4) + + server_1 = game.simulation.network.get_node_by_hostname("server_1") + server_1_usm = server_1.software_manager.software["user-manager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "node-session-remote-login", + { + "node_name": "client_1", + "username": "user123", + "password": "password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "success" + + # Create a new user account via terminal. + action = ( + "node-send-remote-command", + { + "node_name": "client_1", + "remote_ip": str(server_1.network_interface[1].ip_address), + "command": ["service", "user-manager", "add_user", "new_user", "new_pass", True], + }, + ) + agent.store_action(action) + game.step() + new_user = server_1.user_manager.users.get("new_user") + assert new_user + assert new_user.password == "new_pass" + assert new_user.disabled is not True + + +def test_user_account_disable_terminal_action(game_and_agent_fixture): + """Tests that agents can use the terminal to disable users.""" + game, agent = game_and_agent_fixture + router = game.simulation.network.get_node_by_hostname("router") + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=4) + + server_1 = game.simulation.network.get_node_by_hostname("server_1") + server_1_usm = server_1.software_manager.software["user-manager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "node-session-remote-login", + { + "node_name": "client_1", + "username": "user123", + "password": "password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "success" + + # Disable a user via terminal + action = ( + "node-send-remote-command", + { + "node_name": "client_1", + "remote_ip": str(server_1.network_interface[1].ip_address), + "command": ["service", "user-manager", "disable_user", "user123"], + }, + ) + agent.store_action(action) + game.step() + + new_user = server_1.user_manager.users.get("user123") + assert new_user + assert new_user.disabled is True diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py index 7323461c..722fd294 100644 --- a/tests/integration_tests/game_layer/observations/test_file_system_observations.py +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -44,6 +44,38 @@ def test_file_observation(simulation): assert observation_state.get("health_status") == 3 # corrupted +def test_config_file_access_categories(simulation): + pc: Computer = simulation.network.get_node_by_hostname("client_1") + file_obs = FileObservation( + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"], + include_num_access=False, + file_system_requires_scan=True, + thresholds={"file_access": {"low": 3, "medium": 6, "high": 9}}, + ) + + assert file_obs.high_file_access_threshold == 9 + assert file_obs.med_file_access_threshold == 6 + assert file_obs.low_file_access_threshold == 3 + + with pytest.raises(Exception): + # should throw an error + FileObservation( + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"], + include_num_access=False, + file_system_requires_scan=True, + thresholds={"file_access": {"low": 9, "medium": 6, "high": 9}}, + ) + + with pytest.raises(Exception): + # should throw an error + FileObservation( + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"], + include_num_access=False, + file_system_requires_scan=True, + thresholds={"file_access": {"low": 3, "medium": 9, "high": 9}}, + ) + + def test_folder_observation(simulation): """Test the folder observation.""" pc: Computer = simulation.network.get_node_by_hostname("client_1") diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index b5e5ca81..046db4eb 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -77,6 +77,22 @@ def test_nic(simulation): nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True) + # The Simulation object created by the fixture also creates the + # NICObservation class with the NICObservation.capture_nmnme class variable + # set to False. Under normal (non-test) circumstances this class variable + # is set from a config file such as data_manipulation.yaml. So although + # capture_nmne is set to True in the NetworkInterface class it's still False + # in the NICObservation class so we set it now. + nic_obs.capture_nmne = True + + # The Simulation object created by the fixture also creates the + # NICObservation class with the NICObservation.capture_nmnme class variable + # set to False. Under normal (non-test) circumstances this class variable + # is set from a config file such as data_manipulation.yaml. So although + # capture_nmne is set to True in the NetworkInterface class it's still False + # in the NICObservation class so we set it now. + nic_obs.capture_nmne = True + # Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs nmne_config = { "capture_nmne": True, # Enable the capture of MNEs @@ -115,14 +131,11 @@ def test_nic_categories(simulation): assert nic_obs.low_nmne_threshold == 0 # default -@pytest.mark.skip(reason="Feature not implemented yet") def test_config_nic_categories(simulation): pc: Computer = simulation.network.get_node_by_hostname("client_1") nic_obs = NICObservation( - where=["network", "nodes", pc.hostname, "NICs", 1], - low_nmne_threshold=3, - med_nmne_threshold=6, - high_nmne_threshold=9, + where=["network", "nodes", pc.config.hostname, "NICs", 1], + thresholds={"nmne": {"low": 3, "medium": 6, "high": 9}}, include_nmne=True, ) @@ -133,20 +146,16 @@ def test_config_nic_categories(simulation): with pytest.raises(Exception): # should throw an error NICObservation( - where=["network", "nodes", pc.hostname, "NICs", 1], - low_nmne_threshold=9, - med_nmne_threshold=6, - high_nmne_threshold=9, + where=["network", "nodes", pc.config.hostname, "NICs", 1], + thresholds={"nmne": {"low": 9, "medium": 6, "high": 9}}, include_nmne=True, ) with pytest.raises(Exception): # should throw an error NICObservation( - where=["network", "nodes", pc.hostname, "NICs", 1], - low_nmne_threshold=3, - med_nmne_threshold=9, - high_nmne_threshold=9, + where=["network", "nodes", pc.config.hostname, "NICs", 1], + thresholds={"nmne": {"low": 3, "medium": 9, "high": 9}}, include_nmne=True, ) diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index 09eb3fe4..aef60bc2 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -39,6 +39,8 @@ def test_host_observation(simulation): folders=[], network_interfaces=[], file_system_requires_scan=True, + services_requires_scan=True, + applications_requires_scan=True, include_users=False, ) diff --git a/tests/integration_tests/game_layer/observations/test_obs_data_capture.py b/tests/integration_tests/game_layer/observations/test_obs_data_capture.py new file mode 100644 index 00000000..e8bdea22 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_obs_data_capture.py @@ -0,0 +1,28 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import json + +from primaite.session.environment import PrimaiteGymEnv +from primaite.session.io import PrimaiteIO +from tests import TEST_ASSETS_ROOT + +DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml" + + +def test_obs_data_in_log_file(): + """Create a log file of AgentHistoryItems and check observation data is + included. Assumes that data_manipulation.yaml has an agent labelled + 'defender' with a non-null observation space. + The log file will be in: + primaite/VERSION/sessions/YYYY-MM-DD/HH-MM-SS/agent_actions + """ + env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG) + env.reset() + for _ in range(10): + env.step(0) + env.reset() + io = PrimaiteIO() + path = io.generate_agent_actions_save_path(episode=1) + with open(path, "r") as f: + j = json.load(f) + + assert type(j["0"]["defender"]["observation"]) == dict diff --git a/tests/integration_tests/game_layer/observations/test_software_observations.py b/tests/integration_tests/game_layer/observations/test_software_observations.py index 28cdaf01..1ebff10c 100644 --- a/tests/integration_tests/game_layer/observations/test_software_observations.py +++ b/tests/integration_tests/game_layer/observations/test_software_observations.py @@ -29,7 +29,9 @@ def test_service_observation(simulation): ntp_server = pc.software_manager.software.get("ntp-server") assert ntp_server - service_obs = ServiceObservation(where=["network", "nodes", pc.config.hostname, "services", "ntp-server"]) + service_obs = ServiceObservation( + where=["network", "nodes", pc.config.hostname, "services", "ntp-server"], services_requires_scan=True + ) assert service_obs.space["operating_status"] == spaces.Discrete(7) assert service_obs.space["health_status"] == spaces.Discrete(5) @@ -54,7 +56,9 @@ def test_application_observation(simulation): web_browser: WebBrowser = pc.software_manager.software.get("web-browser") assert web_browser - app_obs = ApplicationObservation(where=["network", "nodes", pc.config.hostname, "applications", "web-browser"]) + app_obs = ApplicationObservation( + where=["network", "nodes", pc.config.hostname, "applications", "web-browser"], applications_requires_scan=True + ) web_browser.close() observation_state = app_obs.observe(simulation.describe_state()) @@ -69,3 +73,33 @@ def test_application_observation(simulation): assert observation_state.get("health_status") == 1 assert observation_state.get("operating_status") == 1 # running assert observation_state.get("num_executions") == 1 + + +def test_application_executions_categories(simulation): + pc: Computer = simulation.network.get_node_by_hostname("client_1") + + app_obs = ApplicationObservation( + where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"], + applications_requires_scan=False, + thresholds={"app_executions": {"low": 3, "medium": 6, "high": 9}}, + ) + + assert app_obs.high_app_execution_threshold == 9 + assert app_obs.med_app_execution_threshold == 6 + assert app_obs.low_app_execution_threshold == 3 + + with pytest.raises(Exception): + # should throw an error + ApplicationObservation( + where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"], + applications_requires_scan=False, + thresholds={"app_executions": {"low": 9, "medium": 6, "high": 9}}, + ) + + with pytest.raises(Exception): + # should throw an error + ApplicationObservation( + where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"], + applications_requires_scan=False, + thresholds={"app_executions": {"low": 3, "medium": 9, "high": 9}}, + ) diff --git a/tests/integration_tests/game_layer/test_RNG_seed.py b/tests/integration_tests/game_layer/test_RNG_seed.py index 45fa445d..2b80e153 100644 --- a/tests/integration_tests/game_layer/test_RNG_seed.py +++ b/tests/integration_tests/game_layer/test_RNG_seed.py @@ -7,6 +7,7 @@ import yaml from primaite.config.load import data_manipulation_config_path from primaite.game.agent.interface import AgentHistoryItem from primaite.session.environment import PrimaiteGymEnv +from primaite.simulator import SIM_OUTPUT @pytest.fixture() @@ -33,6 +34,11 @@ def test_rng_seed_set(create_env): assert a == b + # Check that seed log file was created. + path = SIM_OUTPUT.path / "seed.log" + with open(path, "r") as file: + assert file + def test_rng_seed_unset(create_env): """Test with no RNG seed.""" @@ -48,3 +54,19 @@ def test_rng_seed_unset(create_env): b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do-nothing"] assert a != b + + +def test_for_generated_seed(): + """ + Show that setting generate_seed_value to true producess a valid seed. + """ + with open(data_manipulation_config_path(), "r") as f: + cfg = yaml.safe_load(f) + + cfg["game"]["generate_seed_value"] = True + PrimaiteGymEnv(env_config=cfg) + path = SIM_OUTPUT.path / "seed.log" + with open(path, "r") as file: + data = file.read() + + assert data.split(" ")[3] != None diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index 03b94ab7..764c207e 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -22,6 +22,7 @@ from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus from primaite.simulator.network.hardware.nodes.network.firewall import Firewall +from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.software import SoftwareHealthState @@ -107,7 +108,7 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox """ Test that the RouterACLAddRuleAction can form a request and that it is accepted by the simulation. - The acl starts off with 4 rules, and we add a rule, and check that the acl now has 5 rules. + The ACL starts off with 3 rules, and we add a rule, and check that the ACL now has 4 rules. """ game, agent = game_and_agent @@ -139,7 +140,7 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox agent.store_action(action) game.step() - # 3: Check that the acl now has 5 rules, and that client 1 cannot ping server 2 + # 3: Check that the acl now has 6 rules, and that client 1 cannot ping server 2 assert router.acl.num_rules == 5 assert not client_1.ping("10.0.2.3") # Cannot ping server_2 assert client_1.ping("10.0.2.2") # Can ping server_1 @@ -164,11 +165,9 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox }, ) agent.store_action(action) - print(agent.most_recent_action) game.step() - print(agent.most_recent_action) + # 5: Check that the ACL now has 6 rules, but that server_1 can still ping server_2 - print(router.acl.show()) assert router.acl.num_rules == 6 assert server_1.ping("10.0.2.3") # Can ping server_2 @@ -180,7 +179,8 @@ def test_router_acl_removerule_integration(game_and_agent: Tuple[PrimaiteGame, P # 1: Check that http traffic is going across the network nicely. client_1 = game.simulation.network.get_node_by_hostname("client_1") server_1 = game.simulation.network.get_node_by_hostname("server_1") - router = game.simulation.network.get_node_by_hostname("router") + router: Router = game.simulation.network.get_node_by_hostname("router") + assert router.acl.num_rules == 4 browser: WebBrowser = client_1.software_manager.software.get("web-browser") browser.run() @@ -198,7 +198,7 @@ def test_router_acl_removerule_integration(game_and_agent: Tuple[PrimaiteGame, P agent.store_action(action) game.step() - # 3: Check that the ACL now has 3 rules, and that client 1 cannot access example.com + # 3: Check that the ACL now has 2 rules, and that client 1 cannot access example.com assert router.acl.num_rules == 3 assert not browser.get_webpage() client_1.software_manager.software.get("dns-client").dns_cache.clear() diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index ea7fbc99..80e7c3b3 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -1,5 +1,11 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from itertools import product + +import yaml + +from primaite.config.load import data_manipulation_config_path from primaite.game.agent.observations.nic_observations import NICObservation +from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server @@ -277,3 +283,19 @@ def test_capture_nmne_observations(uc2_network: Network): assert web_nic_obs["outbound"] == expected_nmne assert db_nic_obs["inbound"] == expected_nmne uc2_network.apply_timestep(timestep=0) + + +def test_nmne_parameter_settings(): + """ + Check that the four permutations of the values of capture_nmne and + include_nmne work as expected. + """ + + with open(data_manipulation_config_path(), "r") as f: + cfg = yaml.safe_load(f) + + DEFENDER = 3 + for capture, include in product([True, False], [True, False]): + cfg["simulation"]["network"]["nmne_config"]["capture_nmne"] = capture + cfg["agents"][DEFENDER]["observation_space"]["options"]["components"][0]["options"]["include_nmne"] = include + PrimaiteGymEnv(env_config=cfg) diff --git a/tests/integration_tests/system/test_arp.py b/tests/integration_tests/system/test_arp.py index 055d58c6..b9a92255 100644 --- a/tests/integration_tests/system/test_arp.py +++ b/tests/integration_tests/system/test_arp.py @@ -1,6 +1,7 @@ -# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK -from primaite.simulator.network.hardware.nodes.network.router import RouterARP +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterARP from primaite.simulator.system.services.arp.arp import ARP +from primaite.utils.validation.port import PORT_LOOKUP from tests.integration_tests.network.test_routing import multi_hop_network @@ -48,3 +49,19 @@ def test_arp_fails_for_network_address_between_routers(multi_hop_network): actual_result = router_1_arp.get_arp_cache_mac_address(router_1.network_interface[1].ip_network.network_address) assert actual_result == expected_result + + +def test_arp_not_affected_by_acl(multi_hop_network): + pc_a = multi_hop_network.get_node_by_hostname("pc_a") + router_1: Router = multi_hop_network.get_node_by_hostname("router_1") + + # Add explicit rule to block ARP traffic. This shouldn't actually stop ARP traffic + # as it operates a different layer within the network. + router_1.acl.add_rule(action=ACLAction.DENY, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=23) + + pc_a_arp: ARP = pc_a.software_manager.arp + + expected_result = router_1.network_interface[2].mac_address + actual_result = pc_a_arp.get_arp_cache_mac_address(router_1.network_interface[2].ip_address) + + assert actual_result == expected_result diff --git a/tests/unit_tests/_primaite/_game/_agent/test_observations.py b/tests/unit_tests/_primaite/_game/_agent/test_observations.py index 5d5921a9..3df6ca0a 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_observations.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_observations.py @@ -1,10 +1,11 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +import json from typing import List import pytest import yaml -from primaite.game.agent.observations import ObservationManager +from primaite.game.agent.observations import ApplicationObservation, ObservationManager, ServiceObservation from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation from primaite.game.agent.observations.host_observations import HostObservation @@ -136,3 +137,227 @@ class TestFileSystemRequiresScan: [], files=[], num_files=0, include_num_access=False, file_system_requires_scan=True ) assert obs_requiring_scan.observe(folder_state)["health_status"] == 1 + + +class TestServicesRequiresScan: + @pytest.mark.parametrize( + ("yaml_option_string", "expected_val"), + ( + ("services_requires_scan: true", True), + ("services_requires_scan: false", False), + (" ", True), + ), + ) + def test_obs_config(self, yaml_option_string, expected_val): + """Check that the default behaviour is to set service_requires_scan to True.""" + obs_cfg_yaml = f""" + type: custom + options: + components: + - type: nodes + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: web-server + - service_name: dns-client + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + services: + - service_name: ftp-server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 3 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + {yaml_option_string} + include_nmne: true + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: links + label: LINKS + options: + link_references: + - router_1:eth-1<->switch_1:eth-8 + - router_1:eth-2<->switch_2:eth-8 + - switch_1:eth-1<->domain_controller:eth-1 + - switch_1:eth-2<->web_server:eth-1 + - switch_1:eth-3<->database_server:eth-1 + - switch_1:eth-4<->backup_server:eth-1 + - switch_1:eth-7<->security_suite:eth-1 + - switch_2:eth-1<->client_1:eth-1 + - switch_2:eth-2<->client_2:eth-1 + - switch_2:eth-7<->security_suite:eth-2 + - type: none + label: ICS + options: {{}} + + """ + + cfg = yaml.safe_load(obs_cfg_yaml) + manager = ObservationManager.from_config(cfg) + + hosts: List[HostObservation] = manager.obs.components["NODES"].hosts + for i, host in enumerate(hosts): + services: List[ServiceObservation] = host.services + for j, service in enumerate(services): + val = service.services_requires_scan + print(f"host {i} service {j} {val}") + assert val == expected_val # Make sure services require scan by default + + def test_services_requires_scan(self): + state = {"health_state_actual": 3, "health_state_visible": 1, "operating_state": 1} + + obs_requiring_scan = ServiceObservation([], services_requires_scan=True) + assert obs_requiring_scan.observe(state)["health_status"] == 1 # should be visible value + + obs_not_requiring_scan = ServiceObservation([], services_requires_scan=False) + assert obs_not_requiring_scan.observe(state)["health_status"] == 3 # should be actual value + + +class TestApplicationsRequiresScan: + @pytest.mark.parametrize( + ("yaml_option_string", "expected_val"), + ( + ("applications_requires_scan: true", True), + ("applications_requires_scan: false", False), + (" ", True), + ), + ) + def test_obs_config(self, yaml_option_string, expected_val): + """Check that the default behaviour is to set applications_requires_scan to True.""" + obs_cfg_yaml = f""" + type: custom + options: + components: + - type: nodes + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + applications: + - application_name: web-browser + - hostname: client_2 + applications: + - application_name: web-browser + - application_name: database-client + num_services: 0 + num_applications: 3 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + {yaml_option_string} + include_nmne: true + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: links + label: LINKS + options: + link_references: + - router_1:eth-1<->switch_1:eth-8 + - router_1:eth-2<->switch_2:eth-8 + - switch_1:eth-1<->domain_controller:eth-1 + - switch_1:eth-2<->web_server:eth-1 + - switch_1:eth-3<->database_server:eth-1 + - switch_1:eth-4<->backup_server:eth-1 + - switch_1:eth-7<->security_suite:eth-1 + - switch_2:eth-1<->client_1:eth-1 + - switch_2:eth-2<->client_2:eth-1 + - switch_2:eth-7<->security_suite:eth-2 + - type: none + label: ICS + options: {{}} + + """ + + cfg = yaml.safe_load(obs_cfg_yaml) + manager = ObservationManager.from_config(cfg) + + hosts: List[HostObservation] = manager.obs.components["NODES"].hosts + for i, host in enumerate(hosts): + services: List[ServiceObservation] = host.services + for j, service in enumerate(services): + val = service.services_requires_scan + print(f"host {i} service {j} {val}") + assert val == expected_val # Make sure applications require scan by default + + def test_applications_requires_scan(self): + state = {"health_state_actual": 3, "health_state_visible": 1, "operating_state": 1, "num_executions": 1} + + obs_requiring_scan = ApplicationObservation([], applications_requires_scan=True) + assert obs_requiring_scan.observe(state)["health_status"] == 1 # should be visible value + + obs_not_requiring_scan = ApplicationObservation([], applications_requires_scan=False) + assert obs_not_requiring_scan.observe(state)["health_status"] == 3 # should be actual value diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py index 91369f6c..81e05467 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py @@ -73,7 +73,7 @@ def test_ftp_should_not_process_commands_if_service_not_running(ftp_client): assert ftp_client_service._process_ftp_command(payload=payload).status_code is FTPStatusCode.ERROR -def test_ftp_tries_to_senf_file__that_does_not_exist(ftp_client): +def test_ftp_tries_to_send_file__that_does_not_exist(ftp_client): """Method send_file should return false if no file to send.""" assert ftp_client.file_system.get_file(folder_name="root", file_name="test.txt") is None diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 32fcae9a..3b2377e9 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -6,6 +6,7 @@ import pytest from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame +from primaite.interface.request import RequestResponse from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server @@ -442,3 +443,59 @@ def test_terminal_connection_timeout(basic_network): assert len(computer_b.user_session_manager.remote_sessions) == 0 assert not remote_connection.is_active + + +def test_terminal_last_response_updates(basic_network): + """Test that the _last_response within Terminal correctly updates.""" + network: Network = basic_network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") + + assert terminal_a.last_response is None + + remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11") + + # Last response should be a successful logon + assert terminal_a.last_response == RequestResponse(status="success", data={"reason": "Login Successful"}) + + remote_connection.execute(command=["software_manager", "application", "install", "ransomware-script"]) + + # Last response should now update following successful install + assert terminal_a.last_response == RequestResponse(status="success", data={}) + + remote_connection.execute(command=["software_manager", "application", "install", "ransomware-script"]) + + # Last response should now update to success, but with supplied reason. + assert terminal_a.last_response == RequestResponse(status="success", data={"reason": "already installed"}) + + remote_connection.execute(command=["file_system", "create", "file", "folder123", "doggo.pdf", False]) + + # Check file was created. + assert computer_b.file_system.access_file(folder_name="folder123", file_name="doggo.pdf") + + # Last response should be confirmation of file creation. + assert terminal_a.last_response == RequestResponse( + status="success", + data={"file_name": "doggo.pdf", "folder_name": "folder123", "file_type": "PDF", "file_size": 102400}, + ) + + remote_connection.execute( + command=[ + "service", + "ftp-client", + "send", + { + "dest_ip_address": "192.168.0.2", + "src_folder": "folder123", + "src_file_name": "cat.pdf", + "dest_folder": "root", + "dest_file_name": "cat.pdf", + }, + ] + ) + + assert terminal_a.last_response == RequestResponse( + status="failure", + data={"reason": "Unable to locate given file on local file system. Perhaps given options are invalid?"}, + )