Merged PR 283: Refactor episode reset

## Summary
Instead of setting all attributes back to a snapshot, simply recreate the entire game and all agents from a cached copy of the config.

This removes the need for `set_original_state` and `reset_component_for_episode` methods on SimComponents.

## Test process
* Unit tests passing
* I've also tried adding a `__del__` method to simcomponent and agent and press env.reset() to verify that the ref count reaches 0, and nothing is secretly keeping using an old part of the simulation.

## Checklist
- [X] PR is linked to a **work item**
- [X] **acceptance criteria** of linked ticket are met
- [X] performed **self-review** of the code
- [X] written **tests** for any new functionality added with this PR
- [X] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [X] updated the **change log**
- [X] ran **pre-commit** checks for code style
- [ ] attended to any **TO-DOs** left in the code

Related work items: #2317
This commit is contained in:
Marek Wolan
2024-02-29 11:11:03 +00:00
58 changed files with 515 additions and 969 deletions

View File

@@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
- Made environment reset completely recreate the game object.
- Changed the red agent in the data manipulation scenario to randomly choose client 1 or client 2 to start its attack.
- Changed the data manipulation scenario to include a second green agent on client 1.
- Refactored actions and observations to be configurable via object name, instead of UUID.

View File

@@ -108,6 +108,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
source/simulation
source/game_layer
source/config
source/environment
.. toctree::
:caption: Developer information:

View File

@@ -0,0 +1,10 @@
RL Environments
***************
RL environments are the objects that directly interface with RL libraries such as Stable-Baselines3 and Ray RLLib. The PrimAITE simulation is exposed via three different environment APIs:
* Gymnasium API - this is the standard interface that works with many RL libraries like SB3, Ray, Tianshou, etc. ``PrimaiteGymEnv`` adheres to the `Official Gymnasium documentation <https://gymnasium.farama.org/api/env/>`_.
* Ray Single agent API - For training a single Ray RLLib agent
* Ray MARL API - For training multi-agent systems with Ray RLLib. ``PrimaiteRayMARLEnv`` adheres to the `Official Ray documentation <https://docs.ray.io/en/latest/rllib/package_ref/env/multi_agent_env.html>`_.
There are Jupyter notebooks which demonstrate integration with each of these three environments. They are located in ``~/primaite/<VERSION>/notebooks/example_notebooks``.

View File

@@ -20,6 +20,11 @@ The game layer is responsible for managing agents and getting them to interface
PrimAITE Session
^^^^^^^^^^^^^^^
.. admonition:: Deprecated
:class: deprecated
PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The `session` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality.
``PrimaiteSession`` is the main entry point into Primaite and it allows the simultaneous coordination of a simulation and agents that interact with it. ``PrimaiteSession`` keeps track of multiple agents of different types.
Agents

View File

@@ -4,6 +4,11 @@
.. _run a primaite session:
.. admonition:: Deprecated
:class: deprecated
PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The ``session`` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality.
Run a PrimAITE Session
======================

View File

@@ -14,7 +14,7 @@ io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_step_metadata: false
save_pcap_logs: true
save_pcap_logs: false
save_sys_logs: true
@@ -656,12 +656,13 @@ simulation:
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: web_server_web_service
type: WebServer
applications:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server

View File

@@ -1,15 +1,23 @@
training_config:
rl_framework: RLLIB_multi_agent
# rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
n_eval_episodes: 5
max_steps_per_episode: 256
deterministic_eval: false
n_agents: 2
agent_references:
- defender_1
- defender_2
io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_step_metadata: false
save_pcap_logs: false
save_sys_logs: true
game:
@@ -36,9 +44,9 @@ agents:
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_ref: client_2
- node_name: client_2
applications:
- application_ref: client_2_web_browser
- application_name: WebBrowser
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
@@ -54,6 +62,31 @@ agents:
frequency: 4
variance: 3
- ref: client_1_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
type: UC2GreenObservation
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_1
applications:
- application_name: WebBrowser
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 1
reward_function:
reward_components:
- type: DUMMY
- ref: data_manipulation_attacker
team: RED
type: RedDatabaseCorruptingAgent
@@ -72,7 +105,7 @@ agents:
- type: NODE_OS_SCAN
options:
nodes:
- node_ref: client_1
- node_name: client_1
applications:
- application_name: DataManipulationBot
- node_name: client_2
@@ -104,25 +137,21 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: DNSServer
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
services:
- service_ref: database_service
- service_name: WebServer
- node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -137,23 +166,23 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
- node_hostname: domain_controller
nic_num: 1
- node_ref: web_server
- node_hostname: web_server
nic_num: 1
- node_ref: database_server
- node_hostname: database_server
nic_num: 1
- node_ref: backup_server
- node_hostname: backup_server
nic_num: 1
- node_ref: security_suite
- node_hostname: security_suite
nic_num: 1
- node_ref: client_1
- node_hostname: client_1
nic_num: 1
- node_ref: client_2
- node_hostname: client_2
nic_num: 1
- node_ref: security_suite
- node_hostname: security_suite
nic_num: 2
ics: null
@@ -184,10 +213,10 @@ agents:
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
target_router_ref: router_1
target_router_hostname: router_1
- type: NETWORK_ACL_REMOVERULE
options:
target_router_ref: router_1
target_router_hostname: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
@@ -242,25 +271,25 @@ agents:
action: "NODE_FILE_SCAN"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
11:
action: "NODE_FILE_DELETE"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
13:
action: "NODE_SERVICE_PATCH"
@@ -271,22 +300,22 @@ agents:
action: "NODE_FOLDER_SCAN"
options:
node_id: 2
folder_id: 1
folder_id: 0
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 2
folder_id: 1
folder_id: 0
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 2
folder_id: 1
folder_id: 0
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 2
folder_id: 1
folder_id: 0
18:
action: "NODE_OS_SCAN"
options:
@@ -303,63 +332,63 @@ agents:
action: "NODE_RESET"
options:
node_id: 5
22:
22: # "ACL: ADDRULE - Block outgoing traffic from client 1"
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 1
source_ip_id: 7 # client 1
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
23:
23: # "ACL: ADDRULE - Block outgoing traffic from client 2"
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 2
permission: 2
source_ip_id: 8
dest_ip_id: 1
source_ip_id: 8 # client 2
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
24:
24: # block tcp traffic from client 1 to web app
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 3
permission: 2
source_ip_id: 7
dest_ip_id: 3
source_ip_id: 7 # client 1
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
25:
25: # block tcp traffic from client 2 to web app
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 4
permission: 2
source_ip_id: 8
dest_ip_id: 3
source_ip_id: 8 # client 2
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 5
permission: 2
source_ip_id: 7
dest_ip_id: 4
source_ip_id: 7 # client 1
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 6
permission: 2
source_ip_id: 8
dest_ip_id: 4
source_ip_id: 8 # client 2
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
@@ -407,123 +436,148 @@ agents:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 0
nic_id: 1
nic_id: 0
39:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 0
nic_id: 1
nic_id: 0
40:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 1
nic_id: 1
nic_id: 0
41:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 1
nic_id: 1
nic_id: 0
42:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 2
nic_id: 1
nic_id: 0
43:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 2
nic_id: 1
nic_id: 0
44:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 3
nic_id: 1
nic_id: 0
45:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 3
nic_id: 1
nic_id: 0
46:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
nic_id: 0
47:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
nic_id: 0
48:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 2
nic_id: 1
49:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 2
nic_id: 1
50:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 1
nic_id: 0
51:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 1
nic_id: 0
52:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 6
nic_id: 1
nic_id: 0
53:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 6
nic_id: 1
nic_id: 0
options:
nodes:
- node_ref: domain_controller
- node_ref: web_server
- node_name: domain_controller
- node_name: web_server
applications:
- application_name: DatabaseClient
services:
- service_ref: web_server_web_service
- node_ref: database_server
- service_name: WebServer
- node_name: database_server
folders:
- folder_name: database
files:
- file_name: database.db
services:
- service_ref: database_service
- node_ref: backup_server
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- service_name: DatabaseService
- node_name: backup_server
- node_name: security_suite
- node_name: client_1
- node_name: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
weight: 0.34
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.33
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: client_1
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.33
options:
node_hostname: client_2
agent_settings:
# ...
flatten_obs: true
- ref: defender_2
team: BLUE
@@ -537,25 +591,21 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: DNSServer
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
services:
- service_ref: database_service
- service_name: WebServer
- node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -570,23 +620,23 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
- node_hostname: domain_controller
nic_num: 1
- node_ref: web_server
- node_hostname: web_server
nic_num: 1
- node_ref: database_server
- node_hostname: database_server
nic_num: 1
- node_ref: backup_server
- node_hostname: backup_server
nic_num: 1
- node_ref: security_suite
- node_hostname: security_suite
nic_num: 1
- node_ref: client_1
- node_hostname: client_1
nic_num: 1
- node_ref: client_2
- node_hostname: client_2
nic_num: 1
- node_ref: security_suite
- node_hostname: security_suite
nic_num: 2
ics: null
@@ -617,10 +667,10 @@ agents:
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
target_router_ref: router_1
target_router_hostname: router_1
- type: NETWORK_ACL_REMOVERULE
options:
target_router_ref: router_1
target_router_hostname: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
@@ -675,25 +725,25 @@ agents:
action: "NODE_FILE_SCAN"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
11:
action: "NODE_FILE_DELETE"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
13:
action: "NODE_SERVICE_PATCH"
@@ -704,22 +754,22 @@ agents:
action: "NODE_FOLDER_SCAN"
options:
node_id: 2
folder_id: 1
folder_id: 0
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 2
folder_id: 1
folder_id: 0
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 2
folder_id: 1
folder_id: 0
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 2
folder_id: 1
folder_id: 0
18:
action: "NODE_OS_SCAN"
options:
@@ -736,63 +786,63 @@ agents:
action: "NODE_RESET"
options:
node_id: 5
22:
22: # "ACL: ADDRULE - Block outgoing traffic from client 1"
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 1
source_ip_id: 7 # client 1
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
23:
23: # "ACL: ADDRULE - Block outgoing traffic from client 2"
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 2
permission: 2
source_ip_id: 8
dest_ip_id: 1
source_ip_id: 8 # client 2
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
24:
24: # block tcp traffic from client 1 to web app
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 3
permission: 2
source_ip_id: 7
dest_ip_id: 3
source_ip_id: 7 # client 1
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
25:
25: # block tcp traffic from client 2 to web app
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 4
permission: 2
source_ip_id: 8
dest_ip_id: 3
source_ip_id: 8 # client 2
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 5
permission: 2
source_ip_id: 7
dest_ip_id: 4
source_ip_id: 7 # client 1
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 6
permission: 2
source_ip_id: 8
dest_ip_id: 4
source_ip_id: 8 # client 2
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
@@ -840,122 +890,148 @@ agents:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 0
nic_id: 1
nic_id: 0
39:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 0
nic_id: 1
nic_id: 0
40:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 1
nic_id: 1
nic_id: 0
41:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 1
nic_id: 1
nic_id: 0
42:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 2
nic_id: 1
nic_id: 0
43:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 2
nic_id: 1
nic_id: 0
44:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 3
nic_id: 1
nic_id: 0
45:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 3
nic_id: 1
nic_id: 0
46:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
nic_id: 0
47:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
nic_id: 0
48:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 2
nic_id: 1
49:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 2
nic_id: 1
50:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 1
nic_id: 0
51:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 1
nic_id: 0
52:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 6
nic_id: 1
nic_id: 0
53:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 6
nic_id: 1
nic_id: 0
options:
nodes:
- node_ref: domain_controller
- node_ref: web_server
- node_name: domain_controller
- node_name: web_server
applications:
- application_name: DatabaseClient
services:
- service_ref: web_server_web_service
- node_ref: database_server
- service_name: WebServer
- node_name: database_server
folders:
- folder_name: database
files:
- file_name: database.db
services:
- service_ref: database_service
- node_ref: backup_server
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- service_name: DatabaseService
- node_name: backup_server
- node_name: security_suite
- node_name: client_1
- node_name: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
weight: 0.34
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.33
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: client_1
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.33
options:
node_hostname: client_2
agent_settings:
# ...
flatten_obs: true
@@ -1036,12 +1112,13 @@ simulation:
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: web_server_web_service
type: WebServer
applications:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server
@@ -1093,10 +1170,14 @@ simulation:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
- ref: client_1_web_browser
type: WebBrowser
options:
target_url: http://arcd.com/users/
services:
- ref: client_1_dns_client
type: DNSClient
@@ -1113,6 +1194,13 @@ simulation:
type: WebBrowser
options:
target_url: http://arcd.com/users/
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
services:
- ref: client_2_dns_client
type: DNSClient

View File

@@ -80,18 +80,15 @@ class PrimaiteGame:
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
self.agents: List[AbstractAgent] = []
"""List of agents."""
self.agents: Dict[str, AbstractAgent] = {}
"""Mapping from agent name to agent object."""
self.rl_agents: List[ProxyAgent] = []
"""Subset of agent list including only the reinforcement learning agents."""
self.rl_agents: Dict[str, ProxyAgent] = {}
"""Subset of agents which are intended for reinforcement learning."""
self.step_counter: int = 0
"""Current timestep within the episode."""
self.episode_counter: int = 0
"""Current episode number."""
self.options: PrimaiteGameOptions
"""Special options that apply for the entire game."""
@@ -137,7 +134,7 @@ class PrimaiteGame:
self.update_agents(sim_state)
# Apply all actions to simulation as requests
agent_actions = self.apply_agent_actions() # noqa
self.apply_agent_actions()
# Advance timestep
self.advance_timestep()
@@ -148,7 +145,7 @@ class PrimaiteGame:
def update_agents(self, state: Dict) -> None:
"""Update agents' observations and rewards based on the current state."""
for agent in self.agents:
for _, agent in self.agents.items():
agent.update_observation(state)
agent.update_reward(state)
agent.reward_function.total_reward += agent.reward_function.current_reward
@@ -162,7 +159,7 @@ class PrimaiteGame:
"""
agent_actions = {}
for agent in self.agents:
for _, agent in self.agents.items():
obs = agent.observation_manager.current_observation
rew = agent.reward_function.current_reward
action_choice, options = agent.get_action(obs, rew)
@@ -185,20 +182,14 @@ class PrimaiteGame:
return True
return False
def reset(self) -> None:
"""Reset the game, this will reset the simulation."""
self.episode_counter += 1
self.step_counter = 0
_LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}")
self.simulation.reset_component_for_episode(episode=self.episode_counter)
for agent in self.agents:
agent.reward_function.total_reward = 0.0
agent.reset_agent_for_episode()
def close(self) -> None:
"""Close the game, this will close the simulation."""
return NotImplemented
def setup_for_episode(self, episode: int) -> None:
"""Perform any final configuration of components to make them ready for the game to start."""
self.simulation.setup_for_episode(episode=episode)
@classmethod
def from_config(cls, cfg: Dict) -> "PrimaiteGame":
"""Create a PrimaiteGame object from a config dictionary.
@@ -272,7 +263,9 @@ class PrimaiteGame:
new_service = new_node.software_manager.software[service_type]
game.ref_map_services[service_ref] = new_service.uuid
else:
_LOGGER.warning(f"service type not found {service_type}")
msg = f"Configuration contains an invalid service type: {service_type}"
_LOGGER.error(msg)
raise ValueError(msg)
# service-dependent options
if service_type == "DNSClient":
if "options" in service_cfg:
@@ -311,7 +304,9 @@ class PrimaiteGame:
new_application = new_node.software_manager.software[application_type]
game.ref_map_applications[application_ref] = new_application.uuid
else:
_LOGGER.warning(f"application type not found {application_type}")
msg = f"Configuration contains an invalid application type: {application_type}"
_LOGGER.error(msg)
raise ValueError(msg)
if application_type == "DataManipulationBot":
if "options" in application_cfg:
@@ -402,7 +397,6 @@ class PrimaiteGame:
reward_function=reward_function,
agent_settings=agent_settings,
)
game.agents.append(new_agent)
elif agent_type == "ProxyAgent":
new_agent = ProxyAgent(
agent_name=agent_cfg["ref"],
@@ -411,8 +405,7 @@ class PrimaiteGame:
reward_function=reward_function,
agent_settings=agent_settings,
)
game.agents.append(new_agent)
game.rl_agents.append(new_agent)
game.rl_agents[agent_cfg["ref"]] = new_agent
elif agent_type == "RedDatabaseCorruptingAgent":
new_agent = DataManipulationAgent(
agent_name=agent_cfg["ref"],
@@ -421,11 +414,11 @@ class PrimaiteGame:
reward_function=reward_function,
agent_settings=agent_settings,
)
game.agents.append(new_agent)
else:
_LOGGER.warning(f"agent type {agent_type} not found")
game.simulation.set_original_state()
msg(f"Configuration error: {agent_type} is not a valid agent type.")
_LOGGER.error(msg)
raise ValueError(msg)
game.agents[agent_cfg["ref"]] = new_agent
# Set the NMNE capture config
set_nmne_config(cfg["simulation"]["network"].get("nmne_config", {}))

View File

@@ -60,7 +60,7 @@
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
" )\n",
" .environment(env=PrimaiteRayMARLEnv, env_config={\"cfg\":cfg})#, disable_env_checking=True)\n",
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)#, disable_env_checking=True)\n",
" .rollouts(num_rollout_workers=0)\n",
" .training(train_batch_size=128)\n",
" )\n"
@@ -88,6 +88,13 @@
" param_space=config\n",
").fit()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@@ -54,7 +54,7 @@
"metadata": {},
"outputs": [],
"source": [
"env_config = {\"cfg\":cfg}\n",
"env_config = cfg\n",
"\n",
"config = (\n",
" PPOConfig()\n",

View File

@@ -27,9 +27,7 @@
"outputs": [],
"source": [
"with open(example_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"game = PrimaiteGame.from_config(cfg)"
" cfg = yaml.safe_load(f)\n"
]
},
{
@@ -38,7 +36,7 @@
"metadata": {},
"outputs": [],
"source": [
"gym = PrimaiteGymEnv(game=game)"
"gym = PrimaiteGymEnv(game_config=cfg)"
]
},
{
@@ -65,7 +63,7 @@
"metadata": {},
"outputs": [],
"source": [
"model.learn(total_timesteps=1000)\n"
"model.learn(total_timesteps=10)\n"
]
},
{
@@ -76,6 +74,13 @@
"source": [
"model.save(\"deleteme\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@@ -346,9 +346,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
@@ -358,9 +356,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"# Imports\n",
@@ -383,9 +379,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"# create the env\n",
@@ -396,10 +390,10 @@
" cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['data_manipulation_p_of_success'] = 1.0\n",
" cfg['simulation']['network']['nodes'][8]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n",
" cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n",
"game = PrimaiteGame.from_config(cfg)\n",
"env = PrimaiteGymEnv(game = game)\n",
"# Don't flatten obs as we are not training an agent and we wish to see the dict-formatted observations\n",
"env.agent.flatten_obs = False\n",
" # don't flatten observations so that we can see what is going on\n",
" cfg['agents'][3]['agent_settings']['flatten_obs'] = False\n",
"\n",
"env = PrimaiteGymEnv(game_config = cfg)\n",
"obs, info = env.reset()\n",
"print('env created successfully')\n",
"pprint(obs)"
@@ -433,9 +427,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"for step in range(35):\n",
@@ -453,9 +445,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"pprint(obs['NODES'])"
@@ -471,9 +461,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"obs, reward, terminated, truncated, info = env.step(9) # scan database file\n",
@@ -499,9 +487,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"obs, reward, terminated, truncated, info = env.step(13) # patch the database\n",
@@ -526,9 +512,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"obs, reward, terminated, truncated, info = env.step(0) # patch the database\n",
@@ -551,9 +535,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"env.step(13) # Patch the database\n",
@@ -593,6 +575,22 @@
"obs['ACL']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Reset the environment, you can rerun the other cells to verify that the attack works the same every episode."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env.reset()"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -603,7 +601,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "venv",
"language": "python",
"name": "python3"
},
@@ -617,9 +615,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
"nbformat_minor": 2
}

View File

@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, Final, Optional, SupportsFloat, Tuple
from typing import Any, Dict, Optional, SupportsFloat, Tuple
import gymnasium
from gymnasium.core import ActType, ObsType
@@ -18,11 +18,23 @@ class PrimaiteGymEnv(gymnasium.Env):
assumptions about the agent list always having a list of length 1.
"""
def __init__(self, game: PrimaiteGame):
def __init__(self, game_config: Dict):
"""Initialise the environment."""
super().__init__()
self.game: "PrimaiteGame" = game
self.agent: ProxyAgent = self.game.rl_agents[0]
self.game_config: Dict = game_config
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config)
"""Current game."""
self._agent_name = next(iter(self.game.rl_agents))
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
self.episode_counter: int = 0
"""Current episode number."""
@property
def agent(self) -> ProxyAgent:
"""Grab a fresh reference to the agent object because it will be reinstantiated each episode."""
return self.game.rl_agents[self._agent_name]
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
@@ -45,13 +57,13 @@ class PrimaiteGymEnv(gymnasium.Env):
return next_obs, reward, terminated, truncated, info
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata"
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"step_{self.game.step_counter}.json"
data = {
"episode": self.game.episode_counter,
"episode": self.episode_counter,
"step": self.game.step_counter,
"action": int(action),
"reward": int(reward),
@@ -63,10 +75,12 @@ class PrimaiteGymEnv(gymnasium.Env):
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
print(
f"Resetting environment, episode {self.game.episode_counter}, "
f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}"
f"Resetting environment, episode {self.episode_counter}, "
f"avg. reward: {self.agent.reward_function.total_reward}"
)
self.game.reset()
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config)
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
@@ -88,12 +102,12 @@ class PrimaiteGymEnv(gymnasium.Env):
def _get_obs(self) -> ObsType:
"""Return the current observation."""
if not self.agent.flatten_obs:
return self.agent.observation_manager.current_observation
else:
if self.agent.flatten_obs:
unflat_space = self.agent.observation_manager.space
unflat_obs = self.agent.observation_manager.current_observation
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
else:
return self.agent.observation_manager.current_observation
class PrimaiteRayEnv(gymnasium.Env):
@@ -102,12 +116,11 @@ class PrimaiteRayEnv(gymnasium.Env):
def __init__(self, env_config: Dict) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
:type env_config: Dict[str, PrimaiteGame]
:param env_config: A dictionary containing the environment configuration.
:type env_config: Dict
"""
self.env = PrimaiteGymEnv(game=PrimaiteGame.from_config(env_config["cfg"]))
self.env.game.episode_counter -= 1
self.env = PrimaiteGymEnv(game_config=env_config)
self.env.episode_counter -= 1
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
@@ -128,13 +141,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
:type env_config: Dict[str, PrimaiteGame]
:type env_config: Dict
"""
self.game: PrimaiteGame = PrimaiteGame.from_config(env_config["cfg"])
self.game_config: Dict = env_config
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config)
"""Reference to the primaite game"""
self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents}
"""List of all possible agents in the environment. This list should not change!"""
self._agent_ids = list(self.agents.keys())
self._agent_ids = list(self.game.rl_agents.keys())
"""Agent ids. This is a list of strings of agent names."""
self.episode_counter: int = 0
"""Current episode number."""
self.terminateds = set()
self.truncateds = set()
@@ -149,9 +165,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
)
super().__init__()
@property
def agents(self) -> Dict[str, ProxyAgent]:
"""Grab a fresh reference to the agents from this episode's game object."""
return {name: self.game.rl_agents[name] for name in self._agent_ids}
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
self.game.reset()
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config)
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
@@ -172,7 +195,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
# 1. Perform actions
for agent_name, action in actions.items():
self.agents[agent_name].store_action(action)
agent_actions = self.game.apply_agent_actions()
self.game.apply_agent_actions()
# 2. Advance timestep
self.game.advance_timestep()
@@ -186,7 +209,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
terminateds = {name: False for name, _ in self.agents.items()}
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
infos = {"agent_actions": agent_actions}
infos = {name: {} for name, _ in self.agents.items()}
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
truncateds["__all__"] = self.game.calculate_truncated()
if self.game.save_step_metadata:
@@ -194,13 +217,13 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
return next_obs, rewards, terminateds, truncateds, infos
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):
output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata"
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"step_{self.game.step_counter}.json"
data = {
"episode": self.game.episode_counter,
"episode": self.episode_counter,
"step": self.game.step_counter,
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
"reward": rewards,
@@ -212,8 +235,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def _get_obs(self) -> Dict[str, ObsType]:
"""Return the current observation."""
obs = {}
for name, agent in self.agents.items():
for agent_name in self._agent_ids:
agent = self.game.rl_agents[agent_name]
unflat_space = agent.observation_manager.space
unflat_obs = agent.observation_manager.current_observation
obs[name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
return obs

View File

@@ -101,11 +101,11 @@ class PrimaiteSession:
# CREATE ENVIRONMENT
if sess.training_options.rl_framework == "RLLIB_single_agent":
sess.env = PrimaiteRayEnv(env_config={"cfg": cfg})
sess.env = PrimaiteRayEnv(env_config=cfg)
elif sess.training_options.rl_framework == "RLLIB_multi_agent":
sess.env = PrimaiteRayMARLEnv(env_config={"cfg": cfg})
sess.env = PrimaiteRayMARLEnv(env_config=cfg)
elif sess.training_options.rl_framework == "SB3":
sess.env = PrimaiteGymEnv(game=game)
sess.env = PrimaiteGymEnv(game_config=cfg)
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
if agent_load_path:

View File

@@ -153,22 +153,18 @@ class SimComponent(BaseModel):
uuid: str = Field(default_factory=lambda: str(uuid4()))
"""The component UUID."""
_original_state: Dict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._request_manager: RequestManager = self._init_request_manager()
self._parent: Optional["SimComponent"] = None
# @abstractmethod
def set_original_state(self):
"""Sets the original state."""
pass
def setup_for_episode(self, episode: int):
"""
Perform any additional setup on this component that can't happen during __init__.
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
for key, value in self._original_state.items():
self.__setattr__(key, value)
For instance, some components may require for the entire network to exist before some configuration can be set.
"""
pass
def _init_request_manager(self) -> RequestManager:
"""

View File

@@ -42,19 +42,6 @@ class Account(SimComponent):
"Account Type, currently this can be service account (used by apps) or user account."
enabled: bool = True
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {
"num_logons",
"num_logoffs",
"num_group_changes",
"username",
"password",
"account_type",
"enabled",
}
self._original_state = self.model_dump(include=vals_to_include)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.

View File

@@ -73,20 +73,6 @@ class File(FileSystemItemABC):
self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})")
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting File ({self.path}) original state on node {self.sys_log.hostname}")
super().set_original_state()
vals_to_include = {"folder_id", "folder_name", "file_type", "sim_size", "real", "sim_path", "sim_root"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting File ({self.path}) state on node {self.sys_log.hostname}")
super().reset_component_for_episode(episode)
@property
def path(self) -> str:
"""

View File

@@ -34,43 +34,6 @@ class FileSystem(SimComponent):
if not self.folders:
self.create_folder("root")
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting FileSystem original state on node {self.sys_log.hostname}")
for folder in self.folders.values():
folder.set_original_state()
# Capture a list of all 'original' file uuids
original_keys = list(self.folders.keys())
vals_to_include = {"sim_root"}
self._original_state.update(self.model_dump(include=vals_to_include))
self._original_state["original_folder_uuids"] = original_keys
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting FileSystem state on node {self.sys_log.hostname}")
# Move any 'original' folder that have been deleted back to folders
original_folder_uuids = self._original_state["original_folder_uuids"]
for uuid in original_folder_uuids:
if uuid in self.deleted_folders:
folder = self.deleted_folders[uuid]
self.deleted_folders.pop(uuid)
self.folders[uuid] = folder
# Clear any other deleted folders that aren't original (have been created by agent)
self.deleted_folders.clear()
# Now clear all non-original folders created by agent
current_folder_uuids = list(self.folders.keys())
for uuid in current_folder_uuids:
if uuid not in original_folder_uuids:
folder = self.folders[uuid]
self.folders.pop(uuid)
# Now reset all remaining folders
for folder in self.folders.values():
folder.reset_component_for_episode(episode)
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()

View File

@@ -85,11 +85,6 @@ class FileSystemItemABC(SimComponent):
deleted: bool = False
"If true, the FileSystemItem was deleted."
def set_original_state(self):
"""Sets the original state."""
vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red", "deleted"}
self._original_state = self.model_dump(include=vals_to_keep)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.

View File

@@ -49,49 +49,6 @@ class Folder(FileSystemItemABC):
self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})")
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting Folder ({self.name}) original state on node {self.sys_log.hostname}")
for file in self.files.values():
file.set_original_state()
super().set_original_state()
vals_to_include = {
"scan_duration",
"scan_countdown",
"red_scan_duration",
"red_scan_countdown",
"restore_duration",
"restore_countdown",
}
self._original_state.update(self.model_dump(include=vals_to_include))
self._original_state["original_file_uuids"] = list(self.files.keys())
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting Folder ({self.name}) state on node {self.sys_log.hostname}")
# Move any 'original' file that have been deleted back to files
original_file_uuids = self._original_state["original_file_uuids"]
for uuid in original_file_uuids:
if uuid in self.deleted_files:
file = self.deleted_files[uuid]
self.deleted_files.pop(uuid)
self.files[uuid] = file
# Clear any other deleted files that aren't original (have been created by agent)
self.deleted_files.clear()
# Now clear all non-original files created by agent
current_file_uuids = list(self.files.keys())
for uuid in current_file_uuids:
if uuid not in original_file_uuids:
file = self.files[uuid]
self.files.pop(uuid)
# Now reset all remaining files
for file in self.files.values():
file.reset_component_for_episode(episode)
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(

View File

@@ -273,11 +273,6 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC)
return state
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"}
self._original_state = self.model_dump(include=vals_to_include)
def enable(self):
"""
Enables this wired network interface and attempts to send a "hello" message to the default gateway.

View File

@@ -45,19 +45,12 @@ class Network(SimComponent):
self._nx_graph = MultiGraph()
def set_original_state(self):
"""Sets the original state."""
for node in self.nodes.values():
node.set_original_state()
for link in self.links.values():
link.set_original_state()
def reset_component_for_episode(self, episode: int):
def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
for node in self.nodes.values():
node.reset_component_for_episode(episode)
node.setup_for_episode(episode=episode)
for link in self.links.values():
link.reset_component_for_episode(episode)
link.setup_for_episode(episode=episode)
for node in self.nodes.values():
node.power_on()
@@ -179,7 +172,7 @@ class Network(SimComponent):
def clear_links(self):
"""Clear all the links in the network by resetting their component state for the episode."""
for link in self.links.values():
link.reset_component_for_episode()
link.setup_for_episode(episode=0) # TODO: shouldn't be using this method here.
def draw(self, seed: int = 123):
"""

View File

@@ -100,6 +100,15 @@ class NetworkInterface(SimComponent, ABC):
nmne: Dict = Field(default_factory=lambda: {})
"A dict containing details of the number of malicious network events captured."
def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().setup_for_episode(episode=episode)
self.nmne = {}
if episode and self.pcap:
self.pcap.current_episode = episode
self.pcap.setup_logger()
self.enable()
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
@@ -127,15 +136,6 @@ class NetworkInterface(SimComponent, ABC):
state.update({"nmne": self.nmne})
return state
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().reset_component_for_episode(episode)
self.nmne = {}
if episode and self.pcap:
self.pcap.current_episode = episode
self.pcap.setup_logger()
self.enable()
@abstractmethod
def enable(self):
"""Enable the interface."""
@@ -547,14 +547,6 @@ class Link(SimComponent):
self.endpoint_b.connect_link(self)
self.endpoint_up()
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"bandwidth", "current_load"}
self._original_state = self.model_dump(include=vals_to_include)
super().set_original_state()
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -740,50 +732,20 @@ class Node(SimComponent):
self.session_manager.node = self
self.session_manager.software_manager = self.software_manager
self._install_system_software()
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
for software in self.software_manager.software.values():
software.set_original_state()
self.file_system.set_original_state()
for network_interface in self.network_interfaces.values():
network_interface.set_original_state()
vals_to_include = {
"hostname",
"default_gateway",
"operating_state",
"revealed_to_red",
"start_up_duration",
"start_up_countdown",
"shut_down_duration",
"shut_down_countdown",
"is_resetting",
"node_scan_duration",
"node_scan_countdown",
"red_scan_countdown",
}
self._original_state = self.model_dump(include=vals_to_include)
def reset_component_for_episode(self, episode: int):
def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().reset_component_for_episode(episode)
# Reset Session Manager
self.session_manager.clear()
super().setup_for_episode(episode=episode)
# Reset File System
self.file_system.reset_component_for_episode(episode)
self.file_system.setup_for_episode(episode=episode)
# Reset all Nics
for network_interface in self.network_interfaces.values():
network_interface.reset_component_for_episode(episode)
network_interface.setup_for_episode(episode=episode)
for software in self.software_manager.software.values():
software.reset_component_for_episode(episode)
software.setup_for_episode(episode=episode)
if episode and self.sys_log:
self.sys_log.current_episode = episode

View File

@@ -209,11 +209,6 @@ class NIC(IPWiredNetworkInterface):
return state
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"}
self._original_state = self.model_dump(include=vals_to_include)
def receive_frame(self, frame: Frame) -> bool:
"""
Attempt to receive and process a network frame from the connected Link.

View File

@@ -109,24 +109,6 @@ class Firewall(Router):
sys_log=kwargs["sys_log"], implicit_action=ACLAction.PERMIT, name=f"{hostname} - External Outbound"
)
self.set_original_state()
def set_original_state(self):
"""Set the original state for the Firewall."""
super().set_original_state()
vals_to_include = {
"internal_port",
"external_port",
"dmz_port",
"internal_inbound_acl",
"internal_outbound_acl",
"dmz_inbound_acl",
"dmz_outbound_acl",
"external_inbound_acl",
"external_outbound_acl",
}
self._original_state.update(self.model_dump(include=vals_to_include))
def describe_state(self) -> Dict:
"""
Describes the current state of the Firewall.

View File

@@ -136,11 +136,6 @@ class ACLRule(SimComponent):
rule_strings.append(f"{key}={value}")
return ", ".join(rule_strings)
def set_original_state(self):
"""Sets the original state."""
vals_to_keep = {"action", "protocol", "src_ip_address", "src_port", "dst_ip_address", "dst_port"}
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
def describe_state(self) -> Dict:
"""
Describes the current state of the ACLRule.
@@ -296,48 +291,6 @@ class AccessControlList(SimComponent):
super().__init__(**kwargs)
self._acl = [None] * (self.max_acl_rules - 1)
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
self.implicit_rule.set_original_state()
vals_to_keep = {"implicit_action", "max_acl_rules", "acl"}
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
for i, rule in enumerate(self._acl):
if not rule:
continue
self._default_config[i] = {"action": rule.action.name}
if rule.src_ip_address:
self._default_config[i]["src_ip"] = str(rule.src_ip_address)
if rule.dst_ip_address:
self._default_config[i]["dst_ip"] = str(rule.dst_ip_address)
if rule.src_port:
self._default_config[i]["src_port"] = rule.src_port.name
if rule.dst_port:
self._default_config[i]["dst_port"] = rule.dst_port.name
if rule.protocol:
self._default_config[i]["protocol"] = rule.protocol.name
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.implicit_rule.reset_component_for_episode(episode)
super().reset_component_for_episode(episode)
self._reset_rules_to_default()
def _reset_rules_to_default(self) -> None:
"""Clear all ACL rules and set them to the default rules config."""
self._acl = [None] * (self.max_acl_rules - 1)
for r_num, r_cfg in self._default_config.items():
self.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
dst_ip_address=r_cfg.get("dst_ip"),
position=r_num,
)
def _init_request_manager(self) -> RequestManager:
# TODO: Add src and dst wildcard masks as positional args in this request.
@@ -616,11 +569,6 @@ class RouteEntry(SimComponent):
metric: float = 0.0
"The cost metric for this route. Default is 0.0."
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"}
self._original_values = self.model_dump(include=vals_to_include)
def describe_state(self) -> Dict:
"""
Describes the current state of the RouteEntry.
@@ -653,17 +601,6 @@ class RouteTable(SimComponent):
default_route: Optional[RouteEntry] = None
sys_log: SysLog
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
self._original_state["routes_orig"] = self.routes
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.routes.clear()
self.routes = self._original_state["routes_orig"]
super().reset_component_for_episode(episode)
def describe_state(self) -> Dict:
"""
Describes the current state of the RouteTable.
@@ -1104,8 +1041,6 @@ class Router(NetworkNode):
self._set_default_acl()
self.set_original_state()
def _install_system_software(self):
"""
Installs essential system software and network services on the router.
@@ -1131,20 +1066,7 @@ class Router(NetworkNode):
self.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
self.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
def set_original_state(self):
"""
Sets or resets the router to its original configuration state.
This method is called to initialize the router's state or to revert it to a known good configuration during
network simulations or after configuration changes.
"""
self.acl.set_original_state()
self.route_table.set_original_state()
super().set_original_state()
vals_to_include = {"num_ports"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
def setup_for_episode(self, episode: int):
"""
Resets the router's components for a new network simulation episode.
@@ -1154,13 +1076,10 @@ class Router(NetworkNode):
:param episode: The episode number for which the router is being reset.
"""
self.software_manager.arp.clear()
self.acl.reset_component_for_episode(episode)
self.route_table.reset_component_for_episode(episode)
for i, network_interface in self.network_interface.items():
network_interface.reset_component_for_episode(episode)
for i, _ in self.network_interface.items():
self.enable_port(i)
super().reset_component_for_episode(episode)
super().setup_for_episode(episode=episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
@@ -1391,7 +1310,6 @@ class Router(NetworkNode):
network_interface.ip_address = ip_address
network_interface.subnet_mask = subnet_mask
self.sys_log.info(f"Configured Network Interface {network_interface}")
self.set_original_state()
def enable_port(self, port: int):
"""
@@ -1493,6 +1411,14 @@ class Router(NetworkNode):
subnet_mask=port_cfg["subnet_mask"],
)
if "acl" in cfg:
new.acl._default_config = cfg["acl"] # save the config to allow resetting
new.acl._reset_rules_to_default() # read the config and apply rules
for r_num, r_cfg in cfg["acl"].items():
new.acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
dst_ip_address=r_cfg.get("dst_ip"),
position=r_num,
)
return new

View File

@@ -32,12 +32,6 @@ class SwitchPort(WiredNetworkInterface):
_connected_node: Optional[Switch] = None
"The Switch to which the SwitchPort is connected."
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"port_num", "mac_address", "speed", "mtu", "enabled"}
self._original_state = self.model_dump(include=vals_to_include)
super().set_original_state()
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.

View File

@@ -122,8 +122,6 @@ class WirelessRouter(Router):
self.connect_nic(RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0"))
self.set_original_state()
@property
def wireless_access_point(self) -> WirelessAccessPoint:
"""
@@ -166,7 +164,6 @@ class WirelessRouter(Router):
network_interface.ip_address = ip_address
network_interface.subnet_mask = subnet_mask
self.sys_log.info(f"Configured WAP {network_interface}")
self.set_original_state()
self.wireless_access_point.frequency = frequency # Set operating frequency
self.wireless_access_point.enable() # Re-enable the WAP with new settings

View File

@@ -21,13 +21,9 @@ class Simulation(SimComponent):
super().__init__(**kwargs)
def set_original_state(self):
"""Sets the original state."""
self.network.set_original_state()
def reset_component_for_episode(self, episode: int):
def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.network.reset_component_for_episode(episode)
self.network.setup_for_episode(episode=episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()

View File

@@ -38,12 +38,6 @@ class Application(IOSoftware):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"operating_state", "execution_control_status", "num_executions", "groups"}
self._original_state.update(self.model_dump(include=vals_to_include))
@abstractmethod
def describe_state(self) -> Dict:
"""

View File

@@ -31,20 +31,6 @@ class DatabaseClient(Application):
kwargs["port"] = Port.POSTGRES_SERVER
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DatabaseClient WebServer original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"server_ip_address", "server_password", "connected", "_query_success_tracker"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting DataBaseClient state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
self._query_success_tracker.clear()
def describe_state(self) -> Dict:
"""

View File

@@ -49,26 +49,6 @@ class DataManipulationBot(DatabaseClient):
super().__init__(**kwargs)
self.name = "DataManipulationBot"
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DataManipulationBot original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {
"server_ip_address",
"payload",
"server_password",
"port_scan_p_of_success",
"data_manipulation_p_of_success",
"attack_stage",
"repeat",
}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting DataManipulationBot state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()

View File

@@ -57,27 +57,6 @@ class DoSBot(DatabaseClient, Application):
self.name = "DoSBot"
self.max_sessions = 1000 # override normal max sessions
def set_original_state(self):
"""Set the original state of the Denial of Service Bot."""
_LOGGER.debug(f"Setting {self.name} original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {
"target_ip_address",
"target_port",
"payload",
"repeat",
"attack_stage",
"max_sessions",
"port_scan_p_of_success",
"dos_intensity",
}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting {self.name} state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()

View File

@@ -47,21 +47,8 @@ class WebBrowser(Application):
kwargs["port"] = Port.HTTP
super().__init__(**kwargs)
self.set_original_state()
self.run()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting WebBrowser original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"target_url", "domain_name_ip_address", "latest_response"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting WebBrowser state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(
@@ -80,9 +67,6 @@ class WebBrowser(Application):
state["history"] = [hist_item.state() for hist_item in self.history]
return state
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
def get_webpage(self, url: Optional[str] = None) -> bool:
"""
Retrieve the webpage.

View File

@@ -24,12 +24,6 @@ class Process(Software):
operating_state: ProcessOperatingState
"The current operating state of the Process."
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"operating_state"}
self._original_state.update(self.model_dump(include=vals_to_include))
@abstractmethod
def describe_state(self) -> Dict:
"""

View File

@@ -40,25 +40,6 @@ class DatabaseService(Service):
super().__init__(**kwargs)
self._create_db_file()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DatabaseService original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {
"password",
"connections",
"backup_server_ip",
"latest_backup_directory",
"latest_backup_file_name",
}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}")
self.clear_connections()
super().reset_component_for_episode(episode)
def configure_backup(self, backup_server: IPv4Address):
"""
Set up the database backup.

View File

@@ -29,18 +29,6 @@ class DNSClient(Service):
super().__init__(**kwargs)
self.start()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DNSClient original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"dns_server"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.dns_cache.clear()
super().reset_component_for_episode(episode)
def describe_state(self) -> Dict:
"""
Describes the current state of the software.

View File

@@ -28,20 +28,6 @@ class DNSServer(Service):
super().__init__(**kwargs)
self.start()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DNSServer original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"dns_table"}
self._original_state["dns_table_orig"] = self.model_dump(include=vals_to_include)["dns_table"]
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.dns_table.clear()
for key, value in self._original_state["dns_table_orig"].items():
self.dns_table[key] = value
super().reset_component_for_episode(episode)
def describe_state(self) -> Dict:
"""
Describes the current state of the software.

View File

@@ -27,18 +27,6 @@ class FTPClient(FTPServiceABC):
super().__init__(**kwargs)
self.start()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting FTPClient original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"connected"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting FTPClient state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""
Process the command in the FTP Packet.

View File

@@ -27,19 +27,6 @@ class FTPServer(FTPServiceABC):
super().__init__(**kwargs)
self.start()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting FTPServer original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"server_password"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}")
self.clear_connections()
super().reset_component_for_episode(episode)
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""
Process the command in the FTP Packet.

View File

@@ -49,21 +49,12 @@ class NTPClient(Service):
state = super().describe_state()
return state
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def send(
self,
payload: NTPPacket,
session_id: Optional[str] = None,
dest_ip_address: IPv4Address = None,
dest_port: [Port] = Port.NTP,
dest_port: Port = Port.NTP,
**kwargs,
) -> bool:
"""Requests NTP data from NTP server.

View File

@@ -34,16 +34,6 @@ class NTPServer(Service):
state = super().describe_state()
return state
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including
resetting any stateful properties or statistics, and clearing any message
queues.
"""
pass
def receive(
self,
payload: NTPPacket,

View File

@@ -78,12 +78,6 @@ class Service(IOSoftware):
"""
return super().receive(payload=payload, session_id=session_id, **kwargs)
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"operating_state", "restart_duration", "restart_countdown"}
self._original_state.update(self.model_dump(include=vals_to_include))
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request("scan", RequestType(func=lambda request, context: self.scan()))

View File

@@ -23,18 +23,6 @@ class WebServer(Service):
last_response_status_code: Optional[HttpStatusCode] = None
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting WebServer original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"last_response_status_code"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting WebServer state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -130,7 +118,7 @@ class WebServer(Service):
self.set_health_state(SoftwareHealthState.COMPROMISED)
return response
except Exception:
except Exception: # TODO: refactor this. Likely to cause silent bugs. (ADO ticket #2345 )
# something went wrong on the server
response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR
return response

View File

@@ -3,7 +3,7 @@ from abc import abstractmethod
from datetime import datetime
from enum import Enum
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent
from primaite.simulator.file_system.file_system import FileSystem, Folder
@@ -13,6 +13,9 @@ from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.session_manager import Session
from primaite.simulator.system.core.sys_log import SysLog
if TYPE_CHECKING:
from primaite.simulator.system.core.software_manager import SoftwareManager
class SoftwareType(Enum):
"""
@@ -84,7 +87,7 @@ class Software(SimComponent):
"The count of times the software has been scanned, defaults to 0."
revealed_to_red: bool = False
"Indicates if the software has been revealed to red agent, defaults is False."
software_manager: Any = None
software_manager: "SoftwareManager" = None
"An instance of Software Manager that is used by the parent node."
sys_log: SysLog = None
"An instance of SysLog that is used by the parent node."
@@ -97,19 +100,6 @@ class Software(SimComponent):
_patching_countdown: Optional[int] = None
"Current number of ticks left to patch the software."
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {
"name",
"health_state_actual",
"health_state_visible",
"criticality",
"patching_count",
"scanning_count",
"revealed_to_red",
}
self._original_state = self.model_dump(include=vals_to_include)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(
@@ -248,12 +238,6 @@ class IOSoftware(Software):
_connections: Dict[str, Dict] = {}
"Active connections."
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"installing_count", "max_sessions", "tcp", "udp", "port"}
self._original_state.update(self.model_dump(include=vals_to_include))
@abstractmethod
def describe_state(self) -> Dict:
"""

View File

@@ -589,15 +589,16 @@ simulation:
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.10
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: web_server_web_service
type: WebServer
applications:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server

View File

@@ -593,15 +593,16 @@ simulation:
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.10
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: web_server_web_service
type: WebServer
applications:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server
@@ -624,7 +625,7 @@ simulation:
dns_server: 192.168.1.10
services:
- ref: backup_service
type: DatabaseBackup
type: FTPServer
- ref: security_suite
type: server

View File

@@ -1043,16 +1043,16 @@ simulation:
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.10
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: web_server_web_service
type: WebServer
applications:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server
type: server
@@ -1074,7 +1074,7 @@ simulation:
dns_server: 192.168.1.10
services:
- ref: backup_service
type: DatabaseBackup
type: FTPServer
- ref: security_suite
type: server

View File

@@ -599,15 +599,16 @@ simulation:
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.10
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: web_server_web_service
type: WebServer
applications:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server
@@ -630,7 +631,7 @@ simulation:
dns_server: 192.168.1.10
services:
- ref: backup_service
type: DatabaseBackup
type: FTPServer
- ref: security_suite
type: server

View File

@@ -600,15 +600,16 @@ simulation:
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.10
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: web_server_web_service
type: WebServer
applications:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server
@@ -631,7 +632,7 @@ simulation:
dns_server: 192.168.1.10
services:
- ref: backup_service
type: DatabaseBackup
type: FTPServer
- ref: security_suite
type: server

View File

@@ -510,6 +510,6 @@ def game_and_agent():
reward_function=reward_function,
)
game.agents.append(test_agent)
game.agents["test_agent"] = test_agent
return (game, test_agent)

View File

@@ -11,14 +11,12 @@ from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
# @pytest.mark.skip(reason="no way of currently testing this")
def test_sb3_compatibility():
"""Test that the Gymnasium environment can be used with an SB3 agent."""
with open(example_config_path(), "r") as f:
cfg = yaml.safe_load(f)
game = PrimaiteGame.from_config(cfg)
gym = PrimaiteGymEnv(game=game)
gym = PrimaiteGymEnv(game_config=cfg)
model = PPO("MlpPolicy", gym)
model.learn(total_timesteps=1000)

View File

@@ -42,20 +42,20 @@ def test_example_config():
assert len(game.agents) == 4 # red, blue and 2 green agents
# green agent 1
assert game.agents[0].agent_name == "client_2_green_user"
assert isinstance(game.agents[0], RandomAgent)
assert "client_2_green_user" in game.agents
assert isinstance(game.agents["client_2_green_user"], RandomAgent)
# green agent 2
assert game.agents[1].agent_name == "client_1_green_user"
assert isinstance(game.agents[1], RandomAgent)
assert "client_1_green_user" in game.agents
assert isinstance(game.agents["client_1_green_user"], RandomAgent)
# red agent
assert game.agents[2].agent_name == "client_1_data_manipulation_red_bot"
assert isinstance(game.agents[2], DataManipulationAgent)
assert "client_1_data_manipulation_red_bot" in game.agents
assert isinstance(game.agents["client_1_data_manipulation_red_bot"], DataManipulationAgent)
# blue agent
assert game.agents[3].agent_name == "defender"
assert isinstance(game.agents[3], ProxyAgent)
assert "defender" in game.agents
assert isinstance(game.agents["defender"], ProxyAgent)
network: Network = game.simulation.network

View File

@@ -7,26 +7,11 @@ from primaite.simulator.domain.account import Account, AccountType
@pytest.fixture(scope="function")
def account() -> Account:
acct = Account(username="Jake", password="totally_hashed_password", account_type=AccountType.USER)
acct.set_original_state()
return acct
def test_original_state(account):
"""Test the original state - see if it resets properly"""
account.log_on()
account.log_off()
account.disable()
state = account.describe_state()
assert state["num_logons"] is 1
assert state["num_logoffs"] is 1
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is False
account.reset_component_for_episode(episode=1)
state = account.describe_state()
assert state["num_logons"] is 0
assert state["num_logoffs"] is 0
@@ -39,13 +24,7 @@ def test_original_state(account):
account.log_on()
account.log_off()
account.disable()
account.set_original_state()
account.log_on()
state = account.describe_state()
assert state["num_logons"] is 2
account.reset_component_for_episode(episode=2)
state = account.describe_state()
assert state["num_logons"] is 1
assert state["num_logoffs"] is 1

View File

@@ -185,38 +185,6 @@ def test_get_file(file_system):
file_system.show(full=True)
def test_reset_file_system(file_system):
# file and folder that existed originally
file_system.create_file(file_name="test_file.zip")
file_system.create_folder(folder_name="test_folder")
file_system.set_original_state()
# create a new file
file_system.create_file(file_name="new_file.txt")
# create a new folder
file_system.create_folder(folder_name="new_folder")
# delete the file that existed originally
file_system.delete_file(folder_name="root", file_name="test_file.zip")
assert file_system.get_file(folder_name="root", file_name="test_file.zip") is None
# delete the folder that existed originally
file_system.delete_folder(folder_name="test_folder")
assert file_system.get_folder(folder_name="test_folder") is None
# reset
file_system.reset_component_for_episode(episode=1)
# deleted original file and folder should be back
assert file_system.get_file(folder_name="root", file_name="test_file.zip")
assert file_system.get_folder(folder_name="test_folder")
# new file and folder should be removed
assert file_system.get_file(folder_name="root", file_name="new_file.txt") is None
assert file_system.get_folder(folder_name="new_folder") is None
@pytest.mark.skip(reason="Skipping until we tackle serialisation")
def test_serialisation(file_system):
"""Test to check that the object serialisation works correctly."""

View File

@@ -31,7 +31,6 @@ def network(example_network) -> Network:
assert len(example_network.computers) is 2
assert len(example_network.servers) is 2
example_network.set_original_state()
example_network.show()
return example_network
@@ -45,40 +44,6 @@ def test_describe_state(network):
assert len(state["links"]) is 6
def test_reset_network(network):
"""
Test that the network is properly reset.
TODO: make sure that once implemented - any installed/uninstalled services, processes, apps,
etc are also removed/reinstalled
"""
state_before = network.describe_state()
client_1: Computer = network.get_node_by_hostname("client_1")
server_1: Computer = network.get_node_by_hostname("server_1")
assert client_1.operating_state is NodeOperatingState.ON
assert server_1.operating_state is NodeOperatingState.ON
client_1.power_off()
assert client_1.operating_state is NodeOperatingState.SHUTTING_DOWN
server_1.power_off()
assert server_1.operating_state is NodeOperatingState.SHUTTING_DOWN
assert network.describe_state() != state_before
network.reset_component_for_episode(episode=1)
assert client_1.operating_state is NodeOperatingState.ON
assert server_1.operating_state is NodeOperatingState.ON
# don't worry if UUIDs change
a = filter_keys_nested_item(json.dumps(network.describe_state(), sort_keys=True, indent=2), ["uuid"])
b = filter_keys_nested_item(json.dumps(state_before, sort_keys=True, indent=2), ["uuid"])
assert a == b
def test_creating_container():
"""Check that we can create a network container"""
net = Network()

View File

@@ -19,7 +19,6 @@ def dos_bot() -> DoSBot:
dos_bot: DoSBot = computer.software_manager.software.get("DoSBot")
dos_bot.configure(target_ip_address=IPv4Address("192.168.0.1"))
dos_bot.set_original_state()
return dos_bot
@@ -28,35 +27,6 @@ def test_dos_bot_creation(dos_bot):
assert dos_bot is not None
def test_dos_bot_reset(dos_bot):
assert dos_bot.target_ip_address == IPv4Address("192.168.0.1")
assert dos_bot.target_port is Port.POSTGRES_SERVER
assert dos_bot.payload is None
assert dos_bot.repeat is False
dos_bot.configure(
target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True
)
# should reset the relevant items
dos_bot.reset_component_for_episode(episode=0)
assert dos_bot.target_ip_address == IPv4Address("192.168.0.1")
assert dos_bot.target_port is Port.POSTGRES_SERVER
assert dos_bot.payload is None
assert dos_bot.repeat is False
dos_bot.configure(
target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True
)
dos_bot.set_original_state()
dos_bot.reset_component_for_episode(episode=1)
# should reset to the configured value
assert dos_bot.target_ip_address == IPv4Address("192.168.1.1")
assert dos_bot.target_port is Port.HTTP
assert dos_bot.payload == "payload"
assert dos_bot.repeat is True
def test_dos_bot_cannot_run_when_node_offline(dos_bot):
dos_bot_node: Computer = dos_bot.parent
assert dos_bot_node.operating_state is NodeOperatingState.ON

View File

@@ -2,12 +2,14 @@ from typing import Dict
import pytest
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.software import Software, SoftwareHealthState
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
class TestSoftware(Software):
class TestSoftware(Service):
def describe_state(self) -> Dict:
pass
@@ -15,7 +17,11 @@ class TestSoftware(Software):
@pytest.fixture(scope="function")
def software(file_system):
return TestSoftware(
name="TestSoftware", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_service")
name="TestSoftware",
port=Port.ARP,
file_system=file_system,
sys_log=SysLog(hostname="test_service"),
protocol=IPProtocol.TCP,
)