Merged PR 283: Refactor episode reset
## Summary Instead of setting all attributes back to a snapshot, simply recreate the entire game and all agents from a cached copy of the config. This removes the need for `set_original_state` and `reset_component_for_episode` methods on SimComponents. ## Test process * Unit tests passing * I've also tried adding a `__del__` method to simcomponent and agent and press env.reset() to verify that the ref count reaches 0, and nothing is secretly keeping using an old part of the simulation. ## Checklist - [X] PR is linked to a **work item** - [X] **acceptance criteria** of linked ticket are met - [X] performed **self-review** of the code - [X] written **tests** for any new functionality added with this PR - [X] updated the **documentation** if this PR changes or adds functionality - [ ] written/updated **design docs** if this PR implements new functionality - [X] updated the **change log** - [X] ran **pre-commit** checks for code style - [ ] attended to any **TO-DOs** left in the code Related work items: #2317
This commit is contained in:
@@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
- Made environment reset completely recreate the game object.
|
||||
- Changed the red agent in the data manipulation scenario to randomly choose client 1 or client 2 to start its attack.
|
||||
- Changed the data manipulation scenario to include a second green agent on client 1.
|
||||
- Refactored actions and observations to be configurable via object name, instead of UUID.
|
||||
|
||||
@@ -108,6 +108,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
|
||||
source/simulation
|
||||
source/game_layer
|
||||
source/config
|
||||
source/environment
|
||||
|
||||
.. toctree::
|
||||
:caption: Developer information:
|
||||
|
||||
10
docs/source/environment.rst
Normal file
10
docs/source/environment.rst
Normal file
@@ -0,0 +1,10 @@
|
||||
RL Environments
|
||||
***************
|
||||
|
||||
RL environments are the objects that directly interface with RL libraries such as Stable-Baselines3 and Ray RLLib. The PrimAITE simulation is exposed via three different environment APIs:
|
||||
|
||||
* Gymnasium API - this is the standard interface that works with many RL libraries like SB3, Ray, Tianshou, etc. ``PrimaiteGymEnv`` adheres to the `Official Gymnasium documentation <https://gymnasium.farama.org/api/env/>`_.
|
||||
* Ray Single agent API - For training a single Ray RLLib agent
|
||||
* Ray MARL API - For training multi-agent systems with Ray RLLib. ``PrimaiteRayMARLEnv`` adheres to the `Official Ray documentation <https://docs.ray.io/en/latest/rllib/package_ref/env/multi_agent_env.html>`_.
|
||||
|
||||
There are Jupyter notebooks which demonstrate integration with each of these three environments. They are located in ``~/primaite/<VERSION>/notebooks/example_notebooks``.
|
||||
@@ -20,6 +20,11 @@ The game layer is responsible for managing agents and getting them to interface
|
||||
PrimAITE Session
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. admonition:: Deprecated
|
||||
:class: deprecated
|
||||
|
||||
PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The `session` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality.
|
||||
|
||||
``PrimaiteSession`` is the main entry point into Primaite and it allows the simultaneous coordination of a simulation and agents that interact with it. ``PrimaiteSession`` keeps track of multiple agents of different types.
|
||||
|
||||
Agents
|
||||
|
||||
@@ -4,6 +4,11 @@
|
||||
|
||||
.. _run a primaite session:
|
||||
|
||||
.. admonition:: Deprecated
|
||||
:class: deprecated
|
||||
|
||||
PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The ``session`` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality.
|
||||
|
||||
Run a PrimAITE Session
|
||||
======================
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ io_settings:
|
||||
save_checkpoints: true
|
||||
checkpoint_interval: 5
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: true
|
||||
save_pcap_logs: false
|
||||
save_sys_logs: true
|
||||
|
||||
|
||||
@@ -656,12 +656,13 @@ simulation:
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
applications:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
|
||||
@@ -1,15 +1,23 @@
|
||||
training_config:
|
||||
rl_framework: RLLIB_multi_agent
|
||||
# rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 1
|
||||
n_eval_episodes: 5
|
||||
max_steps_per_episode: 256
|
||||
deterministic_eval: false
|
||||
n_agents: 2
|
||||
agent_references:
|
||||
- defender_1
|
||||
- defender_2
|
||||
|
||||
|
||||
io_settings:
|
||||
save_checkpoints: true
|
||||
checkpoint_interval: 5
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: false
|
||||
save_sys_logs: true
|
||||
|
||||
|
||||
game:
|
||||
@@ -36,9 +44,9 @@ agents:
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_2
|
||||
- node_name: client_2
|
||||
applications:
|
||||
- application_ref: client_2_web_browser
|
||||
- application_name: WebBrowser
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
@@ -54,6 +62,31 @@ agents:
|
||||
frequency: 4
|
||||
variance: 3
|
||||
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: WebBrowser
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 1
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
|
||||
|
||||
|
||||
- ref: data_manipulation_attacker
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
@@ -72,7 +105,7 @@ agents:
|
||||
- type: NODE_OS_SCAN
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_1
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
- node_name: client_2
|
||||
@@ -104,25 +137,21 @@ agents:
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_ref: domain_controller
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_ref: domain_controller_dns_server
|
||||
- node_ref: web_server
|
||||
- service_name: DNSServer
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_ref: web_server_database_client
|
||||
- node_ref: database_server
|
||||
services:
|
||||
- service_ref: database_service
|
||||
- service_name: WebServer
|
||||
- node_hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_ref: backup_server
|
||||
# services:
|
||||
# - service_ref: backup_service
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
- node_hostname: backup_server
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
@@ -137,23 +166,23 @@ agents:
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_node_ref: router_1
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
- node_hostname: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
- node_hostname: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
- node_hostname: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
- node_hostname: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
- node_hostname: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
- node_hostname: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
- node_hostname: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
- node_hostname: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
|
||||
@@ -184,10 +213,10 @@ agents:
|
||||
- type: NODE_RESET
|
||||
- type: NETWORK_ACL_ADDRULE
|
||||
options:
|
||||
target_router_ref: router_1
|
||||
target_router_hostname: router_1
|
||||
- type: NETWORK_ACL_REMOVERULE
|
||||
options:
|
||||
target_router_ref: router_1
|
||||
target_router_hostname: router_1
|
||||
- type: NETWORK_NIC_ENABLE
|
||||
- type: NETWORK_NIC_DISABLE
|
||||
|
||||
@@ -242,25 +271,25 @@ agents:
|
||||
action: "NODE_FILE_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
10:
|
||||
action: "NODE_FILE_CHECKHASH"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
11:
|
||||
action: "NODE_FILE_DELETE"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
12:
|
||||
action: "NODE_FILE_REPAIR"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
13:
|
||||
action: "NODE_SERVICE_PATCH"
|
||||
@@ -271,22 +300,22 @@ agents:
|
||||
action: "NODE_FOLDER_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
15:
|
||||
action: "NODE_FOLDER_CHECKHASH"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
16:
|
||||
action: "NODE_FOLDER_REPAIR"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
17:
|
||||
action: "NODE_FOLDER_RESTORE"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
18:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
@@ -303,63 +332,63 @@ agents:
|
||||
action: "NODE_RESET"
|
||||
options:
|
||||
node_id: 5
|
||||
22:
|
||||
22: # "ACL: ADDRULE - Block outgoing traffic from client 1"
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 1
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
23:
|
||||
23: # "ACL: ADDRULE - Block outgoing traffic from client 2"
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 2
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 1
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
24:
|
||||
24: # block tcp traffic from client 1 to web app
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 3
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 3
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
25:
|
||||
25: # block tcp traffic from client 2 to web app
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 4
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 3
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
26:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 5
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 4
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
27:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 6
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 4
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
@@ -407,123 +436,148 @@ agents:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
39:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
40:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
41:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
42:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
43:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
44:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
45:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
46:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
47:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
48:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 2
|
||||
nic_id: 1
|
||||
49:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 2
|
||||
nic_id: 1
|
||||
50:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
51:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
52:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
53:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: domain_controller
|
||||
- node_ref: web_server
|
||||
- node_name: domain_controller
|
||||
- node_name: web_server
|
||||
applications:
|
||||
- application_name: DatabaseClient
|
||||
services:
|
||||
- service_ref: web_server_web_service
|
||||
- node_ref: database_server
|
||||
- service_name: WebServer
|
||||
- node_name: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
services:
|
||||
- service_ref: database_service
|
||||
- node_ref: backup_server
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
- service_name: DatabaseService
|
||||
- node_name: backup_server
|
||||
- node_name: security_suite
|
||||
- node_name: client_1
|
||||
- node_name: client_2
|
||||
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_name: domain_controller
|
||||
nic_num: 1
|
||||
- node_name: web_server
|
||||
nic_num: 1
|
||||
- node_name: database_server
|
||||
nic_num: 1
|
||||
- node_name: backup_server
|
||||
nic_num: 1
|
||||
- node_name: security_suite
|
||||
nic_num: 1
|
||||
- node_name: client_1
|
||||
nic_num: 1
|
||||
- node_name: client_2
|
||||
nic_num: 1
|
||||
- node_name: security_suite
|
||||
nic_num: 2
|
||||
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
weight: 0.34
|
||||
options:
|
||||
node_ref: database_server
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.33
|
||||
options:
|
||||
node_ref: web_server
|
||||
service_ref: web_server_web_service
|
||||
node_hostname: client_1
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.33
|
||||
options:
|
||||
node_hostname: client_2
|
||||
|
||||
|
||||
agent_settings:
|
||||
# ...
|
||||
|
||||
flatten_obs: true
|
||||
|
||||
- ref: defender_2
|
||||
team: BLUE
|
||||
@@ -537,25 +591,21 @@ agents:
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_ref: domain_controller
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_ref: domain_controller_dns_server
|
||||
- node_ref: web_server
|
||||
- service_name: DNSServer
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_ref: web_server_database_client
|
||||
- node_ref: database_server
|
||||
services:
|
||||
- service_ref: database_service
|
||||
- service_name: WebServer
|
||||
- node_hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_ref: backup_server
|
||||
# services:
|
||||
# - service_ref: backup_service
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
- node_hostname: backup_server
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
@@ -570,23 +620,23 @@ agents:
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_node_ref: router_1
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
- node_hostname: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
- node_hostname: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
- node_hostname: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
- node_hostname: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
- node_hostname: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
- node_hostname: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
- node_hostname: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
- node_hostname: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
|
||||
@@ -617,10 +667,10 @@ agents:
|
||||
- type: NODE_RESET
|
||||
- type: NETWORK_ACL_ADDRULE
|
||||
options:
|
||||
target_router_ref: router_1
|
||||
target_router_hostname: router_1
|
||||
- type: NETWORK_ACL_REMOVERULE
|
||||
options:
|
||||
target_router_ref: router_1
|
||||
target_router_hostname: router_1
|
||||
- type: NETWORK_NIC_ENABLE
|
||||
- type: NETWORK_NIC_DISABLE
|
||||
|
||||
@@ -675,25 +725,25 @@ agents:
|
||||
action: "NODE_FILE_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
10:
|
||||
action: "NODE_FILE_CHECKHASH"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
11:
|
||||
action: "NODE_FILE_DELETE"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
12:
|
||||
action: "NODE_FILE_REPAIR"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
13:
|
||||
action: "NODE_SERVICE_PATCH"
|
||||
@@ -704,22 +754,22 @@ agents:
|
||||
action: "NODE_FOLDER_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
15:
|
||||
action: "NODE_FOLDER_CHECKHASH"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
16:
|
||||
action: "NODE_FOLDER_REPAIR"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
17:
|
||||
action: "NODE_FOLDER_RESTORE"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
folder_id: 0
|
||||
18:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
@@ -736,63 +786,63 @@ agents:
|
||||
action: "NODE_RESET"
|
||||
options:
|
||||
node_id: 5
|
||||
22:
|
||||
22: # "ACL: ADDRULE - Block outgoing traffic from client 1"
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 1
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
23:
|
||||
23: # "ACL: ADDRULE - Block outgoing traffic from client 2"
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 2
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 1
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
24:
|
||||
24: # block tcp traffic from client 1 to web app
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 3
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 3
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
25:
|
||||
25: # block tcp traffic from client 2 to web app
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 4
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 3
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
26:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 5
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 4
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
27:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 6
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 4
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
@@ -840,122 +890,148 @@ agents:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
39:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
40:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
41:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
42:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
43:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
44:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
45:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
46:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
47:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
48:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 2
|
||||
nic_id: 1
|
||||
49:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 2
|
||||
nic_id: 1
|
||||
50:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
51:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
52:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
53:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 1
|
||||
nic_id: 0
|
||||
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: domain_controller
|
||||
- node_ref: web_server
|
||||
- node_name: domain_controller
|
||||
- node_name: web_server
|
||||
applications:
|
||||
- application_name: DatabaseClient
|
||||
services:
|
||||
- service_ref: web_server_web_service
|
||||
- node_ref: database_server
|
||||
- service_name: WebServer
|
||||
- node_name: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
services:
|
||||
- service_ref: database_service
|
||||
- node_ref: backup_server
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
- service_name: DatabaseService
|
||||
- node_name: backup_server
|
||||
- node_name: security_suite
|
||||
- node_name: client_1
|
||||
- node_name: client_2
|
||||
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_name: domain_controller
|
||||
nic_num: 1
|
||||
- node_name: web_server
|
||||
nic_num: 1
|
||||
- node_name: database_server
|
||||
nic_num: 1
|
||||
- node_name: backup_server
|
||||
nic_num: 1
|
||||
- node_name: security_suite
|
||||
nic_num: 1
|
||||
- node_name: client_1
|
||||
nic_num: 1
|
||||
- node_name: client_2
|
||||
nic_num: 1
|
||||
- node_name: security_suite
|
||||
nic_num: 2
|
||||
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
weight: 0.34
|
||||
options:
|
||||
node_ref: database_server
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.33
|
||||
options:
|
||||
node_ref: web_server
|
||||
service_ref: web_server_web_service
|
||||
node_hostname: client_1
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.33
|
||||
options:
|
||||
node_hostname: client_2
|
||||
|
||||
|
||||
agent_settings:
|
||||
# ...
|
||||
flatten_obs: true
|
||||
|
||||
|
||||
|
||||
@@ -1036,12 +1112,13 @@ simulation:
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
applications:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
@@ -1093,10 +1170,14 @@ simulation:
|
||||
- ref: data_manipulation_bot
|
||||
type: DataManipulationBot
|
||||
options:
|
||||
port_scan_p_of_success: 0.1
|
||||
data_manipulation_p_of_success: 0.1
|
||||
port_scan_p_of_success: 0.8
|
||||
data_manipulation_p_of_success: 0.8
|
||||
payload: "DELETE"
|
||||
server_ip: 192.168.1.14
|
||||
- ref: client_1_web_browser
|
||||
type: WebBrowser
|
||||
options:
|
||||
target_url: http://arcd.com/users/
|
||||
services:
|
||||
- ref: client_1_dns_client
|
||||
type: DNSClient
|
||||
@@ -1113,6 +1194,13 @@ simulation:
|
||||
type: WebBrowser
|
||||
options:
|
||||
target_url: http://arcd.com/users/
|
||||
- ref: data_manipulation_bot
|
||||
type: DataManipulationBot
|
||||
options:
|
||||
port_scan_p_of_success: 0.8
|
||||
data_manipulation_p_of_success: 0.8
|
||||
payload: "DELETE"
|
||||
server_ip: 192.168.1.14
|
||||
services:
|
||||
- ref: client_2_dns_client
|
||||
type: DNSClient
|
||||
|
||||
@@ -80,18 +80,15 @@ class PrimaiteGame:
|
||||
self.simulation: Simulation = Simulation()
|
||||
"""Simulation object with which the agents will interact."""
|
||||
|
||||
self.agents: List[AbstractAgent] = []
|
||||
"""List of agents."""
|
||||
self.agents: Dict[str, AbstractAgent] = {}
|
||||
"""Mapping from agent name to agent object."""
|
||||
|
||||
self.rl_agents: List[ProxyAgent] = []
|
||||
"""Subset of agent list including only the reinforcement learning agents."""
|
||||
self.rl_agents: Dict[str, ProxyAgent] = {}
|
||||
"""Subset of agents which are intended for reinforcement learning."""
|
||||
|
||||
self.step_counter: int = 0
|
||||
"""Current timestep within the episode."""
|
||||
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
|
||||
self.options: PrimaiteGameOptions
|
||||
"""Special options that apply for the entire game."""
|
||||
|
||||
@@ -137,7 +134,7 @@ class PrimaiteGame:
|
||||
self.update_agents(sim_state)
|
||||
|
||||
# Apply all actions to simulation as requests
|
||||
agent_actions = self.apply_agent_actions() # noqa
|
||||
self.apply_agent_actions()
|
||||
|
||||
# Advance timestep
|
||||
self.advance_timestep()
|
||||
@@ -148,7 +145,7 @@ class PrimaiteGame:
|
||||
|
||||
def update_agents(self, state: Dict) -> None:
|
||||
"""Update agents' observations and rewards based on the current state."""
|
||||
for agent in self.agents:
|
||||
for _, agent in self.agents.items():
|
||||
agent.update_observation(state)
|
||||
agent.update_reward(state)
|
||||
agent.reward_function.total_reward += agent.reward_function.current_reward
|
||||
@@ -162,7 +159,7 @@ class PrimaiteGame:
|
||||
|
||||
"""
|
||||
agent_actions = {}
|
||||
for agent in self.agents:
|
||||
for _, agent in self.agents.items():
|
||||
obs = agent.observation_manager.current_observation
|
||||
rew = agent.reward_function.current_reward
|
||||
action_choice, options = agent.get_action(obs, rew)
|
||||
@@ -185,20 +182,14 @@ class PrimaiteGame:
|
||||
return True
|
||||
return False
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the game, this will reset the simulation."""
|
||||
self.episode_counter += 1
|
||||
self.step_counter = 0
|
||||
_LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}")
|
||||
self.simulation.reset_component_for_episode(episode=self.episode_counter)
|
||||
for agent in self.agents:
|
||||
agent.reward_function.total_reward = 0.0
|
||||
agent.reset_agent_for_episode()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the game, this will close the simulation."""
|
||||
return NotImplemented
|
||||
|
||||
def setup_for_episode(self, episode: int) -> None:
|
||||
"""Perform any final configuration of components to make them ready for the game to start."""
|
||||
self.simulation.setup_for_episode(episode=episode)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: Dict) -> "PrimaiteGame":
|
||||
"""Create a PrimaiteGame object from a config dictionary.
|
||||
@@ -272,7 +263,9 @@ class PrimaiteGame:
|
||||
new_service = new_node.software_manager.software[service_type]
|
||||
game.ref_map_services[service_ref] = new_service.uuid
|
||||
else:
|
||||
_LOGGER.warning(f"service type not found {service_type}")
|
||||
msg = f"Configuration contains an invalid service type: {service_type}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
# service-dependent options
|
||||
if service_type == "DNSClient":
|
||||
if "options" in service_cfg:
|
||||
@@ -311,7 +304,9 @@ class PrimaiteGame:
|
||||
new_application = new_node.software_manager.software[application_type]
|
||||
game.ref_map_applications[application_ref] = new_application.uuid
|
||||
else:
|
||||
_LOGGER.warning(f"application type not found {application_type}")
|
||||
msg = f"Configuration contains an invalid application type: {application_type}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
if application_type == "DataManipulationBot":
|
||||
if "options" in application_cfg:
|
||||
@@ -402,7 +397,6 @@ class PrimaiteGame:
|
||||
reward_function=reward_function,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
game.agents.append(new_agent)
|
||||
elif agent_type == "ProxyAgent":
|
||||
new_agent = ProxyAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
@@ -411,8 +405,7 @@ class PrimaiteGame:
|
||||
reward_function=reward_function,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
game.agents.append(new_agent)
|
||||
game.rl_agents.append(new_agent)
|
||||
game.rl_agents[agent_cfg["ref"]] = new_agent
|
||||
elif agent_type == "RedDatabaseCorruptingAgent":
|
||||
new_agent = DataManipulationAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
@@ -421,11 +414,11 @@ class PrimaiteGame:
|
||||
reward_function=reward_function,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
game.agents.append(new_agent)
|
||||
else:
|
||||
_LOGGER.warning(f"agent type {agent_type} not found")
|
||||
|
||||
game.simulation.set_original_state()
|
||||
msg(f"Configuration error: {agent_type} is not a valid agent type.")
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
game.agents[agent_cfg["ref"]] = new_agent
|
||||
|
||||
# Set the NMNE capture config
|
||||
set_nmne_config(cfg["simulation"]["network"].get("nmne_config", {}))
|
||||
|
||||
@@ -60,7 +60,7 @@
|
||||
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
|
||||
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
|
||||
" )\n",
|
||||
" .environment(env=PrimaiteRayMARLEnv, env_config={\"cfg\":cfg})#, disable_env_checking=True)\n",
|
||||
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)#, disable_env_checking=True)\n",
|
||||
" .rollouts(num_rollout_workers=0)\n",
|
||||
" .training(train_batch_size=128)\n",
|
||||
" )\n"
|
||||
@@ -88,6 +88,13 @@
|
||||
" param_space=config\n",
|
||||
").fit()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -54,7 +54,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env_config = {\"cfg\":cfg}\n",
|
||||
"env_config = cfg\n",
|
||||
"\n",
|
||||
"config = (\n",
|
||||
" PPOConfig()\n",
|
||||
|
||||
@@ -27,9 +27,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(example_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"\n",
|
||||
"game = PrimaiteGame.from_config(cfg)"
|
||||
" cfg = yaml.safe_load(f)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -38,7 +36,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gym = PrimaiteGymEnv(game=game)"
|
||||
"gym = PrimaiteGymEnv(game_config=cfg)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -65,7 +63,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.learn(total_timesteps=1000)\n"
|
||||
"model.learn(total_timesteps=10)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -76,6 +74,13 @@
|
||||
"source": [
|
||||
"model.save(\"deleteme\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -346,9 +346,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
@@ -358,9 +356,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Imports\n",
|
||||
@@ -383,9 +379,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# create the env\n",
|
||||
@@ -396,10 +390,10 @@
|
||||
" cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['data_manipulation_p_of_success'] = 1.0\n",
|
||||
" cfg['simulation']['network']['nodes'][8]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n",
|
||||
" cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n",
|
||||
"game = PrimaiteGame.from_config(cfg)\n",
|
||||
"env = PrimaiteGymEnv(game = game)\n",
|
||||
"# Don't flatten obs as we are not training an agent and we wish to see the dict-formatted observations\n",
|
||||
"env.agent.flatten_obs = False\n",
|
||||
" # don't flatten observations so that we can see what is going on\n",
|
||||
" cfg['agents'][3]['agent_settings']['flatten_obs'] = False\n",
|
||||
"\n",
|
||||
"env = PrimaiteGymEnv(game_config = cfg)\n",
|
||||
"obs, info = env.reset()\n",
|
||||
"print('env created successfully')\n",
|
||||
"pprint(obs)"
|
||||
@@ -433,9 +427,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for step in range(35):\n",
|
||||
@@ -453,9 +445,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pprint(obs['NODES'])"
|
||||
@@ -471,9 +461,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs, reward, terminated, truncated, info = env.step(9) # scan database file\n",
|
||||
@@ -499,9 +487,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs, reward, terminated, truncated, info = env.step(13) # patch the database\n",
|
||||
@@ -526,9 +512,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs, reward, terminated, truncated, info = env.step(0) # patch the database\n",
|
||||
@@ -551,9 +535,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.step(13) # Patch the database\n",
|
||||
@@ -593,6 +575,22 @@
|
||||
"obs['ACL']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Reset the environment, you can rerun the other cells to verify that the attack works the same every episode."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.reset()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -603,7 +601,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -617,9 +615,9 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Dict, Final, Optional, SupportsFloat, Tuple
|
||||
from typing import Any, Dict, Optional, SupportsFloat, Tuple
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
@@ -18,11 +18,23 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
assumptions about the agent list always having a list of length 1.
|
||||
"""
|
||||
|
||||
def __init__(self, game: PrimaiteGame):
|
||||
def __init__(self, game_config: Dict):
|
||||
"""Initialise the environment."""
|
||||
super().__init__()
|
||||
self.game: "PrimaiteGame" = game
|
||||
self.agent: ProxyAgent = self.game.rl_agents[0]
|
||||
self.game_config: Dict = game_config
|
||||
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config)
|
||||
"""Current game."""
|
||||
self._agent_name = next(iter(self.game.rl_agents))
|
||||
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
|
||||
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
|
||||
@property
|
||||
def agent(self) -> ProxyAgent:
|
||||
"""Grab a fresh reference to the agent object because it will be reinstantiated each episode."""
|
||||
return self.game.rl_agents[self._agent_name]
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||
"""Perform a step in the environment."""
|
||||
@@ -45,13 +57,13 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
|
||||
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata"
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"step_{self.game.step_counter}.json"
|
||||
|
||||
data = {
|
||||
"episode": self.game.episode_counter,
|
||||
"episode": self.episode_counter,
|
||||
"step": self.game.step_counter,
|
||||
"action": int(action),
|
||||
"reward": int(reward),
|
||||
@@ -63,10 +75,12 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
"""Reset the environment."""
|
||||
print(
|
||||
f"Resetting environment, episode {self.game.episode_counter}, "
|
||||
f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}"
|
||||
f"Resetting environment, episode {self.episode_counter}, "
|
||||
f"avg. reward: {self.agent.reward_function.total_reward}"
|
||||
)
|
||||
self.game.reset()
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config)
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
self.episode_counter += 1
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
@@ -88,12 +102,12 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
|
||||
def _get_obs(self) -> ObsType:
|
||||
"""Return the current observation."""
|
||||
if not self.agent.flatten_obs:
|
||||
return self.agent.observation_manager.current_observation
|
||||
else:
|
||||
if self.agent.flatten_obs:
|
||||
unflat_space = self.agent.observation_manager.space
|
||||
unflat_obs = self.agent.observation_manager.current_observation
|
||||
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
else:
|
||||
return self.agent.observation_manager.current_observation
|
||||
|
||||
|
||||
class PrimaiteRayEnv(gymnasium.Env):
|
||||
@@ -102,12 +116,11 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
"""Initialise the environment.
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
|
||||
which is the PrimaiteGame instance.
|
||||
:type env_config: Dict[str, PrimaiteGame]
|
||||
:param env_config: A dictionary containing the environment configuration.
|
||||
:type env_config: Dict
|
||||
"""
|
||||
self.env = PrimaiteGymEnv(game=PrimaiteGame.from_config(env_config["cfg"]))
|
||||
self.env.game.episode_counter -= 1
|
||||
self.env = PrimaiteGymEnv(game_config=env_config)
|
||||
self.env.episode_counter -= 1
|
||||
self.action_space = self.env.action_space
|
||||
self.observation_space = self.env.observation_space
|
||||
|
||||
@@ -128,13 +141,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
|
||||
which is the PrimaiteGame instance.
|
||||
:type env_config: Dict[str, PrimaiteGame]
|
||||
:type env_config: Dict
|
||||
"""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(env_config["cfg"])
|
||||
self.game_config: Dict = env_config
|
||||
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config)
|
||||
"""Reference to the primaite game"""
|
||||
self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents}
|
||||
"""List of all possible agents in the environment. This list should not change!"""
|
||||
self._agent_ids = list(self.agents.keys())
|
||||
self._agent_ids = list(self.game.rl_agents.keys())
|
||||
"""Agent ids. This is a list of strings of agent names."""
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
|
||||
self.terminateds = set()
|
||||
self.truncateds = set()
|
||||
@@ -149,9 +165,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def agents(self) -> Dict[str, ProxyAgent]:
|
||||
"""Grab a fresh reference to the agents from this episode's game object."""
|
||||
return {name: self.game.rl_agents[name] for name in self._agent_ids}
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
self.game.reset()
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config)
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
self.episode_counter += 1
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
@@ -172,7 +195,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
# 1. Perform actions
|
||||
for agent_name, action in actions.items():
|
||||
self.agents[agent_name].store_action(action)
|
||||
agent_actions = self.game.apply_agent_actions()
|
||||
self.game.apply_agent_actions()
|
||||
|
||||
# 2. Advance timestep
|
||||
self.game.advance_timestep()
|
||||
@@ -186,7 +209,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
|
||||
terminateds = {name: False for name, _ in self.agents.items()}
|
||||
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
|
||||
infos = {"agent_actions": agent_actions}
|
||||
infos = {name: {} for name, _ in self.agents.items()}
|
||||
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
|
||||
truncateds["__all__"] = self.game.calculate_truncated()
|
||||
if self.game.save_step_metadata:
|
||||
@@ -194,13 +217,13 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
return next_obs, rewards, terminateds, truncateds, infos
|
||||
|
||||
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata"
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"step_{self.game.step_counter}.json"
|
||||
|
||||
data = {
|
||||
"episode": self.game.episode_counter,
|
||||
"episode": self.episode_counter,
|
||||
"step": self.game.step_counter,
|
||||
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
|
||||
"reward": rewards,
|
||||
@@ -212,8 +235,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
def _get_obs(self) -> Dict[str, ObsType]:
|
||||
"""Return the current observation."""
|
||||
obs = {}
|
||||
for name, agent in self.agents.items():
|
||||
for agent_name in self._agent_ids:
|
||||
agent = self.game.rl_agents[agent_name]
|
||||
unflat_space = agent.observation_manager.space
|
||||
unflat_obs = agent.observation_manager.current_observation
|
||||
obs[name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
return obs
|
||||
|
||||
@@ -101,11 +101,11 @@ class PrimaiteSession:
|
||||
|
||||
# CREATE ENVIRONMENT
|
||||
if sess.training_options.rl_framework == "RLLIB_single_agent":
|
||||
sess.env = PrimaiteRayEnv(env_config={"cfg": cfg})
|
||||
sess.env = PrimaiteRayEnv(env_config=cfg)
|
||||
elif sess.training_options.rl_framework == "RLLIB_multi_agent":
|
||||
sess.env = PrimaiteRayMARLEnv(env_config={"cfg": cfg})
|
||||
sess.env = PrimaiteRayMARLEnv(env_config=cfg)
|
||||
elif sess.training_options.rl_framework == "SB3":
|
||||
sess.env = PrimaiteGymEnv(game=game)
|
||||
sess.env = PrimaiteGymEnv(game_config=cfg)
|
||||
|
||||
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
|
||||
if agent_load_path:
|
||||
|
||||
@@ -153,22 +153,18 @@ class SimComponent(BaseModel):
|
||||
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
||||
"""The component UUID."""
|
||||
|
||||
_original_state: Dict = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._request_manager: RequestManager = self._init_request_manager()
|
||||
self._parent: Optional["SimComponent"] = None
|
||||
|
||||
# @abstractmethod
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
pass
|
||||
def setup_for_episode(self, episode: int):
|
||||
"""
|
||||
Perform any additional setup on this component that can't happen during __init__.
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
for key, value in self._original_state.items():
|
||||
self.__setattr__(key, value)
|
||||
For instance, some components may require for the entire network to exist before some configuration can be set.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
|
||||
@@ -42,19 +42,6 @@ class Account(SimComponent):
|
||||
"Account Type, currently this can be service account (used by apps) or user account."
|
||||
enabled: bool = True
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_include = {
|
||||
"num_logons",
|
||||
"num_logoffs",
|
||||
"num_group_changes",
|
||||
"username",
|
||||
"password",
|
||||
"account_type",
|
||||
"enabled",
|
||||
}
|
||||
self._original_state = self.model_dump(include=vals_to_include)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
|
||||
@@ -73,20 +73,6 @@ class File(FileSystemItemABC):
|
||||
|
||||
self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})")
|
||||
|
||||
self.set_original_state()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
_LOGGER.debug(f"Setting File ({self.path}) original state on node {self.sys_log.hostname}")
|
||||
super().set_original_state()
|
||||
vals_to_include = {"folder_id", "folder_name", "file_type", "sim_size", "real", "sim_path", "sim_root"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting File ({self.path}) state on node {self.sys_log.hostname}")
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
"""
|
||||
|
||||
@@ -34,43 +34,6 @@ class FileSystem(SimComponent):
|
||||
if not self.folders:
|
||||
self.create_folder("root")
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
_LOGGER.debug(f"Setting FileSystem original state on node {self.sys_log.hostname}")
|
||||
for folder in self.folders.values():
|
||||
folder.set_original_state()
|
||||
# Capture a list of all 'original' file uuids
|
||||
original_keys = list(self.folders.keys())
|
||||
vals_to_include = {"sim_root"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
self._original_state["original_folder_uuids"] = original_keys
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting FileSystem state on node {self.sys_log.hostname}")
|
||||
# Move any 'original' folder that have been deleted back to folders
|
||||
original_folder_uuids = self._original_state["original_folder_uuids"]
|
||||
for uuid in original_folder_uuids:
|
||||
if uuid in self.deleted_folders:
|
||||
folder = self.deleted_folders[uuid]
|
||||
self.deleted_folders.pop(uuid)
|
||||
self.folders[uuid] = folder
|
||||
|
||||
# Clear any other deleted folders that aren't original (have been created by agent)
|
||||
self.deleted_folders.clear()
|
||||
|
||||
# Now clear all non-original folders created by agent
|
||||
current_folder_uuids = list(self.folders.keys())
|
||||
for uuid in current_folder_uuids:
|
||||
if uuid not in original_folder_uuids:
|
||||
folder = self.folders[uuid]
|
||||
self.folders.pop(uuid)
|
||||
|
||||
# Now reset all remaining folders
|
||||
for folder in self.folders.values():
|
||||
folder.reset_component_for_episode(episode)
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
|
||||
@@ -85,11 +85,6 @@ class FileSystemItemABC(SimComponent):
|
||||
deleted: bool = False
|
||||
"If true, the FileSystemItem was deleted."
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red", "deleted"}
|
||||
self._original_state = self.model_dump(include=vals_to_keep)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
|
||||
@@ -49,49 +49,6 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})")
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
_LOGGER.debug(f"Setting Folder ({self.name}) original state on node {self.sys_log.hostname}")
|
||||
for file in self.files.values():
|
||||
file.set_original_state()
|
||||
super().set_original_state()
|
||||
vals_to_include = {
|
||||
"scan_duration",
|
||||
"scan_countdown",
|
||||
"red_scan_duration",
|
||||
"red_scan_countdown",
|
||||
"restore_duration",
|
||||
"restore_countdown",
|
||||
}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
self._original_state["original_file_uuids"] = list(self.files.keys())
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting Folder ({self.name}) state on node {self.sys_log.hostname}")
|
||||
# Move any 'original' file that have been deleted back to files
|
||||
original_file_uuids = self._original_state["original_file_uuids"]
|
||||
for uuid in original_file_uuids:
|
||||
if uuid in self.deleted_files:
|
||||
file = self.deleted_files[uuid]
|
||||
self.deleted_files.pop(uuid)
|
||||
self.files[uuid] = file
|
||||
|
||||
# Clear any other deleted files that aren't original (have been created by agent)
|
||||
self.deleted_files.clear()
|
||||
|
||||
# Now clear all non-original files created by agent
|
||||
current_file_uuids = list(self.files.keys())
|
||||
for uuid in current_file_uuids:
|
||||
if uuid not in original_file_uuids:
|
||||
file = self.files[uuid]
|
||||
self.files.pop(uuid)
|
||||
|
||||
# Now reset all remaining files
|
||||
for file in self.files.values():
|
||||
file.reset_component_for_episode(episode)
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
|
||||
@@ -273,11 +273,6 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC)
|
||||
|
||||
return state
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"}
|
||||
self._original_state = self.model_dump(include=vals_to_include)
|
||||
|
||||
def enable(self):
|
||||
"""
|
||||
Enables this wired network interface and attempts to send a "hello" message to the default gateway.
|
||||
|
||||
@@ -45,19 +45,12 @@ class Network(SimComponent):
|
||||
|
||||
self._nx_graph = MultiGraph()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
for node in self.nodes.values():
|
||||
node.set_original_state()
|
||||
for link in self.links.values():
|
||||
link.set_original_state()
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
def setup_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
for node in self.nodes.values():
|
||||
node.reset_component_for_episode(episode)
|
||||
node.setup_for_episode(episode=episode)
|
||||
for link in self.links.values():
|
||||
link.reset_component_for_episode(episode)
|
||||
link.setup_for_episode(episode=episode)
|
||||
|
||||
for node in self.nodes.values():
|
||||
node.power_on()
|
||||
@@ -179,7 +172,7 @@ class Network(SimComponent):
|
||||
def clear_links(self):
|
||||
"""Clear all the links in the network by resetting their component state for the episode."""
|
||||
for link in self.links.values():
|
||||
link.reset_component_for_episode()
|
||||
link.setup_for_episode(episode=0) # TODO: shouldn't be using this method here.
|
||||
|
||||
def draw(self, seed: int = 123):
|
||||
"""
|
||||
|
||||
@@ -100,6 +100,15 @@ class NetworkInterface(SimComponent, ABC):
|
||||
nmne: Dict = Field(default_factory=lambda: {})
|
||||
"A dict containing details of the number of malicious network events captured."
|
||||
|
||||
def setup_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
super().setup_for_episode(episode=episode)
|
||||
self.nmne = {}
|
||||
if episode and self.pcap:
|
||||
self.pcap.current_episode = episode
|
||||
self.pcap.setup_logger()
|
||||
self.enable()
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
@@ -127,15 +136,6 @@ class NetworkInterface(SimComponent, ABC):
|
||||
state.update({"nmne": self.nmne})
|
||||
return state
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
super().reset_component_for_episode(episode)
|
||||
self.nmne = {}
|
||||
if episode and self.pcap:
|
||||
self.pcap.current_episode = episode
|
||||
self.pcap.setup_logger()
|
||||
self.enable()
|
||||
|
||||
@abstractmethod
|
||||
def enable(self):
|
||||
"""Enable the interface."""
|
||||
@@ -547,14 +547,6 @@ class Link(SimComponent):
|
||||
self.endpoint_b.connect_link(self)
|
||||
self.endpoint_up()
|
||||
|
||||
self.set_original_state()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_include = {"bandwidth", "current_load"}
|
||||
self._original_state = self.model_dump(include=vals_to_include)
|
||||
super().set_original_state()
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
@@ -740,50 +732,20 @@ class Node(SimComponent):
|
||||
self.session_manager.node = self
|
||||
self.session_manager.software_manager = self.software_manager
|
||||
self._install_system_software()
|
||||
self.set_original_state()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
for software in self.software_manager.software.values():
|
||||
software.set_original_state()
|
||||
|
||||
self.file_system.set_original_state()
|
||||
|
||||
for network_interface in self.network_interfaces.values():
|
||||
network_interface.set_original_state()
|
||||
|
||||
vals_to_include = {
|
||||
"hostname",
|
||||
"default_gateway",
|
||||
"operating_state",
|
||||
"revealed_to_red",
|
||||
"start_up_duration",
|
||||
"start_up_countdown",
|
||||
"shut_down_duration",
|
||||
"shut_down_countdown",
|
||||
"is_resetting",
|
||||
"node_scan_duration",
|
||||
"node_scan_countdown",
|
||||
"red_scan_countdown",
|
||||
}
|
||||
self._original_state = self.model_dump(include=vals_to_include)
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
def setup_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
# Reset Session Manager
|
||||
self.session_manager.clear()
|
||||
super().setup_for_episode(episode=episode)
|
||||
|
||||
# Reset File System
|
||||
self.file_system.reset_component_for_episode(episode)
|
||||
self.file_system.setup_for_episode(episode=episode)
|
||||
|
||||
# Reset all Nics
|
||||
for network_interface in self.network_interfaces.values():
|
||||
network_interface.reset_component_for_episode(episode)
|
||||
network_interface.setup_for_episode(episode=episode)
|
||||
|
||||
for software in self.software_manager.software.values():
|
||||
software.reset_component_for_episode(episode)
|
||||
software.setup_for_episode(episode=episode)
|
||||
|
||||
if episode and self.sys_log:
|
||||
self.sys_log.current_episode = episode
|
||||
|
||||
@@ -209,11 +209,6 @@ class NIC(IPWiredNetworkInterface):
|
||||
|
||||
return state
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"}
|
||||
self._original_state = self.model_dump(include=vals_to_include)
|
||||
|
||||
def receive_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
Attempt to receive and process a network frame from the connected Link.
|
||||
|
||||
@@ -109,24 +109,6 @@ class Firewall(Router):
|
||||
sys_log=kwargs["sys_log"], implicit_action=ACLAction.PERMIT, name=f"{hostname} - External Outbound"
|
||||
)
|
||||
|
||||
self.set_original_state()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Set the original state for the Firewall."""
|
||||
super().set_original_state()
|
||||
vals_to_include = {
|
||||
"internal_port",
|
||||
"external_port",
|
||||
"dmz_port",
|
||||
"internal_inbound_acl",
|
||||
"internal_outbound_acl",
|
||||
"dmz_inbound_acl",
|
||||
"dmz_outbound_acl",
|
||||
"external_inbound_acl",
|
||||
"external_outbound_acl",
|
||||
}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the Firewall.
|
||||
|
||||
@@ -136,11 +136,6 @@ class ACLRule(SimComponent):
|
||||
rule_strings.append(f"{key}={value}")
|
||||
return ", ".join(rule_strings)
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_keep = {"action", "protocol", "src_ip_address", "src_port", "dst_ip_address", "dst_port"}
|
||||
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the ACLRule.
|
||||
@@ -296,48 +291,6 @@ class AccessControlList(SimComponent):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self._acl = [None] * (self.max_acl_rules - 1)
|
||||
self.set_original_state()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
self.implicit_rule.set_original_state()
|
||||
vals_to_keep = {"implicit_action", "max_acl_rules", "acl"}
|
||||
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
|
||||
|
||||
for i, rule in enumerate(self._acl):
|
||||
if not rule:
|
||||
continue
|
||||
self._default_config[i] = {"action": rule.action.name}
|
||||
if rule.src_ip_address:
|
||||
self._default_config[i]["src_ip"] = str(rule.src_ip_address)
|
||||
if rule.dst_ip_address:
|
||||
self._default_config[i]["dst_ip"] = str(rule.dst_ip_address)
|
||||
if rule.src_port:
|
||||
self._default_config[i]["src_port"] = rule.src_port.name
|
||||
if rule.dst_port:
|
||||
self._default_config[i]["dst_port"] = rule.dst_port.name
|
||||
if rule.protocol:
|
||||
self._default_config[i]["protocol"] = rule.protocol.name
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.implicit_rule.reset_component_for_episode(episode)
|
||||
super().reset_component_for_episode(episode)
|
||||
self._reset_rules_to_default()
|
||||
|
||||
def _reset_rules_to_default(self) -> None:
|
||||
"""Clear all ACL rules and set them to the default rules config."""
|
||||
self._acl = [None] * (self.max_acl_rules - 1)
|
||||
for r_num, r_cfg in self._default_config.items():
|
||||
self.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
# TODO: Add src and dst wildcard masks as positional args in this request.
|
||||
@@ -616,11 +569,6 @@ class RouteEntry(SimComponent):
|
||||
metric: float = 0.0
|
||||
"The cost metric for this route. Default is 0.0."
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"}
|
||||
self._original_values = self.model_dump(include=vals_to_include)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the RouteEntry.
|
||||
@@ -653,17 +601,6 @@ class RouteTable(SimComponent):
|
||||
default_route: Optional[RouteEntry] = None
|
||||
sys_log: SysLog
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
self._original_state["routes_orig"] = self.routes
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.routes.clear()
|
||||
self.routes = self._original_state["routes_orig"]
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the RouteTable.
|
||||
@@ -1104,8 +1041,6 @@ class Router(NetworkNode):
|
||||
|
||||
self._set_default_acl()
|
||||
|
||||
self.set_original_state()
|
||||
|
||||
def _install_system_software(self):
|
||||
"""
|
||||
Installs essential system software and network services on the router.
|
||||
@@ -1131,20 +1066,7 @@ class Router(NetworkNode):
|
||||
self.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
|
||||
self.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
|
||||
|
||||
def set_original_state(self):
|
||||
"""
|
||||
Sets or resets the router to its original configuration state.
|
||||
|
||||
This method is called to initialize the router's state or to revert it to a known good configuration during
|
||||
network simulations or after configuration changes.
|
||||
"""
|
||||
self.acl.set_original_state()
|
||||
self.route_table.set_original_state()
|
||||
super().set_original_state()
|
||||
vals_to_include = {"num_ports"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
def setup_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the router's components for a new network simulation episode.
|
||||
|
||||
@@ -1154,13 +1076,10 @@ class Router(NetworkNode):
|
||||
:param episode: The episode number for which the router is being reset.
|
||||
"""
|
||||
self.software_manager.arp.clear()
|
||||
self.acl.reset_component_for_episode(episode)
|
||||
self.route_table.reset_component_for_episode(episode)
|
||||
for i, network_interface in self.network_interface.items():
|
||||
network_interface.reset_component_for_episode(episode)
|
||||
for i, _ in self.network_interface.items():
|
||||
self.enable_port(i)
|
||||
|
||||
super().reset_component_for_episode(episode)
|
||||
super().setup_for_episode(episode=episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
@@ -1391,7 +1310,6 @@ class Router(NetworkNode):
|
||||
network_interface.ip_address = ip_address
|
||||
network_interface.subnet_mask = subnet_mask
|
||||
self.sys_log.info(f"Configured Network Interface {network_interface}")
|
||||
self.set_original_state()
|
||||
|
||||
def enable_port(self, port: int):
|
||||
"""
|
||||
@@ -1493,6 +1411,14 @@ class Router(NetworkNode):
|
||||
subnet_mask=port_cfg["subnet_mask"],
|
||||
)
|
||||
if "acl" in cfg:
|
||||
new.acl._default_config = cfg["acl"] # save the config to allow resetting
|
||||
new.acl._reset_rules_to_default() # read the config and apply rules
|
||||
for r_num, r_cfg in cfg["acl"].items():
|
||||
new.acl.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
position=r_num,
|
||||
)
|
||||
return new
|
||||
|
||||
@@ -32,12 +32,6 @@ class SwitchPort(WiredNetworkInterface):
|
||||
_connected_node: Optional[Switch] = None
|
||||
"The Switch to which the SwitchPort is connected."
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_include = {"port_num", "mac_address", "speed", "mtu", "enabled"}
|
||||
self._original_state = self.model_dump(include=vals_to_include)
|
||||
super().set_original_state()
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
|
||||
@@ -122,8 +122,6 @@ class WirelessRouter(Router):
|
||||
|
||||
self.connect_nic(RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0"))
|
||||
|
||||
self.set_original_state()
|
||||
|
||||
@property
|
||||
def wireless_access_point(self) -> WirelessAccessPoint:
|
||||
"""
|
||||
@@ -166,7 +164,6 @@ class WirelessRouter(Router):
|
||||
network_interface.ip_address = ip_address
|
||||
network_interface.subnet_mask = subnet_mask
|
||||
self.sys_log.info(f"Configured WAP {network_interface}")
|
||||
self.set_original_state()
|
||||
self.wireless_access_point.frequency = frequency # Set operating frequency
|
||||
self.wireless_access_point.enable() # Re-enable the WAP with new settings
|
||||
|
||||
|
||||
@@ -21,13 +21,9 @@ class Simulation(SimComponent):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
self.network.set_original_state()
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
def setup_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.network.reset_component_for_episode(episode)
|
||||
self.network.setup_for_episode(episode=episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
@@ -38,12 +38,6 @@ class Application(IOSoftware):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
vals_to_include = {"operating_state", "execution_control_status", "num_executions", "groups"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
@abstractmethod
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
|
||||
@@ -31,20 +31,6 @@ class DatabaseClient(Application):
|
||||
kwargs["port"] = Port.POSTGRES_SERVER
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
self.set_original_state()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
_LOGGER.debug(f"Setting DatabaseClient WebServer original state on node {self.software_manager.node.hostname}")
|
||||
super().set_original_state()
|
||||
vals_to_include = {"server_ip_address", "server_password", "connected", "_query_success_tracker"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting DataBaseClient state on node {self.software_manager.node.hostname}")
|
||||
super().reset_component_for_episode(episode)
|
||||
self._query_success_tracker.clear()
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
|
||||
@@ -49,26 +49,6 @@ class DataManipulationBot(DatabaseClient):
|
||||
super().__init__(**kwargs)
|
||||
self.name = "DataManipulationBot"
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
_LOGGER.debug(f"Setting DataManipulationBot original state on node {self.software_manager.node.hostname}")
|
||||
super().set_original_state()
|
||||
vals_to_include = {
|
||||
"server_ip_address",
|
||||
"payload",
|
||||
"server_password",
|
||||
"port_scan_p_of_success",
|
||||
"data_manipulation_p_of_success",
|
||||
"attack_stage",
|
||||
"repeat",
|
||||
}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting DataManipulationBot state on node {self.software_manager.node.hostname}")
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
|
||||
@@ -57,27 +57,6 @@ class DoSBot(DatabaseClient, Application):
|
||||
self.name = "DoSBot"
|
||||
self.max_sessions = 1000 # override normal max sessions
|
||||
|
||||
def set_original_state(self):
|
||||
"""Set the original state of the Denial of Service Bot."""
|
||||
_LOGGER.debug(f"Setting {self.name} original state on node {self.software_manager.node.hostname}")
|
||||
super().set_original_state()
|
||||
vals_to_include = {
|
||||
"target_ip_address",
|
||||
"target_port",
|
||||
"payload",
|
||||
"repeat",
|
||||
"attack_stage",
|
||||
"max_sessions",
|
||||
"port_scan_p_of_success",
|
||||
"dos_intensity",
|
||||
}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting {self.name} state on node {self.software_manager.node.hostname}")
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
|
||||
@@ -47,21 +47,8 @@ class WebBrowser(Application):
|
||||
kwargs["port"] = Port.HTTP
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.set_original_state()
|
||||
self.run()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
_LOGGER.debug(f"Setting WebBrowser original state on node {self.software_manager.node.hostname}")
|
||||
super().set_original_state()
|
||||
vals_to_include = {"target_url", "domain_name_ip_address", "latest_response"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting WebBrowser state on node {self.software_manager.node.hostname}")
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
@@ -80,9 +67,6 @@ class WebBrowser(Application):
|
||||
state["history"] = [hist_item.state() for hist_item in self.history]
|
||||
return state
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
|
||||
def get_webpage(self, url: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Retrieve the webpage.
|
||||
|
||||
@@ -24,12 +24,6 @@ class Process(Software):
|
||||
operating_state: ProcessOperatingState
|
||||
"The current operating state of the Process."
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
vals_to_include = {"operating_state"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
@abstractmethod
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
|
||||
@@ -40,25 +40,6 @@ class DatabaseService(Service):
|
||||
super().__init__(**kwargs)
|
||||
self._create_db_file()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
_LOGGER.debug(f"Setting DatabaseService original state on node {self.software_manager.node.hostname}")
|
||||
super().set_original_state()
|
||||
vals_to_include = {
|
||||
"password",
|
||||
"connections",
|
||||
"backup_server_ip",
|
||||
"latest_backup_directory",
|
||||
"latest_backup_file_name",
|
||||
}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}")
|
||||
self.clear_connections()
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def configure_backup(self, backup_server: IPv4Address):
|
||||
"""
|
||||
Set up the database backup.
|
||||
|
||||
@@ -29,18 +29,6 @@ class DNSClient(Service):
|
||||
super().__init__(**kwargs)
|
||||
self.start()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
_LOGGER.debug(f"Setting DNSClient original state on node {self.software_manager.node.hostname}")
|
||||
super().set_original_state()
|
||||
vals_to_include = {"dns_server"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.dns_cache.clear()
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the software.
|
||||
|
||||
@@ -28,20 +28,6 @@ class DNSServer(Service):
|
||||
super().__init__(**kwargs)
|
||||
self.start()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
_LOGGER.debug(f"Setting DNSServer original state on node {self.software_manager.node.hostname}")
|
||||
super().set_original_state()
|
||||
vals_to_include = {"dns_table"}
|
||||
self._original_state["dns_table_orig"] = self.model_dump(include=vals_to_include)["dns_table"]
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.dns_table.clear()
|
||||
for key, value in self._original_state["dns_table_orig"].items():
|
||||
self.dns_table[key] = value
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the software.
|
||||
|
||||
@@ -27,18 +27,6 @@ class FTPClient(FTPServiceABC):
|
||||
super().__init__(**kwargs)
|
||||
self.start()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
_LOGGER.debug(f"Setting FTPClient original state on node {self.software_manager.node.hostname}")
|
||||
super().set_original_state()
|
||||
vals_to_include = {"connected"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting FTPClient state on node {self.software_manager.node.hostname}")
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
|
||||
"""
|
||||
Process the command in the FTP Packet.
|
||||
|
||||
@@ -27,19 +27,6 @@ class FTPServer(FTPServiceABC):
|
||||
super().__init__(**kwargs)
|
||||
self.start()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
_LOGGER.debug(f"Setting FTPServer original state on node {self.software_manager.node.hostname}")
|
||||
super().set_original_state()
|
||||
vals_to_include = {"server_password"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}")
|
||||
self.clear_connections()
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
|
||||
"""
|
||||
Process the command in the FTP Packet.
|
||||
|
||||
@@ -49,21 +49,12 @@ class NTPClient(Service):
|
||||
state = super().describe_state()
|
||||
return state
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the Service component for a new episode.
|
||||
|
||||
This method ensures the Service is ready for a new episode, including resetting any
|
||||
stateful properties or statistics, and clearing any message queues.
|
||||
"""
|
||||
pass
|
||||
|
||||
def send(
|
||||
self,
|
||||
payload: NTPPacket,
|
||||
session_id: Optional[str] = None,
|
||||
dest_ip_address: IPv4Address = None,
|
||||
dest_port: [Port] = Port.NTP,
|
||||
dest_port: Port = Port.NTP,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""Requests NTP data from NTP server.
|
||||
|
||||
@@ -34,16 +34,6 @@ class NTPServer(Service):
|
||||
state = super().describe_state()
|
||||
return state
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
Resets the Service component for a new episode.
|
||||
|
||||
This method ensures the Service is ready for a new episode, including
|
||||
resetting any stateful properties or statistics, and clearing any message
|
||||
queues.
|
||||
"""
|
||||
pass
|
||||
|
||||
def receive(
|
||||
self,
|
||||
payload: NTPPacket,
|
||||
|
||||
@@ -78,12 +78,6 @@ class Service(IOSoftware):
|
||||
"""
|
||||
return super().receive(payload=payload, session_id=session_id, **kwargs)
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
vals_to_include = {"operating_state", "restart_duration", "restart_countdown"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request("scan", RequestType(func=lambda request, context: self.scan()))
|
||||
|
||||
@@ -23,18 +23,6 @@ class WebServer(Service):
|
||||
|
||||
last_response_status_code: Optional[HttpStatusCode] = None
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
_LOGGER.debug(f"Setting WebServer original state on node {self.software_manager.node.hostname}")
|
||||
super().set_original_state()
|
||||
vals_to_include = {"last_response_status_code"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting WebServer state on node {self.software_manager.node.hostname}")
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
@@ -130,7 +118,7 @@ class WebServer(Service):
|
||||
self.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
|
||||
return response
|
||||
except Exception:
|
||||
except Exception: # TODO: refactor this. Likely to cause silent bugs. (ADO ticket #2345 )
|
||||
# something went wrong on the server
|
||||
response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR
|
||||
return response
|
||||
|
||||
@@ -3,7 +3,7 @@ from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.file_system.file_system import FileSystem, Folder
|
||||
@@ -13,6 +13,9 @@ from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.session_manager import Session
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
|
||||
|
||||
class SoftwareType(Enum):
|
||||
"""
|
||||
@@ -84,7 +87,7 @@ class Software(SimComponent):
|
||||
"The count of times the software has been scanned, defaults to 0."
|
||||
revealed_to_red: bool = False
|
||||
"Indicates if the software has been revealed to red agent, defaults is False."
|
||||
software_manager: Any = None
|
||||
software_manager: "SoftwareManager" = None
|
||||
"An instance of Software Manager that is used by the parent node."
|
||||
sys_log: SysLog = None
|
||||
"An instance of SysLog that is used by the parent node."
|
||||
@@ -97,19 +100,6 @@ class Software(SimComponent):
|
||||
_patching_countdown: Optional[int] = None
|
||||
"Current number of ticks left to patch the software."
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_include = {
|
||||
"name",
|
||||
"health_state_actual",
|
||||
"health_state_visible",
|
||||
"criticality",
|
||||
"patching_count",
|
||||
"scanning_count",
|
||||
"revealed_to_red",
|
||||
}
|
||||
self._original_state = self.model_dump(include=vals_to_include)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
@@ -248,12 +238,6 @@ class IOSoftware(Software):
|
||||
_connections: Dict[str, Dict] = {}
|
||||
"Active connections."
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
vals_to_include = {"installing_count", "max_sessions", "tcp", "udp", "port"}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
@abstractmethod
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
|
||||
@@ -589,15 +589,16 @@ simulation:
|
||||
hostname: web_server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.10
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
applications:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
|
||||
@@ -593,15 +593,16 @@ simulation:
|
||||
hostname: web_server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.10
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
applications:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
@@ -624,7 +625,7 @@ simulation:
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: backup_service
|
||||
type: DatabaseBackup
|
||||
type: FTPServer
|
||||
|
||||
- ref: security_suite
|
||||
type: server
|
||||
|
||||
@@ -1043,16 +1043,16 @@ simulation:
|
||||
hostname: web_server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.10
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
applications:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
type: server
|
||||
@@ -1074,7 +1074,7 @@ simulation:
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: backup_service
|
||||
type: DatabaseBackup
|
||||
type: FTPServer
|
||||
|
||||
- ref: security_suite
|
||||
type: server
|
||||
|
||||
@@ -599,15 +599,16 @@ simulation:
|
||||
hostname: web_server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.10
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
applications:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
@@ -630,7 +631,7 @@ simulation:
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: backup_service
|
||||
type: DatabaseBackup
|
||||
type: FTPServer
|
||||
|
||||
- ref: security_suite
|
||||
type: server
|
||||
|
||||
@@ -600,15 +600,16 @@ simulation:
|
||||
hostname: web_server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.10
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
applications:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
@@ -631,7 +632,7 @@ simulation:
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: backup_service
|
||||
type: DatabaseBackup
|
||||
type: FTPServer
|
||||
|
||||
- ref: security_suite
|
||||
type: server
|
||||
|
||||
@@ -510,6 +510,6 @@ def game_and_agent():
|
||||
reward_function=reward_function,
|
||||
)
|
||||
|
||||
game.agents.append(test_agent)
|
||||
game.agents["test_agent"] = test_agent
|
||||
|
||||
return (game, test_agent)
|
||||
|
||||
@@ -11,14 +11,12 @@ from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
|
||||
|
||||
# @pytest.mark.skip(reason="no way of currently testing this")
|
||||
def test_sb3_compatibility():
|
||||
"""Test that the Gymnasium environment can be used with an SB3 agent."""
|
||||
with open(example_config_path(), "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
gym = PrimaiteGymEnv(game=game)
|
||||
gym = PrimaiteGymEnv(game_config=cfg)
|
||||
model = PPO("MlpPolicy", gym)
|
||||
|
||||
model.learn(total_timesteps=1000)
|
||||
|
||||
@@ -42,20 +42,20 @@ def test_example_config():
|
||||
assert len(game.agents) == 4 # red, blue and 2 green agents
|
||||
|
||||
# green agent 1
|
||||
assert game.agents[0].agent_name == "client_2_green_user"
|
||||
assert isinstance(game.agents[0], RandomAgent)
|
||||
assert "client_2_green_user" in game.agents
|
||||
assert isinstance(game.agents["client_2_green_user"], RandomAgent)
|
||||
|
||||
# green agent 2
|
||||
assert game.agents[1].agent_name == "client_1_green_user"
|
||||
assert isinstance(game.agents[1], RandomAgent)
|
||||
assert "client_1_green_user" in game.agents
|
||||
assert isinstance(game.agents["client_1_green_user"], RandomAgent)
|
||||
|
||||
# red agent
|
||||
assert game.agents[2].agent_name == "client_1_data_manipulation_red_bot"
|
||||
assert isinstance(game.agents[2], DataManipulationAgent)
|
||||
assert "client_1_data_manipulation_red_bot" in game.agents
|
||||
assert isinstance(game.agents["client_1_data_manipulation_red_bot"], DataManipulationAgent)
|
||||
|
||||
# blue agent
|
||||
assert game.agents[3].agent_name == "defender"
|
||||
assert isinstance(game.agents[3], ProxyAgent)
|
||||
assert "defender" in game.agents
|
||||
assert isinstance(game.agents["defender"], ProxyAgent)
|
||||
|
||||
network: Network = game.simulation.network
|
||||
|
||||
|
||||
@@ -7,26 +7,11 @@ from primaite.simulator.domain.account import Account, AccountType
|
||||
@pytest.fixture(scope="function")
|
||||
def account() -> Account:
|
||||
acct = Account(username="Jake", password="totally_hashed_password", account_type=AccountType.USER)
|
||||
acct.set_original_state()
|
||||
return acct
|
||||
|
||||
|
||||
def test_original_state(account):
|
||||
"""Test the original state - see if it resets properly"""
|
||||
account.log_on()
|
||||
account.log_off()
|
||||
account.disable()
|
||||
|
||||
state = account.describe_state()
|
||||
assert state["num_logons"] is 1
|
||||
assert state["num_logoffs"] is 1
|
||||
assert state["num_group_changes"] is 0
|
||||
assert state["username"] is "Jake"
|
||||
assert state["password"] is "totally_hashed_password"
|
||||
assert state["account_type"] is AccountType.USER.value
|
||||
assert state["enabled"] is False
|
||||
|
||||
account.reset_component_for_episode(episode=1)
|
||||
state = account.describe_state()
|
||||
assert state["num_logons"] is 0
|
||||
assert state["num_logoffs"] is 0
|
||||
@@ -39,13 +24,7 @@ def test_original_state(account):
|
||||
account.log_on()
|
||||
account.log_off()
|
||||
account.disable()
|
||||
account.set_original_state()
|
||||
|
||||
account.log_on()
|
||||
state = account.describe_state()
|
||||
assert state["num_logons"] is 2
|
||||
|
||||
account.reset_component_for_episode(episode=2)
|
||||
state = account.describe_state()
|
||||
assert state["num_logons"] is 1
|
||||
assert state["num_logoffs"] is 1
|
||||
|
||||
@@ -185,38 +185,6 @@ def test_get_file(file_system):
|
||||
file_system.show(full=True)
|
||||
|
||||
|
||||
def test_reset_file_system(file_system):
|
||||
# file and folder that existed originally
|
||||
file_system.create_file(file_name="test_file.zip")
|
||||
file_system.create_folder(folder_name="test_folder")
|
||||
file_system.set_original_state()
|
||||
|
||||
# create a new file
|
||||
file_system.create_file(file_name="new_file.txt")
|
||||
|
||||
# create a new folder
|
||||
file_system.create_folder(folder_name="new_folder")
|
||||
|
||||
# delete the file that existed originally
|
||||
file_system.delete_file(folder_name="root", file_name="test_file.zip")
|
||||
assert file_system.get_file(folder_name="root", file_name="test_file.zip") is None
|
||||
|
||||
# delete the folder that existed originally
|
||||
file_system.delete_folder(folder_name="test_folder")
|
||||
assert file_system.get_folder(folder_name="test_folder") is None
|
||||
|
||||
# reset
|
||||
file_system.reset_component_for_episode(episode=1)
|
||||
|
||||
# deleted original file and folder should be back
|
||||
assert file_system.get_file(folder_name="root", file_name="test_file.zip")
|
||||
assert file_system.get_folder(folder_name="test_folder")
|
||||
|
||||
# new file and folder should be removed
|
||||
assert file_system.get_file(folder_name="root", file_name="new_file.txt") is None
|
||||
assert file_system.get_folder(folder_name="new_folder") is None
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping until we tackle serialisation")
|
||||
def test_serialisation(file_system):
|
||||
"""Test to check that the object serialisation works correctly."""
|
||||
|
||||
@@ -31,7 +31,6 @@ def network(example_network) -> Network:
|
||||
assert len(example_network.computers) is 2
|
||||
assert len(example_network.servers) is 2
|
||||
|
||||
example_network.set_original_state()
|
||||
example_network.show()
|
||||
|
||||
return example_network
|
||||
@@ -45,40 +44,6 @@ def test_describe_state(network):
|
||||
assert len(state["links"]) is 6
|
||||
|
||||
|
||||
def test_reset_network(network):
|
||||
"""
|
||||
Test that the network is properly reset.
|
||||
|
||||
TODO: make sure that once implemented - any installed/uninstalled services, processes, apps,
|
||||
etc are also removed/reinstalled
|
||||
|
||||
"""
|
||||
state_before = network.describe_state()
|
||||
|
||||
client_1: Computer = network.get_node_by_hostname("client_1")
|
||||
server_1: Computer = network.get_node_by_hostname("server_1")
|
||||
|
||||
assert client_1.operating_state is NodeOperatingState.ON
|
||||
assert server_1.operating_state is NodeOperatingState.ON
|
||||
|
||||
client_1.power_off()
|
||||
assert client_1.operating_state is NodeOperatingState.SHUTTING_DOWN
|
||||
|
||||
server_1.power_off()
|
||||
assert server_1.operating_state is NodeOperatingState.SHUTTING_DOWN
|
||||
|
||||
assert network.describe_state() != state_before
|
||||
|
||||
network.reset_component_for_episode(episode=1)
|
||||
|
||||
assert client_1.operating_state is NodeOperatingState.ON
|
||||
assert server_1.operating_state is NodeOperatingState.ON
|
||||
# don't worry if UUIDs change
|
||||
a = filter_keys_nested_item(json.dumps(network.describe_state(), sort_keys=True, indent=2), ["uuid"])
|
||||
b = filter_keys_nested_item(json.dumps(state_before, sort_keys=True, indent=2), ["uuid"])
|
||||
assert a == b
|
||||
|
||||
|
||||
def test_creating_container():
|
||||
"""Check that we can create a network container"""
|
||||
net = Network()
|
||||
|
||||
@@ -19,7 +19,6 @@ def dos_bot() -> DoSBot:
|
||||
|
||||
dos_bot: DoSBot = computer.software_manager.software.get("DoSBot")
|
||||
dos_bot.configure(target_ip_address=IPv4Address("192.168.0.1"))
|
||||
dos_bot.set_original_state()
|
||||
return dos_bot
|
||||
|
||||
|
||||
@@ -28,35 +27,6 @@ def test_dos_bot_creation(dos_bot):
|
||||
assert dos_bot is not None
|
||||
|
||||
|
||||
def test_dos_bot_reset(dos_bot):
|
||||
assert dos_bot.target_ip_address == IPv4Address("192.168.0.1")
|
||||
assert dos_bot.target_port is Port.POSTGRES_SERVER
|
||||
assert dos_bot.payload is None
|
||||
assert dos_bot.repeat is False
|
||||
|
||||
dos_bot.configure(
|
||||
target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True
|
||||
)
|
||||
|
||||
# should reset the relevant items
|
||||
dos_bot.reset_component_for_episode(episode=0)
|
||||
assert dos_bot.target_ip_address == IPv4Address("192.168.0.1")
|
||||
assert dos_bot.target_port is Port.POSTGRES_SERVER
|
||||
assert dos_bot.payload is None
|
||||
assert dos_bot.repeat is False
|
||||
|
||||
dos_bot.configure(
|
||||
target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True
|
||||
)
|
||||
dos_bot.set_original_state()
|
||||
dos_bot.reset_component_for_episode(episode=1)
|
||||
# should reset to the configured value
|
||||
assert dos_bot.target_ip_address == IPv4Address("192.168.1.1")
|
||||
assert dos_bot.target_port is Port.HTTP
|
||||
assert dos_bot.payload == "payload"
|
||||
assert dos_bot.repeat is True
|
||||
|
||||
|
||||
def test_dos_bot_cannot_run_when_node_offline(dos_bot):
|
||||
dos_bot_node: Computer = dos_bot.parent
|
||||
assert dos_bot_node.operating_state is NodeOperatingState.ON
|
||||
|
||||
@@ -2,12 +2,14 @@ from typing import Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.software import Software, SoftwareHealthState
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
|
||||
|
||||
|
||||
class TestSoftware(Software):
|
||||
class TestSoftware(Service):
|
||||
def describe_state(self) -> Dict:
|
||||
pass
|
||||
|
||||
@@ -15,7 +17,11 @@ class TestSoftware(Software):
|
||||
@pytest.fixture(scope="function")
|
||||
def software(file_system):
|
||||
return TestSoftware(
|
||||
name="TestSoftware", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_service")
|
||||
name="TestSoftware",
|
||||
port=Port.ARP,
|
||||
file_system=file_system,
|
||||
sys_log=SysLog(hostname="test_service"),
|
||||
protocol=IPProtocol.TCP,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user