diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f87f54e..871b9923 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [4.0.0] = TBC ### Added +- Log observation space data by episode and step. +- Added `show_history` method to Agents, allowing you to view actions taken by an agent per step. By default, `do-nothing` actions are omitted. +- New ``node-send-local-command`` action implemented which grants agents the ability to execute commands locally. (Previously limited to remote only) +- Added ability to set the observation threshold for NMNE, file access and application executions ### Changed - Agents now follow a common configuration format, simplifying the configuration of agents and their extensibilty. @@ -24,6 +28,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Updated tests that don't use YAMLs to still use the new action and agent schemas - Nodes now use a config schema and are extensible, allowing for plugin support. - Node tests have been updated to use the new node config schemas when not using YAML files. +- ACLs are no longer applied to layer-2 traffic. +- Random number seed values are recorded in simulation/seed.log if the seed is set in the config file + or `generate_seed_value` is set to `true`. +- ARP .show() method will now include the port number associated with each entry. +- Added `services_requires_scan` and `applications_requires_scan` to agent observation space config to allow the agents to be able to see actual health states of services and applications without requiring scans (Default `True`, set to `False` to allow agents to see actual health state without scanning). +- Updated the `Terminal` class to provide response information when sending remote command execution. ### Fixed - DNS client no longer fails to check its cache if a DNS server address is missing. diff --git a/docs/source/configuration/agents.rst b/docs/source/configuration/agents.rst index dce8da3a..ee84aede 100644 --- a/docs/source/configuration/agents.rst +++ b/docs/source/configuration/agents.rst @@ -21,7 +21,7 @@ Agents can be scripted (deterministic and stochastic), or controlled by a reinfo team: GREEN type: probabilistic-agent observation_space: - type: UC2GreenObservation + type: UC2GreenObservation # TODO: what action_space: reward_function: reward_components: @@ -160,3 +160,4 @@ If ``True``, gymnasium flattening will be performed on the observation space bef ----------------- Agents will record their action log for each step. This is a summary of what the agent did, along with response information from requests within the simulation. +A summary of the actions taken by the agent can be viewed using the `show_history()` function. By default, this will display all actions taken apart from ``DONOTHING``. diff --git a/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst b/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst index 542b817b..e6d5da67 100644 --- a/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst +++ b/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst @@ -54,6 +54,39 @@ Optional. Default value is ``3``. The number of time steps required to occur in order for the node to cycle from ``ON`` to ``SHUTTING_DOWN`` and then finally ``OFF``. +``file_system`` +--------------- + +Optional. + +The file system of the node. This configuration allows nodes to be initialised with files and/or folders. + +The file system takes a list of folders and files. + +Example: + +.. code-block:: yaml + + simulation: + network: + nodes: + - hostname: client_1 + type: computer + ip_address: 192.168.10.11 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + file_system: + - empty_folder # example of an empty folder + - downloads: + - "test_1.txt" # files in the downloads folder + - "test_2.txt" + - root: + - passwords: # example of file with size and type + size: 69 # size in bytes + type: TXT # See FileType for list of available file types + +List of file types: :py:mod:`primaite.simulator.file_system.file_type.FileType` + ``users`` --------- diff --git a/docs/source/configuration/simulation/nodes/network_examples.rst b/docs/source/configuration/simulation/nodes/network_examples.rst index 4616139e..84ee4c60 100644 --- a/docs/source/configuration/simulation/nodes/network_examples.rst +++ b/docs/source/configuration/simulation/nodes/network_examples.rst @@ -1177,8 +1177,8 @@ ACLs permitting or denying traffic as per our configured ACL rules. some_tech_storage_srv = network.get_node_by_hostname("some_tech_storage_srv") some_tech_storage_srv.file_system.create_file(file_name="test.png") - pc_1_ftp_client: FTPClient = network.get_node_by_hostname("pc_1").software_manager.software["FTPClient"] - pc_2_ftp_client: FTPClient = network.get_node_by_hostname("pc_2").software_manager.software["FTPClient"] + pc_1_ftp_client: FTPClient = network.get_node_by_hostname("pc_1").software_manager.software["ftp-client"] + pc_2_ftp_client: FTPClient = network.get_node_by_hostname("pc_2").software_manager.software["ftp-client"] assert not pc_1_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -1224,7 +1224,7 @@ ACLs permitting or denying traffic as per our configured ACL rules. web_server: Server = network.get_node_by_hostname("some_tech_web_srv") - web_ftp_client: FTPClient = web_server.software_manager.software["FTPClient"] + web_ftp_client: FTPClient = web_server.software_manager.software["ftp-client"] assert not web_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -1269,7 +1269,7 @@ ACLs permitting or denying traffic as per our configured ACL rules. some_tech_storage_srv.file_system.create_file(file_name="test.png") some_tech_snr_dev_pc: Computer = network.get_node_by_hostname("some_tech_snr_dev_pc") - snr_dev_ftp_client: FTPClient = some_tech_snr_dev_pc.software_manager.software["FTPClient"] + snr_dev_ftp_client: FTPClient = some_tech_snr_dev_pc.software_manager.software["ftp-client"] assert snr_dev_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -1294,7 +1294,7 @@ ACLs permitting or denying traffic as per our configured ACL rules. some_tech_storage_srv.file_system.create_file(file_name="test.png") some_tech_jnr_dev_pc: Computer = network.get_node_by_hostname("some_tech_jnr_dev_pc") - jnr_dev_ftp_client: FTPClient = some_tech_jnr_dev_pc.software_manager.software["FTPClient"] + jnr_dev_ftp_client: FTPClient = some_tech_jnr_dev_pc.software_manager.software["ftp-client"] assert not jnr_dev_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -1337,7 +1337,7 @@ ACLs permitting or denying traffic as per our configured ACL rules. some_tech_storage_srv.file_system.create_file(file_name="test.png") some_tech_hr_pc: Computer = network.get_node_by_hostname("some_tech_hr_1") - hr_ftp_client: FTPClient = some_tech_hr_pc.software_manager.software["FTPClient"] + hr_ftp_client: FTPClient = some_tech_hr_pc.software_manager.software["ftp-client"] assert not hr_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, diff --git a/docs/source/configuration/simulation/nodes/router.rst b/docs/source/configuration/simulation/nodes/router.rst index 4b41784c..ee278a98 100644 --- a/docs/source/configuration/simulation/nodes/router.rst +++ b/docs/source/configuration/simulation/nodes/router.rst @@ -74,7 +74,7 @@ The subnet mask setting for the port. ``acl`` ------- -Sets up the ACL rules for the router. +Sets up the ACL rules for the router to apply to layer-3 traffic. These are not applied to layer-2 traffic such as ARP. e.g. @@ -85,10 +85,6 @@ e.g. ... acl: 1: - action: PERMIT - src_port: ARP - dst_port: ARP - 2: action: PERMIT protocol: ICMP diff --git a/docs/source/how_to_guides/extensible_agents.rst b/docs/source/how_to_guides/extensible_agents.rst index 1d765417..3236c21a 100644 --- a/docs/source/how_to_guides/extensible_agents.rst +++ b/docs/source/how_to_guides/extensible_agents.rst @@ -46,17 +46,13 @@ The core features that should be implemented in any new agent are detailed below - ref: example_green_agent team: GREEN - type: ExampleAgent + type: example-agent action_space: action_map: 0: action: do-nothing options: {} - reward_function: - reward_components: - - type: dummy - agent_settings: start_step: 25 frequency: 20 diff --git a/docs/source/how_to_guides/extensible_nodes.rst b/docs/source/how_to_guides/extensible_nodes.rst index 043d0f06..18d64ca8 100644 --- a/docs/source/how_to_guides/extensible_nodes.rst +++ b/docs/source/how_to_guides/extensible_nodes.rst @@ -26,9 +26,9 @@ class Router(NetworkNode, identifier="router"): """ Represents a network router within the simulation, managing routing and forwarding of IP packets across network interfaces.""" SYSTEM_SOFTWARE: ClassVar[Dict] = { - "UserSessionManager": UserSessionManager, - "UserManager": UserManager, - "Terminal": Terminal, + "user-session-manager": UserSessionManager, + "user-manager": UserManager, + "terminal": Terminal, } network_interfaces: Dict[str, RouterInterface] = {} @@ -52,4 +52,4 @@ class Router(NetworkNode, identifier="router"): Changes to YAML file. ===================== -While effort has been made to ensure that nodes defined within configuration YAML files for use with PrimAITE 3.X remain compatible with PrimAITE v4+, it is encouraged to review for minor changes needed. +While effort has been made to ensure that nodes defined within configuration YAML files for use with PrimAITE 3.X remain compatible with PrimAITE v4+, it is encouraged to review for minor changes needed. diff --git a/docs/source/request_system.rst b/docs/source/request_system.rst index f0437705..93fc2a9f 100644 --- a/docs/source/request_system.rst +++ b/docs/source/request_system.rst @@ -2,6 +2,8 @@ © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +.. _request_system: + Request System ************** diff --git a/docs/source/simulation_components/network/network.rst b/docs/source/simulation_components/network/network.rst index 152b74b8..a6fe4070 100644 --- a/docs/source/simulation_components/network/network.rst +++ b/docs/source/simulation_components/network/network.rst @@ -97,19 +97,19 @@ we'll use the following Network that has a client, server, two switches, and a r network.connect(endpoint_a=switch_2.network_interface[1], endpoint_b=client_1.network_interface[1]) network.connect(endpoint_a=switch_1.network_interface[1], endpoint_b=server_1.network_interface[1]) -8. Add ACL rules on the Router to allow ARP and ICMP traffic. +8. Add an ACL rule on the Router to allow ICMP traffic. .. code-block:: python router_1.acl.add_rule( action=ACLAction.PERMIT, - src_port=Port["ARP"], - dst_port=Port["ARP"], + src_port=PORT_LOOKUP["ARP"], + dst_port=PORT_LOOKUP["ARP"], position=22 ) router_1.acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol["ICMP"], + protocol=PROTOCOL_LOOKUP["ICMP"], position=23 ) diff --git a/docs/source/simulation_components/network/nodes/wireless_router.rst b/docs/source/simulation_components/network/nodes/wireless_router.rst index d7207846..4078ffda 100644 --- a/docs/source/simulation_components/network/nodes/wireless_router.rst +++ b/docs/source/simulation_components/network/nodes/wireless_router.rst @@ -102,8 +102,8 @@ ICMP traffic, ensuring basic network connectivity and ping functionality. network.connect(pc_a.network_interface[1], router_1.router_interface) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Configure PC B pc_b = Computer( diff --git a/docs/source/simulation_components/system/applications/c2_suite.rst b/docs/source/simulation_components/system/applications/c2_suite.rst index 34175fc3..c780485a 100644 --- a/docs/source/simulation_components/system/applications/c2_suite.rst +++ b/docs/source/simulation_components/system/applications/c2_suite.rst @@ -183,7 +183,7 @@ Python # Example command: Installing and configuring Ransomware: ransomware_installation_command = { "commands": [ - ["software_manager","application","install","RansomwareScript"], + ["software_manager","application","install","ransomware-script"], ], "username": "admin", "password": "admin", diff --git a/docs/source/simulation_components/system/applications/data_manipulation_bot.rst b/docs/source/simulation_components/system/applications/data_manipulation_bot.rst index 8e008504..3ddb8bca 100644 --- a/docs/source/simulation_components/system/applications/data_manipulation_bot.rst +++ b/docs/source/simulation_components/system/applications/data_manipulation_bot.rst @@ -77,7 +77,7 @@ Python network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) client_1.software_manager.install(DatabaseClient) client_1.software_manager.install(DataManipulationBot) - data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") + data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("data-manipulation-bot") data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE") data_manipulation_bot.run() @@ -98,7 +98,7 @@ If not using the data manipulation bot manually, it needs to be used with a data type: red-database-corrupting-agent observation_space: - type: UC2RedObservation + type: uc2-red-observation #TODO what options: nodes: - node_name: client_1 diff --git a/docs/source/simulation_components/system/applications/database_client.rst b/docs/source/simulation_components/system/applications/database_client.rst index 7087dedf..472b504c 100644 --- a/docs/source/simulation_components/system/applications/database_client.rst +++ b/docs/source/simulation_components/system/applications/database_client.rst @@ -59,7 +59,7 @@ Python # install DatabaseClient client.software_manager.install(DatabaseClient) - database_client: DatabaseClient = client.software_manager.software.get("DatabaseClient") + database_client: DatabaseClient = client.software_manager.software.get("database-sclient") # Configure the DatabaseClient database_client.configure(server_ip_address=IPv4Address("192.168.0.1")) # address of the DatabaseService diff --git a/docs/source/simulation_components/system/applications/ransomware_script.rst b/docs/source/simulation_components/system/applications/ransomware_script.rst index 192618fc..a8975f32 100644 --- a/docs/source/simulation_components/system/applications/ransomware_script.rst +++ b/docs/source/simulation_components/system/applications/ransomware_script.rst @@ -62,7 +62,7 @@ Python network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) client_1.software_manager.install(DatabaseClient) client_1.software_manager.install(RansomwareScript) - RansomwareScript: RansomwareScript = client_1.software_manager.software.get("RansomwareScript") + RansomwareScript: RansomwareScript = client_1.software_manager.software.get("ransomware-script") RansomwareScript.configure(server_ip_address=IPv4Address("192.168.1.14")) RansomwareScript.execute() diff --git a/docs/source/simulation_components/system/applications/web_browser.rst b/docs/source/simulation_components/system/applications/web_browser.rst index c04e60af..659caa09 100644 --- a/docs/source/simulation_components/system/applications/web_browser.rst +++ b/docs/source/simulation_components/system/applications/web_browser.rst @@ -61,7 +61,7 @@ The :ref:`DNSClient` must be configured to use the :ref:`DNSServer`. The :ref:`D # Install WebBrowser on computer computer.software_manager.install(WebBrowser) - web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") + web_browser: WebBrowser = computer.software_manager.software.get("web-browser") web_browser.run() # configure the WebBrowser diff --git a/docs/source/simulation_components/system/services/database_service.rst b/docs/source/simulation_components/system/services/database_service.rst index 961f2e45..c819a0f7 100644 --- a/docs/source/simulation_components/system/services/database_service.rst +++ b/docs/source/simulation_components/system/services/database_service.rst @@ -66,7 +66,7 @@ Python # Install DatabaseService on server server.software_manager.install(DatabaseService) - db_service: DatabaseService = server.software_manager.software.get("DatabaseService") + db_service: DatabaseService = server.software_manager.software.get("database-service") db_service.start() # configure DatabaseService diff --git a/docs/source/simulation_components/system/services/dns_client.rst b/docs/source/simulation_components/system/services/dns_client.rst index 17a1ed25..40762bfc 100644 --- a/docs/source/simulation_components/system/services/dns_client.rst +++ b/docs/source/simulation_components/system/services/dns_client.rst @@ -56,7 +56,7 @@ Python # Install DNSClient on server server.software_manager.install(DNSClient) - dns_client: DNSClient = server.software_manager.software.get("DNSClient") + dns_client: DNSClient = server.software_manager.software.get("dns-client") dns_client.start() # configure DatabaseService diff --git a/docs/source/simulation_components/system/services/dns_server.rst b/docs/source/simulation_components/system/services/dns_server.rst index 633221d5..ca0e3691 100644 --- a/docs/source/simulation_components/system/services/dns_server.rst +++ b/docs/source/simulation_components/system/services/dns_server.rst @@ -53,7 +53,7 @@ Python # Install DNSServer on server server.software_manager.install(DNSServer) - dns_server: DNSServer = server.software_manager.software.get("DNSServer") + dns_server: DNSServer = server.software_manager.software.get("dns-server") dns_server.start() # configure DatabaseService diff --git a/docs/source/simulation_components/system/services/ftp_client.rst b/docs/source/simulation_components/system/services/ftp_client.rst index d4375069..530b5aff 100644 --- a/docs/source/simulation_components/system/services/ftp_client.rst +++ b/docs/source/simulation_components/system/services/ftp_client.rst @@ -60,7 +60,7 @@ Python # Install FTPClient on server server.software_manager.install(FTPClient) - ftp_client: FTPClient = server.software_manager.software.get("FTPClient") + ftp_client: FTPClient = server.software_manager.software.get("ftp-client") ftp_client.start() diff --git a/docs/source/simulation_components/system/services/ftp_server.rst b/docs/source/simulation_components/system/services/ftp_server.rst index a5ad32fe..20dd6707 100644 --- a/docs/source/simulation_components/system/services/ftp_server.rst +++ b/docs/source/simulation_components/system/services/ftp_server.rst @@ -55,7 +55,7 @@ Python # Install FTPServer on server server.software_manager.install(FTPServer) - ftp_server: FTPServer = server.software_manager.software.get("FTPServer") + ftp_server: FTPServer = server.software_manager.software.get("ftp-server") ftp_server.start() ftp_server.server_password = "test" diff --git a/docs/source/simulation_components/system/services/ntp_client.rst b/docs/source/simulation_components/system/services/ntp_client.rst index 8c011cad..5406d9fc 100644 --- a/docs/source/simulation_components/system/services/ntp_client.rst +++ b/docs/source/simulation_components/system/services/ntp_client.rst @@ -53,7 +53,7 @@ Python # Install NTPClient on server server.software_manager.install(NTPClient) - ntp_client: NTPClient = server.software_manager.software.get("NTPClient") + ntp_client: NTPClient = server.software_manager.software.get("ntp-client") ntp_client.start() ntp_client.configure(ntp_server_ip_address=IPv4Address("192.168.0.10")) diff --git a/docs/source/simulation_components/system/services/ntp_server.rst b/docs/source/simulation_components/system/services/ntp_server.rst index c1d16d61..2c01dcaf 100644 --- a/docs/source/simulation_components/system/services/ntp_server.rst +++ b/docs/source/simulation_components/system/services/ntp_server.rst @@ -55,7 +55,7 @@ Python # Install NTPServer on server server.software_manager.install(NTPServer) - ntp_server: NTPServer = server.software_manager.software.get("NTPServer") + ntp_server: NTPServer = server.software_manager.software.get("ntp-server") ntp_server.start() diff --git a/docs/source/simulation_components/system/services/terminal.rst b/docs/source/simulation_components/system/services/terminal.rst index bc5cee48..5c9bad79 100644 --- a/docs/source/simulation_components/system/services/terminal.rst +++ b/docs/source/simulation_components/system/services/terminal.rst @@ -23,6 +23,14 @@ Key capabilities - Simulates common Terminal processes/commands. - Leverages the Service base class for install/uninstall, status tracking etc. +Usage +""""" + + - Pre-Installs on any `Node` component (with the exception of `Switches`). + - Terminal Clients connect, execute commands and disconnect from remote nodes. + - Ensures that users are logged in to the component before executing any commands. + - Service runs on SSH port 22 by default. + - Enables Agents to send commands both remotely and locally. Implementation """""""""""""" @@ -30,19 +38,112 @@ Implementation - Manages remote connections in a dictionary by session ID. - Processes commands, forwarding to the ``RequestManager`` or ``SessionManager`` where appropriate. - Extends Service class. - - A detailed guide on the implementation and functionality of the Terminal class can be found in the "Terminal-Processing" jupyter notebook. + +A detailed guide on the implementation and functionality of the Terminal class can be found in the "Terminal-Processing" jupyter notebook. + +Command Format +^^^^^^^^^^^^^^ + +Terminals implement their commands through leveraging the pre-existing :ref:`request_system`. + +Due to this Terminals will only accept commands passed within the ``RequestFormat``. + +:py:class:`primaite.game.interface.RequestFormat` + +For example, ``terminal`` command actions when used in ``yaml`` format are formatted as follows: + +.. code-block:: yaml + + command: + - "file_system" + - "create" + - "file" + - "downloads" + - "cat.png" + - "False + +This is then loaded from yaml into a dictionary containing the terminal command: + +.. code-block:: python + + {"command":["file_system", "create", "file", "downloads", "cat.png", "False"]} + +Which is then passed to the ``Terminals`` Request Manager to be executed. + +Game Layer Usage (Agents) +======================== + +The below code examples demonstrate how to use terminal related actions in yaml files. + +yaml +"""" + +``node-send-local-command`` +""""""""""""""""""""""""""" + +Agents can execute local commands without needing to perform a separate remote login action (``node-session-remote-login``). + +.. code-block:: yaml + + ... + ... + action: node-send-local-command + options: + node_id: 0 + username: admin + password: admin + command: # Example command - Creates a file called 'cat.png' in the downloads folder. + - "file_system" + - "create" + - "file" + - "downloads" + - "cat.png" + - "False" -Usage -""""" +``node-session-remote-login`` +""""""""""""""""" - - Pre-Installs on all ``Nodes`` (with the exception of ``Switches``). - - Terminal Clients connect, execute commands and disconnect from remote nodes. - - Ensures that users are logged in to the component before executing any commands. - - Service runs on SSH port 22 by default. +Agents are able to use the terminal to login into remote nodes via ``SSH`` which allows for agents to execute commands on remote hosts. + +.. code-block:: yaml + + ... + ... + action: node-session-remote-login + options: + node_id: 0 + username: admin + password: admin + remote_ip: 192.168.0.10 # Example Ip Address. (The remote host's IP that will be used by ssh) + + +``node-send-remote-command`` +"""""""""""""""""""""""""""" + +After remotely logging into another host, an agent can use the ``node-send-remote-command`` to execute commands across the network remotely. + +.. code-block:: yaml + + ... + ... + action: node-send-remote-command + options: + node_id: 0 + remote_ip: 192.168.0.10 + command: + - "file_system" + - "create" + - "file" + - "downloads" + - "cat.png" + - "False" + + + +Simulation Layer Usage +====================== -Usage -===== The below code examples demonstrate how to create a terminal, a remote terminal, and how to send a basic application install command to a remote node. @@ -65,7 +166,7 @@ Python operating_state=NodeOperatingState.ON, ) - terminal: Terminal = client.software_manager.software.get("Terminal") + terminal: Terminal = client.software_manager.software.get("terminal") Creating Remote Terminal Connection """"""""""""""""""""""""""""""""""" @@ -86,7 +187,7 @@ Creating Remote Terminal Connection node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) - terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + terminal_a: Terminal = node_a.software_manager.software.get("terminal") term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") @@ -112,12 +213,12 @@ Executing a basic application install command node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) - terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + terminal_a: Terminal = node_a.software_manager.software.get("terminal") term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") - term_a_term_b_remote_connection.execute(["software_manager", "application", "install", "RansomwareScript"]) + term_a_term_b_remote_connection.execute(["software_manager", "application", "install", "ransomware-script"]) @@ -140,7 +241,7 @@ Creating a folder on a remote node node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) - terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + terminal_a: Terminal = node_a.software_manager.software.get("terminal") term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") @@ -167,7 +268,7 @@ Disconnect from Remote Node node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) - terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + terminal_a: Terminal = node_a.software_manager.software.get("terminal") term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") diff --git a/docs/source/simulation_components/system/services/web_server.rst b/docs/source/simulation_components/system/services/web_server.rst index bce42791..9d7f4d2f 100644 --- a/docs/source/simulation_components/system/services/web_server.rst +++ b/docs/source/simulation_components/system/services/web_server.rst @@ -56,7 +56,7 @@ Python # Install WebServer on server server.software_manager.install(WebServer) - web_server: WebServer = server.software_manager.software.get("WebServer") + web_server: WebServer = server.software_manager.software.get("web-server") web_server.start() Via Configuration diff --git a/docs/source/simulation_components/system/software.rst b/docs/source/simulation_components/system/software.rst index d28815bb..c2f3066b 100644 --- a/docs/source/simulation_components/system/software.rst +++ b/docs/source/simulation_components/system/software.rst @@ -30,7 +30,7 @@ See :ref:`Node Start up and Shut down` node.software_manager.install(WebServer) - web_server: WebServer = node.software_manager.software.get("WebServer") + web_server: WebServer = node.software_manager.software.get("web-server") assert web_server.operating_state is ServiceOperatingState.RUNNING # service is immediately ran after install node.power_off() diff --git a/src/primaite/game/agent/actions/node.py b/src/primaite/game/agent/actions/node.py index 19639c21..b1b6ec12 100644 --- a/src/primaite/game/agent/actions/node.py +++ b/src/primaite/game/agent/actions/node.py @@ -1,6 +1,6 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from abc import ABC, abstractmethod -from typing import ClassVar, List, Optional, Union +from typing import ClassVar, List, Literal, Optional, Union from primaite.game.agent.actions.manager import AbstractAction from primaite.interface.request import RequestFormat @@ -153,8 +153,6 @@ class NodeNMAPPortScanAction(NodeNMAPAbstractAction, discriminator="node-nmap-po class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, discriminator="node-network-service-recon"): """Action which performs an nmap network service recon (ping scan followed by port scan).""" - config: "NodeNetworkServiceReconAction.ConfigSchema" - class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema): """Configuration schema for NodeNetworkServiceReconAction.""" @@ -179,3 +177,70 @@ class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, discriminator="node- "show": config.show, }, ] + + +class NodeAccountsAddUserAction(AbstractAction, discriminator="node-account-add-user"): + class ConfigSchema(AbstractAction.ConfigSchema): + type: Literal["node-account-add-user"] = "node-account-add-user" + node_name: str + username: str + password: str + is_admin: bool + + @classmethod + @staticmethod + def form_request(config: ConfigSchema) -> RequestFormat: + return [ + "network", + "node", + config.node_name, + "service", + "user-manager", + "add_user", + config.username, + config.password, + config.is_admin, + ] + + +class NodeAccountsDisableUserAction(AbstractAction, discriminator="node-account-disable-user"): + class ConfigSchema(AbstractAction.ConfigSchema): + type: Literal["node-account-disable-user"] = "node-account-disable-user" + node_name: str + username: str + + @classmethod + @staticmethod + def form_request(config: ConfigSchema) -> RequestFormat: + return [ + "network", + "node", + config.node_name, + "service", + "user-manager", + "disable_user", + config.username, + ] + + +class NodeSendLocalCommandAction(AbstractAction, discriminator="node-send-local-command"): + class ConfigSchema(AbstractAction.ConfigSchema): + type: Literal["node-send-local-command"] = "node-send-local-command" + node_name: str + username: str + password: str + command: RequestFormat + + @staticmethod + def form_request(config: ConfigSchema) -> RequestFormat: + return [ + "network", + "node", + config.node_name, + "service", + "terminal", + "send_local_command", + config.username, + config.password, + {"command": config.command}, + ] diff --git a/src/primaite/game/agent/actions/session.py b/src/primaite/game/agent/actions/session.py index 58a8a555..63a45c5e 100644 --- a/src/primaite/game/agent/actions/session.py +++ b/src/primaite/game/agent/actions/session.py @@ -34,8 +34,6 @@ class NodeSessionAbstractAction(AbstractAction, ABC): class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="node-session-remote-login"): """Action which performs a remote session login.""" - config: "NodeSessionsRemoteLoginAction.ConfigSchema" - class ConfigSchema(NodeSessionAbstractAction.ConfigSchema): """Configuration schema for NodeSessionsRemoteLoginAction.""" @@ -53,7 +51,7 @@ class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="no config.node_name, "service", "terminal", - "node-session-remote-login", + "node_session_remote_login", config.username, config.password, config.remote_ip, @@ -63,8 +61,6 @@ class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="no class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, discriminator="node-session-remote-logoff"): """Action which performs a remote session logout.""" - config: "NodeSessionsRemoteLogoutAction.ConfigSchema" - class ConfigSchema(NodeSessionAbstractAction.ConfigSchema): """Configuration schema for NodeSessionsRemoteLogoutAction.""" @@ -78,14 +74,13 @@ class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, discriminator="n return ["network", "node", config.node_name, "service", "terminal", config.verb, config.remote_ip] -class NodeAccountChangePasswordAction(NodeSessionAbstractAction, discriminator="node-account-change-password"): +class NodeAccountChangePasswordAction(AbstractAction, discriminator="node-account-change-password"): """Action which changes the password for a user.""" - config: "NodeAccountChangePasswordAction.ConfigSchema" - - class ConfigSchema(NodeSessionAbstractAction.ConfigSchema): + class ConfigSchema(AbstractAction.ConfigSchema): """Configuration schema for NodeAccountsChangePasswordAction.""" + node_name: str username: str current_password: str new_password: str diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index a55cd3ff..d06bd1d0 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Type, TYPE_CHECKING from gymnasium.core import ActType, ObsType +from prettytable import PrettyTable from pydantic import BaseModel, ConfigDict, Field from primaite.game.agent.actions import ActionManager @@ -42,6 +43,9 @@ class AgentHistoryItem(BaseModel): reward_info: Dict[str, Any] = {} + observation: Optional[ObsType] = None + """The observation space data for this step.""" + class AbstractAgent(BaseModel, ABC): """Base class for scripted and RL agents.""" @@ -67,6 +71,9 @@ class AbstractAgent(BaseModel, ABC): default_factory=lambda: ObservationManager.ConfigSchema() ) reward_function: RewardFunction.ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema()) + thresholds: Optional[Dict] = {} + # TODO: this is only relevant to some observations, need to refactor the way thresholds are dealt with (#3085) + """A dict containing the observation thresholds.""" config: ConfigSchema = Field(default_factory=lambda: AbstractAgent.ConfigSchema()) @@ -90,10 +97,42 @@ class AbstractAgent(BaseModel, ABC): def model_post_init(self, __context: Any) -> None: """Overwrite the default empty action, observation, and rewards with ones defined through the config.""" self.action_manager = ActionManager(config=self.config.action_space) + self.config.observation_space.options.thresholds = self.config.thresholds self.observation_manager = ObservationManager(config=self.config.observation_space) self.reward_function = RewardFunction(config=self.config.reward_function) return super().model_post_init(__context) + def add_agent_action(self, item: AgentHistoryItem, table: PrettyTable) -> PrettyTable: + """Update the given table with information from given AgentHistoryItem.""" + node, application = "unknown", "unknown" + if (node_id := item.parameters.get("node_id")) is not None: + node = self.action_manager.node_names[node_id] + if (application_id := item.parameters.get("application_id")) is not None: + application = self.action_manager.application_names[node_id][application_id] + if (application_name := item.parameters.get("application_name")) is not None: + application = application_name + table.add_row([item.timestep, item.action, node, application, item.response.status]) + return table + + def show_history(self, ignored_actions: Optional[list] = None): + """ + Print an agent action provided it's not the DONOTHING action. + + :param ignored_actions: OPTIONAL: List of actions to be ignored when displaying the history. + If not provided, defaults to ignore DONOTHING actions. + """ + if not ignored_actions: + ignored_actions = ["DONOTHING"] + table = PrettyTable() + table.field_names = ["Step", "Action", "Node", "Application", "Response"] + print(f"Actions for '{self.agent_name}':") + for item in self.history: + if item.action in ignored_actions: + pass + else: + table = self.add_agent_action(item=item, table=table) + print(table) + def update_observation(self, state: Dict) -> ObsType: """ Convert a state from the simulator into an observation for the agent using the observation space. @@ -140,12 +179,23 @@ class AbstractAgent(BaseModel, ABC): return request def process_action_response( - self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse + self, + timestep: int, + action: str, + parameters: Dict[str, Any], + request: RequestFormat, + response: RequestResponse, + observation: ObsType, ) -> None: """Process the response from the most recent action.""" self.history.append( AgentHistoryItem( - timestep=timestep, action=action, parameters=parameters, request=request, response=response + timestep=timestep, + action=action, + parameters=parameters, + request=request, + response=response, + observation=observation, ) ) diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index ed9dcd8f..a9e3a9aa 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -26,7 +26,13 @@ class FileObservation(AbstractObservation, discriminator="file"): file_system_requires_scan: Optional[bool] = None """If True, the file must be scanned to update the health state. Tf False, the true state is always shown.""" - def __init__(self, where: WhereType, include_num_access: bool, file_system_requires_scan: bool) -> None: + def __init__( + self, + where: WhereType, + include_num_access: bool, + file_system_requires_scan: bool, + thresholds: Optional[Dict] = {}, + ) -> None: """ Initialise a file observation instance. @@ -48,10 +54,36 @@ class FileObservation(AbstractObservation, discriminator="file"): if self.include_num_access: self.default_observation["num_access"] = 0 - # TODO: allow these to be configured in yaml - self.high_threshold = 10 - self.med_threshold = 5 - self.low_threshold = 0 + if thresholds.get("file_access") is None: + self.low_file_access_threshold = 0 + self.med_file_access_threshold = 5 + self.high_file_access_threshold = 10 + else: + self._set_file_access_threshold( + thresholds=[ + thresholds.get("file_access")["low"], + thresholds.get("file_access")["medium"], + thresholds.get("file_access")["high"], + ] + ) + + def _set_file_access_threshold(self, thresholds: List[int]): + """ + Method that validates and then sets the file access threshold. + + :param: thresholds: The file access threshold to validate and set. + """ + if self._validate_thresholds( + thresholds=[ + thresholds[0], + thresholds[1], + thresholds[2], + ], + threshold_identifier="file_access", + ): + self.low_file_access_threshold = thresholds[0] + self.med_file_access_threshold = thresholds[1] + self.high_file_access_threshold = thresholds[2] def _categorise_num_access(self, num_access: int) -> int: """ @@ -60,11 +92,11 @@ class FileObservation(AbstractObservation, discriminator="file"): :param num_access: Number of file accesses. :return: Bin number corresponding to the number of accesses. """ - if num_access > self.high_threshold: + if num_access > self.high_file_access_threshold: return 3 - elif num_access > self.med_threshold: + elif num_access > self.med_file_access_threshold: return 2 - elif num_access > self.low_threshold: + elif num_access > self.low_file_access_threshold: return 1 return 0 @@ -122,6 +154,7 @@ class FileObservation(AbstractObservation, discriminator="file"): where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access, file_system_requires_scan=config.file_system_requires_scan, + thresholds=config.thresholds, ) @@ -149,6 +182,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"): num_files: int, include_num_access: bool, file_system_requires_scan: bool, + thresholds: Optional[Dict] = {}, ) -> None: """ Initialise a folder observation instance. @@ -177,6 +211,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"): where=None, include_num_access=include_num_access, file_system_requires_scan=self.file_system_requires_scan, + thresholds=thresholds, ) ) while len(self.files) > num_files: @@ -253,6 +288,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"): for file_config in config.files: file_config.include_num_access = config.include_num_access file_config.file_system_requires_scan = config.file_system_requires_scan + file_config.thresholds = config.thresholds files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files] return cls( @@ -261,4 +297,5 @@ class FolderObservation(AbstractObservation, discriminator="folder"): num_files=config.num_files, include_num_access=config.include_num_access, file_system_requires_scan=config.file_system_requires_scan, + thresholds=config.thresholds, ) diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 17bcb983..9b979063 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -54,7 +54,15 @@ class HostObservation(AbstractObservation, discriminator="host"): """ If True, files and folders must be scanned to update the health state. If False, true state is always shown. """ - include_users: Optional[bool] = None + services_requires_scan: Optional[bool] = None + """ + If True, services must be scanned to update the health state. If False, true state is always shown. + """ + applications_requires_scan: Optional[bool] = None + """ + If True, applications must be scanned to update the health state. If False, true state is always shown. + """ + include_users: Optional[bool] = True """If True, report user session information.""" def __init__( @@ -73,6 +81,8 @@ class HostObservation(AbstractObservation, discriminator="host"): monitored_traffic: Optional[Dict], include_num_access: bool, file_system_requires_scan: bool, + services_requires_scan: bool, + applications_requires_scan: bool, include_users: bool, ) -> None: """ @@ -108,6 +118,12 @@ class HostObservation(AbstractObservation, discriminator="host"): :param file_system_requires_scan: If True, the files and folders must be scanned to update the health state. If False, the true state is always shown. :type file_system_requires_scan: bool + :param services_requires_scan: If True, services must be scanned to update the health state. + If False, the true state is always shown. + :type services_requires_scan: bool + :param applications_requires_scan: If True, applications must be scanned to update the health state. + If False, the true state is always shown. + :type applications_requires_scan: bool :param include_users: If True, report user session information. :type include_users: bool """ @@ -121,7 +137,7 @@ class HostObservation(AbstractObservation, discriminator="host"): # Ensure lists have lengths equal to specified counts by truncating or padding self.services: List[ServiceObservation] = services while len(self.services) < num_services: - self.services.append(ServiceObservation(where=None)) + self.services.append(ServiceObservation(where=None, services_requires_scan=services_requires_scan)) while len(self.services) > num_services: truncated_service = self.services.pop() msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" @@ -129,7 +145,9 @@ class HostObservation(AbstractObservation, discriminator="host"): self.applications: List[ApplicationObservation] = applications while len(self.applications) < num_applications: - self.applications.append(ApplicationObservation(where=None)) + self.applications.append( + ApplicationObservation(where=None, applications_requires_scan=applications_requires_scan) + ) while len(self.applications) > num_applications: truncated_application = self.applications.pop() msg = f"Too many applications in Node observation space for node. Truncating {truncated_application.where}" @@ -153,7 +171,13 @@ class HostObservation(AbstractObservation, discriminator="host"): self.nics: List[NICObservation] = network_interfaces while len(self.nics) < num_nics: - self.nics.append(NICObservation(where=None, include_nmne=include_nmne, monitored_traffic=monitored_traffic)) + self.nics.append( + NICObservation( + where=None, + include_nmne=include_nmne, + monitored_traffic=monitored_traffic, + ) + ) while len(self.nics) > num_nics: truncated_nic = self.nics.pop() msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}" @@ -269,8 +293,15 @@ class HostObservation(AbstractObservation, discriminator="host"): folder_config.include_num_access = config.include_num_access folder_config.num_files = config.num_files folder_config.file_system_requires_scan = config.file_system_requires_scan + folder_config.thresholds = config.thresholds for nic_config in config.network_interfaces: nic_config.include_nmne = config.include_nmne + nic_config.thresholds = config.thresholds + for service_config in config.services: + service_config.services_requires_scan = config.services_requires_scan + for application_config in config.applications: + application_config.applications_requires_scan = config.applications_requires_scan + application_config.thresholds = config.thresholds services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services] applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications] @@ -281,7 +312,10 @@ class HostObservation(AbstractObservation, discriminator="host"): count = 1 while len(nics) < config.num_nics: nic_config = NICObservation.ConfigSchema( - nic_num=count, include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic + nic_num=count, + include_nmne=config.include_nmne, + monitored_traffic=config.monitored_traffic, + thresholds=config.thresholds, ) nics.append(NICObservation.from_config(config=nic_config, parent_where=where)) count += 1 @@ -301,5 +335,7 @@ class HostObservation(AbstractObservation, discriminator="host"): monitored_traffic=config.monitored_traffic, include_num_access=config.include_num_access, file_system_requires_scan=config.file_system_requires_scan, + services_requires_scan=config.services_requires_scan, + applications_requires_scan=config.applications_requires_scan, include_users=config.include_users, ) diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 1aa6470d..8faeb906 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -1,13 +1,14 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import Dict, List, Optional +from typing import ClassVar, Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.simulator.network.nmne import NMNEConfig from primaite.utils.validation.ip_protocol import IPProtocol from primaite.utils.validation.port import Port @@ -15,6 +16,9 @@ from primaite.utils.validation.port import Port class NICObservation(AbstractObservation, discriminator="network-interface"): """Status information about a network interface within the simulation environment.""" + capture_nmne: ClassVar[bool] = NMNEConfig().capture_nmne + "A Boolean specifying whether malicious network events should be captured." + class ConfigSchema(AbstractObservation.ConfigSchema): """Configuration schema for NICObservation.""" @@ -25,7 +29,13 @@ class NICObservation(AbstractObservation, discriminator="network-interface"): monitored_traffic: Optional[Dict[IPProtocol, List[Port]]] = None """A dict containing which traffic types are to be included in the observation.""" - def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None: + def __init__( + self, + where: WhereType, + include_nmne: bool, + monitored_traffic: Optional[Dict] = None, + thresholds: Dict = {}, + ) -> None: """ Initialise a network interface observation instance. @@ -45,10 +55,18 @@ class NICObservation(AbstractObservation, discriminator="network-interface"): self.nmne_inbound_last_step: int = 0 self.nmne_outbound_last_step: int = 0 - # TODO: allow these to be configured in yaml - self.high_nmne_threshold = 10 - self.med_nmne_threshold = 5 - self.low_nmne_threshold = 0 + if thresholds.get("nmne") is None: + self.low_nmne_threshold = 0 + self.med_nmne_threshold = 5 + self.high_nmne_threshold = 10 + else: + self._set_nmne_threshold( + thresholds=[ + thresholds.get("nmne")["low"], + thresholds.get("nmne")["medium"], + thresholds.get("nmne")["high"], + ] + ) self.monitored_traffic = monitored_traffic if self.monitored_traffic: @@ -105,6 +123,20 @@ class NICObservation(AbstractObservation, discriminator="network-interface"): bandwidth_utilisation = traffic_value / nic_max_bandwidth return int(bandwidth_utilisation * 9) + 1 + def _set_nmne_threshold(self, thresholds: List[int]): + """ + Method that validates and then sets the NMNE threshold. + + :param: thresholds: The NMNE threshold to validate and set. + """ + if self._validate_thresholds( + thresholds=thresholds, + threshold_identifier="nmne", + ): + self.low_nmne_threshold = thresholds[0] + self.med_nmne_threshold = thresholds[1] + self.high_nmne_threshold = thresholds[2] + def observe(self, state: Dict) -> ObsType: """ Generate observation based on the current state of the simulation. @@ -116,7 +148,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"): """ nic_state = access_from_nested_dict(state, self.where) - if nic_state is NOT_PRESENT_IN_STATE: + if nic_state is NOT_PRESENT_IN_STATE or self.where is None: return self.default_observation obs = {"nic_status": 1 if nic_state["enabled"] else 2} @@ -164,7 +196,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"): for port in self.monitored_traffic[protocol]: obs["TRAFFIC"][protocol][port] = {"inbound": 0, "outbound": 0} - if self.include_nmne: + if self.capture_nmne and self.include_nmne: obs.update({"NMNE": {}}) direction_dict = nic_state["nmne"].get("direction", {}) inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) @@ -224,6 +256,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"): where=parent_where + ["NICs", config.nic_num], include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic, + thresholds=config.thresholds, ) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 3a3283a2..260fac68 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -48,7 +48,13 @@ class NodesObservation(AbstractObservation, discriminator="nodes"): include_num_access: Optional[bool] = None """Flag to include the number of accesses.""" file_system_requires_scan: bool = True - """If True, the folder must be scanned to update the health state. Tf False, the true state is always shown.""" + """If True, the folder must be scanned to update the health state. If False, the true state is always shown.""" + services_requires_scan: bool = True + """If True, the services must be scanned to update the health state. + If False, the true state is always shown.""" + applications_requires_scan: bool = True + """If True, the applications must be scanned to update the health state. + If False, the true state is always shown.""" include_users: Optional[bool] = True """If True, report user session information.""" num_ports: Optional[int] = None @@ -196,8 +202,14 @@ class NodesObservation(AbstractObservation, discriminator="nodes"): host_config.include_num_access = config.include_num_access if host_config.file_system_requires_scan is None: host_config.file_system_requires_scan = config.file_system_requires_scan + if host_config.services_requires_scan is None: + host_config.services_requires_scan = config.services_requires_scan + if host_config.applications_requires_scan is None: + host_config.applications_requires_scan = config.applications_requires_scan if host_config.include_users is None: host_config.include_users = config.include_users + if not host_config.thresholds: + host_config.thresholds = config.thresholds for router_config in config.routers: if router_config.num_ports is None: @@ -214,6 +226,8 @@ class NodesObservation(AbstractObservation, discriminator="nodes"): router_config.num_rules = config.num_rules if router_config.include_users is None: router_config.include_users = config.include_users + if not router_config.thresholds: + router_config.thresholds = config.thresholds for firewall_config in config.firewalls: if firewall_config.ip_list is None: @@ -228,6 +242,8 @@ class NodesObservation(AbstractObservation, discriminator="nodes"): firewall_config.num_rules = config.num_rules if firewall_config.include_users is None: firewall_config.include_users = config.include_users + if not firewall_config.thresholds: + firewall_config.thresholds = config.thresholds hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts] routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers] diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py index 032435b8..e8cb18aa 100644 --- a/src/primaite/game/agent/observations/observation_manager.py +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -114,7 +114,9 @@ class NestedObservation(AbstractObservation, discriminator="custom"): instances = dict() for component in config.components: obs_class = AbstractObservation._registry[component.type] - obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options)) + obs_instance = obs_class.from_config( + config=obs_class.ConfigSchema(**component.options, thresholds=config.thresholds) + ) instances[component.label] = obs_instance return cls(components=instances) @@ -242,8 +244,5 @@ class ObservationManager(BaseModel): """ if config is None: return cls(NullObservation()) - obs_type = config["type"] - obs_class = AbstractObservation._registry[obs_type] - observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"])) - obs_manager = cls(observation) + obs_manager = cls(config=config) return obs_manager diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index da81d2ad..8558b75c 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """Manages the observation space for the agent.""" from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable, Optional, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Type, Union from gymnasium import spaces from gymnasium.core import ObsType @@ -19,6 +19,9 @@ class AbstractObservation(ABC): class ConfigSchema(ABC, BaseModel): """Config schema for observations.""" + thresholds: Optional[Dict] = {} + """A dict containing the observation thresholds.""" + model_config = ConfigDict(extra="forbid") _registry: Dict[str, Type["AbstractObservation"]] = {} @@ -69,3 +72,34 @@ class AbstractObservation(ABC): def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> "AbstractObservation": """Create this observation space component form a serialised format.""" return cls() + + def _validate_thresholds(self, thresholds: List[int] = None, threshold_identifier: Optional[str] = "") -> bool: + """ + Method that checks if the thresholds are non overlapping and in the correct (ascending) order. + + Pass in the thresholds from low to high e.g. + thresholds=[low_threshold, med_threshold, ..._threshold, high_threshold] + + Throws an error if the threshold is not valid + + :param: thresholds: List of thresholds in ascending order. + :type: List[int] + :param: threshold_identifier: The name of the threshold option. + :type: Optional[str] + + :returns: bool + """ + if thresholds is None or len(thresholds) < 2: + raise Exception(f"{threshold_identifier} thresholds are invalid {thresholds}") + for idx in range(1, len(thresholds)): + if not isinstance(thresholds[idx], int): + raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.") + if not isinstance(thresholds[idx - 1], int): + raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.") + + if thresholds[idx] <= thresholds[idx - 1]: + raise Exception( + f"{threshold_identifier} threshold ({thresholds[idx - 1]}) " + f"is greater than or equal to ({thresholds[idx]}.)" + ) + return True diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index 07ec1abf..dac6b362 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import Dict +from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType @@ -19,7 +19,10 @@ class ServiceObservation(AbstractObservation, discriminator="service"): service_name: str """Name of the service, used for querying simulation state dictionary""" - def __init__(self, where: WhereType) -> None: + services_requires_scan: Optional[bool] = None + """If True, services must be scanned to update the health state. If False, true state is always shown.""" + + def __init__(self, where: WhereType, services_requires_scan: bool) -> None: """ Initialise a service observation instance. @@ -28,6 +31,7 @@ class ServiceObservation(AbstractObservation, discriminator="service"): :type where: WhereType """ self.where = where + self.services_requires_scan = services_requires_scan self.default_observation = {"operating_status": 0, "health_status": 0} def observe(self, state: Dict) -> ObsType: @@ -44,7 +48,9 @@ class ServiceObservation(AbstractObservation, discriminator="service"): return self.default_observation return { "operating_status": service_state["operating_state"], - "health_status": service_state["health_state_visible"], + "health_status": service_state["health_state_visible"] + if self.services_requires_scan + else service_state["health_state_actual"], } @property @@ -70,7 +76,9 @@ class ServiceObservation(AbstractObservation, discriminator="service"): :return: Constructed service observation instance. :rtype: ServiceObservation """ - return cls(where=parent_where + ["services", config.service_name]) + return cls( + where=parent_where + ["services", config.service_name], services_requires_scan=config.services_requires_scan + ) class ApplicationObservation(AbstractObservation, discriminator="application"): @@ -82,7 +90,12 @@ class ApplicationObservation(AbstractObservation, discriminator="application"): application_name: str """Name of the application, used for querying simulation state dictionary""" - def __init__(self, where: WhereType) -> None: + applications_requires_scan: Optional[bool] = None + """ + If True, applications must be scanned to update the health state. If False, true state is always shown. + """ + + def __init__(self, where: WhereType, applications_requires_scan: bool, thresholds: Optional[Dict] = {}) -> None: """ Initialise an application observation instance. @@ -92,25 +105,52 @@ class ApplicationObservation(AbstractObservation, discriminator="application"): :type where: WhereType """ self.where = where + self.applications_requires_scan = applications_requires_scan self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0} - # TODO: allow these to be configured in yaml - self.high_threshold = 10 - self.med_threshold = 5 - self.low_threshold = 0 + if thresholds.get("app_executions") is None: + self.low_app_execution_threshold = 0 + self.med_app_execution_threshold = 5 + self.high_app_execution_threshold = 10 + else: + self._set_application_execution_thresholds( + thresholds=[ + thresholds.get("app_executions")["low"], + thresholds.get("app_executions")["medium"], + thresholds.get("app_executions")["high"], + ] + ) + + def _set_application_execution_thresholds(self, thresholds: List[int]): + """ + Method that validates and then sets the application execution threshold. + + :param: thresholds: The application execution threshold to validate and set. + """ + if self._validate_thresholds( + thresholds=[ + thresholds[0], + thresholds[1], + thresholds[2], + ], + threshold_identifier="app_executions", + ): + self.low_app_execution_threshold = thresholds[0] + self.med_app_execution_threshold = thresholds[1] + self.high_app_execution_threshold = thresholds[2] def _categorise_num_executions(self, num_executions: int) -> int: """ - Represent number of file accesses as a categorical variable. + Represent number of application executions as a categorical variable. - :param num_access: Number of file accesses. + :param num_access: Number of application executions. :return: Bin number corresponding to the number of accesses. """ - if num_executions > self.high_threshold: + if num_executions > self.high_app_execution_threshold: return 3 - elif num_executions > self.med_threshold: + elif num_executions > self.med_app_execution_threshold: return 2 - elif num_executions > self.low_threshold: + elif num_executions > self.low_app_execution_threshold: return 1 return 0 @@ -128,7 +168,9 @@ class ApplicationObservation(AbstractObservation, discriminator="application"): return self.default_observation return { "operating_status": application_state["operating_state"], - "health_status": application_state["health_state_visible"], + "health_status": application_state["health_state_visible"] + if self.applications_requires_scan + else application_state["health_state_actual"], "num_executions": self._categorise_num_executions(application_state["num_executions"]), } @@ -161,4 +203,8 @@ class ApplicationObservation(AbstractObservation, discriminator="application"): :return: Constructed application observation instance. :rtype: ApplicationObservation """ - return cls(where=parent_where + ["applications", config.application_name]) + return cls( + where=parent_where + ["applications", config.application_name], + applications_requires_scan=config.applications_requires_scan, + thresholds=config.thresholds, + ) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 1427776e..a3b77ec3 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, ConfigDict from primaite import DEFAULT_BANDWIDTH, getLogger from primaite.game.agent.interface import AbstractAgent, ProxyAgent +from primaite.game.agent.observations import NICObservation from primaite.game.agent.rewards import SharedReward from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT @@ -44,15 +45,15 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) SERVICE_TYPES_MAPPING = { - "DNSClient": DNSClient, - "DNSServer": DNSServer, - "DatabaseService": DatabaseService, - "WebServer": WebServer, - "FTPClient": FTPClient, - "FTPServer": FTPServer, - "NTPClient": NTPClient, - "NTPServer": NTPServer, - "Terminal": Terminal, + "dns-client": DNSClient, + "dns-server": DNSServer, + "database-service": DatabaseService, + "web-server": WebServer, + "ftp-client": FTPClient, + "ftp-server": FTPServer, + "ntp-client": NTPClient, + "ntp-server": NTPServer, + "terminal": Terminal, } """List of available services that can be installed on nodes in the PrimAITE Simulation.""" @@ -68,6 +69,8 @@ class PrimaiteGameOptions(BaseModel): seed: int = None """Random number seed for RNGs.""" + generate_seed_value: bool = False + """Internally generated seed value.""" max_episode_length: int = 256 """Maximum number of episodes for the PrimAITE game.""" ports: List[Port] @@ -175,6 +178,7 @@ class PrimaiteGame: parameters=parameters, request=request, response=response, + observation=obs, ) def pre_timestep(self) -> None: @@ -263,6 +267,7 @@ class PrimaiteGame: node_sets_cfg = network_config.get("node_sets", []) # Set the NMNE capture config NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {})) + NICObservation.capture_nmne = NMNEConfig(**network_config.get("nmne_config", {})).capture_nmne for node_cfg in nodes_cfg: n_type = node_cfg["type"] @@ -293,6 +298,7 @@ class PrimaiteGame: if "users" in node_cfg and new_node.software_manager.software.get("user-manager"): user_manager: UserManager = new_node.software_manager.software["user-manager"] # noqa + for user_cfg in node_cfg["users"]: user_manager.add_user(**user_cfg, bypass_can_perform_action=True) @@ -407,6 +413,7 @@ class PrimaiteGame: agents_cfg = cfg.get("agents", []) for agent_cfg in agents_cfg: + agent_cfg = {**agent_cfg, "thresholds": game.options.thresholds} new_agent = AbstractAgent.from_config(agent_cfg) game.agents[agent_cfg["ref"]] = new_agent if isinstance(new_agent, ProxyAgent): diff --git a/src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb b/src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb index ef4e75dd..52499ea6 100644 --- a/src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb @@ -50,40 +50,22 @@ "custom_c2_agent = \"\"\"\n", " - ref: CustomC2Agent\n", " team: RED\n", - " type: ProxyAgent\n", + " type: proxy-a.gent\n", "\n", " action_space:\n", - " options:\n", - " nodes:\n", - " - node_name: web_server\n", - " applications:\n", - " - application_name: C2Beacon\n", - " - node_name: client_1\n", - " applications:\n", - " - application_name: C2Server\n", - " max_folders_per_node: 1\n", - " max_files_per_folder: 1\n", - " max_services_per_node: 2\n", - " max_nics_per_node: 8\n", - " max_acl_rules: 10\n", - " ip_list:\n", - " - 192.168.1.21\n", - " - 192.168.1.14\n", - " wildcard_list:\n", - " - 0.0.0.1\n", " action_map:\n", " 0:\n", " action: do_nothing\n", " options: {}\n", " 1:\n", - " action: node_application_install\n", + " action: node-application-install\n", " options:\n", - " node_id: 0\n", - " application_name: C2Beacon\n", + " node_name: web_server\n", + " application_name: c2-beacon\n", " 2:\n", - " action: configure_c2_beacon\n", + " action: configure-c2-beacon\n", " options:\n", - " node_id: 0\n", + " node_name: web_server\n", " config:\n", " c2_server_ip_address: 192.168.10.21\n", " keep_alive_frequency:\n", @@ -92,10 +74,10 @@ " 3:\n", " action: node_application_execute\n", " options:\n", - " node_id: 0\n", - " application_id: 0\n", + " node_name: web_server\n", + " application_name: c2-beacon\n", " 4:\n", - " action: c2_server_terminal_command\n", + " action: c2-server-terminal-command\n", " options:\n", " node_id: 1\n", " ip_address:\n", @@ -111,14 +93,14 @@ " 5:\n", " action: c2-server-ransomware-configure\n", " options:\n", - " node_id: 1\n", + " node_name: client_1\n", " config:\n", " server_ip_address: 192.168.1.14\n", " payload: ENCRYPT\n", " 6:\n", - " action: c2_server_data_exfiltrate\n", + " action: c2-server-data-exfiltrate\n", " options:\n", - " node_id: 1\n", + " node_name: client_1\n", " target_file_name: \"database.db\"\n", " target_folder_name: \"database\"\n", " exfiltration_folder_name: \"spoils\"\n", @@ -128,31 +110,27 @@ " password: admin\n", "\n", " 7:\n", - " action: c2_server_ransomware_launch\n", + " action: c2-server-ransomware-launch\n", " options:\n", - " node_id: 1\n", + " node_name: client_1\n", " 8:\n", - " action: configure_c2_beacon\n", + " action: configure-c2-beacon\n", " options:\n", - " node_id: 0\n", + " node_name: web_server\n", " config:\n", " c2_server_ip_address: 192.168.10.21\n", " keep_alive_frequency: 10\n", " masquerade_protocol: TCP\n", " masquerade_port: DNS\n", " 9:\n", - " action: configure_c2_beacon\n", + " action: configure-c2-beacon\n", " options:\n", - " node_id: 0\n", + " node_name: web_server\n", " config:\n", " c2_server_ip_address: 192.168.10.22\n", " keep_alive_frequency:\n", " masquerade_protocol:\n", " masquerade_port:\n", - "\n", - " reward_function:\n", - " reward_components:\n", - " - type: DUMMY\n", "\"\"\"\n", "c2_agent_yaml = yaml.safe_load(custom_c2_agent)" ] @@ -225,7 +203,7 @@ " nodes: # Node List\n", " - node_name: web_server\n", " applications: \n", - " - application_name: C2Beacon\n", + " - application_name: c2-beacon\n", " ...\n", " ...\n", " action_map:\n", @@ -233,7 +211,7 @@ " action: node_application_install \n", " options:\n", " node_id: 0 # Index 0 at the node list.\n", - " application_name: C2Beacon\n", + " application_name: c2-beacon\n", "```" ] }, @@ -268,7 +246,7 @@ " action_map:\n", " ...\n", " 2:\n", - " action: configure_c2_beacon\n", + " action: configure-c2-beacon\n", " options:\n", " node_id: 0 # Node Index\n", " config: # Further information about these config options can be found at the bottom of this notebook.\n", @@ -286,7 +264,7 @@ "outputs": [], "source": [ "env.step(2)\n", - "c2_beacon: C2Beacon = web_server.software_manager.software[\"C2Beacon\"]\n", + "c2_beacon: C2Beacon = web_server.software_manager.software[\"c2-beacon\"]\n", "web_server.software_manager.show()\n", "c2_beacon.show()" ] @@ -307,13 +285,13 @@ " nodes: # Node List\n", " - node_name: web_server\n", " applications: \n", - " - application_name: C2Beacon\n", + " - application_name: c2-beacon\n", " ...\n", " ...\n", " action_map:\n", " ...\n", " 3:\n", - " action: node_application_execute\n", + " action: node-application-execute\n", " options:\n", " node_id: 0\n", " application_id: 0\n", @@ -374,11 +352,11 @@ " ...\n", " - node_name: client_1\n", " applications: \n", - " - application_name: C2Server\n", + " - application_name: c2-server\n", " ...\n", " action_map:\n", " 4:\n", - " action: C2_SERVER_TERMINAL_COMMAND\n", + " action: c2-server-terminal-command\n", " options:\n", " node_id: 1\n", " ip_address:\n", @@ -431,7 +409,7 @@ " ...\n", " - node_name: client_1\n", " applications: \n", - " - application_name: C2Server\n", + " - application_name: c2-server\n", " ...\n", " action_map:\n", " 5:\n", @@ -459,7 +437,7 @@ "metadata": {}, "outputs": [], "source": [ - "ransomware_script: RansomwareScript = web_server.software_manager.software[\"RansomwareScript\"]\n", + "ransomware_script: RansomwareScript = web_server.software_manager.software[\"ransomware-script\"]\n", "web_server.software_manager.show()\n", "ransomware_script.show()" ] @@ -483,11 +461,11 @@ " ...\n", " - node_name: client_1\n", " applications: \n", - " - application_name: C2Server\n", + " - application_name: c2-server\n", " ...\n", " action_map:\n", " 6:\n", - " action: c2_server_data_exfiltrate\n", + " action: c2-server-data-exfiltrate\n", " options:\n", " node_id: 1\n", " target_file_name: \"database.db\"\n", @@ -549,11 +527,11 @@ " ...\n", " - node_name: client_1\n", " applications: \n", - " - application_name: C2Server\n", + " - application_name: c2-server\n", " ...\n", " action_map:\n", " 7:\n", - " action: c2_server_ransomware_launch\n", + " action: c2-server-ransomware-launch\n", " options:\n", " node_id: 1\n", "```\n" @@ -598,20 +576,20 @@ "custom_blue_agent_yaml = \"\"\"\n", " - ref: defender\n", " team: BLUE\n", - " type: ProxyAgent\n", + " type: proxy-agent\n", "\n", " observation_space:\n", - " type: CUSTOM\n", + " type: custom\n", " options:\n", " components:\n", - " - type: NODES\n", + " - type: nodes\n", " label: NODES\n", " options:\n", " hosts:\n", " - hostname: web_server\n", " applications:\n", - " - application_name: C2Beacon\n", - " - application_name: RansomwareScript\n", + " - application_name: c2-beacon\n", + " - application_name: ransomware-script\n", " folders:\n", " - folder_name: exfiltration_folder\n", " files:\n", @@ -661,7 +639,7 @@ " - UDP\n", " num_rules: 10\n", "\n", - " - type: LINKS\n", + " - type: links\n", " label: LINKS\n", " options:\n", " link_references:\n", @@ -675,7 +653,7 @@ " - switch_2:eth-1<->client_1:eth-1\n", " - switch_2:eth-2<->client_2:eth-1\n", " - switch_2:eth-7<->security_suite:eth-2\n", - " - type: \"NONE\"\n", + " - type: \"none\"\n", " label: ICS\n", " options: {}\n", "\n", @@ -685,16 +663,16 @@ " action: do_nothing\n", " options: {}\n", " 1:\n", - " action: node_application_remove\n", + " action: node-application-remove\n", " options:\n", - " node_id: 0\n", + " node_name: web-server\n", " application_name: C2Beacon\n", " 2:\n", - " action: node_shutdown\n", + " action: node-shutdown\n", " options:\n", - " node_id: 0\n", + " node_name: web-server\n", " 3:\n", - " action: router_acl_add_rule\n", + " action: router-acl-add-rule\n", " options:\n", " target_router: router_1\n", " position: 1\n", @@ -707,36 +685,6 @@ " source_wildcard_id: 0\n", " dest_wildcard_id: 0\n", "\n", - "\n", - " options:\n", - " nodes:\n", - " - node_name: web_server\n", - " applications:\n", - " - application_name: C2Beacon\n", - "\n", - " - node_name: database_server\n", - " folders:\n", - " - folder_name: database\n", - " files:\n", - " - file_name: database.db\n", - " services:\n", - " - service_name: DatabaseService\n", - " - node_name: router_1\n", - "\n", - " max_folders_per_node: 2\n", - " max_files_per_folder: 2\n", - " max_services_per_node: 2\n", - " max_nics_per_node: 8\n", - " max_acl_rules: 10\n", - " ip_list:\n", - " - 192.168.10.21\n", - " - 192.168.1.12\n", - " wildcard_list:\n", - " - 0.0.0.1\n", - " reward_function:\n", - " reward_components:\n", - " - type: DUMMY\n", - "\n", " agent_settings:\n", " flatten_obs: False\n", "\"\"\"\n", @@ -875,7 +823,7 @@ "outputs": [], "source": [ "# Installing RansomwareScript via C2 Terminal Commands\n", - "ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"RansomwareScript\"]],\n", + "ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"ransomware-script\"]],\n", " \"username\": \"admin\",\n", " \"password\": \"admin\"}\n", "c2_server.send_command(C2Command.TERMINAL, command_options=ransomware_install_command)\n" @@ -1034,11 +982,11 @@ " web_server: Server = given_env.game.simulation.network.get_node_by_hostname(\"web_server\")\n", "\n", " client_1.software_manager.install(C2Server)\n", - " c2_server: C2Server = client_1.software_manager.software[\"C2Server\"]\n", + " c2_server: C2Server = client_1.software_manager.software[\"c2-server\"]\n", " c2_server.run()\n", "\n", " web_server.software_manager.install(C2Beacon)\n", - " c2_beacon: C2Beacon = web_server.software_manager.software[\"C2Beacon\"]\n", + " c2_beacon: C2Beacon = web_server.software_manager.software[\"c2-beacon\"]\n", " c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\")\n", " c2_beacon.establish()\n", "\n", @@ -1132,11 +1080,11 @@ "outputs": [], "source": [ "# Attempting to install the C2 RansomwareScript\n", - "ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"RansomwareScript\"]],\n", + "ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"ransomware-script\"]],\n", " \"username\": \"admin\",\n", " \"password\": \"admin\"}\n", "\n", - "c2_server: C2Server = client_1.software_manager.software[\"C2Server\"]\n", + "c2_server: C2Server = client_1.software_manager.software[\"c2-server\"]\n", "c2_server.send_command(C2Command.TERMINAL, command_options=ransomware_install_command)" ] }, @@ -1220,11 +1168,11 @@ "outputs": [], "source": [ "# Attempting to install the C2 RansomwareScript\n", - "ransomware_install_command = {\"commands\":[\"software_manager\", \"application\", \"install\", \"RansomwareScript\"],\n", + "ransomware_install_command = {\"commands\":[\"software_manager\", \"application\", \"install\", \"ransomware-script\"],\n", " \"username\": \"admin\",\n", " \"password\": \"admin\"}\n", "\n", - "c2_server: C2Server = client_1.software_manager.software[\"C2Server\"]\n", + "c2_server: C2Server = client_1.software_manager.software[\"c2-server\"]\n", "c2_server.send_command(C2Command.TERMINAL, command_options=ransomware_install_command)" ] }, @@ -1345,7 +1293,7 @@ "metadata": {}, "outputs": [], "source": [ - "database_server: Server = blue_env.game.simulation.network.get_node_by_hostname(\"database_server\")\n", + "database_server: Server = blue_env.game.simulation.network.get_node_by_hostname(\"database-server\")\n", "database_server.software_manager.file_system.show(full=True)" ] }, @@ -1391,7 +1339,7 @@ "\n", "``` YAML\n", "...\n", - " action: configure_c2_beacon\n", + " action: configure-c2-beacon\n", " options:\n", " node_id: 0\n", " config:\n", @@ -1446,16 +1394,16 @@ "source": [ "web_server: Server = c2_config_env.game.simulation.network.get_node_by_hostname(\"web_server\")\n", "web_server.software_manager.install(C2Beacon)\n", - "c2_beacon: C2Beacon = web_server.software_manager.software[\"C2Beacon\"]\n", + "c2_beacon: C2Beacon = web_server.software_manager.software[\"c2-beacon\"]\n", "\n", "client_1: Computer = c2_config_env.game.simulation.network.get_node_by_hostname(\"client_1\")\n", "client_1.software_manager.install(C2Server)\n", - "c2_server_1: C2Server = client_1.software_manager.software[\"C2Server\"]\n", + "c2_server_1: C2Server = client_1.software_manager.software[\"c2-server\"]\n", "c2_server_1.run()\n", "\n", "client_2: Computer = c2_config_env.game.simulation.network.get_node_by_hostname(\"client_2\")\n", "client_2.software_manager.install(C2Server)\n", - "c2_server_2: C2Server = client_2.software_manager.software[\"C2Server\"]\n", + "c2_server_2: C2Server = client_2.software_manager.software[\"c2-server\"]\n", "c2_server_2.run()" ] }, @@ -1759,6 +1707,16 @@ "\n", "display_obs_diffs(tcp_c2_obs, udp_c2_obs, blue_config_env.game.step_counter)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "env.game.agents[\"CustomC2Agent\"].show_history()" + ] } ], "metadata": { diff --git a/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb b/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb index 756fc44f..2dbf750e 100644 --- a/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb @@ -47,7 +47,7 @@ "source": [ "def make_cfg_have_flat_obs(cfg):\n", " for agent in cfg['agents']:\n", - " if agent['type'] == \"ProxyAgent\":\n", + " if agent['type'] == \"proxy-agent\":\n", " agent['agent_settings']['flatten_obs'] = False" ] }, @@ -76,9 +76,9 @@ " # parse the info dict form step output and write out what the red agent is doing\n", " red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", " red_action = red_info.action\n", - " if red_action == 'do_nothing':\n", + " if red_action == 'do-nothing':\n", " red_str = 'DO NOTHING'\n", - " elif red_action == 'node_application_execute':\n", + " elif red_action == 'node-application-execute':\n", " client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n", " red_str = f\"ATTACK from {client}\"\n", " return red_str" @@ -147,36 +147,14 @@ "```yaml\n", " - ref: data_manipulation_attacker # name of agent\n", " team: RED # not used, just for human reference\n", - " type: RedDatabaseCorruptingAgent # type of agent - this lets primaite know which agent class to use\n", + " type: red-database-corrupting-agent # type of agent - this lets primaite know which agent class to use\n", "\n", " # Since the agent does not need to react to what is happening in the environment, the observation space is empty.\n", " observation_space:\n", - " type: UC2RedObservation\n", + " type: uc2-red-observation # TODO: what\n", " options:\n", " nodes: {}\n", "\n", - " action_space:\n", - " \n", - " # The agent has access to the DataManipulationBoth on clients 1 and 2.\n", - " options:\n", - " nodes:\n", - " - node_name: client_1 # The network should have a node called client_1\n", - " applications:\n", - " - application_name: DataManipulationBot # The node client_1 should have DataManipulationBot configured on it\n", - " - node_name: client_2 # The network should have a node called client_2\n", - " applications:\n", - " - application_name: DataManipulationBot # The node client_2 should have DataManipulationBot configured on it\n", - "\n", - " # not important\n", - " max_folders_per_node: 1\n", - " max_files_per_folder: 1\n", - " max_services_per_node: 1\n", - "\n", - " # red agent does not need a reward function\n", - " reward_function:\n", - " reward_components:\n", - " - type: DUMMY\n", - "\n", " # These actions are passed to the RedDatabaseCorruptingAgent init method, they dictate the schedule of attacks\n", " agent_settings:\n", " start_settings:\n", @@ -211,15 +189,13 @@ " \n", " # \n", " applications:\n", - " - ref: data_manipulation_bot\n", - " type: DataManipulationBot\n", + " - type: data-manipulation-bot\n", " options:\n", " port_scan_p_of_success: 0.8 # Probability that port scan is successful\n", " data_manipulation_p_of_success: 0.8 # Probability that SQL attack is successful\n", " payload: \"DELETE\" # The SQL query which causes the attack (this has to be DELETE)\n", " server_ip: 192.168.1.14 # IP address of server hosting the database\n", - " - ref: client_1_database_client\n", - " type: DatabaseClient # Database client must be installed in order for DataManipulationBot to function\n", + " - type: database-client # Database client must be installed in order for DataManipulationBot to function\n", " options:\n", " db_server_ip: 192.168.1.14 # IP address of server hosting the database\n", "```" @@ -354,19 +330,16 @@ "# Make attack always succeed.\n", "change = yaml.safe_load(\"\"\"\n", " applications:\n", - " - ref: data_manipulation_bot\n", - " type: DataManipulationBot\n", + " - type: data-manipulation-bot\n", " options:\n", " port_scan_p_of_success: 1.0\n", " data_manipulation_p_of_success: 1.0\n", " payload: \"DELETE\"\n", " server_ip: 192.168.1.14\n", - " - ref: client_1_web_browser\n", - " type: WebBrowser\n", + " - type: web-browser\n", " options:\n", " target_url: http://arcd.com/users/\n", - " - ref: client_1_database_client\n", - " type: DatabaseClient\n", + " - type: database-client\n", " options:\n", " db_server_ip: 192.168.1.14\n", "\"\"\")\n", @@ -399,19 +372,16 @@ "# Make attack always fail.\n", "change = yaml.safe_load(\"\"\"\n", " applications:\n", - " - ref: data_manipulation_bot\n", - " type: DataManipulationBot\n", + " - type: data-manipulation-bot\n", " options:\n", " port_scan_p_of_success: 0.0\n", " data_manipulation_p_of_success: 0.0\n", " payload: \"DELETE\"\n", " server_ip: 192.168.1.14\n", - " - ref: client_1_web_browser\n", - " type: WebBrowser\n", + " - type: web-browser\n", " options:\n", " target_url: http://arcd.com/users/\n", - " - ref: client_1_database_client\n", - " type: DatabaseClient\n", + " - type: database-client\n", " options:\n", " db_server_ip: 192.168.1.14\n", "\"\"\")\n", diff --git a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb index dbc6f0c1..2070e03c 100644 --- a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb @@ -684,6 +684,15 @@ " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.game.agents[\"data_manipulation_attacker\"].show_history()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -717,7 +726,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb b/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb index f8691d7d..58573ac6 100644 --- a/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb +++ b/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb @@ -153,6 +153,49 @@ "PRIMAITE_CONFIG[\"developer_mode\"][\"enabled\"] = was_enabled\n", "PRIMAITE_CONFIG[\"developer_mode\"][\"output_sys_logs\"] = was_syslogs_enabled" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Viewing Agent history\n", + "\n", + "It's possible to view the actions carried out by an agent for a given training session using the `show_history()` method. By default, this will be all actions apart from DONOTHING actions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(data_manipulation_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + "\n", + "env = PrimaiteGymEnv(env_config=cfg)\n", + "\n", + "# Run the training session to generate some resultant data.\n", + "for i in range(100):\n", + " env.step(0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Calling `.show_history()` should show us when the Data Manipulation used the `NODE_APPLICATION_EXECUTE` action." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "attacker = env.game.agents[\"data_manipulation_attacker\"]\n", + "\n", + "attacker.show_history()" + ] } ], "metadata": { @@ -171,7 +214,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/How-To-Use-Primaite-Dev-Mode.ipynb b/src/primaite/notebooks/How-To-Use-Primaite-Dev-Mode.ipynb new file mode 100644 index 00000000..8f8ec24b --- /dev/null +++ b/src/primaite/notebooks/How-To-Use-Primaite-Dev-Mode.ipynb @@ -0,0 +1,479 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# PrimAITE Developer mode\n", + "\n", + "PrimAITE has built in developer tools.\n", + "\n", + "The dev-mode is designed to help make the development of PrimAITE easier.\n", + "\n", + "`NOTE: For the purposes of the notebook, the commands are preceeded by \"!\". When running the commands, run it without the \"!\".`\n", + "\n", + "To display the available dev-mode options, run the command below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode --help" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Save the current PRIMAITE_CONFIG to restore after the notebook runs\n", + "\n", + "from primaite import PRIMAITE_CONFIG\n", + "\n", + "temp_config = PRIMAITE_CONFIG.copy()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dev mode options" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### enable\n", + "\n", + "Enables the dev mode for PrimAITE.\n", + "\n", + "This will enable the developer mode for PrimAITE.\n", + "\n", + "By default, when developer mode is enabled, session logs will be generated in the PRIMAITE_ROOT/sessions folder unless configured to be generated in another location." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode enable" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### disable\n", + "\n", + "Disables the dev mode for PrimAITE.\n", + "\n", + "This will disable the developer mode for PrimAITE." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode disable" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### show\n", + "\n", + "Shows if PrimAITE is running in dev mode or production mode.\n", + "\n", + "The command will also show the developer mode configuration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode show" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### config\n", + "\n", + "Configure the PrimAITE developer mode" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --help" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### path\n", + "\n", + "Set the path where generated session files will be output.\n", + "\n", + "By default, this value will be in PRIMAITE_ROOT/sessions.\n", + "\n", + "To reset the path to default, run:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config path -root\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config path --default" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### --sys-log-level or -slevel\n", + "\n", + "Set the system log level.\n", + "\n", + "This will override the system log level in configurations and will make PrimAITE include the set log level and above.\n", + "\n", + "Available options are:\n", + "- `DEBUG`\n", + "- `INFO`\n", + "- `WARNING`\n", + "- `ERROR`\n", + "- `CRITICAL`\n", + "\n", + "Default value is `DEBUG`\n", + "\n", + "Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --sys-log-level DEBUG\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -slevel DEBUG" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### --agent-log-level or -alevel\n", + "\n", + "Set the agent log level.\n", + "\n", + "This will override the agent log level in configurations and will make PrimAITE include the set log level and above.\n", + "\n", + "Available options are:\n", + "- `DEBUG`\n", + "- `INFO`\n", + "- `WARNING`\n", + "- `ERROR`\n", + "- `CRITICAL`\n", + "\n", + "Default value is `DEBUG`\n", + "\n", + "Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --agent-log-level DEBUG\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -alevel DEBUG" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### --output-sys-logs or -sys\n", + "\n", + "If enabled, developer mode will output system logs.\n", + "\n", + "Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --output-sys-logs\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -sys" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To disable outputting sys logs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --no-sys-logs\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -nsys" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### --output-agent-logs or -agent\n", + "\n", + "If enabled, developer mode will output agent action logs.\n", + "\n", + "Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --output-agent-logs\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -agent" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To disable outputting agent action logs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --no-agent-logs\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -nagent" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### --output-pcap-logs or -pcap\n", + "\n", + "If enabled, developer mode will output PCAP logs.\n", + "\n", + "Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --output-pcap-logs\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -pcap" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To disable outputting PCAP logs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --no-pcap-logs\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -npcap" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### --output-to-terminal or -t\n", + "\n", + "If enabled, developer mode will output logs to the terminal.\n", + "\n", + "Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --output-to-terminal\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -t" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To disable terminal outputs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --no-terminal\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -nt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Combining commands\n", + "\n", + "It is possible to combine commands to set the configuration.\n", + "\n", + "This saves having to enter multiple commands and allows for a much more efficient setting of PrimAITE developer mode configurations." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Example of setting system log level and enabling the system logging:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config -slevel WARNING -sys" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Another example where the system log and agent action log levels are set and enabled and should be printed to terminal:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config -slevel ERROR -sys -alevel ERROR -agent -t" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Restore PRIMAITE_CONFIG\n", + "from primaite.utils.cli.primaite_config_utils import update_primaite_application_config\n", + "\n", + "\n", + "global PRIMAITE_CONFIG\n", + "PRIMAITE_CONFIG[\"developer_mode\"] = temp_config[\"developer_mode\"]\n", + "update_primaite_application_config(config=PRIMAITE_CONFIG)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/primaite/notebooks/Requests-and-Responses.ipynb b/src/primaite/notebooks/Requests-and-Responses.ipynb index 83aed07c..9260b29d 100644 --- a/src/primaite/notebooks/Requests-and-Responses.ipynb +++ b/src/primaite/notebooks/Requests-and-Responses.ipynb @@ -114,7 +114,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(f\"DNS Client state: {client.software_manager.software.get('DNSClient').operating_state.name}\")" + "print(f\"DNS Client state: {client.software_manager.software.get('dns-client').operating_state.name}\")" ] }, { diff --git a/src/primaite/notebooks/Terminal-Processing.ipynb b/src/primaite/notebooks/Terminal-Processing.ipynb index 9aa4e96a..7c94d432 100644 --- a/src/primaite/notebooks/Terminal-Processing.ipynb +++ b/src/primaite/notebooks/Terminal-Processing.ipynb @@ -9,6 +9,13 @@ "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Simulation Layer Implementation." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -67,9 +74,9 @@ "source": [ "network: Network = basic_network()\n", "computer_a: Computer = network.get_node_by_hostname(\"node_a\")\n", - "terminal_a: Terminal = computer_a.software_manager.software.get(\"Terminal\")\n", + "terminal_a: Terminal = computer_a.software_manager.software.get(\"terminal\")\n", "computer_b: Computer = network.get_node_by_hostname(\"node_b\")\n", - "terminal_b: Terminal = computer_b.software_manager.software.get(\"Terminal\")" + "terminal_b: Terminal = computer_b.software_manager.software.get(\"terminal\")" ] }, { @@ -121,7 +128,7 @@ "metadata": {}, "outputs": [], "source": [ - "term_a_term_b_remote_connection.execute([\"software_manager\", \"application\", \"install\", \"RansomwareScript\"])" + "term_a_term_b_remote_connection.execute([\"software_manager\", \"application\", \"install\", \"ransomware-script\"])" ] }, { @@ -169,6 +176,22 @@ "computer_b.file_system.show()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Information about the latest response when executing a remote command can be seen by calling the `last_response` attribute within `Terminal`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(terminal_a.last_response)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -207,6 +230,263 @@ "source": [ "computer_b.user_session_manager.show(include_historic=True, include_session_id=True)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Game Layer Implementation\n", + "\n", + "This notebook section will detail the implementation of how the game layer utilises the terminal to support different agent actions.\n", + "\n", + "The ``Terminal`` is used in a variety of different ways in the game layer. Specifically, the terminal is leveraged to implement the following actions:\n", + "\n", + "\n", + "| Game Layer Action | Simulation Layer |\n", + "|-----------------------------------|--------------------------|\n", + "| ``node-send-local-command`` | Uses the given user credentials, creates a ``LocalTerminalSession`` and executes the given command and returns the ``RequestResponse``.\n", + "| ``node-session-remote-login`` | Uses the given user credentials and remote IP to create a ``RemoteTerminalSession``.\n", + "| ``node-send-remote-command`` | Uses the given remote IP to locate the correct ``RemoteTerminalSession``, executes the given command and returns the ``RequestsResponse``." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Game Layer Setup\n", + "\n", + "Similar to other notebooks, the next code cells create a custom proxy agent to demonstrate how these commands can be leveraged by agents in the ``UC2`` network environment.\n", + "\n", + "If you're unfamiliar with ``UC2`` then please refer to the [UC2-E2E-Demo notebook for further reference](./Data-Manipulation-E2E-Demonstration.ipynb)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import yaml\n", + "from primaite.config.load import data_manipulation_config_path\n", + "from primaite.session.environment import PrimaiteGymEnv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "custom_terminal_agent = \"\"\"\n", + " - ref: CustomC2Agent\n", + " team: RED\n", + " type: proxy-agent\n", + " observation_space: null\n", + " action_space:\n", + " options:\n", + " nodes:\n", + " - node_name: client_1\n", + " max_folders_per_node: 1\n", + " max_files_per_folder: 1\n", + " max_services_per_node: 2\n", + " max_nics_per_node: 8\n", + " max_acl_rules: 10\n", + " ip_list:\n", + " - 192.168.1.21\n", + " - 192.168.1.14\n", + " wildcard_list:\n", + " - 0.0.0.1\n", + " action_map:\n", + " 0:\n", + " action: do-nothing\n", + " options: {}\n", + " 1:\n", + " action: node-send-local-command\n", + " options:\n", + " node_name: client_1\n", + " username: admin\n", + " password: admin\n", + " command:\n", + " - file_system\n", + " - create\n", + " - file\n", + " - downloads\n", + " - dog.png\n", + " - False\n", + " 2:\n", + " action: node-session-remote-login\n", + " options:\n", + " node_name: client_1\n", + " username: admin\n", + " password: admin\n", + " remote_ip: 192.168.10.22\n", + " 3:\n", + " action: node-send-remote-command\n", + " options:\n", + " node_name: client_1\n", + " remote_ip: 192.168.10.22\n", + " command:\n", + " - file_system\n", + " - create\n", + " - file\n", + " - downloads\n", + " - cat.png\n", + " - False\n", + "\"\"\"\n", + "custom_terminal_agent_yaml = yaml.safe_load(custom_terminal_agent)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(data_manipulation_config_path()) as f:\n", + " cfg = yaml.safe_load(f)\n", + " # removing all agents & adding the custom agent.\n", + " cfg['agents'] = {}\n", + " cfg['agents'] = custom_terminal_agent_yaml\n", + "\n", + "env = PrimaiteGymEnv(env_config=cfg)\n", + "\n", + "client_1: Computer = env.game.simulation.network.get_node_by_hostname(\"client_1\")\n", + "client_2: Computer = env.game.simulation.network.get_node_by_hostname(\"client_2\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Terminal Action | ``node-send-local-command`` \n", + "\n", + "The yaml snippet below shows all the relevant agent options for this action:\n", + "\n", + "```yaml\n", + "\n", + " action_space:\n", + " action_list:\n", + " ...\n", + " - type: node-send-local-command\n", + " ...\n", + " options:\n", + " nodes: # Node List\n", + " - node_name: client_1\n", + " ...\n", + " ...\n", + " action_map:\n", + " 1:\n", + " action: node-send-local-command\n", + " options:\n", + " node_id: 0 # Index 0 at the node list.\n", + " username: admin\n", + " password: admin\n", + " command:\n", + " - file_system\n", + " - create\n", + " - file\n", + " - downloads\n", + " - dog.png\n", + " - False\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.step(1)\n", + "client_1.file_system.show(full=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Terminal Action | ``node-session-remote-login`` \n", + "\n", + "The yaml snippet below shows all the relevant agent options for this action:\n", + "\n", + "```yaml\n", + "\n", + " action_space:\n", + " action_list:\n", + " ...\n", + " - type: node-session-remote-login\n", + " ...\n", + " options:\n", + " nodes: # Node List\n", + " - node_name: client_1\n", + " ...\n", + " ...\n", + " action_map:\n", + " 2:\n", + " action: node-session-remote-login\n", + " options:\n", + " node_id: 0 # Index 0 at the node list.\n", + " username: admin\n", + " password: admin\n", + " remote_ip: 192.168.10.22 # client_2's ip address.\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.step(2)\n", + "client_2.session_manager.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Terminal Action | ``node-send-remote-command``\n", + "\n", + "The yaml snippet below shows all the relevant agent options for this action:\n", + "\n", + "```yaml\n", + "\n", + " action_space:\n", + " action_list:\n", + " ...\n", + " - type: node-send-remote-command\n", + " ...\n", + " options:\n", + " nodes: # Node List\n", + " - node_name: client_1\n", + " ...\n", + " ...\n", + " action_map:\n", + " 1:\n", + " action: node-send-remote-command\n", + " options:\n", + " node_id: 0 # Index 0 at the node list.\n", + " remote_ip: 192.168.10.22\n", + " commands:\n", + " - file_system\n", + " - create\n", + " - file\n", + " - downloads\n", + " - cat.png\n", + " - False\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.step(3)\n", + "client_2.file_system.show(full=True)" + ] } ], "metadata": { diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index b7a9a042..fa545dbc 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -26,14 +26,26 @@ except ModuleNotFoundError: _LOGGER.debug("Torch not available for importing") -def set_random_seed(seed: int) -> Union[None, int]: +def set_random_seed(seed: int, generate_seed_value: bool) -> Union[None, int]: """ Set random number generators. + If seed is None or -1 and generate_seed_value is True randomly generate a + seed value. + If seed is > -1 and generate_seed_value is True ignore the latter and use + the provide seed value. + :param seed: int + :param generate_seed_value: bool + :return: None or the int representing the seed used. """ if seed is None or seed == -1: - return None + if generate_seed_value: + rng = np.random.default_rng() + # 2**32-1 is highest value for python RNG seed. + seed = int(rng.integers(low=0, high=2**32 - 1)) + else: + return None elif seed < -1: raise ValueError("Invalid random number seed") # Seed python RNG @@ -50,6 +62,13 @@ def set_random_seed(seed: int) -> Union[None, int]: return seed +def log_seed_value(seed: int): + """Log the selected seed value to file.""" + path = SIM_OUTPUT.path / "seed.log" + with open(path, "w") as file: + file.write(f"Seed value = {seed}") + + class PrimaiteGymEnv(gymnasium.Env): """ Thin wrapper env to provide agents with a gymnasium API. @@ -65,7 +84,8 @@ class PrimaiteGymEnv(gymnasium.Env): """Object that returns a config corresponding to the current episode.""" self.seed = self.episode_scheduler(0).get("game", {}).get("seed") """Get RNG seed from config file. NB: Must be before game instantiation.""" - self.seed = set_random_seed(self.seed) + self.generate_seed_value = self.episode_scheduler(0).get("game", {}).get("generate_seed_value") + self.seed = set_random_seed(self.seed, self.generate_seed_value) self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {})) """Handles IO for the environment. This produces sys logs, agent logs, etc.""" self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0)) @@ -79,6 +99,8 @@ class PrimaiteGymEnv(gymnasium.Env): _LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}") + log_seed_value(self.seed) + def action_masks(self) -> np.ndarray: """ Return the action mask for the agent. @@ -146,7 +168,7 @@ class PrimaiteGymEnv(gymnasium.Env): f"avg. reward: {self.agent.reward_function.total_reward}" ) if seed is not None: - set_random_seed(seed) + set_random_seed(seed, self.generate_seed_value) self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward if self.io.settings.save_agent_actions: diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 8653359a..50a7d2a4 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -864,7 +864,21 @@ class UserManager(Service, discriminator="user-manager"): """ rm = super()._init_request_manager() - # todo add doc about requeest schemas + # todo add doc about request schemas + rm.add_request( + "add_user", + RequestType( + func=lambda request, context: RequestResponse.from_bool( + self.add_user(username=request[0], password=request[1], is_admin=request[2]) + ) + ), + ) + rm.add_request( + "disable_user", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.disable_user(username=request[0])) + ), + ) rm.add_request( "change_password", RequestType( @@ -1572,7 +1586,7 @@ class Node(SimComponent, ABC): operating_state: Any = None - users: Any = None # Temporary to appease "extra=forbid" + users: List[Dict] = [] # Temporary to appease "extra=forbid" config: ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema()) """Configuration items within Node""" @@ -1638,6 +1652,8 @@ class Node(SimComponent, ABC): self._install_system_software() self.session_manager.node = self self.session_manager.software_manager = self.software_manager + for user in self.config.users: + self.user_manager.add_user(**user, bypass_can_perform_action=True) @property def user_manager(self) -> Optional[UserManager]: @@ -1769,7 +1785,7 @@ class Node(SimComponent, ABC): """ application_name = request[0] if self.software_manager.software.get(application_name): - self.sys_log.warning(f"Can't install {application_name}. It's already installed.") + self.sys_log.info(f"Can't install {application_name}. It's already installed.") return RequestResponse(status="success", data={"reason": "already installed"}) application_class = Application._registry[application_name] self.software_manager.install(application_class) diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 76d9167c..86ac790c 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -2,11 +2,12 @@ from __future__ import annotations from ipaddress import IPv4Address -from typing import Any, ClassVar, Dict, Literal, Optional +from typing import Any, ClassVar, Dict, List, Literal, Optional from pydantic import Field from primaite import getLogger +from primaite.simulator.file_system.file_type import FileType from primaite.simulator.network.hardware.base import ( IPWiredNetworkInterface, Link, @@ -313,7 +314,7 @@ class HostNode(Node, discriminator="host-node"): """ SYSTEM_SOFTWARE: ClassVar[Dict] = { - "HostARP": HostARP, + "host-arp": HostARP, "icmp": ICMP, "dns-client": DNSClient, "ntp-client": NTPClient, @@ -339,7 +340,7 @@ class HostNode(Node, discriminator="host-node"): ip_address: IPV4Address services: Any = None # temporarily unset to appease extra="forbid" applications: Any = None # temporarily unset to appease extra="forbid" - folders: Any = None # temporarily unset to appease extra="forbid" + folders: List[Dict] = {} # temporarily unset to appease extra="forbid" network_interfaces: Any = None # temporarily unset to appease extra="forbid" config: ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema()) @@ -348,6 +349,18 @@ class HostNode(Node, discriminator="host-node"): super().__init__(**kwargs) self.connect_nic(NIC(ip_address=kwargs["config"].ip_address, subnet_mask=kwargs["config"].subnet_mask)) + for folder in self.config.folders: + # handle empty foler defined by just a string + self.file_system.create_folder(folder["folder_name"]) + + for file in folder.get("files", []): + self.file_system.create_file( + folder_name=folder["folder_name"], + file_name=file["file_name"], + size=file.get("size", 0), + file_type=FileType[file.get("type", "UNKNOWN").upper()], + ) + @property def nmap(self) -> Optional[NMAP]: """ diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 2cbc23d2..c872b8b3 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -49,7 +49,7 @@ class Firewall(Router, discriminator="firewall"): Example: >>> from primaite.simulator.network.transmission.network_layer import IPProtocol - >>> from primaite.simulator.network.transmission.transport_layer import Port + >>> from primaite.utils.validation.port import Port >>> firewall = Firewall(hostname="Firewall1") >>> firewall.configure_internal_port(ip_address="192.168.1.1", subnet_mask="255.255.255.0") >>> firewall.configure_external_port(ip_address="10.0.0.1", subnet_mask="255.255.255.0") diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 3b35600b..47ee3169 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -467,6 +467,7 @@ class AccessControlList(SimComponent): """Check if a packet with the given properties is permitted through the ACL.""" permitted = False rule: ACLRule = None + for _rule in self._acl: if not _rule: continue @@ -1215,9 +1216,9 @@ class Router(NetworkNode, discriminator="router"): config: ConfigSchema = Field(default_factory=lambda: Router.ConfigSchema()) SYSTEM_SOFTWARE: ClassVar[Dict] = { - "UserSessionManager": UserSessionManager, - "UserManager": UserManager, - "Terminal": Terminal, + "user-session-manager": UserSessionManager, + "user-manager": UserManager, + "terminal": Terminal, } network_interfaces: Dict[str, RouterInterface] = {} @@ -1385,6 +1386,12 @@ class Router(NetworkNode, discriminator="router"): return False + def subject_to_acl(self, frame: Frame) -> bool: + """Check that frame is subject to ACL rules.""" + if frame.ip.protocol == "udp" and frame.is_arp: + return False + return True + def receive_frame(self, frame: Frame, from_network_interface: RouterInterface): """ Processes an incoming frame received on one of the router's interfaces. @@ -1398,8 +1405,12 @@ class Router(NetworkNode, discriminator="router"): if self.operating_state != NodeOperatingState.ON: return - # Check if it's permitted - permitted, rule = self.acl.is_permitted(frame) + if self.subject_to_acl(frame=frame): + # Check if it's permitted + permitted, rule = self.acl.is_permitted(frame) + else: + permitted = True + rule = None if not permitted: at_port = self._get_port_of_nic(from_network_interface) diff --git a/src/primaite/simulator/network/transmission/data_link_layer.py b/src/primaite/simulator/network/transmission/data_link_layer.py index e7c2a124..a07194a4 100644 --- a/src/primaite/simulator/network/transmission/data_link_layer.py +++ b/src/primaite/simulator/network/transmission/data_link_layer.py @@ -163,7 +163,7 @@ class Frame(BaseModel): """ Checks if the Frame is an ARP (Address Resolution Protocol) packet. - This is determined by checking if the destination port of the TCP header is equal to the ARP port. + This is determined by checking if the destination and source port of the UDP header is equal to the ARP port. :return: True if the Frame is an ARP packet, otherwise False. """ diff --git a/src/primaite/simulator/system/services/arp/arp.py b/src/primaite/simulator/system/services/arp/arp.py index b0630d5d..c6b687ce 100644 --- a/src/primaite/simulator/system/services/arp/arp.py +++ b/src/primaite/simulator/system/services/arp/arp.py @@ -55,7 +55,7 @@ class ARP(Service, discriminator="arp"): :param markdown: If True, format the output as Markdown. Otherwise, use plain text. """ - table = PrettyTable(["IP Address", "MAC Address", "Via"]) + table = PrettyTable(["IP Address", "MAC Address", "Via", "Port"]) if markdown: table.set_style(MARKDOWN) table.align = "l" @@ -66,6 +66,7 @@ class ARP(Service, discriminator="arp"): str(ip), arp.mac_address, self.software_manager.node.network_interfaces[arp.network_interface_uuid].mac_address, + self.software_manager.node.network_interfaces[arp.network_interface_uuid].port_num, ] ) print(table) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 2ce7d176..112f6abc 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -142,12 +142,20 @@ class Terminal(Service, discriminator="terminal"): _client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {} """Dictionary of connect requests made to remote nodes.""" + _last_response: Optional[RequestResponse] = None + """Last response received from RequestManager, for returning remote RequestResponse.""" + def __init__(self, **kwargs): kwargs["name"] = "terminal" kwargs["port"] = PORT_LOOKUP["SSH"] kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) + @property + def last_response(self) -> Optional[RequestResponse]: + """Public version of _last_response attribute.""" + return self._last_response + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -186,7 +194,7 @@ class Terminal(Service, discriminator="terminal"): return RequestResponse(status="failure", data={}) rm.add_request( - "node-session-remote-login", + "node_session_remote_login", request_type=RequestType(func=_remote_login), ) @@ -209,28 +217,45 @@ class Terminal(Service, discriminator="terminal"): command: str = request[1]["command"] remote_connection = self._get_connection_from_ip(ip_address=ip_address) if remote_connection: - outcome = remote_connection.execute(command) - if outcome: - return RequestResponse( - status="success", - data={}, - ) - else: - return RequestResponse( - status="failure", - data={}, - ) + remote_connection.execute(command) + return self.last_response if not None else RequestResponse(status="failure", data={}) + return RequestResponse( + status="failure", + data={"reason": "Failed to execute command."}, + ) rm.add_request( "send_remote_command", request_type=RequestType(func=remote_execute_request), ) + def local_execute_request(request: RequestFormat, context: Dict) -> RequestResponse: + """Executes a command using a local terminal session.""" + command: str = request[2]["command"] + local_connection = self._process_local_login(username=request[0], password=request[1]) + if local_connection: + outcome = local_connection.execute(command) + if outcome: + return RequestResponse( + status="success", + data={"reason": outcome}, + ) + return RequestResponse( + status="success", + data={"reason": "Local Terminal failed to resolve command. Potentially invalid credentials?"}, + ) + + rm.add_request( + "send_local_command", + request_type=RequestType(func=local_execute_request), + ) + return rm def execute(self, command: List[Any]) -> Optional[RequestResponse]: """Execute a passed ssh command via the request manager.""" - return self.parent.apply_request(command) + self._last_response = self.parent.apply_request(command) + return self._last_response def _get_connection_from_ip(self, ip_address: IPv4Address) -> Optional[RemoteTerminalConnection]: """Find Remote Terminal Connection from a given IP.""" @@ -409,6 +434,8 @@ class Terminal(Service, discriminator="terminal"): """ source_ip = kwargs["frame"].ip.src_ip_address self.sys_log.info(f"{self.name}: Received payload: {payload}. Source: {source_ip}") + self._last_response = None # Clear last response + if isinstance(payload, SSHPacket): if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: # validate & add connection @@ -457,6 +484,9 @@ class Terminal(Service, discriminator="terminal"): session_id=session_id, source_ip=source_ip, ) + self._last_response: RequestResponse = RequestResponse( + status="success", data={"reason": "Login Successful"} + ) elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST: # Requesting a command to be executed @@ -468,12 +498,32 @@ class Terminal(Service, discriminator="terminal"): payload.connection_uuid ) remote_session.last_active_step = self.software_manager.node.user_session_manager.current_timestep - self.execute(command) + self._last_response: RequestResponse = self.execute(command) + + if self._last_response.status == "success": + transport_message = SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS + else: + transport_message = SSHTransportMessage.SSH_MSG_SERVICE_FAILED + + payload: SSHPacket = SSHPacket( + payload=self._last_response, + transport_message=transport_message, + connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_DATA, + ) + self.software_manager.send_payload_to_session_manager( + payload=payload, dest_port=self.port, session_id=session_id + ) return True else: self.sys_log.error( f"{self.name}: Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command." ) + elif ( + payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS + or SSHTransportMessage.SSH_MSG_SERVICE_FAILED + ): + # Likely receiving command ack from remote. + self._last_response = payload.payload if isinstance(payload, dict) and payload.get("type"): if payload["type"] == "disconnect": diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 2eddefc1..3f8760c4 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -117,37 +117,44 @@ class WebServer(Service, discriminator="web-server"): :type: payload: HttpRequestPacket """ response = HttpResponsePacket(status_code=HttpStatusCode.NOT_FOUND, payload=payload) - try: - parsed_url = urlparse(payload.request_url) - path = parsed_url.path.strip("/") - if len(path) < 1: + parsed_url = urlparse(payload.request_url) + path = parsed_url.path.strip("/") if parsed_url and parsed_url.path else "" + + if len(path) < 1: + # query succeeded + response.status_code = HttpStatusCode.OK + + if path.startswith("users"): + # get data from DatabaseServer + # get all users + if not self._establish_db_connection(): + # unable to create a db connection + response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR + return response + + if self.db_connection.query("SELECT"): # query succeeded + self.set_health_state(SoftwareHealthState.GOOD) response.status_code = HttpStatusCode.OK + else: + self.set_health_state(SoftwareHealthState.COMPROMISED) + return response - if path.startswith("users"): - # get data from DatabaseServer - # get all users - if not self.db_connection: - self._establish_db_connection() - - if self.db_connection.query("SELECT"): - # query succeeded - self.set_health_state(SoftwareHealthState.GOOD) - response.status_code = HttpStatusCode.OK - else: - self.set_health_state(SoftwareHealthState.COMPROMISED) - - return response - except Exception: # TODO: refactor this. Likely to cause silent bugs. (ADO ticket #2345 ) - # something went wrong on the server - response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR - return response - - def _establish_db_connection(self) -> None: + def _establish_db_connection(self) -> bool: """Establish a connection to db.""" + # if active db connection, return true + if self.db_connection: + return True + + # otherwise, try to create db connection db_client = self.software_manager.software.get("database-client") + + if db_client is None: + return False # database client not installed + self.db_connection: DatabaseClientConnection = db_client.get_new_connection() + return self.db_connection is not None def send( self, diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index c9ac5f8d..7ffe4c08 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -25,7 +25,19 @@ game: - ICMP - TCP - UDP - + thresholds: + nmne: + high: 100 + medium: 25 + low: 5 + file_access: + high: 10 + medium: 5 + low: 2 + app_executions: + high: 5 + medium: 3 + low: 2 agents: - ref: client_2_green_user team: GREEN @@ -64,10 +76,16 @@ agents: options: hosts: - hostname: client_1 + applications: + - application_name: WebBrowser + folders: + - folder_name: root + files: + - file_name: "test.txt" - hostname: client_2 - hostname: client_3 num_services: 1 - num_applications: 0 + num_applications: 1 num_folders: 1 num_files: 1 num_nics: 2 @@ -182,6 +200,10 @@ simulation: options: ntp_server_ip: 192.168.1.10 - type: ntp-server + folders: + - folder_name: root + files: + - file_name: test.txt - hostname: client_2 type: computer ip_address: 192.168.10.22 diff --git a/tests/assets/configs/nodes_with_initial_files.yaml b/tests/assets/configs/nodes_with_initial_files.yaml new file mode 100644 index 00000000..d4c6406b --- /dev/null +++ b/tests/assets/configs/nodes_with_initial_files.yaml @@ -0,0 +1,226 @@ +# Basic Switched network +# +# -------------- -------------- -------------- +# | client_1 |------| switch_1 |------| client_2 | +# -------------- -------------- -------------- +# +io_settings: + save_step_metadata: false + save_pcap_logs: true + save_sys_logs: true + sys_log_level: WARNING + agent_log_level: INFO + save_agent_logs: true + write_agent_log_to_terminal: True + + +game: + max_episode_length: 256 + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + +agents: + - ref: client_2_green_user + team: GREEN + type: periodic-agent + action_space: + action_map: + 0: + action: do-nothing + options: {} + 1: + action: node-application-execute + options: + node_id: 0 + application_id: 0 + + agent_settings: + possible_start_nodes: [client_2,] + target_application: web-browser + start_step: 5 + frequency: 4 + variance: 3 + + + + - ref: defender + team: BLUE + type: proxy-agent + + observation_space: + type: custom + options: + components: + - type: nodes + label: NODES + options: + hosts: + - hostname: client_1 + - hostname: client_2 + - hostname: client_3 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + include_nmne: false + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.23 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: links + label: LINKS + options: + link_references: + - switch_1:eth-1<->client_1:eth-1 + - switch_1:eth-2<->client_2:eth-1 + - type: none + label: ICS + options: {} + + action_space: + action_map: + 0: + action: do-nothing + options: {} + + reward_function: + reward_components: + - type: database-file-integrity + weight: 0.5 + options: + node_hostname: database_server + folder_name: database + file_name: database.db + + - type: web-server-404-penalty + weight: 0.5 + options: + node_hostname: web_server + service_name: web_server_web_service + + + agent_settings: + flatten_obs: true + +simulation: + network: + nodes: + + - type: switch + hostname: switch_1 + num_ports: 8 + + - hostname: client_1 + type: computer + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: ransomware-script + - type: web-browser + options: + target_url: http://arcd.com/users/ + - type: database-client + options: + db_server_ip: 192.168.1.10 + server_password: arcd + - type: data-manipulation-bot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.21 + server_password: arcd + - type: dos-bot + options: + target_ip_address: 192.168.10.21 + payload: SPOOF DATA + port_scan_p_of_success: 0.8 + services: + - type: dns-client + options: + dns_server: 192.168.1.10 + - type: dns-server + options: + domain_mapping: + arcd.com: 192.168.1.10 + - type: database-service + options: + backup_server_ip: 192.168.1.10 + - type: web-server + - type: ftp-server + options: + server_password: arcd + - type: ntp-client + options: + ntp_server_ip: 192.168.1.10 + - type: ntp-server + - hostname: client_2 + type: computer + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + folders: + - folder_name: empty_folder + - folder_name: downloads + files: + - file_name: "test.txt" + - file_name: "another_file.pwtwoti" + - folder_name: root + files: + - file_name: passwords + size: 663 + type: TXT + # pre installed services and applications + - hostname: client_3 + type: computer + ip_address: 192.168.10.23 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + start_up_duration: 0 + shut_down_duration: 0 + operating_state: "OFF" + # pre installed services and applications + + links: + - endpoint_a_hostname: switch_1 + endpoint_a_port: 1 + endpoint_b_hostname: client_1 + endpoint_b_port: 1 + bandwidth: 200 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 2 + endpoint_b_hostname: client_2 + endpoint_b_port: 1 + bandwidth: 200 diff --git a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py index 4153adc0..627fc53b 100644 --- a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py +++ b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py @@ -8,7 +8,7 @@ from primaite.config.load import data_manipulation_config_path from primaite.game.game import PrimaiteGame from tests import TEST_ASSETS_ROOT -BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" +BASIC_SWITCHED_NETWORK_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" def load_config(config_path: Union[str, Path]) -> PrimaiteGame: @@ -24,3 +24,42 @@ def test_thresholds(): game = load_config(data_manipulation_config_path()) assert game.options.thresholds is not None + + +def test_nmne_threshold(): + """Test that the NMNE thresholds are properly loaded in by observation.""" + game = load_config(BASIC_SWITCHED_NETWORK_CONFIG) + + assert game.options.thresholds["nmne"] is not None + + # get NIC observation + nic_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].nics[0] + assert nic_obs.low_nmne_threshold == 5 + assert nic_obs.med_nmne_threshold == 25 + assert nic_obs.high_nmne_threshold == 100 + + +def test_file_access_threshold(): + """Test that the NMNE thresholds are properly loaded in by observation.""" + game = load_config(BASIC_SWITCHED_NETWORK_CONFIG) + + assert game.options.thresholds["file_access"] is not None + + # get file observation + file_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].folders[0].files[0] + assert file_obs.low_file_access_threshold == 2 + assert file_obs.med_file_access_threshold == 5 + assert file_obs.high_file_access_threshold == 10 + + +def test_app_executions_threshold(): + """Test that the NMNE thresholds are properly loaded in by observation.""" + game = load_config(BASIC_SWITCHED_NETWORK_CONFIG) + + assert game.options.thresholds["app_executions"] is not None + + # get application observation + app_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].applications[0] + assert app_obs.low_app_execution_threshold == 2 + assert app_obs.med_app_execution_threshold == 3 + assert app_obs.high_app_execution_threshold == 5 diff --git a/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py b/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py new file mode 100644 index 00000000..4c99a39f --- /dev/null +++ b/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py @@ -0,0 +1,64 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from pathlib import Path +from typing import Union + +import yaml + +from primaite.game.game import PrimaiteGame +from primaite.simulator.file_system.file_type import FileType +from tests import TEST_ASSETS_ROOT + +BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/nodes_with_initial_files.yaml" + + +def load_config(config_path: Union[str, Path]) -> PrimaiteGame: + """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) + + return PrimaiteGame.from_config(cfg) + + +def test_node_file_system_from_config(): + """Test that the appropriate files are instantiated in nodes when loaded from config.""" + game = load_config(BASIC_CONFIG) + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + assert client_1.software_manager.software.get("database-service") # database service should be installed + assert client_1.file_system.get_file(folder_name="database", file_name="database.db") # database files should exist + + assert client_1.software_manager.software.get("web-server") # web server should be installed + assert client_1.file_system.get_file(folder_name="primaite", file_name="index.html") # web files should exist + + client_2 = game.simulation.network.get_node_by_hostname("client_2") + + # database service should not be installed + assert client_2.software_manager.software.get("database-service") is None + # database files should not exist + assert client_2.file_system.get_file(folder_name="database", file_name="database.db") is None + + # web server should not be installed + assert client_2.software_manager.software.get("web-server") is None + # web files should not exist + assert client_2.file_system.get_file(folder_name="primaite", file_name="index.html") is None + + empty_folder = client_2.file_system.get_folder(folder_name="empty_folder") + assert empty_folder + assert len(empty_folder.files) == 0 # should have no files + + password_file = client_2.file_system.get_file(folder_name="root", file_name="passwords.txt") + assert password_file # should exist + assert password_file.file_type is FileType.TXT + assert password_file.size == 663 + + downloads_folder = client_2.file_system.get_folder(folder_name="downloads") + assert downloads_folder # downloads folder should exist + + test_txt = downloads_folder.get_file(file_name="test.txt") + assert test_txt # test.txt should exist + assert test_txt.file_type is FileType.TXT + + unknown_file_type = downloads_folder.get_file(file_name="another_file.pwtwoti") + assert unknown_file_type # unknown_file_type should exist + assert unknown_file_type.file_type is FileType.UNKNOWN diff --git a/tests/integration_tests/extensions/nodes/giga_switch.py b/tests/integration_tests/extensions/nodes/giga_switch.py index d9599618..5c202ed2 100644 --- a/tests/integration_tests/extensions/nodes/giga_switch.py +++ b/tests/integration_tests/extensions/nodes/giga_switch.py @@ -49,7 +49,7 @@ class GigaSwitch(NetworkNode, discriminator="gigaswitch"): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.hostname} Switch Ports" + table.title = f"{self.config.hostname} Switch Ports" for port_num, port in self.network_interface.items(): table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"]) print(table) diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py index c39d8263..3ee97fb7 100644 --- a/tests/integration_tests/game_layer/actions/test_terminal_actions.py +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -106,7 +106,6 @@ def test_remote_login_change_password(game_and_agent_fixture: Tuple[PrimaiteGame "username": "user123", "current_password": "password", "new_password": "different_password", - "remote_ip": str(server_1.network_interface[1].ip_address), }, ) agent.store_action(action) @@ -146,7 +145,6 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam "username": "user123", "current_password": "password", "new_password": "different_password", - "remote_ip": str(server_1.network_interface[1].ip_address), }, ) agent.store_action(action) @@ -166,3 +164,55 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam assert server_1.file_system.get_folder("folder123") is None assert server_1.file_system.get_file("folder123", "doggo.pdf") is None + + +def test_local_terminal(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + # create a new user account on server_1 that will be logged into remotely + client_1_usm: UserManager = client_1.software_manager.software["user-manager"] + client_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "node-send-local-command", + { + "node_name": "client_1", + "username": "user123", + "password": "password", + "command": ["file_system", "create", "file", "folder123", "doggo.pdf", False], + }, + ) + agent.store_action(action) + game.step() + + assert client_1.file_system.get_folder("folder123") + assert client_1.file_system.get_file("folder123", "doggo.pdf") + + # Change password + action = ( + "node-account-change-password", + { + "node_name": "client_1", + "username": "user123", + "current_password": "password", + "new_password": "different_password", + }, + ) + agent.store_action(action) + game.step() + + action = ( + "node-send-local-command", + { + "node_name": "client_1", + "username": "user123", + "password": "password", + "command": ["file_system", "create", "file", "folder123", "cat.pdf", False], + }, + ) + agent.store_action(action) + game.step() + + assert client_1.file_system.get_file("folder123", "cat.pdf") is None + client_1.session_manager.show() diff --git a/tests/integration_tests/game_layer/actions/test_user_account_actions.py b/tests/integration_tests/game_layer/actions/test_user_account_actions.py new file mode 100644 index 00000000..26b871db --- /dev/null +++ b/tests/integration_tests/game_layer/actions/test_user_account_actions.py @@ -0,0 +1,176 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest + +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.network.router import ACLAction +from primaite.utils.validation.port import Port, PORT_LOOKUP + + +@pytest.fixture +def game_and_agent_fixture(game_and_agent): + """Create a game with a simple agent that can be controlled by the tests.""" + game, agent = game_and_agent + + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + client_1.start_up_duration = 3 + + return (game, agent) + + +def test_user_account_add_user_action(game_and_agent_fixture): + """Tests the add user account action.""" + game, agent = game_and_agent_fixture + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + assert len(client_1.user_manager.users) == 1 # admin is created by default + assert len(client_1.user_manager.admins) == 1 + + # add admin account + action = ( + "node-account-add-user", + {"node_name": "client_1", "username": "admin_2", "password": "e-tronic-boogaloo", "is_admin": True}, + ) + agent.store_action(action) + game.step() + + assert len(client_1.user_manager.users) == 2 # new user added + assert len(client_1.user_manager.admins) == 2 + + # add non admin account + action = ( + "node-account-add-user", + {"node_name": "client_1", "username": "leeroy.jenkins", "password": "no_plan_needed", "is_admin": False}, + ) + agent.store_action(action) + game.step() + + assert len(client_1.user_manager.users) == 3 # new user added + assert len(client_1.user_manager.admins) == 2 + + +def test_user_account_disable_user_action(game_and_agent_fixture): + """Tests the disable user account action.""" + game, agent = game_and_agent_fixture + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + client_1.user_manager.add_user(username="test", password="password", is_admin=True) + assert len(client_1.user_manager.users) == 2 # new user added + assert len(client_1.user_manager.admins) == 2 + + test_user = client_1.user_manager.users.get("test") + assert test_user + assert test_user.disabled is not True + + # disable test account + action = ( + "node-account-disable-user", + { + "node_name": "client_1", + "username": "test", + }, + ) + agent.store_action(action) + game.step() + assert test_user.disabled + + +def test_user_account_change_password_action(game_and_agent_fixture): + """Tests the change password user account action.""" + game, agent = game_and_agent_fixture + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + client_1.user_manager.add_user(username="test", password="password", is_admin=True) + + test_user = client_1.user_manager.users.get("test") + assert test_user.password == "password" + + # change account password + action = ( + "node-account-change-password", + {"node_name": "client_1", "username": "test", "current_password": "password", "new_password": "2Hard_2_Hack"}, + ) + agent.store_action(action) + game.step() + + assert test_user.password == "2Hard_2_Hack" + + +def test_user_account_create_terminal_action(game_and_agent_fixture): + """Tests that agents can use the terminal to create new users.""" + game, agent = game_and_agent_fixture + + router = game.simulation.network.get_node_by_hostname("router") + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=4) + + server_1 = game.simulation.network.get_node_by_hostname("server_1") + server_1_usm = server_1.software_manager.software["user-manager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "node-session-remote-login", + { + "node_name": "client_1", + "username": "user123", + "password": "password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "success" + + # Create a new user account via terminal. + action = ( + "node-send-remote-command", + { + "node_name": "client_1", + "remote_ip": str(server_1.network_interface[1].ip_address), + "command": ["service", "user-manager", "add_user", "new_user", "new_pass", True], + }, + ) + agent.store_action(action) + game.step() + new_user = server_1.user_manager.users.get("new_user") + assert new_user + assert new_user.password == "new_pass" + assert new_user.disabled is not True + + +def test_user_account_disable_terminal_action(game_and_agent_fixture): + """Tests that agents can use the terminal to disable users.""" + game, agent = game_and_agent_fixture + router = game.simulation.network.get_node_by_hostname("router") + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=4) + + server_1 = game.simulation.network.get_node_by_hostname("server_1") + server_1_usm = server_1.software_manager.software["user-manager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "node-session-remote-login", + { + "node_name": "client_1", + "username": "user123", + "password": "password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "success" + + # Disable a user via terminal + action = ( + "node-send-remote-command", + { + "node_name": "client_1", + "remote_ip": str(server_1.network_interface[1].ip_address), + "command": ["service", "user-manager", "disable_user", "user123"], + }, + ) + agent.store_action(action) + game.step() + + new_user = server_1.user_manager.users.get("user123") + assert new_user + assert new_user.disabled is True diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py index 7323461c..722fd294 100644 --- a/tests/integration_tests/game_layer/observations/test_file_system_observations.py +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -44,6 +44,38 @@ def test_file_observation(simulation): assert observation_state.get("health_status") == 3 # corrupted +def test_config_file_access_categories(simulation): + pc: Computer = simulation.network.get_node_by_hostname("client_1") + file_obs = FileObservation( + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"], + include_num_access=False, + file_system_requires_scan=True, + thresholds={"file_access": {"low": 3, "medium": 6, "high": 9}}, + ) + + assert file_obs.high_file_access_threshold == 9 + assert file_obs.med_file_access_threshold == 6 + assert file_obs.low_file_access_threshold == 3 + + with pytest.raises(Exception): + # should throw an error + FileObservation( + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"], + include_num_access=False, + file_system_requires_scan=True, + thresholds={"file_access": {"low": 9, "medium": 6, "high": 9}}, + ) + + with pytest.raises(Exception): + # should throw an error + FileObservation( + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"], + include_num_access=False, + file_system_requires_scan=True, + thresholds={"file_access": {"low": 3, "medium": 9, "high": 9}}, + ) + + def test_folder_observation(simulation): """Test the folder observation.""" pc: Computer = simulation.network.get_node_by_hostname("client_1") diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index b5e5ca81..30eccb06 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -77,6 +77,14 @@ def test_nic(simulation): nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True) + # The Simulation object created by the fixture also creates the + # NICObservation class with the NICObservation.capture_nmnme class variable + # set to False. Under normal (non-test) circumstances this class variable + # is set from a config file such as data_manipulation.yaml. So although + # capture_nmne is set to True in the NetworkInterface class it's still False + # in the NICObservation class so we set it now. + nic_obs.capture_nmne = True + # Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs nmne_config = { "capture_nmne": True, # Enable the capture of MNEs @@ -115,14 +123,11 @@ def test_nic_categories(simulation): assert nic_obs.low_nmne_threshold == 0 # default -@pytest.mark.skip(reason="Feature not implemented yet") def test_config_nic_categories(simulation): pc: Computer = simulation.network.get_node_by_hostname("client_1") nic_obs = NICObservation( - where=["network", "nodes", pc.hostname, "NICs", 1], - low_nmne_threshold=3, - med_nmne_threshold=6, - high_nmne_threshold=9, + where=["network", "nodes", pc.config.hostname, "NICs", 1], + thresholds={"nmne": {"low": 3, "medium": 6, "high": 9}}, include_nmne=True, ) @@ -133,20 +138,16 @@ def test_config_nic_categories(simulation): with pytest.raises(Exception): # should throw an error NICObservation( - where=["network", "nodes", pc.hostname, "NICs", 1], - low_nmne_threshold=9, - med_nmne_threshold=6, - high_nmne_threshold=9, + where=["network", "nodes", pc.config.hostname, "NICs", 1], + thresholds={"nmne": {"low": 9, "medium": 6, "high": 9}}, include_nmne=True, ) with pytest.raises(Exception): # should throw an error NICObservation( - where=["network", "nodes", pc.hostname, "NICs", 1], - low_nmne_threshold=3, - med_nmne_threshold=9, - high_nmne_threshold=9, + where=["network", "nodes", pc.config.hostname, "NICs", 1], + thresholds={"nmne": {"low": 3, "medium": 9, "high": 9}}, include_nmne=True, ) diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index 09eb3fe4..aef60bc2 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -39,6 +39,8 @@ def test_host_observation(simulation): folders=[], network_interfaces=[], file_system_requires_scan=True, + services_requires_scan=True, + applications_requires_scan=True, include_users=False, ) diff --git a/tests/integration_tests/game_layer/observations/test_obs_data_capture.py b/tests/integration_tests/game_layer/observations/test_obs_data_capture.py new file mode 100644 index 00000000..e8bdea22 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_obs_data_capture.py @@ -0,0 +1,28 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import json + +from primaite.session.environment import PrimaiteGymEnv +from primaite.session.io import PrimaiteIO +from tests import TEST_ASSETS_ROOT + +DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml" + + +def test_obs_data_in_log_file(): + """Create a log file of AgentHistoryItems and check observation data is + included. Assumes that data_manipulation.yaml has an agent labelled + 'defender' with a non-null observation space. + The log file will be in: + primaite/VERSION/sessions/YYYY-MM-DD/HH-MM-SS/agent_actions + """ + env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG) + env.reset() + for _ in range(10): + env.step(0) + env.reset() + io = PrimaiteIO() + path = io.generate_agent_actions_save_path(episode=1) + with open(path, "r") as f: + j = json.load(f) + + assert type(j["0"]["defender"]["observation"]) == dict diff --git a/tests/integration_tests/game_layer/observations/test_software_observations.py b/tests/integration_tests/game_layer/observations/test_software_observations.py index 28cdaf01..1ebff10c 100644 --- a/tests/integration_tests/game_layer/observations/test_software_observations.py +++ b/tests/integration_tests/game_layer/observations/test_software_observations.py @@ -29,7 +29,9 @@ def test_service_observation(simulation): ntp_server = pc.software_manager.software.get("ntp-server") assert ntp_server - service_obs = ServiceObservation(where=["network", "nodes", pc.config.hostname, "services", "ntp-server"]) + service_obs = ServiceObservation( + where=["network", "nodes", pc.config.hostname, "services", "ntp-server"], services_requires_scan=True + ) assert service_obs.space["operating_status"] == spaces.Discrete(7) assert service_obs.space["health_status"] == spaces.Discrete(5) @@ -54,7 +56,9 @@ def test_application_observation(simulation): web_browser: WebBrowser = pc.software_manager.software.get("web-browser") assert web_browser - app_obs = ApplicationObservation(where=["network", "nodes", pc.config.hostname, "applications", "web-browser"]) + app_obs = ApplicationObservation( + where=["network", "nodes", pc.config.hostname, "applications", "web-browser"], applications_requires_scan=True + ) web_browser.close() observation_state = app_obs.observe(simulation.describe_state()) @@ -69,3 +73,33 @@ def test_application_observation(simulation): assert observation_state.get("health_status") == 1 assert observation_state.get("operating_status") == 1 # running assert observation_state.get("num_executions") == 1 + + +def test_application_executions_categories(simulation): + pc: Computer = simulation.network.get_node_by_hostname("client_1") + + app_obs = ApplicationObservation( + where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"], + applications_requires_scan=False, + thresholds={"app_executions": {"low": 3, "medium": 6, "high": 9}}, + ) + + assert app_obs.high_app_execution_threshold == 9 + assert app_obs.med_app_execution_threshold == 6 + assert app_obs.low_app_execution_threshold == 3 + + with pytest.raises(Exception): + # should throw an error + ApplicationObservation( + where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"], + applications_requires_scan=False, + thresholds={"app_executions": {"low": 9, "medium": 6, "high": 9}}, + ) + + with pytest.raises(Exception): + # should throw an error + ApplicationObservation( + where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"], + applications_requires_scan=False, + thresholds={"app_executions": {"low": 3, "medium": 9, "high": 9}}, + ) diff --git a/tests/integration_tests/game_layer/test_RNG_seed.py b/tests/integration_tests/game_layer/test_RNG_seed.py index 45fa445d..2b80e153 100644 --- a/tests/integration_tests/game_layer/test_RNG_seed.py +++ b/tests/integration_tests/game_layer/test_RNG_seed.py @@ -7,6 +7,7 @@ import yaml from primaite.config.load import data_manipulation_config_path from primaite.game.agent.interface import AgentHistoryItem from primaite.session.environment import PrimaiteGymEnv +from primaite.simulator import SIM_OUTPUT @pytest.fixture() @@ -33,6 +34,11 @@ def test_rng_seed_set(create_env): assert a == b + # Check that seed log file was created. + path = SIM_OUTPUT.path / "seed.log" + with open(path, "r") as file: + assert file + def test_rng_seed_unset(create_env): """Test with no RNG seed.""" @@ -48,3 +54,19 @@ def test_rng_seed_unset(create_env): b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do-nothing"] assert a != b + + +def test_for_generated_seed(): + """ + Show that setting generate_seed_value to true producess a valid seed. + """ + with open(data_manipulation_config_path(), "r") as f: + cfg = yaml.safe_load(f) + + cfg["game"]["generate_seed_value"] = True + PrimaiteGymEnv(env_config=cfg) + path = SIM_OUTPUT.path / "seed.log" + with open(path, "r") as file: + data = file.read() + + assert data.split(" ")[3] != None diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index 03b94ab7..59bee385 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -22,6 +22,7 @@ from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus from primaite.simulator.network.hardware.nodes.network.firewall import Firewall +from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.software import SoftwareHealthState @@ -107,7 +108,7 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox """ Test that the RouterACLAddRuleAction can form a request and that it is accepted by the simulation. - The acl starts off with 4 rules, and we add a rule, and check that the acl now has 5 rules. + The ACL starts off with 4 rules, and we add a rule, and check that the ACL now has 5 rules. """ game, agent = game_and_agent @@ -164,11 +165,9 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox }, ) agent.store_action(action) - print(agent.most_recent_action) game.step() - print(agent.most_recent_action) + # 5: Check that the ACL now has 6 rules, but that server_1 can still ping server_2 - print(router.acl.show()) assert router.acl.num_rules == 6 assert server_1.ping("10.0.2.3") # Can ping server_2 @@ -180,7 +179,8 @@ def test_router_acl_removerule_integration(game_and_agent: Tuple[PrimaiteGame, P # 1: Check that http traffic is going across the network nicely. client_1 = game.simulation.network.get_node_by_hostname("client_1") server_1 = game.simulation.network.get_node_by_hostname("server_1") - router = game.simulation.network.get_node_by_hostname("router") + router: Router = game.simulation.network.get_node_by_hostname("router") + assert router.acl.num_rules == 4 browser: WebBrowser = client_1.software_manager.software.get("web-browser") browser.run() diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index ea7fbc99..80e7c3b3 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -1,5 +1,11 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from itertools import product + +import yaml + +from primaite.config.load import data_manipulation_config_path from primaite.game.agent.observations.nic_observations import NICObservation +from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server @@ -277,3 +283,19 @@ def test_capture_nmne_observations(uc2_network: Network): assert web_nic_obs["outbound"] == expected_nmne assert db_nic_obs["inbound"] == expected_nmne uc2_network.apply_timestep(timestep=0) + + +def test_nmne_parameter_settings(): + """ + Check that the four permutations of the values of capture_nmne and + include_nmne work as expected. + """ + + with open(data_manipulation_config_path(), "r") as f: + cfg = yaml.safe_load(f) + + DEFENDER = 3 + for capture, include in product([True, False], [True, False]): + cfg["simulation"]["network"]["nmne_config"]["capture_nmne"] = capture + cfg["agents"][DEFENDER]["observation_space"]["options"]["components"][0]["options"]["include_nmne"] = include + PrimaiteGymEnv(env_config=cfg) diff --git a/tests/integration_tests/system/test_arp.py b/tests/integration_tests/system/test_arp.py index 055d58c6..b9a92255 100644 --- a/tests/integration_tests/system/test_arp.py +++ b/tests/integration_tests/system/test_arp.py @@ -1,6 +1,7 @@ -# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK -from primaite.simulator.network.hardware.nodes.network.router import RouterARP +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterARP from primaite.simulator.system.services.arp.arp import ARP +from primaite.utils.validation.port import PORT_LOOKUP from tests.integration_tests.network.test_routing import multi_hop_network @@ -48,3 +49,19 @@ def test_arp_fails_for_network_address_between_routers(multi_hop_network): actual_result = router_1_arp.get_arp_cache_mac_address(router_1.network_interface[1].ip_network.network_address) assert actual_result == expected_result + + +def test_arp_not_affected_by_acl(multi_hop_network): + pc_a = multi_hop_network.get_node_by_hostname("pc_a") + router_1: Router = multi_hop_network.get_node_by_hostname("router_1") + + # Add explicit rule to block ARP traffic. This shouldn't actually stop ARP traffic + # as it operates a different layer within the network. + router_1.acl.add_rule(action=ACLAction.DENY, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=23) + + pc_a_arp: ARP = pc_a.software_manager.arp + + expected_result = router_1.network_interface[2].mac_address + actual_result = pc_a_arp.get_arp_cache_mac_address(router_1.network_interface[2].ip_address) + + assert actual_result == expected_result diff --git a/tests/unit_tests/_primaite/_game/_agent/test_observations.py b/tests/unit_tests/_primaite/_game/_agent/test_observations.py index 5d5921a9..3df6ca0a 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_observations.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_observations.py @@ -1,10 +1,11 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +import json from typing import List import pytest import yaml -from primaite.game.agent.observations import ObservationManager +from primaite.game.agent.observations import ApplicationObservation, ObservationManager, ServiceObservation from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation from primaite.game.agent.observations.host_observations import HostObservation @@ -136,3 +137,227 @@ class TestFileSystemRequiresScan: [], files=[], num_files=0, include_num_access=False, file_system_requires_scan=True ) assert obs_requiring_scan.observe(folder_state)["health_status"] == 1 + + +class TestServicesRequiresScan: + @pytest.mark.parametrize( + ("yaml_option_string", "expected_val"), + ( + ("services_requires_scan: true", True), + ("services_requires_scan: false", False), + (" ", True), + ), + ) + def test_obs_config(self, yaml_option_string, expected_val): + """Check that the default behaviour is to set service_requires_scan to True.""" + obs_cfg_yaml = f""" + type: custom + options: + components: + - type: nodes + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: web-server + - service_name: dns-client + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + services: + - service_name: ftp-server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 3 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + {yaml_option_string} + include_nmne: true + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: links + label: LINKS + options: + link_references: + - router_1:eth-1<->switch_1:eth-8 + - router_1:eth-2<->switch_2:eth-8 + - switch_1:eth-1<->domain_controller:eth-1 + - switch_1:eth-2<->web_server:eth-1 + - switch_1:eth-3<->database_server:eth-1 + - switch_1:eth-4<->backup_server:eth-1 + - switch_1:eth-7<->security_suite:eth-1 + - switch_2:eth-1<->client_1:eth-1 + - switch_2:eth-2<->client_2:eth-1 + - switch_2:eth-7<->security_suite:eth-2 + - type: none + label: ICS + options: {{}} + + """ + + cfg = yaml.safe_load(obs_cfg_yaml) + manager = ObservationManager.from_config(cfg) + + hosts: List[HostObservation] = manager.obs.components["NODES"].hosts + for i, host in enumerate(hosts): + services: List[ServiceObservation] = host.services + for j, service in enumerate(services): + val = service.services_requires_scan + print(f"host {i} service {j} {val}") + assert val == expected_val # Make sure services require scan by default + + def test_services_requires_scan(self): + state = {"health_state_actual": 3, "health_state_visible": 1, "operating_state": 1} + + obs_requiring_scan = ServiceObservation([], services_requires_scan=True) + assert obs_requiring_scan.observe(state)["health_status"] == 1 # should be visible value + + obs_not_requiring_scan = ServiceObservation([], services_requires_scan=False) + assert obs_not_requiring_scan.observe(state)["health_status"] == 3 # should be actual value + + +class TestApplicationsRequiresScan: + @pytest.mark.parametrize( + ("yaml_option_string", "expected_val"), + ( + ("applications_requires_scan: true", True), + ("applications_requires_scan: false", False), + (" ", True), + ), + ) + def test_obs_config(self, yaml_option_string, expected_val): + """Check that the default behaviour is to set applications_requires_scan to True.""" + obs_cfg_yaml = f""" + type: custom + options: + components: + - type: nodes + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + applications: + - application_name: web-browser + - hostname: client_2 + applications: + - application_name: web-browser + - application_name: database-client + num_services: 0 + num_applications: 3 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + {yaml_option_string} + include_nmne: true + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: links + label: LINKS + options: + link_references: + - router_1:eth-1<->switch_1:eth-8 + - router_1:eth-2<->switch_2:eth-8 + - switch_1:eth-1<->domain_controller:eth-1 + - switch_1:eth-2<->web_server:eth-1 + - switch_1:eth-3<->database_server:eth-1 + - switch_1:eth-4<->backup_server:eth-1 + - switch_1:eth-7<->security_suite:eth-1 + - switch_2:eth-1<->client_1:eth-1 + - switch_2:eth-2<->client_2:eth-1 + - switch_2:eth-7<->security_suite:eth-2 + - type: none + label: ICS + options: {{}} + + """ + + cfg = yaml.safe_load(obs_cfg_yaml) + manager = ObservationManager.from_config(cfg) + + hosts: List[HostObservation] = manager.obs.components["NODES"].hosts + for i, host in enumerate(hosts): + services: List[ServiceObservation] = host.services + for j, service in enumerate(services): + val = service.services_requires_scan + print(f"host {i} service {j} {val}") + assert val == expected_val # Make sure applications require scan by default + + def test_applications_requires_scan(self): + state = {"health_state_actual": 3, "health_state_visible": 1, "operating_state": 1, "num_executions": 1} + + obs_requiring_scan = ApplicationObservation([], applications_requires_scan=True) + assert obs_requiring_scan.observe(state)["health_status"] == 1 # should be visible value + + obs_not_requiring_scan = ApplicationObservation([], applications_requires_scan=False) + assert obs_not_requiring_scan.observe(state)["health_status"] == 3 # should be actual value diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py index 91369f6c..81e05467 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py @@ -73,7 +73,7 @@ def test_ftp_should_not_process_commands_if_service_not_running(ftp_client): assert ftp_client_service._process_ftp_command(payload=payload).status_code is FTPStatusCode.ERROR -def test_ftp_tries_to_senf_file__that_does_not_exist(ftp_client): +def test_ftp_tries_to_send_file__that_does_not_exist(ftp_client): """Method send_file should return false if no file to send.""" assert ftp_client.file_system.get_file(folder_name="root", file_name="test.txt") is None diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 32fcae9a..3b2377e9 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -6,6 +6,7 @@ import pytest from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame +from primaite.interface.request import RequestResponse from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server @@ -442,3 +443,59 @@ def test_terminal_connection_timeout(basic_network): assert len(computer_b.user_session_manager.remote_sessions) == 0 assert not remote_connection.is_active + + +def test_terminal_last_response_updates(basic_network): + """Test that the _last_response within Terminal correctly updates.""" + network: Network = basic_network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") + + assert terminal_a.last_response is None + + remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11") + + # Last response should be a successful logon + assert terminal_a.last_response == RequestResponse(status="success", data={"reason": "Login Successful"}) + + remote_connection.execute(command=["software_manager", "application", "install", "ransomware-script"]) + + # Last response should now update following successful install + assert terminal_a.last_response == RequestResponse(status="success", data={}) + + remote_connection.execute(command=["software_manager", "application", "install", "ransomware-script"]) + + # Last response should now update to success, but with supplied reason. + assert terminal_a.last_response == RequestResponse(status="success", data={"reason": "already installed"}) + + remote_connection.execute(command=["file_system", "create", "file", "folder123", "doggo.pdf", False]) + + # Check file was created. + assert computer_b.file_system.access_file(folder_name="folder123", file_name="doggo.pdf") + + # Last response should be confirmation of file creation. + assert terminal_a.last_response == RequestResponse( + status="success", + data={"file_name": "doggo.pdf", "folder_name": "folder123", "file_type": "PDF", "file_size": 102400}, + ) + + remote_connection.execute( + command=[ + "service", + "ftp-client", + "send", + { + "dest_ip_address": "192.168.0.2", + "src_folder": "folder123", + "src_file_name": "cat.pdf", + "dest_folder": "root", + "dest_file_name": "cat.pdf", + }, + ] + ) + + assert terminal_a.last_response == RequestResponse( + status="failure", + data={"reason": "Unable to locate given file on local file system. Perhaps given options are invalid?"}, + )