#2887 - Resolve conflicts from merge
This commit is contained in:
@@ -30,35 +30,22 @@ agents:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_2
|
||||
applications:
|
||||
- application_name: WebBrowser
|
||||
- application_name: DatabaseClient
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 2
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
action: do_nothing
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
action: node_application_execute
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
node_name: client_2
|
||||
application_name: WebBrowser
|
||||
2:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
action: node_application_execute
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 1
|
||||
node_name: client_2
|
||||
application_name: DatabaseClient
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
@@ -79,35 +66,22 @@ agents:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: WebBrowser
|
||||
- application_name: DatabaseClient
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 2
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
action: do_nothing
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
action: node_application_execute
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
node_name: client_1
|
||||
application_name: WebBrowser
|
||||
2:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
action: node_application_execute
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 1
|
||||
node_name: client_1
|
||||
application_name: WebBrowser
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
@@ -128,33 +102,12 @@ agents:
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
- node_name: client_2
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
|
||||
start_settings:
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
agent_settings:
|
||||
possible_start_nodes: [client_1, client_2]
|
||||
target_application: DataManipulationBot
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
@@ -208,8 +161,8 @@ agents:
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
@@ -235,490 +188,426 @@ agents:
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SERVICE_SCAN
|
||||
- type: NODE_SERVICE_STOP
|
||||
- type: NODE_SERVICE_START
|
||||
- type: NODE_SERVICE_PAUSE
|
||||
- type: NODE_SERVICE_RESUME
|
||||
- type: NODE_SERVICE_RESTART
|
||||
- type: NODE_SERVICE_DISABLE
|
||||
- type: NODE_SERVICE_ENABLE
|
||||
- type: NODE_SERVICE_FIX
|
||||
- type: NODE_FILE_SCAN
|
||||
- type: NODE_FILE_CHECKHASH
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_REPAIR
|
||||
- type: NODE_FILE_RESTORE
|
||||
- type: NODE_FOLDER_SCAN
|
||||
- type: NODE_FOLDER_CHECKHASH
|
||||
- type: NODE_FOLDER_REPAIR
|
||||
- type: NODE_FOLDER_RESTORE
|
||||
- type: NODE_OS_SCAN
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: NODE_RESET
|
||||
- type: ROUTER_ACL_ADDRULE
|
||||
- type: ROUTER_ACL_REMOVERULE
|
||||
- type: HOST_NIC_ENABLE
|
||||
- type: HOST_NIC_DISABLE
|
||||
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
action: do_nothing
|
||||
options: {}
|
||||
# scan webapp service
|
||||
1:
|
||||
action: NODE_SERVICE_SCAN
|
||||
action: node_service_scan
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
node_name: web_server
|
||||
service_name: WebServer
|
||||
# stop webapp service
|
||||
2:
|
||||
action: NODE_SERVICE_STOP
|
||||
action: node_service_stop
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
node_name: web_server
|
||||
service_name: WebServer
|
||||
# start webapp service
|
||||
3:
|
||||
action: "NODE_SERVICE_START"
|
||||
action: "node_service_start"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
node_name: web_server
|
||||
service_name: WebServer
|
||||
4:
|
||||
action: "NODE_SERVICE_PAUSE"
|
||||
action: "node_service_pause"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
node_name: web_server
|
||||
service_name: WebServer
|
||||
5:
|
||||
action: "NODE_SERVICE_RESUME"
|
||||
action: "node_service_resume"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
node_name: web_server
|
||||
service_name: WebServer
|
||||
6:
|
||||
action: "NODE_SERVICE_RESTART"
|
||||
action: "node_service_restart"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
node_name: web_server
|
||||
service_name: WebServer
|
||||
7:
|
||||
action: "NODE_SERVICE_DISABLE"
|
||||
action: "node_service_disable"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
node_name: web_server
|
||||
service_name: WebServer
|
||||
8:
|
||||
action: "NODE_SERVICE_ENABLE"
|
||||
action: "node_service_enable"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
node_name: web_server
|
||||
service_name: WebServer
|
||||
9: # check database.db file
|
||||
action: "NODE_FILE_SCAN"
|
||||
action: "node_file_scan"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
node_name: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
10:
|
||||
action: "NODE_FILE_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
|
||||
action: "node_file_scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
node_name: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
11:
|
||||
action: "NODE_FILE_DELETE"
|
||||
action: "node_file_delete"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
node_name: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
12:
|
||||
action: "NODE_FILE_REPAIR"
|
||||
action: "node_file_repair"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
node_name: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
13:
|
||||
action: "NODE_SERVICE_FIX"
|
||||
action: "node_service_fix"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 0
|
||||
node_name: database_server
|
||||
service_name: DatabaseService
|
||||
14:
|
||||
action: "NODE_FOLDER_SCAN"
|
||||
action: "node_folder_scan"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
node_name: database_server
|
||||
folder_name: database
|
||||
15:
|
||||
action: "NODE_FOLDER_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
|
||||
action: "node_folder_scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
node_name: database_server
|
||||
folder_name: database
|
||||
16:
|
||||
action: "NODE_FOLDER_REPAIR"
|
||||
action: "node_folder_repair"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
node_name: database_server
|
||||
folder_name: database
|
||||
17:
|
||||
action: "NODE_FOLDER_RESTORE"
|
||||
action: "node_folder_restore"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
node_name: database_server
|
||||
folder_name: database
|
||||
18:
|
||||
action: "NODE_OS_SCAN"
|
||||
action: "node_os_scan"
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: domain_controller
|
||||
19:
|
||||
action: "NODE_SHUTDOWN"
|
||||
action: "node_shutdown"
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: domain_controller
|
||||
20:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: domain_controller
|
||||
21:
|
||||
action: NODE_RESET
|
||||
action: node_reset
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: domain_controller
|
||||
22:
|
||||
action: "NODE_OS_SCAN"
|
||||
action: "node_os_scan"
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: web_server
|
||||
23:
|
||||
action: "NODE_SHUTDOWN"
|
||||
action: "node_shutdown"
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: web_server
|
||||
24:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: web_server
|
||||
25:
|
||||
action: NODE_RESET
|
||||
action: node_reset
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: web_server
|
||||
26: # old action num: 18
|
||||
action: "NODE_OS_SCAN"
|
||||
action: "node_os_scan"
|
||||
options:
|
||||
node_id: 2
|
||||
node_name: database_server
|
||||
27:
|
||||
action: "NODE_SHUTDOWN"
|
||||
action: "node_shutdown"
|
||||
options:
|
||||
node_id: 2
|
||||
node_name: database_server
|
||||
28:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 2
|
||||
node_name: database_server
|
||||
29:
|
||||
action: NODE_RESET
|
||||
action: node_reset
|
||||
options:
|
||||
node_id: 2
|
||||
node_name: database_server
|
||||
30:
|
||||
action: "NODE_OS_SCAN"
|
||||
action: "node_os_scan"
|
||||
options:
|
||||
node_id: 3
|
||||
node_name: backup_server
|
||||
31:
|
||||
action: "NODE_SHUTDOWN"
|
||||
action: "node_shutdown"
|
||||
options:
|
||||
node_id: 3
|
||||
node_name: backup_server
|
||||
32:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 3
|
||||
node_name: backup_server
|
||||
33:
|
||||
action: NODE_RESET
|
||||
action: node_reset
|
||||
options:
|
||||
node_id: 3
|
||||
node_name: backup_server
|
||||
34:
|
||||
action: "NODE_OS_SCAN"
|
||||
action: "node_os_scan"
|
||||
options:
|
||||
node_id: 4
|
||||
node_name: security_suite
|
||||
35:
|
||||
action: "NODE_SHUTDOWN"
|
||||
action: "node_shutdown"
|
||||
options:
|
||||
node_id: 4
|
||||
node_name: security_suite
|
||||
36:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 4
|
||||
node_name: security_suite
|
||||
37:
|
||||
action: NODE_RESET
|
||||
action: node_reset
|
||||
options:
|
||||
node_id: 4
|
||||
node_name: security_suite
|
||||
38:
|
||||
action: "NODE_OS_SCAN"
|
||||
action: "node_os_scan"
|
||||
options:
|
||||
node_id: 5
|
||||
node_name: client_1
|
||||
39: # old action num: 19 # shutdown client 1
|
||||
action: "NODE_SHUTDOWN"
|
||||
action: "node_shutdown"
|
||||
options:
|
||||
node_id: 5
|
||||
node_name: client_1
|
||||
40: # old action num: 20
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 5
|
||||
node_name: client_1
|
||||
41: # old action num: 21
|
||||
action: NODE_RESET
|
||||
action: node_reset
|
||||
options:
|
||||
node_id: 5
|
||||
node_name: client_1
|
||||
42:
|
||||
action: "NODE_OS_SCAN"
|
||||
action: "node_os_scan"
|
||||
options:
|
||||
node_id: 6
|
||||
node_name: client_2
|
||||
43:
|
||||
action: "NODE_SHUTDOWN"
|
||||
action: "node_shutdown"
|
||||
options:
|
||||
node_id: 6
|
||||
node_name: client_2
|
||||
44:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 6
|
||||
node_name: client_2
|
||||
45:
|
||||
action: NODE_RESET
|
||||
action: node_reset
|
||||
options:
|
||||
node_id: 6
|
||||
node_name: client_2
|
||||
|
||||
46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1"
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
action: "router_acl_add_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
permission: DENY
|
||||
src_ip: 192.168.10.21 # client 1
|
||||
dst_ip: ALL # ALL
|
||||
src_port: ALL
|
||||
dst_port: ALL
|
||||
protocol_name: ALL
|
||||
src_wildcard: NONE
|
||||
dst_wildcard: NONE
|
||||
47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2"
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
action: "router_acl_add_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 2
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
permission: DENY
|
||||
src_ip: 192.168.10.22 # client 2
|
||||
dst_ip: ALL # ALL
|
||||
src_port: ALL
|
||||
dst_port: ALL
|
||||
protocol_name: ALL
|
||||
src_wildcard: NONE
|
||||
dst_wildcard: NONE
|
||||
48: # old action num: 24 # block tcp traffic from client 1 to web app
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
action: "router_acl_add_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 3
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
permission: DENY
|
||||
src_ip: 192.168.10.21 # client 1
|
||||
dst_ip: 192.168.1.12 # web server
|
||||
src_port: ALL
|
||||
dst_port: ALL
|
||||
protocol_name: TCP
|
||||
src_wildcard: NONE
|
||||
dst_wildcard: NONE
|
||||
49: # old action num: 25 # block tcp traffic from client 2 to web app
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
action: "router_acl_add_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 4
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
permission: DENY
|
||||
src_ip: 192.168.10.22 # client 2
|
||||
dst_ip: 192.168.1.12 # web server
|
||||
src_port: ALL
|
||||
dst_port: ALL
|
||||
protocol_name: TCP
|
||||
src_wildcard: NONE
|
||||
dst_wildcard: NONE
|
||||
50: # old action num: 26
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
action: "router_acl_add_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 5
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
permission: DENY
|
||||
src_ip: 192.168.10.21 # client 1
|
||||
dst_ip: 192.168.1.14 # database
|
||||
src_port: ALL
|
||||
dst_port: ALL
|
||||
protocol_name: TCP
|
||||
src_wildcard: NONE
|
||||
dst_wildcard: NONE
|
||||
51: # old action num: 27
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
action: "router_acl_add_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 6
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
permission: DENY
|
||||
src_ip: 192.168.10.22 # client 2
|
||||
dst_ip: 192.168.1.14 # database
|
||||
src_port: ALL
|
||||
dst_port: ALL
|
||||
protocol_name: TCP
|
||||
src_wildcard: NONE
|
||||
dst_wildcard: NONE
|
||||
52: # old action num: 28
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 0
|
||||
53: # old action num: 29
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 1
|
||||
54: # old action num: 30
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 2
|
||||
55: # old action num: 31
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 3
|
||||
56: # old action num: 32
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 4
|
||||
57: # old action num: 33
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 5
|
||||
58: # old action num: 34
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 6
|
||||
59: # old action num: 35
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 7
|
||||
60: # old action num: 36
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 8
|
||||
61: # old action num: 37
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 9
|
||||
62: # old action num: 38
|
||||
action: "HOST_NIC_DISABLE"
|
||||
action: "host_nic_disable"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
node_name: domain_controller
|
||||
nic_num: 1
|
||||
63: # old action num: 39
|
||||
action: "HOST_NIC_ENABLE"
|
||||
action: "host_nic_enable"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
node_name: domain_controller
|
||||
nic_num: 1
|
||||
64: # old action num: 40
|
||||
action: "HOST_NIC_DISABLE"
|
||||
action: "host_nic_disable"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
node_name: web_server
|
||||
nic_num: 1
|
||||
65: # old action num: 41
|
||||
action: "HOST_NIC_ENABLE"
|
||||
action: "host_nic_enable"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
node_name: web_server
|
||||
nic_num: 1
|
||||
66: # old action num: 42
|
||||
action: "HOST_NIC_DISABLE"
|
||||
action: "host_nic_disable"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 0
|
||||
node_name: database_server
|
||||
nic_num: 1
|
||||
67: # old action num: 43
|
||||
action: "HOST_NIC_ENABLE"
|
||||
action: "host_nic_enable"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 0
|
||||
node_name: database_server
|
||||
nic_num: 1
|
||||
68: # old action num: 44
|
||||
action: "HOST_NIC_DISABLE"
|
||||
action: "host_nic_disable"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 0
|
||||
node_name: backup_server
|
||||
nic_num: 1
|
||||
69: # old action num: 45
|
||||
action: "HOST_NIC_ENABLE"
|
||||
action: "host_nic_enable"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 0
|
||||
node_name: backup_server
|
||||
nic_num: 1
|
||||
70: # old action num: 46
|
||||
action: "HOST_NIC_DISABLE"
|
||||
action: "host_nic_disable"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 0
|
||||
node_name: security_suite
|
||||
nic_num: 1
|
||||
71: # old action num: 47
|
||||
action: "HOST_NIC_ENABLE"
|
||||
action: "host_nic_enable"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 0
|
||||
node_name: security_suite
|
||||
nic_num: 1
|
||||
72: # old action num: 48
|
||||
action: "HOST_NIC_DISABLE"
|
||||
action: "host_nic_disable"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
node_name: security_suite
|
||||
nic_num: 2
|
||||
73: # old action num: 49
|
||||
action: "HOST_NIC_ENABLE"
|
||||
action: "host_nic_enable"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
node_name: security_suite
|
||||
nic_num: 2
|
||||
74: # old action num: 50
|
||||
action: "HOST_NIC_DISABLE"
|
||||
action: "host_nic_disable"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 0
|
||||
node_name: client_1
|
||||
nic_num: 1
|
||||
75: # old action num: 51
|
||||
action: "HOST_NIC_ENABLE"
|
||||
action: "host_nic_enable"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 0
|
||||
node_name: client_1
|
||||
nic_num: 1
|
||||
76: # old action num: 52
|
||||
action: "HOST_NIC_DISABLE"
|
||||
action: "host_nic_disable"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 0
|
||||
node_name: client_2
|
||||
nic_num: 1
|
||||
77: # old action num: 53
|
||||
action: "HOST_NIC_ENABLE"
|
||||
action: "host_nic_enable"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 0
|
||||
node_name: client_2
|
||||
nic_num: 1
|
||||
|
||||
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_name: domain_controller
|
||||
- node_name: web_server
|
||||
applications:
|
||||
- application_name: DatabaseClient
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- node_name: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
services:
|
||||
- service_name: DatabaseService
|
||||
- node_name: backup_server
|
||||
- node_name: security_suite
|
||||
- node_name: client_1
|
||||
- node_name: client_2
|
||||
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,68 +6,48 @@ game:
|
||||
agents:
|
||||
- ref: RL_Agent
|
||||
type: ProxyAgent
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: HOST_NIC_ENABLE
|
||||
- type: HOST_NIC_DISABLE
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
action: do_nothing
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_SHUTDOWN
|
||||
action: node_shutdown
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: client_1
|
||||
2:
|
||||
action: NODE_SHUTDOWN
|
||||
action: node_shutdown
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: server
|
||||
3:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: client_1
|
||||
4:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: server
|
||||
5:
|
||||
action: HOST_NIC_DISABLE
|
||||
action: host_nic_disable
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
node_name: client_1
|
||||
nic_num: 1
|
||||
6:
|
||||
action: HOST_NIC_DISABLE
|
||||
action: host_nic_disable
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
node_name: server
|
||||
nic_num: 1
|
||||
7:
|
||||
action: HOST_NIC_ENABLE
|
||||
action: host_nic_enable
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
node_name: client_1
|
||||
nic_num: 1
|
||||
8:
|
||||
action: HOST_NIC_ENABLE
|
||||
action: host_nic_enable
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
- node_name: server
|
||||
max_folders_per_node: 0
|
||||
max_files_per_folder: 0
|
||||
max_services_per_node: 0
|
||||
max_nics_per_node: 1
|
||||
max_acl_rules: 0
|
||||
ip_list:
|
||||
- 192.168.1.2
|
||||
- 192.168.1.3
|
||||
reward_function:
|
||||
reward_components: []
|
||||
node_name: server
|
||||
nic_num: 1
|
||||
|
||||
simulation:
|
||||
network:
|
||||
|
||||
@@ -6,25 +6,17 @@ agents: &greens
|
||||
action_probabilities:
|
||||
0: 0.2
|
||||
1: 0.8
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DatabaseClient
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
action: do_nothing
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
action: node_application_execute
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
node_name: client
|
||||
application_name: DatabaseClient
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
@@ -6,25 +6,17 @@ agents: &greens
|
||||
action_probabilities:
|
||||
0: 0.95
|
||||
1: 0.05
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DatabaseClient
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
action: do_nothing
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
action: node_application_execute
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
node_name: client
|
||||
application_name: DatabaseClient
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
@@ -3,24 +3,9 @@ reds: &reds
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_settings:
|
||||
start_step: 10
|
||||
frequency: 10
|
||||
variance: 0
|
||||
possible_start_nodes: [client,]
|
||||
target_application: DataManipulationBot
|
||||
start_step: 10
|
||||
frequency: 10
|
||||
variance: 0
|
||||
|
||||
@@ -3,24 +3,9 @@ reds: &reds
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_settings:
|
||||
start_step: 3
|
||||
frequency: 2
|
||||
variance: 1
|
||||
possible_start_nodes: [client_1]
|
||||
target_application: DataManipulationBot
|
||||
start_step: 3
|
||||
frequency: 2
|
||||
variance: 1
|
||||
|
||||
@@ -54,65 +54,46 @@ agents:
|
||||
- server:eth-1<->switch_1:eth-2
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: HOST_NIC_ENABLE
|
||||
- type: HOST_NIC_DISABLE
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
action: do_nothing
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_SHUTDOWN
|
||||
action: node_shutdown
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: client_1
|
||||
2:
|
||||
action: NODE_SHUTDOWN
|
||||
action: node_shutdown
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: server
|
||||
3:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: client_1
|
||||
4:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: server
|
||||
5:
|
||||
action: HOST_NIC_DISABLE
|
||||
action: host_nic_disable
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
node_name: client_1
|
||||
nic_num: 1
|
||||
6:
|
||||
action: HOST_NIC_DISABLE
|
||||
action: host_nic_disable
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
node_name: server
|
||||
nic_num: 1
|
||||
7:
|
||||
action: HOST_NIC_ENABLE
|
||||
action: host_nic_enable
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
node_name: client_1
|
||||
nic_num: 1
|
||||
8:
|
||||
action: HOST_NIC_ENABLE
|
||||
action: host_nic_enable
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
- node_name: server
|
||||
|
||||
max_folders_per_node: 0
|
||||
max_files_per_folder: 0
|
||||
max_services_per_node: 0
|
||||
max_nics_per_node: 1
|
||||
max_acl_rules: 0
|
||||
ip_list:
|
||||
- 192.168.1.2
|
||||
- 192.168.1.3
|
||||
node_name: server
|
||||
nic_num: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
33
src/primaite/game/agent/actions/__init__.py
Normal file
33
src/primaite/game/agent/actions/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
|
||||
from primaite.game.agent.actions import (
|
||||
abstract,
|
||||
acl,
|
||||
application,
|
||||
file,
|
||||
folder,
|
||||
host_nic,
|
||||
manager,
|
||||
network,
|
||||
node,
|
||||
service,
|
||||
session,
|
||||
software,
|
||||
)
|
||||
from primaite.game.agent.actions.manager import ActionManager
|
||||
|
||||
__all__ = (
|
||||
"abstract",
|
||||
"acl",
|
||||
"application",
|
||||
"software",
|
||||
"file",
|
||||
"folder",
|
||||
"host_nic",
|
||||
"manager",
|
||||
"network",
|
||||
"node",
|
||||
"service",
|
||||
"session",
|
||||
"ActionManager",
|
||||
)
|
||||
36
src/primaite/game/agent/actions/abstract.py
Normal file
36
src/primaite/game/agent/actions/abstract.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import Any, ClassVar, Dict, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
|
||||
class AbstractAction(BaseModel, ABC):
|
||||
"""Base class for actions."""
|
||||
|
||||
config: "AbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(BaseModel, ABC):
|
||||
"""Base configuration schema for Actions."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
type: str = ""
|
||||
|
||||
_registry: ClassVar[Dict[str, Type[AbstractAction]]] = {}
|
||||
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Cannot create new action under reserved name {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
pass
|
||||
157
src/primaite/game/agent/actions/acl.py
Normal file
157
src/primaite/game/agent/actions/acl.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import Literal, Union
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol
|
||||
from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import Port
|
||||
|
||||
__all__ = (
|
||||
"RouterACLAddRuleAction",
|
||||
"RouterACLRemoveRuleAction",
|
||||
"FirewallACLAddRuleAction",
|
||||
"FirewallACLRemoveRuleAction",
|
||||
)
|
||||
|
||||
|
||||
class ACLAddRuleAbstractAction(AbstractAction, ABC):
|
||||
"""Base abstract class for ACL add rule actions."""
|
||||
|
||||
config: ConfigSchema = "ACLAddRuleAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration Schema base for ACL add rule abstract actions."""
|
||||
|
||||
src_ip: IPV4Address
|
||||
protocol_name: Union[IPProtocol, Literal["ALL"]]
|
||||
permission: Literal["PERMIT", "DENY"]
|
||||
position: int
|
||||
dst_ip: Union[IPV4Address, Literal["ALL"]]
|
||||
src_port: Union[Port, Literal["ALL"]]
|
||||
dst_port: Union[Port, Literal["ALL"]]
|
||||
src_wildcard: Union[IPV4Address, Literal["NONE"]]
|
||||
dst_wildcard: Union[IPV4Address, Literal["NONE"]]
|
||||
|
||||
|
||||
class ACLRemoveRuleAbstractAction(AbstractAction, identifier="acl_remove_rule_abstract_action"):
|
||||
"""Base abstract class for ACL remove rule actions."""
|
||||
|
||||
config: ConfigSchema = "ACLRemoveRuleAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration Schema base for ACL remove rule abstract actions."""
|
||||
|
||||
position: int
|
||||
|
||||
|
||||
class RouterACLAddRuleAction(ACLAddRuleAbstractAction, identifier="router_acl_add_rule"):
|
||||
"""Action which adds a rule to a router's ACL."""
|
||||
|
||||
config: "RouterACLAddRuleAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for RouterACLAddRuleAction."""
|
||||
|
||||
target_router: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.target_router,
|
||||
"acl",
|
||||
"add_rule",
|
||||
config.permission,
|
||||
config.protocol_name,
|
||||
str(config.src_ip),
|
||||
str(config.src_wildcard),
|
||||
config.src_port,
|
||||
str(config.dst_ip),
|
||||
str(config.dst_wildcard),
|
||||
config.dst_port,
|
||||
config.position,
|
||||
]
|
||||
|
||||
|
||||
class RouterACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="router_acl_remove_rule"):
|
||||
"""Action which removes a rule from a router's ACL."""
|
||||
|
||||
config: "RouterACLRemoveRuleAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for RouterACLRemoveRuleAction."""
|
||||
|
||||
target_router: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["network", "node", config.target_router, "acl", "remove_rule", config.position]
|
||||
|
||||
|
||||
class FirewallACLAddRuleAction(ACLAddRuleAbstractAction, identifier="firewall_acl_add_rule"):
|
||||
"""Action which adds a rule to a firewall port's ACL."""
|
||||
|
||||
config: "FirewallACLAddRuleAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for FirewallACLAddRuleAction."""
|
||||
|
||||
target_firewall_nodename: str
|
||||
firewall_port_name: str
|
||||
firewall_port_direction: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.target_firewall_nodename,
|
||||
config.firewall_port_name,
|
||||
config.firewall_port_direction,
|
||||
"acl",
|
||||
"add_rule",
|
||||
config.permission,
|
||||
config.protocol_name,
|
||||
str(config.src_ip),
|
||||
str(config.src_wildcard),
|
||||
config.src_port,
|
||||
str(config.dst_ip),
|
||||
str(config.dst_wildcard),
|
||||
config.dst_port,
|
||||
config.position,
|
||||
]
|
||||
|
||||
|
||||
class FirewallACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="firewall_acl_remove_rule"):
|
||||
"""Action which removes a rule from a firewall port's ACL."""
|
||||
|
||||
config: "FirewallACLRemoveRuleAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for FirewallACLRemoveRuleAction."""
|
||||
|
||||
target_firewall_nodename: str
|
||||
firewall_port_name: str
|
||||
firewall_port_direction: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.target_firewall_nodename,
|
||||
config.firewall_port_name,
|
||||
config.firewall_port_direction,
|
||||
"acl",
|
||||
"remove_rule",
|
||||
config.position,
|
||||
]
|
||||
137
src/primaite/game/agent/actions/application.py
Normal file
137
src/primaite/game/agent/actions/application.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import ABC
|
||||
from typing import ClassVar
|
||||
|
||||
from primaite.game.agent.actions.abstract import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
"NodeApplicationExecuteAction",
|
||||
"NodeApplicationScanAction",
|
||||
"NodeApplicationCloseAction",
|
||||
"NodeApplicationFixAction",
|
||||
"NodeApplicationInstallAction",
|
||||
"NodeApplicationRemoveAction",
|
||||
)
|
||||
|
||||
|
||||
class NodeApplicationAbstractAction(AbstractAction, ABC):
|
||||
"""
|
||||
Base class for application actions.
|
||||
|
||||
Any action which applies to an application and uses node_name and application_name as its only two parameters can
|
||||
inherit from this base class.
|
||||
"""
|
||||
|
||||
config: "NodeApplicationAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base Configuration schema for Node Application actions."""
|
||||
|
||||
node_name: str
|
||||
application_name: str
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"application",
|
||||
config.application_name,
|
||||
config.verb,
|
||||
]
|
||||
|
||||
|
||||
class NodeApplicationExecuteAction(NodeApplicationAbstractAction, identifier="node_application_execute"):
|
||||
"""Action which executes an application."""
|
||||
|
||||
config: "NodeApplicationExecuteAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeApplicationExecuteAction."""
|
||||
|
||||
verb: str = "execute"
|
||||
|
||||
|
||||
class NodeApplicationScanAction(NodeApplicationAbstractAction, identifier="node_application_scan"):
|
||||
"""Action which scans an application."""
|
||||
|
||||
config: "NodeApplicationScanAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeApplicationScanAction."""
|
||||
|
||||
verb: str = "scan"
|
||||
|
||||
|
||||
class NodeApplicationCloseAction(NodeApplicationAbstractAction, identifier="node_application_close"):
|
||||
"""Action which closes an application."""
|
||||
|
||||
config: "NodeApplicationCloseAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeApplicationCloseAction."""
|
||||
|
||||
verb: str = "close"
|
||||
|
||||
|
||||
class NodeApplicationFixAction(NodeApplicationAbstractAction, identifier="node_application_fix"):
|
||||
"""Action which fixes an application."""
|
||||
|
||||
config: "NodeApplicationFixAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeApplicationFixAction."""
|
||||
|
||||
verb: str = "fix"
|
||||
|
||||
|
||||
class NodeApplicationInstallAction(NodeApplicationAbstractAction, identifier="node_application_install"):
|
||||
"""Action which installs an application."""
|
||||
|
||||
config: "NodeApplicationInstallAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeApplicationInstallAction."""
|
||||
|
||||
verb: str = "install"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"software_manager",
|
||||
"application",
|
||||
config.verb,
|
||||
config.application_name,
|
||||
]
|
||||
|
||||
|
||||
class NodeApplicationRemoveAction(NodeApplicationAbstractAction, identifier="node_application_remove"):
|
||||
"""Action which removes/uninstalls an application."""
|
||||
|
||||
config: "NodeApplicationRemoveAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeApplicationRemoveAction."""
|
||||
|
||||
verb: str = "uninstall"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"software_manager",
|
||||
"application",
|
||||
config.verb,
|
||||
config.application_name,
|
||||
]
|
||||
189
src/primaite/game/agent/actions/file.py
Normal file
189
src/primaite/game/agent/actions/file.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import ABC
|
||||
from typing import ClassVar
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
"NodeFileCreateAction",
|
||||
"NodeFileScanAction",
|
||||
"NodeFileDeleteAction",
|
||||
"NodeFileRestoreAction",
|
||||
"NodeFileCorruptAction",
|
||||
"NodeFileAccessAction",
|
||||
"NodeFileCheckhashAction",
|
||||
"NodeFileRepairAction",
|
||||
)
|
||||
|
||||
|
||||
class NodeFileAbstractAction(AbstractAction, ABC):
|
||||
"""Abstract base class for file actions.
|
||||
|
||||
Any action which applies to a file and uses node_name, folder_name, and file_name as its
|
||||
only three parameters can inherit from this base class.
|
||||
"""
|
||||
|
||||
config: "NodeFileAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeFileAbstractAction."""
|
||||
|
||||
node_name: str
|
||||
folder_name: str
|
||||
file_name: str
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.folder_name is None or config.file_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
"folder",
|
||||
config.folder_name,
|
||||
"file",
|
||||
config.file_name,
|
||||
config.verb,
|
||||
]
|
||||
|
||||
|
||||
class NodeFileCreateAction(NodeFileAbstractAction, identifier="node_file_create"):
|
||||
"""Action which creates a new file in a given folder."""
|
||||
|
||||
config: "NodeFileCreateAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileCreateAction."""
|
||||
|
||||
verb: ClassVar[str] = "create"
|
||||
force: bool = False
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.folder_name is None or config.file_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
config.verb,
|
||||
"file",
|
||||
config.folder_name,
|
||||
config.file_name,
|
||||
config.verb,
|
||||
]
|
||||
|
||||
|
||||
class NodeFileScanAction(NodeFileAbstractAction, identifier="node_file_scan"):
|
||||
"""Action which scans a file."""
|
||||
|
||||
config: "NodeFileScanAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileScanAction."""
|
||||
|
||||
verb: ClassVar[str] = "scan"
|
||||
|
||||
|
||||
class NodeFileDeleteAction(NodeFileAbstractAction, identifier="node_file_delete"):
|
||||
"""Action which deletes a file."""
|
||||
|
||||
config: "NodeFileDeleteAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileDeleteAction."""
|
||||
|
||||
verb: ClassVar[str] = "delete"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.folder_name is None or config.file_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
config.verb,
|
||||
"file",
|
||||
config.folder_name,
|
||||
config.file_name,
|
||||
]
|
||||
|
||||
|
||||
class NodeFileRestoreAction(NodeFileAbstractAction, identifier="node_file_restore"):
|
||||
"""Action which restores a file."""
|
||||
|
||||
config: "NodeFileRestoreAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileRestoreAction."""
|
||||
|
||||
verb: ClassVar[str] = "restore"
|
||||
|
||||
|
||||
class NodeFileCorruptAction(NodeFileAbstractAction, identifier="node_file_corrupt"):
|
||||
"""Action which corrupts a file."""
|
||||
|
||||
config: "NodeFileCorruptAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileCorruptAction."""
|
||||
|
||||
verb: ClassVar[str] = "corrupt"
|
||||
|
||||
|
||||
class NodeFileAccessAction(NodeFileAbstractAction, identifier="node_file_access"):
|
||||
"""Action which increases a file's access count."""
|
||||
|
||||
config: "NodeFileAccessAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileAccessAction."""
|
||||
|
||||
verb: ClassVar[str] = "access"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.folder_name is None or config.file_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
config.verb,
|
||||
config.folder_name,
|
||||
config.file_name,
|
||||
]
|
||||
|
||||
|
||||
class NodeFileCheckhashAction(NodeFileAbstractAction, identifier="node_file_checkhash"):
|
||||
"""Action which checks the hash of a file."""
|
||||
|
||||
config: "NodeFileCheckhashAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileCheckhashAction."""
|
||||
|
||||
verb: ClassVar[str] = "checkhash"
|
||||
|
||||
|
||||
class NodeFileRepairAction(NodeFileAbstractAction, identifier="node_file_repair"):
|
||||
"""Action which repairs a file."""
|
||||
|
||||
config: "NodeFileRepairAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeFileRepairAction."""
|
||||
|
||||
verb: ClassVar[str] = "repair"
|
||||
117
src/primaite/game/agent/actions/folder.py
Normal file
117
src/primaite/game/agent/actions/folder.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import ABC
|
||||
from typing import ClassVar
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
"NodeFolderScanAction",
|
||||
"NodeFolderCheckhashAction",
|
||||
"NodeFolderRepairAction",
|
||||
"NodeFolderRestoreAction",
|
||||
"NodeFolderCreateAction",
|
||||
)
|
||||
|
||||
|
||||
class NodeFolderAbstractAction(AbstractAction, ABC):
|
||||
"""
|
||||
Base class for folder actions.
|
||||
|
||||
Any action which applies to a folder and uses node_name and folder_name as its only two parameters can inherit from
|
||||
this base class.
|
||||
"""
|
||||
|
||||
config: "NodeFolderAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base configuration schema for NodeFolder actions."""
|
||||
|
||||
node_name: str
|
||||
folder_name: str
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.folder_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
"folder",
|
||||
config.folder_name,
|
||||
config.verb,
|
||||
]
|
||||
|
||||
|
||||
class NodeFolderScanAction(NodeFolderAbstractAction, identifier="node_folder_scan"):
|
||||
"""Action which scans a folder."""
|
||||
|
||||
config: "NodeFolderScanAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFolderScanAction."""
|
||||
|
||||
verb: ClassVar[str] = "scan"
|
||||
|
||||
|
||||
class NodeFolderCheckhashAction(NodeFolderAbstractAction, identifier="node_folder_checkhash"):
|
||||
"""Action which checks the hash of a folder."""
|
||||
|
||||
config: "NodeFolderCheckhashAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFolderCheckhashAction."""
|
||||
|
||||
verb: ClassVar[str] = "checkhash"
|
||||
|
||||
|
||||
class NodeFolderRepairAction(NodeFolderAbstractAction, identifier="node_folder_repair"):
|
||||
"""Action which repairs a folder."""
|
||||
|
||||
config: "NodeFolderRepairAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFolderRepairAction."""
|
||||
|
||||
verb: ClassVar[str] = "repair"
|
||||
|
||||
|
||||
class NodeFolderRestoreAction(NodeFolderAbstractAction, identifier="node_folder_restore"):
|
||||
"""Action which restores a folder."""
|
||||
|
||||
config: "NodeFolderRestoreAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFolderRestoreAction."""
|
||||
|
||||
verb: ClassVar[str] = "restore"
|
||||
|
||||
|
||||
class NodeFolderCreateAction(NodeFolderAbstractAction, identifier="node_folder_create"):
|
||||
"""Action which creates a new folder."""
|
||||
|
||||
config: "NodeFolderCreateAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFolderCreateAction."""
|
||||
|
||||
verb: ClassVar[str] = "create"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.folder_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
config.verb,
|
||||
"folder",
|
||||
config.folder_name,
|
||||
]
|
||||
62
src/primaite/game/agent/actions/host_nic.py
Normal file
62
src/primaite/game/agent/actions/host_nic.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import ABC
|
||||
from typing import ClassVar
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = ("HostNICEnableAction", "HostNICDisableAction")
|
||||
|
||||
|
||||
class HostNICAbstractAction(AbstractAction, ABC):
|
||||
"""
|
||||
Abstract base class for NIC actions.
|
||||
|
||||
Any action which applies to a NIC and uses node_name and nic_num as its only two parameters can inherit from this
|
||||
base class.
|
||||
"""
|
||||
|
||||
config: "HostNICAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base Configuration schema for HostNIC actions."""
|
||||
|
||||
node_name: str
|
||||
nic_num: int
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.nic_num is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"network_interface",
|
||||
config.nic_num,
|
||||
config.verb,
|
||||
]
|
||||
|
||||
|
||||
class HostNICEnableAction(HostNICAbstractAction, identifier="host_nic_enable"):
|
||||
"""Action which enables a NIC."""
|
||||
|
||||
config: "HostNICEnableAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(HostNICAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for HostNICEnableAction."""
|
||||
|
||||
verb: ClassVar[str] = "enable"
|
||||
|
||||
|
||||
class HostNICDisableAction(HostNICAbstractAction, identifier="host_nic_disable"):
|
||||
"""Action which disables a NIC."""
|
||||
|
||||
config: "HostNICDisableAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(HostNICAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for HostNICDisableAction."""
|
||||
|
||||
verb: ClassVar[str] = "disable"
|
||||
108
src/primaite/game/agent/actions/manager.py
Normal file
108
src/primaite/game/agent/actions/manager.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
"""yaml example.
|
||||
|
||||
agents:
|
||||
- name: agent_1
|
||||
action_space:
|
||||
actions:
|
||||
- do_nothing
|
||||
- node_service_start
|
||||
- node_service_stop
|
||||
action_map:
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from gymnasium import spaces
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from primaite.game.agent.actions.abstract import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = ("DoNothingAction", "ActionManager")
|
||||
|
||||
|
||||
class DoNothingAction(AbstractAction, identifier="do_nothing"):
|
||||
"""Do Nothing Action."""
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for do_nothingAction."""
|
||||
|
||||
type: str = "do_nothing"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["do_nothing"]
|
||||
|
||||
|
||||
class _ActionMapItem(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
action: str
|
||||
options: Dict
|
||||
|
||||
|
||||
class ActionManager(BaseModel):
|
||||
"""Class which manages the action space for an agent."""
|
||||
|
||||
class ConfigSchema(BaseModel):
|
||||
"""Config Schema for ActionManager."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
action_map: Dict[int, _ActionMapItem] = {}
|
||||
"""Mapping between integer action choices and CAOS actions."""
|
||||
|
||||
@field_validator("action_map", mode="after")
|
||||
def consecutive_action_nums(cls, v: Dict) -> Dict:
|
||||
"""Make sure all numbers between 0 and N are represented as dict keys in action map."""
|
||||
assert all([i in v.keys() for i in range(len(v))])
|
||||
return v
|
||||
|
||||
config: ActionManager.ConfigSchema = Field(default_factory=lambda: ActionManager.ConfigSchema())
|
||||
|
||||
action_map: Dict[int, Tuple[str, Dict]] = {}
|
||||
"""Init as empty, populate after model validation."""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.action_map = {n: (v.action, v.options) for n, v in self.config.action_map.items()}
|
||||
|
||||
def get_action(self, action: int) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Produce action in CAOS format.
|
||||
|
||||
The agent chooses an action (as an integer), this is converted into an action in CAOS format
|
||||
The CAOS format is basically an action identifier, followed by parameters stored in a dictionary.
|
||||
"""
|
||||
act_identifier, act_options = self.action_map[action]
|
||||
return act_identifier, act_options
|
||||
|
||||
def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat:
|
||||
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format."""
|
||||
act_class = AbstractAction._registry[action_identifier]
|
||||
config = act_class.ConfigSchema(**action_options)
|
||||
return act_class.form_request(config=config)
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Return the gymnasium action space for this agent."""
|
||||
return spaces.Discrete(len(self.action_map))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: Dict) -> "ActionManager":
|
||||
"""
|
||||
Construct an ActionManager from a config dictionary.
|
||||
|
||||
The action space config supports must contain the following key:
|
||||
``action_map`` - List of actions available to the agent, formatted as a dictionary where the key is the
|
||||
action number between 0 - N, and the value is the CAOS-formatted action.
|
||||
|
||||
:param cfg: The action space config.
|
||||
:type cfg: Dict
|
||||
:return: The constructed ActionManager.
|
||||
:rtype: ActionManager
|
||||
"""
|
||||
return cls(**cfg.get("options", {}), act_map=cfg.get("action_map"))
|
||||
57
src/primaite/game/agent/actions/network.py
Normal file
57
src/primaite/game/agent/actions/network.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = ("NetworkPortEnableAction", "NetworkPortDisableAction")
|
||||
|
||||
|
||||
class NetworkPortAbstractAction(AbstractAction, identifier="network_port_abstract"):
|
||||
"""Base class for Network port actions."""
|
||||
|
||||
config: "NetworkPortAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base configuration schema for NetworkPort actions."""
|
||||
|
||||
target_nodename: str
|
||||
port_num: int
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.target_nodename is None or config.port_num is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.target_nodename,
|
||||
"network_interface",
|
||||
config.port_num,
|
||||
config.verb,
|
||||
]
|
||||
|
||||
|
||||
class NetworkPortEnableAction(NetworkPortAbstractAction, identifier="network_port_enable"):
|
||||
"""Action which enables are port on a router or a firewall."""
|
||||
|
||||
config: "NetworkPortEnableAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NetworkPortAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NetworkPortEnableAction."""
|
||||
|
||||
verb: ClassVar[str] = "enable"
|
||||
|
||||
|
||||
class NetworkPortDisableAction(NetworkPortAbstractAction, identifier="network_port_disable"):
|
||||
"""Action which disables are port on a router or a firewall."""
|
||||
|
||||
config: "NetworkPortDisableAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NetworkPortAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NetworkPortDisableAction."""
|
||||
|
||||
verb: ClassVar[str] = "disable"
|
||||
186
src/primaite/game/agent/actions/node.py
Normal file
186
src/primaite/game/agent/actions/node.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import abstractmethod
|
||||
from typing import ClassVar, List, Optional, Union
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol
|
||||
from primaite.utils.validation.port import Port
|
||||
|
||||
__all__ = (
|
||||
"NodeOSScanAction",
|
||||
"NodeShutdownAction",
|
||||
"NodeStartupAction",
|
||||
"NodeResetAction",
|
||||
"NodeNMAPPingScanAction",
|
||||
"NodeNMAPPortScanAction",
|
||||
"NodeNetworkServiceReconAction",
|
||||
)
|
||||
|
||||
|
||||
class NodeAbstractAction(AbstractAction, identifier="node_abstract"):
|
||||
"""
|
||||
Abstract base class for node actions.
|
||||
|
||||
Any action which applies to a node and uses node_name as its only parameter can inherit from this base class.
|
||||
"""
|
||||
|
||||
config: "NodeAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base Configuration schema for Node actions."""
|
||||
|
||||
node_name: str
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
print(config)
|
||||
return ["network", "node", config.node_name, config.verb]
|
||||
|
||||
|
||||
class NodeOSScanAction(NodeAbstractAction, identifier="node_os_scan"):
|
||||
"""Action which scans a node's OS."""
|
||||
|
||||
config: "NodeOSScanAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeOSScanAction."""
|
||||
|
||||
verb: ClassVar[str] = "scan"
|
||||
|
||||
|
||||
class NodeShutdownAction(NodeAbstractAction, identifier="node_shutdown"):
|
||||
"""Action which shuts down a node."""
|
||||
|
||||
config: "NodeShutdownAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeShutdownAction."""
|
||||
|
||||
verb: ClassVar[str] = "shutdown"
|
||||
|
||||
|
||||
class NodeStartupAction(NodeAbstractAction, identifier="node_startup"):
|
||||
"""Action which starts up a node."""
|
||||
|
||||
config: "NodeStartupAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeStartupAction."""
|
||||
|
||||
verb: ClassVar[str] = "startup"
|
||||
|
||||
|
||||
class NodeResetAction(NodeAbstractAction, identifier="node_reset"):
|
||||
"""Action which resets a node."""
|
||||
|
||||
config: "NodeResetAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeResetAction."""
|
||||
|
||||
verb: ClassVar[str] = "reset"
|
||||
|
||||
|
||||
class NodeNMAPAbstractAction(AbstractAction, identifier="node_nmap_abstract_action"):
|
||||
"""Base class for NodeNMAP actions."""
|
||||
|
||||
config: "NodeNMAPAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base Configuration Schema for NodeNMAP actions."""
|
||||
|
||||
target_ip_address: Union[str, List[str]]
|
||||
show: bool = False
|
||||
source_node: str
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
# NMAP action requests don't share a common format for their requests
|
||||
# This is just a placeholder to ensure the method is defined.
|
||||
pass
|
||||
|
||||
|
||||
class NodeNMAPPingScanAction(NodeNMAPAbstractAction, identifier="node_nmap_ping_scan"):
|
||||
"""Action which performs an NMAP ping scan."""
|
||||
|
||||
config: "NodeNMAPPingScanAction.ConfigSchema"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: "NodeNMAPPingScanAction.ConfigSchema") -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.source_node,
|
||||
"application",
|
||||
"NMAP",
|
||||
"ping_scan",
|
||||
{"target_ip_address": config.target_ip_address, "show": config.show},
|
||||
]
|
||||
|
||||
|
||||
class NodeNMAPPortScanAction(NodeNMAPAbstractAction, identifier="node_nmap_port_scan"):
|
||||
"""Action which performs an NMAP port scan."""
|
||||
|
||||
config: "NodeNMAPPortScanAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeNMAPPortScanAction."""
|
||||
|
||||
source_node: str
|
||||
target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None
|
||||
target_port: Optional[Union[Port, List[Port]]] = None
|
||||
show: Optional[bool] = (False,)
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.source_node,
|
||||
"application",
|
||||
"NMAP",
|
||||
"port_scan",
|
||||
{
|
||||
"target_ip_address": config.target_ip_address,
|
||||
"target_port": config.target_port,
|
||||
"target_protocol": config.target_protocol,
|
||||
"show": config.show,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, identifier="node_network_service_recon"):
|
||||
"""Action which performs an NMAP network service recon (ping scan followed by port scan)."""
|
||||
|
||||
config: "NodeNetworkServiceReconAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeNetworkServiceReconAction."""
|
||||
|
||||
target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None
|
||||
target_port: Optional[Union[Port, List[Port]]] = None
|
||||
show: Optional[bool] = (False,)
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.source_node,
|
||||
"application",
|
||||
"NMAP",
|
||||
"network_service_recon",
|
||||
{
|
||||
"target_ip_address": config.target_ip_address,
|
||||
"target_port": config.target_port,
|
||||
"target_protocol": config.target_protocol,
|
||||
"show": config.show,
|
||||
},
|
||||
]
|
||||
135
src/primaite/game/agent/actions/service.py
Normal file
135
src/primaite/game/agent/actions/service.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from typing import ClassVar
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
"NodeServiceScanAction",
|
||||
"NodeServiceStopAction",
|
||||
"NodeServiceStartAction",
|
||||
"NodeServicePauseAction",
|
||||
"NodeServiceResumeAction",
|
||||
"NodeServiceRestartAction",
|
||||
"NodeServiceDisableAction",
|
||||
"NodeServiceEnableAction",
|
||||
"NodeServiceFixAction",
|
||||
)
|
||||
|
||||
|
||||
class NodeServiceAbstractAction(AbstractAction, identifier="node_service_abstract"):
|
||||
"""Abstract Action for Node Service related actions.
|
||||
|
||||
Any actions which use node_name and service_name can inherit from this class.
|
||||
"""
|
||||
|
||||
config: "NodeServiceAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
node_name: str
|
||||
service_name: str
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["network", "node", config.node_name, "service", config.service_name, config.verb]
|
||||
|
||||
|
||||
class NodeServiceScanAction(NodeServiceAbstractAction, identifier="node_service_scan"):
|
||||
"""Action which scans a service."""
|
||||
|
||||
config: "NodeServiceScanAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceScanAction."""
|
||||
|
||||
verb: ClassVar[str] = "scan"
|
||||
|
||||
|
||||
class NodeServiceStopAction(NodeServiceAbstractAction, identifier="node_service_stop"):
|
||||
"""Action which stops a service."""
|
||||
|
||||
config: "NodeServiceStopAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceStopAction."""
|
||||
|
||||
verb: ClassVar[str] = "stop"
|
||||
|
||||
|
||||
class NodeServiceStartAction(NodeServiceAbstractAction, identifier="node_service_start"):
|
||||
"""Action which starts a service."""
|
||||
|
||||
config: "NodeServiceStartAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceStartAction."""
|
||||
|
||||
verb: ClassVar[str] = "start"
|
||||
|
||||
|
||||
class NodeServicePauseAction(NodeServiceAbstractAction, identifier="node_service_pause"):
|
||||
"""Action which pauses a service."""
|
||||
|
||||
config: "NodeServicePauseAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServicePauseAction."""
|
||||
|
||||
verb: ClassVar[str] = "pause"
|
||||
|
||||
|
||||
class NodeServiceResumeAction(NodeServiceAbstractAction, identifier="node_service_resume"):
|
||||
"""Action which resumes a service."""
|
||||
|
||||
config: "NodeServiceResumeAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceResumeAction."""
|
||||
|
||||
verb: ClassVar[str] = "resume"
|
||||
|
||||
|
||||
class NodeServiceRestartAction(NodeServiceAbstractAction, identifier="node_service_restart"):
|
||||
"""Action which restarts a service."""
|
||||
|
||||
config: "NodeServiceRestartAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceRestartAction."""
|
||||
|
||||
verb: ClassVar[str] = "restart"
|
||||
|
||||
|
||||
class NodeServiceDisableAction(NodeServiceAbstractAction, identifier="node_service_disable"):
|
||||
"""Action which disables a service."""
|
||||
|
||||
config: "NodeServiceDisableAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceDisableAction."""
|
||||
|
||||
verb: ClassVar[str] = "disable"
|
||||
|
||||
|
||||
class NodeServiceEnableAction(NodeServiceAbstractAction, identifier="node_service_enable"):
|
||||
"""Action which enables a service."""
|
||||
|
||||
config: "NodeServiceEnableAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceEnableAction."""
|
||||
|
||||
verb: ClassVar[str] = "enable"
|
||||
|
||||
|
||||
class NodeServiceFixAction(NodeServiceAbstractAction, identifier="node_service_fix"):
|
||||
"""Action which fixes a service."""
|
||||
|
||||
config: "NodeServiceFixAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceFixAction."""
|
||||
|
||||
verb: ClassVar[str] = "fix"
|
||||
108
src/primaite/game/agent/actions/session.py
Normal file
108
src/primaite/game/agent/actions/session.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import abstractmethod
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
"NodeSessionsRemoteLoginAction",
|
||||
"NodeSessionsRemoteLogoutAction",
|
||||
"NodeAccountChangePasswordAction",
|
||||
)
|
||||
|
||||
|
||||
class NodeSessionAbstractAction(AbstractAction, identifier="node_session_abstract"):
|
||||
"""Base class for NodeSession actions."""
|
||||
|
||||
config: "NodeSessionAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base configuration schema for NodeSessionAbstractActions."""
|
||||
|
||||
node_name: str
|
||||
remote_ip: str
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""
|
||||
Abstract method for request forming.
|
||||
|
||||
Should return the action formatted as a request which can be ingested by the PrimAITE simulation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, identifier="node_session_remote_login"):
|
||||
"""Action which performs a remote session login."""
|
||||
|
||||
config: "NodeSessionsRemoteLoginAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeSessionsRemoteLoginAction."""
|
||||
|
||||
username: str
|
||||
password: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.remote_ip is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"service",
|
||||
"Terminal",
|
||||
"node_session_remote_login",
|
||||
config.username,
|
||||
config.password,
|
||||
config.remote_ip,
|
||||
]
|
||||
|
||||
|
||||
class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, identifier="node_session_remote_logoff"):
|
||||
"""Action which performs a remote session logout."""
|
||||
|
||||
config: "NodeSessionsRemoteLogoutAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeSessionsRemoteLogoutAction."""
|
||||
|
||||
verb: str = "remote_logoff"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.remote_ip is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", config.node_name, "service", "Terminal", config.verb, config.remote_ip]
|
||||
|
||||
|
||||
class NodeAccountChangePasswordAction(NodeSessionAbstractAction, identifier="node_account_change_password"):
|
||||
"""Action which changes the password for a user."""
|
||||
|
||||
config: "NodeAccountChangePasswordAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeAccountsChangePasswordAction."""
|
||||
|
||||
username: str
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"service",
|
||||
"UserManager",
|
||||
"change_password",
|
||||
config.username,
|
||||
config.current_password,
|
||||
config.new_password,
|
||||
]
|
||||
241
src/primaite/game/agent/actions/software.py
Normal file
241
src/primaite/game/agent/actions/software.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
"ConfigureRansomwareScriptAction",
|
||||
"ConfigureDoSBotAction",
|
||||
"ConfigureC2BeaconAction",
|
||||
"NodeSendRemoteCommandAction",
|
||||
"TerminalC2ServerAction",
|
||||
"RansomwareLaunchC2ServerAction",
|
||||
"ExfiltrationC2ServerAction",
|
||||
"ConfigureDatabaseClientAction",
|
||||
)
|
||||
|
||||
|
||||
class ConfigureRansomwareScriptAction(AbstractAction, identifier="configure_ransomware_script"):
|
||||
"""Action which sets config parameters for a ransomware script on a node."""
|
||||
|
||||
config: "ConfigureRansomwareScriptAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration schema for ConfigureRansomwareScriptAction."""
|
||||
|
||||
node_name: str
|
||||
server_ip_address: Optional[str] = None
|
||||
server_password: Optional[str] = None
|
||||
payload: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
data = dict(
|
||||
server_ip_address=config.server_ip_address,
|
||||
server_password=config.server_password,
|
||||
payload=config.payload,
|
||||
)
|
||||
return ["network", "node", config.node_name, "application", "RansomwareScript", "configure", data]
|
||||
|
||||
|
||||
class RansomwareConfigureC2ServerAction(ConfigureRansomwareScriptAction, identifier="c2_server_ransomware_configure"):
|
||||
"""Action which causes a C2 server to send a command to set options on a ransomware script remotely."""
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigureRansomwareScriptAction.ConfigSchema) -> RequestFormat:
|
||||
data = dict(
|
||||
server_ip_address=config.server_ip_address, server_password=config.server_password, payload=config.payload
|
||||
)
|
||||
return ["network", "node", config.node_name, "application", "C2Server", "ransomware_configure", data]
|
||||
|
||||
|
||||
class ConfigureDoSBotAction(AbstractAction, identifier="configure_dos_bot"):
|
||||
"""Action which sets config parameters for a DoS bot on a node."""
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Schema for options that can be passed to this action."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
node_name: str
|
||||
target_ip_address: Optional[str] = None
|
||||
target_port: Optional[str] = None
|
||||
payload: Optional[str] = None
|
||||
repeat: Optional[bool] = None
|
||||
port_scan_p_of_success: Optional[float] = None
|
||||
dos_intensity: Optional[float] = None
|
||||
max_sessions: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
data = dict(
|
||||
target_ip_address=config.target_ip_address,
|
||||
target_port=config.target_port,
|
||||
payload=config.payload,
|
||||
repeat=config.repeat,
|
||||
port_scan_p_of_success=config.port_scan_p_of_success,
|
||||
dos_intensity=config.dos_intensity,
|
||||
max_sessions=config.max_sessions,
|
||||
)
|
||||
data = {k: v for k, v in data.items() if v is not None}
|
||||
return ["network", "node", config.node_name, "application", "DoSBot", "configure", data]
|
||||
|
||||
|
||||
class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2_beacon"):
|
||||
"""Action which configures a C2 Beacon based on the parameters given."""
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration schema for ConfigureC2BeaconAction."""
|
||||
|
||||
node_name: str
|
||||
c2_server_ip_address: str
|
||||
keep_alive_frequency: int = Field(default=5, ge=1)
|
||||
masquerade_protocol: str = Field(default="TCP")
|
||||
masquerade_port: str = Field(default="HTTP")
|
||||
|
||||
@classmethod
|
||||
def form_request(self, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
data = dict(
|
||||
c2_server_ip_address=config.c2_server_ip_address,
|
||||
keep_alive_frequency=config.keep_alive_frequency,
|
||||
masquerade_protocol=config.masquerade_protocol,
|
||||
masquerade_port=config.masquerade_port,
|
||||
)
|
||||
return ["network", "node", config.node_name, "application", "C2Beacon", "configure", data]
|
||||
|
||||
|
||||
class NodeSendRemoteCommandAction(AbstractAction, identifier="node_send_remote_command"):
|
||||
"""Action which sends a terminal command to a remote node via SSH."""
|
||||
|
||||
config: "NodeSendRemoteCommandAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeSendRemoteCommandAction."""
|
||||
|
||||
node_name: str
|
||||
remote_ip: str
|
||||
command: RequestFormat
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"service",
|
||||
"Terminal",
|
||||
"send_remote_command",
|
||||
config.remote_ip,
|
||||
{"command": config.command},
|
||||
]
|
||||
|
||||
|
||||
class TerminalC2ServerAction(AbstractAction, identifier="c2_server_terminal_command"):
|
||||
"""Action which causes the C2 Server to send a command to the C2 Beacon to execute the terminal command passed."""
|
||||
|
||||
config: "TerminalC2ServerAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Schema for options that can be passed to this action."""
|
||||
|
||||
node_name: str
|
||||
commands: Union[List[RequestFormat], RequestFormat]
|
||||
ip_address: Optional[str]
|
||||
username: Optional[str]
|
||||
password: Optional[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
|
||||
command_model = {
|
||||
"commands": config.commands,
|
||||
"ip_address": config.ip_address,
|
||||
"username": config.username,
|
||||
"password": config.password,
|
||||
}
|
||||
return ["network", "node", config.node_name, "application", "C2Server", "terminal_command", command_model]
|
||||
|
||||
|
||||
class RansomwareLaunchC2ServerAction(AbstractAction, identifier="c2_server_ransomware_launch"):
|
||||
"""Action which causes the C2 Server to send a command to the C2 Beacon to launch the RansomwareScript."""
|
||||
|
||||
config: "RansomwareLaunchC2ServerAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration schema for RansomwareLaunchC2ServerAction."""
|
||||
|
||||
node_name: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
# This action currently doesn't require any further configuration options.
|
||||
return ["network", "node", config.node_name, "application", "C2Server", "ransomware_launch"]
|
||||
|
||||
|
||||
class ExfiltrationC2ServerAction(AbstractAction, identifier="c2_server_data_exfiltrate"):
|
||||
"""Action which exfiltrates a target file from a certain node onto the C2 beacon and then the C2 Server."""
|
||||
|
||||
config: "ExfiltrationC2ServerAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Schema for options that can be passed to this action."""
|
||||
|
||||
node_name: str
|
||||
username: Optional[str]
|
||||
password: Optional[str]
|
||||
target_ip_address: str
|
||||
target_file_name: str
|
||||
target_folder_name: str
|
||||
exfiltration_folder_name: Optional[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
|
||||
command_model = {
|
||||
"target_file_name": config.target_file_name,
|
||||
"target_folder_name": config.target_folder_name,
|
||||
"exfiltration_folder_name": config.exfiltration_folder_name,
|
||||
"target_ip_address": config.target_ip_address,
|
||||
"username": config.username,
|
||||
"password": config.password,
|
||||
}
|
||||
return ["network", "node", config.node_name, "application", "C2Server", "exfiltrate", command_model]
|
||||
|
||||
|
||||
class ConfigureDatabaseClientAction(AbstractAction, identifier="configure_database_client"):
|
||||
"""Action which sets config parameters for a database client on a node."""
|
||||
|
||||
config: "ConfigureDatabaseClientAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Schema for options that can be passed to this action."""
|
||||
|
||||
node_name: str
|
||||
server_ip_address: Optional[str] = None
|
||||
server_password: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
data = {"server_ip_address": config.server_ip_address, "server_password": config.server_password}
|
||||
return ["network", "node", config.node_name, "application", "DatabaseClient", "configure", data]
|
||||
@@ -1,6 +1,7 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
@@ -20,20 +21,22 @@ class _NotJSONFilter(logging.Filter):
|
||||
|
||||
class AgentLog:
|
||||
"""
|
||||
A Agent Log class is a simple logger dedicated to managing and writing logging updates and information for an agent.
|
||||
An Agent Log class is a simple logger dedicated to managing and writing updates and information for an agent.
|
||||
|
||||
Each log message is written to a file located at: <simulation output directory>/agent_name/agent_name.log
|
||||
Each log message is written to a file located at:
|
||||
<simulation output directory>/agent_name/agent_name.log
|
||||
"""
|
||||
|
||||
def __init__(self, agent_name: str):
|
||||
def __init__(self, agent_name: Optional[str]):
|
||||
"""
|
||||
Constructs a Agent Log instance for a given hostname.
|
||||
|
||||
:param hostname: The hostname associated with the system logs being recorded.
|
||||
:param agent_name: The agent_name associated with the system logs being recorded.
|
||||
"""
|
||||
self.agent_name = agent_name
|
||||
self.current_episode: int = 1
|
||||
super().__init__()
|
||||
self.agent_name = agent_name if agent_name else "unnamed_agent"
|
||||
self.current_timestep: int = 0
|
||||
self.current_episode: int = 1
|
||||
self.setup_logger()
|
||||
|
||||
@property
|
||||
@@ -90,7 +93,7 @@ class AgentLog:
|
||||
|
||||
def _write_to_terminal(self, msg: str, level: str, to_terminal: bool = False):
|
||||
if to_terminal or SIM_OUTPUT.write_agent_log_to_terminal:
|
||||
print(f"{self.agent_name}: ({ self.timestep}) ({level}) {msg}")
|
||||
print(f"{self.agent_name}: ({self.timestep}) ({level}) {msg}")
|
||||
|
||||
def debug(self, msg: str, to_terminal: bool = False):
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
"""Interface for agents."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.agent_log import AgentLog
|
||||
@@ -15,6 +17,8 @@ from primaite.interface.request import RequestFormat, RequestResponse
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
__all__ = ("AgentHistoryItem", "AbstractAgent", "AbstractScriptedAgent", "ProxyAgent")
|
||||
|
||||
|
||||
class AgentHistoryItem(BaseModel):
|
||||
"""One entry of an agent's action log - what the agent did and how the simulator responded in 1 step."""
|
||||
@@ -39,89 +43,56 @@ class AgentHistoryItem(BaseModel):
|
||||
reward_info: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class AgentStartSettings(BaseModel):
|
||||
"""Configuration values for when an agent starts performing actions."""
|
||||
|
||||
start_step: int = 5
|
||||
"The timestep at which an agent begins performing it's actions"
|
||||
frequency: int = 5
|
||||
"The number of timesteps to wait between performing actions"
|
||||
variance: int = 0
|
||||
"The amount the frequency can randomly change to"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_variance_lt_frequency(self) -> "AgentStartSettings":
|
||||
"""
|
||||
Make sure variance is equal to or lower than frequency.
|
||||
|
||||
This is because the calculation for the next execution time is now + (frequency +- variance). If variance were
|
||||
greater than frequency, sometimes the bracketed term would be negative and the attack would never happen again.
|
||||
"""
|
||||
if self.variance > self.frequency:
|
||||
raise ValueError(
|
||||
f"Agent start settings error: variance must be lower than frequency "
|
||||
f"{self.variance=}, {self.frequency=}"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class AgentSettings(BaseModel):
|
||||
"""Settings for configuring the operation of an agent."""
|
||||
|
||||
start_settings: Optional[AgentStartSettings] = None
|
||||
"Configuration for when an agent begins performing it's actions"
|
||||
flatten_obs: bool = True
|
||||
"Whether to flatten the observation space before passing it to the agent. True by default."
|
||||
action_masking: bool = False
|
||||
"Whether to return action masks at each step."
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Optional[Dict]) -> "AgentSettings":
|
||||
"""Construct agent settings from a config dictionary.
|
||||
|
||||
:param config: A dict of options for the agent settings.
|
||||
:type config: Dict
|
||||
:return: The agent settings.
|
||||
:rtype: AgentSettings
|
||||
"""
|
||||
if config is None:
|
||||
return cls()
|
||||
|
||||
return cls(**config)
|
||||
|
||||
|
||||
class AbstractAgent(ABC):
|
||||
class AbstractAgent(BaseModel, ABC):
|
||||
"""Base class for scripted and RL agents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: Optional[str],
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
agent_settings: Optional[AgentSettings] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize an agent.
|
||||
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
|
||||
:param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes.
|
||||
:type agent_name: Optional[str]
|
||||
:param action_space: Action space for the agent.
|
||||
:type action_space: Optional[ActionManager]
|
||||
:param observation_space: Observation space for the agent.
|
||||
:type observation_space: Optional[ObservationSpace]
|
||||
:param reward_function: Reward function for the agent.
|
||||
:type reward_function: Optional[RewardFunction]
|
||||
:param agent_settings: Configurable Options for Abstracted Agents
|
||||
:type agent_settings: Optional[AgentSettings]
|
||||
"""
|
||||
self.agent_name: str = agent_name or "unnamed_agent"
|
||||
self.action_manager: Optional[ActionManager] = action_space
|
||||
self.observation_manager: Optional[ObservationManager] = observation_space
|
||||
self.reward_function: Optional[RewardFunction] = reward_function
|
||||
self.agent_settings = agent_settings or AgentSettings()
|
||||
self.history: List[AgentHistoryItem] = []
|
||||
self.logger = AgentLog(agent_name)
|
||||
class AgentSettingsSchema(BaseModel, ABC):
|
||||
"""Schema for the 'agent_settings' key."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
class ConfigSchema(BaseModel, ABC):
|
||||
"""Configuration Schema for AbstractAgents."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
type: str
|
||||
ref: str = ""
|
||||
"""name of the agent."""
|
||||
team: Optional[Literal["BLUE", "GREEN", "RED"]] = None
|
||||
agent_settings: AbstractAgent.AgentSettingsSchema = Field(default=lambda: AbstractAgent.AgentSettingsSchema())
|
||||
action_space: ActionManager.ConfigSchema = Field(default_factory=lambda: ActionManager.ConfigSchema())
|
||||
observation_space: ObservationManager.ConfigSchema = Field(
|
||||
default_factory=lambda: ObservationManager.ConfigSchema()
|
||||
)
|
||||
reward_function: RewardFunction.ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema())
|
||||
|
||||
config: "AbstractAgent.ConfigSchema" = Field(default_factory=lambda: AbstractAgent.ConfigSchema())
|
||||
|
||||
logger: AgentLog = AgentLog(agent_name="Abstract_Agent")
|
||||
history: List[AgentHistoryItem] = []
|
||||
|
||||
action_manager: ActionManager = Field(default_factory=lambda: ActionManager())
|
||||
observation_manager: ObservationManager = Field(default_factory=lambda: ObservationManager())
|
||||
reward_function: RewardFunction = Field(default_factory=lambda: RewardFunction())
|
||||
|
||||
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
|
||||
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Cannot create a new agent under reserved name {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Overwrite the default empty action, observation, and rewards with ones defined through the config."""
|
||||
self.action_manager = ActionManager(config=self.config.action_space)
|
||||
self.observation_manager = ObservationManager(config=self.config.observation_space)
|
||||
self.reward_function = RewardFunction(config=self.config.reward_function)
|
||||
return super().model_post_init(__context)
|
||||
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -159,9 +130,9 @@ class AbstractAgent(ABC):
|
||||
"""
|
||||
# in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39,
|
||||
# then use a bespoke conversion to take 1-40 int back into CAOS action
|
||||
return ("DO_NOTHING", {})
|
||||
return ("do_nothing", {})
|
||||
|
||||
def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> List[str]:
|
||||
def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> RequestFormat:
|
||||
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
|
||||
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
|
||||
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
|
||||
@@ -182,36 +153,47 @@ class AbstractAgent(ABC):
|
||||
"""Update the most recent history item with the reward value."""
|
||||
self.history[-1].reward = self.reward_function.current_reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> AbstractAgent:
|
||||
"""Grab the relevant agent class and construct an instance from a config dict."""
|
||||
agent_type = config["type"]
|
||||
agent_class = cls._registry[agent_type]
|
||||
return agent_class(config=config)
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent):
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent, identifier="AbstractScriptedAgent"):
|
||||
"""Base class for actors which generate their own behaviour."""
|
||||
|
||||
config: "AbstractScriptedAgent.ConfigSchema" = Field(default_factory=lambda: AbstractScriptedAgent.ConfigSchema())
|
||||
|
||||
class ConfigSchema(AbstractAgent.ConfigSchema):
|
||||
"""Configuration Schema for AbstractScriptedAgents."""
|
||||
|
||||
type: str = "AbstractScriptedAgent"
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""Return an action to be taken in the environment."""
|
||||
return super().get_action(obs=obs, timestep=timestep)
|
||||
|
||||
|
||||
class ProxyAgent(AbstractAgent):
|
||||
class ProxyAgent(AbstractAgent, identifier="ProxyAgent"):
|
||||
"""Agent that sends observations to an RL model and receives actions from that model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: Optional[str],
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
agent_settings: Optional[AgentSettings] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
agent_name=agent_name,
|
||||
action_space=action_space,
|
||||
observation_space=observation_space,
|
||||
reward_function=reward_function,
|
||||
)
|
||||
self.most_recent_action: ActType
|
||||
self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
|
||||
self.action_masking: bool = agent_settings.action_masking if agent_settings else False
|
||||
config: "ProxyAgent.ConfigSchema" = Field(default_factory=lambda: ProxyAgent.ConfigSchema())
|
||||
most_recent_action: ActType = None
|
||||
|
||||
class AgentSettingsSchema(AbstractAgent.AgentSettingsSchema):
|
||||
"""Schema for the `agent_settings` part of the agent config."""
|
||||
|
||||
flatten_obs: bool = False
|
||||
action_masking: bool = False
|
||||
|
||||
class ConfigSchema(AbstractAgent.ConfigSchema):
|
||||
"""Configuration Schema for Proxy Agent."""
|
||||
|
||||
type: str = "Proxy_Agent"
|
||||
agent_settings: ProxyAgent.AgentSettingsSchema = Field(default_factory=lambda: ProxyAgent.AgentSettingsSchema())
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
@@ -233,3 +215,8 @@ class ProxyAgent(AbstractAgent):
|
||||
The environment is responsible for calling this method when it receives an action from the agent policy.
|
||||
"""
|
||||
self.most_recent_action = action
|
||||
|
||||
@property
|
||||
def flatten_obs(self) -> bool:
|
||||
"""Return agent flatten_obs param."""
|
||||
return self.config.agent_settings.flatten_obs
|
||||
|
||||
@@ -17,5 +17,5 @@ from primaite.game.agent.observations.software_observation import ApplicationObs
|
||||
__all__ = [
|
||||
"ACLObservation", "FileObservation", "FolderObservation", "FirewallObservation", "HostObservation",
|
||||
"LinksObservation", "NICObservation", "PortObservation", "NodesObservation", "NestedObservation",
|
||||
"ObservationManager", "ApplicationObservation", "ServiceObservation",]
|
||||
"ObservationManager", "ApplicationObservation", "ServiceObservation", "RouterObservation", "LinkObservation",]
|
||||
# fmt: on
|
||||
|
||||
@@ -24,8 +24,8 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
"""List of IP addresses."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of wildcard strings."""
|
||||
port_list: Optional[List[int]] = None
|
||||
"""List of port numbers."""
|
||||
port_list: Optional[List[str]] = None
|
||||
"""List of port names."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocol names."""
|
||||
num_rules: Optional[int] = None
|
||||
@@ -37,7 +37,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
num_rules: int,
|
||||
ip_list: List[IPv4Address],
|
||||
wildcard_list: List[str],
|
||||
port_list: List[int],
|
||||
port_list: List[str],
|
||||
protocol_list: List[str],
|
||||
) -> None:
|
||||
"""
|
||||
@@ -51,8 +51,8 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
:type ip_list: List[IPv4Address]
|
||||
:param wildcard_list: List of wildcard strings.
|
||||
:type wildcard_list: List[str]
|
||||
:param port_list: List of port numbers.
|
||||
:type port_list: List[int]
|
||||
:param port_list: List of port names.
|
||||
:type port_list: List[str]
|
||||
:param protocol_list: List of protocol names.
|
||||
:type protocol_list: List[str]
|
||||
"""
|
||||
@@ -60,7 +60,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
self.num_rules: int = num_rules
|
||||
self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)}
|
||||
self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)}
|
||||
self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)}
|
||||
self.port_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(port_list)}
|
||||
self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)}
|
||||
self.default_observation: Dict = {
|
||||
i
|
||||
|
||||
@@ -190,6 +190,8 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
if self.files:
|
||||
self.default_observation["FILES"] = {i + 1: f.default_observation for i, f in enumerate(self.files)}
|
||||
|
||||
self.cached_obs: Optional[ObsType] = self.default_observation
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
@@ -204,7 +206,10 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
return self.default_observation
|
||||
|
||||
if self.file_system_requires_scan:
|
||||
health_status = folder_state["visible_status"]
|
||||
if not folder_state["scanned_this_step"]:
|
||||
health_status = self.cached_obs["health_status"]
|
||||
else:
|
||||
health_status = folder_state["visible_status"]
|
||||
else:
|
||||
health_status = folder_state["health_status"]
|
||||
|
||||
|
||||
@@ -27,13 +27,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
"""List of IP addresses for encoding ACLs."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of IP wildcards for encoding ACLs."""
|
||||
port_list: Optional[List[int]] = None
|
||||
port_list: Optional[List[str]] = None
|
||||
"""List of ports for encoding ACLs."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocols for encoding ACLs."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of rules ACL rules to show."""
|
||||
include_users: Optional[bool] = True
|
||||
include_users: Optional[bool] = None
|
||||
"""If True, report user session information."""
|
||||
|
||||
def __init__(
|
||||
@@ -41,7 +41,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
where: WhereType,
|
||||
ip_list: List[str],
|
||||
wildcard_list: List[str],
|
||||
port_list: List[int],
|
||||
port_list: List[str],
|
||||
protocol_list: List[str],
|
||||
num_rules: int,
|
||||
include_users: bool,
|
||||
@@ -56,8 +56,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
:type ip_list: List[str]
|
||||
:param wildcard_list: List of wildcard rules.
|
||||
:type wildcard_list: List[str]
|
||||
:param port_list: List of port numbers.
|
||||
:type port_list: List[int]
|
||||
:param port_list: List of port names.
|
||||
:type port_list: List[str]
|
||||
:param protocol_list: List of protocol types.
|
||||
:type protocol_list: List[str]
|
||||
:param num_rules: Number of rules configured in the firewall.
|
||||
@@ -72,7 +72,6 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
self.ports: List[PortObservation] = [
|
||||
PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3)
|
||||
]
|
||||
# TODO: check what the port nums are for firewall.
|
||||
|
||||
self.internal_inbound_acl = ACLObservation(
|
||||
where=self.where + ["internal_inbound_acl", "acl"],
|
||||
@@ -140,6 +139,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
},
|
||||
},
|
||||
}
|
||||
if self.include_users:
|
||||
self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -153,29 +154,35 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
firewall_state = access_from_nested_dict(state, self.where)
|
||||
if firewall_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
obs = {
|
||||
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
|
||||
"ACL": {
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.internal_outbound_acl.observe(state),
|
||||
|
||||
is_on = firewall_state["operating_state"] == 1
|
||||
if not is_on:
|
||||
obs = {**self.default_observation}
|
||||
|
||||
else:
|
||||
obs = {
|
||||
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
|
||||
"ACL": {
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.internal_outbound_acl.observe(state),
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.dmz_outbound_acl.observe(state),
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.external_outbound_acl.observe(state),
|
||||
},
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.dmz_outbound_acl.observe(state),
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.external_outbound_acl.observe(state),
|
||||
},
|
||||
},
|
||||
}
|
||||
if self.include_users:
|
||||
sess = firewall_state["services"]["UserSessionManager"]
|
||||
obs["users"] = {
|
||||
"local_login": 1 if sess["current_local_user"] else 0,
|
||||
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
|
||||
}
|
||||
if self.include_users:
|
||||
sess = firewall_state["services"]["UserSessionManager"]
|
||||
obs["users"] = {
|
||||
"local_login": 1 if sess["current_local_user"] else 0,
|
||||
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
|
||||
}
|
||||
return obs
|
||||
|
||||
@property
|
||||
@@ -186,34 +193,36 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
:return: Gymnasium space representing the observation space for firewall status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
space = spaces.Dict(
|
||||
{
|
||||
"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}),
|
||||
"ACL": spaces.Dict(
|
||||
{
|
||||
"INTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.internal_inbound_acl.space,
|
||||
"OUTBOUND": self.internal_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"DMZ": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.dmz_inbound_acl.space,
|
||||
"OUTBOUND": self.dmz_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"EXTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.external_inbound_acl.space,
|
||||
"OUTBOUND": self.external_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
return space
|
||||
shape = {
|
||||
"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}),
|
||||
"ACL": spaces.Dict(
|
||||
{
|
||||
"INTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.internal_inbound_acl.space,
|
||||
"OUTBOUND": self.internal_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"DMZ": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.dmz_inbound_acl.space,
|
||||
"OUTBOUND": self.dmz_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"EXTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.external_inbound_acl.space,
|
||||
"OUTBOUND": self.external_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
}
|
||||
),
|
||||
}
|
||||
if self.include_users:
|
||||
shape["users"] = spaces.Dict(
|
||||
{"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)}
|
||||
)
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation:
|
||||
|
||||
@@ -54,7 +54,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
"""
|
||||
If True, files and folders must be scanned to update the health state. If False, true state is always shown.
|
||||
"""
|
||||
include_users: Optional[bool] = True
|
||||
include_users: Optional[bool] = None
|
||||
"""If True, report user session information."""
|
||||
|
||||
def __init__(
|
||||
@@ -191,25 +191,31 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
if node_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
is_on = node_state["operating_state"] == 1
|
||||
if not is_on:
|
||||
obs = {**self.default_observation}
|
||||
|
||||
else:
|
||||
obs = {}
|
||||
if self.services:
|
||||
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
|
||||
if self.applications:
|
||||
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
|
||||
if self.folders:
|
||||
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
|
||||
if self.nics:
|
||||
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
|
||||
if self.include_num_access:
|
||||
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
|
||||
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
|
||||
if self.include_users:
|
||||
sess = node_state["services"]["UserSessionManager"]
|
||||
obs["users"] = {
|
||||
"local_login": 1 if sess["current_local_user"] else 0,
|
||||
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
|
||||
}
|
||||
|
||||
obs["operating_status"] = node_state["operating_state"]
|
||||
if self.services:
|
||||
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
|
||||
if self.applications:
|
||||
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
|
||||
if self.folders:
|
||||
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
|
||||
if self.nics:
|
||||
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
|
||||
if self.include_num_access:
|
||||
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
|
||||
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
|
||||
if self.include_users:
|
||||
sess = node_state["services"]["UserSessionManager"]
|
||||
obs["users"] = {
|
||||
"local_login": 1 if sess["current_local_user"] else 0,
|
||||
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
|
||||
}
|
||||
return obs
|
||||
|
||||
@property
|
||||
|
||||
@@ -56,7 +56,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
"""List of IP addresses for encoding ACLs."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of IP wildcards for encoding ACLs."""
|
||||
port_list: Optional[List[int]] = None
|
||||
port_list: Optional[List[str]] = None
|
||||
"""List of ports for encoding ACLs."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocols for encoding ACLs."""
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import cached_property
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import BaseModel, ConfigDict, model_validator, ValidationError
|
||||
from pydantic import BaseModel, computed_field, ConfigDict, Field, model_validator, ValidationError
|
||||
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
|
||||
@@ -140,7 +141,7 @@ class NullObservation(AbstractObservation, identifier="NONE"):
|
||||
return cls()
|
||||
|
||||
|
||||
class ObservationManager:
|
||||
class ObservationManager(BaseModel):
|
||||
"""
|
||||
Manage the observations of an Agent.
|
||||
|
||||
@@ -150,15 +151,66 @@ class ObservationManager:
|
||||
3. Formatting this information so an agent can use it to make decisions.
|
||||
"""
|
||||
|
||||
def __init__(self, obs: AbstractObservation) -> None:
|
||||
"""Initialise observation space.
|
||||
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
|
||||
:param observation: Observation object
|
||||
:type observation: AbstractObservation
|
||||
"""
|
||||
self.obs: AbstractObservation = obs
|
||||
self.current_observation: ObsType
|
||||
"""Cached copy of the observation at the time it was most recently calculated."""
|
||||
class ConfigSchema(BaseModel):
|
||||
"""Config Schema for Observation Manager."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
type: str = "NONE"
|
||||
"""Identifier name for the top-level observation."""
|
||||
options: AbstractObservation.ConfigSchema = Field(
|
||||
default_factory=lambda: NullObservation.ConfigSchema(), validate_default=True
|
||||
)
|
||||
"""Options to pass into the top-level observation during creation."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def resolve_obs_options_type(cls, data: Any) -> Any:
|
||||
"""
|
||||
When constructing the model from a dict, resolve the correct observation class based on `type` field.
|
||||
|
||||
Workaround: The `options` field is statically typed as AbstractObservation. Therefore, it falls over when
|
||||
passing in data that adheres to a subclass schema rather than the plain AbstractObservation schema. There is
|
||||
a way to do this properly using discriminated union, but most advice on the internet assumes that the full
|
||||
list of types between which to discriminate is known ahead-of-time. That is not the case for us, because of
|
||||
our plugin architecture.
|
||||
|
||||
We may be able to revisit and implement a better solution when needed using the following resources as
|
||||
research starting points:
|
||||
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
|
||||
https://github.com/pydantic/pydantic/issues/7366
|
||||
https://github.com/pydantic/pydantic/issues/7462
|
||||
https://github.com/pydantic/pydantic/pull/7983
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
# (TODO: duplicate default definition between here and the actual model)
|
||||
obs_type = data["type"] if "type" in data else "NONE"
|
||||
obs_class = AbstractObservation._registry[obs_type]
|
||||
|
||||
# if no options are passed in, try to create a default schema. Only works if there are no mandatory fields
|
||||
if "options" not in data:
|
||||
data["options"] = obs_class.ConfigSchema()
|
||||
|
||||
# if options passed as a dict, validate against schema
|
||||
elif isinstance(data["options"], dict):
|
||||
data["options"] = obs_class.ConfigSchema(**data["options"])
|
||||
|
||||
return data
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: ObservationManager.ConfigSchema())
|
||||
|
||||
current_observation: ObsType = 0
|
||||
|
||||
@computed_field
|
||||
@cached_property
|
||||
def obs(self) -> AbstractObservation:
|
||||
"""Create the main observation component for the observation manager from the config."""
|
||||
obs_class = AbstractObservation._registry[self.config.type]
|
||||
obs_instance = obs_class.from_config(config=self.config.options)
|
||||
return obs_instance
|
||||
|
||||
def update(self, state: Dict) -> Dict:
|
||||
"""
|
||||
|
||||
@@ -31,7 +31,7 @@ class AbstractObservation(ABC):
|
||||
"""Initialise an observation. This method must be overwritten."""
|
||||
self.default_observation: ObsType
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register an observation type.
|
||||
|
||||
@@ -40,6 +40,8 @@ class AbstractObservation(ABC):
|
||||
:raises ValueError: When attempting to create a component with a name that is already in use.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Duplicate observation component type {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@@ -33,13 +33,13 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
"""List of IP addresses for encoding ACLs."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of IP wildcards for encoding ACLs."""
|
||||
port_list: Optional[List[int]] = None
|
||||
port_list: Optional[List[str]] = None
|
||||
"""List of ports for encoding ACLs."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocols for encoding ACLs."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of rules ACL rules to show."""
|
||||
include_users: Optional[bool] = True
|
||||
include_users: Optional[bool] = None
|
||||
"""If True, report user session information."""
|
||||
|
||||
def __init__(
|
||||
@@ -84,6 +84,8 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
}
|
||||
if self.ports:
|
||||
self.default_observation["PORTS"] = {i + 1: p.default_observation for i, p in enumerate(self.ports)}
|
||||
if self.include_users:
|
||||
self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -98,16 +100,21 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
if router_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["ACL"] = self.acl.observe(state)
|
||||
if self.ports:
|
||||
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
|
||||
if self.include_users:
|
||||
sess = router_state["services"]["UserSessionManager"]
|
||||
obs["users"] = {
|
||||
"local_login": 1 if sess["current_local_user"] else 0,
|
||||
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
|
||||
}
|
||||
is_on = router_state["operating_state"] == 1
|
||||
if not is_on:
|
||||
obs = {**self.default_observation}
|
||||
|
||||
else:
|
||||
obs = {}
|
||||
obs["ACL"] = self.acl.observe(state)
|
||||
if self.ports:
|
||||
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
|
||||
if self.include_users:
|
||||
sess = router_state["services"]["UserSessionManager"]
|
||||
obs["users"] = {
|
||||
"local_login": 1 if sess["current_local_user"] else 0,
|
||||
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
|
||||
}
|
||||
return obs
|
||||
|
||||
@property
|
||||
@@ -121,6 +128,10 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
shape = {"ACL": self.acl.space}
|
||||
if self.ports:
|
||||
shape["PORTS"] = spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)})
|
||||
if self.include_users:
|
||||
shape["users"] = spaces.Dict(
|
||||
{"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)}
|
||||
)
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -30,7 +30,7 @@ the structure:
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Never
|
||||
|
||||
from primaite import getLogger
|
||||
@@ -48,21 +48,17 @@ class AbstractReward(BaseModel):
|
||||
|
||||
config: "AbstractReward.ConfigSchema"
|
||||
|
||||
# def __init__(self, schema_name, **kwargs):
|
||||
# super.__init__(self, **kwargs)
|
||||
# # Create ConfigSchema class
|
||||
# self.config_class = type(schema_name, (BaseModel, ABC), **kwargs)
|
||||
# self.config = self.config_class()
|
||||
|
||||
class ConfigSchema(BaseModel, ABC):
|
||||
"""Config schema for AbstractReward."""
|
||||
|
||||
type: str
|
||||
type: str = ""
|
||||
|
||||
_registry: ClassVar[Dict[str, Type["AbstractReward"]]] = {}
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Duplicate reward {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
@@ -381,14 +377,19 @@ class SharedReward(AbstractReward, identifier="SHARED_REWARD"):
|
||||
|
||||
|
||||
class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"):
|
||||
"""Apply a negative reward when taking any action except DONOTHING."""
|
||||
"""Apply a negative reward when taking any action except do_nothing."""
|
||||
|
||||
config: "ActionPenalty.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""Config schema for ActionPenalty."""
|
||||
"""Config schema for ActionPenalty.
|
||||
|
||||
:param action_penalty: Reward to give agents for taking any action except do_nothing
|
||||
:type action_penalty: float
|
||||
:param do_nothing_penalty: Reward to give agent for taking the do_nothing action
|
||||
:type do_nothing_penalty: float
|
||||
"""
|
||||
|
||||
type: str = "ACTION_PENALTY"
|
||||
action_penalty: float = -1.0
|
||||
do_nothing_penalty: float = 0.0
|
||||
|
||||
@@ -402,21 +403,81 @@ class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"):
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
if last_action_response.action == "DONOTHING":
|
||||
if last_action_response.action == "do_nothing":
|
||||
return self.config.do_nothing_penalty
|
||||
|
||||
else:
|
||||
return self.config.action_penalty
|
||||
|
||||
|
||||
class RewardFunction:
|
||||
class _SingleComponentConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
type: str
|
||||
options: AbstractReward.ConfigSchema
|
||||
weight: float = 1.0
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def resolve_obs_options_type(cls, data: Any) -> Any:
|
||||
"""
|
||||
When constructing the model from a dict, resolve the correct reward class based on `type` field.
|
||||
|
||||
Workaround: The `options` field is statically typed as AbstractReward. Therefore, it falls over when
|
||||
passing in data that adheres to a subclass schema rather than the plain AbstractReward schema. There is
|
||||
a way to do this properly using discriminated union, but most advice on the internet assumes that the full
|
||||
list of types between which to discriminate is known ahead-of-time. That is not the case for us, because of
|
||||
our plugin architecture.
|
||||
|
||||
We may be able to revisit and implement a better solution when needed using the following resources as
|
||||
research starting points:
|
||||
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
|
||||
https://github.com/pydantic/pydantic/issues/7366
|
||||
https://github.com/pydantic/pydantic/issues/7462
|
||||
https://github.com/pydantic/pydantic/pull/7983
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
assert "type" in data, ValueError('Reward component definition is missing the "type" key.')
|
||||
rew_type = data["type"]
|
||||
rew_class = AbstractReward._registry[rew_type]
|
||||
|
||||
# if no options are passed in, try to create a default schema. Only works if there are no mandatory fields.
|
||||
if "options" not in data:
|
||||
data["options"] = rew_class.ConfigSchema()
|
||||
|
||||
# if options are passed as a dict, validate against schema
|
||||
elif isinstance(data["options"], dict):
|
||||
data["options"] = rew_class.ConfigSchema(**data["options"])
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class RewardFunction(BaseModel):
|
||||
"""Manages the reward function for the agent."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialise the reward function object."""
|
||||
self.reward_components: List[Tuple[AbstractReward, float]] = []
|
||||
"attribute reward_components keeps track of reward components and the weights assigned to each."
|
||||
self.current_reward: float = 0.0
|
||||
self.total_reward: float = 0.0
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
class ConfigSchema(BaseModel):
|
||||
"""Config Schema for RewardFunction."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
reward_components: Iterable[_SingleComponentConfig] = []
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema())
|
||||
|
||||
reward_components: List[Tuple[AbstractReward, float]] = []
|
||||
|
||||
current_reward: float = 0.0
|
||||
total_reward: float = 0.0
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
for rew_config in self.config.reward_components:
|
||||
rew_class = AbstractReward._registry[rew_config.type]
|
||||
rew_instance = rew_class(config=rew_config.options)
|
||||
self.register_component(component=rew_instance, weight=rew_config.weight)
|
||||
|
||||
def register_component(self, component: AbstractReward, weight: float = 1.0) -> None:
|
||||
"""Add a reward component to the reward function.
|
||||
|
||||
@@ -1 +1,6 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
|
||||
from primaite.game.agent import interface
|
||||
from primaite.game.agent.scripted_agents import abstract_tap, data_manipulation_bot, probabilistic_agent, random_agent
|
||||
|
||||
__all__ = ("abstract_tap", "data_manipulation_bot", "interface", "probabilistic_agent", "random_agent")
|
||||
|
||||
61
src/primaite/game/agent/scripted_agents/abstract_tap.py
Normal file
61
src/primaite/game/agent/scripted_agents/abstract_tap.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
|
||||
|
||||
__all__ = "AbstractTAPAgent"
|
||||
|
||||
|
||||
class AbstractTAPAgent(PeriodicAgent, identifier="AbstractTAP"):
|
||||
"""Base class for TAP agents to inherit from."""
|
||||
|
||||
config: "AbstractTAPAgent.ConfigSchema" = Field(default_factory=lambda: AbstractTAPAgent.ConfigSchema())
|
||||
next_execution_timestep: int = 0
|
||||
|
||||
class AgentSettingsSchema(PeriodicAgent.AgentSettingsSchema):
|
||||
"""Schema for the `agent_settings` part of the agent config."""
|
||||
|
||||
possible_starting_nodes: List[str] = Field(default_factory=list)
|
||||
|
||||
class ConfigSchema(PeriodicAgent.ConfigSchema):
|
||||
"""Configuration schema for Abstract TAP agents."""
|
||||
|
||||
type: str = "AbstractTAP"
|
||||
agent_settings: AbstractTAPAgent.AgentSettingsSchema = Field(
|
||||
default_factory=lambda: AbstractTAPAgent.AgentSettingsSchema()
|
||||
)
|
||||
|
||||
starting_node: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""Return an action to be taken in the environment."""
|
||||
return super().get_action(obs=obs, timestep=timestep)
|
||||
|
||||
@abstractmethod
|
||||
def setup_agent(self) -> None:
|
||||
"""Set up agent."""
|
||||
pass
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
|
||||
:param timestep: The timestep to add variance to.
|
||||
"""
|
||||
random_timestep_increment = random.randint(
|
||||
-self.config.agent_settings.variance, self.config.agent_settings.variance
|
||||
)
|
||||
self.next_execution_timestep = timestep + random_timestep_increment
|
||||
|
||||
def _select_start_node(self) -> None:
|
||||
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
|
||||
# we are assuming that every node in the node manager has a data manipulation application at idx 0
|
||||
self.starting_node = random.choice(self.config.agent_settings.possible_starting_nodes)
|
||||
self.logger.debug(f"Selected starting node: {self.starting_node}")
|
||||
@@ -1,31 +1,35 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
import random
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
|
||||
|
||||
__all__ = "DataManipulationAgent"
|
||||
|
||||
|
||||
class DataManipulationAgent(AbstractScriptedAgent):
|
||||
class DataManipulationAgent(PeriodicAgent, identifier="RedDatabaseCorruptingAgent"):
|
||||
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""
|
||||
|
||||
next_execution_timestep: int = 0
|
||||
starting_node_idx: int = 0
|
||||
class AgentSettingsSchema(PeriodicAgent.AgentSettingsSchema):
|
||||
"""Schema for the `agent_settings` part of the agent config."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.setup_agent()
|
||||
target_application: str = "DataManipulationBot"
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
class ConfigSchema(PeriodicAgent.ConfigSchema):
|
||||
"""Configuration Schema for DataManipulationAgent."""
|
||||
|
||||
:param timestep: The timestep to add variance to.
|
||||
"""
|
||||
random_timestep_increment = random.randint(
|
||||
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
|
||||
type: str = "RedDatabaseCorruptingAgent"
|
||||
agent_settings: "DataManipulationAgent.AgentSettingsSchema" = Field(
|
||||
default_factory=lambda: DataManipulationAgent.AgentSettingsSchema()
|
||||
)
|
||||
self.next_execution_timestep = timestep + random_timestep_increment
|
||||
|
||||
config: "DataManipulationAgent.ConfigSchema" = Field(default_factory=lambda: DataManipulationAgent.ConfigSchema())
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._set_next_execution_timestep(timestep=self.config.agent_settings.start_step, variance=0)
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
|
||||
"""Waits until a specific timestep, then attempts to execute its data manipulation application.
|
||||
@@ -38,21 +42,14 @@ class DataManipulationAgent(AbstractScriptedAgent):
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
if timestep < self.next_execution_timestep:
|
||||
self.logger.debug(msg="Performing do NOTHING")
|
||||
return "DONOTHING", {}
|
||||
self.logger.debug(msg="Performing do nothing action")
|
||||
return "do_nothing", {}
|
||||
|
||||
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
|
||||
self._set_next_execution_timestep(
|
||||
timestep=timestep + self.config.agent_settings.frequency, variance=self.config.agent_settings.variance
|
||||
)
|
||||
self.logger.info(msg="Performing a data manipulation attack!")
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
|
||||
|
||||
def setup_agent(self) -> None:
|
||||
"""Set the next execution timestep when the episode resets."""
|
||||
self._select_start_node()
|
||||
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
|
||||
|
||||
def _select_start_node(self) -> None:
|
||||
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
|
||||
# we are assuming that every node in the node manager has a data manipulation application at idx 0
|
||||
num_nodes = len(self.action_manager.node_names)
|
||||
self.starting_node_idx = random.randint(0, num_nodes - 1)
|
||||
self.logger.debug(msg=f"Select Start Node ID: {self.starting_node_idx}")
|
||||
return "node_application_execute", {
|
||||
"node_name": self.start_node,
|
||||
"application_name": self.config.agent_settings.target_application,
|
||||
}
|
||||
|
||||
@@ -1,29 +1,28 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
"""Agents with predefined behaviours."""
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pydantic
|
||||
from gymnasium.core import ObsType
|
||||
from numpy.random import Generator
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
__all__ = "ProbabilisticAgent"
|
||||
|
||||
|
||||
class ProbabilisticAgent(AbstractScriptedAgent):
|
||||
class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent"):
|
||||
"""Scripted agent which randomly samples its action space with prescribed probabilities for each action."""
|
||||
|
||||
class Settings(pydantic.BaseModel):
|
||||
"""Config schema for Probabilistic agent settings."""
|
||||
rng: Generator = Field(default_factory=lambda: np.random.default_rng(np.random.randint(0, 65535)))
|
||||
|
||||
model_config = pydantic.ConfigDict(extra="forbid")
|
||||
"""Strict validation."""
|
||||
action_probabilities: Dict[int, float]
|
||||
class AgentSettingsSchema(AbstractScriptedAgent.AgentSettingsSchema):
|
||||
"""Schema for the `agent_settings` part of the agent config."""
|
||||
|
||||
action_probabilities: Dict[int, float] = None
|
||||
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
|
||||
# TODO: give the option to still set a random seed, but have it vary each episode in a predictable way
|
||||
# for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed.
|
||||
|
||||
@pydantic.field_validator("action_probabilities", mode="after")
|
||||
@classmethod
|
||||
@@ -44,31 +43,20 @@ class ProbabilisticAgent(AbstractScriptedAgent):
|
||||
)
|
||||
return v
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: str,
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
settings: Dict = {},
|
||||
) -> None:
|
||||
# If the action probabilities are not specified, create equal probabilities for all actions
|
||||
if "action_probabilities" not in settings:
|
||||
num_actions = len(action_space.action_map)
|
||||
settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}}
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration schema for Probabilistic Agent."""
|
||||
|
||||
# The random number seed for np.random is dependent on whether a random number seed is set
|
||||
# in the config file. If there is one it is processed by set_random_seed() in environment.py
|
||||
# and as a consequence the the sequence of rng_seed's used here will be repeatable.
|
||||
self.settings = ProbabilisticAgent.Settings(**settings)
|
||||
rng_seed = np.random.randint(0, 65535)
|
||||
self.rng = np.random.default_rng(rng_seed)
|
||||
type: str = "ProbabilisticAgent"
|
||||
agent_settings: "ProbabilisticAgent.AgentSettingsSchema" = Field(
|
||||
default_factory=lambda: ProbabilisticAgent.AgentSettingsSchema()
|
||||
)
|
||||
|
||||
# convert probabilities from
|
||||
self.probabilities = np.asarray(list(self.settings.action_probabilities.values()))
|
||||
config: "ProbabilisticAgent.ConfigSchema" = Field(default_factory=lambda: ProbabilisticAgent.ConfigSchema())
|
||||
|
||||
super().__init__(agent_name, action_space, observation_space, reward_function)
|
||||
self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}")
|
||||
@property
|
||||
def probabilities(self) -> Dict[str, int]:
|
||||
"""Convenience method to view the probabilities of the Agent."""
|
||||
return np.asarray(list(self.config.agent_settings.action_probabilities.values()))
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
|
||||
@@ -1,20 +1,27 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
import random
|
||||
from typing import Dict, Optional, Tuple
|
||||
from functools import cached_property
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import BaseModel
|
||||
from pydantic import computed_field, Field, model_validator
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
__all__ = ("RandomAgent", "PeriodicAgent")
|
||||
|
||||
|
||||
class RandomAgent(AbstractScriptedAgent):
|
||||
class RandomAgent(AbstractScriptedAgent, identifier="RandomAgent"):
|
||||
"""Agent that ignores its observation and acts completely at random."""
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
config: "RandomAgent.ConfigSchema" = Field(default_factory=lambda: RandomAgent.ConfigSchema())
|
||||
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration Schema for Random Agents."""
|
||||
|
||||
type: str = "RandomAgent"
|
||||
|
||||
def get_action(self) -> Tuple[str, Dict]:
|
||||
"""Sample the action space randomly.
|
||||
|
||||
:param obs: Current observation for this agent, not used in RandomAgent
|
||||
@@ -27,41 +34,60 @@ class RandomAgent(AbstractScriptedAgent):
|
||||
return self.action_manager.get_action(self.action_manager.space.sample())
|
||||
|
||||
|
||||
class PeriodicAgent(AbstractScriptedAgent):
|
||||
class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"):
|
||||
"""Agent that does nothing most of the time, but executes application at regular intervals (with variance)."""
|
||||
|
||||
class Settings(BaseModel):
|
||||
"""Configuration values for when an agent starts performing actions."""
|
||||
config: "PeriodicAgent.ConfigSchema" = Field(default_factory=lambda: PeriodicAgent.ConfigSchema())
|
||||
|
||||
start_step: int = 20
|
||||
"The timestep at which an agent begins performing it's actions."
|
||||
start_variance: int = 5
|
||||
"Deviation around the start step."
|
||||
class AgentSettingsSchema(AbstractScriptedAgent.AgentSettingsSchema):
|
||||
"""Schema for the `agent_settings` part of the agent config."""
|
||||
|
||||
start_step: int = 5
|
||||
"The timestep at which an agent begins performing it's actions"
|
||||
frequency: int = 5
|
||||
"The number of timesteps to wait between performing actions."
|
||||
"The number of timesteps to wait between performing actions"
|
||||
variance: int = 0
|
||||
"The amount the frequency can randomly change to."
|
||||
max_executions: int = 999999
|
||||
"Maximum number of times the agent can execute its action."
|
||||
"The amount the frequency can randomly change to"
|
||||
possible_start_nodes: List[str]
|
||||
target_application: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: str,
|
||||
action_space: ActionManager,
|
||||
observation_space: ObservationManager,
|
||||
reward_function: RewardFunction,
|
||||
settings: Optional[Settings] = None,
|
||||
) -> None:
|
||||
"""Initialise PeriodicAgent."""
|
||||
super().__init__(
|
||||
agent_name=agent_name,
|
||||
action_space=action_space,
|
||||
observation_space=observation_space,
|
||||
reward_function=reward_function,
|
||||
@model_validator(mode="after")
|
||||
def check_variance_lt_frequency(self) -> "PeriodicAgent.ConfigSchema":
|
||||
"""
|
||||
Make sure variance is equal to or lower than frequency.
|
||||
|
||||
This is because the calculation for the next execution time is now + (frequency +- variance).
|
||||
If variance were greater than frequency, sometimes the bracketed term would be negative
|
||||
and the attack would never happen again.
|
||||
"""
|
||||
if self.variance >= self.frequency:
|
||||
raise ValueError(
|
||||
f"Agent start settings error: variance must be lower than frequency "
|
||||
f"{self.variance=}, {self.frequency=}"
|
||||
)
|
||||
return self
|
||||
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration Schema for Periodic Agent."""
|
||||
|
||||
type: str = "PeriodicAgent"
|
||||
"""Name of the agent."""
|
||||
agent_settings: "PeriodicAgent.AgentSettingsSchema" = Field(
|
||||
default_factory=lambda: PeriodicAgent.AgentSettingsSchema()
|
||||
)
|
||||
self.settings = settings or PeriodicAgent.Settings()
|
||||
self._set_next_execution_timestep(timestep=self.settings.start_step, variance=self.settings.start_variance)
|
||||
self.num_executions = 0
|
||||
|
||||
max_executions: int = 999999
|
||||
"Maximum number of times the agent can execute its action."
|
||||
num_executions: int = 0
|
||||
"""Number of times the agent has executed an action."""
|
||||
next_execution_timestep: int = 0
|
||||
"""Timestep of the next action execution by the agent."""
|
||||
|
||||
@computed_field
|
||||
@cached_property
|
||||
def start_node(self) -> str:
|
||||
"""On instantiation, randomly select a start node."""
|
||||
return random.choice(self.config.agent_settings.possible_start_nodes)
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int, variance: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
@@ -76,9 +102,14 @@ class PeriodicAgent(AbstractScriptedAgent):
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
|
||||
"""Do nothing, unless the current timestep is the next execution timestep, in which case do the action."""
|
||||
if timestep == self.next_execution_timestep and self.num_executions < self.settings.max_executions:
|
||||
if timestep == self.next_execution_timestep and self.num_executions < self.max_executions:
|
||||
self.num_executions += 1
|
||||
self._set_next_execution_timestep(timestep + self.settings.frequency, self.settings.variance)
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}
|
||||
self._set_next_execution_timestep(
|
||||
timestep + self.config.agent_settings.frequency, self.config.agent_settings.variance
|
||||
)
|
||||
return "node_application_execute", {
|
||||
"node_name": self.start_node,
|
||||
"application_name": self.config.agent_settings.target_application,
|
||||
}
|
||||
|
||||
return "DONOTHING", {}
|
||||
return "do_nothing", {}
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
import random
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
|
||||
|
||||
class TAP001(AbstractScriptedAgent):
|
||||
"""
|
||||
TAP001 | Mobile Malware -- Ransomware Variant.
|
||||
|
||||
Scripted Red Agent. Capable of one action; launching the kill-chain (Ransomware Application)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.setup_agent()
|
||||
|
||||
next_execution_timestep: int = 0
|
||||
starting_node_idx: int = 0
|
||||
installed: bool = False
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
|
||||
:param timestep: The timestep to add variance to.
|
||||
"""
|
||||
random_timestep_increment = random.randint(
|
||||
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
|
||||
)
|
||||
self.next_execution_timestep = timestep + random_timestep_increment
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
|
||||
"""Waits until a specific timestep, then attempts to execute the ransomware application.
|
||||
|
||||
This application acts a wrapper around the kill-chain, similar to green-analyst and
|
||||
the previous UC2 data manipulation bot.
|
||||
|
||||
:param obs: Current observation for this agent.
|
||||
:type obs: ObsType
|
||||
:param timestep: The current simulation timestep, used for scheduling actions
|
||||
:type timestep: int
|
||||
:return: Action formatted in CAOS format
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
if timestep < self.next_execution_timestep:
|
||||
return "DONOTHING", {}
|
||||
|
||||
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
|
||||
|
||||
if not self.installed:
|
||||
self.installed = True
|
||||
return "NODE_APPLICATION_INSTALL", {
|
||||
"node_id": self.starting_node_idx,
|
||||
"application_name": "RansomwareScript",
|
||||
}
|
||||
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
|
||||
|
||||
def setup_agent(self) -> None:
|
||||
"""Set the next execution timestep when the episode resets."""
|
||||
self._select_start_node()
|
||||
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
|
||||
for n, act in self.action_manager.action_map.items():
|
||||
if not act[0] == "NODE_APPLICATION_INSTALL":
|
||||
continue
|
||||
if act[1]["node_id"] == self.starting_node_idx:
|
||||
self.ip_address = act[1]["ip_address"]
|
||||
return
|
||||
raise RuntimeError("TAP001 agent could not find database server ip address in action map")
|
||||
|
||||
def _select_start_node(self) -> None:
|
||||
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
|
||||
# we are assuming that every node in the node manager has a data manipulation application at idx 0
|
||||
num_nodes = len(self.action_manager.node_names)
|
||||
self.starting_node_idx = random.randint(0, num_nodes - 1)
|
||||
@@ -7,14 +7,8 @@ import numpy as np
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import DEFAULT_BANDWIDTH, getLogger
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction, SharedReward
|
||||
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
|
||||
from primaite.game.agent.scripted_agents.tap001 import TAP001
|
||||
from primaite.game.agent.interface import AbstractAgent, ProxyAgent
|
||||
from primaite.game.agent.rewards import SharedReward
|
||||
from primaite.game.science import graph_has_cycle, topological_sort
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.network.creation import NetworkNodeAdder
|
||||
@@ -44,7 +38,7 @@ from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.services.terminal.terminal import Terminal
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
from primaite.simulator.system.software import Software
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol
|
||||
from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -258,6 +252,7 @@ class PrimaiteGame:
|
||||
net = sim.network
|
||||
|
||||
simulation_config = cfg.get("simulation", {})
|
||||
defaults_config = cfg.get("defaults", {})
|
||||
network_config = simulation_config.get("network", {})
|
||||
airspace_cfg = network_config.get("airspace", {})
|
||||
frequency_max_capacity_mbps_cfg = airspace_cfg.get("frequency_max_capacity_mbps", {})
|
||||
@@ -283,6 +278,18 @@ class PrimaiteGame:
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
# TODO: handle simulation defaults more cleanly
|
||||
if "node_start_up_duration" in defaults_config:
|
||||
new_node.start_up_duration = defaults_config["node_startup_duration"]
|
||||
if "node_shut_down_duration" in defaults_config:
|
||||
new_node.shut_down_duration = defaults_config["node_shut_down_duration"]
|
||||
if "node_scan_duration" in defaults_config:
|
||||
new_node.node_scan_duration = defaults_config["node_scan_duration"]
|
||||
if "folder_scan_duration" in defaults_config:
|
||||
new_node.file_system._default_folder_scan_duration = defaults_config["folder_scan_duration"]
|
||||
if "folder_restore_duration" in defaults_config:
|
||||
new_node.file_system._default_folder_restore_duration = defaults_config["folder_restore_duration"]
|
||||
|
||||
if "users" in node_cfg and new_node.software_manager.software.get("UserManager"):
|
||||
user_manager: UserManager = new_node.software_manager.software["UserManager"] # noqa
|
||||
for user_cfg in node_cfg["users"]:
|
||||
@@ -315,12 +322,12 @@ class PrimaiteGame:
|
||||
|
||||
if service_class is not None:
|
||||
_LOGGER.debug(f"installing {service_type} on node {new_node.config.hostname}")
|
||||
new_node.software_manager.install(service_class, **service_cfg.get("options", {}))
|
||||
new_node.software_manager.install(service_class)
|
||||
new_service = new_node.software_manager.software[service_class.__name__]
|
||||
|
||||
# fixing duration for the service
|
||||
if "fix_duration" in service_cfg.get("options", {}):
|
||||
new_service.fixing_duration = service_cfg["options"]["fix_duration"]
|
||||
if "fixing_duration" in service_cfg.get("options", {}):
|
||||
new_service.config.fixing_duration = service_cfg["options"]["fixing_duration"]
|
||||
|
||||
_set_software_listen_on_ports(new_service, service_cfg)
|
||||
# start the service
|
||||
@@ -329,6 +336,15 @@ class PrimaiteGame:
|
||||
msg = f"Configuration contains an invalid service type: {service_type}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
# TODO: handle simulation defaults more cleanly
|
||||
if "service_fix_duration" in defaults_config:
|
||||
new_service.fixing_duration = defaults_config["service_fix_duration"]
|
||||
if "service_restart_duration" in defaults_config:
|
||||
new_service.restart_duration = defaults_config["service_restart_duration"]
|
||||
if "service_install_duration" in defaults_config:
|
||||
new_service.install_duration = defaults_config["service_install_duration"]
|
||||
|
||||
# service-dependent options
|
||||
if service_type == "DNSClient":
|
||||
if "options" in service_cfg:
|
||||
@@ -361,74 +377,20 @@ class PrimaiteGame:
|
||||
application_type = application_cfg["type"]
|
||||
|
||||
if application_type in Application._registry:
|
||||
new_node.software_manager.install(Application._registry[application_type])
|
||||
application_class = Application._registry[application_type]
|
||||
application_options = application_cfg.get("options", {})
|
||||
application_options["type"] = application_type
|
||||
new_node.software_manager.install(application_class, software_config=application_options)
|
||||
new_application = new_node.software_manager.software[application_type] # grab the instance
|
||||
|
||||
# fixing duration for the application
|
||||
if "fix_duration" in application_cfg.get("options", {}):
|
||||
new_application.fixing_duration = application_cfg["options"]["fix_duration"]
|
||||
else:
|
||||
msg = f"Configuration contains an invalid application type: {application_type}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
_set_software_listen_on_ports(new_application, application_cfg)
|
||||
|
||||
# run the application
|
||||
new_application.run()
|
||||
|
||||
if application_type == "DataManipulationBot":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
server_ip_address=IPv4Address(opt.get("server_ip")),
|
||||
server_password=opt.get("server_password"),
|
||||
payload=opt.get("payload", "DELETE"),
|
||||
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
|
||||
data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")),
|
||||
)
|
||||
elif application_type == "RansomwareScript":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
server_ip_address=IPv4Address(opt.get("server_ip")) if opt.get("server_ip") else None,
|
||||
server_password=opt.get("server_password"),
|
||||
payload=opt.get("payload", "ENCRYPT"),
|
||||
)
|
||||
elif application_type == "DatabaseClient":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
server_ip_address=IPv4Address(opt.get("db_server_ip")),
|
||||
server_password=opt.get("server_password"),
|
||||
)
|
||||
elif application_type == "WebBrowser":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.target_url = opt.get("target_url")
|
||||
elif application_type == "DoSBot":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
target_ip_address=IPv4Address(opt.get("target_ip_address")),
|
||||
target_port=PORT_LOOKUP[opt.get("target_port", "POSTGRES_SERVER")],
|
||||
payload=opt.get("payload"),
|
||||
repeat=bool(opt.get("repeat")),
|
||||
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
|
||||
dos_intensity=float(opt.get("dos_intensity", "1.0")),
|
||||
max_sessions=int(opt.get("max_sessions", "1000")),
|
||||
)
|
||||
elif application_type == "C2Beacon":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
c2_server_ip_address=IPv4Address(opt.get("c2_server_ip_address")),
|
||||
keep_alive_frequency=(opt.get("keep_alive_frequency", 5)),
|
||||
masquerade_protocol=PROTOCOL_LOOKUP[
|
||||
(opt.get("masquerade_protocol", PROTOCOL_LOOKUP["TCP"]))
|
||||
],
|
||||
masquerade_port=PORT_LOOKUP[(opt.get("masquerade_port", PORT_LOOKUP["HTTP"]))],
|
||||
)
|
||||
if "network_interfaces" in node_cfg:
|
||||
for nic_num, nic_cfg in node_cfg["network_interfaces"].items():
|
||||
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))
|
||||
@@ -470,76 +432,10 @@ class PrimaiteGame:
|
||||
agents_cfg = cfg.get("agents", [])
|
||||
|
||||
for agent_cfg in agents_cfg:
|
||||
agent_ref = agent_cfg["ref"] # noqa: F841
|
||||
agent_type = agent_cfg["type"]
|
||||
action_space_cfg = agent_cfg["action_space"]
|
||||
observation_space_cfg = agent_cfg["observation_space"]
|
||||
reward_function_cfg = agent_cfg["reward_function"]
|
||||
|
||||
# CREATE OBSERVATION SPACE
|
||||
obs_space = ObservationManager.from_config(observation_space_cfg)
|
||||
|
||||
# CREATE ACTION SPACE
|
||||
action_space = ActionManager.from_config(game, action_space_cfg)
|
||||
|
||||
# CREATE REWARD FUNCTION
|
||||
reward_function = RewardFunction.from_config(reward_function_cfg)
|
||||
|
||||
# CREATE AGENT
|
||||
if agent_type == "ProbabilisticAgent":
|
||||
# TODO: implement non-random agents and fix this parsing
|
||||
settings = agent_cfg.get("agent_settings", {})
|
||||
new_agent = ProbabilisticAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
settings=settings,
|
||||
)
|
||||
elif agent_type == "PeriodicAgent":
|
||||
settings = PeriodicAgent.Settings(**agent_cfg.get("settings", {}))
|
||||
new_agent = PeriodicAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
elif agent_type == "ProxyAgent":
|
||||
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
|
||||
new_agent = ProxyAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
game.rl_agents[agent_cfg["ref"]] = new_agent
|
||||
elif agent_type == "RedDatabaseCorruptingAgent":
|
||||
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
|
||||
|
||||
new_agent = DataManipulationAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
elif agent_type == "TAP001":
|
||||
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
|
||||
new_agent = TAP001(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
else:
|
||||
msg = f"Configuration error: {agent_type} is not a valid agent type."
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
new_agent = AbstractAgent.from_config(agent_cfg)
|
||||
game.agents[agent_cfg["ref"]] = new_agent
|
||||
if isinstance(new_agent, ProxyAgent):
|
||||
game.rl_agents[agent_cfg["ref"]] = new_agent
|
||||
|
||||
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
|
||||
game.setup_reward_sharing()
|
||||
|
||||
@@ -11,6 +11,15 @@
|
||||
"PrimAITE environments support action masking. The action mask shows which of the agent's actions are applicable with the current environment state. For example, a node can only be turned on if it is currently turned off."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -19,7 +28,7 @@
|
||||
"source": [
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from prettytable import PrettyTable\n"
|
||||
"from prettytable import PrettyTable"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -195,7 +204,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -209,7 +218,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -51,24 +51,15 @@
|
||||
" - ref: CustomC2Agent\n",
|
||||
" team: RED\n",
|
||||
" type: ProxyAgent\n",
|
||||
" observation_space: null\n",
|
||||
"\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" - type: DONOTHING\n",
|
||||
" - type: NODE_APPLICATION_INSTALL\n",
|
||||
" - type: NODE_APPLICATION_EXECUTE\n",
|
||||
" - type: CONFIGURE_C2_BEACON\n",
|
||||
" - type: C2_SERVER_RANSOMWARE_LAUNCH\n",
|
||||
" - type: C2_SERVER_RANSOMWARE_CONFIGURE\n",
|
||||
" - type: C2_SERVER_TERMINAL_COMMAND\n",
|
||||
" - type: C2_SERVER_DATA_EXFILTRATE\n",
|
||||
" options:\n",
|
||||
" nodes:\n",
|
||||
" - node_name: web_server\n",
|
||||
" applications: \n",
|
||||
" applications:\n",
|
||||
" - application_name: C2Beacon\n",
|
||||
" - node_name: client_1\n",
|
||||
" applications: \n",
|
||||
" applications:\n",
|
||||
" - application_name: C2Server\n",
|
||||
" max_folders_per_node: 1\n",
|
||||
" max_files_per_folder: 1\n",
|
||||
@@ -82,15 +73,15 @@
|
||||
" - 0.0.0.1\n",
|
||||
" action_map:\n",
|
||||
" 0:\n",
|
||||
" action: DONOTHING\n",
|
||||
" action: do_nothing\n",
|
||||
" options: {}\n",
|
||||
" 1:\n",
|
||||
" action: NODE_APPLICATION_INSTALL\n",
|
||||
" action: node_application_install\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" application_name: C2Beacon\n",
|
||||
" 2:\n",
|
||||
" action: CONFIGURE_C2_BEACON\n",
|
||||
" action: configure_c2_beacon\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" config:\n",
|
||||
@@ -99,12 +90,12 @@
|
||||
" masquerade_protocol:\n",
|
||||
" masquerade_port:\n",
|
||||
" 3:\n",
|
||||
" action: NODE_APPLICATION_EXECUTE\n",
|
||||
" action: node_application_execute\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" application_id: 0 \n",
|
||||
" application_id: 0\n",
|
||||
" 4:\n",
|
||||
" action: C2_SERVER_TERMINAL_COMMAND\n",
|
||||
" action: c2_server_terminal_command\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
" ip_address:\n",
|
||||
@@ -112,20 +103,20 @@
|
||||
" username: admin\n",
|
||||
" password: admin\n",
|
||||
" commands:\n",
|
||||
" - \n",
|
||||
" -\n",
|
||||
" - software_manager\n",
|
||||
" - application\n",
|
||||
" - install\n",
|
||||
" - RansomwareScript\n",
|
||||
" 5:\n",
|
||||
" action: C2_SERVER_RANSOMWARE_CONFIGURE\n",
|
||||
" action: c2_server_ransomware_configure\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
" config:\n",
|
||||
" server_ip_address: 192.168.1.14\n",
|
||||
" payload: ENCRYPT\n",
|
||||
" 6:\n",
|
||||
" action: C2_SERVER_DATA_EXFILTRATE\n",
|
||||
" action: c2_server_data_exfiltrate\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
" target_file_name: \"database.db\"\n",
|
||||
@@ -134,14 +125,14 @@
|
||||
" target_ip_address: 192.168.1.14\n",
|
||||
" account:\n",
|
||||
" username: admin\n",
|
||||
" password: admin \n",
|
||||
" password: admin\n",
|
||||
"\n",
|
||||
" 7:\n",
|
||||
" action: C2_SERVER_RANSOMWARE_LAUNCH\n",
|
||||
" action: c2_server_ransomware_launch\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
" 8:\n",
|
||||
" action: CONFIGURE_C2_BEACON\n",
|
||||
" action: configure_c2_beacon\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" config:\n",
|
||||
@@ -150,7 +141,7 @@
|
||||
" masquerade_protocol: TCP\n",
|
||||
" masquerade_port: DNS\n",
|
||||
" 9:\n",
|
||||
" action: CONFIGURE_C2_BEACON\n",
|
||||
" action: configure_c2_beacon\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" config:\n",
|
||||
@@ -177,7 +168,7 @@
|
||||
" # removing all agents & adding the custom agent.\n",
|
||||
" cfg['agents'] = {}\n",
|
||||
" cfg['agents'] = c2_agent_yaml\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"env = PrimaiteGymEnv(env_config=cfg)"
|
||||
]
|
||||
@@ -222,7 +213,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### **Command and Control** | C2 Beacon Actions | NODE_APPLICATION_INSTALL\n",
|
||||
"### **Command and Control** | C2 Beacon Actions | node_application_install\n",
|
||||
"\n",
|
||||
"The custom proxy red agent defined at the start of this notebook has been configured to install the C2 Beacon as action ``1`` in it's action map. \n",
|
||||
"\n",
|
||||
@@ -230,10 +221,6 @@
|
||||
"\n",
|
||||
"```yaml\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: NODE_APPLICATION_INSTALL\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" - node_name: web_server\n",
|
||||
@@ -243,7 +230,7 @@
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
" 1:\n",
|
||||
" action: NODE_APPLICATION_INSTALL \n",
|
||||
" action: node_application_install \n",
|
||||
" options:\n",
|
||||
" node_id: 0 # Index 0 at the node list.\n",
|
||||
" application_name: C2Beacon\n",
|
||||
@@ -265,7 +252,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### **Command and Control** | C2 Beacon Actions | CONFIGURE_C2_BEACON \n",
|
||||
"### **Command and Control** | C2 Beacon Actions | configure_c2_beacon \n",
|
||||
"\n",
|
||||
"The custom proxy red agent defined at the start of this notebook can configure the C2 Beacon via action ``2`` in it's action map. \n",
|
||||
"\n",
|
||||
@@ -273,10 +260,6 @@
|
||||
"\n",
|
||||
"```yaml\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: CONFIGURE_C2_BEACON\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" - node_name: web_server\n",
|
||||
@@ -285,7 +268,7 @@
|
||||
" action_map:\n",
|
||||
" ...\n",
|
||||
" 2:\n",
|
||||
" action: CONFIGURE_C2_BEACON\n",
|
||||
" action: configure_c2_beacon\n",
|
||||
" options:\n",
|
||||
" node_id: 0 # Node Index\n",
|
||||
" config: # Further information about these config options can be found at the bottom of this notebook.\n",
|
||||
@@ -312,18 +295,14 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### **Command and Control** | C2 Beacon Actions | NODE_APPLICATION_EXECUTE\n",
|
||||
"### **Command and Control** | C2 Beacon Actions | node_application_execute\n",
|
||||
"\n",
|
||||
"The final action is ``NODE_APPLICATION_EXECUTE`` which is used to establish a connection for the C2 application. This action can be called by the Red Agent via action ``3`` in it's action map. \n",
|
||||
"The final action is ``node_application_execute`` which is used to establish a connection for the C2 application. This action can be called by the Red Agent via action ``3`` in it's action map. \n",
|
||||
"\n",
|
||||
"The yaml snippet below shows all the relevant agent options for this action:\n",
|
||||
"\n",
|
||||
"```yaml\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: NODE_APPLICATION_EXECUTE\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" - node_name: web_server\n",
|
||||
@@ -334,7 +313,7 @@
|
||||
" action_map:\n",
|
||||
" ...\n",
|
||||
" 3:\n",
|
||||
" action: NODE_APPLICATION_EXECUTE\n",
|
||||
" action: node_application_execute\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" application_id: 0\n",
|
||||
@@ -347,7 +326,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.step(3) "
|
||||
"env.step(3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -390,10 +369,6 @@
|
||||
"\n",
|
||||
"``` yaml\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: C2_SERVER_TERMINAL_COMMAND\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" ...\n",
|
||||
@@ -441,7 +416,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### **Command and Control** | C2 Server Actions | C2_SERVER_RANSOMWARE_CONFIGURE\n",
|
||||
"### **Command and Control** | C2 Server Actions | c2_server_ransomware_configure\n",
|
||||
"\n",
|
||||
"Another action the C2 Server grants is the ability for a Red Agent to configure the RansomwareScript via the C2 Server rather than the note directly.\n",
|
||||
"\n",
|
||||
@@ -451,10 +426,6 @@
|
||||
"\n",
|
||||
"``` yaml\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: C2_SERVER_RANSOMWARE_CONFIGURE\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" ...\n",
|
||||
@@ -464,7 +435,7 @@
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
" 5:\n",
|
||||
" action: C2_SERVER_RANSOMWARE_CONFIG\n",
|
||||
" action: c2_server_ransomware_configure\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
" config:\n",
|
||||
@@ -497,9 +468,9 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### **Command and Control** | C2 Server Actions | C2_SERVER_DATA_EXFILTRATE\n",
|
||||
"### **Command and Control** | C2 Server Actions | c2_server_data_exfiltrate\n",
|
||||
"\n",
|
||||
"The second to last action available is the ``C2_SERVER_DATA_EXFILTRATE`` which is indexed as action ``6`` in the action map.\n",
|
||||
"The second to last action available is the ``c2_server_data_exfiltrate`` which is indexed as action ``6`` in the action map.\n",
|
||||
"\n",
|
||||
"This action can be used to exfiltrate a target file on a remote node to the C2 Beacon and the C2 Server's host file system via the ``FTP`` services.\n",
|
||||
"\n",
|
||||
@@ -507,10 +478,6 @@
|
||||
"\n",
|
||||
"``` yaml\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: C2_SERVER_DATA_EXFILTRATE\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" ...\n",
|
||||
@@ -520,7 +487,7 @@
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
" 6:\n",
|
||||
" action: C2_SERVER_DATA_EXFILTRATE\n",
|
||||
" action: c2_server_data_exfiltrate\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
" target_file_name: \"database.db\"\n",
|
||||
@@ -567,9 +534,9 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### **Command and Control** | C2 Server Actions | C2_SERVER_RANSOMWARE_LAUNCH\n",
|
||||
"### **Command and Control** | C2 Server Actions | c2_server_ransomware_launch\n",
|
||||
"\n",
|
||||
"Finally, the last available action is for the C2_SERVER_RANSOMWARE_LAUNCH to start the ransomware script installed on the same node as the C2 beacon.\n",
|
||||
"Finally, the last available action is for the c2_server_ransomware_launch to start the ransomware script installed on the same node as the C2 beacon.\n",
|
||||
"\n",
|
||||
"This action is indexed as action ``7``.\n",
|
||||
"\n",
|
||||
@@ -577,10 +544,6 @@
|
||||
"\n",
|
||||
"``` yaml\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: C2_SERVER_RANSOMWARE_LAUNCH\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" ...\n",
|
||||
@@ -590,7 +553,7 @@
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
" 7:\n",
|
||||
" action: C2_SERVER_RANSOMWARE_LAUNCH\n",
|
||||
" action: c2_server_ransomware_launch\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
"```\n"
|
||||
@@ -632,7 +595,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"custom_blue_agent_yaml = \"\"\" \n",
|
||||
"custom_blue_agent_yaml = \"\"\"\n",
|
||||
" - ref: defender\n",
|
||||
" team: BLUE\n",
|
||||
" type: ProxyAgent\n",
|
||||
@@ -715,28 +678,23 @@
|
||||
" - type: \"NONE\"\n",
|
||||
" label: ICS\n",
|
||||
" options: {}\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" - type: NODE_APPLICATION_REMOVE\n",
|
||||
" - type: NODE_SHUTDOWN\n",
|
||||
" - type: ROUTER_ACL_ADDRULE\n",
|
||||
" - type: DONOTHING\n",
|
||||
" action_map:\n",
|
||||
" 0:\n",
|
||||
" action: DONOTHING\n",
|
||||
" action: do_nothing\n",
|
||||
" options: {}\n",
|
||||
" 1:\n",
|
||||
" action: NODE_APPLICATION_REMOVE\n",
|
||||
" action: node_application_remove\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" application_name: C2Beacon\n",
|
||||
" 2:\n",
|
||||
" action: NODE_SHUTDOWN\n",
|
||||
" action: node_shutdown\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" 3:\n",
|
||||
" action: ROUTER_ACL_ADDRULE\n",
|
||||
" action: router_acl_add_rule\n",
|
||||
" options:\n",
|
||||
" target_router: router_1\n",
|
||||
" position: 1\n",
|
||||
@@ -747,7 +705,7 @@
|
||||
" dest_port_id: 2\n",
|
||||
" protocol_id: 1\n",
|
||||
" source_wildcard_id: 0\n",
|
||||
" dest_wildcard_id: 0 \n",
|
||||
" dest_wildcard_id: 0\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" options:\n",
|
||||
@@ -796,7 +754,7 @@
|
||||
" # removing all agents & adding the custom agent.\n",
|
||||
" cfg['agents'] = {}\n",
|
||||
" cfg['agents'] = custom_blue\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"blue_env = PrimaiteGymEnv(env_config=cfg)"
|
||||
]
|
||||
@@ -1121,7 +1079,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The code cell below uses the custom blue agent defined at the start of this section perform a NODE_APPLICATION_REMOVE on the C2 beacon:"
|
||||
"The code cell below uses the custom blue agent defined at the start of this section perform a node_application_remove on the C2 beacon:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1130,7 +1088,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Using CAOS ACTION: NODE_APPLICATION_REMOVE & capturing the OBS\n",
|
||||
"# Using CAOS ACTION: node_application_remove & capturing the OBS\n",
|
||||
"post_blue_action_obs, _, _, _, _ = blue_env.step(1)"
|
||||
]
|
||||
},
|
||||
@@ -1216,7 +1174,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The code cell below uses the custom blue agent defined at the start of this section to perform a ``NODE_SHUT_DOWN`` action on the web server."
|
||||
"The code cell below uses the custom blue agent defined at the start of this section to perform a ``node_shut_down`` action on the web server."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1225,7 +1183,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Using CAOS ACTION: NODE_SHUT_DOWN & capturing the OBS\n",
|
||||
"# Using CAOS ACTION: node_shut_down & capturing the OBS\n",
|
||||
"post_blue_action_obs, _, _, _, _ = blue_env.step(2)"
|
||||
]
|
||||
},
|
||||
@@ -1306,7 +1264,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The code cell below uses the custom blue agent defined at the start of this section to perform a ROUTER_ACL_ADDRULE on router 1."
|
||||
"The code cell below uses the custom blue agent defined at the start of this section to perform a router_acl_add_rule on router 1."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1315,7 +1273,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Using CAOS ACTION: ROUTER_ACL_ADDRULE & capturing the OBS\n",
|
||||
"# Using CAOS ACTION: router_acl_add_rule & capturing the OBS\n",
|
||||
"post_blue_action_obs, _, _, _, _ = blue_env.step(3)"
|
||||
]
|
||||
},
|
||||
@@ -1429,11 +1387,11 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"As demonstrated earlier, red agents can use the ``CONFIGURE_C2_BEACON`` action to configure these settings mid episode through the configuration options:\n",
|
||||
"As demonstrated earlier, red agents can use the ``configure_c2_beacon`` action to configure these settings mid episode through the configuration options:\n",
|
||||
"\n",
|
||||
"``` YAML\n",
|
||||
"...\n",
|
||||
" action: CONFIGURE_C2_BEACON\n",
|
||||
" action: configure_c2_beacon\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" config:\n",
|
||||
@@ -1468,7 +1426,7 @@
|
||||
" # removing all agents & adding the custom agent.\n",
|
||||
" cfg['agents'] = {}\n",
|
||||
" cfg['agents'] = c2_agent_yaml\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"c2_config_env = PrimaiteGymEnv(env_config=cfg)"
|
||||
]
|
||||
@@ -1555,7 +1513,7 @@
|
||||
"source": [
|
||||
"for i in range(6):\n",
|
||||
" env.step(0)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"c2_server_1.show()"
|
||||
]
|
||||
},
|
||||
@@ -1676,7 +1634,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Comparing the OBS of the default frequency to a timestep frequency of 1 \n",
|
||||
"# Comparing the OBS of the default frequency to a timestep frequency of 1\n",
|
||||
"for i in range(2):\n",
|
||||
" keep_alive_obs, _, _, _, _ = blue_config_env.step(0)\n",
|
||||
" display_obs_diffs(default_obs, keep_alive_obs, blue_config_env.game.step_counter)"
|
||||
@@ -1760,7 +1718,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Capturing default C2 Traffic \n",
|
||||
"# Capturing default C2 Traffic\n",
|
||||
"for i in range(3):\n",
|
||||
" tcp_c2_obs, _, _, _, _ = blue_config_env.step(0)\n",
|
||||
"\n",
|
||||
|
||||
@@ -15,6 +15,15 @@
|
||||
"*(For a full explanation of the Data Manipulation scenario, check out the data manipulation scenario notebook)*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -67,9 +76,9 @@
|
||||
" # parse the info dict form step output and write out what the red agent is doing\n",
|
||||
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info.action\n",
|
||||
" if red_action == 'DONOTHING':\n",
|
||||
" if red_action == 'do_nothing':\n",
|
||||
" red_str = 'DO NOTHING'\n",
|
||||
" elif red_action == 'NODE_APPLICATION_EXECUTE':\n",
|
||||
" elif red_action == 'node_application_execute':\n",
|
||||
" client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n",
|
||||
" red_str = f\"ATTACK from {client}\"\n",
|
||||
" return red_str"
|
||||
@@ -147,12 +156,7 @@
|
||||
" nodes: {}\n",
|
||||
"\n",
|
||||
" action_space:\n",
|
||||
"\n",
|
||||
" # The agent has two action choices, either do nothing, or execute a pre-scripted attack by using \n",
|
||||
" action_list:\n",
|
||||
" - type: DONOTHING\n",
|
||||
" - type: NODE_APPLICATION_EXECUTE\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" # The agent has access to the DataManipulationBoth on clients 1 and 2.\n",
|
||||
" options:\n",
|
||||
" nodes:\n",
|
||||
@@ -306,19 +310,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"change = yaml.safe_load(\"\"\"\n",
|
||||
"action_space:\n",
|
||||
" action_list:\n",
|
||||
" - type: DONOTHING\n",
|
||||
" - type: NODE_APPLICATION_EXECUTE\n",
|
||||
" options:\n",
|
||||
" nodes:\n",
|
||||
" - node_name: client_1\n",
|
||||
" applications:\n",
|
||||
" - application_name: DataManipulationBot\n",
|
||||
" max_folders_per_node: 1\n",
|
||||
" max_files_per_folder: 1\n",
|
||||
" max_services_per_node: 1\n",
|
||||
"# TODO:\n",
|
||||
"\"\"\")\n",
|
||||
"#TODO 2869 fix\n",
|
||||
"\n",
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
@@ -444,7 +438,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
||||
@@ -165,13 +165,13 @@
|
||||
"\n",
|
||||
"| node_id | node name |\n",
|
||||
"|---------|------------------|\n",
|
||||
"| 1 | domain_controller|\n",
|
||||
"| 2 | web_server |\n",
|
||||
"| 3 | database_server |\n",
|
||||
"| 4 | backup_server |\n",
|
||||
"| 5 | security_suite |\n",
|
||||
"| 6 | client_1 |\n",
|
||||
"| 7 | client_2 |\n",
|
||||
"| 0 | domain_controller|\n",
|
||||
"| 1 | web_server |\n",
|
||||
"| 2 | database_server |\n",
|
||||
"| 3 | backup_server |\n",
|
||||
"| 4 | security_suite |\n",
|
||||
"| 5 | client_1 |\n",
|
||||
"| 6 | client_2 |\n",
|
||||
"\n",
|
||||
"Service 1 on node 2 (web_server) corresponds to the Web Server service. Other services are only there for padding to ensure that each node's observation space has the same shape. They are filled with zeroes.\n",
|
||||
"\n",
|
||||
@@ -371,6 +371,15 @@
|
||||
"First, load the required modules"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -449,9 +458,9 @@
|
||||
" # parse the info dict form step output and write out what the red agent is doing\n",
|
||||
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info.action\n",
|
||||
" if red_action == 'DONOTHING':\n",
|
||||
" if red_action == 'do_nothing':\n",
|
||||
" red_str = 'DO NOTHING'\n",
|
||||
" elif red_action == 'NODE_APPLICATION_EXECUTE':\n",
|
||||
" elif red_action == 'node_application_execute':\n",
|
||||
" client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n",
|
||||
" red_str = f\"ATTACK from {client}\"\n",
|
||||
" return red_str"
|
||||
@@ -547,7 +556,7 @@
|
||||
"\n",
|
||||
"The reward will increase slightly as soon as the file finishes restoring. Then, the reward will increase to 0.9 when both green agents make successful requests.\n",
|
||||
"\n",
|
||||
"Run the following cell until the green action is `NODE_APPLICATION_EXECUTE` for application 0, then the reward should increase. If you run it enough times, another red attack will happen and the reward will drop again."
|
||||
"Run the following cell until the green action is `node_application_execute` for application 0, then the reward should increase. If you run it enough times, another red attack will happen and the reward will drop again."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -9,6 +9,15 @@
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
@@ -201,7 +201,7 @@
|
||||
"source": [
|
||||
"caos_action = [\n",
|
||||
" \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n",
|
||||
" \"service\", \"Terminal\", \"ssh_to_remote\", \"admin\", \"admin\", str(some_tech_storage_srv.network_interface[1].ip_address)\n",
|
||||
" \"service\", \"Terminal\", \"node_session_remote_login\", \"admin\", \"admin\", str(some_tech_storage_srv.network_interface[1].ip_address)\n",
|
||||
"]\n",
|
||||
"game.simulation.apply_request(caos_action)"
|
||||
]
|
||||
@@ -259,7 +259,7 @@
|
||||
"source": [
|
||||
"caos_action = [\n",
|
||||
" \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n",
|
||||
" \"service\", \"Terminal\", \"ssh_to_remote\", \"admin\", \"admin\", str(some_tech_rt.network_interface[4].ip_address)\n",
|
||||
" \"service\", \"Terminal\", \"node_session_remote_login\", \"admin\", \"admin\", str(some_tech_rt.network_interface[4].ip_address)\n",
|
||||
"]\n",
|
||||
"game.simulation.apply_request(caos_action)"
|
||||
]
|
||||
@@ -396,7 +396,7 @@
|
||||
"source": [
|
||||
"caos_action = [\n",
|
||||
" \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n",
|
||||
" \"service\", \"Terminal\", \"ssh_to_remote\", \"admin\", \"admin\", str(some_tech_storage_srv.network_interface[1].ip_address)\n",
|
||||
" \"service\", \"Terminal\", \"node_session_remote_login\", \"admin\", \"admin\", str(some_tech_storage_srv.network_interface[1].ip_address)\n",
|
||||
"]\n",
|
||||
"game.simulation.apply_request(caos_action)"
|
||||
]
|
||||
@@ -25,6 +25,15 @@
|
||||
"Let's set up a minimal network simulation and send some requests to see how it works."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
@@ -18,6 +18,15 @@
|
||||
"The Terminal service comes pre-installed on most Nodes (The exception being Switches, as these are currently dumb). "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
@@ -18,6 +18,15 @@
|
||||
"#### First, Import packages and read our config file."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -32,8 +41,6 @@
|
||||
"from ray.rllib.algorithms.ppo import PPOConfig\n",
|
||||
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
|
||||
"\n",
|
||||
"# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n",
|
||||
"# to copy the files to your user data path.\n",
|
||||
"with open(PRIMAITE_PATHS.user_config_path / 'example_config/data_manipulation_marl.yaml', 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"\n",
|
||||
|
||||
@@ -11,6 +11,15 @@
|
||||
"This notebook will demonstrate how to use PrimaiteRayEnv to train a basic PPO agent."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -95,7 +104,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -109,7 +118,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -18,6 +18,15 @@
|
||||
"#### First, we import the inital packages and read in our configuration file."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -168,7 +177,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -182,7 +191,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.8"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -238,7 +238,7 @@
|
||||
"### Episode 2\n",
|
||||
"When we reset the environment again, it moves onto episode 2, where it will bring in greens_1 and reds_1 for green and red agent definitions. Let's verify the agent names and that they take actions at the defined frequency.\n",
|
||||
"\n",
|
||||
"Most green actions will be `NODE_APPLICATION_EXECUTE` while red will `DONOTHING` except at steps 10 and 20."
|
||||
"Most green actions will be `node_application_execute` while red will `DONOTHING` except at steps 10 and 20."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -269,7 +269,7 @@
|
||||
"### Episode 3\n",
|
||||
"When we reset the environment again, it moves onto episode 3, where it will bring in greens_2 and reds_2 for green and red agent definitions. Let's verify the agent names and that they take actions at the defined frequency.\n",
|
||||
"\n",
|
||||
"Now, green will perform `NODE_APPLICATION_EXECUTE` only 5% of the time, while red will perform `NODE_APPLICATION_EXECUTE` more frequently than before."
|
||||
"Now, green will perform `node_application_execute` only 5% of the time, while red will perform `node_application_execute` more frequently than before."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 110 KiB After Width: | Height: | Size: 110 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 69 KiB After Width: | Height: | Size: 69 KiB |
@@ -18,6 +18,15 @@
|
||||
"Import packages and read config file."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
@@ -89,7 +89,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
:return: Action mask
|
||||
:rtype: List[bool]
|
||||
"""
|
||||
if not self.agent.action_masking:
|
||||
if not self.agent.config.agent_settings.action_masking:
|
||||
return np.asarray([True] * len(self.agent.action_manager.action_map))
|
||||
else:
|
||||
return self.game.action_mask(self._agent_name)
|
||||
|
||||
@@ -44,7 +44,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
)
|
||||
for agent_name in self._agent_ids:
|
||||
agent = self.game.rl_agents[agent_name]
|
||||
if agent.action_masking:
|
||||
if agent.config.agent_settings.action_masking:
|
||||
self.observation_space[agent_name] = spaces.Dict(
|
||||
{
|
||||
"action_mask": spaces.MultiBinary(agent.action_manager.space.n),
|
||||
@@ -143,7 +143,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
unflat_space = agent.observation_manager.space
|
||||
unflat_obs = agent.observation_manager.current_observation
|
||||
obs = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
if agent.action_masking:
|
||||
if agent.config.agent_settings.action_masking:
|
||||
all_obs[agent_name] = {"action_mask": self.game.action_mask(agent_name), "observations": obs}
|
||||
else:
|
||||
all_obs[agent_name] = obs
|
||||
@@ -168,7 +168,7 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
self.env = PrimaiteGymEnv(env_config=env_config)
|
||||
# self.env.episode_counter -= 1
|
||||
self.action_space = self.env.action_space
|
||||
if self.env.agent.action_masking:
|
||||
if self.env.agent.config.agent_settings.action_masking:
|
||||
self.observation_space = spaces.Dict(
|
||||
{"action_mask": spaces.MultiBinary(self.env.action_space.n), "observations": self.env.observation_space}
|
||||
)
|
||||
@@ -178,7 +178,7 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
super().reset() # Ensure PRNG seed is set everywhere
|
||||
if self.env.agent.action_masking:
|
||||
if self.env.agent.config.agent_settings.action_masking:
|
||||
obs, *_ = self.env.reset(seed=seed)
|
||||
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}
|
||||
return new_obs, *_
|
||||
@@ -187,7 +187,7 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
|
||||
"""Perform a step in the environment."""
|
||||
# if action masking is enabled, intercept the step method and add action mask to observation
|
||||
if self.env.agent.action_masking:
|
||||
if self.env.agent.config.agent_settings.action_masking:
|
||||
obs, *_ = self.env.step(action)
|
||||
new_obs = {"action_mask": self.game.action_mask(self.env._agent_name), "observations": obs}
|
||||
return new_obs, *_
|
||||
|
||||
@@ -264,7 +264,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -664,7 +664,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -130,8 +130,8 @@ class File(FileSystemItemABC):
|
||||
|
||||
Return False if corruption is detected, otherwise True
|
||||
"""
|
||||
warnings.warn("NODE_FILE_CHECKHASH is currently not implemented.")
|
||||
self.sys_log.warning("NODE_FILE_CHECKHASH is currently not implemented.")
|
||||
warnings.warn("node_file_checkhash is currently not implemented.")
|
||||
self.sys_log.warning("node_file_checkhash is currently not implemented.")
|
||||
return False
|
||||
|
||||
if self.deleted:
|
||||
|
||||
@@ -30,6 +30,11 @@ class FileSystem(SimComponent):
|
||||
num_file_deletions: int = 0
|
||||
"Number of file deletions in the current step."
|
||||
|
||||
_default_folder_scan_duration: Optional[int] = None
|
||||
"Override default scan duration for folders"
|
||||
_default_folder_restore_duration: Optional[int] = None
|
||||
"Override default restore duration for folders"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# Ensure a default root folder
|
||||
@@ -258,6 +263,11 @@ class FileSystem(SimComponent):
|
||||
name=folder.name, request_type=RequestType(func=folder._request_manager)
|
||||
)
|
||||
self.folders[folder.uuid] = folder
|
||||
# set the folder scan and restore durations.
|
||||
if self._default_folder_scan_duration is not None:
|
||||
folder.scan_duration = self._default_folder_scan_duration
|
||||
if self._default_folder_restore_duration is not None:
|
||||
folder.restore_duration = self._default_folder_restore_duration
|
||||
return folder
|
||||
|
||||
def delete_folder(self, folder_name: str) -> bool:
|
||||
|
||||
@@ -43,6 +43,9 @@ def convert_size(size_bytes: int) -> str:
|
||||
class FileSystemItemHealthStatus(Enum):
|
||||
"""Status of the FileSystemItem."""
|
||||
|
||||
NONE = 0
|
||||
"""File system item health status is not known."""
|
||||
|
||||
GOOD = 1
|
||||
"""File/Folder is OK."""
|
||||
|
||||
@@ -72,7 +75,7 @@ class FileSystemItemABC(SimComponent):
|
||||
health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.GOOD
|
||||
"Actual status of the current FileSystemItem"
|
||||
|
||||
visible_health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.GOOD
|
||||
visible_health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.NONE
|
||||
"Visible status of the current FileSystemItem"
|
||||
|
||||
previous_hash: Optional[str] = None
|
||||
|
||||
@@ -46,7 +46,7 @@ class Folder(FileSystemItemABC):
|
||||
:param sys_log: The SysLog instance to us to create system logs.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._scanned_this_step: bool = False
|
||||
self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})")
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
@@ -83,6 +83,7 @@ class Folder(FileSystemItemABC):
|
||||
state = super().describe_state()
|
||||
state["files"] = {file.name: file.describe_state() for uuid, file in self.files.items()}
|
||||
state["deleted_files"] = {file.name: file.describe_state() for uuid, file in self.deleted_files.items()}
|
||||
state["scanned_this_step"] = self._scanned_this_step
|
||||
return state
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
@@ -135,7 +136,7 @@ class Folder(FileSystemItemABC):
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
|
||||
self._scanned_this_step = False
|
||||
for file in self.files.values():
|
||||
file.pre_timestep(timestep)
|
||||
|
||||
@@ -148,9 +149,17 @@ class Folder(FileSystemItemABC):
|
||||
for file_id in self.files:
|
||||
file = self.get_file_by_id(file_uuid=file_id)
|
||||
file.scan()
|
||||
if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT:
|
||||
self.health_status = FileSystemItemHealthStatus.CORRUPT
|
||||
# set folder health to worst file's health by generating a list of file healths. If no files, use 0
|
||||
self.health_status = FileSystemItemHealthStatus(
|
||||
max(
|
||||
[f.health_status.value for f in self.files.values()]
|
||||
or [
|
||||
0,
|
||||
]
|
||||
)
|
||||
)
|
||||
self.visible_health_status = self.health_status
|
||||
self._scanned_this_step = True
|
||||
|
||||
def _reveal_to_red_timestep(self) -> None:
|
||||
"""Apply reveal to red timestep."""
|
||||
@@ -387,8 +396,8 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
Return False if corruption is detected, otherwise True
|
||||
"""
|
||||
warnings.warn("NODE_FOLDER_CHECKHASH is currently not implemented.")
|
||||
self.sys_log.error("NODE_FOLDER_CHECKHASH is currently not implemented.")
|
||||
warnings.warn("node_folder_checkhash is currently not implemented.")
|
||||
self.sys_log.error("node_folder_checkhash is currently not implemented.")
|
||||
return False
|
||||
|
||||
if self.deleted:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import ABC, abstractmethod
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, ClassVar, Dict, Literal, Type
|
||||
from typing import Any, ClassVar, Dict, Literal, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@@ -49,7 +49,7 @@ class NetworkNodeAdder(BaseModel):
|
||||
|
||||
_registry: ClassVar[Dict[str, Type["NetworkNodeAdder"]]] = {}
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
def __init_subclass__(cls, identifier: Optional[str], **kwargs: Any) -> None:
|
||||
"""
|
||||
Register a network node adder class.
|
||||
|
||||
@@ -58,6 +58,8 @@ class NetworkNodeAdder(BaseModel):
|
||||
:raises ValueError: When attempting to register a name that is already reserved.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Duplicate node adder {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@@ -824,7 +824,7 @@ class User(SimComponent):
|
||||
return self.model_dump()
|
||||
|
||||
|
||||
class UserManager(Service):
|
||||
class UserManager(Service, identifier="UserManager"):
|
||||
"""
|
||||
Manages users within the PrimAITE system, handling creation, authentication, and administration.
|
||||
|
||||
@@ -833,11 +833,18 @@ class UserManager(Service):
|
||||
:param disabled_admins: A dictionary of currently disabled admin users by their usernames
|
||||
"""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for UserManager."""
|
||||
|
||||
type: str = "UserManager"
|
||||
|
||||
config: "UserManager.ConfigSchema" = Field(default_factory=lambda: UserManager.ConfigSchema())
|
||||
|
||||
users: Dict[str, User] = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Initializes a UserManager instanc.
|
||||
Initializes a UserManager instance.
|
||||
|
||||
:param username: The username for the default admin user
|
||||
:param password: The password for the default admin user
|
||||
@@ -1130,13 +1137,20 @@ class RemoteUserSession(UserSession):
|
||||
return state
|
||||
|
||||
|
||||
class UserSessionManager(Service):
|
||||
class UserSessionManager(Service, identifier="UserSessionManager"):
|
||||
"""
|
||||
Manages user sessions on a Node, including local and remote sessions.
|
||||
|
||||
This class handles authentication, session management, and session timeouts for users interacting with the Node.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for UserSessionManager."""
|
||||
|
||||
type: str = "UserSessionManager"
|
||||
|
||||
config: "UserSessionManager.ConfigSchema" = Field(default_factory=lambda: UserSessionManager.ConfigSchema())
|
||||
|
||||
local_session: Optional[UserSession] = None
|
||||
"""The current local user session, if any."""
|
||||
|
||||
@@ -1554,7 +1568,6 @@ class Node(SimComponent, ABC):
|
||||
red_scan_countdown: int = 0
|
||||
"Time steps until reveal to red scan is complete."
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "Node":
|
||||
"""Create Node object from a given configuration dictionary."""
|
||||
@@ -1564,7 +1577,7 @@ class Node(SimComponent, ABC):
|
||||
obj = cls(config=cls.ConfigSchema(**config))
|
||||
return obj
|
||||
|
||||
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register a node type.
|
||||
|
||||
@@ -1572,10 +1585,10 @@ class Node(SimComponent, ABC):
|
||||
:type identifier: str
|
||||
:raises ValueError: When attempting to register an node with a name that is already allocated.
|
||||
"""
|
||||
if identifier == "default":
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
identifier = identifier.lower()
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Tried to define new node {identifier}, but this name is already reserved.")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar, Dict, Optional, Set, Type
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType
|
||||
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
|
||||
@@ -21,13 +23,20 @@ class ApplicationOperatingState(Enum):
|
||||
"The application is being installed or updated."
|
||||
|
||||
|
||||
class Application(IOSoftware):
|
||||
class Application(IOSoftware, ABC):
|
||||
"""
|
||||
Represents an Application in the simulation environment.
|
||||
|
||||
Applications are user-facing programs that may perform input/output operations.
|
||||
"""
|
||||
|
||||
class ConfigSchema(IOSoftware.ConfigSchema, ABC):
|
||||
"""Config Schema for Application class."""
|
||||
|
||||
type: str
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: Application.ConfigSchema())
|
||||
|
||||
operating_state: ApplicationOperatingState = ApplicationOperatingState.CLOSED
|
||||
"The current operating state of the Application."
|
||||
execution_control_status: str = "manual"
|
||||
@@ -44,21 +53,36 @@ class Application(IOSoftware):
|
||||
_registry: ClassVar[Dict[str, Type["Application"]]] = {}
|
||||
"""Registry of application types. Automatically populated when subclasses are defined."""
|
||||
|
||||
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register an application type.
|
||||
|
||||
:param identifier: Uniquely specifies an application class by name. Used for finding items by config.
|
||||
:type identifier: str
|
||||
:type identifier: Optional[str]
|
||||
:raises ValueError: When attempting to register an application with a name that is already allocated.
|
||||
"""
|
||||
if identifier == "default":
|
||||
return
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Tried to define new application {identifier}, but this name is already reserved.")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "Application":
|
||||
"""Create an application from a config dictionary.
|
||||
|
||||
:param config: dict of options for application components constructor
|
||||
:type config: dict
|
||||
:return: The application component.
|
||||
:rtype: Application
|
||||
"""
|
||||
if config["type"] not in cls._registry:
|
||||
raise ValueError(f"Invalid Application type {config['type']}")
|
||||
application_class = cls._registry[config["type"]]
|
||||
application_object = application_class(config=application_class.ConfigSchema(**config))
|
||||
return application_object
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
@@ -67,11 +67,19 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
|
||||
Extends the Application class to provide functionality for connecting, querying, and disconnecting from a
|
||||
Database Service. It mainly operates over TCP protocol.
|
||||
|
||||
:ivar server_ip_address: The IPv4 address of the Database Service server, defaults to None.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""ConfigSchema for DatabaseClient."""
|
||||
|
||||
type: str = "DatabaseClient"
|
||||
db_server_ip: Optional[IPV4Address] = None
|
||||
server_password: Optional[str] = None
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: DatabaseClient.ConfigSchema())
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
"""The IPv4 address of the Database Service server, defaults to None."""
|
||||
server_password: Optional[str] = None
|
||||
_query_success_tracker: Dict[str, bool] = {}
|
||||
"""Keep track of connections that were established or verified during this step. Used for rewards."""
|
||||
@@ -93,6 +101,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"]
|
||||
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
||||
super().__init__(**kwargs)
|
||||
self.server_ip_address = self.config.db_server_ip
|
||||
self.server_password = self.config.server_password
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
|
||||
@@ -3,7 +3,7 @@ from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Final, List, Optional, Set, Tuple, Union
|
||||
|
||||
from prettytable import PrettyTable
|
||||
from pydantic import validate_call
|
||||
from pydantic import Field, validate_call
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
@@ -52,6 +52,13 @@ class NMAP(Application, identifier="NMAP"):
|
||||
as ping scans to discover active hosts and port scans to detect open ports on those hosts.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""ConfigSchema for NMAP."""
|
||||
|
||||
type: str = "NMAP"
|
||||
|
||||
config: "NMAP.ConfigSchema" = Field(default_factory=lambda: NMAP.ConfigSchema())
|
||||
|
||||
_active_port_scans: Dict[str, PortScanPayload] = {}
|
||||
_port_scan_responses: Dict[str, PortScanPayload] = {}
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, validate_call
|
||||
from pydantic import Field, validate_call
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.file_system.file_system import FileSystem, Folder
|
||||
@@ -45,10 +45,10 @@ class C2Payload(Enum):
|
||||
"""C2 Input Command payload. Used by the C2 Server to send a command to the c2 beacon."""
|
||||
|
||||
OUTPUT = "output_command"
|
||||
"""C2 Output Command. Used by the C2 Beacon to send the results of a Input command to the c2 server."""
|
||||
"""C2 Output Command. Used by the C2 Beacon to send the results of an Input command to the c2 server."""
|
||||
|
||||
|
||||
class AbstractC2(Application, identifier="AbstractC2"):
|
||||
class AbstractC2(Application):
|
||||
"""
|
||||
An abstract command and control (c2) application.
|
||||
|
||||
@@ -60,9 +60,25 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
|
||||
Defaults to masquerading as HTTP (Port 80) via TCP.
|
||||
|
||||
Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite.
|
||||
Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""Configuration for AbstractC2."""
|
||||
|
||||
keep_alive_frequency: int = Field(default=5, ge=1)
|
||||
"""The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon."""
|
||||
|
||||
masquerade_protocol: IPProtocol = Field(default=PROTOCOL_LOOKUP["TCP"])
|
||||
"""The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP."""
|
||||
|
||||
masquerade_port: Port = Field(default=PORT_LOOKUP["HTTP"])
|
||||
"""The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP."""
|
||||
|
||||
listen_on_ports: Set[Port] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]}
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: AbstractC2.ConfigSchema())
|
||||
|
||||
c2_connection_active: bool = False
|
||||
"""Indicates if the c2 server and c2 beacon are currently connected."""
|
||||
|
||||
@@ -75,19 +91,6 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
keep_alive_inactivity: int = 0
|
||||
"""Indicates how many timesteps since the last time the c2 application received a keep alive."""
|
||||
|
||||
class _C2Opts(BaseModel):
|
||||
"""A Pydantic Schema for the different C2 configuration options."""
|
||||
|
||||
keep_alive_frequency: int = Field(default=5, ge=1)
|
||||
"""The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon."""
|
||||
|
||||
masquerade_protocol: IPProtocol = Field(default=PROTOCOL_LOOKUP["TCP"])
|
||||
"""The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP."""
|
||||
|
||||
masquerade_port: Port = Field(default=PORT_LOOKUP["HTTP"])
|
||||
"""The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP."""
|
||||
|
||||
c2_config: _C2Opts = _C2Opts()
|
||||
"""
|
||||
Holds the current configuration settings of the C2 Suite.
|
||||
|
||||
@@ -100,6 +103,12 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
C2 beacon to reconfigure it's configuration settings.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialise the C2 applications to by default listen for HTTP traffic."""
|
||||
kwargs["port"] = PORT_LOOKUP["NONE"]
|
||||
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _craft_packet(
|
||||
self, c2_payload: C2Payload, c2_command: Optional[C2Command] = None, command_options: Optional[Dict] = {}
|
||||
) -> C2Packet:
|
||||
@@ -118,13 +127,13 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
:type c2_command: C2Command.
|
||||
:param command_options: The relevant C2 Beacon parameters.F
|
||||
:type command_options: Dict
|
||||
:return: Returns the construct C2Packet
|
||||
:return: Returns the constructed C2Packet
|
||||
:rtype: C2Packet
|
||||
"""
|
||||
constructed_packet = C2Packet(
|
||||
masquerade_protocol=self.c2_config.masquerade_protocol,
|
||||
masquerade_port=self.c2_config.masquerade_port,
|
||||
keep_alive_frequency=self.c2_config.keep_alive_frequency,
|
||||
masquerade_protocol=self.config.masquerade_protocol,
|
||||
masquerade_port=self.config.masquerade_port,
|
||||
keep_alive_frequency=self.config.keep_alive_frequency,
|
||||
payload_type=c2_payload,
|
||||
command=c2_command,
|
||||
payload=command_options,
|
||||
@@ -140,13 +149,6 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
"""
|
||||
return super().describe_state()
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialise the C2 applications to by default listen for HTTP traffic."""
|
||||
kwargs["listen_on_ports"] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]}
|
||||
kwargs["port"] = PORT_LOOKUP["NONE"]
|
||||
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def _host_ftp_client(self) -> Optional[FTPClient]:
|
||||
"""Return the FTPClient that is installed C2 Application's host.
|
||||
@@ -330,8 +332,8 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
if self.send(
|
||||
payload=keep_alive_packet,
|
||||
dest_ip_address=self.c2_remote_connection,
|
||||
dest_port=self.c2_config.masquerade_port,
|
||||
ip_protocol=self.c2_config.masquerade_protocol,
|
||||
dest_port=self.config.masquerade_port,
|
||||
ip_protocol=self.config.masquerade_protocol,
|
||||
session_id=session_id,
|
||||
):
|
||||
# Setting the keep_alive_sent guard condition to True. This is used to prevent packet storms.
|
||||
@@ -340,8 +342,8 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
self.sys_log.info(f"{self.name}: Keep Alive sent to {self.c2_remote_connection}")
|
||||
self.sys_log.debug(
|
||||
f"{self.name}: Keep Alive sent to {self.c2_remote_connection} "
|
||||
f"Masquerade Port: {self.c2_config.masquerade_port} "
|
||||
f"Masquerade Protocol: {self.c2_config.masquerade_protocol} "
|
||||
f"Masquerade Port: {self.config.masquerade_port} "
|
||||
f"Masquerade Protocol: {self.config.masquerade_protocol} "
|
||||
)
|
||||
return True
|
||||
else:
|
||||
@@ -376,15 +378,15 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
|
||||
# Updating the C2 Configuration attribute.
|
||||
|
||||
self.c2_config.masquerade_port = payload.masquerade_port
|
||||
self.c2_config.masquerade_protocol = payload.masquerade_protocol
|
||||
self.c2_config.keep_alive_frequency = payload.keep_alive_frequency
|
||||
self.config.masquerade_port = payload.masquerade_port
|
||||
self.config.masquerade_protocol = payload.masquerade_protocol
|
||||
self.config.keep_alive_frequency = payload.keep_alive_frequency
|
||||
|
||||
self.sys_log.debug(
|
||||
f"{self.name}: C2 Config Resolved Config from Keep Alive:"
|
||||
f"Masquerade Port: {self.c2_config.masquerade_port}"
|
||||
f"Masquerade Protocol: {self.c2_config.masquerade_protocol}"
|
||||
f"Keep Alive Frequency: {self.c2_config.keep_alive_frequency}"
|
||||
f"Masquerade Port: {self.config.masquerade_port}"
|
||||
f"Masquerade Protocol: {self.config.masquerade_protocol}"
|
||||
f"Keep Alive Frequency: {self.config.keep_alive_frequency}"
|
||||
)
|
||||
|
||||
# This statement is intended to catch on the C2 Application that is listening for connection.
|
||||
@@ -410,8 +412,8 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
self.keep_alive_inactivity = 0
|
||||
self.keep_alive_frequency = 5
|
||||
self.c2_remote_connection = None
|
||||
self.c2_config.masquerade_port = PORT_LOOKUP["HTTP"]
|
||||
self.c2_config.masquerade_protocol = PROTOCOL_LOOKUP["TCP"]
|
||||
self.config.masquerade_port = PORT_LOOKUP["HTTP"]
|
||||
self.config.masquerade_protocol = PROTOCOL_LOOKUP["TCP"]
|
||||
|
||||
@abstractmethod
|
||||
def _confirm_remote_connection(self, timestep: int) -> bool:
|
||||
|
||||
@@ -3,7 +3,7 @@ from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import validate_call
|
||||
from pydantic import Field, validate_call
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
@@ -12,8 +12,9 @@ from primaite.simulator.system.applications.red_applications.c2 import ExfilOpts
|
||||
from primaite.simulator.system.applications.red_applications.c2.abstract_c2 import AbstractC2, C2Command, C2Payload
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
|
||||
from primaite.simulator.system.services.terminal.terminal import Terminal, TerminalClientConnection
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.port import PORT_LOOKUP
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
|
||||
|
||||
class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
@@ -32,15 +33,30 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
2. Leveraging the terminal application to execute requests (dependent on the command given)
|
||||
3. Sending the RequestResponse back to the C2 Server (Command output)
|
||||
|
||||
Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite.
|
||||
Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite.
|
||||
"""
|
||||
|
||||
class ConfigSchema(AbstractC2.ConfigSchema):
|
||||
"""ConfigSchema for C2Beacon."""
|
||||
|
||||
type: str = "C2Beacon"
|
||||
c2_server_ip_address: Optional[IPV4Address] = None
|
||||
keep_alive_frequency: int = 5
|
||||
masquerade_protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"]
|
||||
masquerade_port: Port = PORT_LOOKUP["HTTP"]
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: C2Beacon.ConfigSchema())
|
||||
|
||||
keep_alive_attempted: bool = False
|
||||
"""Indicates if a keep alive has been attempted to be sent this timestep. Used to prevent packet storms."""
|
||||
|
||||
terminal_session: TerminalClientConnection = None
|
||||
"The currently in use terminal session."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "C2Beacon"
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def _host_terminal(self) -> Optional[Terminal]:
|
||||
"""Return the Terminal that is installed on the same machine as the C2 Beacon."""
|
||||
@@ -119,10 +135,6 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
rm.add_request("configure", request_type=RequestType(func=_configure))
|
||||
return rm
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "C2Beacon"
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Configure is practically setter method for the ``c2.config`` attribute that also ties into the request manager.
|
||||
@validate_call
|
||||
def configure(
|
||||
@@ -146,7 +158,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
masquerade_port | What port should the C2 traffic use? (TCP or UDP)
|
||||
|
||||
These configuration options are used to reassign the fields in the inherited inner class
|
||||
``c2_config``.
|
||||
``config``.
|
||||
|
||||
If a connection is already in progress then this method also sends a keep alive to the C2
|
||||
Server in order for the C2 Server to sync with the new configuration settings.
|
||||
@@ -162,9 +174,9 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
:return: Returns True if the configuration was successful, False otherwise.
|
||||
"""
|
||||
self.c2_remote_connection = IPv4Address(c2_server_ip_address)
|
||||
self.c2_config.keep_alive_frequency = keep_alive_frequency
|
||||
self.c2_config.masquerade_port = masquerade_port
|
||||
self.c2_config.masquerade_protocol = masquerade_protocol
|
||||
self.config.keep_alive_frequency = keep_alive_frequency
|
||||
self.config.masquerade_port = masquerade_port
|
||||
self.config.masquerade_protocol = masquerade_protocol
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Configured {self.name} with remote C2 server connection: {c2_server_ip_address=}."
|
||||
)
|
||||
@@ -263,14 +275,12 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
if self.send(
|
||||
payload=output_packet,
|
||||
dest_ip_address=self.c2_remote_connection,
|
||||
dest_port=self.c2_config.masquerade_port,
|
||||
ip_protocol=self.c2_config.masquerade_protocol,
|
||||
dest_port=self.config.masquerade_port,
|
||||
ip_protocol=self.config.masquerade_protocol,
|
||||
session_id=session_id,
|
||||
):
|
||||
self.sys_log.info(f"{self.name}: Command output sent to {self.c2_remote_connection}")
|
||||
self.sys_log.debug(
|
||||
f"{self.name}: on {self.c2_config.masquerade_port} via {self.c2_config.masquerade_protocol}"
|
||||
)
|
||||
self.sys_log.debug(f"{self.name}: on {self.config.masquerade_port} via {self.config.masquerade_protocol}")
|
||||
return True
|
||||
else:
|
||||
self.sys_log.warning(
|
||||
@@ -562,7 +572,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
:rtype bool:
|
||||
"""
|
||||
self.keep_alive_attempted = False # Resetting keep alive sent.
|
||||
if self.keep_alive_inactivity == self.c2_config.keep_alive_frequency:
|
||||
if self.keep_alive_inactivity == self.config.keep_alive_frequency:
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Attempting to Send Keep Alive to {self.c2_remote_connection} at timestep {timestep}."
|
||||
)
|
||||
@@ -627,9 +637,9 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
self.c2_connection_active,
|
||||
self.c2_remote_connection,
|
||||
self.keep_alive_inactivity,
|
||||
self.c2_config.keep_alive_frequency,
|
||||
self.c2_config.masquerade_protocol,
|
||||
self.c2_config.masquerade_port,
|
||||
self.config.keep_alive_frequency,
|
||||
self.config.masquerade_protocol,
|
||||
self.config.masquerade_port,
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import validate_call
|
||||
from pydantic import Field, validate_call
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
@@ -31,9 +31,16 @@ class C2Server(AbstractC2, identifier="C2Server"):
|
||||
1. Sending commands to the C2 Beacon. (Command input)
|
||||
2. Parsing terminal RequestResponses back to the Agent.
|
||||
|
||||
Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite.
|
||||
Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite.
|
||||
"""
|
||||
|
||||
class ConfigSchema(AbstractC2.ConfigSchema):
|
||||
"""ConfigSchema for C2Server."""
|
||||
|
||||
type: str = "C2Server"
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: C2Server.ConfigSchema())
|
||||
|
||||
current_command_output: RequestResponse = None
|
||||
"""The Request Response by the last command send. This attribute is updated by the method _handle_command_output."""
|
||||
|
||||
@@ -251,8 +258,8 @@ class C2Server(AbstractC2, identifier="C2Server"):
|
||||
payload=command_packet,
|
||||
dest_ip_address=self.c2_remote_connection,
|
||||
session_id=self.c2_session.uuid,
|
||||
dest_port=self.c2_config.masquerade_port,
|
||||
ip_protocol=self.c2_config.masquerade_protocol,
|
||||
dest_port=self.config.masquerade_port,
|
||||
ip_protocol=self.config.masquerade_protocol,
|
||||
):
|
||||
self.sys_log.info(f"{self.name}: Successfully sent {given_command}.")
|
||||
self.sys_log.info(f"{self.name}: Awaiting command response {given_command}.")
|
||||
@@ -334,11 +341,11 @@ class C2Server(AbstractC2, identifier="C2Server"):
|
||||
:return: Returns False if the C2 beacon is considered dead. Otherwise True.
|
||||
:rtype bool:
|
||||
"""
|
||||
if self.keep_alive_inactivity > self.c2_config.keep_alive_frequency:
|
||||
if self.keep_alive_inactivity > self.config.keep_alive_frequency:
|
||||
self.sys_log.info(f"{self.name}: C2 Beacon connection considered dead due to inactivity.")
|
||||
self.sys_log.debug(
|
||||
f"{self.name}: Did not receive expected keep alive connection from {self.c2_remote_connection}"
|
||||
f"{self.name}: Expected at timestep: {timestep} due to frequency: {self.c2_config.keep_alive_frequency}"
|
||||
f"{self.name}: Expected at timestep: {timestep} due to frequency: {self.config.keep_alive_frequency}"
|
||||
f"{self.name}: Last Keep Alive received at {(timestep - self.keep_alive_inactivity)}"
|
||||
)
|
||||
self._reset_c2_connection()
|
||||
@@ -389,8 +396,8 @@ class C2Server(AbstractC2, identifier="C2Server"):
|
||||
[
|
||||
self.c2_connection_active,
|
||||
self.c2_remote_connection,
|
||||
self.c2_config.masquerade_protocol,
|
||||
self.c2_config.masquerade_port,
|
||||
self.config.masquerade_protocol,
|
||||
self.config.masquerade_port,
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
|
||||
@@ -3,6 +3,8 @@ from enum import IntEnum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.interface.request import RequestResponse
|
||||
@@ -10,6 +12,7 @@ from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import PORT_LOOKUP
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -40,6 +43,18 @@ class DataManipulationAttackStage(IntEnum):
|
||||
class DataManipulationBot(Application, identifier="DataManipulationBot"):
|
||||
"""A bot that simulates a script which performs a SQL injection attack."""
|
||||
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""Configuration schema for DataManipulationBot."""
|
||||
|
||||
type: str = "DataManipulationBot"
|
||||
server_ip: Optional[IPV4Address] = None
|
||||
server_password: Optional[str] = None
|
||||
payload: str = "DELETE"
|
||||
port_scan_p_of_success: float = 0.1
|
||||
data_manipulation_p_of_success: float = 0.1
|
||||
|
||||
config: "DataManipulationBot.ConfigSchema" = Field(default_factory=lambda: DataManipulationBot.ConfigSchema())
|
||||
|
||||
payload: Optional[str] = None
|
||||
port_scan_p_of_success: float = 0.1
|
||||
data_manipulation_p_of_success: float = 0.1
|
||||
@@ -56,6 +71,12 @@ class DataManipulationBot(Application, identifier="DataManipulationBot"):
|
||||
super().__init__(**kwargs)
|
||||
self._db_connection: Optional[DatabaseClientConnection] = None
|
||||
|
||||
self.server_ip_address = self.config.server_ip
|
||||
self.server_password = self.config.server_password
|
||||
self.payload = self.config.payload
|
||||
self.port_scan_p_of_success = self.config.port_scan_p_of_success
|
||||
self.data_manipulation_p_of_success = self.config.data_manipulation_p_of_success
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
|
||||
@@ -3,11 +3,14 @@ from enum import IntEnum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -32,6 +35,20 @@ class DoSAttackStage(IntEnum):
|
||||
class DoSBot(DatabaseClient, identifier="DoSBot"):
|
||||
"""A bot that simulates a Denial of Service attack."""
|
||||
|
||||
class ConfigSchema(DatabaseClient.ConfigSchema):
|
||||
"""ConfigSchema for DoSBot."""
|
||||
|
||||
type: str = "DoSBot"
|
||||
target_ip_address: Optional[IPV4Address] = None
|
||||
target_port: Port = PORT_LOOKUP["POSTGRES_SERVER"]
|
||||
payload: Optional[str] = None
|
||||
repeat: bool = False
|
||||
port_scan_p_of_success: float = 0.1
|
||||
dos_intensity: float = 1.0
|
||||
max_sessions: int = 1000
|
||||
|
||||
config: "DoSBot.ConfigSchema" = Field(default_factory=lambda: DoSBot.ConfigSchema())
|
||||
|
||||
target_ip_address: Optional[IPv4Address] = None
|
||||
"""IP address of the target service."""
|
||||
|
||||
@@ -56,7 +73,13 @@ class DoSBot(DatabaseClient, identifier="DoSBot"):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.name = "DoSBot"
|
||||
self.max_sessions = 1000 # override normal max sessions
|
||||
self.target_ip_address = self.config.target_ip_address
|
||||
self.target_port = self.config.target_port
|
||||
self.payload = self.config.payload
|
||||
self.repeat = self.config.repeat
|
||||
self.port_scan_p_of_success = self.config.port_scan_p_of_success
|
||||
self.dos_intensity = self.config.dos_intensity
|
||||
self.max_sessions = self.config.max_sessions
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
|
||||
@@ -3,12 +3,14 @@ from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import PORT_LOOKUP
|
||||
|
||||
|
||||
@@ -18,6 +20,16 @@ class RansomwareScript(Application, identifier="RansomwareScript"):
|
||||
:ivar payload: The attack stage query payload. (Default ENCRYPT)
|
||||
"""
|
||||
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""ConfigSchema for RansomwareScript."""
|
||||
|
||||
type: str = "RansomwareScript"
|
||||
server_ip: Optional[IPV4Address] = None
|
||||
server_password: Optional[str] = None
|
||||
payload: str = "ENCRYPT"
|
||||
|
||||
config: "RansomwareScript.ConfigSchema" = Field(default_factory=lambda: RansomwareScript.ConfigSchema())
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
"""IP address of node which hosts the database."""
|
||||
server_password: Optional[str] = None
|
||||
@@ -32,6 +44,9 @@ class RansomwareScript(Application, identifier="RansomwareScript"):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self._db_connection: Optional[DatabaseClientConnection] = None
|
||||
self.server_ip_address = self.config.server_ip
|
||||
self.server_password = self.config.server_password
|
||||
self.payload = self.config.payload
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
|
||||
@@ -4,7 +4,7 @@ from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
@@ -30,7 +30,13 @@ class WebBrowser(Application, identifier="WebBrowser"):
|
||||
The application requests and loads web pages using its domain name and requesting IP addresses using DNS.
|
||||
"""
|
||||
|
||||
target_url: Optional[str] = None
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""ConfigSchema for WebBrowser."""
|
||||
|
||||
type: str = "WebBrowser"
|
||||
target_url: Optional[str] = None
|
||||
|
||||
config: "WebBrowser.ConfigSchema" = Field(default_factory=lambda: WebBrowser.ConfigSchema())
|
||||
|
||||
domain_name_ip_address: Optional[IPv4Address] = None
|
||||
"The IP address of the domain name for the webpage."
|
||||
@@ -86,7 +92,7 @@ class WebBrowser(Application, identifier="WebBrowser"):
|
||||
:param: url: The address of the web page the browser requests
|
||||
:type: url: str
|
||||
"""
|
||||
url = url or self.target_url
|
||||
url = url or self.config.target_url
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
|
||||
@@ -106,7 +106,7 @@ class SoftwareManager:
|
||||
return True
|
||||
return False
|
||||
|
||||
def install(self, software_class: Type[IOSoftware], **install_kwargs):
|
||||
def install(self, software_class: Type[IOSoftware], software_config: Optional[IOSoftware.ConfigSchema] = None):
|
||||
"""
|
||||
Install an Application or Service.
|
||||
|
||||
@@ -115,13 +115,22 @@ class SoftwareManager:
|
||||
if software_class in self._software_class_to_name_map:
|
||||
self.sys_log.warning(f"Cannot install {software_class} as it is already installed")
|
||||
return
|
||||
software = software_class(
|
||||
software_manager=self,
|
||||
sys_log=self.sys_log,
|
||||
file_system=self.file_system,
|
||||
dns_server=self.dns_server,
|
||||
**install_kwargs,
|
||||
)
|
||||
if software_config is None:
|
||||
software = software_class(
|
||||
software_manager=self,
|
||||
sys_log=self.sys_log,
|
||||
file_system=self.file_system,
|
||||
dns_server=self.dns_server,
|
||||
)
|
||||
else:
|
||||
software = software_class(
|
||||
software_manager=self,
|
||||
sys_log=self.sys_log,
|
||||
file_system=self.file_system,
|
||||
dns_server=self.dns_server,
|
||||
config=software_config,
|
||||
)
|
||||
|
||||
software.parent = self.node
|
||||
if isinstance(software, Application):
|
||||
self.node.applications[software.uuid] = software
|
||||
|
||||
@@ -5,6 +5,7 @@ from abc import abstractmethod
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.simulator.network.hardware.base import NetworkInterface
|
||||
from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket
|
||||
@@ -14,7 +15,7 @@ from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import PORT_LOOKUP
|
||||
|
||||
|
||||
class ARP(Service):
|
||||
class ARP(Service, identifier="ARP"):
|
||||
"""
|
||||
The ARP (Address Resolution Protocol) Service.
|
||||
|
||||
@@ -22,6 +23,13 @@ class ARP(Service):
|
||||
sends ARP requests and replies, and processes incoming ARP packets.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for ARP."""
|
||||
|
||||
type: str = "ARP"
|
||||
|
||||
config: "ARP.ConfigSchema" = Field(default_factory=lambda: ARP.ConfigSchema())
|
||||
|
||||
arp: Dict[IPV4Address, ARPEntry] = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -3,6 +3,8 @@ from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
|
||||
@@ -17,13 +19,21 @@ from primaite.utils.validation.port import PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseService(Service):
|
||||
class DatabaseService(Service, identifier="DatabaseService"):
|
||||
"""
|
||||
A class for simulating a generic SQL Server service.
|
||||
|
||||
This class inherits from the `Service` class and provides methods to simulate a SQL database.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for DatabaseService."""
|
||||
|
||||
type: str = "DatabaseService"
|
||||
backup_server_ip: Optional[IPv4Address] = None
|
||||
|
||||
config: "DatabaseService.ConfigSchema" = Field(default_factory=lambda: DatabaseService.ConfigSchema())
|
||||
|
||||
password: Optional[str] = None
|
||||
"""Password that needs to be provided by clients if they want to connect to the DatabaseService."""
|
||||
|
||||
@@ -42,6 +52,7 @@ class DatabaseService(Service):
|
||||
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
||||
super().__init__(**kwargs)
|
||||
self._create_db_file()
|
||||
self.backup_server_ip = self.config.backup_server_ip
|
||||
|
||||
def install(self):
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
@@ -12,9 +14,15 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class DNSClient(Service):
|
||||
class DNSClient(Service, identifier="DNSClient"):
|
||||
"""Represents a DNS Client as a Service."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for DNSClient."""
|
||||
|
||||
type: str = "DNSClient"
|
||||
|
||||
config: "DNSClient.ConfigSchema" = Field(default_factory=lambda: DNSClient.ConfigSchema())
|
||||
dns_cache: Dict[str, IPv4Address] = {}
|
||||
"A dict of known mappings between domain/URLs names and IPv4 addresses."
|
||||
dns_server: Optional[IPv4Address] = None
|
||||
|
||||
@@ -3,6 +3,7 @@ from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.dns import DNSPacket
|
||||
@@ -13,9 +14,17 @@ from primaite.utils.validation.port import PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class DNSServer(Service):
|
||||
class DNSServer(Service, identifier="DNSServer"):
|
||||
"""Represents a DNS Server as a Service."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for DNSServer."""
|
||||
|
||||
type: str = "DNSServer"
|
||||
domain_mapping: dict = {}
|
||||
|
||||
config: "DNSServer.ConfigSchema" = Field(default_factory=lambda: DNSServer.ConfigSchema())
|
||||
|
||||
dns_table: Dict[str, IPv4Address] = {}
|
||||
"A dict of mappings between domain names and IPv4 addresses."
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
@@ -9,20 +11,28 @@ from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class FTPClient(FTPServiceABC):
|
||||
class FTPClient(FTPServiceABC, identifier="FTPClient"):
|
||||
"""
|
||||
A class for simulating an FTP client service.
|
||||
|
||||
This class inherits from the `Service` class and provides methods to emulate FTP
|
||||
This class inherits from the `FTPServiceABC` class and provides methods to emulate FTP
|
||||
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
|
||||
"""
|
||||
|
||||
config: "FTPClient.ConfigSchema" = Field(default_factory=lambda: FTPClient.ConfigSchema())
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for FTPClient."""
|
||||
|
||||
type: str = "FTPClient"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "FTPClient"
|
||||
kwargs["port"] = PORT_LOOKUP["FTP"]
|
||||
@@ -108,6 +118,7 @@ class FTPClient(FTPServiceABC):
|
||||
session_id: Optional[str] = None,
|
||||
is_reattempt: Optional[bool] = False,
|
||||
) -> bool:
|
||||
self._active = True
|
||||
"""
|
||||
Connects the client to a given FTP server.
|
||||
|
||||
@@ -164,6 +175,7 @@ class FTPClient(FTPServiceABC):
|
||||
:param: is_reattempt: Set to True if attempt to disconnect from FTP Server has been attempted. Default False.
|
||||
:type: is_reattempt: Optional[bool]
|
||||
"""
|
||||
self._active = True
|
||||
# send a disconnect request payload to FTP server
|
||||
payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.QUIT)
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
@@ -209,6 +221,7 @@ class FTPClient(FTPServiceABC):
|
||||
:param: session_id: The id of the session
|
||||
:type: session_id: Optional[str]
|
||||
"""
|
||||
self._active = True
|
||||
# check if the file to transfer exists on the client
|
||||
file_to_transfer: File = self.file_system.get_file(folder_name=src_folder_name, file_name=src_file_name)
|
||||
if not file_to_transfer:
|
||||
@@ -266,6 +279,7 @@ class FTPClient(FTPServiceABC):
|
||||
:param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"].
|
||||
:type: dest_port: Optional[int]
|
||||
"""
|
||||
self._active = True
|
||||
# check if FTP is currently connected to IP
|
||||
self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
|
||||
|
||||
@@ -317,6 +331,7 @@ class FTPClient(FTPServiceABC):
|
||||
This helps prevent an FTP request loop - FTP client and servers can exist on
|
||||
the same node.
|
||||
"""
|
||||
self._active = True
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
|
||||
@@ -1,26 +1,36 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.port import is_valid_port, PORT_LOOKUP
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class FTPServer(FTPServiceABC):
|
||||
class FTPServer(FTPServiceABC, identifier="FTPServer"):
|
||||
"""
|
||||
A class for simulating an FTP server service.
|
||||
|
||||
This class inherits from the `Service` class and provides methods to emulate FTP
|
||||
This class inherits from the `FTPServiceABC` class and provides methods to emulate FTP
|
||||
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
|
||||
"""
|
||||
|
||||
config: "FTPServer.ConfigSchema" = Field(default_factory=lambda: FTPServer.ConfigSchema())
|
||||
|
||||
server_password: Optional[str] = None
|
||||
"""Password needed to connect to FTP server. Default is None."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for FTPServer."""
|
||||
|
||||
type: str = "FTPServer"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "FTPServer"
|
||||
kwargs["port"] = PORT_LOOKUP["FTP"]
|
||||
|
||||
@@ -3,9 +3,11 @@ from abc import ABC
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import StrictBool
|
||||
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.services.service import Service, ServiceOperatingState
|
||||
from primaite.utils.validation.port import Port
|
||||
|
||||
|
||||
@@ -16,9 +18,22 @@ class FTPServiceABC(Service, ABC):
|
||||
Contains shared methods between both classes.
|
||||
"""
|
||||
|
||||
_active: StrictBool = False
|
||||
"""Flag that is True on timesteps where service transmits data and False when idle. Used for describe_state."""
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""When a new timestep begins, clear the _active attribute."""
|
||||
self._active = False
|
||||
return super().pre_timestep(timestep)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""Returns a Dict of the FTPService state."""
|
||||
return super().describe_state()
|
||||
state = super().describe_state()
|
||||
|
||||
# override so that the service is shows as running only if actively transmitting data this timestep
|
||||
if self.operating_state == ServiceOperatingState.RUNNING and not self._active:
|
||||
state["operating_state"] = ServiceOperatingState.STOPPED.value
|
||||
return state
|
||||
|
||||
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
|
||||
"""
|
||||
@@ -29,6 +44,7 @@ class FTPServiceABC(Service, ABC):
|
||||
:param: session_id: session ID linked to the FTP Packet. Optional.
|
||||
:type: session_id: Optional[str]
|
||||
"""
|
||||
self._active = True
|
||||
if payload.ftp_command is not None:
|
||||
self.sys_log.info(f"Received FTP {payload.ftp_command.name} command.")
|
||||
|
||||
@@ -51,6 +67,7 @@ class FTPServiceABC(Service, ABC):
|
||||
:param: payload: The FTP Packet that contains the file data
|
||||
:type: FTPPacket
|
||||
"""
|
||||
self._active = True
|
||||
try:
|
||||
file_name = payload.ftp_command_args["dest_file_name"]
|
||||
folder_name = payload.ftp_command_args["dest_folder_name"]
|
||||
@@ -106,6 +123,7 @@ class FTPServiceABC(Service, ABC):
|
||||
:param: is_response: is true if the data being sent is in response to a request. Default False.
|
||||
:type: is_response: bool
|
||||
"""
|
||||
self._active = True
|
||||
# send STOR request
|
||||
payload: FTPPacket = FTPPacket(
|
||||
ftp_command=FTPCommand.STOR,
|
||||
@@ -135,6 +153,7 @@ class FTPServiceABC(Service, ABC):
|
||||
:param: payload: The FTP Packet that contains the file data
|
||||
:type: FTPPacket
|
||||
"""
|
||||
self._active = True
|
||||
try:
|
||||
# find the file
|
||||
file_name = payload.ftp_command_args["src_file_name"]
|
||||
@@ -181,6 +200,7 @@ class FTPServiceABC(Service, ABC):
|
||||
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
self._active = True
|
||||
self.sys_log.info(f"{self.name}: Sending FTP {payload.ftp_command.name} {payload.ftp_command_args}")
|
||||
|
||||
return super().send(
|
||||
|
||||
@@ -3,6 +3,8 @@ import secrets
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.hardware.base import NetworkInterface
|
||||
from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType
|
||||
@@ -14,7 +16,7 @@ from primaite.utils.validation.port import PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class ICMP(Service):
|
||||
class ICMP(Service, identifier="ICMP"):
|
||||
"""
|
||||
The Internet Control Message Protocol (ICMP) service.
|
||||
|
||||
@@ -22,6 +24,13 @@ class ICMP(Service):
|
||||
network diagnostics, notably the ping command.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for ICMP."""
|
||||
|
||||
type: str = "ICMP"
|
||||
|
||||
config: "ICMP.ConfigSchema" = Field(default_factory=lambda: ICMP.ConfigSchema())
|
||||
|
||||
request_replies: Dict = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -3,6 +3,8 @@ from datetime import datetime
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.ntp import NTPPacket
|
||||
from primaite.simulator.system.services.service import Service, ServiceOperatingState
|
||||
@@ -12,9 +14,16 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class NTPClient(Service):
|
||||
class NTPClient(Service, identifier="NTPClient"):
|
||||
"""Represents a NTP client as a service."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for NTPClient."""
|
||||
|
||||
type: str = "NTPClient"
|
||||
|
||||
config: "NTPClient.ConfigSchema" = Field(default_factory=lambda: NTPClient.ConfigSchema())
|
||||
|
||||
ntp_server: Optional[IPv4Address] = None
|
||||
"The NTP server the client sends requests to."
|
||||
time: Optional[datetime] = None
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.ntp import NTPPacket
|
||||
from primaite.simulator.system.services.service import Service
|
||||
@@ -11,9 +13,16 @@ from primaite.utils.validation.port import PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class NTPServer(Service):
|
||||
class NTPServer(Service, identifier="NTPServer"):
|
||||
"""Represents a NTP server as a service."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for NTPServer."""
|
||||
|
||||
type: str = "NTPServer"
|
||||
|
||||
config: "NTPServer.ConfigSchema" = Field(default_factory=lambda: NTPServer.ConfigSchema())
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "NTPServer"
|
||||
kwargs["port"] = PORT_LOOKUP["NTP"]
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar, Dict, Optional, Type
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType
|
||||
@@ -37,6 +39,13 @@ class Service(IOSoftware):
|
||||
Services are programs that run in the background and may perform input/output operations.
|
||||
"""
|
||||
|
||||
class ConfigSchema(IOSoftware.ConfigSchema, ABC):
|
||||
"""Config Schema for Service class."""
|
||||
|
||||
type: str
|
||||
|
||||
config: "Service.ConfigSchema" = Field(default_factory=lambda: Service.ConfigSchema())
|
||||
|
||||
operating_state: ServiceOperatingState = ServiceOperatingState.STOPPED
|
||||
"The current operating state of the Service."
|
||||
|
||||
@@ -52,7 +61,7 @@ class Service(IOSoftware):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register a hostnode type.
|
||||
|
||||
@@ -60,15 +69,30 @@ class Service(IOSoftware):
|
||||
:type identifier: str
|
||||
:raises ValueError: When attempting to register an hostnode with a name that is already allocated.
|
||||
"""
|
||||
if identifier == "default":
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
# Enforce lowercase registry entries because it makes comparisons everywhere else much easier.
|
||||
identifier = identifier.lower()
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Tried to define new hostnode {identifier}, but this name is already reserved.")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "Service":
|
||||
"""Create a service from a config dictionary.
|
||||
|
||||
:param config: dict of options for service components constructor
|
||||
:type config: dict
|
||||
:return: The service component.
|
||||
:rtype: Service
|
||||
"""
|
||||
if config["type"] not in cls._registry:
|
||||
raise ValueError(f"Invalid service type {config['type']}")
|
||||
service_class = cls._registry[config["type"]]
|
||||
service_object = service_class(config=service_class.ConfigSchema(**config))
|
||||
return service_object
|
||||
|
||||
def _can_perform_action(self) -> bool:
|
||||
"""
|
||||
Checks if the service can perform actions.
|
||||
@@ -232,14 +256,14 @@ class Service(IOSoftware):
|
||||
|
||||
def disable(self) -> bool:
|
||||
"""Disable the service."""
|
||||
self.sys_log.info(f"Disabling Application {self.name}")
|
||||
self.sys_log.info(f"Disabling Service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.DISABLED
|
||||
return True
|
||||
|
||||
def enable(self) -> bool:
|
||||
"""Enable the disabled service."""
|
||||
if self.operating_state == ServiceOperatingState.DISABLED:
|
||||
self.sys_log.info(f"Enabling Application {self.name}")
|
||||
self.sys_log.info(f"Enabling Service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.STOPPED
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -7,7 +7,7 @@ from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
@@ -129,9 +129,16 @@ class RemoteTerminalConnection(TerminalClientConnection):
|
||||
return self.parent_terminal.send(payload=payload, session_id=self.ssh_session_id)
|
||||
|
||||
|
||||
class Terminal(Service):
|
||||
class Terminal(Service, identifier="Terminal"):
|
||||
"""Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for Terminal."""
|
||||
|
||||
type: str = "Terminal"
|
||||
|
||||
config: "Terminal.ConfigSchema" = Field(default_factory=lambda: Terminal.ConfigSchema())
|
||||
|
||||
_client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {}
|
||||
"""Dictionary of connect requests made to remote nodes."""
|
||||
|
||||
@@ -179,7 +186,7 @@ class Terminal(Service):
|
||||
return RequestResponse(status="failure", data={})
|
||||
|
||||
rm.add_request(
|
||||
"ssh_to_remote",
|
||||
"node_session_remote_login",
|
||||
request_type=RequestType(func=_remote_login),
|
||||
)
|
||||
|
||||
@@ -279,7 +286,6 @@ class Terminal(Service):
|
||||
:param password: Password for login.
|
||||
:return: boolean, True if successful, else False
|
||||
"""
|
||||
# TODO: Un-comment this when UserSessionManager is merged.
|
||||
connection_uuid = self.parent.user_session_manager.local_login(username=username, password=password)
|
||||
if connection_uuid:
|
||||
self.sys_log.info(f"{self.name}: Login request authorised, connection uuid: {connection_uuid}")
|
||||
@@ -406,7 +412,6 @@ class Terminal(Service):
|
||||
if isinstance(payload, SSHPacket):
|
||||
if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST:
|
||||
# validate & add connection
|
||||
# TODO: uncomment this as part of 2781
|
||||
username = payload.user_account.username
|
||||
password = payload.user_account.password
|
||||
connection_id = self.parent.user_session_manager.remote_login(
|
||||
|
||||
@@ -3,6 +3,8 @@ from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.http import (
|
||||
HttpRequestMethod,
|
||||
@@ -19,9 +21,16 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class WebServer(Service):
|
||||
class WebServer(Service, identifier="WebServer"):
|
||||
"""Class used to represent a Web Server Service in simulation."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for WebServer."""
|
||||
|
||||
type: str = "WebServer"
|
||||
|
||||
config: "WebServer.ConfigSchema" = Field(default_factory=lambda: WebServer.ConfigSchema())
|
||||
|
||||
response_codes_this_timestep: List[HttpStatusCode] = []
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
import copy
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Optional, Set, TYPE_CHECKING, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
@@ -70,7 +70,7 @@ class SoftwareCriticality(Enum):
|
||||
"The highest level of criticality."
|
||||
|
||||
|
||||
class Software(SimComponent):
|
||||
class Software(SimComponent, ABC):
|
||||
"""
|
||||
A base class representing software in a simulator environment.
|
||||
|
||||
@@ -78,14 +78,22 @@ class Software(SimComponent):
|
||||
It outlines the fundamental attributes and behaviors expected of any software in the simulation.
|
||||
"""
|
||||
|
||||
class ConfigSchema(BaseModel, ABC):
|
||||
"""Configurable options for all software."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
starting_health_state: SoftwareHealthState = SoftwareHealthState.UNUSED
|
||||
criticality: SoftwareCriticality = SoftwareCriticality.LOWEST
|
||||
fixing_duration: int = 2
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: Software.ConfigSchema())
|
||||
|
||||
name: str
|
||||
"The name of the software."
|
||||
health_state_actual: SoftwareHealthState = SoftwareHealthState.UNUSED
|
||||
"The actual health state of the software."
|
||||
health_state_visible: SoftwareHealthState = SoftwareHealthState.UNUSED
|
||||
"The health state of the software visible to the red agent."
|
||||
criticality: SoftwareCriticality = SoftwareCriticality.LOWEST
|
||||
"The criticality level of the software."
|
||||
fixing_count: int = 0
|
||||
"The count of patches applied to the software, defaults to 0."
|
||||
scanning_count: int = 0
|
||||
@@ -100,11 +108,13 @@ class Software(SimComponent):
|
||||
"The FileSystem of the Node the Software is installed on."
|
||||
folder: Optional[Folder] = None
|
||||
"The folder on the file system the Software uses."
|
||||
fixing_duration: int = 2
|
||||
"The number of ticks it takes to patch the software."
|
||||
_fixing_countdown: Optional[int] = None
|
||||
"Current number of ticks left to patch the software."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.health_state_actual = self.config.starting_health_state # don't remove this
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
@@ -152,7 +162,7 @@ class Software(SimComponent):
|
||||
{
|
||||
"health_state_actual": self.health_state_actual.value,
|
||||
"health_state_visible": self.health_state_visible.value,
|
||||
"criticality": self.criticality.value,
|
||||
"criticality": self.config.criticality.value,
|
||||
"fixing_count": self.fixing_count,
|
||||
"scanning_count": self.scanning_count,
|
||||
"revealed_to_red": self.revealed_to_red,
|
||||
@@ -201,7 +211,7 @@ class Software(SimComponent):
|
||||
def fix(self) -> bool:
|
||||
"""Perform a fix on the software."""
|
||||
if self.health_state_actual in (SoftwareHealthState.COMPROMISED, SoftwareHealthState.GOOD):
|
||||
self._fixing_countdown = self.fixing_duration
|
||||
self._fixing_countdown = self.config.fixing_duration
|
||||
self.set_health_state(SoftwareHealthState.FIXING)
|
||||
return True
|
||||
return False
|
||||
@@ -233,7 +243,7 @@ class Software(SimComponent):
|
||||
super().pre_timestep(timestep)
|
||||
|
||||
|
||||
class IOSoftware(Software):
|
||||
class IOSoftware(Software, ABC):
|
||||
"""
|
||||
Represents software in a simulator environment that is capable of input/output operations.
|
||||
|
||||
@@ -243,6 +253,13 @@ class IOSoftware(Software):
|
||||
required.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Software.ConfigSchema, ABC):
|
||||
"""Configuration options for all IO Software."""
|
||||
|
||||
listen_on_ports: Set[Port] = Field(default_factory=set)
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: IOSoftware.ConfigSchema())
|
||||
|
||||
installing_count: int = 0
|
||||
"The number of times the software has been installed. Default is 0."
|
||||
max_sessions: int = 100
|
||||
@@ -260,6 +277,10 @@ class IOSoftware(Software):
|
||||
_connections: Dict[str, Dict] = {}
|
||||
"Active connections."
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.listen_on_ports = self.config.listen_on_ports
|
||||
|
||||
@abstractmethod
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
raise DeprecationWarning(
|
||||
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
|
||||
)
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
raise DeprecationWarning(
|
||||
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
|
||||
)
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import csv
|
||||
from logging import Logger
|
||||
from typing import Final, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
@@ -31,7 +31,7 @@ def ipv4_validator(v: Any) -> IPv4Address:
|
||||
|
||||
IPV4Address: Final[Annotated] = Annotated[IPv4Address, BeforeValidator(ipv4_validator)]
|
||||
"""
|
||||
IPv4Address with with IPv4Address with with pre-validation and auto-conversion from str using ipv4_validator..
|
||||
IPv4Address with pre-validation and auto-conversion from str using ipv4_validator..
|
||||
|
||||
This type is essentially an IPv4Address from the standard library's ipaddress module,
|
||||
but with added validation logic. If you use this custom type, the ipv4_validator function
|
||||
|
||||
Reference in New Issue
Block a user