diff --git a/CHANGELOG.md b/CHANGELOG.md index 01e45d2e..dcff5934 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +- Made environment reset completely recreate the game object. - Changed the red agent in the data manipulation scenario to randomly choose client 1 or client 2 to start its attack. - Changed the data manipulation scenario to include a second green agent on client 1. - Refactored actions and observations to be configurable via object name, instead of UUID. @@ -82,7 +83,8 @@ SessionManager. - `AirSpace` class to simulate wireless communications, managing wireless interfaces and facilitating the transmission of frames within specified frequencies. - `AirSpaceFrequency` enum for defining standard wireless frequencies, including 2.4 GHz and 5 GHz bands, to support realistic wireless network simulations. - `WirelessRouter` class, extending the `Router` class, to incorporate wireless networking capabilities alongside traditional wired connections. This class allows the configuration of wireless access points with specific IP settings and operating frequencies. - +- NMNE capturing capabilities to `NetworkInterface` class for detecting and logging Malicious Network Events. +- New `nmne_config` settings in the simulation configuration to enable NMNE capturing and specify keywords such as "DELETE". ### Changed - Integrated the RouteTable into the Routers frame processing. @@ -94,7 +96,8 @@ SessionManager. - Refactored all tests to utilise new `Node` subclasses (`Computer`, `Server`, `Router`, `Switch`) instead of creating generic `Node` instances and manually adding network interfaces. This change aligns test setups more closely with the intended use cases and hierarchies within the network simulation framework. - Updated all tests to employ the `Network()` class for managing nodes and their connections, ensuring a consistent and structured approach to setting up network topologies in testing scenarios. - **ACLRule Wildcard Masking**: Updated the `ACLRule` class to support IP ranges using wildcard masking. This enhancement allows for more flexible and granular control over traffic filtering, enabling the specification of broader or more specific IP address ranges in ACL rules. - +- Updated `NetworkInterface` documentation to reflect the new NMNE capturing features and how to use them. +- Integration of NMNE capturing functionality within the `NicObservation` class. ### Removed - Removed legacy simulation modules: `acl`, `common`, `environment`, `links`, `nodes`, `pol` diff --git a/docs/index.rst b/docs/index.rst index 9eae8adc..08e0ac21 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -108,6 +108,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE! source/simulation source/game_layer source/config + source/environment .. toctree:: :caption: Developer information: diff --git a/docs/source/environment.rst b/docs/source/environment.rst new file mode 100644 index 00000000..2b76572d --- /dev/null +++ b/docs/source/environment.rst @@ -0,0 +1,10 @@ +RL Environments +*************** + +RL environments are the objects that directly interface with RL libraries such as Stable-Baselines3 and Ray RLLib. The PrimAITE simulation is exposed via three different environment APIs: + +* Gymnasium API - this is the standard interface that works with many RL libraries like SB3, Ray, Tianshou, etc. ``PrimaiteGymEnv`` adheres to the `Official Gymnasium documentation `_. +* Ray Single agent API - For training a single Ray RLLib agent +* Ray MARL API - For training multi-agent systems with Ray RLLib. ``PrimaiteRayMARLEnv`` adheres to the `Official Ray documentation `_. + +There are Jupyter notebooks which demonstrate integration with each of these three environments. They are located in ``~/primaite//notebooks/example_notebooks``. diff --git a/docs/source/game_layer.rst b/docs/source/game_layer.rst index cdae17dd..1f2921fe 100644 --- a/docs/source/game_layer.rst +++ b/docs/source/game_layer.rst @@ -20,6 +20,11 @@ The game layer is responsible for managing agents and getting them to interface PrimAITE Session ^^^^^^^^^^^^^^^ +.. admonition:: Deprecated + :class: deprecated + + PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The `session` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality. + ``PrimaiteSession`` is the main entry point into Primaite and it allows the simultaneous coordination of a simulation and agents that interact with it. ``PrimaiteSession`` keeps track of multiple agents of different types. Agents diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst index 706397b6..87a3f03d 100644 --- a/docs/source/primaite_session.rst +++ b/docs/source/primaite_session.rst @@ -4,6 +4,11 @@ .. _run a primaite session: +.. admonition:: Deprecated + :class: deprecated + + PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The ``session`` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality. + Run a PrimAITE Session ====================== diff --git a/docs/source/simulation_components/network/network_interfaces.rst b/docs/source/simulation_components/network/network_interfaces.rst index 9e1ad80a..2bb8dda4 100644 --- a/docs/source/simulation_components/network/network_interfaces.rst +++ b/docs/source/simulation_components/network/network_interfaces.rst @@ -65,9 +65,14 @@ Network Interface Classes **NetworkInterface (Base Layer)** -Abstract base class defining core interface properties like MAC address, speed, MTU. -Requires subclasses implement key methods like send/receive frames, enable/disable interface. -Establishes universal network interface capabilities. +- Abstract base class defining core interface properties like MAC address, speed, MTU. +- Requires subclasses implement key methods like send/receive frames, enable/disable interface. +- Establishes universal network interface capabilities. +- Malicious Network Events Monitoring: + + * Enhances network interfaces with the capability to monitor and capture Malicious Network Events (MNEs) based on predefined criteria such as specific keywords or traffic patterns. + * Integrates Number of Malicious Network Events (NMNE) detection functionalities, leveraging configurable settings like ``capture_nmne``, `nmne_capture_keywords``, and observation mechanisms such as ``NicObservation`` to classify and record network anomalies. + * Offers an additional layer of security and data analysis, crucial for identifying and mitigating malicious activities within the network infrastructure. Provides vital information for network security analysis and reinforcement learning algorithms. **WiredNetworkInterface (Connection Type Layer)** diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 6813161d..a3fbf561 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -14,7 +14,7 @@ io_settings: save_checkpoints: true checkpoint_interval: 5 save_step_metadata: false - save_pcap_logs: true + save_pcap_logs: false save_sys_logs: true @@ -627,6 +627,10 @@ agents: simulation: network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE nodes: - ref: router_1 @@ -696,12 +700,13 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index df6130d1..d6d3f044 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -1,15 +1,23 @@ training_config: rl_framework: RLLIB_multi_agent - # rl_framework: SB3 + rl_algorithm: PPO + seed: 333 + n_learn_episodes: 1 + n_eval_episodes: 5 + max_steps_per_episode: 256 + deterministic_eval: false n_agents: 2 agent_references: - defender_1 - defender_2 + io_settings: save_checkpoints: true checkpoint_interval: 5 save_step_metadata: false + save_pcap_logs: false + save_sys_logs: true game: @@ -36,9 +44,9 @@ agents: - type: NODE_APPLICATION_EXECUTE options: nodes: - - node_ref: client_2 + - node_name: client_2 applications: - - application_ref: client_2_web_browser + - application_name: WebBrowser max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 @@ -54,6 +62,31 @@ agents: frequency: 4 variance: 3 + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: WebBrowser + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 1 + reward_function: + reward_components: + - type: DUMMY + + + + - ref: data_manipulation_attacker team: RED type: RedDatabaseCorruptingAgent @@ -72,7 +105,7 @@ agents: - type: NODE_OS_SCAN options: nodes: - - node_ref: client_1 + - node_name: client_1 applications: - application_name: DataManipulationBot - node_name: client_2 @@ -104,25 +137,21 @@ agents: num_files_per_folder: 1 num_nics_per_node: 2 nodes: - - node_ref: domain_controller + - node_hostname: domain_controller services: - - service_ref: domain_controller_dns_server - - node_ref: web_server + - service_name: DNSServer + - node_hostname: web_server services: - - service_ref: web_server_database_client - - node_ref: database_server - services: - - service_ref: database_service + - service_name: WebServer + - node_hostname: database_server folders: - folder_name: database files: - file_name: database.db - - node_ref: backup_server - # services: - # - service_ref: backup_service - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - node_hostname: backup_server + - node_hostname: security_suite + - node_hostname: client_1 + - node_hostname: client_2 links: - link_ref: router_1___switch_1 - link_ref: router_1___switch_2 @@ -137,23 +166,23 @@ agents: acl: options: max_acl_rules: 10 - router_node_ref: router_1 + router_hostname: router_1 ip_address_order: - - node_ref: domain_controller + - node_hostname: domain_controller nic_num: 1 - - node_ref: web_server + - node_hostname: web_server nic_num: 1 - - node_ref: database_server + - node_hostname: database_server nic_num: 1 - - node_ref: backup_server + - node_hostname: backup_server nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 1 - - node_ref: client_1 + - node_hostname: client_1 nic_num: 1 - - node_ref: client_2 + - node_hostname: client_2 nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 2 ics: null @@ -184,10 +213,10 @@ agents: - type: NODE_RESET - type: NETWORK_ACL_ADDRULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_ACL_REMOVERULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_NIC_ENABLE - type: NETWORK_NIC_DISABLE @@ -242,25 +271,25 @@ agents: action: "NODE_FILE_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 10: action: "NODE_FILE_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 11: action: "NODE_FILE_DELETE" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 12: action: "NODE_FILE_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 13: action: "NODE_SERVICE_PATCH" @@ -271,22 +300,22 @@ agents: action: "NODE_FOLDER_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 15: action: "NODE_FOLDER_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: node_id: 2 - folder_id: 1 + folder_id: 0 18: action: "NODE_OS_SCAN" options: @@ -303,63 +332,63 @@ agents: action: "NODE_RESET" options: node_id: 5 - 22: + 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" action: "NETWORK_ACL_ADDRULE" options: position: 1 permission: 2 - source_ip_id: 7 - dest_ip_id: 1 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 23: + 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 2 permission: 2 - source_ip_id: 8 - dest_ip_id: 1 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 24: + 24: # block tcp traffic from client 1 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 3 permission: 2 - source_ip_id: 7 - dest_ip_id: 3 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 25: + 25: # block tcp traffic from client 2 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 4 permission: 2 - source_ip_id: 8 - dest_ip_id: 3 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 26: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 5 permission: 2 - source_ip_id: 7 - dest_ip_id: 4 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 27: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 6 permission: 2 - source_ip_id: 8 - dest_ip_id: 4 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 @@ -407,123 +436,148 @@ agents: action: "NETWORK_NIC_DISABLE" options: node_id: 0 - nic_id: 1 + nic_id: 0 39: action: "NETWORK_NIC_ENABLE" options: node_id: 0 - nic_id: 1 + nic_id: 0 40: action: "NETWORK_NIC_DISABLE" options: node_id: 1 - nic_id: 1 + nic_id: 0 41: action: "NETWORK_NIC_ENABLE" options: node_id: 1 - nic_id: 1 + nic_id: 0 42: action: "NETWORK_NIC_DISABLE" options: node_id: 2 - nic_id: 1 + nic_id: 0 43: action: "NETWORK_NIC_ENABLE" options: node_id: 2 - nic_id: 1 + nic_id: 0 44: action: "NETWORK_NIC_DISABLE" options: node_id: 3 - nic_id: 1 + nic_id: 0 45: action: "NETWORK_NIC_ENABLE" options: node_id: 3 - nic_id: 1 + nic_id: 0 46: action: "NETWORK_NIC_DISABLE" options: node_id: 4 - nic_id: 1 + nic_id: 0 47: action: "NETWORK_NIC_ENABLE" options: node_id: 4 - nic_id: 1 + nic_id: 0 48: action: "NETWORK_NIC_DISABLE" options: node_id: 4 - nic_id: 2 + nic_id: 1 49: action: "NETWORK_NIC_ENABLE" options: node_id: 4 - nic_id: 2 + nic_id: 1 50: action: "NETWORK_NIC_DISABLE" options: node_id: 5 - nic_id: 1 + nic_id: 0 51: action: "NETWORK_NIC_ENABLE" options: node_id: 5 - nic_id: 1 + nic_id: 0 52: action: "NETWORK_NIC_DISABLE" options: node_id: 6 - nic_id: 1 + nic_id: 0 53: action: "NETWORK_NIC_ENABLE" options: node_id: 6 - nic_id: 1 + nic_id: 0 options: nodes: - - node_ref: domain_controller - - node_ref: web_server + - node_name: domain_controller + - node_name: web_server + applications: + - application_name: DatabaseClient services: - - service_ref: web_server_web_service - - node_ref: database_server + - service_name: WebServer + - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db services: - - service_ref: database_service - - node_ref: backup_server - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - service_name: DatabaseService + - node_name: backup_server + - node_name: security_suite + - node_name: client_1 + - node_name: client_2 + max_folders_per_node: 2 max_files_per_folder: 2 max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_name: domain_controller + nic_num: 1 + - node_name: web_server + nic_num: 1 + - node_name: database_server + nic_num: 1 + - node_name: backup_server + nic_num: 1 + - node_name: security_suite + nic_num: 1 + - node_name: client_1 + nic_num: 1 + - node_name: client_2 + nic_num: 1 + - node_name: security_suite + nic_num: 2 + reward_function: reward_components: - type: DATABASE_FILE_INTEGRITY - weight: 0.5 + weight: 0.34 options: - node_ref: database_server + node_hostname: database_server folder_name: database file_name: database.db - - - - type: WEB_SERVER_404_PENALTY - weight: 0.5 + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.33 options: - node_ref: web_server - service_ref: web_server_web_service + node_hostname: client_1 + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.33 + options: + node_hostname: client_2 agent_settings: - # ... - + flatten_obs: true - ref: defender_2 team: BLUE @@ -537,25 +591,21 @@ agents: num_files_per_folder: 1 num_nics_per_node: 2 nodes: - - node_ref: domain_controller + - node_hostname: domain_controller services: - - service_ref: domain_controller_dns_server - - node_ref: web_server + - service_name: DNSServer + - node_hostname: web_server services: - - service_ref: web_server_database_client - - node_ref: database_server - services: - - service_ref: database_service + - service_name: WebServer + - node_hostname: database_server folders: - folder_name: database files: - file_name: database.db - - node_ref: backup_server - # services: - # - service_ref: backup_service - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - node_hostname: backup_server + - node_hostname: security_suite + - node_hostname: client_1 + - node_hostname: client_2 links: - link_ref: router_1___switch_1 - link_ref: router_1___switch_2 @@ -570,23 +620,23 @@ agents: acl: options: max_acl_rules: 10 - router_node_ref: router_1 + router_hostname: router_1 ip_address_order: - - node_ref: domain_controller + - node_hostname: domain_controller nic_num: 1 - - node_ref: web_server + - node_hostname: web_server nic_num: 1 - - node_ref: database_server + - node_hostname: database_server nic_num: 1 - - node_ref: backup_server + - node_hostname: backup_server nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 1 - - node_ref: client_1 + - node_hostname: client_1 nic_num: 1 - - node_ref: client_2 + - node_hostname: client_2 nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 2 ics: null @@ -617,10 +667,10 @@ agents: - type: NODE_RESET - type: NETWORK_ACL_ADDRULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_ACL_REMOVERULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_NIC_ENABLE - type: NETWORK_NIC_DISABLE @@ -675,25 +725,25 @@ agents: action: "NODE_FILE_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 10: action: "NODE_FILE_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 11: action: "NODE_FILE_DELETE" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 12: action: "NODE_FILE_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 13: action: "NODE_SERVICE_PATCH" @@ -704,22 +754,22 @@ agents: action: "NODE_FOLDER_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 15: action: "NODE_FOLDER_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: node_id: 2 - folder_id: 1 + folder_id: 0 18: action: "NODE_OS_SCAN" options: @@ -736,63 +786,63 @@ agents: action: "NODE_RESET" options: node_id: 5 - 22: + 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" action: "NETWORK_ACL_ADDRULE" options: position: 1 permission: 2 - source_ip_id: 7 - dest_ip_id: 1 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 23: + 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 2 permission: 2 - source_ip_id: 8 - dest_ip_id: 1 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 24: + 24: # block tcp traffic from client 1 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 3 permission: 2 - source_ip_id: 7 - dest_ip_id: 3 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 25: + 25: # block tcp traffic from client 2 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 4 permission: 2 - source_ip_id: 8 - dest_ip_id: 3 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 26: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 5 permission: 2 - source_ip_id: 7 - dest_ip_id: 4 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 27: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 6 permission: 2 - source_ip_id: 8 - dest_ip_id: 4 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 @@ -840,122 +890,148 @@ agents: action: "NETWORK_NIC_DISABLE" options: node_id: 0 - nic_id: 1 + nic_id: 0 39: action: "NETWORK_NIC_ENABLE" options: node_id: 0 - nic_id: 1 + nic_id: 0 40: action: "NETWORK_NIC_DISABLE" options: node_id: 1 - nic_id: 1 + nic_id: 0 41: action: "NETWORK_NIC_ENABLE" options: node_id: 1 - nic_id: 1 + nic_id: 0 42: action: "NETWORK_NIC_DISABLE" options: node_id: 2 - nic_id: 1 + nic_id: 0 43: action: "NETWORK_NIC_ENABLE" options: node_id: 2 - nic_id: 1 + nic_id: 0 44: action: "NETWORK_NIC_DISABLE" options: node_id: 3 - nic_id: 1 + nic_id: 0 45: action: "NETWORK_NIC_ENABLE" options: node_id: 3 - nic_id: 1 + nic_id: 0 46: action: "NETWORK_NIC_DISABLE" options: node_id: 4 - nic_id: 1 + nic_id: 0 47: action: "NETWORK_NIC_ENABLE" options: node_id: 4 - nic_id: 1 + nic_id: 0 48: action: "NETWORK_NIC_DISABLE" options: node_id: 4 - nic_id: 2 + nic_id: 1 49: action: "NETWORK_NIC_ENABLE" options: node_id: 4 - nic_id: 2 + nic_id: 1 50: action: "NETWORK_NIC_DISABLE" options: node_id: 5 - nic_id: 1 + nic_id: 0 51: action: "NETWORK_NIC_ENABLE" options: node_id: 5 - nic_id: 1 + nic_id: 0 52: action: "NETWORK_NIC_DISABLE" options: node_id: 6 - nic_id: 1 + nic_id: 0 53: action: "NETWORK_NIC_ENABLE" options: node_id: 6 - nic_id: 1 + nic_id: 0 options: nodes: - - node_ref: domain_controller - - node_ref: web_server + - node_name: domain_controller + - node_name: web_server + applications: + - application_name: DatabaseClient services: - - service_ref: web_server_web_service - - node_ref: database_server + - service_name: WebServer + - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db services: - - service_ref: database_service - - node_ref: backup_server - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - service_name: DatabaseService + - node_name: backup_server + - node_name: security_suite + - node_name: client_1 + - node_name: client_2 + max_folders_per_node: 2 max_files_per_folder: 2 max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_name: domain_controller + nic_num: 1 + - node_name: web_server + nic_num: 1 + - node_name: database_server + nic_num: 1 + - node_name: backup_server + nic_num: 1 + - node_name: security_suite + nic_num: 1 + - node_name: client_1 + nic_num: 1 + - node_name: client_2 + nic_num: 1 + - node_name: security_suite + nic_num: 2 + reward_function: reward_components: - type: DATABASE_FILE_INTEGRITY - weight: 0.5 + weight: 0.34 options: - node_ref: database_server + node_hostname: database_server folder_name: database file_name: database.db - - - - type: WEB_SERVER_404_PENALTY - weight: 0.5 + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.33 options: - node_ref: web_server - service_ref: web_server_web_service + node_hostname: client_1 + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.33 + options: + node_hostname: client_2 agent_settings: - # ... + flatten_obs: true @@ -963,6 +1039,10 @@ agents: simulation: network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE nodes: - ref: router_1 @@ -1032,12 +1112,13 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server @@ -1089,10 +1170,14 @@ simulation: - ref: data_manipulation_bot type: DataManipulationBot options: - port_scan_p_of_success: 0.1 - data_manipulation_p_of_success: 0.1 + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 + - ref: client_1_web_browser + type: WebBrowser + options: + target_url: http://arcd.com/users/ services: - ref: client_1_dns_client type: DNSClient @@ -1109,6 +1194,13 @@ simulation: type: WebBrowser options: target_url: http://arcd.com/users/ + - ref: data_manipulation_bot + type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 services: - ref: client_2_dns_client type: DNSClient diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index dfee2543..82e11fe0 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -8,6 +8,7 @@ from gymnasium.core import ObsType from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.simulator.network.nmne import CAPTURE_NMNE _LOGGER = getLogger(__name__) @@ -346,7 +347,14 @@ class FolderObservation(AbstractObservation): class NicObservation(AbstractObservation): """Observation of a Network Interface Card (NIC) in the network.""" - default_observation: spaces.Space = {"nic_status": 0} + @property + def default_observation(self) -> Dict: + """The default NIC observation dict.""" + data = {"nic_status": 0} + if CAPTURE_NMNE: + data.update({"nmne": {"inbound": 0, "outbound": 0}}) + + return data def __init__(self, where: Optional[Tuple[str]] = None) -> None: """Initialise NIC observation. @@ -360,6 +368,29 @@ class NicObservation(AbstractObservation): super().__init__() self.where: Optional[Tuple[str]] = where + def _categorise_mne_count(self, nmne_count: int) -> int: + """ + Categorise the number of Malicious Network Events (NMNEs) into discrete bins. + + This helps in classifying the severity or volume of MNEs into manageable levels for the agent. + + Bins are defined as follows: + - 0: No MNEs detected (0 events). + - 1: Low number of MNEs (1-5 events). + - 2: Moderate number of MNEs (6-10 events). + - 3: High number of MNEs (more than 10 events). + + :param nmne_count: Number of MNEs detected. + :return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count. + """ + if nmne_count > 10: + return 3 + elif nmne_count > 5: + return 2 + elif nmne_count > 0: + return 1 + return 0 + def observe(self, state: Dict) -> Dict: """Generate observation based on the current state of the simulation. @@ -371,15 +402,31 @@ class NicObservation(AbstractObservation): if self.where is None: return self.default_observation nic_state = access_from_nested_dict(state, self.where) + if nic_state is NOT_PRESENT_IN_STATE: return self.default_observation else: - return {"nic_status": 1 if nic_state["enabled"] else 2} + obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2} + if CAPTURE_NMNE: + obs_dict.update({"nmne": {}}) + direction_dict = nic_state["nmne"].get("direction", {}) + inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) + inbound_count = inbound_keywords.get("*", 0) + outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) + outbound_count = outbound_keywords.get("*", 0) + obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count) + obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count) + return obs_dict @property def space(self) -> spaces.Space: """Gymnasium space object describing the observation space shape.""" - return spaces.Dict({"nic_status": spaces.Discrete(3)}) + return spaces.Dict( + { + "nic_status": spaces.Discrete(3), + "nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}), + } + ) @classmethod def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index b44abe16..eeb0d007 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -18,6 +18,7 @@ from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.network.hardware.nodes.network.switch import Switch +from primaite.simulator.network.nmne import set_nmne_config from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient @@ -80,18 +81,15 @@ class PrimaiteGame: self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" - self.agents: List[AbstractAgent] = [] - """List of agents.""" + self.agents: Dict[str, AbstractAgent] = {} + """Mapping from agent name to agent object.""" - self.rl_agents: List[ProxyAgent] = [] - """Subset of agent list including only the reinforcement learning agents.""" + self.rl_agents: Dict[str, ProxyAgent] = {} + """Subset of agents which are intended for reinforcement learning.""" self.step_counter: int = 0 """Current timestep within the episode.""" - self.episode_counter: int = 0 - """Current episode number.""" - self.options: PrimaiteGameOptions """Special options that apply for the entire game.""" @@ -137,7 +135,7 @@ class PrimaiteGame: self.update_agents(sim_state) # Apply all actions to simulation as requests - agent_actions = self.apply_agent_actions() # noqa + self.apply_agent_actions() # Advance timestep self.advance_timestep() @@ -148,7 +146,7 @@ class PrimaiteGame: def update_agents(self, state: Dict) -> None: """Update agents' observations and rewards based on the current state.""" - for agent in self.agents: + for _, agent in self.agents.items(): agent.update_observation(state) agent.update_reward(state) agent.reward_function.total_reward += agent.reward_function.current_reward @@ -162,7 +160,7 @@ class PrimaiteGame: """ agent_actions = {} - for agent in self.agents: + for _, agent in self.agents.items(): obs = agent.observation_manager.current_observation rew = agent.reward_function.current_reward action_choice, options = agent.get_action(obs, rew, timestep=self.step_counter) @@ -185,20 +183,14 @@ class PrimaiteGame: return True return False - def reset(self) -> None: - """Reset the game, this will reset the simulation.""" - self.episode_counter += 1 - self.step_counter = 0 - _LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}") - self.simulation.reset_component_for_episode(episode=self.episode_counter) - for agent in self.agents: - agent.reward_function.total_reward = 0.0 - agent.reset_agent_for_episode() - def close(self) -> None: """Close the game, this will close the simulation.""" return NotImplemented + def setup_for_episode(self, episode: int) -> None: + """Perform any final configuration of components to make them ready for the game to start.""" + self.simulation.setup_for_episode(episode=episode) + @classmethod def from_config(cls, cfg: Dict) -> "PrimaiteGame": """Create a PrimaiteGame object from a config dictionary. @@ -272,7 +264,9 @@ class PrimaiteGame: new_service = new_node.software_manager.software[service_type] game.ref_map_services[service_ref] = new_service.uuid else: - _LOGGER.warning(f"service type not found {service_type}") + msg = f"Configuration contains an invalid service type: {service_type}" + _LOGGER.error(msg) + raise ValueError(msg) # service-dependent options if service_type == "DNSClient": if "options" in service_cfg: @@ -311,7 +305,9 @@ class PrimaiteGame: new_application = new_node.software_manager.software[application_type] game.ref_map_applications[application_ref] = new_application.uuid else: - _LOGGER.warning(f"application type not found {application_type}") + msg = f"Configuration contains an invalid application type: {application_type}" + _LOGGER.error(msg) + raise ValueError(msg) if application_type == "DataManipulationBot": if "options" in application_cfg: @@ -403,7 +399,6 @@ class PrimaiteGame: reward_function=reward_function, settings=settings, ) - game.agents.append(new_agent) elif agent_type == "ProxyAgent": new_agent = ProxyAgent( agent_name=agent_cfg["ref"], @@ -412,8 +407,7 @@ class PrimaiteGame: reward_function=reward_function, agent_settings=agent_settings, ) - game.agents.append(new_agent) - game.rl_agents.append(new_agent) + game.rl_agents[agent_cfg["ref"]] = new_agent elif agent_type == "RedDatabaseCorruptingAgent": new_agent = DataManipulationAgent( agent_name=agent_cfg["ref"], @@ -422,10 +416,13 @@ class PrimaiteGame: reward_function=reward_function, agent_settings=agent_settings, ) - game.agents.append(new_agent) else: - _LOGGER.warning(f"agent type {agent_type} not found") + msg(f"Configuration error: {agent_type} is not a valid agent type.") + _LOGGER.error(msg) + raise ValueError(msg) + game.agents[agent_cfg["ref"]] = new_agent - game.simulation.set_original_state() + # Set the NMNE capture config + set_nmne_config(cfg["simulation"]["network"].get("nmne_config", {})) return game diff --git a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb index 0d4b6d0e..4ef02443 100644 --- a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb +++ b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb @@ -60,7 +60,7 @@ " policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n", " policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n", " )\n", - " .environment(env=PrimaiteRayMARLEnv, env_config={\"cfg\":cfg})#, disable_env_checking=True)\n", + " .environment(env=PrimaiteRayMARLEnv, env_config=cfg)#, disable_env_checking=True)\n", " .rollouts(num_rollout_workers=0)\n", " .training(train_batch_size=128)\n", " )\n" @@ -88,6 +88,13 @@ " param_space=config\n", ").fit()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/src/primaite/notebooks/training_example_ray_single_agent.ipynb b/src/primaite/notebooks/training_example_ray_single_agent.ipynb index ea006ae9..3c27bdc6 100644 --- a/src/primaite/notebooks/training_example_ray_single_agent.ipynb +++ b/src/primaite/notebooks/training_example_ray_single_agent.ipynb @@ -54,7 +54,7 @@ "metadata": {}, "outputs": [], "source": [ - "env_config = {\"cfg\":cfg}\n", + "env_config = cfg\n", "\n", "config = (\n", " PPOConfig()\n", diff --git a/src/primaite/notebooks/training_example_sb3.ipynb b/src/primaite/notebooks/training_example_sb3.ipynb index e5085c5e..0472854e 100644 --- a/src/primaite/notebooks/training_example_sb3.ipynb +++ b/src/primaite/notebooks/training_example_sb3.ipynb @@ -27,9 +27,7 @@ "outputs": [], "source": [ "with open(example_config_path(), 'r') as f:\n", - " cfg = yaml.safe_load(f)\n", - "\n", - "game = PrimaiteGame.from_config(cfg)" + " cfg = yaml.safe_load(f)\n" ] }, { @@ -38,7 +36,7 @@ "metadata": {}, "outputs": [], "source": [ - "gym = PrimaiteGymEnv(game=game)" + "gym = PrimaiteGymEnv(game_config=cfg)" ] }, { @@ -65,7 +63,7 @@ "metadata": {}, "outputs": [], "source": [ - "model.learn(total_timesteps=1000)\n" + "model.learn(total_timesteps=10)\n" ] }, { @@ -76,6 +74,13 @@ "source": [ "model.save(\"deleteme\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/uc2_demo.ipynb index fa4a28a4..cf973905 100644 --- a/src/primaite/notebooks/uc2_demo.ipynb +++ b/src/primaite/notebooks/uc2_demo.ipynb @@ -130,6 +130,9 @@ " - NETWORK_INTERFACES\n", " - \n", " - nic_status\n", + " - nmne\n", + " - inbound\n", + " - outbound\n", " - operating_status\n", "- LINKS\n", " - \n", @@ -220,6 +223,14 @@ "|1|ENABLED|\n", "|2|DISABLED|\n", "\n", + "NMNE (number of malicious network events) means, for inbound or outbound traffic, means:\n", + "|value|NMNEs|\n", + "|--|--|\n", + "|0|None|\n", + "|1|1 - 5|\n", + "|2|6 - 10|\n", + "|3|More than 10|\n", + "\n", "Link load has the following meaning:\n", "|load|percent utilisation|\n", "|--|--|\n", @@ -371,150 +382,9 @@ }, { "cell_type": "code", - "execution_count": 3, - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-02-27 09:43:39,312::WARNING::primaite.game.game::275::service type not found DatabaseClient\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Resetting environment, episode 0, avg. reward: 0.0\n", - "env created successfully\n", - "{'ACL': {1: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 0,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0},\n", - " 2: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 1,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0},\n", - " 3: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 2,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0},\n", - " 4: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 3,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0},\n", - " 5: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 4,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0},\n", - " 6: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 5,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0},\n", - " 7: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 6,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0},\n", - " 8: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 7,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0},\n", - " 9: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 8,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0},\n", - " 10: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 9,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0}},\n", - " 'ICS': 0,\n", - " 'LINKS': {1: {'PROTOCOLS': {'ALL': 1}},\n", - " 2: {'PROTOCOLS': {'ALL': 1}},\n", - " 3: {'PROTOCOLS': {'ALL': 1}},\n", - " 4: {'PROTOCOLS': {'ALL': 1}},\n", - " 5: {'PROTOCOLS': {'ALL': 1}},\n", - " 6: {'PROTOCOLS': {'ALL': 1}},\n", - " 7: {'PROTOCOLS': {'ALL': 1}},\n", - " 8: {'PROTOCOLS': {'ALL': 1}},\n", - " 9: {'PROTOCOLS': {'ALL': 1}},\n", - " 10: {'PROTOCOLS': {'ALL': 0}}},\n", - " 'NODES': {1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", - " 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", - " 2: {'nic_status': 0}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", - " 'operating_status': 1},\n", - " 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", - " 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", - " 2: {'nic_status': 0}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", - " 'operating_status': 1},\n", - " 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 1}},\n", - " 'health_status': 1}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", - " 2: {'nic_status': 0}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1},\n", - " 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", - " 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", - " 2: {'nic_status': 0}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1},\n", - " 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", - " 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", - " 2: {'nic_status': 0}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1},\n", - " 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", - " 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", - " 2: {'nic_status': 0}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1},\n", - " 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", - " 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", - " 2: {'nic_status': 0}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1}}}\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# create the env\n", "with open(example_config_path(), 'r') as f:\n", @@ -524,10 +394,10 @@ " cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['data_manipulation_p_of_success'] = 1.0\n", " cfg['simulation']['network']['nodes'][8]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n", " cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n", - "game = PrimaiteGame.from_config(cfg)\n", - "env = PrimaiteGymEnv(game = game)\n", - "# Don't flatten obs as we are not training an agent and we wish to see the dict-formatted observations\n", - "env.agent.flatten_obs = False\n", + " # don't flatten observations so that we can see what is going on\n", + " cfg['agents'][3]['agent_settings']['flatten_obs'] = False\n", + "\n", + "env = PrimaiteGymEnv(game_config = cfg)\n", "obs, info = env.reset()\n", "print('env created successfully')\n", "pprint(obs)" @@ -560,53 +430,9 @@ }, { "cell_type": "code", - "execution_count": 11, - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "step: 211, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 212, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 213, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 214, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 215, Red action: DO NOTHING, Blue reward:-0.42\n", - "step: 216, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 217, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 218, Red action: DO NOTHING, Blue reward:-0.42\n", - "step: 219, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 220, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 221, Red action: ATTACK from client 2, Blue reward:-0.32\n", - "step: 222, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 223, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 224, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 225, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 226, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 227, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 228, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 229, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 230, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 231, Red action: DO NOTHING, Blue reward:-0.42\n", - "step: 232, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 233, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 234, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 235, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 236, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 237, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 238, Red action: ATTACK from client 2, Blue reward:-0.32\n", - "step: 239, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 240, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 241, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 242, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 243, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 244, Red action: DO NOTHING, Blue reward:-0.32\n", - "step: 245, Red action: DO NOTHING, Blue reward:-0.32\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "for step in range(35):\n", " obs, reward, terminated, truncated, info = env.step(0)\n", @@ -623,9 +449,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "pprint(obs['NODES'])" @@ -641,9 +465,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "obs, reward, terminated, truncated, info = env.step(9) # scan database file\n", @@ -669,9 +491,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "obs, reward, terminated, truncated, info = env.step(13) # patch the database\n", @@ -696,9 +516,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "obs, reward, terminated, truncated, info = env.step(0) # patch the database\n", @@ -721,9 +539,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "env.step(13) # Patch the database\n", @@ -763,6 +579,22 @@ "obs['ACL']" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Reset the environment, you can rerun the other cells to verify that the attack works the same every episode." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.reset()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -800,7 +632,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "venv", "language": "python", "name": "python3" }, @@ -818,5 +650,5 @@ } }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 2 } diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index a3831bc1..f8dbab9d 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Final, Optional, SupportsFloat, Tuple +from typing import Any, Dict, Optional, SupportsFloat, Tuple import gymnasium from gymnasium.core import ActType, ObsType @@ -18,11 +18,23 @@ class PrimaiteGymEnv(gymnasium.Env): assumptions about the agent list always having a list of length 1. """ - def __init__(self, game: PrimaiteGame): + def __init__(self, game_config: Dict): """Initialise the environment.""" super().__init__() - self.game: "PrimaiteGame" = game - self.agent: ProxyAgent = self.game.rl_agents[0] + self.game_config: Dict = game_config + """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" + self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config) + """Current game.""" + self._agent_name = next(iter(self.game.rl_agents)) + """Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key.""" + + self.episode_counter: int = 0 + """Current episode number.""" + + @property + def agent(self) -> ProxyAgent: + """Grab a fresh reference to the agent object because it will be reinstantiated each episode.""" + return self.game.rl_agents[self._agent_name] def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """Perform a step in the environment.""" @@ -45,13 +57,13 @@ class PrimaiteGymEnv(gymnasium.Env): return next_obs, reward, terminated, truncated, info def _write_step_metadata_json(self, action: int, state: Dict, reward: int): - output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata" + output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata" output_dir.mkdir(parents=True, exist_ok=True) path = output_dir / f"step_{self.game.step_counter}.json" data = { - "episode": self.game.episode_counter, + "episode": self.episode_counter, "step": self.game.step_counter, "action": int(action), "reward": int(reward), @@ -63,10 +75,12 @@ class PrimaiteGymEnv(gymnasium.Env): def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: """Reset the environment.""" print( - f"Resetting environment, episode {self.game.episode_counter}, " - f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}" + f"Resetting environment, episode {self.episode_counter}, " + f"avg. reward: {self.agent.reward_function.total_reward}" ) - self.game.reset() + self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config) + self.game.setup_for_episode(episode=self.episode_counter) + self.episode_counter += 1 state = self.game.get_sim_state() self.game.update_agents(state) next_obs = self._get_obs() @@ -88,12 +102,12 @@ class PrimaiteGymEnv(gymnasium.Env): def _get_obs(self) -> ObsType: """Return the current observation.""" - if not self.agent.flatten_obs: - return self.agent.observation_manager.current_observation - else: + if self.agent.flatten_obs: unflat_space = self.agent.observation_manager.space unflat_obs = self.agent.observation_manager.current_observation return gymnasium.spaces.flatten(unflat_space, unflat_obs) + else: + return self.agent.observation_manager.current_observation class PrimaiteRayEnv(gymnasium.Env): @@ -102,12 +116,11 @@ class PrimaiteRayEnv(gymnasium.Env): def __init__(self, env_config: Dict) -> None: """Initialise the environment. - :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` - which is the PrimaiteGame instance. - :type env_config: Dict[str, PrimaiteGame] + :param env_config: A dictionary containing the environment configuration. + :type env_config: Dict """ - self.env = PrimaiteGymEnv(game=PrimaiteGame.from_config(env_config["cfg"])) - self.env.game.episode_counter -= 1 + self.env = PrimaiteGymEnv(game_config=env_config) + self.env.episode_counter -= 1 self.action_space = self.env.action_space self.observation_space = self.env.observation_space @@ -128,13 +141,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` which is the PrimaiteGame instance. - :type env_config: Dict[str, PrimaiteGame] + :type env_config: Dict """ - self.game: PrimaiteGame = PrimaiteGame.from_config(env_config["cfg"]) + self.game_config: Dict = env_config + """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" + self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config) """Reference to the primaite game""" - self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents} - """List of all possible agents in the environment. This list should not change!""" - self._agent_ids = list(self.agents.keys()) + self._agent_ids = list(self.game.rl_agents.keys()) + """Agent ids. This is a list of strings of agent names.""" + self.episode_counter: int = 0 + """Current episode number.""" self.terminateds = set() self.truncateds = set() @@ -149,9 +165,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): ) super().__init__() + @property + def agents(self) -> Dict[str, ProxyAgent]: + """Grab a fresh reference to the agents from this episode's game object.""" + return {name: self.game.rl_agents[name] for name in self._agent_ids} + def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" - self.game.reset() + self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config) + self.game.setup_for_episode(episode=self.episode_counter) + self.episode_counter += 1 state = self.game.get_sim_state() self.game.update_agents(state) next_obs = self._get_obs() @@ -172,7 +195,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): # 1. Perform actions for agent_name, action in actions.items(): self.agents[agent_name].store_action(action) - agent_actions = self.game.apply_agent_actions() + self.game.apply_agent_actions() # 2. Advance timestep self.game.advance_timestep() @@ -186,7 +209,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()} terminateds = {name: False for name, _ in self.agents.items()} truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()} - infos = {"agent_actions": agent_actions} + infos = {name: {} for name, _ in self.agents.items()} terminateds["__all__"] = len(self.terminateds) == len(self.agents) truncateds["__all__"] = self.game.calculate_truncated() if self.game.save_step_metadata: @@ -194,13 +217,13 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): return next_obs, rewards, terminateds, truncateds, infos def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict): - output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata" + output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata" output_dir.mkdir(parents=True, exist_ok=True) path = output_dir / f"step_{self.game.step_counter}.json" data = { - "episode": self.game.episode_counter, + "episode": self.episode_counter, "step": self.game.step_counter, "actions": {agent_name: int(action) for agent_name, action in actions.items()}, "reward": rewards, @@ -212,8 +235,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def _get_obs(self) -> Dict[str, ObsType]: """Return the current observation.""" obs = {} - for name, agent in self.agents.items(): + for agent_name in self._agent_ids: + agent = self.game.rl_agents[agent_name] unflat_space = agent.observation_manager.space unflat_obs = agent.observation_manager.current_observation - obs[name] = gymnasium.spaces.flatten(unflat_space, unflat_obs) + obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs) return obs diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index 5c663cfd..b8f80e95 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -101,11 +101,11 @@ class PrimaiteSession: # CREATE ENVIRONMENT if sess.training_options.rl_framework == "RLLIB_single_agent": - sess.env = PrimaiteRayEnv(env_config={"cfg": cfg}) + sess.env = PrimaiteRayEnv(env_config=cfg) elif sess.training_options.rl_framework == "RLLIB_multi_agent": - sess.env = PrimaiteRayMARLEnv(env_config={"cfg": cfg}) + sess.env = PrimaiteRayMARLEnv(env_config=cfg) elif sess.training_options.rl_framework == "SB3": - sess.env = PrimaiteGymEnv(game=game) + sess.env = PrimaiteGymEnv(game_config=cfg) sess.policy = PolicyABC.from_config(sess.training_options, session=sess) if agent_load_path: diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 964dac01..99e9be7f 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -153,22 +153,18 @@ class SimComponent(BaseModel): uuid: str = Field(default_factory=lambda: str(uuid4())) """The component UUID.""" - _original_state: Dict = {} - def __init__(self, **kwargs): super().__init__(**kwargs) self._request_manager: RequestManager = self._init_request_manager() self._parent: Optional["SimComponent"] = None - # @abstractmethod - def set_original_state(self): - """Sets the original state.""" - pass + def setup_for_episode(self, episode: int): + """ + Perform any additional setup on this component that can't happen during __init__. - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - for key, value in self._original_state.items(): - self.__setattr__(key, value) + For instance, some components may require for the entire network to exist before some configuration can be set. + """ + pass def _init_request_manager(self) -> RequestManager: """ diff --git a/src/primaite/simulator/domain/account.py b/src/primaite/simulator/domain/account.py index d9dad06a..186caf5b 100644 --- a/src/primaite/simulator/domain/account.py +++ b/src/primaite/simulator/domain/account.py @@ -42,19 +42,6 @@ class Account(SimComponent): "Account Type, currently this can be service account (used by apps) or user account." enabled: bool = True - def set_original_state(self): - """Sets the original state.""" - vals_to_include = { - "num_logons", - "num_logoffs", - "num_group_changes", - "username", - "password", - "account_type", - "enabled", - } - self._original_state = self.model_dump(include=vals_to_include) - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/file_system/file.py b/src/primaite/simulator/file_system/file.py index 608a1d78..d9b02e8e 100644 --- a/src/primaite/simulator/file_system/file.py +++ b/src/primaite/simulator/file_system/file.py @@ -73,20 +73,6 @@ class File(FileSystemItemABC): self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})") - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting File ({self.path}) original state on node {self.sys_log.hostname}") - super().set_original_state() - vals_to_include = {"folder_id", "folder_name", "file_type", "sim_size", "real", "sim_path", "sim_root"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting File ({self.path}) state on node {self.sys_log.hostname}") - super().reset_component_for_episode(episode) - @property def path(self) -> str: """ diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index ee80587d..8fd4e5d7 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -34,43 +34,6 @@ class FileSystem(SimComponent): if not self.folders: self.create_folder("root") - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting FileSystem original state on node {self.sys_log.hostname}") - for folder in self.folders.values(): - folder.set_original_state() - # Capture a list of all 'original' file uuids - original_keys = list(self.folders.keys()) - vals_to_include = {"sim_root"} - self._original_state.update(self.model_dump(include=vals_to_include)) - self._original_state["original_folder_uuids"] = original_keys - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting FileSystem state on node {self.sys_log.hostname}") - # Move any 'original' folder that have been deleted back to folders - original_folder_uuids = self._original_state["original_folder_uuids"] - for uuid in original_folder_uuids: - if uuid in self.deleted_folders: - folder = self.deleted_folders[uuid] - self.deleted_folders.pop(uuid) - self.folders[uuid] = folder - - # Clear any other deleted folders that aren't original (have been created by agent) - self.deleted_folders.clear() - - # Now clear all non-original folders created by agent - current_folder_uuids = list(self.folders.keys()) - for uuid in current_folder_uuids: - if uuid not in original_folder_uuids: - folder = self.folders[uuid] - self.folders.pop(uuid) - - # Now reset all remaining folders - for folder in self.folders.values(): - folder.reset_component_for_episode(episode) - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/file_system/file_system_item_abc.py b/src/primaite/simulator/file_system/file_system_item_abc.py index c3e1426b..fbe5f4b3 100644 --- a/src/primaite/simulator/file_system/file_system_item_abc.py +++ b/src/primaite/simulator/file_system/file_system_item_abc.py @@ -85,11 +85,6 @@ class FileSystemItemABC(SimComponent): deleted: bool = False "If true, the FileSystemItem was deleted." - def set_original_state(self): - """Sets the original state.""" - vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red", "deleted"} - self._original_state = self.model_dump(include=vals_to_keep) - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index 13fdc597..771dc7a0 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -49,49 +49,6 @@ class Folder(FileSystemItemABC): self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})") - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting Folder ({self.name}) original state on node {self.sys_log.hostname}") - for file in self.files.values(): - file.set_original_state() - super().set_original_state() - vals_to_include = { - "scan_duration", - "scan_countdown", - "red_scan_duration", - "red_scan_countdown", - "restore_duration", - "restore_countdown", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - self._original_state["original_file_uuids"] = list(self.files.keys()) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting Folder ({self.name}) state on node {self.sys_log.hostname}") - # Move any 'original' file that have been deleted back to files - original_file_uuids = self._original_state["original_file_uuids"] - for uuid in original_file_uuids: - if uuid in self.deleted_files: - file = self.deleted_files[uuid] - self.deleted_files.pop(uuid) - self.files[uuid] = file - - # Clear any other deleted files that aren't original (have been created by agent) - self.deleted_files.clear() - - # Now clear all non-original files created by agent - current_file_uuids = list(self.files.keys()) - for uuid in current_file_uuids: - if uuid not in original_file_uuids: - file = self.files[uuid] - self.files.pop(uuid) - - # Now reset all remaining files - for file in self.files.values(): - file.reset_component_for_episode(episode) - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index 724b8728..d264f751 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -273,11 +273,6 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC) return state - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"} - self._original_state = self.model_dump(include=vals_to_include) - def enable(self): """ Enables this wired network interface and attempts to send a "hello" message to the default gateway. diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index b32d2630..b5a16430 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -45,19 +45,12 @@ class Network(SimComponent): self._nx_graph = MultiGraph() - def set_original_state(self): - """Sets the original state.""" - for node in self.nodes.values(): - node.set_original_state() - for link in self.links.values(): - link.set_original_state() - - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" for node in self.nodes.values(): - node.reset_component_for_episode(episode) + node.setup_for_episode(episode=episode) for link in self.links.values(): - link.reset_component_for_episode(episode) + link.setup_for_episode(episode=episode) for node in self.nodes.values(): node.power_on() @@ -179,7 +172,7 @@ class Network(SimComponent): def clear_links(self): """Clear all the links in the network by resetting their component state for the episode.""" for link in self.links.values(): - link.reset_component_for_episode() + link.setup_for_episode(episode=0) # TODO: shouldn't be using this method here. def draw(self, seed: int = 123): """ diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index fa135674..ff79f314 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -17,6 +17,15 @@ from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.domain.account import Account from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.nmne import ( + CAPTURE_BY_DIRECTION, + CAPTURE_BY_IP_ADDRESS, + CAPTURE_BY_KEYWORD, + CAPTURE_BY_PORT, + CAPTURE_BY_PROTOCOL, + CAPTURE_NMNE, + NMNE_CAPTURE_KEYWORDS, +) from primaite.simulator.network.transmission.data_link_layer import Frame from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.packet_capture import PacketCapture @@ -88,6 +97,18 @@ class NetworkInterface(SimComponent, ABC): pcap: Optional[PacketCapture] = None "A PacketCapture instance for capturing and analysing packets passing through this interface." + nmne: Dict = Field(default_factory=lambda: {}) + "A dict containing details of the number of malicious network events captured." + + def setup_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + super().setup_for_episode(episode=episode) + self.nmne = {} + if episode and self.pcap: + self.pcap.current_episode = episode + self.pcap.setup_logger() + self.enable() + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() @@ -111,16 +132,10 @@ class NetworkInterface(SimComponent, ABC): "enabled": self.enabled, } ) + if CAPTURE_NMNE: + state.update({"nmne": self.nmne}) return state - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - super().reset_component_for_episode(episode) - if episode and self.pcap: - self.pcap.current_episode = episode - self.pcap.setup_logger() - self.enable() - @abstractmethod def enable(self): """Enable the interface.""" @@ -131,6 +146,82 @@ class NetworkInterface(SimComponent, ABC): """Disable the interface.""" pass + def _capture_nmne(self, frame: Frame, inbound: bool = True) -> None: + """ + Processes and captures network frame data based on predefined global NMNE settings. + + This method updates the NMNE structure with counts of malicious network events based on the frame content and + direction. The structure is dynamically adjusted according to the enabled capture settings. + + .. note:: + While there is a lot of logic in this code that defines a multi-level hierarchical NMNE structure, + most of it is unused for now as a result of all `CAPTURE_BY_<>` variables in + ``primaite.simulator.network.nmne`` being hardcoded and set as final. Once they're 'released' and made + configurable, this function will be updated to properly explain the dynamic data structure. + + :param frame: The network frame to process, containing IP, TCP/UDP, and payload information. + :param inbound: Boolean indicating if the frame direction is inbound. Defaults to True. + """ + # Exit function if NMNE capturing is disabled + if not CAPTURE_NMNE: + return + + # Initialise basic frame data variables + direction = "inbound" if inbound else "outbound" # Direction of the traffic + ip_address = str(frame.ip.src_ip_address if inbound else frame.ip.dst_ip_address) # Source or destination IP + protocol = frame.ip.protocol.name # Network protocol used in the frame + + # Initialise port variable; will be determined based on protocol type + port = None + + # Determine the source or destination port based on the protocol (TCP/UDP) + if frame.tcp: + port = frame.tcp.src_port.value if inbound else frame.tcp.dst_port.value + elif frame.udp: + port = frame.udp.src_port.value if inbound else frame.udp.dst_port.value + + # Convert frame payload to string for keyword checking + frame_str = str(frame.payload) + + # Proceed only if any NMNE keyword is present in the frame payload + if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS): + # Start with the root of the NMNE capture structure + current_level = self.nmne + + # Update NMNE structure based on enabled settings + if CAPTURE_BY_DIRECTION: + # Set or get the dictionary for the current direction + current_level = current_level.setdefault("direction", {}) + current_level = current_level.setdefault(direction, {}) + + if CAPTURE_BY_IP_ADDRESS: + # Set or get the dictionary for the current IP address + current_level = current_level.setdefault("ip_address", {}) + current_level = current_level.setdefault(ip_address, {}) + + if CAPTURE_BY_PROTOCOL: + # Set or get the dictionary for the current protocol + current_level = current_level.setdefault("protocol", {}) + current_level = current_level.setdefault(protocol, {}) + + if CAPTURE_BY_PORT: + # Set or get the dictionary for the current port + current_level = current_level.setdefault("port", {}) + current_level = current_level.setdefault(port, {}) + + # Ensure 'KEYWORD' level is present in the structure + keyword_level = current_level.setdefault("keywords", {}) + + # Increment the count for detected keywords in the payload + if CAPTURE_BY_KEYWORD: + for keyword in NMNE_CAPTURE_KEYWORDS: + if keyword in frame_str: + # Update the count for each keyword found + keyword_level[keyword] = keyword_level.get(keyword, 0) + 1 + else: + # Increment a generic counter if keyword capturing is not enabled + keyword_level["*"] = keyword_level.get("*", 0) + 1 + @abstractmethod def send_frame(self, frame: Frame) -> bool: """ @@ -139,7 +230,7 @@ class NetworkInterface(SimComponent, ABC): :param frame: The network frame to be sent. :return: A boolean indicating whether the frame was successfully sent. """ - pass + self._capture_nmne(frame, inbound=False) @abstractmethod def receive_frame(self, frame: Frame) -> bool: @@ -149,7 +240,7 @@ class NetworkInterface(SimComponent, ABC): :param frame: The network frame being received. :return: A boolean indicating whether the frame was successfully received. """ - pass + self._capture_nmne(frame, inbound=True) def __str__(self) -> str: """ @@ -263,6 +354,7 @@ class WiredNetworkInterface(NetworkInterface, ABC): :param frame: The network frame to be sent. :return: True if the frame is sent, False if the Network Interface is disabled or not connected to a link. """ + super().send_frame(frame) if self.enabled: frame.set_sent_timestamp() self.pcap.capture_outbound(frame) @@ -279,7 +371,7 @@ class WiredNetworkInterface(NetworkInterface, ABC): :param frame: The network frame being received. :return: A boolean indicating whether the frame was successfully received. """ - pass + return super().receive_frame(frame) class Layer3Interface(BaseModel, ABC): @@ -409,7 +501,7 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC): except AttributeError: pass - # @abstractmethod + @abstractmethod def receive_frame(self, frame: Frame) -> bool: """ Receives a network frame on the network interface. @@ -417,7 +509,7 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC): :param frame: The network frame being received. :return: A boolean indicating whether the frame was successfully received. """ - pass + return super().receive_frame(frame) class Link(SimComponent): @@ -455,14 +547,6 @@ class Link(SimComponent): self.endpoint_b.connect_link(self) self.endpoint_up() - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"bandwidth", "current_load"} - self._original_state = self.model_dump(include=vals_to_include) - super().set_original_state() - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -648,50 +732,20 @@ class Node(SimComponent): self.session_manager.node = self self.session_manager.software_manager = self.software_manager self._install_system_software() - self.set_original_state() - def set_original_state(self): - """Sets the original state.""" - for software in self.software_manager.software.values(): - software.set_original_state() - - self.file_system.set_original_state() - - for network_interface in self.network_interfaces.values(): - network_interface.set_original_state() - - vals_to_include = { - "hostname", - "default_gateway", - "operating_state", - "revealed_to_red", - "start_up_duration", - "start_up_countdown", - "shut_down_duration", - "shut_down_countdown", - "is_resetting", - "node_scan_duration", - "node_scan_countdown", - "red_scan_countdown", - } - self._original_state = self.model_dump(include=vals_to_include) - - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - super().reset_component_for_episode(episode) - - # Reset Session Manager - self.session_manager.clear() + super().setup_for_episode(episode=episode) # Reset File System - self.file_system.reset_component_for_episode(episode) + self.file_system.setup_for_episode(episode=episode) # Reset all Nics for network_interface in self.network_interfaces.values(): - network_interface.reset_component_for_episode(episode) + network_interface.setup_for_episode(episode=episode) for software in self.software_manager.software.values(): - software.reset_component_for_episode(episode) + software.setup_for_episode(episode=episode) if episode and self.sys_log: self.sys_log.current_episode = episode diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 3f34f736..977380be 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -205,19 +205,10 @@ class NIC(IPWiredNetworkInterface): state = super().describe_state() # Update the state with NIC-specific information - state.update( - { - "wake_on_lan": self.wake_on_lan, - } - ) + state.update({"wake_on_lan": self.wake_on_lan}) return state - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"} - self._original_state = self.model_dump(include=vals_to_include) - def receive_frame(self, frame: Frame) -> bool: """ Attempt to receive and process a network frame from the connected Link. @@ -248,6 +239,7 @@ class NIC(IPWiredNetworkInterface): accept_frame = True if accept_frame: + super().receive_frame(frame) self._connected_node.receive_frame(frame=frame, from_network_interface=self) return True return False diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 22effa2a..f2305652 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -109,24 +109,6 @@ class Firewall(Router): sys_log=kwargs["sys_log"], implicit_action=ACLAction.PERMIT, name=f"{hostname} - External Outbound" ) - self.set_original_state() - - def set_original_state(self): - """Set the original state for the Firewall.""" - super().set_original_state() - vals_to_include = { - "internal_port", - "external_port", - "dmz_port", - "internal_inbound_acl", - "internal_outbound_acl", - "dmz_inbound_acl", - "dmz_outbound_acl", - "external_inbound_acl", - "external_outbound_acl", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - def describe_state(self) -> Dict: """ Describes the current state of the Firewall. diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 774aae7c..aa6eec3a 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -136,11 +136,6 @@ class ACLRule(SimComponent): rule_strings.append(f"{key}={value}") return ", ".join(rule_strings) - def set_original_state(self): - """Sets the original state.""" - vals_to_keep = {"action", "protocol", "src_ip_address", "src_port", "dst_ip_address", "dst_port"} - self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True) - def describe_state(self) -> Dict: """ Describes the current state of the ACLRule. @@ -296,48 +291,6 @@ class AccessControlList(SimComponent): super().__init__(**kwargs) self._acl = [None] * (self.max_acl_rules - 1) - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - self.implicit_rule.set_original_state() - vals_to_keep = {"implicit_action", "max_acl_rules", "acl"} - self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True) - - for i, rule in enumerate(self._acl): - if not rule: - continue - self._default_config[i] = {"action": rule.action.name} - if rule.src_ip_address: - self._default_config[i]["src_ip"] = str(rule.src_ip_address) - if rule.dst_ip_address: - self._default_config[i]["dst_ip"] = str(rule.dst_ip_address) - if rule.src_port: - self._default_config[i]["src_port"] = rule.src_port.name - if rule.dst_port: - self._default_config[i]["dst_port"] = rule.dst_port.name - if rule.protocol: - self._default_config[i]["protocol"] = rule.protocol.name - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.implicit_rule.reset_component_for_episode(episode) - super().reset_component_for_episode(episode) - self._reset_rules_to_default() - - def _reset_rules_to_default(self) -> None: - """Clear all ACL rules and set them to the default rules config.""" - self._acl = [None] * (self.max_acl_rules - 1) - for r_num, r_cfg in self._default_config.items(): - self.add_rule( - action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], - src_ip_address=r_cfg.get("src_ip"), - dst_ip_address=r_cfg.get("dst_ip"), - position=r_num, - ) def _init_request_manager(self) -> RequestManager: # TODO: Add src and dst wildcard masks as positional args in this request. @@ -616,11 +569,6 @@ class RouteEntry(SimComponent): metric: float = 0.0 "The cost metric for this route. Default is 0.0." - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"} - self._original_values = self.model_dump(include=vals_to_include) - def describe_state(self) -> Dict: """ Describes the current state of the RouteEntry. @@ -653,17 +601,6 @@ class RouteTable(SimComponent): default_route: Optional[RouteEntry] = None sys_log: SysLog - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - self._original_state["routes_orig"] = self.routes - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.routes.clear() - self.routes = self._original_state["routes_orig"] - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Describes the current state of the RouteTable. @@ -1104,8 +1041,6 @@ class Router(NetworkNode): self._set_default_acl() - self.set_original_state() - def _install_system_software(self): """ Installs essential system software and network services on the router. @@ -1131,20 +1066,7 @@ class Router(NetworkNode): self.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) self.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - def set_original_state(self): - """ - Sets or resets the router to its original configuration state. - - This method is called to initialize the router's state or to revert it to a known good configuration during - network simulations or after configuration changes. - """ - self.acl.set_original_state() - self.route_table.set_original_state() - super().set_original_state() - vals_to_include = {"num_ports"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """ Resets the router's components for a new network simulation episode. @@ -1154,13 +1076,10 @@ class Router(NetworkNode): :param episode: The episode number for which the router is being reset. """ self.software_manager.arp.clear() - self.acl.reset_component_for_episode(episode) - self.route_table.reset_component_for_episode(episode) - for i, network_interface in self.network_interface.items(): - network_interface.reset_component_for_episode(episode) + for i, _ in self.network_interface.items(): self.enable_port(i) - super().reset_component_for_episode(episode) + super().setup_for_episode(episode=episode) def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() @@ -1391,7 +1310,6 @@ class Router(NetworkNode): network_interface.ip_address = ip_address network_interface.subnet_mask = subnet_mask self.sys_log.info(f"Configured Network Interface {network_interface}") - self.set_original_state() def enable_port(self, port: int): """ @@ -1493,6 +1411,14 @@ class Router(NetworkNode): subnet_mask=port_cfg["subnet_mask"], ) if "acl" in cfg: - new.acl._default_config = cfg["acl"] # save the config to allow resetting - new.acl._reset_rules_to_default() # read the config and apply rules + for r_num, r_cfg in cfg["acl"].items(): + new.acl.add_rule( + action=ACLAction[r_cfg["action"]], + src_port=None if not (p := r_cfg.get("src_port")) else Port[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], + protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_ip_address=r_cfg.get("src_ip"), + dst_ip_address=r_cfg.get("dst_ip"), + position=r_num, + ) return new diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index 33e6ee9a..557ea287 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -32,12 +32,6 @@ class SwitchPort(WiredNetworkInterface): _connected_node: Optional[Switch] = None "The Switch to which the SwitchPort is connected." - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"port_num", "mac_address", "speed", "mtu", "enabled"} - self._original_state = self.model_dump(include=vals_to_include) - super().set_original_state() - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index dd0b58d3..91833d6a 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -122,8 +122,6 @@ class WirelessRouter(Router): self.connect_nic(RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")) - self.set_original_state() - @property def wireless_access_point(self) -> WirelessAccessPoint: """ @@ -166,7 +164,6 @@ class WirelessRouter(Router): network_interface.ip_address = ip_address network_interface.subnet_mask = subnet_mask self.sys_log.info(f"Configured WAP {network_interface}") - self.set_original_state() self.wireless_access_point.frequency = frequency # Set operating frequency self.wireless_access_point.enable() # Re-enable the WAP with new settings diff --git a/src/primaite/simulator/network/nmne.py b/src/primaite/simulator/network/nmne.py new file mode 100644 index 00000000..87839712 --- /dev/null +++ b/src/primaite/simulator/network/nmne.py @@ -0,0 +1,47 @@ +from typing import Dict, Final, List + +CAPTURE_NMNE: bool = True +"""Indicates whether Malicious Network Events (MNEs) should be captured. Default is True.""" + +NMNE_CAPTURE_KEYWORDS: List[str] = [] +"""List of keywords to identify malicious network events.""" + +# TODO: Remove final and make configurable after example layout when the NicObservation creates nmne structure dynamically +CAPTURE_BY_DIRECTION: Final[bool] = True +"""Flag to determine if captures should be organized by traffic direction (inbound/outbound).""" +CAPTURE_BY_IP_ADDRESS: Final[bool] = False +"""Flag to determine if captures should be organized by source or destination IP address.""" +CAPTURE_BY_PROTOCOL: Final[bool] = False +"""Flag to determine if captures should be organized by network protocol (e.g., TCP, UDP).""" +CAPTURE_BY_PORT: Final[bool] = False +"""Flag to determine if captures should be organized by source or destination port.""" +CAPTURE_BY_KEYWORD: Final[bool] = False +"""Flag to determine if captures should be filtered and categorised based on specific keywords.""" + + +def set_nmne_config(nmne_config: Dict): + """ + Sets the configuration for capturing Malicious Network Events (MNEs) based on a provided dictionary. + + This function updates global settings related to NMNE capture, including whether to capture NMNEs and what + keywords to use for identifying NMNEs. + + The function ensures that the settings are updated only if they are provided in the `nmne_config` dictionary, + and maintains type integrity by checking the types of the provided values. + + :param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys include: + "capture_nmne" (bool) to indicate whether NMNEs should be captured, "nmne_capture_keywords" (list of strings) + to specify keywords for NMNE identification. + """ + global NMNE_CAPTURE_KEYWORDS + global CAPTURE_NMNE + + # Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect + CAPTURE_NMNE = nmne_config.get("capture_nmne", False) + if not isinstance(CAPTURE_NMNE, bool): + CAPTURE_NMNE = True # Revert to default True if the provided value is not a boolean + + # Update the NMNE capture keywords, appending new keywords if provided + NMNE_CAPTURE_KEYWORDS += nmne_config.get("nmne_capture_keywords", []) + if not isinstance(NMNE_CAPTURE_KEYWORDS, list): + NMNE_CAPTURE_KEYWORDS = [] # Reset to empty list if the provided value is not a list diff --git a/src/primaite/simulator/sim_container.py b/src/primaite/simulator/sim_container.py index 896861e6..a2285d92 100644 --- a/src/primaite/simulator/sim_container.py +++ b/src/primaite/simulator/sim_container.py @@ -21,13 +21,9 @@ class Simulation(SimComponent): super().__init__(**kwargs) - def set_original_state(self): - """Sets the original state.""" - self.network.set_original_state() - - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - self.network.reset_component_for_episode(episode) + self.network.setup_for_episode(episode=episode) def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 322ac808..513606a9 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -38,12 +38,6 @@ class Application(IOSoftware): def __init__(self, **kwargs): super().__init__(**kwargs) - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"operating_state", "execution_control_status", "num_executions", "groups"} - self._original_state.update(self.model_dump(include=vals_to_include)) - @abstractmethod def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 67c0c9b4..fe8180d7 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -34,7 +34,6 @@ class DatabaseClient(Application): kwargs["port"] = Port.POSTGRES_SERVER kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) - self.set_original_state() def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() @@ -50,19 +49,6 @@ class DatabaseClient(Application): self._connections_status.append(can_connect) return can_connect - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DatabaseClient WebServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"server_ip_address", "server_password", "connected", "_query_success_tracker"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting DataBaseClient state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - self._query_success_tracker.clear() - def describe_state(self) -> Dict: """ Describes the current state of the ACLRule. diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index a844f059..5fe951b7 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -49,26 +49,6 @@ class DataManipulationBot(DatabaseClient): super().__init__(**kwargs) self.name = "DataManipulationBot" - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DataManipulationBot original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = { - "server_ip_address", - "payload", - "server_password", - "port_scan_p_of_success", - "data_manipulation_p_of_success", - "attack_stage", - "repeat", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting DataManipulationBot state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index dfc48dd3..9dac6b25 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -57,27 +57,6 @@ class DoSBot(DatabaseClient, Application): self.name = "DoSBot" self.max_sessions = 1000 # override normal max sessions - def set_original_state(self): - """Set the original state of the Denial of Service Bot.""" - _LOGGER.debug(f"Setting {self.name} original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = { - "target_ip_address", - "target_port", - "payload", - "repeat", - "attack_stage", - "max_sessions", - "port_scan_p_of_success", - "dos_intensity", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting {self.name} state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index eef0ed5d..6f2c479c 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -47,21 +47,8 @@ class WebBrowser(Application): kwargs["port"] = Port.HTTP super().__init__(**kwargs) - self.set_original_state() self.run() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting WebBrowser original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"target_url", "domain_name_ip_address", "latest_response"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting WebBrowser state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( @@ -80,9 +67,6 @@ class WebBrowser(Application): state["history"] = [hist_item.state() for hist_item in self.history] return state - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - def get_webpage(self, url: Optional[str] = None) -> bool: """ Retrieve the webpage. diff --git a/src/primaite/simulator/system/processes/process.py b/src/primaite/simulator/system/processes/process.py index b753e3ad..458a6b5c 100644 --- a/src/primaite/simulator/system/processes/process.py +++ b/src/primaite/simulator/system/processes/process.py @@ -24,12 +24,6 @@ class Process(Software): operating_state: ProcessOperatingState "The current operating state of the Process." - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"operating_state"} - self._original_state.update(self.model_dump(include=vals_to_include)) - @abstractmethod def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 0b9554d5..439d2b78 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -40,25 +40,6 @@ class DatabaseService(Service): super().__init__(**kwargs) self._create_db_file() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DatabaseService original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = { - "password", - "connections", - "backup_server_ip", - "latest_backup_directory", - "latest_backup_file_name", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}") - self.clear_connections() - super().reset_component_for_episode(episode) - def configure_backup(self, backup_server: IPv4Address): """ Set up the database backup. diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 2d3879ff..967af6b2 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -29,18 +29,6 @@ class DNSClient(Service): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DNSClient original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"dns_server"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.dns_cache.clear() - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Describes the current state of the software. diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index 8decf7e9..4d0ebbb8 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -28,20 +28,6 @@ class DNSServer(Service): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DNSServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"dns_table"} - self._original_state["dns_table_orig"] = self.model_dump(include=vals_to_include)["dns_table"] - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.dns_table.clear() - for key, value in self._original_state["dns_table_orig"].items(): - self.dns_table[key] = value - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Describes the current state of the software. diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index 39bc57f0..7c334ced 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -27,18 +27,6 @@ class FTPClient(FTPServiceABC): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting FTPClient original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"connected"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting FTPClient state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket: """ Process the command in the FTP Packet. diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index a82b0919..c5330de2 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -27,19 +27,6 @@ class FTPServer(FTPServiceABC): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting FTPServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"server_password"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}") - self.clear_connections() - super().reset_component_for_episode(episode) - def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket: """ Process the command in the FTP Packet. diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index 43d1d783..ad00065c 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -49,21 +49,12 @@ class NTPClient(Service): state = super().describe_state() return state - def reset_component_for_episode(self, episode: int): - """ - Resets the Service component for a new episode. - - This method ensures the Service is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. - """ - pass - def send( self, payload: NTPPacket, session_id: Optional[str] = None, dest_ip_address: IPv4Address = None, - dest_port: [Port] = Port.NTP, + dest_port: Port = Port.NTP, **kwargs, ) -> bool: """Requests NTP data from NTP server. diff --git a/src/primaite/simulator/system/services/ntp/ntp_server.py b/src/primaite/simulator/system/services/ntp/ntp_server.py index 3ae80936..f9d9ee7c 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_server.py +++ b/src/primaite/simulator/system/services/ntp/ntp_server.py @@ -34,16 +34,6 @@ class NTPServer(Service): state = super().describe_state() return state - def reset_component_for_episode(self, episode: int): - """ - Resets the Service component for a new episode. - - This method ensures the Service is ready for a new episode, including - resetting any stateful properties or statistics, and clearing any message - queues. - """ - pass - def receive( self, payload: NTPPacket, diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 162678a0..4102657c 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -78,12 +78,6 @@ class Service(IOSoftware): """ return super().receive(payload=payload, session_id=session_id, **kwargs) - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"operating_state", "restart_duration", "restart_countdown"} - self._original_state.update(self.model_dump(include=vals_to_include)) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request("scan", RequestType(func=lambda request, context: self.scan())) diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index eaea6bb1..5e7591e9 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -23,18 +23,6 @@ class WebServer(Service): last_response_status_code: Optional[HttpStatusCode] = None - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting WebServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"last_response_status_code"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting WebServer state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -130,7 +118,7 @@ class WebServer(Service): self.set_health_state(SoftwareHealthState.COMPROMISED) return response - except Exception: + except Exception: # TODO: refactor this. Likely to cause silent bugs. (ADO ticket #2345 ) # something went wrong on the server response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR return response diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index ce39930b..8864659c 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -3,7 +3,7 @@ from abc import abstractmethod from datetime import datetime from enum import Enum from ipaddress import IPv4Address, IPv4Network -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, TYPE_CHECKING, Union from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent from primaite.simulator.file_system.file_system import FileSystem, Folder @@ -13,6 +13,9 @@ from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.core.sys_log import SysLog +if TYPE_CHECKING: + from primaite.simulator.system.core.software_manager import SoftwareManager + class SoftwareType(Enum): """ @@ -84,7 +87,7 @@ class Software(SimComponent): "The count of times the software has been scanned, defaults to 0." revealed_to_red: bool = False "Indicates if the software has been revealed to red agent, defaults is False." - software_manager: Any = None + software_manager: "SoftwareManager" = None "An instance of Software Manager that is used by the parent node." sys_log: SysLog = None "An instance of SysLog that is used by the parent node." @@ -97,19 +100,6 @@ class Software(SimComponent): _patching_countdown: Optional[int] = None "Current number of ticks left to patch the software." - def set_original_state(self): - """Sets the original state.""" - vals_to_include = { - "name", - "health_state_actual", - "health_state_visible", - "criticality", - "patching_count", - "scanning_count", - "revealed_to_red", - } - self._original_state = self.model_dump(include=vals_to_include) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( @@ -248,12 +238,6 @@ class IOSoftware(Software): _connections: Dict[str, Dict] = {} "Active connections." - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"installing_count", "max_sessions", "tcp", "udp", "port"} - self._original_state.update(self.model_dump(include=vals_to_include)) - @abstractmethod def describe_state(self) -> Dict: """ diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 892e6af7..017492ad 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -589,15 +589,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index 9b668686..e70814f5 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -593,15 +593,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server @@ -624,7 +625,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml index 5a7d8366..6401bcda 100644 --- a/tests/assets/configs/multi_agent_session.yaml +++ b/tests/assets/configs/multi_agent_session.yaml @@ -1043,16 +1043,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - - ref: database_server type: server @@ -1074,7 +1074,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index 42dd27fb..c2616001 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -599,15 +599,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server @@ -630,7 +631,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/assets/configs/train_only_primaite_session.yaml b/tests/assets/configs/train_only_primaite_session.yaml index 8a4a1178..8ef4b8fd 100644 --- a/tests/assets/configs/train_only_primaite_session.yaml +++ b/tests/assets/configs/train_only_primaite_session.yaml @@ -600,15 +600,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server @@ -631,7 +632,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/conftest.py b/tests/conftest.py index 2add835f..5425deee 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -509,6 +509,6 @@ def game_and_agent(): reward_function=reward_function, ) - game.agents.append(test_agent) + game.agents["test_agent"] = test_agent return (game, test_agent) diff --git a/tests/e2e_integration_tests/environments/test_sb3_environment.py b/tests/e2e_integration_tests/environments/test_sb3_environment.py index 91cf5c1e..c48ddbc9 100644 --- a/tests/e2e_integration_tests/environments/test_sb3_environment.py +++ b/tests/e2e_integration_tests/environments/test_sb3_environment.py @@ -11,14 +11,12 @@ from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv -# @pytest.mark.skip(reason="no way of currently testing this") def test_sb3_compatibility(): """Test that the Gymnasium environment can be used with an SB3 agent.""" with open(example_config_path(), "r") as f: cfg = yaml.safe_load(f) - game = PrimaiteGame.from_config(cfg) - gym = PrimaiteGymEnv(game=game) + gym = PrimaiteGymEnv(game_config=cfg) model = PPO("MlpPolicy", gym) model.learn(total_timesteps=1000) diff --git a/tests/integration_tests/game_configuration.py b/tests/integration_tests/game_configuration.py index 3bd870e3..f3dc51bd 100644 --- a/tests/integration_tests/game_configuration.py +++ b/tests/integration_tests/game_configuration.py @@ -42,20 +42,20 @@ def test_example_config(): assert len(game.agents) == 4 # red, blue and 2 green agents # green agent 1 - assert game.agents[0].agent_name == "client_2_green_user" - assert isinstance(game.agents[0], RandomAgent) + assert "client_2_green_user" in game.agents + assert isinstance(game.agents["client_2_green_user"], RandomAgent) # green agent 2 - assert game.agents[1].agent_name == "client_1_green_user" - assert isinstance(game.agents[1], RandomAgent) + assert "client_1_green_user" in game.agents + assert isinstance(game.agents["client_1_green_user"], RandomAgent) # red agent - assert game.agents[2].agent_name == "client_1_data_manipulation_red_bot" - assert isinstance(game.agents[2], DataManipulationAgent) + assert "client_1_data_manipulation_red_bot" in game.agents + assert isinstance(game.agents["client_1_data_manipulation_red_bot"], DataManipulationAgent) # blue agent - assert game.agents[3].agent_name == "defender" - assert isinstance(game.agents[3], ProxyAgent) + assert "defender" in game.agents + assert isinstance(game.agents["defender"], ProxyAgent) network: Network = game.simulation.network diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py new file mode 100644 index 00000000..85ac23e8 --- /dev/null +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -0,0 +1,120 @@ +from primaite.game.agent.observations import NicObservation +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.nmne import set_nmne_config +from primaite.simulator.sim_container import Simulation +from primaite.simulator.system.applications.database_client import DatabaseClient + + +def test_capture_nmne(uc2_network): + """ + Conducts a test to verify that Malicious Network Events (MNEs) are correctly captured. + + This test involves a web server querying a database server and checks if the MNEs are captured + based on predefined keywords in the network configuration. Specifically, it checks the capture + of the "DELETE" SQL command as a malicious network event. + """ + web_server: Server = uc2_network.get_node_by_hostname("web_server") # noqa + db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] # noqa + db_client.connect() + + db_server: Server = uc2_network.get_node_by_hostname("database_server") # noqa + + web_server_nic = web_server.network_interface[1] + db_server_nic = db_server.network_interface[1] + + # Set the NMNE configuration to capture DELETE queries as MNEs + nmne_config = { + "capture_nmne": True, # Enable the capture of MNEs + "nmne_capture_keywords": ["DELETE"], # Specify "DELETE" SQL command as a keyword for MNE detection + } + + # Apply the NMNE configuration settings + set_nmne_config(nmne_config) + + # Assert that initially, there are no captured MNEs on both web and database servers + assert web_server_nic.describe_state()["nmne"] == {} + assert db_server_nic.describe_state()["nmne"] == {} + + # Perform a "SELECT" query + db_client.query("SELECT") + + # Check that it does not trigger an MNE capture. + assert web_server_nic.describe_state()["nmne"] == {} + assert db_server_nic.describe_state()["nmne"] == {} + + # Perform a "DELETE" query + db_client.query("DELETE") + + # Check that the web server's outbound interface and the database server's inbound interface register the MNE + assert web_server_nic.describe_state()["nmne"] == {"direction": {"outbound": {"keywords": {"*": 1}}}} + assert db_server_nic.describe_state()["nmne"] == {"direction": {"inbound": {"keywords": {"*": 1}}}} + + # Perform another "SELECT" query + db_client.query("SELECT") + + # Check that no additional MNEs are captured + assert web_server_nic.describe_state()["nmne"] == {"direction": {"outbound": {"keywords": {"*": 1}}}} + assert db_server_nic.describe_state()["nmne"] == {"direction": {"inbound": {"keywords": {"*": 1}}}} + + # Perform another "DELETE" query + db_client.query("DELETE") + + # Check that the web server and database server interfaces register an additional MNE + assert web_server_nic.describe_state()["nmne"] == {"direction": {"outbound": {"keywords": {"*": 2}}}} + assert db_server_nic.describe_state()["nmne"] == {"direction": {"inbound": {"keywords": {"*": 2}}}} + + +def test_capture_nmne_observations(uc2_network): + """ + Tests the NicObservation class's functionality within a simulated network environment. + + This test ensures the observation space, as defined by instances of NicObservation, accurately reflects the + number of MNEs detected based on network activities over multiple iterations. + + The test employs a series of "DELETE" SQL operations, considered as MNEs, to validate the dynamic update + and accuracy of the observation space related to network interface conditions. It confirms that the + observed NIC states match expected MNE activity levels. + """ + # Initialise a new Simulation instance and assign the test network to it. + sim = Simulation() + sim.network = uc2_network + + web_server: Server = uc2_network.get_node_by_hostname("web_server") + db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + db_client.connect() + + # Set the NMNE configuration to capture DELETE queries as MNEs + nmne_config = { + "capture_nmne": True, # Enable the capture of MNEs + "nmne_capture_keywords": ["DELETE"], # Specify "DELETE" SQL command as a keyword for MNE detection + } + + # Apply the NMNE configuration settings + set_nmne_config(nmne_config) + + # Define observations for the NICs of the database and web servers + db_server_nic_obs = NicObservation(where=["network", "nodes", "database_server", "NICs", 1]) + web_server_nic_obs = NicObservation(where=["network", "nodes", "web_server", "NICs", 1]) + + # Iterate through a set of test cases to simulate multiple DELETE queries + for i in range(1, 20): + # Perform a "DELETE" query each iteration + db_client.query("DELETE") + + # Observe the current state of NMNEs from the NICs of both the database and web servers + db_nic_obs = db_server_nic_obs.observe(sim.describe_state())["nmne"] + web_nic_obs = web_server_nic_obs.observe(sim.describe_state())["nmne"] + + # Define expected NMNE values based on the iteration count + if i > 10: + expected_nmne = 3 # High level of detected MNEs after 10 iterations + elif i > 5: + expected_nmne = 2 # Moderate level after more than 5 iterations + elif i > 0: + expected_nmne = 1 # Low level detected after just starting + else: + expected_nmne = 0 # No MNEs detected + + # Assert that the observed NMNEs match the expected values for both NICs + assert web_nic_obs["outbound"] == expected_nmne + assert db_nic_obs["inbound"] == expected_nmne diff --git a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py index 01ad3871..786fe851 100644 --- a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py +++ b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py @@ -7,26 +7,11 @@ from primaite.simulator.domain.account import Account, AccountType @pytest.fixture(scope="function") def account() -> Account: acct = Account(username="Jake", password="totally_hashed_password", account_type=AccountType.USER) - acct.set_original_state() return acct def test_original_state(account): """Test the original state - see if it resets properly""" - account.log_on() - account.log_off() - account.disable() - - state = account.describe_state() - assert state["num_logons"] is 1 - assert state["num_logoffs"] is 1 - assert state["num_group_changes"] is 0 - assert state["username"] is "Jake" - assert state["password"] is "totally_hashed_password" - assert state["account_type"] is AccountType.USER.value - assert state["enabled"] is False - - account.reset_component_for_episode(episode=1) state = account.describe_state() assert state["num_logons"] is 0 assert state["num_logoffs"] is 0 @@ -39,13 +24,7 @@ def test_original_state(account): account.log_on() account.log_off() account.disable() - account.set_original_state() - account.log_on() - state = account.describe_state() - assert state["num_logons"] is 2 - - account.reset_component_for_episode(episode=2) state = account.describe_state() assert state["num_logons"] is 1 assert state["num_logoffs"] is 1 diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py index 9366d173..4defc80c 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py @@ -185,38 +185,6 @@ def test_get_file(file_system): file_system.show(full=True) -def test_reset_file_system(file_system): - # file and folder that existed originally - file_system.create_file(file_name="test_file.zip") - file_system.create_folder(folder_name="test_folder") - file_system.set_original_state() - - # create a new file - file_system.create_file(file_name="new_file.txt") - - # create a new folder - file_system.create_folder(folder_name="new_folder") - - # delete the file that existed originally - file_system.delete_file(folder_name="root", file_name="test_file.zip") - assert file_system.get_file(folder_name="root", file_name="test_file.zip") is None - - # delete the folder that existed originally - file_system.delete_folder(folder_name="test_folder") - assert file_system.get_folder(folder_name="test_folder") is None - - # reset - file_system.reset_component_for_episode(episode=1) - - # deleted original file and folder should be back - assert file_system.get_file(folder_name="root", file_name="test_file.zip") - assert file_system.get_folder(folder_name="test_folder") - - # new file and folder should be removed - assert file_system.get_file(folder_name="root", file_name="new_file.txt") is None - assert file_system.get_folder(folder_name="new_folder") is None - - @pytest.mark.skip(reason="Skipping until we tackle serialisation") def test_serialisation(file_system): """Test to check that the object serialisation works correctly.""" diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_container.py b/tests/unit_tests/_primaite/_simulator/_network/test_container.py index 9d424697..2cfc3f11 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_container.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_container.py @@ -31,7 +31,6 @@ def network(example_network) -> Network: assert len(example_network.computers) is 2 assert len(example_network.servers) is 2 - example_network.set_original_state() example_network.show() return example_network @@ -45,40 +44,6 @@ def test_describe_state(network): assert len(state["links"]) is 6 -def test_reset_network(network): - """ - Test that the network is properly reset. - - TODO: make sure that once implemented - any installed/uninstalled services, processes, apps, - etc are also removed/reinstalled - - """ - state_before = network.describe_state() - - client_1: Computer = network.get_node_by_hostname("client_1") - server_1: Computer = network.get_node_by_hostname("server_1") - - assert client_1.operating_state is NodeOperatingState.ON - assert server_1.operating_state is NodeOperatingState.ON - - client_1.power_off() - assert client_1.operating_state is NodeOperatingState.SHUTTING_DOWN - - server_1.power_off() - assert server_1.operating_state is NodeOperatingState.SHUTTING_DOWN - - assert network.describe_state() != state_before - - network.reset_component_for_episode(episode=1) - - assert client_1.operating_state is NodeOperatingState.ON - assert server_1.operating_state is NodeOperatingState.ON - # don't worry if UUIDs change - a = filter_keys_nested_item(json.dumps(network.describe_state(), sort_keys=True, indent=2), ["uuid"]) - b = filter_keys_nested_item(json.dumps(state_before, sort_keys=True, indent=2), ["uuid"]) - assert a == b - - def test_creating_container(): """Check that we can create a network container""" net = Network() diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py index ccf40c44..4bfd28d0 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py @@ -19,7 +19,6 @@ def dos_bot() -> DoSBot: dos_bot: DoSBot = computer.software_manager.software.get("DoSBot") dos_bot.configure(target_ip_address=IPv4Address("192.168.0.1")) - dos_bot.set_original_state() return dos_bot @@ -28,35 +27,6 @@ def test_dos_bot_creation(dos_bot): assert dos_bot is not None -def test_dos_bot_reset(dos_bot): - assert dos_bot.target_ip_address == IPv4Address("192.168.0.1") - assert dos_bot.target_port is Port.POSTGRES_SERVER - assert dos_bot.payload is None - assert dos_bot.repeat is False - - dos_bot.configure( - target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True - ) - - # should reset the relevant items - dos_bot.reset_component_for_episode(episode=0) - assert dos_bot.target_ip_address == IPv4Address("192.168.0.1") - assert dos_bot.target_port is Port.POSTGRES_SERVER - assert dos_bot.payload is None - assert dos_bot.repeat is False - - dos_bot.configure( - target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True - ) - dos_bot.set_original_state() - dos_bot.reset_component_for_episode(episode=1) - # should reset to the configured value - assert dos_bot.target_ip_address == IPv4Address("192.168.1.1") - assert dos_bot.target_port is Port.HTTP - assert dos_bot.payload == "payload" - assert dos_bot.repeat is True - - def test_dos_bot_cannot_run_when_node_offline(dos_bot): dos_bot_node: Computer = dos_bot.parent assert dos_bot_node.operating_state is NodeOperatingState.ON diff --git a/tests/unit_tests/_primaite/_simulator/_system/test_software.py b/tests/unit_tests/_primaite/_simulator/_system/test_software.py index e77cd895..6f680012 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/test_software.py +++ b/tests/unit_tests/_primaite/_simulator/_system/test_software.py @@ -2,12 +2,14 @@ from typing import Dict import pytest +from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.sys_log import SysLog -from primaite.simulator.system.software import Software, SoftwareHealthState +from primaite.simulator.system.services.service import Service +from primaite.simulator.system.software import IOSoftware, SoftwareHealthState -class TestSoftware(Software): +class TestSoftware(Service): def describe_state(self) -> Dict: pass @@ -15,7 +17,11 @@ class TestSoftware(Software): @pytest.fixture(scope="function") def software(file_system): return TestSoftware( - name="TestSoftware", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_service") + name="TestSoftware", + port=Port.ARP, + file_system=file_system, + sys_log=SysLog(hostname="test_service"), + protocol=IPProtocol.TCP, )