diff --git a/CHANGELOG.md b/CHANGELOG.md
index 260bfcf0..1c467206 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
+- Made environment reset completely recreate the game object.
- Changed the red agent in the data manipulation scenario to randomly choose client 1 or client 2 to start its attack.
- Changed the data manipulation scenario to include a second green agent on client 1.
- Refactored actions and observations to be configurable via object name, instead of UUID.
diff --git a/docs/index.rst b/docs/index.rst
index 9eae8adc..08e0ac21 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -108,6 +108,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
source/simulation
source/game_layer
source/config
+ source/environment
.. toctree::
:caption: Developer information:
diff --git a/docs/source/environment.rst b/docs/source/environment.rst
new file mode 100644
index 00000000..2b76572d
--- /dev/null
+++ b/docs/source/environment.rst
@@ -0,0 +1,10 @@
+RL Environments
+***************
+
+RL environments are the objects that directly interface with RL libraries such as Stable-Baselines3 and Ray RLLib. The PrimAITE simulation is exposed via three different environment APIs:
+
+* Gymnasium API - this is the standard interface that works with many RL libraries like SB3, Ray, Tianshou, etc. ``PrimaiteGymEnv`` adheres to the `Official Gymnasium documentation `_.
+* Ray Single agent API - For training a single Ray RLLib agent
+* Ray MARL API - For training multi-agent systems with Ray RLLib. ``PrimaiteRayMARLEnv`` adheres to the `Official Ray documentation `_.
+
+There are Jupyter notebooks which demonstrate integration with each of these three environments. They are located in ``~/primaite//notebooks/example_notebooks``.
diff --git a/docs/source/game_layer.rst b/docs/source/game_layer.rst
index eb9b17c3..39ab7bde 100644
--- a/docs/source/game_layer.rst
+++ b/docs/source/game_layer.rst
@@ -20,6 +20,11 @@ The game layer is responsible for managing agents and getting them to interface
PrimAITE Session
^^^^^^^^^^^^^^^^
+.. admonition:: Deprecated
+ :class: deprecated
+
+ PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The `session` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality.
+
``PrimaiteSession`` is the main entry point into Primaite and it allows the simultaneous coordination of a simulation and agents that interact with it. ``PrimaiteSession`` keeps track of multiple agents of different types.
Agents
diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst
index 706397b6..87a3f03d 100644
--- a/docs/source/primaite_session.rst
+++ b/docs/source/primaite_session.rst
@@ -4,6 +4,11 @@
.. _run a primaite session:
+.. admonition:: Deprecated
+ :class: deprecated
+
+ PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The ``session`` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality.
+
Run a PrimAITE Session
======================
diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml
index 6076553c..478124a9 100644
--- a/src/primaite/config/_package_data/example_config.yaml
+++ b/src/primaite/config/_package_data/example_config.yaml
@@ -14,7 +14,7 @@ io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_step_metadata: false
- save_pcap_logs: true
+ save_pcap_logs: false
save_sys_logs: true
@@ -656,12 +656,13 @@ simulation:
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
+ - ref: web_server_web_service
+ type: WebServer
+ applications:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- - ref: web_server_web_service
- type: WebServer
- ref: database_server
diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml
index 12461547..2b54eb37 100644
--- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml
+++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml
@@ -1,15 +1,23 @@
training_config:
rl_framework: RLLIB_multi_agent
- # rl_framework: SB3
+ rl_algorithm: PPO
+ seed: 333
+ n_learn_episodes: 1
+ n_eval_episodes: 5
+ max_steps_per_episode: 256
+ deterministic_eval: false
n_agents: 2
agent_references:
- defender_1
- defender_2
+
io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_step_metadata: false
+ save_pcap_logs: false
+ save_sys_logs: true
game:
@@ -36,9 +44,9 @@ agents:
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- - node_ref: client_2
+ - node_name: client_2
applications:
- - application_ref: client_2_web_browser
+ - application_name: WebBrowser
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
@@ -54,6 +62,31 @@ agents:
frequency: 4
variance: 3
+ - ref: client_1_green_user
+ team: GREEN
+ type: GreenWebBrowsingAgent
+ observation_space:
+ type: UC2GreenObservation
+ action_space:
+ action_list:
+ - type: DONOTHING
+ - type: NODE_APPLICATION_EXECUTE
+ options:
+ nodes:
+ - node_name: client_1
+ applications:
+ - application_name: WebBrowser
+ max_folders_per_node: 1
+ max_files_per_folder: 1
+ max_services_per_node: 1
+ max_applications_per_node: 1
+ reward_function:
+ reward_components:
+ - type: DUMMY
+
+
+
+
- ref: data_manipulation_attacker
team: RED
type: RedDatabaseCorruptingAgent
@@ -72,7 +105,7 @@ agents:
- type: NODE_OS_SCAN
options:
nodes:
- - node_ref: client_1
+ - node_name: client_1
applications:
- application_name: DataManipulationBot
- node_name: client_2
@@ -104,25 +137,21 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- - node_ref: domain_controller
+ - node_hostname: domain_controller
services:
- - service_ref: domain_controller_dns_server
- - node_ref: web_server
+ - service_name: DNSServer
+ - node_hostname: web_server
services:
- - service_ref: web_server_database_client
- - node_ref: database_server
- services:
- - service_ref: database_service
+ - service_name: WebServer
+ - node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- - node_ref: backup_server
- # services:
- # - service_ref: backup_service
- - node_ref: security_suite
- - node_ref: client_1
- - node_ref: client_2
+ - node_hostname: backup_server
+ - node_hostname: security_suite
+ - node_hostname: client_1
+ - node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -137,23 +166,23 @@ agents:
acl:
options:
max_acl_rules: 10
- router_node_ref: router_1
+ router_hostname: router_1
ip_address_order:
- - node_ref: domain_controller
+ - node_hostname: domain_controller
nic_num: 1
- - node_ref: web_server
+ - node_hostname: web_server
nic_num: 1
- - node_ref: database_server
+ - node_hostname: database_server
nic_num: 1
- - node_ref: backup_server
+ - node_hostname: backup_server
nic_num: 1
- - node_ref: security_suite
+ - node_hostname: security_suite
nic_num: 1
- - node_ref: client_1
+ - node_hostname: client_1
nic_num: 1
- - node_ref: client_2
+ - node_hostname: client_2
nic_num: 1
- - node_ref: security_suite
+ - node_hostname: security_suite
nic_num: 2
ics: null
@@ -184,10 +213,10 @@ agents:
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
- target_router_ref: router_1
+ target_router_hostname: router_1
- type: NETWORK_ACL_REMOVERULE
options:
- target_router_ref: router_1
+ target_router_hostname: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
@@ -242,25 +271,25 @@ agents:
action: "NODE_FILE_SCAN"
options:
node_id: 2
- folder_id: 1
+ folder_id: 0
file_id: 0
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 2
- folder_id: 1
+ folder_id: 0
file_id: 0
11:
action: "NODE_FILE_DELETE"
options:
node_id: 2
- folder_id: 1
+ folder_id: 0
file_id: 0
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 2
- folder_id: 1
+ folder_id: 0
file_id: 0
13:
action: "NODE_SERVICE_PATCH"
@@ -271,22 +300,22 @@ agents:
action: "NODE_FOLDER_SCAN"
options:
node_id: 2
- folder_id: 1
+ folder_id: 0
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 2
- folder_id: 1
+ folder_id: 0
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 2
- folder_id: 1
+ folder_id: 0
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 2
- folder_id: 1
+ folder_id: 0
18:
action: "NODE_OS_SCAN"
options:
@@ -303,63 +332,63 @@ agents:
action: "NODE_RESET"
options:
node_id: 5
- 22:
+ 22: # "ACL: ADDRULE - Block outgoing traffic from client 1"
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
- source_ip_id: 7
- dest_ip_id: 1
+ source_ip_id: 7 # client 1
+ dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
- 23:
+ 23: # "ACL: ADDRULE - Block outgoing traffic from client 2"
action: "NETWORK_ACL_ADDRULE"
options:
- position: 1
+ position: 2
permission: 2
- source_ip_id: 8
- dest_ip_id: 1
+ source_ip_id: 8 # client 2
+ dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
- 24:
+ 24: # block tcp traffic from client 1 to web app
action: "NETWORK_ACL_ADDRULE"
options:
- position: 1
+ position: 3
permission: 2
- source_ip_id: 7
- dest_ip_id: 3
+ source_ip_id: 7 # client 1
+ dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
- 25:
+ 25: # block tcp traffic from client 2 to web app
action: "NETWORK_ACL_ADDRULE"
options:
- position: 1
+ position: 4
permission: 2
- source_ip_id: 8
- dest_ip_id: 3
+ source_ip_id: 8 # client 2
+ dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
- position: 1
+ position: 5
permission: 2
- source_ip_id: 7
- dest_ip_id: 4
+ source_ip_id: 7 # client 1
+ dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
- position: 1
+ position: 6
permission: 2
- source_ip_id: 8
- dest_ip_id: 4
+ source_ip_id: 8 # client 2
+ dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
@@ -407,123 +436,148 @@ agents:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 0
- nic_id: 1
+ nic_id: 0
39:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 0
- nic_id: 1
+ nic_id: 0
40:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 1
- nic_id: 1
+ nic_id: 0
41:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 1
- nic_id: 1
+ nic_id: 0
42:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 2
- nic_id: 1
+ nic_id: 0
43:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 2
- nic_id: 1
+ nic_id: 0
44:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 3
- nic_id: 1
+ nic_id: 0
45:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 3
- nic_id: 1
+ nic_id: 0
46:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
- nic_id: 1
+ nic_id: 0
47:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
- nic_id: 1
+ nic_id: 0
48:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
- nic_id: 2
+ nic_id: 1
49:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
- nic_id: 2
+ nic_id: 1
50:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
- nic_id: 1
+ nic_id: 0
51:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
- nic_id: 1
+ nic_id: 0
52:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 6
- nic_id: 1
+ nic_id: 0
53:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 6
- nic_id: 1
+ nic_id: 0
options:
nodes:
- - node_ref: domain_controller
- - node_ref: web_server
+ - node_name: domain_controller
+ - node_name: web_server
+ applications:
+ - application_name: DatabaseClient
services:
- - service_ref: web_server_web_service
- - node_ref: database_server
+ - service_name: WebServer
+ - node_name: database_server
+ folders:
+ - folder_name: database
+ files:
+ - file_name: database.db
services:
- - service_ref: database_service
- - node_ref: backup_server
- - node_ref: security_suite
- - node_ref: client_1
- - node_ref: client_2
+ - service_name: DatabaseService
+ - node_name: backup_server
+ - node_name: security_suite
+ - node_name: client_1
+ - node_name: client_2
+
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
+ ip_address_order:
+ - node_name: domain_controller
+ nic_num: 1
+ - node_name: web_server
+ nic_num: 1
+ - node_name: database_server
+ nic_num: 1
+ - node_name: backup_server
+ nic_num: 1
+ - node_name: security_suite
+ nic_num: 1
+ - node_name: client_1
+ nic_num: 1
+ - node_name: client_2
+ nic_num: 1
+ - node_name: security_suite
+ nic_num: 2
+
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
- weight: 0.5
+ weight: 0.34
options:
- node_ref: database_server
+ node_hostname: database_server
folder_name: database
file_name: database.db
-
-
- - type: WEB_SERVER_404_PENALTY
- weight: 0.5
+ - type: WEBPAGE_UNAVAILABLE_PENALTY
+ weight: 0.33
options:
- node_ref: web_server
- service_ref: web_server_web_service
+ node_hostname: client_1
+ - type: WEBPAGE_UNAVAILABLE_PENALTY
+ weight: 0.33
+ options:
+ node_hostname: client_2
agent_settings:
- # ...
-
+ flatten_obs: true
- ref: defender_2
team: BLUE
@@ -537,25 +591,21 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- - node_ref: domain_controller
+ - node_hostname: domain_controller
services:
- - service_ref: domain_controller_dns_server
- - node_ref: web_server
+ - service_name: DNSServer
+ - node_hostname: web_server
services:
- - service_ref: web_server_database_client
- - node_ref: database_server
- services:
- - service_ref: database_service
+ - service_name: WebServer
+ - node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- - node_ref: backup_server
- # services:
- # - service_ref: backup_service
- - node_ref: security_suite
- - node_ref: client_1
- - node_ref: client_2
+ - node_hostname: backup_server
+ - node_hostname: security_suite
+ - node_hostname: client_1
+ - node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -570,23 +620,23 @@ agents:
acl:
options:
max_acl_rules: 10
- router_node_ref: router_1
+ router_hostname: router_1
ip_address_order:
- - node_ref: domain_controller
+ - node_hostname: domain_controller
nic_num: 1
- - node_ref: web_server
+ - node_hostname: web_server
nic_num: 1
- - node_ref: database_server
+ - node_hostname: database_server
nic_num: 1
- - node_ref: backup_server
+ - node_hostname: backup_server
nic_num: 1
- - node_ref: security_suite
+ - node_hostname: security_suite
nic_num: 1
- - node_ref: client_1
+ - node_hostname: client_1
nic_num: 1
- - node_ref: client_2
+ - node_hostname: client_2
nic_num: 1
- - node_ref: security_suite
+ - node_hostname: security_suite
nic_num: 2
ics: null
@@ -617,10 +667,10 @@ agents:
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
- target_router_ref: router_1
+ target_router_hostname: router_1
- type: NETWORK_ACL_REMOVERULE
options:
- target_router_ref: router_1
+ target_router_hostname: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
@@ -675,25 +725,25 @@ agents:
action: "NODE_FILE_SCAN"
options:
node_id: 2
- folder_id: 1
+ folder_id: 0
file_id: 0
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 2
- folder_id: 1
+ folder_id: 0
file_id: 0
11:
action: "NODE_FILE_DELETE"
options:
node_id: 2
- folder_id: 1
+ folder_id: 0
file_id: 0
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 2
- folder_id: 1
+ folder_id: 0
file_id: 0
13:
action: "NODE_SERVICE_PATCH"
@@ -704,22 +754,22 @@ agents:
action: "NODE_FOLDER_SCAN"
options:
node_id: 2
- folder_id: 1
+ folder_id: 0
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 2
- folder_id: 1
+ folder_id: 0
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 2
- folder_id: 1
+ folder_id: 0
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 2
- folder_id: 1
+ folder_id: 0
18:
action: "NODE_OS_SCAN"
options:
@@ -736,63 +786,63 @@ agents:
action: "NODE_RESET"
options:
node_id: 5
- 22:
+ 22: # "ACL: ADDRULE - Block outgoing traffic from client 1"
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
- source_ip_id: 7
- dest_ip_id: 1
+ source_ip_id: 7 # client 1
+ dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
- 23:
+ 23: # "ACL: ADDRULE - Block outgoing traffic from client 2"
action: "NETWORK_ACL_ADDRULE"
options:
- position: 1
+ position: 2
permission: 2
- source_ip_id: 8
- dest_ip_id: 1
+ source_ip_id: 8 # client 2
+ dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
- 24:
+ 24: # block tcp traffic from client 1 to web app
action: "NETWORK_ACL_ADDRULE"
options:
- position: 1
+ position: 3
permission: 2
- source_ip_id: 7
- dest_ip_id: 3
+ source_ip_id: 7 # client 1
+ dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
- 25:
+ 25: # block tcp traffic from client 2 to web app
action: "NETWORK_ACL_ADDRULE"
options:
- position: 1
+ position: 4
permission: 2
- source_ip_id: 8
- dest_ip_id: 3
+ source_ip_id: 8 # client 2
+ dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
- position: 1
+ position: 5
permission: 2
- source_ip_id: 7
- dest_ip_id: 4
+ source_ip_id: 7 # client 1
+ dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
- position: 1
+ position: 6
permission: 2
- source_ip_id: 8
- dest_ip_id: 4
+ source_ip_id: 8 # client 2
+ dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
@@ -840,122 +890,148 @@ agents:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 0
- nic_id: 1
+ nic_id: 0
39:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 0
- nic_id: 1
+ nic_id: 0
40:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 1
- nic_id: 1
+ nic_id: 0
41:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 1
- nic_id: 1
+ nic_id: 0
42:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 2
- nic_id: 1
+ nic_id: 0
43:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 2
- nic_id: 1
+ nic_id: 0
44:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 3
- nic_id: 1
+ nic_id: 0
45:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 3
- nic_id: 1
+ nic_id: 0
46:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
- nic_id: 1
+ nic_id: 0
47:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
- nic_id: 1
+ nic_id: 0
48:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
- nic_id: 2
+ nic_id: 1
49:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
- nic_id: 2
+ nic_id: 1
50:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
- nic_id: 1
+ nic_id: 0
51:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
- nic_id: 1
+ nic_id: 0
52:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 6
- nic_id: 1
+ nic_id: 0
53:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 6
- nic_id: 1
+ nic_id: 0
options:
nodes:
- - node_ref: domain_controller
- - node_ref: web_server
+ - node_name: domain_controller
+ - node_name: web_server
+ applications:
+ - application_name: DatabaseClient
services:
- - service_ref: web_server_web_service
- - node_ref: database_server
+ - service_name: WebServer
+ - node_name: database_server
+ folders:
+ - folder_name: database
+ files:
+ - file_name: database.db
services:
- - service_ref: database_service
- - node_ref: backup_server
- - node_ref: security_suite
- - node_ref: client_1
- - node_ref: client_2
+ - service_name: DatabaseService
+ - node_name: backup_server
+ - node_name: security_suite
+ - node_name: client_1
+ - node_name: client_2
+
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
+ ip_address_order:
+ - node_name: domain_controller
+ nic_num: 1
+ - node_name: web_server
+ nic_num: 1
+ - node_name: database_server
+ nic_num: 1
+ - node_name: backup_server
+ nic_num: 1
+ - node_name: security_suite
+ nic_num: 1
+ - node_name: client_1
+ nic_num: 1
+ - node_name: client_2
+ nic_num: 1
+ - node_name: security_suite
+ nic_num: 2
+
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
- weight: 0.5
+ weight: 0.34
options:
- node_ref: database_server
+ node_hostname: database_server
folder_name: database
file_name: database.db
-
-
- - type: WEB_SERVER_404_PENALTY
- weight: 0.5
+ - type: WEBPAGE_UNAVAILABLE_PENALTY
+ weight: 0.33
options:
- node_ref: web_server
- service_ref: web_server_web_service
+ node_hostname: client_1
+ - type: WEBPAGE_UNAVAILABLE_PENALTY
+ weight: 0.33
+ options:
+ node_hostname: client_2
agent_settings:
- # ...
+ flatten_obs: true
@@ -1036,12 +1112,13 @@ simulation:
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
+ - ref: web_server_web_service
+ type: WebServer
+ applications:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- - ref: web_server_web_service
- type: WebServer
- ref: database_server
@@ -1093,10 +1170,14 @@ simulation:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
- port_scan_p_of_success: 0.1
- data_manipulation_p_of_success: 0.1
+ port_scan_p_of_success: 0.8
+ data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
+ - ref: client_1_web_browser
+ type: WebBrowser
+ options:
+ target_url: http://arcd.com/users/
services:
- ref: client_1_dns_client
type: DNSClient
@@ -1113,6 +1194,13 @@ simulation:
type: WebBrowser
options:
target_url: http://arcd.com/users/
+ - ref: data_manipulation_bot
+ type: DataManipulationBot
+ options:
+ port_scan_p_of_success: 0.8
+ data_manipulation_p_of_success: 0.8
+ payload: "DELETE"
+ server_ip: 192.168.1.14
services:
- ref: client_2_dns_client
type: DNSClient
diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py
index 07b9bc9e..8d272418 100644
--- a/src/primaite/game/game.py
+++ b/src/primaite/game/game.py
@@ -83,18 +83,15 @@ class PrimaiteGame:
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
- self.agents: List[AbstractAgent] = []
- """List of agents."""
+ self.agents: Dict[str, AbstractAgent] = {}
+ """Mapping from agent name to agent object."""
- self.rl_agents: List[ProxyAgent] = []
- """Subset of agent list including only the reinforcement learning agents."""
+ self.rl_agents: Dict[str, ProxyAgent] = {}
+ """Subset of agents which are intended for reinforcement learning."""
self.step_counter: int = 0
"""Current timestep within the episode."""
- self.episode_counter: int = 0
- """Current episode number."""
-
self.options: PrimaiteGameOptions
"""Special options that apply for the entire game."""
@@ -140,7 +137,7 @@ class PrimaiteGame:
self.update_agents(sim_state)
# Apply all actions to simulation as requests
- agent_actions = self.apply_agent_actions() # noqa
+ self.apply_agent_actions()
# Advance timestep
self.advance_timestep()
@@ -151,7 +148,7 @@ class PrimaiteGame:
def update_agents(self, state: Dict) -> None:
"""Update agents' observations and rewards based on the current state."""
- for agent in self.agents:
+ for _, agent in self.agents.items():
agent.update_observation(state)
agent.update_reward(state)
agent.reward_function.total_reward += agent.reward_function.current_reward
@@ -165,7 +162,7 @@ class PrimaiteGame:
"""
agent_actions = {}
- for agent in self.agents:
+ for _, agent in self.agents.items():
obs = agent.observation_manager.current_observation
rew = agent.reward_function.current_reward
action_choice, options = agent.get_action(obs, rew)
@@ -188,20 +185,14 @@ class PrimaiteGame:
return True
return False
- def reset(self) -> None:
- """Reset the game, this will reset the simulation."""
- self.episode_counter += 1
- self.step_counter = 0
- _LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}")
- self.simulation.reset_component_for_episode(episode=self.episode_counter)
- for agent in self.agents:
- agent.reward_function.total_reward = 0.0
- agent.reset_agent_for_episode()
-
def close(self) -> None:
"""Close the game, this will close the simulation."""
return NotImplemented
+ def setup_for_episode(self, episode: int) -> None:
+ """Perform any final configuration of components to make them ready for the game to start."""
+ self.simulation.setup_for_episode(episode=episode)
+
@classmethod
def from_config(cls, cfg: Dict) -> "PrimaiteGame":
"""Create a PrimaiteGame object from a config dictionary.
@@ -280,8 +271,9 @@ class PrimaiteGame:
# start the service
new_service.start()
else:
- _LOGGER.warning(f"service type not found {service_type}")
-
+ msg = f"Configuration contains an invalid service type: {service_type}"
+ _LOGGER.error(msg)
+ raise ValueError(msg)
# service-dependent options
if service_type == "DNSClient":
if "options" in service_cfg:
@@ -318,7 +310,9 @@ class PrimaiteGame:
new_application = new_node.software_manager.software[application_type]
game.ref_map_applications[application_ref] = new_application.uuid
else:
- _LOGGER.warning(f"application type not found {application_type}")
+ msg = f"Configuration contains an invalid application type: {application_type}"
+ _LOGGER.error(msg)
+ raise ValueError(msg)
# run the application
new_application.run()
@@ -419,7 +413,6 @@ class PrimaiteGame:
reward_function=reward_function,
agent_settings=agent_settings,
)
- game.agents.append(new_agent)
elif agent_type == "ProxyAgent":
new_agent = ProxyAgent(
agent_name=agent_cfg["ref"],
@@ -428,8 +421,7 @@ class PrimaiteGame:
reward_function=reward_function,
agent_settings=agent_settings,
)
- game.agents.append(new_agent)
- game.rl_agents.append(new_agent)
+ game.rl_agents[agent_cfg["ref"]] = new_agent
elif agent_type == "RedDatabaseCorruptingAgent":
new_agent = DataManipulationAgent(
agent_name=agent_cfg["ref"],
@@ -438,11 +430,11 @@ class PrimaiteGame:
reward_function=reward_function,
agent_settings=agent_settings,
)
- game.agents.append(new_agent)
else:
- _LOGGER.warning(f"agent type {agent_type} not found")
-
- game.simulation.set_original_state()
+ msg(f"Configuration error: {agent_type} is not a valid agent type.")
+ _LOGGER.error(msg)
+ raise ValueError(msg)
+ game.agents[agent_cfg["ref"]] = new_agent
# Set the NMNE capture config
set_nmne_config(cfg["simulation"]["network"].get("nmne_config", {}))
diff --git a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb
index 0d4b6d0e..4ef02443 100644
--- a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb
+++ b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb
@@ -60,7 +60,7 @@
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
" )\n",
- " .environment(env=PrimaiteRayMARLEnv, env_config={\"cfg\":cfg})#, disable_env_checking=True)\n",
+ " .environment(env=PrimaiteRayMARLEnv, env_config=cfg)#, disable_env_checking=True)\n",
" .rollouts(num_rollout_workers=0)\n",
" .training(train_batch_size=128)\n",
" )\n"
@@ -88,6 +88,13 @@
" param_space=config\n",
").fit()"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {
diff --git a/src/primaite/notebooks/training_example_ray_single_agent.ipynb b/src/primaite/notebooks/training_example_ray_single_agent.ipynb
index ea006ae9..3c27bdc6 100644
--- a/src/primaite/notebooks/training_example_ray_single_agent.ipynb
+++ b/src/primaite/notebooks/training_example_ray_single_agent.ipynb
@@ -54,7 +54,7 @@
"metadata": {},
"outputs": [],
"source": [
- "env_config = {\"cfg\":cfg}\n",
+ "env_config = cfg\n",
"\n",
"config = (\n",
" PPOConfig()\n",
diff --git a/src/primaite/notebooks/training_example_sb3.ipynb b/src/primaite/notebooks/training_example_sb3.ipynb
index e5085c5e..0472854e 100644
--- a/src/primaite/notebooks/training_example_sb3.ipynb
+++ b/src/primaite/notebooks/training_example_sb3.ipynb
@@ -27,9 +27,7 @@
"outputs": [],
"source": [
"with open(example_config_path(), 'r') as f:\n",
- " cfg = yaml.safe_load(f)\n",
- "\n",
- "game = PrimaiteGame.from_config(cfg)"
+ " cfg = yaml.safe_load(f)\n"
]
},
{
@@ -38,7 +36,7 @@
"metadata": {},
"outputs": [],
"source": [
- "gym = PrimaiteGymEnv(game=game)"
+ "gym = PrimaiteGymEnv(game_config=cfg)"
]
},
{
@@ -65,7 +63,7 @@
"metadata": {},
"outputs": [],
"source": [
- "model.learn(total_timesteps=1000)\n"
+ "model.learn(total_timesteps=10)\n"
]
},
{
@@ -76,6 +74,13 @@
"source": [
"model.save(\"deleteme\")"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {
diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/uc2_demo.ipynb
index b1e12370..c8f2595b 100644
--- a/src/primaite/notebooks/uc2_demo.ipynb
+++ b/src/primaite/notebooks/uc2_demo.ipynb
@@ -346,9 +346,7 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
@@ -358,9 +356,7 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
"# Imports\n",
@@ -383,9 +379,7 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
"# create the env\n",
@@ -396,10 +390,10 @@
" cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['data_manipulation_p_of_success'] = 1.0\n",
" cfg['simulation']['network']['nodes'][8]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n",
" cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n",
- "game = PrimaiteGame.from_config(cfg)\n",
- "env = PrimaiteGymEnv(game = game)\n",
- "# Don't flatten obs as we are not training an agent and we wish to see the dict-formatted observations\n",
- "env.agent.flatten_obs = False\n",
+ " # don't flatten observations so that we can see what is going on\n",
+ " cfg['agents'][3]['agent_settings']['flatten_obs'] = False\n",
+ "\n",
+ "env = PrimaiteGymEnv(game_config = cfg)\n",
"obs, info = env.reset()\n",
"print('env created successfully')\n",
"pprint(obs)"
@@ -433,9 +427,7 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
"for step in range(35):\n",
@@ -453,9 +445,7 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
"pprint(obs['NODES'])"
@@ -471,9 +461,7 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
"obs, reward, terminated, truncated, info = env.step(9) # scan database file\n",
@@ -499,9 +487,7 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
"obs, reward, terminated, truncated, info = env.step(13) # patch the database\n",
@@ -526,9 +512,7 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
"obs, reward, terminated, truncated, info = env.step(0) # patch the database\n",
@@ -551,9 +535,7 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
"env.step(13) # Patch the database\n",
@@ -593,6 +575,22 @@
"obs['ACL']"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Reset the environment, you can rerun the other cells to verify that the attack works the same every episode."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "env.reset()"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -603,7 +601,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3 (ipykernel)",
+ "display_name": "venv",
"language": "python",
"name": "python3"
},
@@ -617,9 +615,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8.10"
+ "version": "3.10.12"
}
},
"nbformat": 4,
- "nbformat_minor": 4
+ "nbformat_minor": 2
}
diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py
index a3831bc1..f8dbab9d 100644
--- a/src/primaite/session/environment.py
+++ b/src/primaite/session/environment.py
@@ -1,5 +1,5 @@
import json
-from typing import Any, Dict, Final, Optional, SupportsFloat, Tuple
+from typing import Any, Dict, Optional, SupportsFloat, Tuple
import gymnasium
from gymnasium.core import ActType, ObsType
@@ -18,11 +18,23 @@ class PrimaiteGymEnv(gymnasium.Env):
assumptions about the agent list always having a list of length 1.
"""
- def __init__(self, game: PrimaiteGame):
+ def __init__(self, game_config: Dict):
"""Initialise the environment."""
super().__init__()
- self.game: "PrimaiteGame" = game
- self.agent: ProxyAgent = self.game.rl_agents[0]
+ self.game_config: Dict = game_config
+ """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
+ self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config)
+ """Current game."""
+ self._agent_name = next(iter(self.game.rl_agents))
+ """Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
+
+ self.episode_counter: int = 0
+ """Current episode number."""
+
+ @property
+ def agent(self) -> ProxyAgent:
+ """Grab a fresh reference to the agent object because it will be reinstantiated each episode."""
+ return self.game.rl_agents[self._agent_name]
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
@@ -45,13 +57,13 @@ class PrimaiteGymEnv(gymnasium.Env):
return next_obs, reward, terminated, truncated, info
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
- output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata"
+ output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"step_{self.game.step_counter}.json"
data = {
- "episode": self.game.episode_counter,
+ "episode": self.episode_counter,
"step": self.game.step_counter,
"action": int(action),
"reward": int(reward),
@@ -63,10 +75,12 @@ class PrimaiteGymEnv(gymnasium.Env):
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
print(
- f"Resetting environment, episode {self.game.episode_counter}, "
- f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}"
+ f"Resetting environment, episode {self.episode_counter}, "
+ f"avg. reward: {self.agent.reward_function.total_reward}"
)
- self.game.reset()
+ self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config)
+ self.game.setup_for_episode(episode=self.episode_counter)
+ self.episode_counter += 1
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
@@ -88,12 +102,12 @@ class PrimaiteGymEnv(gymnasium.Env):
def _get_obs(self) -> ObsType:
"""Return the current observation."""
- if not self.agent.flatten_obs:
- return self.agent.observation_manager.current_observation
- else:
+ if self.agent.flatten_obs:
unflat_space = self.agent.observation_manager.space
unflat_obs = self.agent.observation_manager.current_observation
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
+ else:
+ return self.agent.observation_manager.current_observation
class PrimaiteRayEnv(gymnasium.Env):
@@ -102,12 +116,11 @@ class PrimaiteRayEnv(gymnasium.Env):
def __init__(self, env_config: Dict) -> None:
"""Initialise the environment.
- :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
- which is the PrimaiteGame instance.
- :type env_config: Dict[str, PrimaiteGame]
+ :param env_config: A dictionary containing the environment configuration.
+ :type env_config: Dict
"""
- self.env = PrimaiteGymEnv(game=PrimaiteGame.from_config(env_config["cfg"]))
- self.env.game.episode_counter -= 1
+ self.env = PrimaiteGymEnv(game_config=env_config)
+ self.env.episode_counter -= 1
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
@@ -128,13 +141,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
- :type env_config: Dict[str, PrimaiteGame]
+ :type env_config: Dict
"""
- self.game: PrimaiteGame = PrimaiteGame.from_config(env_config["cfg"])
+ self.game_config: Dict = env_config
+ """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
+ self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config)
"""Reference to the primaite game"""
- self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents}
- """List of all possible agents in the environment. This list should not change!"""
- self._agent_ids = list(self.agents.keys())
+ self._agent_ids = list(self.game.rl_agents.keys())
+ """Agent ids. This is a list of strings of agent names."""
+ self.episode_counter: int = 0
+ """Current episode number."""
self.terminateds = set()
self.truncateds = set()
@@ -149,9 +165,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
)
super().__init__()
+ @property
+ def agents(self) -> Dict[str, ProxyAgent]:
+ """Grab a fresh reference to the agents from this episode's game object."""
+ return {name: self.game.rl_agents[name] for name in self._agent_ids}
+
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
- self.game.reset()
+ self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config)
+ self.game.setup_for_episode(episode=self.episode_counter)
+ self.episode_counter += 1
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
@@ -172,7 +195,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
# 1. Perform actions
for agent_name, action in actions.items():
self.agents[agent_name].store_action(action)
- agent_actions = self.game.apply_agent_actions()
+ self.game.apply_agent_actions()
# 2. Advance timestep
self.game.advance_timestep()
@@ -186,7 +209,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
terminateds = {name: False for name, _ in self.agents.items()}
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
- infos = {"agent_actions": agent_actions}
+ infos = {name: {} for name, _ in self.agents.items()}
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
truncateds["__all__"] = self.game.calculate_truncated()
if self.game.save_step_metadata:
@@ -194,13 +217,13 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
return next_obs, rewards, terminateds, truncateds, infos
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):
- output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata"
+ output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"step_{self.game.step_counter}.json"
data = {
- "episode": self.game.episode_counter,
+ "episode": self.episode_counter,
"step": self.game.step_counter,
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
"reward": rewards,
@@ -212,8 +235,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def _get_obs(self) -> Dict[str, ObsType]:
"""Return the current observation."""
obs = {}
- for name, agent in self.agents.items():
+ for agent_name in self._agent_ids:
+ agent = self.game.rl_agents[agent_name]
unflat_space = agent.observation_manager.space
unflat_obs = agent.observation_manager.current_observation
- obs[name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
+ obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
return obs
diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py
index 5c663cfd..b8f80e95 100644
--- a/src/primaite/session/session.py
+++ b/src/primaite/session/session.py
@@ -101,11 +101,11 @@ class PrimaiteSession:
# CREATE ENVIRONMENT
if sess.training_options.rl_framework == "RLLIB_single_agent":
- sess.env = PrimaiteRayEnv(env_config={"cfg": cfg})
+ sess.env = PrimaiteRayEnv(env_config=cfg)
elif sess.training_options.rl_framework == "RLLIB_multi_agent":
- sess.env = PrimaiteRayMARLEnv(env_config={"cfg": cfg})
+ sess.env = PrimaiteRayMARLEnv(env_config=cfg)
elif sess.training_options.rl_framework == "SB3":
- sess.env = PrimaiteGymEnv(game=game)
+ sess.env = PrimaiteGymEnv(game_config=cfg)
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
if agent_load_path:
diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py
index 964dac01..99e9be7f 100644
--- a/src/primaite/simulator/core.py
+++ b/src/primaite/simulator/core.py
@@ -153,22 +153,18 @@ class SimComponent(BaseModel):
uuid: str = Field(default_factory=lambda: str(uuid4()))
"""The component UUID."""
- _original_state: Dict = {}
-
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._request_manager: RequestManager = self._init_request_manager()
self._parent: Optional["SimComponent"] = None
- # @abstractmethod
- def set_original_state(self):
- """Sets the original state."""
- pass
+ def setup_for_episode(self, episode: int):
+ """
+ Perform any additional setup on this component that can't happen during __init__.
- def reset_component_for_episode(self, episode: int):
- """Reset the original state of the SimComponent."""
- for key, value in self._original_state.items():
- self.__setattr__(key, value)
+ For instance, some components may require for the entire network to exist before some configuration can be set.
+ """
+ pass
def _init_request_manager(self) -> RequestManager:
"""
diff --git a/src/primaite/simulator/domain/account.py b/src/primaite/simulator/domain/account.py
index d9dad06a..186caf5b 100644
--- a/src/primaite/simulator/domain/account.py
+++ b/src/primaite/simulator/domain/account.py
@@ -42,19 +42,6 @@ class Account(SimComponent):
"Account Type, currently this can be service account (used by apps) or user account."
enabled: bool = True
- def set_original_state(self):
- """Sets the original state."""
- vals_to_include = {
- "num_logons",
- "num_logoffs",
- "num_group_changes",
- "username",
- "password",
- "account_type",
- "enabled",
- }
- self._original_state = self.model_dump(include=vals_to_include)
-
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
diff --git a/src/primaite/simulator/file_system/file.py b/src/primaite/simulator/file_system/file.py
index 608a1d78..d9b02e8e 100644
--- a/src/primaite/simulator/file_system/file.py
+++ b/src/primaite/simulator/file_system/file.py
@@ -73,20 +73,6 @@ class File(FileSystemItemABC):
self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})")
- self.set_original_state()
-
- def set_original_state(self):
- """Sets the original state."""
- _LOGGER.debug(f"Setting File ({self.path}) original state on node {self.sys_log.hostname}")
- super().set_original_state()
- vals_to_include = {"folder_id", "folder_name", "file_type", "sim_size", "real", "sim_path", "sim_root"}
- self._original_state.update(self.model_dump(include=vals_to_include))
-
- def reset_component_for_episode(self, episode: int):
- """Reset the original state of the SimComponent."""
- _LOGGER.debug(f"Resetting File ({self.path}) state on node {self.sys_log.hostname}")
- super().reset_component_for_episode(episode)
-
@property
def path(self) -> str:
"""
diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py
index ee80587d..8fd4e5d7 100644
--- a/src/primaite/simulator/file_system/file_system.py
+++ b/src/primaite/simulator/file_system/file_system.py
@@ -34,43 +34,6 @@ class FileSystem(SimComponent):
if not self.folders:
self.create_folder("root")
- def set_original_state(self):
- """Sets the original state."""
- _LOGGER.debug(f"Setting FileSystem original state on node {self.sys_log.hostname}")
- for folder in self.folders.values():
- folder.set_original_state()
- # Capture a list of all 'original' file uuids
- original_keys = list(self.folders.keys())
- vals_to_include = {"sim_root"}
- self._original_state.update(self.model_dump(include=vals_to_include))
- self._original_state["original_folder_uuids"] = original_keys
-
- def reset_component_for_episode(self, episode: int):
- """Reset the original state of the SimComponent."""
- _LOGGER.debug(f"Resetting FileSystem state on node {self.sys_log.hostname}")
- # Move any 'original' folder that have been deleted back to folders
- original_folder_uuids = self._original_state["original_folder_uuids"]
- for uuid in original_folder_uuids:
- if uuid in self.deleted_folders:
- folder = self.deleted_folders[uuid]
- self.deleted_folders.pop(uuid)
- self.folders[uuid] = folder
-
- # Clear any other deleted folders that aren't original (have been created by agent)
- self.deleted_folders.clear()
-
- # Now clear all non-original folders created by agent
- current_folder_uuids = list(self.folders.keys())
- for uuid in current_folder_uuids:
- if uuid not in original_folder_uuids:
- folder = self.folders[uuid]
- self.folders.pop(uuid)
-
- # Now reset all remaining folders
- for folder in self.folders.values():
- folder.reset_component_for_episode(episode)
- super().reset_component_for_episode(episode)
-
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
diff --git a/src/primaite/simulator/file_system/file_system_item_abc.py b/src/primaite/simulator/file_system/file_system_item_abc.py
index c3e1426b..fbe5f4b3 100644
--- a/src/primaite/simulator/file_system/file_system_item_abc.py
+++ b/src/primaite/simulator/file_system/file_system_item_abc.py
@@ -85,11 +85,6 @@ class FileSystemItemABC(SimComponent):
deleted: bool = False
"If true, the FileSystemItem was deleted."
- def set_original_state(self):
- """Sets the original state."""
- vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red", "deleted"}
- self._original_state = self.model_dump(include=vals_to_keep)
-
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py
index 13fdc597..771dc7a0 100644
--- a/src/primaite/simulator/file_system/folder.py
+++ b/src/primaite/simulator/file_system/folder.py
@@ -49,49 +49,6 @@ class Folder(FileSystemItemABC):
self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})")
- def set_original_state(self):
- """Sets the original state."""
- _LOGGER.debug(f"Setting Folder ({self.name}) original state on node {self.sys_log.hostname}")
- for file in self.files.values():
- file.set_original_state()
- super().set_original_state()
- vals_to_include = {
- "scan_duration",
- "scan_countdown",
- "red_scan_duration",
- "red_scan_countdown",
- "restore_duration",
- "restore_countdown",
- }
- self._original_state.update(self.model_dump(include=vals_to_include))
- self._original_state["original_file_uuids"] = list(self.files.keys())
-
- def reset_component_for_episode(self, episode: int):
- """Reset the original state of the SimComponent."""
- _LOGGER.debug(f"Resetting Folder ({self.name}) state on node {self.sys_log.hostname}")
- # Move any 'original' file that have been deleted back to files
- original_file_uuids = self._original_state["original_file_uuids"]
- for uuid in original_file_uuids:
- if uuid in self.deleted_files:
- file = self.deleted_files[uuid]
- self.deleted_files.pop(uuid)
- self.files[uuid] = file
-
- # Clear any other deleted files that aren't original (have been created by agent)
- self.deleted_files.clear()
-
- # Now clear all non-original files created by agent
- current_file_uuids = list(self.files.keys())
- for uuid in current_file_uuids:
- if uuid not in original_file_uuids:
- file = self.files[uuid]
- self.files.pop(uuid)
-
- # Now reset all remaining files
- for file in self.files.values():
- file.reset_component_for_episode(episode)
- super().reset_component_for_episode(episode)
-
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(
diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py
index 724b8728..d264f751 100644
--- a/src/primaite/simulator/network/airspace.py
+++ b/src/primaite/simulator/network/airspace.py
@@ -273,11 +273,6 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC)
return state
- def set_original_state(self):
- """Sets the original state."""
- vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"}
- self._original_state = self.model_dump(include=vals_to_include)
-
def enable(self):
"""
Enables this wired network interface and attempts to send a "hello" message to the default gateway.
diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py
index b32d2630..b5a16430 100644
--- a/src/primaite/simulator/network/container.py
+++ b/src/primaite/simulator/network/container.py
@@ -45,19 +45,12 @@ class Network(SimComponent):
self._nx_graph = MultiGraph()
- def set_original_state(self):
- """Sets the original state."""
- for node in self.nodes.values():
- node.set_original_state()
- for link in self.links.values():
- link.set_original_state()
-
- def reset_component_for_episode(self, episode: int):
+ def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
for node in self.nodes.values():
- node.reset_component_for_episode(episode)
+ node.setup_for_episode(episode=episode)
for link in self.links.values():
- link.reset_component_for_episode(episode)
+ link.setup_for_episode(episode=episode)
for node in self.nodes.values():
node.power_on()
@@ -179,7 +172,7 @@ class Network(SimComponent):
def clear_links(self):
"""Clear all the links in the network by resetting their component state for the episode."""
for link in self.links.values():
- link.reset_component_for_episode()
+ link.setup_for_episode(episode=0) # TODO: shouldn't be using this method here.
def draw(self, seed: int = 123):
"""
diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py
index 35c90d05..ff79f314 100644
--- a/src/primaite/simulator/network/hardware/base.py
+++ b/src/primaite/simulator/network/hardware/base.py
@@ -100,6 +100,15 @@ class NetworkInterface(SimComponent, ABC):
nmne: Dict = Field(default_factory=lambda: {})
"A dict containing details of the number of malicious network events captured."
+ def setup_for_episode(self, episode: int):
+ """Reset the original state of the SimComponent."""
+ super().setup_for_episode(episode=episode)
+ self.nmne = {}
+ if episode and self.pcap:
+ self.pcap.current_episode = episode
+ self.pcap.setup_logger()
+ self.enable()
+
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
@@ -127,15 +136,6 @@ class NetworkInterface(SimComponent, ABC):
state.update({"nmne": self.nmne})
return state
- def reset_component_for_episode(self, episode: int):
- """Reset the original state of the SimComponent."""
- super().reset_component_for_episode(episode)
- self.nmne = {}
- if episode and self.pcap:
- self.pcap.current_episode = episode
- self.pcap.setup_logger()
- self.enable()
-
@abstractmethod
def enable(self):
"""Enable the interface."""
@@ -547,14 +547,6 @@ class Link(SimComponent):
self.endpoint_b.connect_link(self)
self.endpoint_up()
- self.set_original_state()
-
- def set_original_state(self):
- """Sets the original state."""
- vals_to_include = {"bandwidth", "current_load"}
- self._original_state = self.model_dump(include=vals_to_include)
- super().set_original_state()
-
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -740,50 +732,20 @@ class Node(SimComponent):
self.session_manager.node = self
self.session_manager.software_manager = self.software_manager
self._install_system_software()
- self.set_original_state()
- def set_original_state(self):
- """Sets the original state."""
- for software in self.software_manager.software.values():
- software.set_original_state()
-
- self.file_system.set_original_state()
-
- for network_interface in self.network_interfaces.values():
- network_interface.set_original_state()
-
- vals_to_include = {
- "hostname",
- "default_gateway",
- "operating_state",
- "revealed_to_red",
- "start_up_duration",
- "start_up_countdown",
- "shut_down_duration",
- "shut_down_countdown",
- "is_resetting",
- "node_scan_duration",
- "node_scan_countdown",
- "red_scan_countdown",
- }
- self._original_state = self.model_dump(include=vals_to_include)
-
- def reset_component_for_episode(self, episode: int):
+ def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
- super().reset_component_for_episode(episode)
-
- # Reset Session Manager
- self.session_manager.clear()
+ super().setup_for_episode(episode=episode)
# Reset File System
- self.file_system.reset_component_for_episode(episode)
+ self.file_system.setup_for_episode(episode=episode)
# Reset all Nics
for network_interface in self.network_interfaces.values():
- network_interface.reset_component_for_episode(episode)
+ network_interface.setup_for_episode(episode=episode)
for software in self.software_manager.software.values():
- software.reset_component_for_episode(episode)
+ software.setup_for_episode(episode=episode)
if episode and self.sys_log:
self.sys_log.current_episode = episode
diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py
index 1d3b9926..cb3c1bd7 100644
--- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py
+++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py
@@ -209,11 +209,6 @@ class NIC(IPWiredNetworkInterface):
return state
- def set_original_state(self):
- """Sets the original state."""
- vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"}
- self._original_state = self.model_dump(include=vals_to_include)
-
def receive_frame(self, frame: Frame) -> bool:
"""
Attempt to receive and process a network frame from the connected Link.
diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py
index ce98cec4..b4d5cdba 100644
--- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py
+++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py
@@ -111,24 +111,6 @@ class Firewall(Router):
sys_log=kwargs["sys_log"], implicit_action=ACLAction.PERMIT, name=f"{hostname} - External Outbound"
)
- self.set_original_state()
-
- def set_original_state(self):
- """Set the original state for the Firewall."""
- super().set_original_state()
- vals_to_include = {
- "internal_port",
- "external_port",
- "dmz_port",
- "internal_inbound_acl",
- "internal_outbound_acl",
- "dmz_inbound_acl",
- "dmz_outbound_acl",
- "external_inbound_acl",
- "external_outbound_acl",
- }
- self._original_state.update(self.model_dump(include=vals_to_include))
-
def describe_state(self) -> Dict:
"""
Describes the current state of the Firewall.
diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py
index a9e12401..5b45f59c 100644
--- a/src/primaite/simulator/network/hardware/nodes/network/router.py
+++ b/src/primaite/simulator/network/hardware/nodes/network/router.py
@@ -136,11 +136,6 @@ class ACLRule(SimComponent):
rule_strings.append(f"{key}={value}")
return ", ".join(rule_strings)
- def set_original_state(self):
- """Sets the original state."""
- vals_to_keep = {"action", "protocol", "src_ip_address", "src_port", "dst_ip_address", "dst_port"}
- self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
-
def describe_state(self) -> Dict:
"""
Describes the current state of the ACLRule.
@@ -296,48 +291,6 @@ class AccessControlList(SimComponent):
super().__init__(**kwargs)
self._acl = [None] * (self.max_acl_rules - 1)
- self.set_original_state()
-
- def set_original_state(self):
- """Sets the original state."""
- self.implicit_rule.set_original_state()
- vals_to_keep = {"implicit_action", "max_acl_rules", "acl"}
- self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
-
- for i, rule in enumerate(self._acl):
- if not rule:
- continue
- self._default_config[i] = {"action": rule.action.name}
- if rule.src_ip_address:
- self._default_config[i]["src_ip"] = str(rule.src_ip_address)
- if rule.dst_ip_address:
- self._default_config[i]["dst_ip"] = str(rule.dst_ip_address)
- if rule.src_port:
- self._default_config[i]["src_port"] = rule.src_port.name
- if rule.dst_port:
- self._default_config[i]["dst_port"] = rule.dst_port.name
- if rule.protocol:
- self._default_config[i]["protocol"] = rule.protocol.name
-
- def reset_component_for_episode(self, episode: int):
- """Reset the original state of the SimComponent."""
- self.implicit_rule.reset_component_for_episode(episode)
- super().reset_component_for_episode(episode)
- self._reset_rules_to_default()
-
- def _reset_rules_to_default(self) -> None:
- """Clear all ACL rules and set them to the default rules config."""
- self._acl = [None] * (self.max_acl_rules - 1)
- for r_num, r_cfg in self._default_config.items():
- self.add_rule(
- action=ACLAction[r_cfg["action"]],
- src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
- dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
- protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
- src_ip_address=r_cfg.get("src_ip"),
- dst_ip_address=r_cfg.get("dst_ip"),
- position=r_num,
- )
def _init_request_manager(self) -> RequestManager:
# TODO: Add src and dst wildcard masks as positional args in this request.
@@ -616,11 +569,6 @@ class RouteEntry(SimComponent):
metric: float = 0.0
"The cost metric for this route. Default is 0.0."
- def set_original_state(self):
- """Sets the original state."""
- vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"}
- self._original_values = self.model_dump(include=vals_to_include)
-
def describe_state(self) -> Dict:
"""
Describes the current state of the RouteEntry.
@@ -653,17 +601,6 @@ class RouteTable(SimComponent):
default_route: Optional[RouteEntry] = None
sys_log: SysLog
- def set_original_state(self):
- """Sets the original state."""
- super().set_original_state()
- self._original_state["routes_orig"] = self.routes
-
- def reset_component_for_episode(self, episode: int):
- """Reset the original state of the SimComponent."""
- self.routes.clear()
- self.routes = self._original_state["routes_orig"]
- super().reset_component_for_episode(episode)
-
def describe_state(self) -> Dict:
"""
Describes the current state of the RouteTable.
@@ -1104,8 +1041,6 @@ class Router(NetworkNode):
self._set_default_acl()
- self.set_original_state()
-
def _install_system_software(self):
"""
Installs essential system software and network services on the router.
@@ -1131,20 +1066,7 @@ class Router(NetworkNode):
self.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
self.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
- def set_original_state(self):
- """
- Sets or resets the router to its original configuration state.
-
- This method is called to initialize the router's state or to revert it to a known good configuration during
- network simulations or after configuration changes.
- """
- self.acl.set_original_state()
- self.route_table.set_original_state()
- super().set_original_state()
- vals_to_include = {"num_ports"}
- self._original_state.update(self.model_dump(include=vals_to_include))
-
- def reset_component_for_episode(self, episode: int):
+ def setup_for_episode(self, episode: int):
"""
Resets the router's components for a new network simulation episode.
@@ -1154,13 +1076,10 @@ class Router(NetworkNode):
:param episode: The episode number for which the router is being reset.
"""
self.software_manager.arp.clear()
- self.acl.reset_component_for_episode(episode)
- self.route_table.reset_component_for_episode(episode)
- for i, network_interface in self.network_interface.items():
- network_interface.reset_component_for_episode(episode)
+ for i, _ in self.network_interface.items():
self.enable_port(i)
- super().reset_component_for_episode(episode)
+ super().setup_for_episode(episode=episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
@@ -1391,7 +1310,6 @@ class Router(NetworkNode):
network_interface.ip_address = ip_address
network_interface.subnet_mask = subnet_mask
self.sys_log.info(f"Configured Network Interface {network_interface}")
- self.set_original_state()
def enable_port(self, port: int):
"""
@@ -1493,8 +1411,16 @@ class Router(NetworkNode):
subnet_mask=IPv4Address(port_cfg.get("subnet_mask", "255.255.255.0")),
)
if "acl" in cfg:
- router.acl._default_config = cfg["acl"] # save the config to allow resetting
- router.acl._reset_rules_to_default() # read the config and apply rules
+ for r_num, r_cfg in cfg["acl"].items():
+ router.acl.add_rule(
+ action=ACLAction[r_cfg["action"]],
+ src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
+ dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
+ protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
+ src_ip_address=r_cfg.get("src_ip"),
+ dst_ip_address=r_cfg.get("dst_ip"),
+ position=r_num,
+ )
if "routes" in cfg:
for route in cfg.get("routes"):
router.route_table.add_route(
diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py
index 33e6ee9a..557ea287 100644
--- a/src/primaite/simulator/network/hardware/nodes/network/switch.py
+++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py
@@ -32,12 +32,6 @@ class SwitchPort(WiredNetworkInterface):
_connected_node: Optional[Switch] = None
"The Switch to which the SwitchPort is connected."
- def set_original_state(self):
- """Sets the original state."""
- vals_to_include = {"port_num", "mac_address", "speed", "mtu", "enabled"}
- self._original_state = self.model_dump(include=vals_to_include)
- super().set_original_state()
-
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py
index dd0b58d3..91833d6a 100644
--- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py
+++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py
@@ -122,8 +122,6 @@ class WirelessRouter(Router):
self.connect_nic(RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0"))
- self.set_original_state()
-
@property
def wireless_access_point(self) -> WirelessAccessPoint:
"""
@@ -166,7 +164,6 @@ class WirelessRouter(Router):
network_interface.ip_address = ip_address
network_interface.subnet_mask = subnet_mask
self.sys_log.info(f"Configured WAP {network_interface}")
- self.set_original_state()
self.wireless_access_point.frequency = frequency # Set operating frequency
self.wireless_access_point.enable() # Re-enable the WAP with new settings
diff --git a/src/primaite/simulator/sim_container.py b/src/primaite/simulator/sim_container.py
index 896861e6..a2285d92 100644
--- a/src/primaite/simulator/sim_container.py
+++ b/src/primaite/simulator/sim_container.py
@@ -21,13 +21,9 @@ class Simulation(SimComponent):
super().__init__(**kwargs)
- def set_original_state(self):
- """Sets the original state."""
- self.network.set_original_state()
-
- def reset_component_for_episode(self, episode: int):
+ def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
- self.network.reset_component_for_episode(episode)
+ self.network.setup_for_episode(episode=episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py
index 322ac808..513606a9 100644
--- a/src/primaite/simulator/system/applications/application.py
+++ b/src/primaite/simulator/system/applications/application.py
@@ -38,12 +38,6 @@ class Application(IOSoftware):
def __init__(self, **kwargs):
super().__init__(**kwargs)
- def set_original_state(self):
- """Sets the original state."""
- super().set_original_state()
- vals_to_include = {"operating_state", "execution_control_status", "num_executions", "groups"}
- self._original_state.update(self.model_dump(include=vals_to_include))
-
@abstractmethod
def describe_state(self) -> Dict:
"""
diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py
index 50d9f3d4..57fd3a46 100644
--- a/src/primaite/simulator/system/applications/database_client.py
+++ b/src/primaite/simulator/system/applications/database_client.py
@@ -31,20 +31,6 @@ class DatabaseClient(Application):
kwargs["port"] = Port.POSTGRES_SERVER
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
- self.set_original_state()
-
- def set_original_state(self):
- """Sets the original state."""
- _LOGGER.debug(f"Setting DatabaseClient WebServer original state on node {self.software_manager.node.hostname}")
- super().set_original_state()
- vals_to_include = {"server_ip_address", "server_password", "connected", "_query_success_tracker"}
- self._original_state.update(self.model_dump(include=vals_to_include))
-
- def reset_component_for_episode(self, episode: int):
- """Reset the original state of the SimComponent."""
- _LOGGER.debug(f"Resetting DataBaseClient state on node {self.software_manager.node.hostname}")
- super().reset_component_for_episode(episode)
- self._query_success_tracker.clear()
def describe_state(self) -> Dict:
"""
diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py
index a844f059..5fe951b7 100644
--- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py
+++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py
@@ -49,26 +49,6 @@ class DataManipulationBot(DatabaseClient):
super().__init__(**kwargs)
self.name = "DataManipulationBot"
- def set_original_state(self):
- """Sets the original state."""
- _LOGGER.debug(f"Setting DataManipulationBot original state on node {self.software_manager.node.hostname}")
- super().set_original_state()
- vals_to_include = {
- "server_ip_address",
- "payload",
- "server_password",
- "port_scan_p_of_success",
- "data_manipulation_p_of_success",
- "attack_stage",
- "repeat",
- }
- self._original_state.update(self.model_dump(include=vals_to_include))
-
- def reset_component_for_episode(self, episode: int):
- """Reset the original state of the SimComponent."""
- _LOGGER.debug(f"Resetting DataManipulationBot state on node {self.software_manager.node.hostname}")
- super().reset_component_for_episode(episode)
-
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py
index dfc48dd3..9dac6b25 100644
--- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py
+++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py
@@ -57,27 +57,6 @@ class DoSBot(DatabaseClient, Application):
self.name = "DoSBot"
self.max_sessions = 1000 # override normal max sessions
- def set_original_state(self):
- """Set the original state of the Denial of Service Bot."""
- _LOGGER.debug(f"Setting {self.name} original state on node {self.software_manager.node.hostname}")
- super().set_original_state()
- vals_to_include = {
- "target_ip_address",
- "target_port",
- "payload",
- "repeat",
- "attack_stage",
- "max_sessions",
- "port_scan_p_of_success",
- "dos_intensity",
- }
- self._original_state.update(self.model_dump(include=vals_to_include))
-
- def reset_component_for_episode(self, episode: int):
- """Reset the original state of the SimComponent."""
- _LOGGER.debug(f"Resetting {self.name} state on node {self.software_manager.node.hostname}")
- super().reset_component_for_episode(episode)
-
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py
index eef0ed5d..6f2c479c 100644
--- a/src/primaite/simulator/system/applications/web_browser.py
+++ b/src/primaite/simulator/system/applications/web_browser.py
@@ -47,21 +47,8 @@ class WebBrowser(Application):
kwargs["port"] = Port.HTTP
super().__init__(**kwargs)
- self.set_original_state()
self.run()
- def set_original_state(self):
- """Sets the original state."""
- _LOGGER.debug(f"Setting WebBrowser original state on node {self.software_manager.node.hostname}")
- super().set_original_state()
- vals_to_include = {"target_url", "domain_name_ip_address", "latest_response"}
- self._original_state.update(self.model_dump(include=vals_to_include))
-
- def reset_component_for_episode(self, episode: int):
- """Reset the original state of the SimComponent."""
- _LOGGER.debug(f"Resetting WebBrowser state on node {self.software_manager.node.hostname}")
- super().reset_component_for_episode(episode)
-
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(
@@ -80,9 +67,6 @@ class WebBrowser(Application):
state["history"] = [hist_item.state() for hist_item in self.history]
return state
- def reset_component_for_episode(self, episode: int):
- """Reset the original state of the SimComponent."""
-
def get_webpage(self, url: Optional[str] = None) -> bool:
"""
Retrieve the webpage.
diff --git a/src/primaite/simulator/system/processes/process.py b/src/primaite/simulator/system/processes/process.py
index b753e3ad..458a6b5c 100644
--- a/src/primaite/simulator/system/processes/process.py
+++ b/src/primaite/simulator/system/processes/process.py
@@ -24,12 +24,6 @@ class Process(Software):
operating_state: ProcessOperatingState
"The current operating state of the Process."
- def set_original_state(self):
- """Sets the original state."""
- super().set_original_state()
- vals_to_include = {"operating_state"}
- self._original_state.update(self.model_dump(include=vals_to_include))
-
@abstractmethod
def describe_state(self) -> Dict:
"""
diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py
index 726d213e..9fdfd5ff 100644
--- a/src/primaite/simulator/system/services/database/database_service.py
+++ b/src/primaite/simulator/system/services/database/database_service.py
@@ -41,25 +41,6 @@ class DatabaseService(Service):
super().__init__(**kwargs)
self._create_db_file()
- def set_original_state(self):
- """Sets the original state."""
- _LOGGER.debug(f"Setting DatabaseService original state on node {self.software_manager.node.hostname}")
- super().set_original_state()
- vals_to_include = {
- "password",
- "connections",
- "backup_server_ip",
- "latest_backup_directory",
- "latest_backup_file_name",
- }
- self._original_state.update(self.model_dump(include=vals_to_include))
-
- def reset_component_for_episode(self, episode: int):
- """Reset the original state of the SimComponent."""
- _LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}")
- self.clear_connections()
- super().reset_component_for_episode(episode)
-
def configure_backup(self, backup_server: IPv4Address):
"""
Set up the database backup.
diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py
index 2d3879ff..967af6b2 100644
--- a/src/primaite/simulator/system/services/dns/dns_client.py
+++ b/src/primaite/simulator/system/services/dns/dns_client.py
@@ -29,18 +29,6 @@ class DNSClient(Service):
super().__init__(**kwargs)
self.start()
- def set_original_state(self):
- """Sets the original state."""
- _LOGGER.debug(f"Setting DNSClient original state on node {self.software_manager.node.hostname}")
- super().set_original_state()
- vals_to_include = {"dns_server"}
- self._original_state.update(self.model_dump(include=vals_to_include))
-
- def reset_component_for_episode(self, episode: int):
- """Reset the original state of the SimComponent."""
- self.dns_cache.clear()
- super().reset_component_for_episode(episode)
-
def describe_state(self) -> Dict:
"""
Describes the current state of the software.
diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py
index 8decf7e9..4d0ebbb8 100644
--- a/src/primaite/simulator/system/services/dns/dns_server.py
+++ b/src/primaite/simulator/system/services/dns/dns_server.py
@@ -28,20 +28,6 @@ class DNSServer(Service):
super().__init__(**kwargs)
self.start()
- def set_original_state(self):
- """Sets the original state."""
- _LOGGER.debug(f"Setting DNSServer original state on node {self.software_manager.node.hostname}")
- super().set_original_state()
- vals_to_include = {"dns_table"}
- self._original_state["dns_table_orig"] = self.model_dump(include=vals_to_include)["dns_table"]
-
- def reset_component_for_episode(self, episode: int):
- """Reset the original state of the SimComponent."""
- self.dns_table.clear()
- for key, value in self._original_state["dns_table_orig"].items():
- self.dns_table[key] = value
- super().reset_component_for_episode(episode)
-
def describe_state(self) -> Dict:
"""
Describes the current state of the software.
diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py
index 39bc57f0..7c334ced 100644
--- a/src/primaite/simulator/system/services/ftp/ftp_client.py
+++ b/src/primaite/simulator/system/services/ftp/ftp_client.py
@@ -27,18 +27,6 @@ class FTPClient(FTPServiceABC):
super().__init__(**kwargs)
self.start()
- def set_original_state(self):
- """Sets the original state."""
- _LOGGER.debug(f"Setting FTPClient original state on node {self.software_manager.node.hostname}")
- super().set_original_state()
- vals_to_include = {"connected"}
- self._original_state.update(self.model_dump(include=vals_to_include))
-
- def reset_component_for_episode(self, episode: int):
- """Reset the original state of the SimComponent."""
- _LOGGER.debug(f"Resetting FTPClient state on node {self.software_manager.node.hostname}")
- super().reset_component_for_episode(episode)
-
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""
Process the command in the FTP Packet.
diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py
index a82b0919..c5330de2 100644
--- a/src/primaite/simulator/system/services/ftp/ftp_server.py
+++ b/src/primaite/simulator/system/services/ftp/ftp_server.py
@@ -27,19 +27,6 @@ class FTPServer(FTPServiceABC):
super().__init__(**kwargs)
self.start()
- def set_original_state(self):
- """Sets the original state."""
- _LOGGER.debug(f"Setting FTPServer original state on node {self.software_manager.node.hostname}")
- super().set_original_state()
- vals_to_include = {"server_password"}
- self._original_state.update(self.model_dump(include=vals_to_include))
-
- def reset_component_for_episode(self, episode: int):
- """Reset the original state of the SimComponent."""
- _LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}")
- self.clear_connections()
- super().reset_component_for_episode(episode)
-
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""
Process the command in the FTP Packet.
diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py
index 43d1d783..ad00065c 100644
--- a/src/primaite/simulator/system/services/ntp/ntp_client.py
+++ b/src/primaite/simulator/system/services/ntp/ntp_client.py
@@ -49,21 +49,12 @@ class NTPClient(Service):
state = super().describe_state()
return state
- def reset_component_for_episode(self, episode: int):
- """
- Resets the Service component for a new episode.
-
- This method ensures the Service is ready for a new episode, including resetting any
- stateful properties or statistics, and clearing any message queues.
- """
- pass
-
def send(
self,
payload: NTPPacket,
session_id: Optional[str] = None,
dest_ip_address: IPv4Address = None,
- dest_port: [Port] = Port.NTP,
+ dest_port: Port = Port.NTP,
**kwargs,
) -> bool:
"""Requests NTP data from NTP server.
diff --git a/src/primaite/simulator/system/services/ntp/ntp_server.py b/src/primaite/simulator/system/services/ntp/ntp_server.py
index 3ae80936..f9d9ee7c 100644
--- a/src/primaite/simulator/system/services/ntp/ntp_server.py
+++ b/src/primaite/simulator/system/services/ntp/ntp_server.py
@@ -34,16 +34,6 @@ class NTPServer(Service):
state = super().describe_state()
return state
- def reset_component_for_episode(self, episode: int):
- """
- Resets the Service component for a new episode.
-
- This method ensures the Service is ready for a new episode, including
- resetting any stateful properties or statistics, and clearing any message
- queues.
- """
- pass
-
def receive(
self,
payload: NTPPacket,
diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py
index 162678a0..4102657c 100644
--- a/src/primaite/simulator/system/services/service.py
+++ b/src/primaite/simulator/system/services/service.py
@@ -78,12 +78,6 @@ class Service(IOSoftware):
"""
return super().receive(payload=payload, session_id=session_id, **kwargs)
- def set_original_state(self):
- """Sets the original state."""
- super().set_original_state()
- vals_to_include = {"operating_state", "restart_duration", "restart_countdown"}
- self._original_state.update(self.model_dump(include=vals_to_include))
-
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request("scan", RequestType(func=lambda request, context: self.scan()))
diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py
index eaea6bb1..5e7591e9 100644
--- a/src/primaite/simulator/system/services/web_server/web_server.py
+++ b/src/primaite/simulator/system/services/web_server/web_server.py
@@ -23,18 +23,6 @@ class WebServer(Service):
last_response_status_code: Optional[HttpStatusCode] = None
- def set_original_state(self):
- """Sets the original state."""
- _LOGGER.debug(f"Setting WebServer original state on node {self.software_manager.node.hostname}")
- super().set_original_state()
- vals_to_include = {"last_response_status_code"}
- self._original_state.update(self.model_dump(include=vals_to_include))
-
- def reset_component_for_episode(self, episode: int):
- """Reset the original state of the SimComponent."""
- _LOGGER.debug(f"Resetting WebServer state on node {self.software_manager.node.hostname}")
- super().reset_component_for_episode(episode)
-
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -130,7 +118,7 @@ class WebServer(Service):
self.set_health_state(SoftwareHealthState.COMPROMISED)
return response
- except Exception:
+ except Exception: # TODO: refactor this. Likely to cause silent bugs. (ADO ticket #2345 )
# something went wrong on the server
response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR
return response
diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py
index ce39930b..8864659c 100644
--- a/src/primaite/simulator/system/software.py
+++ b/src/primaite/simulator/system/software.py
@@ -3,7 +3,7 @@ from abc import abstractmethod
from datetime import datetime
from enum import Enum
from ipaddress import IPv4Address, IPv4Network
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, Optional, TYPE_CHECKING, Union
from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent
from primaite.simulator.file_system.file_system import FileSystem, Folder
@@ -13,6 +13,9 @@ from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.session_manager import Session
from primaite.simulator.system.core.sys_log import SysLog
+if TYPE_CHECKING:
+ from primaite.simulator.system.core.software_manager import SoftwareManager
+
class SoftwareType(Enum):
"""
@@ -84,7 +87,7 @@ class Software(SimComponent):
"The count of times the software has been scanned, defaults to 0."
revealed_to_red: bool = False
"Indicates if the software has been revealed to red agent, defaults is False."
- software_manager: Any = None
+ software_manager: "SoftwareManager" = None
"An instance of Software Manager that is used by the parent node."
sys_log: SysLog = None
"An instance of SysLog that is used by the parent node."
@@ -97,19 +100,6 @@ class Software(SimComponent):
_patching_countdown: Optional[int] = None
"Current number of ticks left to patch the software."
- def set_original_state(self):
- """Sets the original state."""
- vals_to_include = {
- "name",
- "health_state_actual",
- "health_state_visible",
- "criticality",
- "patching_count",
- "scanning_count",
- "revealed_to_red",
- }
- self._original_state = self.model_dump(include=vals_to_include)
-
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(
@@ -248,12 +238,6 @@ class IOSoftware(Software):
_connections: Dict[str, Dict] = {}
"Active connections."
- def set_original_state(self):
- """Sets the original state."""
- super().set_original_state()
- vals_to_include = {"installing_count", "max_sessions", "tcp", "udp", "port"}
- self._original_state.update(self.model_dump(include=vals_to_include))
-
@abstractmethod
def describe_state(self) -> Dict:
"""
diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml
index 5bdc3273..c76aeef6 100644
--- a/tests/assets/configs/bad_primaite_session.yaml
+++ b/tests/assets/configs/bad_primaite_session.yaml
@@ -589,15 +589,16 @@ simulation:
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
- default_gateway: 192.168.1.10
+ default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
+ - ref: web_server_web_service
+ type: WebServer
+ applications:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- - ref: web_server_web_service
- type: WebServer
- ref: database_server
diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml
index 8361e318..1cb59f87 100644
--- a/tests/assets/configs/eval_only_primaite_session.yaml
+++ b/tests/assets/configs/eval_only_primaite_session.yaml
@@ -593,15 +593,16 @@ simulation:
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
- default_gateway: 192.168.1.10
+ default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
+ - ref: web_server_web_service
+ type: WebServer
+ applications:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- - ref: web_server_web_service
- type: WebServer
- ref: database_server
@@ -624,7 +625,7 @@ simulation:
dns_server: 192.168.1.10
services:
- ref: backup_service
- type: DatabaseBackup
+ type: FTPServer
- ref: security_suite
type: server
diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml
index 87bd9d1c..b1b15372 100644
--- a/tests/assets/configs/multi_agent_session.yaml
+++ b/tests/assets/configs/multi_agent_session.yaml
@@ -1043,16 +1043,16 @@ simulation:
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
- default_gateway: 192.168.1.10
+ default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
+ - ref: web_server_web_service
+ type: WebServer
+ applications:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- - ref: web_server_web_service
- type: WebServer
-
- ref: database_server
type: server
@@ -1074,7 +1074,7 @@ simulation:
dns_server: 192.168.1.10
services:
- ref: backup_service
- type: DatabaseBackup
+ type: FTPServer
- ref: security_suite
type: server
diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml
index 76190a64..e5f9d544 100644
--- a/tests/assets/configs/test_primaite_session.yaml
+++ b/tests/assets/configs/test_primaite_session.yaml
@@ -599,15 +599,16 @@ simulation:
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
- default_gateway: 192.168.1.10
+ default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
+ - ref: web_server_web_service
+ type: WebServer
+ applications:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- - ref: web_server_web_service
- type: WebServer
- ref: database_server
@@ -630,7 +631,7 @@ simulation:
dns_server: 192.168.1.10
services:
- ref: backup_service
- type: DatabaseBackup
+ type: FTPServer
- ref: security_suite
type: server
diff --git a/tests/assets/configs/train_only_primaite_session.yaml b/tests/assets/configs/train_only_primaite_session.yaml
index 5d004c7e..10e088d8 100644
--- a/tests/assets/configs/train_only_primaite_session.yaml
+++ b/tests/assets/configs/train_only_primaite_session.yaml
@@ -600,15 +600,16 @@ simulation:
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
- default_gateway: 192.168.1.10
+ default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
+ - ref: web_server_web_service
+ type: WebServer
+ applications:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- - ref: web_server_web_service
- type: WebServer
- ref: database_server
@@ -631,7 +632,7 @@ simulation:
dns_server: 192.168.1.10
services:
- ref: backup_service
- type: DatabaseBackup
+ type: FTPServer
- ref: security_suite
type: server
diff --git a/tests/conftest.py b/tests/conftest.py
index ada89026..dbfff2f3 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -529,6 +529,6 @@ def game_and_agent():
reward_function=reward_function,
)
- game.agents.append(test_agent)
+ game.agents["test_agent"] = test_agent
return (game, test_agent)
diff --git a/tests/e2e_integration_tests/environments/test_sb3_environment.py b/tests/e2e_integration_tests/environments/test_sb3_environment.py
index 91cf5c1e..c48ddbc9 100644
--- a/tests/e2e_integration_tests/environments/test_sb3_environment.py
+++ b/tests/e2e_integration_tests/environments/test_sb3_environment.py
@@ -11,14 +11,12 @@ from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
-# @pytest.mark.skip(reason="no way of currently testing this")
def test_sb3_compatibility():
"""Test that the Gymnasium environment can be used with an SB3 agent."""
with open(example_config_path(), "r") as f:
cfg = yaml.safe_load(f)
- game = PrimaiteGame.from_config(cfg)
- gym = PrimaiteGymEnv(game=game)
+ gym = PrimaiteGymEnv(game_config=cfg)
model = PPO("MlpPolicy", gym)
model.learn(total_timesteps=1000)
diff --git a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py
index 54dca371..306f591d 100644
--- a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py
+++ b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py
@@ -14,7 +14,47 @@ from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
from primaite.simulator.system.services.web_server.web_server import WebServer
-from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, load_config
+from tests import TEST_ASSETS_ROOT
+
+BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
+
+
+def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
+ """Returns a PrimaiteGame object which loads the contents of a given yaml path."""
+ with open(config_path, "r") as f:
+ cfg = yaml.safe_load(f)
+
+ return PrimaiteGame.from_config(cfg)
+
+
+def test_example_config():
+ """Test that the example config can be parsed properly."""
+ game = load_config(example_config_path())
+
+ assert len(game.agents) == 4 # red, blue and 2 green agents
+
+ # green agent 1
+ assert "client_2_green_user" in game.agents
+ assert isinstance(game.agents["client_2_green_user"], RandomAgent)
+
+ # green agent 2
+ assert "client_1_green_user" in game.agents
+ assert isinstance(game.agents["client_1_green_user"], RandomAgent)
+
+ # red agent
+ assert "client_1_data_manipulation_red_bot" in game.agents
+ assert isinstance(game.agents["client_1_data_manipulation_red_bot"], DataManipulationAgent)
+
+ # blue agent
+ assert "defender" in game.agents
+ assert isinstance(game.agents["defender"], ProxyAgent)
+
+ network: Network = game.simulation.network
+
+ assert len(network.nodes) == 10 # 10 nodes in example network
+ assert len(network.routers) == 1 # 1 router in network
+ assert len(network.switches) == 2 # 2 switches in network
+ assert len(network.servers) == 5 # 5 servers in network
def test_node_software_install():
diff --git a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py
index 01ad3871..786fe851 100644
--- a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py
+++ b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py
@@ -7,26 +7,11 @@ from primaite.simulator.domain.account import Account, AccountType
@pytest.fixture(scope="function")
def account() -> Account:
acct = Account(username="Jake", password="totally_hashed_password", account_type=AccountType.USER)
- acct.set_original_state()
return acct
def test_original_state(account):
"""Test the original state - see if it resets properly"""
- account.log_on()
- account.log_off()
- account.disable()
-
- state = account.describe_state()
- assert state["num_logons"] is 1
- assert state["num_logoffs"] is 1
- assert state["num_group_changes"] is 0
- assert state["username"] is "Jake"
- assert state["password"] is "totally_hashed_password"
- assert state["account_type"] is AccountType.USER.value
- assert state["enabled"] is False
-
- account.reset_component_for_episode(episode=1)
state = account.describe_state()
assert state["num_logons"] is 0
assert state["num_logoffs"] is 0
@@ -39,13 +24,7 @@ def test_original_state(account):
account.log_on()
account.log_off()
account.disable()
- account.set_original_state()
- account.log_on()
- state = account.describe_state()
- assert state["num_logons"] is 2
-
- account.reset_component_for_episode(episode=2)
state = account.describe_state()
assert state["num_logons"] is 1
assert state["num_logoffs"] is 1
diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py
index 9366d173..4defc80c 100644
--- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py
+++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py
@@ -185,38 +185,6 @@ def test_get_file(file_system):
file_system.show(full=True)
-def test_reset_file_system(file_system):
- # file and folder that existed originally
- file_system.create_file(file_name="test_file.zip")
- file_system.create_folder(folder_name="test_folder")
- file_system.set_original_state()
-
- # create a new file
- file_system.create_file(file_name="new_file.txt")
-
- # create a new folder
- file_system.create_folder(folder_name="new_folder")
-
- # delete the file that existed originally
- file_system.delete_file(folder_name="root", file_name="test_file.zip")
- assert file_system.get_file(folder_name="root", file_name="test_file.zip") is None
-
- # delete the folder that existed originally
- file_system.delete_folder(folder_name="test_folder")
- assert file_system.get_folder(folder_name="test_folder") is None
-
- # reset
- file_system.reset_component_for_episode(episode=1)
-
- # deleted original file and folder should be back
- assert file_system.get_file(folder_name="root", file_name="test_file.zip")
- assert file_system.get_folder(folder_name="test_folder")
-
- # new file and folder should be removed
- assert file_system.get_file(folder_name="root", file_name="new_file.txt") is None
- assert file_system.get_folder(folder_name="new_folder") is None
-
-
@pytest.mark.skip(reason="Skipping until we tackle serialisation")
def test_serialisation(file_system):
"""Test to check that the object serialisation works correctly."""
diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_container.py b/tests/unit_tests/_primaite/_simulator/_network/test_container.py
index 9d424697..2cfc3f11 100644
--- a/tests/unit_tests/_primaite/_simulator/_network/test_container.py
+++ b/tests/unit_tests/_primaite/_simulator/_network/test_container.py
@@ -31,7 +31,6 @@ def network(example_network) -> Network:
assert len(example_network.computers) is 2
assert len(example_network.servers) is 2
- example_network.set_original_state()
example_network.show()
return example_network
@@ -45,40 +44,6 @@ def test_describe_state(network):
assert len(state["links"]) is 6
-def test_reset_network(network):
- """
- Test that the network is properly reset.
-
- TODO: make sure that once implemented - any installed/uninstalled services, processes, apps,
- etc are also removed/reinstalled
-
- """
- state_before = network.describe_state()
-
- client_1: Computer = network.get_node_by_hostname("client_1")
- server_1: Computer = network.get_node_by_hostname("server_1")
-
- assert client_1.operating_state is NodeOperatingState.ON
- assert server_1.operating_state is NodeOperatingState.ON
-
- client_1.power_off()
- assert client_1.operating_state is NodeOperatingState.SHUTTING_DOWN
-
- server_1.power_off()
- assert server_1.operating_state is NodeOperatingState.SHUTTING_DOWN
-
- assert network.describe_state() != state_before
-
- network.reset_component_for_episode(episode=1)
-
- assert client_1.operating_state is NodeOperatingState.ON
- assert server_1.operating_state is NodeOperatingState.ON
- # don't worry if UUIDs change
- a = filter_keys_nested_item(json.dumps(network.describe_state(), sort_keys=True, indent=2), ["uuid"])
- b = filter_keys_nested_item(json.dumps(state_before, sort_keys=True, indent=2), ["uuid"])
- assert a == b
-
-
def test_creating_container():
"""Check that we can create a network container"""
net = Network()
diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py
index ccf40c44..4bfd28d0 100644
--- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py
+++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py
@@ -19,7 +19,6 @@ def dos_bot() -> DoSBot:
dos_bot: DoSBot = computer.software_manager.software.get("DoSBot")
dos_bot.configure(target_ip_address=IPv4Address("192.168.0.1"))
- dos_bot.set_original_state()
return dos_bot
@@ -28,35 +27,6 @@ def test_dos_bot_creation(dos_bot):
assert dos_bot is not None
-def test_dos_bot_reset(dos_bot):
- assert dos_bot.target_ip_address == IPv4Address("192.168.0.1")
- assert dos_bot.target_port is Port.POSTGRES_SERVER
- assert dos_bot.payload is None
- assert dos_bot.repeat is False
-
- dos_bot.configure(
- target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True
- )
-
- # should reset the relevant items
- dos_bot.reset_component_for_episode(episode=0)
- assert dos_bot.target_ip_address == IPv4Address("192.168.0.1")
- assert dos_bot.target_port is Port.POSTGRES_SERVER
- assert dos_bot.payload is None
- assert dos_bot.repeat is False
-
- dos_bot.configure(
- target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True
- )
- dos_bot.set_original_state()
- dos_bot.reset_component_for_episode(episode=1)
- # should reset to the configured value
- assert dos_bot.target_ip_address == IPv4Address("192.168.1.1")
- assert dos_bot.target_port is Port.HTTP
- assert dos_bot.payload == "payload"
- assert dos_bot.repeat is True
-
-
def test_dos_bot_cannot_run_when_node_offline(dos_bot):
dos_bot_node: Computer = dos_bot.parent
assert dos_bot_node.operating_state is NodeOperatingState.ON
diff --git a/tests/unit_tests/_primaite/_simulator/_system/test_software.py b/tests/unit_tests/_primaite/_simulator/_system/test_software.py
index e77cd895..6f680012 100644
--- a/tests/unit_tests/_primaite/_simulator/_system/test_software.py
+++ b/tests/unit_tests/_primaite/_simulator/_system/test_software.py
@@ -2,12 +2,14 @@ from typing import Dict
import pytest
+from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.sys_log import SysLog
-from primaite.simulator.system.software import Software, SoftwareHealthState
+from primaite.simulator.system.services.service import Service
+from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
-class TestSoftware(Software):
+class TestSoftware(Service):
def describe_state(self) -> Dict:
pass
@@ -15,7 +17,11 @@ class TestSoftware(Software):
@pytest.fixture(scope="function")
def software(file_system):
return TestSoftware(
- name="TestSoftware", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_service")
+ name="TestSoftware",
+ port=Port.ARP,
+ file_system=file_system,
+ sys_log=SysLog(hostname="test_service"),
+ protocol=IPProtocol.TCP,
)