diff --git a/CHANGELOG.md b/CHANGELOG.md index 260bfcf0..1c467206 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +- Made environment reset completely recreate the game object. - Changed the red agent in the data manipulation scenario to randomly choose client 1 or client 2 to start its attack. - Changed the data manipulation scenario to include a second green agent on client 1. - Refactored actions and observations to be configurable via object name, instead of UUID. diff --git a/docs/index.rst b/docs/index.rst index 9eae8adc..08e0ac21 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -108,6 +108,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE! source/simulation source/game_layer source/config + source/environment .. toctree:: :caption: Developer information: diff --git a/docs/source/environment.rst b/docs/source/environment.rst new file mode 100644 index 00000000..2b76572d --- /dev/null +++ b/docs/source/environment.rst @@ -0,0 +1,10 @@ +RL Environments +*************** + +RL environments are the objects that directly interface with RL libraries such as Stable-Baselines3 and Ray RLLib. The PrimAITE simulation is exposed via three different environment APIs: + +* Gymnasium API - this is the standard interface that works with many RL libraries like SB3, Ray, Tianshou, etc. ``PrimaiteGymEnv`` adheres to the `Official Gymnasium documentation `_. +* Ray Single agent API - For training a single Ray RLLib agent +* Ray MARL API - For training multi-agent systems with Ray RLLib. ``PrimaiteRayMARLEnv`` adheres to the `Official Ray documentation `_. + +There are Jupyter notebooks which demonstrate integration with each of these three environments. They are located in ``~/primaite//notebooks/example_notebooks``. diff --git a/docs/source/game_layer.rst b/docs/source/game_layer.rst index eb9b17c3..39ab7bde 100644 --- a/docs/source/game_layer.rst +++ b/docs/source/game_layer.rst @@ -20,6 +20,11 @@ The game layer is responsible for managing agents and getting them to interface PrimAITE Session ^^^^^^^^^^^^^^^^ +.. admonition:: Deprecated + :class: deprecated + + PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The `session` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality. + ``PrimaiteSession`` is the main entry point into Primaite and it allows the simultaneous coordination of a simulation and agents that interact with it. ``PrimaiteSession`` keeps track of multiple agents of different types. Agents diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst index 706397b6..87a3f03d 100644 --- a/docs/source/primaite_session.rst +++ b/docs/source/primaite_session.rst @@ -4,6 +4,11 @@ .. _run a primaite session: +.. admonition:: Deprecated + :class: deprecated + + PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The ``session`` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality. + Run a PrimAITE Session ====================== diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 6076553c..478124a9 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -14,7 +14,7 @@ io_settings: save_checkpoints: true checkpoint_interval: 5 save_step_metadata: false - save_pcap_logs: true + save_pcap_logs: false save_sys_logs: true @@ -656,12 +656,13 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index 12461547..2b54eb37 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -1,15 +1,23 @@ training_config: rl_framework: RLLIB_multi_agent - # rl_framework: SB3 + rl_algorithm: PPO + seed: 333 + n_learn_episodes: 1 + n_eval_episodes: 5 + max_steps_per_episode: 256 + deterministic_eval: false n_agents: 2 agent_references: - defender_1 - defender_2 + io_settings: save_checkpoints: true checkpoint_interval: 5 save_step_metadata: false + save_pcap_logs: false + save_sys_logs: true game: @@ -36,9 +44,9 @@ agents: - type: NODE_APPLICATION_EXECUTE options: nodes: - - node_ref: client_2 + - node_name: client_2 applications: - - application_ref: client_2_web_browser + - application_name: WebBrowser max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 @@ -54,6 +62,31 @@ agents: frequency: 4 variance: 3 + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: WebBrowser + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 1 + reward_function: + reward_components: + - type: DUMMY + + + + - ref: data_manipulation_attacker team: RED type: RedDatabaseCorruptingAgent @@ -72,7 +105,7 @@ agents: - type: NODE_OS_SCAN options: nodes: - - node_ref: client_1 + - node_name: client_1 applications: - application_name: DataManipulationBot - node_name: client_2 @@ -104,25 +137,21 @@ agents: num_files_per_folder: 1 num_nics_per_node: 2 nodes: - - node_ref: domain_controller + - node_hostname: domain_controller services: - - service_ref: domain_controller_dns_server - - node_ref: web_server + - service_name: DNSServer + - node_hostname: web_server services: - - service_ref: web_server_database_client - - node_ref: database_server - services: - - service_ref: database_service + - service_name: WebServer + - node_hostname: database_server folders: - folder_name: database files: - file_name: database.db - - node_ref: backup_server - # services: - # - service_ref: backup_service - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - node_hostname: backup_server + - node_hostname: security_suite + - node_hostname: client_1 + - node_hostname: client_2 links: - link_ref: router_1___switch_1 - link_ref: router_1___switch_2 @@ -137,23 +166,23 @@ agents: acl: options: max_acl_rules: 10 - router_node_ref: router_1 + router_hostname: router_1 ip_address_order: - - node_ref: domain_controller + - node_hostname: domain_controller nic_num: 1 - - node_ref: web_server + - node_hostname: web_server nic_num: 1 - - node_ref: database_server + - node_hostname: database_server nic_num: 1 - - node_ref: backup_server + - node_hostname: backup_server nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 1 - - node_ref: client_1 + - node_hostname: client_1 nic_num: 1 - - node_ref: client_2 + - node_hostname: client_2 nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 2 ics: null @@ -184,10 +213,10 @@ agents: - type: NODE_RESET - type: NETWORK_ACL_ADDRULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_ACL_REMOVERULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_NIC_ENABLE - type: NETWORK_NIC_DISABLE @@ -242,25 +271,25 @@ agents: action: "NODE_FILE_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 10: action: "NODE_FILE_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 11: action: "NODE_FILE_DELETE" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 12: action: "NODE_FILE_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 13: action: "NODE_SERVICE_PATCH" @@ -271,22 +300,22 @@ agents: action: "NODE_FOLDER_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 15: action: "NODE_FOLDER_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: node_id: 2 - folder_id: 1 + folder_id: 0 18: action: "NODE_OS_SCAN" options: @@ -303,63 +332,63 @@ agents: action: "NODE_RESET" options: node_id: 5 - 22: + 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" action: "NETWORK_ACL_ADDRULE" options: position: 1 permission: 2 - source_ip_id: 7 - dest_ip_id: 1 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 23: + 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 2 permission: 2 - source_ip_id: 8 - dest_ip_id: 1 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 24: + 24: # block tcp traffic from client 1 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 3 permission: 2 - source_ip_id: 7 - dest_ip_id: 3 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 25: + 25: # block tcp traffic from client 2 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 4 permission: 2 - source_ip_id: 8 - dest_ip_id: 3 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 26: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 5 permission: 2 - source_ip_id: 7 - dest_ip_id: 4 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 27: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 6 permission: 2 - source_ip_id: 8 - dest_ip_id: 4 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 @@ -407,123 +436,148 @@ agents: action: "NETWORK_NIC_DISABLE" options: node_id: 0 - nic_id: 1 + nic_id: 0 39: action: "NETWORK_NIC_ENABLE" options: node_id: 0 - nic_id: 1 + nic_id: 0 40: action: "NETWORK_NIC_DISABLE" options: node_id: 1 - nic_id: 1 + nic_id: 0 41: action: "NETWORK_NIC_ENABLE" options: node_id: 1 - nic_id: 1 + nic_id: 0 42: action: "NETWORK_NIC_DISABLE" options: node_id: 2 - nic_id: 1 + nic_id: 0 43: action: "NETWORK_NIC_ENABLE" options: node_id: 2 - nic_id: 1 + nic_id: 0 44: action: "NETWORK_NIC_DISABLE" options: node_id: 3 - nic_id: 1 + nic_id: 0 45: action: "NETWORK_NIC_ENABLE" options: node_id: 3 - nic_id: 1 + nic_id: 0 46: action: "NETWORK_NIC_DISABLE" options: node_id: 4 - nic_id: 1 + nic_id: 0 47: action: "NETWORK_NIC_ENABLE" options: node_id: 4 - nic_id: 1 + nic_id: 0 48: action: "NETWORK_NIC_DISABLE" options: node_id: 4 - nic_id: 2 + nic_id: 1 49: action: "NETWORK_NIC_ENABLE" options: node_id: 4 - nic_id: 2 + nic_id: 1 50: action: "NETWORK_NIC_DISABLE" options: node_id: 5 - nic_id: 1 + nic_id: 0 51: action: "NETWORK_NIC_ENABLE" options: node_id: 5 - nic_id: 1 + nic_id: 0 52: action: "NETWORK_NIC_DISABLE" options: node_id: 6 - nic_id: 1 + nic_id: 0 53: action: "NETWORK_NIC_ENABLE" options: node_id: 6 - nic_id: 1 + nic_id: 0 options: nodes: - - node_ref: domain_controller - - node_ref: web_server + - node_name: domain_controller + - node_name: web_server + applications: + - application_name: DatabaseClient services: - - service_ref: web_server_web_service - - node_ref: database_server + - service_name: WebServer + - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db services: - - service_ref: database_service - - node_ref: backup_server - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - service_name: DatabaseService + - node_name: backup_server + - node_name: security_suite + - node_name: client_1 + - node_name: client_2 + max_folders_per_node: 2 max_files_per_folder: 2 max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_name: domain_controller + nic_num: 1 + - node_name: web_server + nic_num: 1 + - node_name: database_server + nic_num: 1 + - node_name: backup_server + nic_num: 1 + - node_name: security_suite + nic_num: 1 + - node_name: client_1 + nic_num: 1 + - node_name: client_2 + nic_num: 1 + - node_name: security_suite + nic_num: 2 + reward_function: reward_components: - type: DATABASE_FILE_INTEGRITY - weight: 0.5 + weight: 0.34 options: - node_ref: database_server + node_hostname: database_server folder_name: database file_name: database.db - - - - type: WEB_SERVER_404_PENALTY - weight: 0.5 + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.33 options: - node_ref: web_server - service_ref: web_server_web_service + node_hostname: client_1 + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.33 + options: + node_hostname: client_2 agent_settings: - # ... - + flatten_obs: true - ref: defender_2 team: BLUE @@ -537,25 +591,21 @@ agents: num_files_per_folder: 1 num_nics_per_node: 2 nodes: - - node_ref: domain_controller + - node_hostname: domain_controller services: - - service_ref: domain_controller_dns_server - - node_ref: web_server + - service_name: DNSServer + - node_hostname: web_server services: - - service_ref: web_server_database_client - - node_ref: database_server - services: - - service_ref: database_service + - service_name: WebServer + - node_hostname: database_server folders: - folder_name: database files: - file_name: database.db - - node_ref: backup_server - # services: - # - service_ref: backup_service - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - node_hostname: backup_server + - node_hostname: security_suite + - node_hostname: client_1 + - node_hostname: client_2 links: - link_ref: router_1___switch_1 - link_ref: router_1___switch_2 @@ -570,23 +620,23 @@ agents: acl: options: max_acl_rules: 10 - router_node_ref: router_1 + router_hostname: router_1 ip_address_order: - - node_ref: domain_controller + - node_hostname: domain_controller nic_num: 1 - - node_ref: web_server + - node_hostname: web_server nic_num: 1 - - node_ref: database_server + - node_hostname: database_server nic_num: 1 - - node_ref: backup_server + - node_hostname: backup_server nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 1 - - node_ref: client_1 + - node_hostname: client_1 nic_num: 1 - - node_ref: client_2 + - node_hostname: client_2 nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 2 ics: null @@ -617,10 +667,10 @@ agents: - type: NODE_RESET - type: NETWORK_ACL_ADDRULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_ACL_REMOVERULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_NIC_ENABLE - type: NETWORK_NIC_DISABLE @@ -675,25 +725,25 @@ agents: action: "NODE_FILE_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 10: action: "NODE_FILE_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 11: action: "NODE_FILE_DELETE" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 12: action: "NODE_FILE_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 13: action: "NODE_SERVICE_PATCH" @@ -704,22 +754,22 @@ agents: action: "NODE_FOLDER_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 15: action: "NODE_FOLDER_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: node_id: 2 - folder_id: 1 + folder_id: 0 18: action: "NODE_OS_SCAN" options: @@ -736,63 +786,63 @@ agents: action: "NODE_RESET" options: node_id: 5 - 22: + 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" action: "NETWORK_ACL_ADDRULE" options: position: 1 permission: 2 - source_ip_id: 7 - dest_ip_id: 1 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 23: + 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 2 permission: 2 - source_ip_id: 8 - dest_ip_id: 1 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 24: + 24: # block tcp traffic from client 1 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 3 permission: 2 - source_ip_id: 7 - dest_ip_id: 3 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 25: + 25: # block tcp traffic from client 2 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 4 permission: 2 - source_ip_id: 8 - dest_ip_id: 3 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 26: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 5 permission: 2 - source_ip_id: 7 - dest_ip_id: 4 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 27: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 6 permission: 2 - source_ip_id: 8 - dest_ip_id: 4 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 @@ -840,122 +890,148 @@ agents: action: "NETWORK_NIC_DISABLE" options: node_id: 0 - nic_id: 1 + nic_id: 0 39: action: "NETWORK_NIC_ENABLE" options: node_id: 0 - nic_id: 1 + nic_id: 0 40: action: "NETWORK_NIC_DISABLE" options: node_id: 1 - nic_id: 1 + nic_id: 0 41: action: "NETWORK_NIC_ENABLE" options: node_id: 1 - nic_id: 1 + nic_id: 0 42: action: "NETWORK_NIC_DISABLE" options: node_id: 2 - nic_id: 1 + nic_id: 0 43: action: "NETWORK_NIC_ENABLE" options: node_id: 2 - nic_id: 1 + nic_id: 0 44: action: "NETWORK_NIC_DISABLE" options: node_id: 3 - nic_id: 1 + nic_id: 0 45: action: "NETWORK_NIC_ENABLE" options: node_id: 3 - nic_id: 1 + nic_id: 0 46: action: "NETWORK_NIC_DISABLE" options: node_id: 4 - nic_id: 1 + nic_id: 0 47: action: "NETWORK_NIC_ENABLE" options: node_id: 4 - nic_id: 1 + nic_id: 0 48: action: "NETWORK_NIC_DISABLE" options: node_id: 4 - nic_id: 2 + nic_id: 1 49: action: "NETWORK_NIC_ENABLE" options: node_id: 4 - nic_id: 2 + nic_id: 1 50: action: "NETWORK_NIC_DISABLE" options: node_id: 5 - nic_id: 1 + nic_id: 0 51: action: "NETWORK_NIC_ENABLE" options: node_id: 5 - nic_id: 1 + nic_id: 0 52: action: "NETWORK_NIC_DISABLE" options: node_id: 6 - nic_id: 1 + nic_id: 0 53: action: "NETWORK_NIC_ENABLE" options: node_id: 6 - nic_id: 1 + nic_id: 0 options: nodes: - - node_ref: domain_controller - - node_ref: web_server + - node_name: domain_controller + - node_name: web_server + applications: + - application_name: DatabaseClient services: - - service_ref: web_server_web_service - - node_ref: database_server + - service_name: WebServer + - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db services: - - service_ref: database_service - - node_ref: backup_server - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - service_name: DatabaseService + - node_name: backup_server + - node_name: security_suite + - node_name: client_1 + - node_name: client_2 + max_folders_per_node: 2 max_files_per_folder: 2 max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_name: domain_controller + nic_num: 1 + - node_name: web_server + nic_num: 1 + - node_name: database_server + nic_num: 1 + - node_name: backup_server + nic_num: 1 + - node_name: security_suite + nic_num: 1 + - node_name: client_1 + nic_num: 1 + - node_name: client_2 + nic_num: 1 + - node_name: security_suite + nic_num: 2 + reward_function: reward_components: - type: DATABASE_FILE_INTEGRITY - weight: 0.5 + weight: 0.34 options: - node_ref: database_server + node_hostname: database_server folder_name: database file_name: database.db - - - - type: WEB_SERVER_404_PENALTY - weight: 0.5 + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.33 options: - node_ref: web_server - service_ref: web_server_web_service + node_hostname: client_1 + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.33 + options: + node_hostname: client_2 agent_settings: - # ... + flatten_obs: true @@ -1036,12 +1112,13 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server @@ -1093,10 +1170,14 @@ simulation: - ref: data_manipulation_bot type: DataManipulationBot options: - port_scan_p_of_success: 0.1 - data_manipulation_p_of_success: 0.1 + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 + - ref: client_1_web_browser + type: WebBrowser + options: + target_url: http://arcd.com/users/ services: - ref: client_1_dns_client type: DNSClient @@ -1113,6 +1194,13 @@ simulation: type: WebBrowser options: target_url: http://arcd.com/users/ + - ref: data_manipulation_bot + type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 services: - ref: client_2_dns_client type: DNSClient diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 07b9bc9e..8d272418 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -83,18 +83,15 @@ class PrimaiteGame: self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" - self.agents: List[AbstractAgent] = [] - """List of agents.""" + self.agents: Dict[str, AbstractAgent] = {} + """Mapping from agent name to agent object.""" - self.rl_agents: List[ProxyAgent] = [] - """Subset of agent list including only the reinforcement learning agents.""" + self.rl_agents: Dict[str, ProxyAgent] = {} + """Subset of agents which are intended for reinforcement learning.""" self.step_counter: int = 0 """Current timestep within the episode.""" - self.episode_counter: int = 0 - """Current episode number.""" - self.options: PrimaiteGameOptions """Special options that apply for the entire game.""" @@ -140,7 +137,7 @@ class PrimaiteGame: self.update_agents(sim_state) # Apply all actions to simulation as requests - agent_actions = self.apply_agent_actions() # noqa + self.apply_agent_actions() # Advance timestep self.advance_timestep() @@ -151,7 +148,7 @@ class PrimaiteGame: def update_agents(self, state: Dict) -> None: """Update agents' observations and rewards based on the current state.""" - for agent in self.agents: + for _, agent in self.agents.items(): agent.update_observation(state) agent.update_reward(state) agent.reward_function.total_reward += agent.reward_function.current_reward @@ -165,7 +162,7 @@ class PrimaiteGame: """ agent_actions = {} - for agent in self.agents: + for _, agent in self.agents.items(): obs = agent.observation_manager.current_observation rew = agent.reward_function.current_reward action_choice, options = agent.get_action(obs, rew) @@ -188,20 +185,14 @@ class PrimaiteGame: return True return False - def reset(self) -> None: - """Reset the game, this will reset the simulation.""" - self.episode_counter += 1 - self.step_counter = 0 - _LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}") - self.simulation.reset_component_for_episode(episode=self.episode_counter) - for agent in self.agents: - agent.reward_function.total_reward = 0.0 - agent.reset_agent_for_episode() - def close(self) -> None: """Close the game, this will close the simulation.""" return NotImplemented + def setup_for_episode(self, episode: int) -> None: + """Perform any final configuration of components to make them ready for the game to start.""" + self.simulation.setup_for_episode(episode=episode) + @classmethod def from_config(cls, cfg: Dict) -> "PrimaiteGame": """Create a PrimaiteGame object from a config dictionary. @@ -280,8 +271,9 @@ class PrimaiteGame: # start the service new_service.start() else: - _LOGGER.warning(f"service type not found {service_type}") - + msg = f"Configuration contains an invalid service type: {service_type}" + _LOGGER.error(msg) + raise ValueError(msg) # service-dependent options if service_type == "DNSClient": if "options" in service_cfg: @@ -318,7 +310,9 @@ class PrimaiteGame: new_application = new_node.software_manager.software[application_type] game.ref_map_applications[application_ref] = new_application.uuid else: - _LOGGER.warning(f"application type not found {application_type}") + msg = f"Configuration contains an invalid application type: {application_type}" + _LOGGER.error(msg) + raise ValueError(msg) # run the application new_application.run() @@ -419,7 +413,6 @@ class PrimaiteGame: reward_function=reward_function, agent_settings=agent_settings, ) - game.agents.append(new_agent) elif agent_type == "ProxyAgent": new_agent = ProxyAgent( agent_name=agent_cfg["ref"], @@ -428,8 +421,7 @@ class PrimaiteGame: reward_function=reward_function, agent_settings=agent_settings, ) - game.agents.append(new_agent) - game.rl_agents.append(new_agent) + game.rl_agents[agent_cfg["ref"]] = new_agent elif agent_type == "RedDatabaseCorruptingAgent": new_agent = DataManipulationAgent( agent_name=agent_cfg["ref"], @@ -438,11 +430,11 @@ class PrimaiteGame: reward_function=reward_function, agent_settings=agent_settings, ) - game.agents.append(new_agent) else: - _LOGGER.warning(f"agent type {agent_type} not found") - - game.simulation.set_original_state() + msg(f"Configuration error: {agent_type} is not a valid agent type.") + _LOGGER.error(msg) + raise ValueError(msg) + game.agents[agent_cfg["ref"]] = new_agent # Set the NMNE capture config set_nmne_config(cfg["simulation"]["network"].get("nmne_config", {})) diff --git a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb index 0d4b6d0e..4ef02443 100644 --- a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb +++ b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb @@ -60,7 +60,7 @@ " policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n", " policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n", " )\n", - " .environment(env=PrimaiteRayMARLEnv, env_config={\"cfg\":cfg})#, disable_env_checking=True)\n", + " .environment(env=PrimaiteRayMARLEnv, env_config=cfg)#, disable_env_checking=True)\n", " .rollouts(num_rollout_workers=0)\n", " .training(train_batch_size=128)\n", " )\n" @@ -88,6 +88,13 @@ " param_space=config\n", ").fit()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/src/primaite/notebooks/training_example_ray_single_agent.ipynb b/src/primaite/notebooks/training_example_ray_single_agent.ipynb index ea006ae9..3c27bdc6 100644 --- a/src/primaite/notebooks/training_example_ray_single_agent.ipynb +++ b/src/primaite/notebooks/training_example_ray_single_agent.ipynb @@ -54,7 +54,7 @@ "metadata": {}, "outputs": [], "source": [ - "env_config = {\"cfg\":cfg}\n", + "env_config = cfg\n", "\n", "config = (\n", " PPOConfig()\n", diff --git a/src/primaite/notebooks/training_example_sb3.ipynb b/src/primaite/notebooks/training_example_sb3.ipynb index e5085c5e..0472854e 100644 --- a/src/primaite/notebooks/training_example_sb3.ipynb +++ b/src/primaite/notebooks/training_example_sb3.ipynb @@ -27,9 +27,7 @@ "outputs": [], "source": [ "with open(example_config_path(), 'r') as f:\n", - " cfg = yaml.safe_load(f)\n", - "\n", - "game = PrimaiteGame.from_config(cfg)" + " cfg = yaml.safe_load(f)\n" ] }, { @@ -38,7 +36,7 @@ "metadata": {}, "outputs": [], "source": [ - "gym = PrimaiteGymEnv(game=game)" + "gym = PrimaiteGymEnv(game_config=cfg)" ] }, { @@ -65,7 +63,7 @@ "metadata": {}, "outputs": [], "source": [ - "model.learn(total_timesteps=1000)\n" + "model.learn(total_timesteps=10)\n" ] }, { @@ -76,6 +74,13 @@ "source": [ "model.save(\"deleteme\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/uc2_demo.ipynb index b1e12370..c8f2595b 100644 --- a/src/primaite/notebooks/uc2_demo.ipynb +++ b/src/primaite/notebooks/uc2_demo.ipynb @@ -346,9 +346,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", @@ -358,9 +356,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "# Imports\n", @@ -383,9 +379,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "# create the env\n", @@ -396,10 +390,10 @@ " cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['data_manipulation_p_of_success'] = 1.0\n", " cfg['simulation']['network']['nodes'][8]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n", " cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n", - "game = PrimaiteGame.from_config(cfg)\n", - "env = PrimaiteGymEnv(game = game)\n", - "# Don't flatten obs as we are not training an agent and we wish to see the dict-formatted observations\n", - "env.agent.flatten_obs = False\n", + " # don't flatten observations so that we can see what is going on\n", + " cfg['agents'][3]['agent_settings']['flatten_obs'] = False\n", + "\n", + "env = PrimaiteGymEnv(game_config = cfg)\n", "obs, info = env.reset()\n", "print('env created successfully')\n", "pprint(obs)" @@ -433,9 +427,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "for step in range(35):\n", @@ -453,9 +445,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "pprint(obs['NODES'])" @@ -471,9 +461,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "obs, reward, terminated, truncated, info = env.step(9) # scan database file\n", @@ -499,9 +487,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "obs, reward, terminated, truncated, info = env.step(13) # patch the database\n", @@ -526,9 +512,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "obs, reward, terminated, truncated, info = env.step(0) # patch the database\n", @@ -551,9 +535,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "env.step(13) # Patch the database\n", @@ -593,6 +575,22 @@ "obs['ACL']" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Reset the environment, you can rerun the other cells to verify that the attack works the same every episode." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.reset()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -603,7 +601,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "venv", "language": "python", "name": "python3" }, @@ -617,9 +615,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.10.12" } }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 2 } diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index a3831bc1..f8dbab9d 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Final, Optional, SupportsFloat, Tuple +from typing import Any, Dict, Optional, SupportsFloat, Tuple import gymnasium from gymnasium.core import ActType, ObsType @@ -18,11 +18,23 @@ class PrimaiteGymEnv(gymnasium.Env): assumptions about the agent list always having a list of length 1. """ - def __init__(self, game: PrimaiteGame): + def __init__(self, game_config: Dict): """Initialise the environment.""" super().__init__() - self.game: "PrimaiteGame" = game - self.agent: ProxyAgent = self.game.rl_agents[0] + self.game_config: Dict = game_config + """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" + self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config) + """Current game.""" + self._agent_name = next(iter(self.game.rl_agents)) + """Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key.""" + + self.episode_counter: int = 0 + """Current episode number.""" + + @property + def agent(self) -> ProxyAgent: + """Grab a fresh reference to the agent object because it will be reinstantiated each episode.""" + return self.game.rl_agents[self._agent_name] def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """Perform a step in the environment.""" @@ -45,13 +57,13 @@ class PrimaiteGymEnv(gymnasium.Env): return next_obs, reward, terminated, truncated, info def _write_step_metadata_json(self, action: int, state: Dict, reward: int): - output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata" + output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata" output_dir.mkdir(parents=True, exist_ok=True) path = output_dir / f"step_{self.game.step_counter}.json" data = { - "episode": self.game.episode_counter, + "episode": self.episode_counter, "step": self.game.step_counter, "action": int(action), "reward": int(reward), @@ -63,10 +75,12 @@ class PrimaiteGymEnv(gymnasium.Env): def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: """Reset the environment.""" print( - f"Resetting environment, episode {self.game.episode_counter}, " - f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}" + f"Resetting environment, episode {self.episode_counter}, " + f"avg. reward: {self.agent.reward_function.total_reward}" ) - self.game.reset() + self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config) + self.game.setup_for_episode(episode=self.episode_counter) + self.episode_counter += 1 state = self.game.get_sim_state() self.game.update_agents(state) next_obs = self._get_obs() @@ -88,12 +102,12 @@ class PrimaiteGymEnv(gymnasium.Env): def _get_obs(self) -> ObsType: """Return the current observation.""" - if not self.agent.flatten_obs: - return self.agent.observation_manager.current_observation - else: + if self.agent.flatten_obs: unflat_space = self.agent.observation_manager.space unflat_obs = self.agent.observation_manager.current_observation return gymnasium.spaces.flatten(unflat_space, unflat_obs) + else: + return self.agent.observation_manager.current_observation class PrimaiteRayEnv(gymnasium.Env): @@ -102,12 +116,11 @@ class PrimaiteRayEnv(gymnasium.Env): def __init__(self, env_config: Dict) -> None: """Initialise the environment. - :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` - which is the PrimaiteGame instance. - :type env_config: Dict[str, PrimaiteGame] + :param env_config: A dictionary containing the environment configuration. + :type env_config: Dict """ - self.env = PrimaiteGymEnv(game=PrimaiteGame.from_config(env_config["cfg"])) - self.env.game.episode_counter -= 1 + self.env = PrimaiteGymEnv(game_config=env_config) + self.env.episode_counter -= 1 self.action_space = self.env.action_space self.observation_space = self.env.observation_space @@ -128,13 +141,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` which is the PrimaiteGame instance. - :type env_config: Dict[str, PrimaiteGame] + :type env_config: Dict """ - self.game: PrimaiteGame = PrimaiteGame.from_config(env_config["cfg"]) + self.game_config: Dict = env_config + """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" + self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config) """Reference to the primaite game""" - self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents} - """List of all possible agents in the environment. This list should not change!""" - self._agent_ids = list(self.agents.keys()) + self._agent_ids = list(self.game.rl_agents.keys()) + """Agent ids. This is a list of strings of agent names.""" + self.episode_counter: int = 0 + """Current episode number.""" self.terminateds = set() self.truncateds = set() @@ -149,9 +165,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): ) super().__init__() + @property + def agents(self) -> Dict[str, ProxyAgent]: + """Grab a fresh reference to the agents from this episode's game object.""" + return {name: self.game.rl_agents[name] for name in self._agent_ids} + def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" - self.game.reset() + self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config) + self.game.setup_for_episode(episode=self.episode_counter) + self.episode_counter += 1 state = self.game.get_sim_state() self.game.update_agents(state) next_obs = self._get_obs() @@ -172,7 +195,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): # 1. Perform actions for agent_name, action in actions.items(): self.agents[agent_name].store_action(action) - agent_actions = self.game.apply_agent_actions() + self.game.apply_agent_actions() # 2. Advance timestep self.game.advance_timestep() @@ -186,7 +209,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()} terminateds = {name: False for name, _ in self.agents.items()} truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()} - infos = {"agent_actions": agent_actions} + infos = {name: {} for name, _ in self.agents.items()} terminateds["__all__"] = len(self.terminateds) == len(self.agents) truncateds["__all__"] = self.game.calculate_truncated() if self.game.save_step_metadata: @@ -194,13 +217,13 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): return next_obs, rewards, terminateds, truncateds, infos def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict): - output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata" + output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata" output_dir.mkdir(parents=True, exist_ok=True) path = output_dir / f"step_{self.game.step_counter}.json" data = { - "episode": self.game.episode_counter, + "episode": self.episode_counter, "step": self.game.step_counter, "actions": {agent_name: int(action) for agent_name, action in actions.items()}, "reward": rewards, @@ -212,8 +235,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def _get_obs(self) -> Dict[str, ObsType]: """Return the current observation.""" obs = {} - for name, agent in self.agents.items(): + for agent_name in self._agent_ids: + agent = self.game.rl_agents[agent_name] unflat_space = agent.observation_manager.space unflat_obs = agent.observation_manager.current_observation - obs[name] = gymnasium.spaces.flatten(unflat_space, unflat_obs) + obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs) return obs diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index 5c663cfd..b8f80e95 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -101,11 +101,11 @@ class PrimaiteSession: # CREATE ENVIRONMENT if sess.training_options.rl_framework == "RLLIB_single_agent": - sess.env = PrimaiteRayEnv(env_config={"cfg": cfg}) + sess.env = PrimaiteRayEnv(env_config=cfg) elif sess.training_options.rl_framework == "RLLIB_multi_agent": - sess.env = PrimaiteRayMARLEnv(env_config={"cfg": cfg}) + sess.env = PrimaiteRayMARLEnv(env_config=cfg) elif sess.training_options.rl_framework == "SB3": - sess.env = PrimaiteGymEnv(game=game) + sess.env = PrimaiteGymEnv(game_config=cfg) sess.policy = PolicyABC.from_config(sess.training_options, session=sess) if agent_load_path: diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 964dac01..99e9be7f 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -153,22 +153,18 @@ class SimComponent(BaseModel): uuid: str = Field(default_factory=lambda: str(uuid4())) """The component UUID.""" - _original_state: Dict = {} - def __init__(self, **kwargs): super().__init__(**kwargs) self._request_manager: RequestManager = self._init_request_manager() self._parent: Optional["SimComponent"] = None - # @abstractmethod - def set_original_state(self): - """Sets the original state.""" - pass + def setup_for_episode(self, episode: int): + """ + Perform any additional setup on this component that can't happen during __init__. - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - for key, value in self._original_state.items(): - self.__setattr__(key, value) + For instance, some components may require for the entire network to exist before some configuration can be set. + """ + pass def _init_request_manager(self) -> RequestManager: """ diff --git a/src/primaite/simulator/domain/account.py b/src/primaite/simulator/domain/account.py index d9dad06a..186caf5b 100644 --- a/src/primaite/simulator/domain/account.py +++ b/src/primaite/simulator/domain/account.py @@ -42,19 +42,6 @@ class Account(SimComponent): "Account Type, currently this can be service account (used by apps) or user account." enabled: bool = True - def set_original_state(self): - """Sets the original state.""" - vals_to_include = { - "num_logons", - "num_logoffs", - "num_group_changes", - "username", - "password", - "account_type", - "enabled", - } - self._original_state = self.model_dump(include=vals_to_include) - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/file_system/file.py b/src/primaite/simulator/file_system/file.py index 608a1d78..d9b02e8e 100644 --- a/src/primaite/simulator/file_system/file.py +++ b/src/primaite/simulator/file_system/file.py @@ -73,20 +73,6 @@ class File(FileSystemItemABC): self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})") - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting File ({self.path}) original state on node {self.sys_log.hostname}") - super().set_original_state() - vals_to_include = {"folder_id", "folder_name", "file_type", "sim_size", "real", "sim_path", "sim_root"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting File ({self.path}) state on node {self.sys_log.hostname}") - super().reset_component_for_episode(episode) - @property def path(self) -> str: """ diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index ee80587d..8fd4e5d7 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -34,43 +34,6 @@ class FileSystem(SimComponent): if not self.folders: self.create_folder("root") - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting FileSystem original state on node {self.sys_log.hostname}") - for folder in self.folders.values(): - folder.set_original_state() - # Capture a list of all 'original' file uuids - original_keys = list(self.folders.keys()) - vals_to_include = {"sim_root"} - self._original_state.update(self.model_dump(include=vals_to_include)) - self._original_state["original_folder_uuids"] = original_keys - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting FileSystem state on node {self.sys_log.hostname}") - # Move any 'original' folder that have been deleted back to folders - original_folder_uuids = self._original_state["original_folder_uuids"] - for uuid in original_folder_uuids: - if uuid in self.deleted_folders: - folder = self.deleted_folders[uuid] - self.deleted_folders.pop(uuid) - self.folders[uuid] = folder - - # Clear any other deleted folders that aren't original (have been created by agent) - self.deleted_folders.clear() - - # Now clear all non-original folders created by agent - current_folder_uuids = list(self.folders.keys()) - for uuid in current_folder_uuids: - if uuid not in original_folder_uuids: - folder = self.folders[uuid] - self.folders.pop(uuid) - - # Now reset all remaining folders - for folder in self.folders.values(): - folder.reset_component_for_episode(episode) - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/file_system/file_system_item_abc.py b/src/primaite/simulator/file_system/file_system_item_abc.py index c3e1426b..fbe5f4b3 100644 --- a/src/primaite/simulator/file_system/file_system_item_abc.py +++ b/src/primaite/simulator/file_system/file_system_item_abc.py @@ -85,11 +85,6 @@ class FileSystemItemABC(SimComponent): deleted: bool = False "If true, the FileSystemItem was deleted." - def set_original_state(self): - """Sets the original state.""" - vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red", "deleted"} - self._original_state = self.model_dump(include=vals_to_keep) - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index 13fdc597..771dc7a0 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -49,49 +49,6 @@ class Folder(FileSystemItemABC): self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})") - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting Folder ({self.name}) original state on node {self.sys_log.hostname}") - for file in self.files.values(): - file.set_original_state() - super().set_original_state() - vals_to_include = { - "scan_duration", - "scan_countdown", - "red_scan_duration", - "red_scan_countdown", - "restore_duration", - "restore_countdown", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - self._original_state["original_file_uuids"] = list(self.files.keys()) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting Folder ({self.name}) state on node {self.sys_log.hostname}") - # Move any 'original' file that have been deleted back to files - original_file_uuids = self._original_state["original_file_uuids"] - for uuid in original_file_uuids: - if uuid in self.deleted_files: - file = self.deleted_files[uuid] - self.deleted_files.pop(uuid) - self.files[uuid] = file - - # Clear any other deleted files that aren't original (have been created by agent) - self.deleted_files.clear() - - # Now clear all non-original files created by agent - current_file_uuids = list(self.files.keys()) - for uuid in current_file_uuids: - if uuid not in original_file_uuids: - file = self.files[uuid] - self.files.pop(uuid) - - # Now reset all remaining files - for file in self.files.values(): - file.reset_component_for_episode(episode) - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index 724b8728..d264f751 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -273,11 +273,6 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC) return state - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"} - self._original_state = self.model_dump(include=vals_to_include) - def enable(self): """ Enables this wired network interface and attempts to send a "hello" message to the default gateway. diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index b32d2630..b5a16430 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -45,19 +45,12 @@ class Network(SimComponent): self._nx_graph = MultiGraph() - def set_original_state(self): - """Sets the original state.""" - for node in self.nodes.values(): - node.set_original_state() - for link in self.links.values(): - link.set_original_state() - - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" for node in self.nodes.values(): - node.reset_component_for_episode(episode) + node.setup_for_episode(episode=episode) for link in self.links.values(): - link.reset_component_for_episode(episode) + link.setup_for_episode(episode=episode) for node in self.nodes.values(): node.power_on() @@ -179,7 +172,7 @@ class Network(SimComponent): def clear_links(self): """Clear all the links in the network by resetting their component state for the episode.""" for link in self.links.values(): - link.reset_component_for_episode() + link.setup_for_episode(episode=0) # TODO: shouldn't be using this method here. def draw(self, seed: int = 123): """ diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 35c90d05..ff79f314 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -100,6 +100,15 @@ class NetworkInterface(SimComponent, ABC): nmne: Dict = Field(default_factory=lambda: {}) "A dict containing details of the number of malicious network events captured." + def setup_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + super().setup_for_episode(episode=episode) + self.nmne = {} + if episode and self.pcap: + self.pcap.current_episode = episode + self.pcap.setup_logger() + self.enable() + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() @@ -127,15 +136,6 @@ class NetworkInterface(SimComponent, ABC): state.update({"nmne": self.nmne}) return state - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - super().reset_component_for_episode(episode) - self.nmne = {} - if episode and self.pcap: - self.pcap.current_episode = episode - self.pcap.setup_logger() - self.enable() - @abstractmethod def enable(self): """Enable the interface.""" @@ -547,14 +547,6 @@ class Link(SimComponent): self.endpoint_b.connect_link(self) self.endpoint_up() - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"bandwidth", "current_load"} - self._original_state = self.model_dump(include=vals_to_include) - super().set_original_state() - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -740,50 +732,20 @@ class Node(SimComponent): self.session_manager.node = self self.session_manager.software_manager = self.software_manager self._install_system_software() - self.set_original_state() - def set_original_state(self): - """Sets the original state.""" - for software in self.software_manager.software.values(): - software.set_original_state() - - self.file_system.set_original_state() - - for network_interface in self.network_interfaces.values(): - network_interface.set_original_state() - - vals_to_include = { - "hostname", - "default_gateway", - "operating_state", - "revealed_to_red", - "start_up_duration", - "start_up_countdown", - "shut_down_duration", - "shut_down_countdown", - "is_resetting", - "node_scan_duration", - "node_scan_countdown", - "red_scan_countdown", - } - self._original_state = self.model_dump(include=vals_to_include) - - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - super().reset_component_for_episode(episode) - - # Reset Session Manager - self.session_manager.clear() + super().setup_for_episode(episode=episode) # Reset File System - self.file_system.reset_component_for_episode(episode) + self.file_system.setup_for_episode(episode=episode) # Reset all Nics for network_interface in self.network_interfaces.values(): - network_interface.reset_component_for_episode(episode) + network_interface.setup_for_episode(episode=episode) for software in self.software_manager.software.values(): - software.reset_component_for_episode(episode) + software.setup_for_episode(episode=episode) if episode and self.sys_log: self.sys_log.current_episode = episode diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 1d3b9926..cb3c1bd7 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -209,11 +209,6 @@ class NIC(IPWiredNetworkInterface): return state - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"} - self._original_state = self.model_dump(include=vals_to_include) - def receive_frame(self, frame: Frame) -> bool: """ Attempt to receive and process a network frame from the connected Link. diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index ce98cec4..b4d5cdba 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -111,24 +111,6 @@ class Firewall(Router): sys_log=kwargs["sys_log"], implicit_action=ACLAction.PERMIT, name=f"{hostname} - External Outbound" ) - self.set_original_state() - - def set_original_state(self): - """Set the original state for the Firewall.""" - super().set_original_state() - vals_to_include = { - "internal_port", - "external_port", - "dmz_port", - "internal_inbound_acl", - "internal_outbound_acl", - "dmz_inbound_acl", - "dmz_outbound_acl", - "external_inbound_acl", - "external_outbound_acl", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - def describe_state(self) -> Dict: """ Describes the current state of the Firewall. diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index a9e12401..5b45f59c 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -136,11 +136,6 @@ class ACLRule(SimComponent): rule_strings.append(f"{key}={value}") return ", ".join(rule_strings) - def set_original_state(self): - """Sets the original state.""" - vals_to_keep = {"action", "protocol", "src_ip_address", "src_port", "dst_ip_address", "dst_port"} - self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True) - def describe_state(self) -> Dict: """ Describes the current state of the ACLRule. @@ -296,48 +291,6 @@ class AccessControlList(SimComponent): super().__init__(**kwargs) self._acl = [None] * (self.max_acl_rules - 1) - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - self.implicit_rule.set_original_state() - vals_to_keep = {"implicit_action", "max_acl_rules", "acl"} - self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True) - - for i, rule in enumerate(self._acl): - if not rule: - continue - self._default_config[i] = {"action": rule.action.name} - if rule.src_ip_address: - self._default_config[i]["src_ip"] = str(rule.src_ip_address) - if rule.dst_ip_address: - self._default_config[i]["dst_ip"] = str(rule.dst_ip_address) - if rule.src_port: - self._default_config[i]["src_port"] = rule.src_port.name - if rule.dst_port: - self._default_config[i]["dst_port"] = rule.dst_port.name - if rule.protocol: - self._default_config[i]["protocol"] = rule.protocol.name - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.implicit_rule.reset_component_for_episode(episode) - super().reset_component_for_episode(episode) - self._reset_rules_to_default() - - def _reset_rules_to_default(self) -> None: - """Clear all ACL rules and set them to the default rules config.""" - self._acl = [None] * (self.max_acl_rules - 1) - for r_num, r_cfg in self._default_config.items(): - self.add_rule( - action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], - src_ip_address=r_cfg.get("src_ip"), - dst_ip_address=r_cfg.get("dst_ip"), - position=r_num, - ) def _init_request_manager(self) -> RequestManager: # TODO: Add src and dst wildcard masks as positional args in this request. @@ -616,11 +569,6 @@ class RouteEntry(SimComponent): metric: float = 0.0 "The cost metric for this route. Default is 0.0." - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"} - self._original_values = self.model_dump(include=vals_to_include) - def describe_state(self) -> Dict: """ Describes the current state of the RouteEntry. @@ -653,17 +601,6 @@ class RouteTable(SimComponent): default_route: Optional[RouteEntry] = None sys_log: SysLog - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - self._original_state["routes_orig"] = self.routes - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.routes.clear() - self.routes = self._original_state["routes_orig"] - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Describes the current state of the RouteTable. @@ -1104,8 +1041,6 @@ class Router(NetworkNode): self._set_default_acl() - self.set_original_state() - def _install_system_software(self): """ Installs essential system software and network services on the router. @@ -1131,20 +1066,7 @@ class Router(NetworkNode): self.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) self.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - def set_original_state(self): - """ - Sets or resets the router to its original configuration state. - - This method is called to initialize the router's state or to revert it to a known good configuration during - network simulations or after configuration changes. - """ - self.acl.set_original_state() - self.route_table.set_original_state() - super().set_original_state() - vals_to_include = {"num_ports"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """ Resets the router's components for a new network simulation episode. @@ -1154,13 +1076,10 @@ class Router(NetworkNode): :param episode: The episode number for which the router is being reset. """ self.software_manager.arp.clear() - self.acl.reset_component_for_episode(episode) - self.route_table.reset_component_for_episode(episode) - for i, network_interface in self.network_interface.items(): - network_interface.reset_component_for_episode(episode) + for i, _ in self.network_interface.items(): self.enable_port(i) - super().reset_component_for_episode(episode) + super().setup_for_episode(episode=episode) def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() @@ -1391,7 +1310,6 @@ class Router(NetworkNode): network_interface.ip_address = ip_address network_interface.subnet_mask = subnet_mask self.sys_log.info(f"Configured Network Interface {network_interface}") - self.set_original_state() def enable_port(self, port: int): """ @@ -1493,8 +1411,16 @@ class Router(NetworkNode): subnet_mask=IPv4Address(port_cfg.get("subnet_mask", "255.255.255.0")), ) if "acl" in cfg: - router.acl._default_config = cfg["acl"] # save the config to allow resetting - router.acl._reset_rules_to_default() # read the config and apply rules + for r_num, r_cfg in cfg["acl"].items(): + router.acl.add_rule( + action=ACLAction[r_cfg["action"]], + src_port=None if not (p := r_cfg.get("src_port")) else Port[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], + protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_ip_address=r_cfg.get("src_ip"), + dst_ip_address=r_cfg.get("dst_ip"), + position=r_num, + ) if "routes" in cfg: for route in cfg.get("routes"): router.route_table.add_route( diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index 33e6ee9a..557ea287 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -32,12 +32,6 @@ class SwitchPort(WiredNetworkInterface): _connected_node: Optional[Switch] = None "The Switch to which the SwitchPort is connected." - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"port_num", "mac_address", "speed", "mtu", "enabled"} - self._original_state = self.model_dump(include=vals_to_include) - super().set_original_state() - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index dd0b58d3..91833d6a 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -122,8 +122,6 @@ class WirelessRouter(Router): self.connect_nic(RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")) - self.set_original_state() - @property def wireless_access_point(self) -> WirelessAccessPoint: """ @@ -166,7 +164,6 @@ class WirelessRouter(Router): network_interface.ip_address = ip_address network_interface.subnet_mask = subnet_mask self.sys_log.info(f"Configured WAP {network_interface}") - self.set_original_state() self.wireless_access_point.frequency = frequency # Set operating frequency self.wireless_access_point.enable() # Re-enable the WAP with new settings diff --git a/src/primaite/simulator/sim_container.py b/src/primaite/simulator/sim_container.py index 896861e6..a2285d92 100644 --- a/src/primaite/simulator/sim_container.py +++ b/src/primaite/simulator/sim_container.py @@ -21,13 +21,9 @@ class Simulation(SimComponent): super().__init__(**kwargs) - def set_original_state(self): - """Sets the original state.""" - self.network.set_original_state() - - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - self.network.reset_component_for_episode(episode) + self.network.setup_for_episode(episode=episode) def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 322ac808..513606a9 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -38,12 +38,6 @@ class Application(IOSoftware): def __init__(self, **kwargs): super().__init__(**kwargs) - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"operating_state", "execution_control_status", "num_executions", "groups"} - self._original_state.update(self.model_dump(include=vals_to_include)) - @abstractmethod def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 50d9f3d4..57fd3a46 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -31,20 +31,6 @@ class DatabaseClient(Application): kwargs["port"] = Port.POSTGRES_SERVER kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DatabaseClient WebServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"server_ip_address", "server_password", "connected", "_query_success_tracker"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting DataBaseClient state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - self._query_success_tracker.clear() def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index a844f059..5fe951b7 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -49,26 +49,6 @@ class DataManipulationBot(DatabaseClient): super().__init__(**kwargs) self.name = "DataManipulationBot" - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DataManipulationBot original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = { - "server_ip_address", - "payload", - "server_password", - "port_scan_p_of_success", - "data_manipulation_p_of_success", - "attack_stage", - "repeat", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting DataManipulationBot state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index dfc48dd3..9dac6b25 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -57,27 +57,6 @@ class DoSBot(DatabaseClient, Application): self.name = "DoSBot" self.max_sessions = 1000 # override normal max sessions - def set_original_state(self): - """Set the original state of the Denial of Service Bot.""" - _LOGGER.debug(f"Setting {self.name} original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = { - "target_ip_address", - "target_port", - "payload", - "repeat", - "attack_stage", - "max_sessions", - "port_scan_p_of_success", - "dos_intensity", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting {self.name} state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index eef0ed5d..6f2c479c 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -47,21 +47,8 @@ class WebBrowser(Application): kwargs["port"] = Port.HTTP super().__init__(**kwargs) - self.set_original_state() self.run() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting WebBrowser original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"target_url", "domain_name_ip_address", "latest_response"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting WebBrowser state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( @@ -80,9 +67,6 @@ class WebBrowser(Application): state["history"] = [hist_item.state() for hist_item in self.history] return state - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - def get_webpage(self, url: Optional[str] = None) -> bool: """ Retrieve the webpage. diff --git a/src/primaite/simulator/system/processes/process.py b/src/primaite/simulator/system/processes/process.py index b753e3ad..458a6b5c 100644 --- a/src/primaite/simulator/system/processes/process.py +++ b/src/primaite/simulator/system/processes/process.py @@ -24,12 +24,6 @@ class Process(Software): operating_state: ProcessOperatingState "The current operating state of the Process." - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"operating_state"} - self._original_state.update(self.model_dump(include=vals_to_include)) - @abstractmethod def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 726d213e..9fdfd5ff 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -41,25 +41,6 @@ class DatabaseService(Service): super().__init__(**kwargs) self._create_db_file() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DatabaseService original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = { - "password", - "connections", - "backup_server_ip", - "latest_backup_directory", - "latest_backup_file_name", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}") - self.clear_connections() - super().reset_component_for_episode(episode) - def configure_backup(self, backup_server: IPv4Address): """ Set up the database backup. diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 2d3879ff..967af6b2 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -29,18 +29,6 @@ class DNSClient(Service): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DNSClient original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"dns_server"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.dns_cache.clear() - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Describes the current state of the software. diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index 8decf7e9..4d0ebbb8 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -28,20 +28,6 @@ class DNSServer(Service): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DNSServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"dns_table"} - self._original_state["dns_table_orig"] = self.model_dump(include=vals_to_include)["dns_table"] - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.dns_table.clear() - for key, value in self._original_state["dns_table_orig"].items(): - self.dns_table[key] = value - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Describes the current state of the software. diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index 39bc57f0..7c334ced 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -27,18 +27,6 @@ class FTPClient(FTPServiceABC): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting FTPClient original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"connected"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting FTPClient state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket: """ Process the command in the FTP Packet. diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index a82b0919..c5330de2 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -27,19 +27,6 @@ class FTPServer(FTPServiceABC): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting FTPServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"server_password"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}") - self.clear_connections() - super().reset_component_for_episode(episode) - def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket: """ Process the command in the FTP Packet. diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index 43d1d783..ad00065c 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -49,21 +49,12 @@ class NTPClient(Service): state = super().describe_state() return state - def reset_component_for_episode(self, episode: int): - """ - Resets the Service component for a new episode. - - This method ensures the Service is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. - """ - pass - def send( self, payload: NTPPacket, session_id: Optional[str] = None, dest_ip_address: IPv4Address = None, - dest_port: [Port] = Port.NTP, + dest_port: Port = Port.NTP, **kwargs, ) -> bool: """Requests NTP data from NTP server. diff --git a/src/primaite/simulator/system/services/ntp/ntp_server.py b/src/primaite/simulator/system/services/ntp/ntp_server.py index 3ae80936..f9d9ee7c 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_server.py +++ b/src/primaite/simulator/system/services/ntp/ntp_server.py @@ -34,16 +34,6 @@ class NTPServer(Service): state = super().describe_state() return state - def reset_component_for_episode(self, episode: int): - """ - Resets the Service component for a new episode. - - This method ensures the Service is ready for a new episode, including - resetting any stateful properties or statistics, and clearing any message - queues. - """ - pass - def receive( self, payload: NTPPacket, diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 162678a0..4102657c 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -78,12 +78,6 @@ class Service(IOSoftware): """ return super().receive(payload=payload, session_id=session_id, **kwargs) - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"operating_state", "restart_duration", "restart_countdown"} - self._original_state.update(self.model_dump(include=vals_to_include)) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request("scan", RequestType(func=lambda request, context: self.scan())) diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index eaea6bb1..5e7591e9 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -23,18 +23,6 @@ class WebServer(Service): last_response_status_code: Optional[HttpStatusCode] = None - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting WebServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"last_response_status_code"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting WebServer state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -130,7 +118,7 @@ class WebServer(Service): self.set_health_state(SoftwareHealthState.COMPROMISED) return response - except Exception: + except Exception: # TODO: refactor this. Likely to cause silent bugs. (ADO ticket #2345 ) # something went wrong on the server response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR return response diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index ce39930b..8864659c 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -3,7 +3,7 @@ from abc import abstractmethod from datetime import datetime from enum import Enum from ipaddress import IPv4Address, IPv4Network -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, TYPE_CHECKING, Union from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent from primaite.simulator.file_system.file_system import FileSystem, Folder @@ -13,6 +13,9 @@ from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.core.sys_log import SysLog +if TYPE_CHECKING: + from primaite.simulator.system.core.software_manager import SoftwareManager + class SoftwareType(Enum): """ @@ -84,7 +87,7 @@ class Software(SimComponent): "The count of times the software has been scanned, defaults to 0." revealed_to_red: bool = False "Indicates if the software has been revealed to red agent, defaults is False." - software_manager: Any = None + software_manager: "SoftwareManager" = None "An instance of Software Manager that is used by the parent node." sys_log: SysLog = None "An instance of SysLog that is used by the parent node." @@ -97,19 +100,6 @@ class Software(SimComponent): _patching_countdown: Optional[int] = None "Current number of ticks left to patch the software." - def set_original_state(self): - """Sets the original state.""" - vals_to_include = { - "name", - "health_state_actual", - "health_state_visible", - "criticality", - "patching_count", - "scanning_count", - "revealed_to_red", - } - self._original_state = self.model_dump(include=vals_to_include) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( @@ -248,12 +238,6 @@ class IOSoftware(Software): _connections: Dict[str, Dict] = {} "Active connections." - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"installing_count", "max_sessions", "tcp", "udp", "port"} - self._original_state.update(self.model_dump(include=vals_to_include)) - @abstractmethod def describe_state(self) -> Dict: """ diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 5bdc3273..c76aeef6 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -589,15 +589,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index 8361e318..1cb59f87 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -593,15 +593,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server @@ -624,7 +625,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml index 87bd9d1c..b1b15372 100644 --- a/tests/assets/configs/multi_agent_session.yaml +++ b/tests/assets/configs/multi_agent_session.yaml @@ -1043,16 +1043,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - - ref: database_server type: server @@ -1074,7 +1074,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index 76190a64..e5f9d544 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -599,15 +599,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server @@ -630,7 +631,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/assets/configs/train_only_primaite_session.yaml b/tests/assets/configs/train_only_primaite_session.yaml index 5d004c7e..10e088d8 100644 --- a/tests/assets/configs/train_only_primaite_session.yaml +++ b/tests/assets/configs/train_only_primaite_session.yaml @@ -600,15 +600,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server @@ -631,7 +632,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/conftest.py b/tests/conftest.py index ada89026..dbfff2f3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -529,6 +529,6 @@ def game_and_agent(): reward_function=reward_function, ) - game.agents.append(test_agent) + game.agents["test_agent"] = test_agent return (game, test_agent) diff --git a/tests/e2e_integration_tests/environments/test_sb3_environment.py b/tests/e2e_integration_tests/environments/test_sb3_environment.py index 91cf5c1e..c48ddbc9 100644 --- a/tests/e2e_integration_tests/environments/test_sb3_environment.py +++ b/tests/e2e_integration_tests/environments/test_sb3_environment.py @@ -11,14 +11,12 @@ from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv -# @pytest.mark.skip(reason="no way of currently testing this") def test_sb3_compatibility(): """Test that the Gymnasium environment can be used with an SB3 agent.""" with open(example_config_path(), "r") as f: cfg = yaml.safe_load(f) - game = PrimaiteGame.from_config(cfg) - gym = PrimaiteGymEnv(game=game) + gym = PrimaiteGymEnv(game_config=cfg) model = PPO("MlpPolicy", gym) model.learn(total_timesteps=1000) diff --git a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py index 54dca371..306f591d 100644 --- a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py +++ b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py @@ -14,7 +14,47 @@ from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.ntp.ntp_server import NTPServer from primaite.simulator.system.services.web_server.web_server import WebServer -from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, load_config +from tests import TEST_ASSETS_ROOT + +BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" + + +def load_config(config_path: Union[str, Path]) -> PrimaiteGame: + """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) + + return PrimaiteGame.from_config(cfg) + + +def test_example_config(): + """Test that the example config can be parsed properly.""" + game = load_config(example_config_path()) + + assert len(game.agents) == 4 # red, blue and 2 green agents + + # green agent 1 + assert "client_2_green_user" in game.agents + assert isinstance(game.agents["client_2_green_user"], RandomAgent) + + # green agent 2 + assert "client_1_green_user" in game.agents + assert isinstance(game.agents["client_1_green_user"], RandomAgent) + + # red agent + assert "client_1_data_manipulation_red_bot" in game.agents + assert isinstance(game.agents["client_1_data_manipulation_red_bot"], DataManipulationAgent) + + # blue agent + assert "defender" in game.agents + assert isinstance(game.agents["defender"], ProxyAgent) + + network: Network = game.simulation.network + + assert len(network.nodes) == 10 # 10 nodes in example network + assert len(network.routers) == 1 # 1 router in network + assert len(network.switches) == 2 # 2 switches in network + assert len(network.servers) == 5 # 5 servers in network def test_node_software_install(): diff --git a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py index 01ad3871..786fe851 100644 --- a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py +++ b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py @@ -7,26 +7,11 @@ from primaite.simulator.domain.account import Account, AccountType @pytest.fixture(scope="function") def account() -> Account: acct = Account(username="Jake", password="totally_hashed_password", account_type=AccountType.USER) - acct.set_original_state() return acct def test_original_state(account): """Test the original state - see if it resets properly""" - account.log_on() - account.log_off() - account.disable() - - state = account.describe_state() - assert state["num_logons"] is 1 - assert state["num_logoffs"] is 1 - assert state["num_group_changes"] is 0 - assert state["username"] is "Jake" - assert state["password"] is "totally_hashed_password" - assert state["account_type"] is AccountType.USER.value - assert state["enabled"] is False - - account.reset_component_for_episode(episode=1) state = account.describe_state() assert state["num_logons"] is 0 assert state["num_logoffs"] is 0 @@ -39,13 +24,7 @@ def test_original_state(account): account.log_on() account.log_off() account.disable() - account.set_original_state() - account.log_on() - state = account.describe_state() - assert state["num_logons"] is 2 - - account.reset_component_for_episode(episode=2) state = account.describe_state() assert state["num_logons"] is 1 assert state["num_logoffs"] is 1 diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py index 9366d173..4defc80c 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py @@ -185,38 +185,6 @@ def test_get_file(file_system): file_system.show(full=True) -def test_reset_file_system(file_system): - # file and folder that existed originally - file_system.create_file(file_name="test_file.zip") - file_system.create_folder(folder_name="test_folder") - file_system.set_original_state() - - # create a new file - file_system.create_file(file_name="new_file.txt") - - # create a new folder - file_system.create_folder(folder_name="new_folder") - - # delete the file that existed originally - file_system.delete_file(folder_name="root", file_name="test_file.zip") - assert file_system.get_file(folder_name="root", file_name="test_file.zip") is None - - # delete the folder that existed originally - file_system.delete_folder(folder_name="test_folder") - assert file_system.get_folder(folder_name="test_folder") is None - - # reset - file_system.reset_component_for_episode(episode=1) - - # deleted original file and folder should be back - assert file_system.get_file(folder_name="root", file_name="test_file.zip") - assert file_system.get_folder(folder_name="test_folder") - - # new file and folder should be removed - assert file_system.get_file(folder_name="root", file_name="new_file.txt") is None - assert file_system.get_folder(folder_name="new_folder") is None - - @pytest.mark.skip(reason="Skipping until we tackle serialisation") def test_serialisation(file_system): """Test to check that the object serialisation works correctly.""" diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_container.py b/tests/unit_tests/_primaite/_simulator/_network/test_container.py index 9d424697..2cfc3f11 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_container.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_container.py @@ -31,7 +31,6 @@ def network(example_network) -> Network: assert len(example_network.computers) is 2 assert len(example_network.servers) is 2 - example_network.set_original_state() example_network.show() return example_network @@ -45,40 +44,6 @@ def test_describe_state(network): assert len(state["links"]) is 6 -def test_reset_network(network): - """ - Test that the network is properly reset. - - TODO: make sure that once implemented - any installed/uninstalled services, processes, apps, - etc are also removed/reinstalled - - """ - state_before = network.describe_state() - - client_1: Computer = network.get_node_by_hostname("client_1") - server_1: Computer = network.get_node_by_hostname("server_1") - - assert client_1.operating_state is NodeOperatingState.ON - assert server_1.operating_state is NodeOperatingState.ON - - client_1.power_off() - assert client_1.operating_state is NodeOperatingState.SHUTTING_DOWN - - server_1.power_off() - assert server_1.operating_state is NodeOperatingState.SHUTTING_DOWN - - assert network.describe_state() != state_before - - network.reset_component_for_episode(episode=1) - - assert client_1.operating_state is NodeOperatingState.ON - assert server_1.operating_state is NodeOperatingState.ON - # don't worry if UUIDs change - a = filter_keys_nested_item(json.dumps(network.describe_state(), sort_keys=True, indent=2), ["uuid"]) - b = filter_keys_nested_item(json.dumps(state_before, sort_keys=True, indent=2), ["uuid"]) - assert a == b - - def test_creating_container(): """Check that we can create a network container""" net = Network() diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py index ccf40c44..4bfd28d0 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py @@ -19,7 +19,6 @@ def dos_bot() -> DoSBot: dos_bot: DoSBot = computer.software_manager.software.get("DoSBot") dos_bot.configure(target_ip_address=IPv4Address("192.168.0.1")) - dos_bot.set_original_state() return dos_bot @@ -28,35 +27,6 @@ def test_dos_bot_creation(dos_bot): assert dos_bot is not None -def test_dos_bot_reset(dos_bot): - assert dos_bot.target_ip_address == IPv4Address("192.168.0.1") - assert dos_bot.target_port is Port.POSTGRES_SERVER - assert dos_bot.payload is None - assert dos_bot.repeat is False - - dos_bot.configure( - target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True - ) - - # should reset the relevant items - dos_bot.reset_component_for_episode(episode=0) - assert dos_bot.target_ip_address == IPv4Address("192.168.0.1") - assert dos_bot.target_port is Port.POSTGRES_SERVER - assert dos_bot.payload is None - assert dos_bot.repeat is False - - dos_bot.configure( - target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True - ) - dos_bot.set_original_state() - dos_bot.reset_component_for_episode(episode=1) - # should reset to the configured value - assert dos_bot.target_ip_address == IPv4Address("192.168.1.1") - assert dos_bot.target_port is Port.HTTP - assert dos_bot.payload == "payload" - assert dos_bot.repeat is True - - def test_dos_bot_cannot_run_when_node_offline(dos_bot): dos_bot_node: Computer = dos_bot.parent assert dos_bot_node.operating_state is NodeOperatingState.ON diff --git a/tests/unit_tests/_primaite/_simulator/_system/test_software.py b/tests/unit_tests/_primaite/_simulator/_system/test_software.py index e77cd895..6f680012 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/test_software.py +++ b/tests/unit_tests/_primaite/_simulator/_system/test_software.py @@ -2,12 +2,14 @@ from typing import Dict import pytest +from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.sys_log import SysLog -from primaite.simulator.system.software import Software, SoftwareHealthState +from primaite.simulator.system.services.service import Service +from primaite.simulator.system.software import IOSoftware, SoftwareHealthState -class TestSoftware(Software): +class TestSoftware(Service): def describe_state(self) -> Dict: pass @@ -15,7 +17,11 @@ class TestSoftware(Software): @pytest.fixture(scope="function") def software(file_system): return TestSoftware( - name="TestSoftware", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_service") + name="TestSoftware", + port=Port.ARP, + file_system=file_system, + sys_log=SysLog(hostname="test_service"), + protocol=IPProtocol.TCP, )