Merge branch 'dev' into bugfix/2299-check_hash_function_corrupts_files_and_folders
This commit is contained in:
@@ -1 +1 @@
|
||||
3.0.0b6
|
||||
3.0.0b9dev
|
||||
|
||||
@@ -114,23 +114,3 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
reset_example_configs.run(overwrite_existing=True)
|
||||
|
||||
_LOGGER.info("PrimAITE setup complete!")
|
||||
|
||||
|
||||
@app.command()
|
||||
def session(
|
||||
config: Optional[str] = None,
|
||||
agent_load_file: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
|
||||
:param config: The path to the config file. Optional, if None, the example config will be used.
|
||||
:type config: Optional[str]
|
||||
"""
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.main import run
|
||||
|
||||
if not config:
|
||||
config = data_manipulation_config_path()
|
||||
print(config)
|
||||
run(config_path=config, agent_load_path=agent_load_file)
|
||||
|
||||
@@ -1,26 +1,12 @@
|
||||
training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 1
|
||||
n_eval_episodes: 5
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
n_agents: 1
|
||||
agent_references:
|
||||
- defender
|
||||
|
||||
io_settings:
|
||||
save_checkpoints: true
|
||||
checkpoint_interval: 5
|
||||
save_agent_actions: true
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: false
|
||||
save_sys_logs: true
|
||||
save_sys_logs: false
|
||||
|
||||
|
||||
game:
|
||||
max_episode_length: 256
|
||||
max_episode_length: 128
|
||||
ports:
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
@@ -43,8 +29,7 @@ agents:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -76,7 +61,14 @@ agents:
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
options:
|
||||
node_hostname: client_2
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_2
|
||||
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
@@ -86,8 +78,7 @@ agents:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -119,7 +110,14 @@ agents:
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
options:
|
||||
node_hostname: client_1
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_1
|
||||
|
||||
|
||||
|
||||
@@ -129,10 +127,7 @@ agents:
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes: {}
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
@@ -165,61 +160,73 @@ agents:
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
type: CUSTOM
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_name: DNSServer
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- node_hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_hostname: backup_server
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_hostname: domain_controller
|
||||
nic_num: 1
|
||||
- node_hostname: web_server
|
||||
nic_num: 1
|
||||
- node_hostname: database_server
|
||||
nic_num: 1
|
||||
- node_hostname: backup_server
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 1
|
||||
- node_hostname: client_1
|
||||
nic_num: 1
|
||||
- node_hostname: client_2
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1:eth-1<->switch_1:eth-8
|
||||
- router_1:eth-2<->switch_2:eth-8
|
||||
- switch_1:eth-1<->domain_controller:eth-1
|
||||
- switch_1:eth-2<->web_server:eth-1
|
||||
- switch_1:eth-3<->database_server:eth-1
|
||||
- switch_1:eth-4<->backup_server:eth-1
|
||||
- switch_1:eth-7<->security_suite:eth-1
|
||||
- switch_2:eth-1<->client_1:eth-1
|
||||
- switch_2:eth-2<->client_2:eth-1
|
||||
- switch_2:eth-7<->security_suite:eth-2
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
@@ -232,7 +239,7 @@ agents:
|
||||
- type: NODE_SERVICE_RESTART
|
||||
- type: NODE_SERVICE_DISABLE
|
||||
- type: NODE_SERVICE_ENABLE
|
||||
- type: NODE_SERVICE_PATCH
|
||||
- type: NODE_SERVICE_FIX
|
||||
- type: NODE_FILE_SCAN
|
||||
- type: NODE_FILE_CHECKHASH
|
||||
- type: NODE_FILE_DELETE
|
||||
@@ -246,14 +253,10 @@ agents:
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: NODE_RESET
|
||||
- type: NETWORK_ACL_ADDRULE
|
||||
options:
|
||||
target_router_hostname: router_1
|
||||
- type: NETWORK_ACL_REMOVERULE
|
||||
options:
|
||||
target_router_hostname: router_1
|
||||
- type: NETWORK_NIC_ENABLE
|
||||
- type: NETWORK_NIC_DISABLE
|
||||
- type: ROUTER_ACL_ADDRULE
|
||||
- type: ROUTER_ACL_REMOVERULE
|
||||
- type: HOST_NIC_ENABLE
|
||||
- type: HOST_NIC_DISABLE
|
||||
|
||||
action_map:
|
||||
0:
|
||||
@@ -309,7 +312,7 @@ agents:
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
10:
|
||||
action: "NODE_FILE_CHECKHASH"
|
||||
action: "NODE_FILE_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
@@ -327,7 +330,7 @@ agents:
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
13:
|
||||
action: "NODE_SERVICE_PATCH"
|
||||
action: "NODE_SERVICE_FIX"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 0
|
||||
@@ -337,7 +340,7 @@ agents:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
15:
|
||||
action: "NODE_FOLDER_CHECKHASH"
|
||||
action: "NODE_FOLDER_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
@@ -465,8 +468,9 @@ agents:
|
||||
node_id: 6
|
||||
|
||||
46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1"
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
@@ -474,9 +478,12 @@ agents:
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2"
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 2
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
@@ -484,9 +491,12 @@ agents:
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
48: # old action num: 24 # block tcp traffic from client 1 to web app
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 3
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
@@ -494,9 +504,12 @@ agents:
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
49: # old action num: 25 # block tcp traffic from client 2 to web app
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 4
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
@@ -504,9 +517,12 @@ agents:
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
50: # old action num: 26
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 5
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
@@ -514,9 +530,12 @@ agents:
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
51: # old action num: 27
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 6
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
@@ -524,123 +543,135 @@ agents:
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
52: # old action num: 28
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 0
|
||||
53: # old action num: 29
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 1
|
||||
54: # old action num: 30
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 2
|
||||
55: # old action num: 31
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 3
|
||||
56: # old action num: 32
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 4
|
||||
57: # old action num: 33
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 5
|
||||
58: # old action num: 34
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 6
|
||||
59: # old action num: 35
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 7
|
||||
60: # old action num: 36
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 8
|
||||
61: # old action num: 37
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 9
|
||||
62: # old action num: 38
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
63: # old action num: 39
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
64: # old action num: 40
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
65: # old action num: 41
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
66: # old action num: 42
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 0
|
||||
67: # old action num: 43
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 0
|
||||
68: # old action num: 44
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 0
|
||||
69: # old action num: 45
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 0
|
||||
70: # old action num: 46
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 0
|
||||
71: # old action num: 47
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 0
|
||||
72: # old action num: 48
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
73: # old action num: 49
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
74: # old action num: 50
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 0
|
||||
75: # old action num: 51
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 0
|
||||
76: # old action num: 52
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 0
|
||||
77: # old action num: 53
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 0
|
||||
@@ -672,23 +703,15 @@ agents:
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_name: domain_controller
|
||||
nic_num: 1
|
||||
- node_name: web_server
|
||||
nic_num: 1
|
||||
- node_name: database_server
|
||||
nic_num: 1
|
||||
- node_name: backup_server
|
||||
nic_num: 1
|
||||
- node_name: security_suite
|
||||
nic_num: 1
|
||||
- node_name: client_1
|
||||
nic_num: 1
|
||||
- node_name: client_2
|
||||
nic_num: 1
|
||||
- node_name: security_suite
|
||||
nic_num: 2
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
|
||||
|
||||
reward_function:
|
||||
@@ -699,22 +722,17 @@ agents:
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
|
||||
- type: SHARED_REWARD
|
||||
weight: 1.0
|
||||
options:
|
||||
node_hostname: client_1
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
agent_name: client_1_green_user
|
||||
|
||||
- type: SHARED_REWARD
|
||||
weight: 1.0
|
||||
options:
|
||||
node_hostname: client_2
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_1
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_2
|
||||
agent_name: client_2_green_user
|
||||
|
||||
|
||||
|
||||
agent_settings:
|
||||
@@ -732,8 +750,7 @@ simulation:
|
||||
- DELETE
|
||||
nodes:
|
||||
|
||||
- ref: router_1
|
||||
hostname: router_1
|
||||
- hostname: router_1
|
||||
type: router
|
||||
num_ports: 5
|
||||
ports:
|
||||
@@ -768,74 +785,61 @@ simulation:
|
||||
action: PERMIT
|
||||
protocol: ICMP
|
||||
|
||||
- ref: switch_1
|
||||
hostname: switch_1
|
||||
- hostname: switch_1
|
||||
type: switch
|
||||
num_ports: 8
|
||||
|
||||
- ref: switch_2
|
||||
hostname: switch_2
|
||||
- hostname: switch_2
|
||||
type: switch
|
||||
num_ports: 8
|
||||
|
||||
- ref: domain_controller
|
||||
hostname: domain_controller
|
||||
- hostname: domain_controller
|
||||
type: server
|
||||
ip_address: 192.168.1.10
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
services:
|
||||
- ref: domain_controller_dns_server
|
||||
type: DNSServer
|
||||
- type: DNSServer
|
||||
options:
|
||||
domain_mapping:
|
||||
arcd.com: 192.168.1.12 # web server
|
||||
|
||||
- ref: web_server
|
||||
hostname: web_server
|
||||
- hostname: web_server
|
||||
type: server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
- type: WebServer
|
||||
applications:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
- type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
|
||||
|
||||
- ref: database_server
|
||||
hostname: database_server
|
||||
- hostname: database_server
|
||||
type: server
|
||||
ip_address: 192.168.1.14
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: database_service
|
||||
type: DatabaseService
|
||||
- type: DatabaseService
|
||||
options:
|
||||
backup_server_ip: 192.168.1.16
|
||||
- ref: database_ftp_client
|
||||
type: FTPClient
|
||||
- type: FTPClient
|
||||
|
||||
- ref: backup_server
|
||||
hostname: backup_server
|
||||
- hostname: backup_server
|
||||
type: server
|
||||
ip_address: 192.168.1.16
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: backup_service
|
||||
type: FTPServer
|
||||
- type: FTPServer
|
||||
|
||||
- ref: security_suite
|
||||
hostname: security_suite
|
||||
- hostname: security_suite
|
||||
type: server
|
||||
ip_address: 192.168.1.110
|
||||
subnet_mask: 255.255.255.0
|
||||
@@ -846,110 +850,88 @@ simulation:
|
||||
ip_address: 192.168.10.110
|
||||
subnet_mask: 255.255.255.0
|
||||
|
||||
- ref: client_1
|
||||
hostname: client_1
|
||||
- hostname: client_1
|
||||
type: computer
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- ref: data_manipulation_bot
|
||||
type: DataManipulationBot
|
||||
- type: DataManipulationBot
|
||||
options:
|
||||
port_scan_p_of_success: 0.8
|
||||
data_manipulation_p_of_success: 0.8
|
||||
payload: "DELETE"
|
||||
server_ip: 192.168.1.14
|
||||
- ref: client_1_web_browser
|
||||
type: WebBrowser
|
||||
- type: WebBrowser
|
||||
options:
|
||||
target_url: http://arcd.com/users/
|
||||
- ref: client_1_database_client
|
||||
type: DatabaseClient
|
||||
- type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
services:
|
||||
- ref: client_1_dns_client
|
||||
type: DNSClient
|
||||
- type: DNSClient
|
||||
|
||||
- ref: client_2
|
||||
hostname: client_2
|
||||
- hostname: client_2
|
||||
type: computer
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- ref: client_2_web_browser
|
||||
type: WebBrowser
|
||||
- type: WebBrowser
|
||||
options:
|
||||
target_url: http://arcd.com/users/
|
||||
- ref: data_manipulation_bot
|
||||
type: DataManipulationBot
|
||||
- type: DataManipulationBot
|
||||
options:
|
||||
port_scan_p_of_success: 0.8
|
||||
data_manipulation_p_of_success: 0.8
|
||||
payload: "DELETE"
|
||||
server_ip: 192.168.1.14
|
||||
- ref: client_2_database_client
|
||||
type: DatabaseClient
|
||||
- type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
services:
|
||||
- ref: client_2_dns_client
|
||||
type: DNSClient
|
||||
|
||||
|
||||
- type: DNSClient
|
||||
|
||||
links:
|
||||
- ref: router_1___switch_1
|
||||
endpoint_a_ref: router_1
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: switch_1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 8
|
||||
- ref: router_1___switch_2
|
||||
endpoint_a_ref: router_1
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: switch_2
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 8
|
||||
- ref: switch_1___domain_controller
|
||||
endpoint_a_ref: switch_1
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: domain_controller
|
||||
endpoint_b_hostname: domain_controller
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___web_server
|
||||
endpoint_a_ref: switch_1
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: web_server
|
||||
endpoint_b_hostname: web_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___database_server
|
||||
endpoint_a_ref: switch_1
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 3
|
||||
endpoint_b_ref: database_server
|
||||
endpoint_b_hostname: database_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___backup_server
|
||||
endpoint_a_ref: switch_1
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 4
|
||||
endpoint_b_ref: backup_server
|
||||
endpoint_b_hostname: backup_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___security_suite
|
||||
endpoint_a_ref: switch_1
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_ref: security_suite
|
||||
endpoint_b_hostname: security_suite
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___client_1
|
||||
endpoint_a_ref: switch_2
|
||||
- endpoint_a_hostname: switch_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: client_1
|
||||
endpoint_b_hostname: client_1
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___client_2
|
||||
endpoint_a_ref: switch_2
|
||||
- endpoint_a_hostname: switch_2
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: client_2
|
||||
endpoint_b_hostname: client_2
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___security_suite
|
||||
endpoint_a_ref: switch_2
|
||||
- endpoint_a_hostname: switch_2
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_ref: security_suite
|
||||
endpoint_b_hostname: security_suite
|
||||
endpoint_b_port: 2
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -156,12 +156,12 @@ class NodeServiceEnableAction(NodeServiceAbstractAction):
|
||||
self.verb: str = "enable"
|
||||
|
||||
|
||||
class NodeServicePatchAction(NodeServiceAbstractAction):
|
||||
"""Action which patches a service."""
|
||||
class NodeServiceFixAction(NodeServiceAbstractAction):
|
||||
"""Action which fixes a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb: str = "patch"
|
||||
self.verb: str = "fix"
|
||||
|
||||
|
||||
class NodeApplicationAbstractAction(AbstractAction):
|
||||
@@ -195,6 +195,69 @@ class NodeApplicationExecuteAction(NodeApplicationAbstractAction):
|
||||
self.verb: str = "execute"
|
||||
|
||||
|
||||
class NodeApplicationScanAction(NodeApplicationAbstractAction):
|
||||
"""Action which scans an application."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications)
|
||||
self.verb: str = "scan"
|
||||
|
||||
|
||||
class NodeApplicationCloseAction(NodeApplicationAbstractAction):
|
||||
"""Action which closes an application."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications)
|
||||
self.verb: str = "close"
|
||||
|
||||
|
||||
class NodeApplicationFixAction(NodeApplicationAbstractAction):
|
||||
"""Action which fixes an application."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications)
|
||||
self.verb: str = "fix"
|
||||
|
||||
|
||||
class NodeApplicationInstallAction(AbstractAction):
|
||||
"""Action which installs an application."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes}
|
||||
|
||||
def form_request(self, node_id: int, application_name: str, ip_address: str) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
if node_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
node_name,
|
||||
"software_manager",
|
||||
"application",
|
||||
"install",
|
||||
application_name,
|
||||
ip_address,
|
||||
]
|
||||
|
||||
|
||||
class NodeApplicationRemoveAction(AbstractAction):
|
||||
"""Action which removes/uninstalls an application."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes}
|
||||
|
||||
def form_request(self, node_id: int, application_name: str) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
if node_name is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", node_name, "software_manager", "application", "uninstall", application_name]
|
||||
|
||||
|
||||
class NodeFolderAbstractAction(AbstractAction):
|
||||
"""
|
||||
Base class for folder actions.
|
||||
@@ -381,25 +444,22 @@ class NodeResetAction(NodeAbstractAction):
|
||||
self.verb: str = "reset"
|
||||
|
||||
|
||||
class NetworkACLAddRuleAction(AbstractAction):
|
||||
class RouterACLAddRuleAction(AbstractAction):
|
||||
"""Action which adds a rule to a router's ACL."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: "ActionManager",
|
||||
target_router_hostname: str,
|
||||
max_acl_rules: int,
|
||||
num_ips: int,
|
||||
num_ports: int,
|
||||
num_protocols: int,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Init method for NetworkACLAddRuleAction.
|
||||
"""Init method for RouterACLAddRuleAction.
|
||||
|
||||
:param manager: Reference to the ActionManager which created this action.
|
||||
:type manager: ActionManager
|
||||
:param target_router_hostname: hostname of the router to which the ACL rule should be added.
|
||||
:type target_router_hostname: str
|
||||
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
|
||||
:type max_acl_rules: int
|
||||
:param num_ips: Number of IP addresses in the simulation.
|
||||
@@ -420,14 +480,16 @@ class NetworkACLAddRuleAction(AbstractAction):
|
||||
"dest_port_id": num_ports,
|
||||
"protocol_id": num_protocols,
|
||||
}
|
||||
self.target_router_name: str = target_router_hostname
|
||||
|
||||
def form_request(
|
||||
self,
|
||||
target_router_nodename: str,
|
||||
position: int,
|
||||
permission: int,
|
||||
source_ip_id: int,
|
||||
source_wildcard_id: int,
|
||||
dest_ip_id: int,
|
||||
dest_wildcard_id: int,
|
||||
source_port_id: int,
|
||||
dest_port_id: int,
|
||||
protocol_id: int,
|
||||
@@ -437,7 +499,149 @@ class NetworkACLAddRuleAction(AbstractAction):
|
||||
permission_str = "UNUSED"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
elif permission == 1:
|
||||
permission_str = "ALLOW"
|
||||
permission_str = "PERMIT"
|
||||
elif permission == 2:
|
||||
permission_str = "DENY"
|
||||
else:
|
||||
_LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.")
|
||||
|
||||
if protocol_id == 0:
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
|
||||
if protocol_id == 1:
|
||||
protocol = "ALL"
|
||||
else:
|
||||
protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2)
|
||||
# subtract 2 to account for UNUSED=0 and ALL=1.
|
||||
|
||||
if source_ip_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif source_ip_id == 1:
|
||||
src_ip = "ALL"
|
||||
else:
|
||||
src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
src_wildcard = self.manager.get_wildcard_by_idx(source_wildcard_id)
|
||||
if source_port_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif source_port_id == 1:
|
||||
src_port = "ALL"
|
||||
else:
|
||||
src_port = self.manager.get_port_by_idx(source_port_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if dest_ip_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif dest_ip_id == 1:
|
||||
dst_ip = "ALL"
|
||||
else:
|
||||
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
dst_wildcard = self.manager.get_wildcard_by_idx(dest_wildcard_id)
|
||||
|
||||
if dest_port_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif dest_port_id == 1:
|
||||
dst_port = "ALL"
|
||||
else:
|
||||
dst_port = self.manager.get_port_by_idx(dest_port_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
target_router_nodename,
|
||||
"acl",
|
||||
"add_rule",
|
||||
permission_str,
|
||||
protocol,
|
||||
str(src_ip),
|
||||
src_wildcard,
|
||||
src_port,
|
||||
str(dst_ip),
|
||||
dst_wildcard,
|
||||
dst_port,
|
||||
position,
|
||||
]
|
||||
|
||||
|
||||
class RouterACLRemoveRuleAction(AbstractAction):
|
||||
"""Action which removes a rule from a router's ACL."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", max_acl_rules: int, **kwargs) -> None:
|
||||
"""Init method for RouterACLRemoveRuleAction.
|
||||
|
||||
:param manager: Reference to the ActionManager which created this action.
|
||||
:type manager: ActionManager
|
||||
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
|
||||
:type max_acl_rules: int
|
||||
"""
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"position": max_acl_rules}
|
||||
|
||||
def form_request(self, target_router_nodename: str, position: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["network", "node", target_router_nodename, "acl", "remove_rule", position]
|
||||
|
||||
|
||||
class FirewallACLAddRuleAction(AbstractAction):
|
||||
"""Action which adds a rule to a firewall port's ACL."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: "ActionManager",
|
||||
max_acl_rules: int,
|
||||
num_ips: int,
|
||||
num_ports: int,
|
||||
num_protocols: int,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Init method for FirewallACLAddRuleAction.
|
||||
|
||||
:param manager: Reference to the ActionManager which created this action.
|
||||
:type manager: ActionManager
|
||||
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
|
||||
:type max_acl_rules: int
|
||||
:param num_ips: Number of IP addresses in the simulation.
|
||||
:type num_ips: int
|
||||
:param num_ports: Number of ports in the simulation.
|
||||
:type num_ports: int
|
||||
:param num_protocols: Number of protocols in the simulation.
|
||||
:type num_protocols: int
|
||||
"""
|
||||
super().__init__(manager=manager)
|
||||
num_permissions = 3
|
||||
self.shape: Dict[str, int] = {
|
||||
"position": max_acl_rules,
|
||||
"permission": num_permissions,
|
||||
"source_ip_id": num_ips,
|
||||
"dest_ip_id": num_ips,
|
||||
"source_port_id": num_ports,
|
||||
"dest_port_id": num_ports,
|
||||
"protocol_id": num_protocols,
|
||||
}
|
||||
|
||||
def form_request(
|
||||
self,
|
||||
target_firewall_nodename: str,
|
||||
firewall_port_name: str,
|
||||
firewall_port_direction: str,
|
||||
position: int,
|
||||
permission: int,
|
||||
source_ip_id: int,
|
||||
source_wildcard_id: int,
|
||||
dest_ip_id: int,
|
||||
dest_wildcard_id: int,
|
||||
source_port_id: int,
|
||||
dest_port_id: int,
|
||||
protocol_id: int,
|
||||
) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if permission == 0:
|
||||
permission_str = "UNUSED"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
elif permission == 1:
|
||||
permission_str = "PERMIT"
|
||||
elif permission == 2:
|
||||
permission_str = "DENY"
|
||||
else:
|
||||
@@ -468,7 +672,7 @@ class NetworkACLAddRuleAction(AbstractAction):
|
||||
src_port = self.manager.get_port_by_idx(source_port_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if source_ip_id == 0:
|
||||
if dest_ip_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif dest_ip_id == 1:
|
||||
dst_ip = "ALL"
|
||||
@@ -483,46 +687,60 @@ class NetworkACLAddRuleAction(AbstractAction):
|
||||
else:
|
||||
dst_port = self.manager.get_port_by_idx(dest_port_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
src_wildcard = self.manager.get_wildcard_by_idx(source_wildcard_id)
|
||||
dst_wildcard = self.manager.get_wildcard_by_idx(dest_wildcard_id)
|
||||
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
self.target_router_name,
|
||||
target_firewall_nodename,
|
||||
firewall_port_name,
|
||||
firewall_port_direction,
|
||||
"acl",
|
||||
"add_rule",
|
||||
permission_str,
|
||||
protocol,
|
||||
str(src_ip),
|
||||
src_wildcard,
|
||||
src_port,
|
||||
str(dst_ip),
|
||||
dst_wildcard,
|
||||
dst_port,
|
||||
position,
|
||||
]
|
||||
|
||||
|
||||
class NetworkACLRemoveRuleAction(AbstractAction):
|
||||
"""Action which removes a rule from a router's ACL."""
|
||||
class FirewallACLRemoveRuleAction(AbstractAction):
|
||||
"""Action which removes a rule from a firewall port's ACL."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", target_router_hostname: str, max_acl_rules: int, **kwargs) -> None:
|
||||
"""Init method for NetworkACLRemoveRuleAction.
|
||||
def __init__(self, manager: "ActionManager", max_acl_rules: int, **kwargs) -> None:
|
||||
"""Init method for RouterACLRemoveRuleAction.
|
||||
|
||||
:param manager: Reference to the ActionManager which created this action.
|
||||
:type manager: ActionManager
|
||||
:param target_router_hostname: Hostname of the router from which the ACL rule should be removed.
|
||||
:type target_router_hostname: str
|
||||
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
|
||||
:type max_acl_rules: int
|
||||
"""
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"position": max_acl_rules}
|
||||
self.target_router_name: str = target_router_hostname
|
||||
|
||||
def form_request(self, position: int) -> List[str]:
|
||||
def form_request(
|
||||
self, target_firewall_nodename: str, firewall_port_name: str, firewall_port_direction: str, position: int
|
||||
) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["network", "node", self.target_router_name, "acl", "remove_rule", position]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
target_firewall_nodename,
|
||||
firewall_port_name,
|
||||
firewall_port_direction,
|
||||
"acl",
|
||||
"remove_rule",
|
||||
position,
|
||||
]
|
||||
|
||||
|
||||
class NetworkNICAbstractAction(AbstractAction):
|
||||
class HostNICAbstractAction(AbstractAction):
|
||||
"""
|
||||
Abstract base class for NIC actions.
|
||||
|
||||
@@ -531,7 +749,7 @@ class NetworkNICAbstractAction(AbstractAction):
|
||||
"""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
"""Init method for NetworkNICAbstractAction.
|
||||
"""Init method for HostNICAbstractAction.
|
||||
|
||||
:param manager: Reference to the ActionManager which created this action.
|
||||
:type manager: ActionManager
|
||||
@@ -553,7 +771,7 @@ class NetworkNICAbstractAction(AbstractAction):
|
||||
return ["network", "node", node_name, "network_interface", nic_num, self.verb]
|
||||
|
||||
|
||||
class NetworkNICEnableAction(NetworkNICAbstractAction):
|
||||
class HostNICEnableAction(HostNICAbstractAction):
|
||||
"""Action which enables a NIC."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
@@ -561,7 +779,7 @@ class NetworkNICEnableAction(NetworkNICAbstractAction):
|
||||
self.verb: str = "enable"
|
||||
|
||||
|
||||
class NetworkNICDisableAction(NetworkNICAbstractAction):
|
||||
class HostNICDisableAction(HostNICAbstractAction):
|
||||
"""Action which disables a NIC."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
@@ -569,6 +787,44 @@ class NetworkNICDisableAction(NetworkNICAbstractAction):
|
||||
self.verb: str = "disable"
|
||||
|
||||
|
||||
class NetworkPortEnableAction(AbstractAction):
|
||||
"""Action which enables are port on a router or a firewall."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", max_nics_per_node: int, **kwargs) -> None:
|
||||
"""Init method for NetworkPortEnableAction.
|
||||
|
||||
:param max_nics_per_node: Maximum number of NICs per node.
|
||||
:type max_nics_per_node: int
|
||||
"""
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"port_id": max_nics_per_node}
|
||||
|
||||
def form_request(self, target_nodename: str, port_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if target_nodename is None or port_id is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", target_nodename, "network_interface", port_id, "enable"]
|
||||
|
||||
|
||||
class NetworkPortDisableAction(AbstractAction):
|
||||
"""Action which disables are port on a router or a firewall."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", max_nics_per_node: int, **kwargs) -> None:
|
||||
"""Init method for NetworkPortDisableAction.
|
||||
|
||||
:param max_nics_per_node: Maximum number of NICs per node.
|
||||
:type max_nics_per_node: int
|
||||
"""
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"port_id": max_nics_per_node}
|
||||
|
||||
def form_request(self, target_nodename: str, port_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if target_nodename is None or port_id is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", target_nodename, "network_interface", port_id, "disable"]
|
||||
|
||||
|
||||
class ActionManager:
|
||||
"""Class which manages the action space for an agent."""
|
||||
|
||||
@@ -582,8 +838,13 @@ class ActionManager:
|
||||
"NODE_SERVICE_RESTART": NodeServiceRestartAction,
|
||||
"NODE_SERVICE_DISABLE": NodeServiceDisableAction,
|
||||
"NODE_SERVICE_ENABLE": NodeServiceEnableAction,
|
||||
"NODE_SERVICE_PATCH": NodeServicePatchAction,
|
||||
"NODE_SERVICE_FIX": NodeServiceFixAction,
|
||||
"NODE_APPLICATION_EXECUTE": NodeApplicationExecuteAction,
|
||||
"NODE_APPLICATION_SCAN": NodeApplicationScanAction,
|
||||
"NODE_APPLICATION_CLOSE": NodeApplicationCloseAction,
|
||||
"NODE_APPLICATION_FIX": NodeApplicationFixAction,
|
||||
"NODE_APPLICATION_INSTALL": NodeApplicationInstallAction,
|
||||
"NODE_APPLICATION_REMOVE": NodeApplicationRemoveAction,
|
||||
"NODE_FILE_SCAN": NodeFileScanAction,
|
||||
"NODE_FILE_CHECKHASH": NodeFileCheckhashAction,
|
||||
"NODE_FILE_DELETE": NodeFileDeleteAction,
|
||||
@@ -598,10 +859,14 @@ class ActionManager:
|
||||
"NODE_SHUTDOWN": NodeShutdownAction,
|
||||
"NODE_STARTUP": NodeStartupAction,
|
||||
"NODE_RESET": NodeResetAction,
|
||||
"NETWORK_ACL_ADDRULE": NetworkACLAddRuleAction,
|
||||
"NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction,
|
||||
"NETWORK_NIC_ENABLE": NetworkNICEnableAction,
|
||||
"NETWORK_NIC_DISABLE": NetworkNICDisableAction,
|
||||
"ROUTER_ACL_ADDRULE": RouterACLAddRuleAction,
|
||||
"ROUTER_ACL_REMOVERULE": RouterACLRemoveRuleAction,
|
||||
"FIREWALL_ACL_ADDRULE": FirewallACLAddRuleAction,
|
||||
"FIREWALL_ACL_REMOVERULE": FirewallACLRemoveRuleAction,
|
||||
"HOST_NIC_ENABLE": HostNICEnableAction,
|
||||
"HOST_NIC_DISABLE": HostNICDisableAction,
|
||||
"NETWORK_PORT_ENABLE": NetworkPortEnableAction,
|
||||
"NETWORK_PORT_DISABLE": NetworkPortDisableAction,
|
||||
}
|
||||
"""Dictionary which maps action type strings to the corresponding action class."""
|
||||
|
||||
@@ -617,7 +882,8 @@ class ActionManager:
|
||||
max_acl_rules: int = 10, # allows calculating shape
|
||||
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
|
||||
ports: List[str] = ["HTTP", "DNS", "ARP", "FTP", "NTP"], # allow mapping index to port
|
||||
ip_address_list: List[str] = [], # to allow us to map an index to an ip address.
|
||||
ip_list: List[str] = [], # to allow us to map an index to an ip address.
|
||||
wildcard_list: List[str] = [], # to allow mapping from wildcard index to
|
||||
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
|
||||
) -> None:
|
||||
"""Init method for ActionManager.
|
||||
@@ -643,8 +909,8 @@ class ActionManager:
|
||||
:type protocols: List[str]
|
||||
:param ports: List of ports that are available in the simulation. Used for calculating action shape.
|
||||
:type ports: List[str]
|
||||
:param ip_address_list: List of IP addresses that known to this agent. Used for calculating action shape.
|
||||
:type ip_address_list: Optional[List[str]]
|
||||
:param ip_list: List of IP addresses that known to this agent. Used for calculating action shape.
|
||||
:type ip_list: Optional[List[str]]
|
||||
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
|
||||
:type act_map: Optional[Dict[int, Dict]]
|
||||
"""
|
||||
@@ -705,8 +971,10 @@ class ActionManager:
|
||||
self.protocols: List[str] = protocols
|
||||
self.ports: List[str] = ports
|
||||
|
||||
self.ip_address_list: List[str] = ip_address_list
|
||||
|
||||
self.ip_address_list: List[str] = ip_list
|
||||
self.wildcard_list: List[str] = wildcard_list
|
||||
if self.wildcard_list == []:
|
||||
self.wildcard_list = ["NONE"]
|
||||
# action_args are settings which are applied to the action space as a whole.
|
||||
global_action_args = {
|
||||
"num_nodes": len(self.node_names),
|
||||
@@ -743,7 +1011,8 @@ class ActionManager:
|
||||
{0: ("NODE_SERVICE_SCAN", {node_id:0, service_id:2})}
|
||||
"""
|
||||
if act_map is None:
|
||||
self.action_map = self._enumerate_actions()
|
||||
# raise RuntimeError("Action map must be specified in the config file.")
|
||||
pass
|
||||
else:
|
||||
self.action_map = {i: (a["action"], a["options"]) for i, a in act_map.items()}
|
||||
# make sure all numbers between 0 and N are represented as dict keys in action map
|
||||
@@ -940,6 +1209,24 @@ class ActionManager:
|
||||
raise RuntimeError(msg)
|
||||
return self.ip_address_list[ip_idx]
|
||||
|
||||
def get_wildcard_by_idx(self, wildcard_idx: int) -> str:
|
||||
"""
|
||||
Get the IP wildcard corresponding to the given index.
|
||||
|
||||
:param ip_idx: The index of the IP wildcard to retrieve.
|
||||
:type ip_idx: int
|
||||
:return: The wildcard address.
|
||||
:rtype: str
|
||||
"""
|
||||
if wildcard_idx >= len(self.wildcard_list):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on ip wildcard {wildcard_idx} but this"
|
||||
f" is out of range for its action space. Wildcard list: {self.wildcard_list}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.wildcard_list[wildcard_idx]
|
||||
|
||||
def get_port_by_idx(self, port_idx: int) -> str:
|
||||
"""
|
||||
Get the port corresponding to the given index.
|
||||
@@ -998,37 +1285,14 @@ class ActionManager:
|
||||
:return: The constructed ActionManager.
|
||||
:rtype: ActionManager
|
||||
"""
|
||||
# If the user has provided a list of IP addresses, use that. Otherwise, generate a list of IP addresses from
|
||||
# the nodes in the simulation.
|
||||
# TODO: refactor. Options:
|
||||
# 1: This should be pulled out into it's own function for clarity
|
||||
# 2: The simulation itself should be able to provide a list of IP addresses with its API, rather than having to
|
||||
# go through the nodes here.
|
||||
ip_address_order = cfg["options"].pop("ip_address_order", {})
|
||||
ip_address_list = []
|
||||
for entry in ip_address_order:
|
||||
node_name = entry["node_name"]
|
||||
nic_num = entry["nic_num"]
|
||||
node_obj = game.simulation.network.get_node_by_hostname(node_name)
|
||||
ip_address = node_obj.network_interface[nic_num].ip_address
|
||||
ip_address_list.append(ip_address)
|
||||
|
||||
if not ip_address_list:
|
||||
node_names = [n["node_name"] for n in cfg.get("nodes", {})]
|
||||
for node_name in node_names:
|
||||
node_obj = game.simulation.network.get_node_by_hostname(node_name)
|
||||
if node_obj is None:
|
||||
continue
|
||||
network_interfaces = node_obj.network_interfaces
|
||||
for nic_uuid, nic_obj in network_interfaces.items():
|
||||
ip_address_list.append(nic_obj.ip_address)
|
||||
if "ip_list" not in cfg["options"]:
|
||||
cfg["options"]["ip_list"] = []
|
||||
|
||||
obj = cls(
|
||||
actions=cfg["action_list"],
|
||||
**cfg["options"],
|
||||
protocols=game.options.protocols,
|
||||
ports=game.options.ports,
|
||||
ip_address_list=ip_address_list,
|
||||
act_map=cfg.get("action_map"),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Interface for agents."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel, model_validator
|
||||
@@ -8,11 +8,31 @@ from pydantic import BaseModel, model_validator
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class AgentActionHistoryItem(BaseModel):
|
||||
"""One entry of an agent's action log - what the agent did and how the simulator responded in 1 step."""
|
||||
|
||||
timestep: int
|
||||
"""Timestep of this action."""
|
||||
|
||||
action: str
|
||||
"""CAOS Action name."""
|
||||
|
||||
parameters: Dict[str, Any]
|
||||
"""CAOS parameters for the given action."""
|
||||
|
||||
request: RequestFormat
|
||||
"""The request that was sent to the simulation based on the CAOS action chosen."""
|
||||
|
||||
response: RequestResponse
|
||||
"""The response sent back by the simulator for this action."""
|
||||
|
||||
|
||||
class AgentStartSettings(BaseModel):
|
||||
"""Configuration values for when an agent starts performing actions."""
|
||||
|
||||
@@ -90,6 +110,7 @@ class AbstractAgent(ABC):
|
||||
self.observation_manager: Optional[ObservationManager] = observation_space
|
||||
self.reward_function: Optional[RewardFunction] = reward_function
|
||||
self.agent_settings = agent_settings or AgentSettings()
|
||||
self.action_history: List[AgentActionHistoryItem] = []
|
||||
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -109,7 +130,7 @@ class AbstractAgent(ABC):
|
||||
:return: Reward from the state.
|
||||
:rtype: float
|
||||
"""
|
||||
return self.reward_function.update(state)
|
||||
return self.reward_function.update(state=state, last_action_response=self.action_history[-1])
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
@@ -120,8 +141,6 @@ class AbstractAgent(ABC):
|
||||
|
||||
:param obs: Observation of the environment.
|
||||
:type obs: ObsType
|
||||
:param reward: Reward from the previous action, defaults to None TODO: should this parameter even be accepted?
|
||||
:type reward: float, optional
|
||||
:param timestep: The current timestep in the simulation, used for non-RL agents. Optional
|
||||
:type timestep: int
|
||||
:return: Action to be taken in the environment.
|
||||
@@ -138,9 +157,15 @@ class AbstractAgent(ABC):
|
||||
request = self.action_manager.form_request(action_identifier=action, action_options=options)
|
||||
return request
|
||||
|
||||
def reset_agent_for_episode(self) -> None:
|
||||
"""Agent reset logic should go here."""
|
||||
pass
|
||||
def process_action_response(
|
||||
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
|
||||
) -> None:
|
||||
"""Process the response from the most recent action."""
|
||||
self.action_history.append(
|
||||
AgentActionHistoryItem(
|
||||
timestep=timestep, action=action, parameters=parameters, request=request, response=response
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent):
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
# flake8: noqa
|
||||
# Pre-import all the observations when we load up the observations module so that they can be resolved by the parser.
|
||||
from primaite.game.agent.observations.acl_observation import ACLObservation
|
||||
from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation
|
||||
from primaite.game.agent.observations.firewall_observation import FirewallObservation
|
||||
from primaite.game.agent.observations.host_observations import HostObservation
|
||||
from primaite.game.agent.observations.link_observation import LinkObservation, LinksObservation
|
||||
from primaite.game.agent.observations.nic_observations import NICObservation, PortObservation
|
||||
from primaite.game.agent.observations.node_observations import NodesObservation
|
||||
from primaite.game.agent.observations.observation_manager import NestedObservation, NullObservation, ObservationManager
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.router_observation import RouterObservation
|
||||
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation
|
||||
|
||||
# fmt: off
|
||||
__all__ = [
|
||||
"ACLObservation", "FileObservation", "FolderObservation", "FirewallObservation", "HostObservation",
|
||||
"LinksObservation", "NICObservation", "PortObservation", "NodesObservation", "NestedObservation",
|
||||
"ObservationManager", "ApplicationObservation", "ServiceObservation",]
|
||||
# fmt: on
|
||||
|
||||
187
src/primaite/game/agent/observations/acl_observation.py
Normal file
187
src/primaite/game/agent/observations/acl_observation.py
Normal file
@@ -0,0 +1,187 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
"""ACL observation, provides information about access control lists within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for ACLObservation."""
|
||||
|
||||
ip_list: Optional[List[IPv4Address]] = None
|
||||
"""List of IP addresses."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of wildcard strings."""
|
||||
port_list: Optional[List[int]] = None
|
||||
"""List of port numbers."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocol names."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of ACL rules."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
num_rules: int,
|
||||
ip_list: List[IPv4Address],
|
||||
wildcard_list: List[str],
|
||||
port_list: List[int],
|
||||
protocol_list: List[str],
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an ACL observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this ACL.
|
||||
:type where: WhereType
|
||||
:param num_rules: Number of ACL rules.
|
||||
:type num_rules: int
|
||||
:param ip_list: List of IP addresses.
|
||||
:type ip_list: List[IPv4Address]
|
||||
:param wildcard_list: List of wildcard strings.
|
||||
:type wildcard_list: List[str]
|
||||
:param port_list: List of port numbers.
|
||||
:type port_list: List[int]
|
||||
:param protocol_list: List of protocol names.
|
||||
:type protocol_list: List[str]
|
||||
"""
|
||||
self.where = where
|
||||
self.num_rules: int = num_rules
|
||||
self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)}
|
||||
self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)}
|
||||
self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)}
|
||||
self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)}
|
||||
self.default_observation: Dict = {
|
||||
i
|
||||
+ 1: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_ip_id": 0,
|
||||
"source_wildcard_id": 0,
|
||||
"source_port_id": 0,
|
||||
"dest_ip_id": 0,
|
||||
"dest_wildcard_id": 0,
|
||||
"dest_port_id": 0,
|
||||
"protocol_id": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing ACL rules.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
acl_state: Dict = access_from_nested_dict(state, self.where)
|
||||
if acl_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
obs = {}
|
||||
acl_items = dict(acl_state.items())
|
||||
i = 1 # don't show rule 0 for compatibility reasons.
|
||||
while i < self.num_rules + 1:
|
||||
rule_state = acl_items[i]
|
||||
if rule_state is None:
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": 0,
|
||||
"source_ip_id": 0,
|
||||
"source_wildcard_id": 0,
|
||||
"source_port_id": 0,
|
||||
"dest_ip_id": 0,
|
||||
"dest_wildcard_id": 0,
|
||||
"dest_port_id": 0,
|
||||
"protocol_id": 0,
|
||||
}
|
||||
else:
|
||||
src_ip = rule_state["src_ip_address"]
|
||||
src_node_id = 1 if src_ip is None else self.ip_to_id[src_ip]
|
||||
dst_ip = rule_state["dst_ip_address"]
|
||||
dst_node_id = 1 if dst_ip is None else self.ip_to_id[dst_ip]
|
||||
src_wildcard = rule_state["src_wildcard_mask"]
|
||||
src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1)
|
||||
dst_wildcard = rule_state["dst_wildcard_mask"]
|
||||
dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1)
|
||||
src_port = rule_state["src_port"]
|
||||
src_port_id = self.port_to_id.get(src_port, 1)
|
||||
dst_port = rule_state["dst_port"]
|
||||
dst_port_id = self.port_to_id.get(dst_port, 1)
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = self.protocol_to_id.get(protocol, 1)
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": rule_state["action"],
|
||||
"source_ip_id": src_node_id,
|
||||
"source_wildcard_id": src_wildcard_id,
|
||||
"source_port_id": src_port_id,
|
||||
"dest_ip_id": dst_node_id,
|
||||
"dest_wildcard_id": dst_wildcard_id,
|
||||
"dest_port_id": dst_port_id,
|
||||
"protocol_id": protocol_id,
|
||||
}
|
||||
i += 1
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for ACL rules.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
{
|
||||
"position": spaces.Discrete(self.num_rules),
|
||||
"permission": spaces.Discrete(3),
|
||||
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
||||
"source_ip_id": spaces.Discrete(len(self.ip_to_id) + 2),
|
||||
"source_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2),
|
||||
"source_port_id": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2),
|
||||
"dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2),
|
||||
"dest_port_id": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"protocol_id": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ACLObservation:
|
||||
"""
|
||||
Create an ACL observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the ACL observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this ACL's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed ACL observation instance.
|
||||
:rtype: ACLObservation
|
||||
"""
|
||||
return cls(
|
||||
where=parent_where + ["acl", "acl"],
|
||||
num_rules=config.num_rules,
|
||||
ip_list=config.ip_list,
|
||||
wildcard_list=config.wildcard_list,
|
||||
port_list=config.port_list,
|
||||
protocol_list=config.protocol_list,
|
||||
)
|
||||
@@ -1,188 +0,0 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.node_observations import NodeObservation
|
||||
from primaite.game.agent.observations.observations import (
|
||||
AbstractObservation,
|
||||
AclObservation,
|
||||
ICSObservation,
|
||||
LinkObservation,
|
||||
NullObservation,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class UC2BlueObservation(AbstractObservation):
|
||||
"""Container for all observations used by the blue agent in UC2.
|
||||
|
||||
TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler
|
||||
for the purpose of compiling several observation components.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nodes: List[NodeObservation],
|
||||
links: List[LinkObservation],
|
||||
acl: AclObservation,
|
||||
ics: ICSObservation,
|
||||
where: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Initialise UC2 blue observation.
|
||||
|
||||
:param nodes: List of node observations
|
||||
:type nodes: List[NodeObservation]
|
||||
:param links: List of link observations
|
||||
:type links: List[LinkObservation]
|
||||
:param acl: The Access Control List observation
|
||||
:type acl: AclObservation
|
||||
:param ics: The ICS observation
|
||||
:type ics: ICSObservation
|
||||
:param where: Where in the simulation state dict to find information. Not used in this particular observation
|
||||
because it only compiles other observations and doesn't contribute any new information, defaults to None
|
||||
:type where: Optional[List[str]], optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
self.nodes: List[NodeObservation] = nodes
|
||||
self.links: List[LinkObservation] = links
|
||||
self.acl: AclObservation = acl
|
||||
self.ics: ICSObservation = ics
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
|
||||
"LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)},
|
||||
"ACL": self.acl.default_observation,
|
||||
"ICS": self.ics.default_observation,
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
|
||||
obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
|
||||
obs["ACL"] = self.acl.observe(state)
|
||||
obs["ICS"] = self.ics.observe(state)
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
|
||||
"LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}),
|
||||
"ACL": self.acl.space,
|
||||
"ICS": self.ics.space,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation":
|
||||
"""Create UC2 blue observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes,
|
||||
links, ACL and ICS observations.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Constructed UC2 blue observation
|
||||
:rtype: UC2BlueObservation
|
||||
"""
|
||||
node_configs = config["nodes"]
|
||||
|
||||
num_services_per_node = config["num_services_per_node"]
|
||||
num_folders_per_node = config["num_folders_per_node"]
|
||||
num_files_per_folder = config["num_files_per_folder"]
|
||||
num_nics_per_node = config["num_nics_per_node"]
|
||||
nodes = [
|
||||
NodeObservation.from_config(
|
||||
config=n,
|
||||
game=game,
|
||||
num_services_per_node=num_services_per_node,
|
||||
num_folders_per_node=num_folders_per_node,
|
||||
num_files_per_folder=num_files_per_folder,
|
||||
num_nics_per_node=num_nics_per_node,
|
||||
)
|
||||
for n in node_configs
|
||||
]
|
||||
|
||||
link_configs = config["links"]
|
||||
links = [LinkObservation.from_config(config=link, game=game) for link in link_configs]
|
||||
|
||||
acl_config = config["acl"]
|
||||
acl = AclObservation.from_config(config=acl_config, game=game)
|
||||
|
||||
ics_config = config["ics"]
|
||||
ics = ICSObservation.from_config(config=ics_config, game=game)
|
||||
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
|
||||
return new
|
||||
|
||||
|
||||
class UC2RedObservation(AbstractObservation):
|
||||
"""Container for all observations used by the red agent in UC2."""
|
||||
|
||||
def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.where: Optional[List[str]] = where
|
||||
self.nodes: List[NodeObservation] = nodes
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation."""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation":
|
||||
"""
|
||||
Create UC2 red observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this UC2 red observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
"""
|
||||
node_configs = config["nodes"]
|
||||
nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs]
|
||||
return cls(nodes=nodes, where=["network"])
|
||||
|
||||
|
||||
class UC2GreenObservation(NullObservation):
|
||||
"""Green agent observation. As the green agent's actions don't depend on the observation, this is empty."""
|
||||
|
||||
pass
|
||||
@@ -1,126 +1,168 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
"""File observation, provides status information about a file within the simulation environment."""
|
||||
|
||||
class FileObservation(AbstractObservation):
|
||||
"""Observation of a file on a node in the network."""
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for FileObservation."""
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
file_name: str
|
||||
"""Name of the file, used for querying simulation state dictionary."""
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Whether to include the number of accesses to the file in the observation."""
|
||||
|
||||
def __init__(self, where: WhereType, include_num_access: bool) -> None:
|
||||
"""
|
||||
Initialise file observation.
|
||||
Initialise a file observation instance.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a file looks like this:
|
||||
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>,'files',<file_name>]
|
||||
:type where: Optional[List[str]]
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this file.
|
||||
A typical location for a file might be
|
||||
['network', 'nodes', <node_hostname>, 'file_system', 'folder', <folder_name>, 'files', <file_name>].
|
||||
:type where: WhereType
|
||||
:param include_num_access: Whether to include the number of accesses to the file in the observation.
|
||||
:type include_num_access: bool
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.default_observation: spaces.Space = {"health_status": 0}
|
||||
"Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted."
|
||||
self.where: WhereType = where
|
||||
self.include_num_access: bool = include_num_access
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
self.default_observation: ObsType = {"health_status": 0}
|
||||
if self.include_num_access:
|
||||
self.default_observation["num_access"] = 0
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_threshold = 10
|
||||
self.med_threshold = 5
|
||||
self.low_threshold = 0
|
||||
|
||||
def _categorise_num_access(self, num_access: int) -> int:
|
||||
"""
|
||||
Represent number of file accesses as a categorical variable.
|
||||
|
||||
:param num_access: Number of file accesses.
|
||||
:return: Bin number corresponding to the number of accesses.
|
||||
"""
|
||||
if num_access > self.high_threshold:
|
||||
return 3
|
||||
elif num_access > self.med_threshold:
|
||||
return 2
|
||||
elif num_access > self.low_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
:return: Observation containing the health status of the file and optionally the number of accesses.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
file_state = access_from_nested_dict(state, self.where)
|
||||
if file_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {"health_status": file_state["visible_status"]}
|
||||
obs = {"health_status": file_state["visible_status"]}
|
||||
if self.include_num_access:
|
||||
obs["num_access"] = self._categorise_num_access(file_state["num_access"])
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:return: Gymnasium space representing the observation space for file status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"health_status": spaces.Discrete(6)})
|
||||
space = {"health_status": spaces.Discrete(6)}
|
||||
if self.include_num_access:
|
||||
space["num_access"] = spaces.Discrete(4)
|
||||
return spaces.Dict(space)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation":
|
||||
"""Create file observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this file observation.
|
||||
:type config: Dict
|
||||
:param game: _description_
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: _description_, defaults to None
|
||||
:type parent_where: _type_, optional
|
||||
:return: _description_
|
||||
:rtype: _type_
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FileObservation:
|
||||
"""
|
||||
return cls(where=parent_where + ["files", config["file_name"]])
|
||||
Create a file observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the file observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this file's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed file observation instance.
|
||||
:rtype: FileObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access)
|
||||
|
||||
|
||||
class FolderObservation(AbstractObservation):
|
||||
"""Folder observation, including files inside of the folder."""
|
||||
class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
"""Folder observation, provides status information about a folder within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for FolderObservation."""
|
||||
|
||||
folder_name: str
|
||||
"""Name of the folder, used for querying simulation state dictionary."""
|
||||
files: List[FileObservation.ConfigSchema] = []
|
||||
"""List of file configurations within the folder."""
|
||||
num_files: Optional[int] = None
|
||||
"""Number of spaces for file observations in this folder."""
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Whether files in this folder should include the number of accesses in their observation."""
|
||||
|
||||
def __init__(
|
||||
self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2
|
||||
self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool
|
||||
) -> None:
|
||||
"""Initialise folder Observation, including files inside the folder.
|
||||
"""
|
||||
Initialise a folder observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
|
||||
A typical location for a file looks like this:
|
||||
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>]
|
||||
:type where: Optional[List[str]]
|
||||
:param max_files: As size of the space must remain static, define max files that can be in this folder
|
||||
, defaults to 5
|
||||
:type max_files: int, optional
|
||||
:param file_positions: Defines the positioning within the observation space of particular files. This ensures
|
||||
that even if new files are created, the existing files will always occupy the same space in the observation
|
||||
space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the
|
||||
observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same
|
||||
name, it will take the position defined in this dict. Defaults to {}
|
||||
:type file_positions: Dict[int, str], optional
|
||||
A typical location for a folder might be ['network', 'nodes', <node_hostname>, 'folders', <folder_name>].
|
||||
:type where: WhereType
|
||||
:param files: List of file observation instances within the folder.
|
||||
:type files: Iterable[FileObservation]
|
||||
:param num_files: Number of files expected in the folder.
|
||||
:type num_files: int
|
||||
:param include_num_access: Whether to include the number of accesses to files in the observation.
|
||||
:type include_num_access: bool
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.where: WhereType = where
|
||||
|
||||
self.files: List[FileObservation] = files
|
||||
while len(self.files) < num_files_per_folder:
|
||||
self.files.append(FileObservation())
|
||||
while len(self.files) > num_files_per_folder:
|
||||
while len(self.files) < num_files:
|
||||
self.files.append(FileObservation(where=None, include_num_access=include_num_access))
|
||||
while len(self.files) > num_files:
|
||||
truncated_file = self.files.pop()
|
||||
msg = f"Too many files in folder observation. Truncating file {truncated_file}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.default_observation = {
|
||||
"health_status": 0,
|
||||
"FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)},
|
||||
}
|
||||
if self.files:
|
||||
self.default_observation["FILES"] = {i + 1: f.default_observation for i, f in enumerate(self.files)}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the health status of the folder and status of files within the folder.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
folder_state = access_from_nested_dict(state, self.where)
|
||||
if folder_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
@@ -130,48 +172,42 @@ class FolderObservation(AbstractObservation):
|
||||
obs = {}
|
||||
|
||||
obs["health_status"] = health_status
|
||||
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
|
||||
if self.files:
|
||||
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:return: Gymnasium space representing the observation space for folder status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"health_status": spaces.Discrete(6),
|
||||
"FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}),
|
||||
}
|
||||
)
|
||||
shape = {"health_status": spaces.Discrete(6)}
|
||||
if self.files:
|
||||
shape["FILES"] = spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)})
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2
|
||||
) -> "FolderObservation":
|
||||
"""Create folder observation from a config. Also creates child file observations.
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FolderObservation:
|
||||
"""
|
||||
Create a folder observation from a configuration schema.
|
||||
|
||||
:param config: Dictionary containing the configuration for this folder observation. Includes the name of the
|
||||
folder and the files inside of it.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param config: Configuration schema containing the necessary information for the folder observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
|
||||
parent node. A typical location for a node ``where`` can be:
|
||||
['network','nodes',<node_hostname>,'file_system']
|
||||
:type parent_where: Optional[List[str]]
|
||||
:param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_files_per_folder: int, optional
|
||||
:return: Constructed folder observation
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed folder observation instance.
|
||||
:rtype: FolderObservation
|
||||
"""
|
||||
where = parent_where + ["folders", config["folder_name"]]
|
||||
where = parent_where + ["file_system", "folders", config.folder_name]
|
||||
|
||||
file_configs = config["files"]
|
||||
files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs]
|
||||
# pass down shared/common config items
|
||||
for file_config in config.files:
|
||||
file_config.include_num_access = config.include_num_access
|
||||
|
||||
return cls(where=where, files=files, num_files_per_folder=num_files_per_folder)
|
||||
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
|
||||
return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access)
|
||||
|
||||
220
src/primaite/game/agent/observations/firewall_observation.py
Normal file
220
src/primaite/game/agent/observations/firewall_observation.py
Normal file
@@ -0,0 +1,220 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.acl_observation import ACLObservation
|
||||
from primaite.game.agent.observations.nic_observations import PortObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
"""Firewall observation, provides status information about a firewall within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for FirewallObservation."""
|
||||
|
||||
hostname: str
|
||||
"""Hostname of the firewall node, used for querying simulation state dictionary."""
|
||||
ip_list: Optional[List[str]] = None
|
||||
"""List of IP addresses for encoding ACLs."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of IP wildcards for encoding ACLs."""
|
||||
port_list: Optional[List[int]] = None
|
||||
"""List of ports for encoding ACLs."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocols for encoding ACLs."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of rules ACL rules to show."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
ip_list: List[str],
|
||||
wildcard_list: List[str],
|
||||
port_list: List[int],
|
||||
protocol_list: List[str],
|
||||
num_rules: int,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a firewall observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this firewall.
|
||||
A typical location for a firewall might be ['network', 'nodes', <firewall_hostname>].
|
||||
:type where: WhereType
|
||||
:param ip_list: List of IP addresses.
|
||||
:type ip_list: List[str]
|
||||
:param wildcard_list: List of wildcard rules.
|
||||
:type wildcard_list: List[str]
|
||||
:param port_list: List of port numbers.
|
||||
:type port_list: List[int]
|
||||
:param protocol_list: List of protocol types.
|
||||
:type protocol_list: List[str]
|
||||
:param num_rules: Number of rules configured in the firewall.
|
||||
:type num_rules: int
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
|
||||
self.ports: List[PortObservation] = [
|
||||
PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3)
|
||||
]
|
||||
# TODO: check what the port nums are for firewall.
|
||||
|
||||
self.internal_inbound_acl = ACLObservation(
|
||||
where=self.where + ["internal_inbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
port_list=port_list,
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.internal_outbound_acl = ACLObservation(
|
||||
where=self.where + ["internal_outbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
port_list=port_list,
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.dmz_inbound_acl = ACLObservation(
|
||||
where=self.where + ["dmz_inbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
port_list=port_list,
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.dmz_outbound_acl = ACLObservation(
|
||||
where=self.where + ["dmz_outbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
port_list=port_list,
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.external_inbound_acl = ACLObservation(
|
||||
where=self.where + ["external_inbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
port_list=port_list,
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.external_outbound_acl = ACLObservation(
|
||||
where=self.where + ["external_outbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
port_list=port_list,
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
|
||||
self.default_observation = {
|
||||
"PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)},
|
||||
"ACL": {
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.internal_outbound_acl.default_observation,
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.dmz_outbound_acl.default_observation,
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.external_outbound_acl.default_observation,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
obs = {
|
||||
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
|
||||
"ACL": {
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.internal_outbound_acl.observe(state),
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.dmz_outbound_acl.observe(state),
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.external_outbound_acl.observe(state),
|
||||
},
|
||||
},
|
||||
}
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for firewall status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
space = spaces.Dict(
|
||||
{
|
||||
"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}),
|
||||
"ACL": spaces.Dict(
|
||||
{
|
||||
"INTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.internal_inbound_acl.space,
|
||||
"OUTBOUND": self.internal_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"DMZ": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.dmz_inbound_acl.space,
|
||||
"OUTBOUND": self.dmz_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"EXTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.external_inbound_acl.space,
|
||||
"OUTBOUND": self.external_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation:
|
||||
"""
|
||||
Create a firewall observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the firewall observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this firewall's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <firewall_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed firewall observation instance.
|
||||
:rtype: FirewallObservation
|
||||
"""
|
||||
return cls(
|
||||
where=parent_where + [config.hostname],
|
||||
ip_list=config.ip_list,
|
||||
wildcard_list=config.wildcard_list,
|
||||
port_list=config.port_list,
|
||||
protocol_list=config.protocol_list,
|
||||
num_rules=config.num_rules,
|
||||
)
|
||||
251
src/primaite/game/agent/observations/host_observations.py
Normal file
251
src/primaite/game/agent/observations/host_observations.py
Normal file
@@ -0,0 +1,251 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.file_system_observations import FolderObservation
|
||||
from primaite.game.agent.observations.nic_observations import NICObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
"""Host observation, provides status information about a host within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for HostObservation."""
|
||||
|
||||
hostname: str
|
||||
"""Hostname of the host, used for querying simulation state dictionary."""
|
||||
services: List[ServiceObservation.ConfigSchema] = []
|
||||
"""List of services to observe on the host."""
|
||||
applications: List[ApplicationObservation.ConfigSchema] = []
|
||||
"""List of applications to observe on the host."""
|
||||
folders: List[FolderObservation.ConfigSchema] = []
|
||||
"""List of folders to observe on the host."""
|
||||
network_interfaces: List[NICObservation.ConfigSchema] = []
|
||||
"""List of network interfaces to observe on the host."""
|
||||
num_services: Optional[int] = None
|
||||
"""Number of spaces for service observations on this host."""
|
||||
num_applications: Optional[int] = None
|
||||
"""Number of spaces for application observations on this host."""
|
||||
num_folders: Optional[int] = None
|
||||
"""Number of spaces for folder observations on this host."""
|
||||
num_files: Optional[int] = None
|
||||
"""Number of spaces for file observations on this host."""
|
||||
num_nics: Optional[int] = None
|
||||
"""Number of spaces for network interface observations on this host."""
|
||||
include_nmne: Optional[bool] = None
|
||||
"""Whether network interface observations should include number of malicious network events."""
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Whether to include the number of accesses to files observations on this host."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
services: List[ServiceObservation],
|
||||
applications: List[ApplicationObservation],
|
||||
folders: List[FolderObservation],
|
||||
network_interfaces: List[NICObservation],
|
||||
num_services: int,
|
||||
num_applications: int,
|
||||
num_folders: int,
|
||||
num_files: int,
|
||||
num_nics: int,
|
||||
include_nmne: bool,
|
||||
include_num_access: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a host observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this host.
|
||||
A typical location for a host might be ['network', 'nodes', <hostname>].
|
||||
:type where: WhereType
|
||||
:param services: List of service observations on the host.
|
||||
:type services: List[ServiceObservation]
|
||||
:param applications: List of application observations on the host.
|
||||
:type applications: List[ApplicationObservation]
|
||||
:param folders: List of folder observations on the host.
|
||||
:type folders: List[FolderObservation]
|
||||
:param network_interfaces: List of network interface observations on the host.
|
||||
:type network_interfaces: List[NICObservation]
|
||||
:param num_services: Number of services to observe.
|
||||
:type num_services: int
|
||||
:param num_applications: Number of applications to observe.
|
||||
:type num_applications: int
|
||||
:param num_folders: Number of folders to observe.
|
||||
:type num_folders: int
|
||||
:param num_files: Number of files.
|
||||
:type num_files: int
|
||||
:param num_nics: Number of network interfaces.
|
||||
:type num_nics: int
|
||||
:param include_nmne: Flag to include network metrics and errors.
|
||||
:type include_nmne: bool
|
||||
:param include_num_access: Flag to include the number of accesses to files.
|
||||
:type include_num_access: bool
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
|
||||
self.include_num_access = include_num_access
|
||||
|
||||
# Ensure lists have lengths equal to specified counts by truncating or padding
|
||||
self.services: List[ServiceObservation] = services
|
||||
while len(self.services) < num_services:
|
||||
self.services.append(ServiceObservation(where=None))
|
||||
while len(self.services) > num_services:
|
||||
truncated_service = self.services.pop()
|
||||
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.applications: List[ApplicationObservation] = applications
|
||||
while len(self.applications) < num_applications:
|
||||
self.applications.append(ApplicationObservation(where=None))
|
||||
while len(self.applications) > num_applications:
|
||||
truncated_application = self.applications.pop()
|
||||
msg = f"Too many applications in Node observation space for node. Truncating {truncated_application.where}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.folders: List[FolderObservation] = folders
|
||||
while len(self.folders) < num_folders:
|
||||
self.folders.append(
|
||||
FolderObservation(where=None, files=[], num_files=num_files, include_num_access=include_num_access)
|
||||
)
|
||||
while len(self.folders) > num_folders:
|
||||
truncated_folder = self.folders.pop()
|
||||
msg = f"Too many folders in Node observation space for node. Truncating folder {truncated_folder.where}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.nics: List[NICObservation] = network_interfaces
|
||||
while len(self.nics) < num_nics:
|
||||
self.nics.append(NICObservation(where=None, include_nmne=include_nmne))
|
||||
while len(self.nics) > num_nics:
|
||||
truncated_nic = self.nics.pop()
|
||||
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.default_observation: ObsType = {
|
||||
"operating_status": 0,
|
||||
}
|
||||
if self.services:
|
||||
self.default_observation["SERVICES"] = {i + 1: s.default_observation for i, s in enumerate(self.services)}
|
||||
if self.applications:
|
||||
self.default_observation["APPLICATIONS"] = {
|
||||
i + 1: a.default_observation for i, a in enumerate(self.applications)
|
||||
}
|
||||
if self.folders:
|
||||
self.default_observation["FOLDERS"] = {i + 1: f.default_observation for i, f in enumerate(self.folders)}
|
||||
if self.nics:
|
||||
self.default_observation["NICS"] = {i + 1: n.default_observation for i, n in enumerate(self.nics)}
|
||||
if self.include_num_access:
|
||||
self.default_observation["num_file_creations"] = 0
|
||||
self.default_observation["num_file_deletions"] = 0
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the status information about the host.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
node_state = access_from_nested_dict(state, self.where)
|
||||
if node_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["operating_status"] = node_state["operating_state"]
|
||||
if self.services:
|
||||
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
|
||||
if self.applications:
|
||||
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
|
||||
if self.folders:
|
||||
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
|
||||
if self.nics:
|
||||
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
|
||||
if self.include_num_access:
|
||||
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
|
||||
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for host status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
shape = {
|
||||
"operating_status": spaces.Discrete(5),
|
||||
}
|
||||
if self.services:
|
||||
shape["SERVICES"] = spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)})
|
||||
if self.applications:
|
||||
shape["APPLICATIONS"] = spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)})
|
||||
if self.folders:
|
||||
shape["FOLDERS"] = spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)})
|
||||
if self.nics:
|
||||
shape["NICS"] = spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)})
|
||||
if self.include_num_access:
|
||||
shape["num_file_creations"] = spaces.Discrete(4)
|
||||
shape["num_file_deletions"] = spaces.Discrete(4)
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> HostObservation:
|
||||
"""
|
||||
Create a host observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the host observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this host.
|
||||
A typical location might be ['network', 'nodes', <hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed host observation instance.
|
||||
:rtype: HostObservation
|
||||
"""
|
||||
if parent_where == []:
|
||||
where = ["network", "nodes", config.hostname]
|
||||
else:
|
||||
where = parent_where + [config.hostname]
|
||||
|
||||
# Pass down shared/common config items
|
||||
for folder_config in config.folders:
|
||||
folder_config.include_num_access = config.include_num_access
|
||||
folder_config.num_files = config.num_files
|
||||
for nic_config in config.network_interfaces:
|
||||
nic_config.include_nmne = config.include_nmne
|
||||
|
||||
services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services]
|
||||
applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications]
|
||||
folders = [FolderObservation.from_config(config=c, parent_where=where) for c in config.folders]
|
||||
nics = [NICObservation.from_config(config=c, parent_where=where) for c in config.network_interfaces]
|
||||
# If list of network interfaces is not defined, assume we want to
|
||||
# monitor the first N interfaces. Network interface numbering starts at 1.
|
||||
count = 1
|
||||
while len(nics) < config.num_nics:
|
||||
nic_config = NICObservation.ConfigSchema(nic_num=count, include_nmne=config.include_nmne)
|
||||
nics.append(NICObservation.from_config(config=nic_config, parent_where=where))
|
||||
count += 1
|
||||
|
||||
return cls(
|
||||
where=where,
|
||||
services=services,
|
||||
applications=applications,
|
||||
folders=folders,
|
||||
network_interfaces=nics,
|
||||
num_services=config.num_services,
|
||||
num_applications=config.num_applications,
|
||||
num_folders=config.num_folders,
|
||||
num_files=config.num_files,
|
||||
num_nics=config.num_nics,
|
||||
include_nmne=config.include_nmne,
|
||||
include_num_access=config.include_num_access,
|
||||
)
|
||||
152
src/primaite/game/agent/observations/link_observation.py
Normal file
152
src/primaite/game/agent/observations/link_observation.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class LinkObservation(AbstractObservation, identifier="LINK"):
|
||||
"""Link observation, providing information about a specific link within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for LinkObservation."""
|
||||
|
||||
link_reference: str
|
||||
"""Reference identifier for the link."""
|
||||
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
"""
|
||||
Initialise a link observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this link.
|
||||
A typical location for a link might be ['network', 'links', <link_reference>].
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.default_observation: ObsType = {"PROTOCOLS": {"ALL": 0}}
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing information about the link.
|
||||
:rtype: Any
|
||||
"""
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
if link_state is NOT_PRESENT_IN_STATE:
|
||||
self.where[-1] = "<->".join(self.where[-1].split("<->")[::-1]) # try swapping endpoint A and B
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
if link_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
bandwidth = link_state["bandwidth"]
|
||||
load = link_state["current_load"]
|
||||
if load == 0:
|
||||
utilisation_category = 0
|
||||
else:
|
||||
utilisation_fraction = load / bandwidth
|
||||
utilisation_category = int(utilisation_fraction * 9) + 1
|
||||
|
||||
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for link status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> LinkObservation:
|
||||
"""
|
||||
Create a link observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the link observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this link.
|
||||
A typical location might be ['network', 'links', <link_reference>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed link observation instance.
|
||||
:rtype: LinkObservation
|
||||
"""
|
||||
link_reference = config.link_reference
|
||||
if parent_where == []:
|
||||
where = ["network", "links", link_reference]
|
||||
else:
|
||||
where = parent_where + ["links", link_reference]
|
||||
return cls(where=where)
|
||||
|
||||
|
||||
class LinksObservation(AbstractObservation, identifier="LINKS"):
|
||||
"""Collection of link observations representing multiple links within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for LinksObservation."""
|
||||
|
||||
link_references: List[str]
|
||||
"""List of reference identifiers for the links."""
|
||||
|
||||
def __init__(self, where: WhereType, links: List[LinkObservation]) -> None:
|
||||
"""
|
||||
Initialise a links observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for these links.
|
||||
A typical location for links might be ['network', 'links'].
|
||||
:type where: WhereType
|
||||
:param links: List of link observations.
|
||||
:type links: List[LinkObservation]
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
self.links: List[LinkObservation] = links
|
||||
self.default_observation: ObsType = {i + 1: l.default_observation for i, l in enumerate(self.links)}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing information about multiple links.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
return {i + 1: l.observe(state) for i, l in enumerate(self.links)}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for multiple links.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({i + 1: l.space for i, l in enumerate(self.links)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> LinksObservation:
|
||||
"""
|
||||
Create a links observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the links observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about these links.
|
||||
A typical location might be ['network'].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed links observation instance.
|
||||
:rtype: LinksObservation
|
||||
"""
|
||||
where = parent_where + ["network"]
|
||||
link_cfgs = [LinkObservation.ConfigSchema(link_reference=ref) for ref in config.link_references]
|
||||
links = [LinkObservation.from_config(c, parent_where=where) for c in link_cfgs]
|
||||
return cls(where=where, links=links)
|
||||
@@ -1,97 +1,53 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
from primaite.simulator.network.nmne import CAPTURE_NMNE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class NicObservation(AbstractObservation):
|
||||
"""Observation of a Network Interface Card (NIC) in the network."""
|
||||
class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
"""Status information about a network interface within the simulation environment."""
|
||||
|
||||
low_nmne_threshold: int = 0
|
||||
"""The minimum number of malicious network events to be considered low."""
|
||||
med_nmne_threshold: int = 5
|
||||
"""The minimum number of malicious network events to be considered medium."""
|
||||
high_nmne_threshold: int = 10
|
||||
"""The minimum number of malicious network events to be considered high."""
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for NICObservation."""
|
||||
|
||||
global CAPTURE_NMNE
|
||||
|
||||
@property
|
||||
def default_observation(self) -> Dict:
|
||||
"""The default NIC observation dict."""
|
||||
data = {"nic_status": 0}
|
||||
if CAPTURE_NMNE:
|
||||
data.update({"NMNE": {"inbound": 0, "outbound": 0}})
|
||||
|
||||
return data
|
||||
nic_num: int
|
||||
"""Number of the network interface."""
|
||||
include_nmne: Optional[bool] = None
|
||||
"""Whether to include number of malicious network events (NMNE) in the observation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: Optional[Tuple[str]] = None,
|
||||
low_nmne_threshold: Optional[int] = 0,
|
||||
med_nmne_threshold: Optional[int] = 5,
|
||||
high_nmne_threshold: Optional[int] = 10,
|
||||
where: WhereType,
|
||||
include_nmne: bool,
|
||||
) -> None:
|
||||
"""Initialise NIC observation.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
|
||||
example may look like this:
|
||||
['network','nodes',<node_hostname>,'NICs',<nic_number>]
|
||||
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
|
||||
:type where: Optional[Tuple[str]], optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
Initialise a network interface observation instance.
|
||||
|
||||
global CAPTURE_NMNE
|
||||
if CAPTURE_NMNE:
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this interface.
|
||||
A typical location for a network interface might be
|
||||
['network', 'nodes', <node_hostname>, 'NICs', <nic_num>].
|
||||
:type where: WhereType
|
||||
:param include_nmne: Flag to determine whether to include NMNE information in the observation.
|
||||
:type include_nmne: bool
|
||||
"""
|
||||
self.where = where
|
||||
self.include_nmne: bool = include_nmne
|
||||
|
||||
self.default_observation: ObsType = {"nic_status": 0}
|
||||
if self.include_nmne:
|
||||
self.default_observation.update({"NMNE": {"inbound": 0, "outbound": 0}})
|
||||
self.nmne_inbound_last_step: int = 0
|
||||
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
|
||||
us find the difference."""
|
||||
self.nmne_outbound_last_step: int = 0
|
||||
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
|
||||
us find the difference."""
|
||||
|
||||
if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold:
|
||||
self._validate_nmne_categories(
|
||||
low_nmne_threshold=low_nmne_threshold,
|
||||
med_nmne_threshold=med_nmne_threshold,
|
||||
high_nmne_threshold=high_nmne_threshold,
|
||||
)
|
||||
|
||||
def _validate_nmne_categories(
|
||||
self, low_nmne_threshold: int = 0, med_nmne_threshold: int = 5, high_nmne_threshold: int = 10
|
||||
):
|
||||
"""
|
||||
Validates the nmne threshold config.
|
||||
|
||||
If the configuration is valid, the thresholds will be set, otherwise, an exception is raised.
|
||||
|
||||
:param: low_nmne_threshold: The minimum number of malicious network events to be considered low
|
||||
:param: med_nmne_threshold: The minimum number of malicious network events to be considered medium
|
||||
:param: high_nmne_threshold: The minimum number of malicious network events to be considered high
|
||||
"""
|
||||
if high_nmne_threshold <= med_nmne_threshold:
|
||||
raise Exception(
|
||||
f"nmne_categories: high nmne count ({high_nmne_threshold}) must be greater "
|
||||
f"than medium nmne count ({med_nmne_threshold})"
|
||||
)
|
||||
|
||||
if med_nmne_threshold <= low_nmne_threshold:
|
||||
raise Exception(
|
||||
f"nmne_categories: medium nmne count ({med_nmne_threshold}) must be greater "
|
||||
f"than low nmne count ({low_nmne_threshold})"
|
||||
)
|
||||
|
||||
self.high_nmne_threshold = high_nmne_threshold
|
||||
self.med_nmne_threshold = med_nmne_threshold
|
||||
self.low_nmne_threshold = low_nmne_threshold
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_nmne_threshold = 10
|
||||
self.med_nmne_threshold = 5
|
||||
self.low_nmne_threshold = 0
|
||||
|
||||
def _categorise_mne_count(self, nmne_count: int) -> int:
|
||||
"""
|
||||
@@ -116,73 +72,120 @@ class NicObservation(AbstractObservation):
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the status of the network interface and optionally NMNE information.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
nic_state = access_from_nested_dict(state, self.where)
|
||||
|
||||
if nic_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
else:
|
||||
obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2}
|
||||
if CAPTURE_NMNE:
|
||||
obs_dict.update({"NMNE": {}})
|
||||
direction_dict = nic_state["nmne"].get("direction", {})
|
||||
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
|
||||
inbound_count = inbound_keywords.get("*", 0)
|
||||
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
|
||||
outbound_count = outbound_keywords.get("*", 0)
|
||||
obs_dict["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
|
||||
obs_dict["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
|
||||
self.nmne_inbound_last_step = inbound_count
|
||||
self.nmne_outbound_last_step = outbound_count
|
||||
return obs_dict
|
||||
|
||||
obs = {"nic_status": 1 if nic_state["enabled"] else 2}
|
||||
if self.include_nmne:
|
||||
obs.update({"NMNE": {}})
|
||||
direction_dict = nic_state["nmne"].get("direction", {})
|
||||
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
|
||||
inbound_count = inbound_keywords.get("*", 0)
|
||||
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
|
||||
outbound_count = outbound_keywords.get("*", 0)
|
||||
obs["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
|
||||
obs["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
|
||||
self.nmne_inbound_last_step = inbound_count
|
||||
self.nmne_outbound_last_step = outbound_count
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for network interface status and NMNE information.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
space = spaces.Dict({"nic_status": spaces.Discrete(3)})
|
||||
|
||||
if CAPTURE_NMNE:
|
||||
if self.include_nmne:
|
||||
space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)})
|
||||
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
|
||||
"""Create NIC observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this NIC observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
|
||||
node. A typical location for a node ``where`` can be: ['network','nodes',<node_hostname>]
|
||||
:type parent_where: Optional[List[str]]
|
||||
:return: Constructed NIC observation
|
||||
:rtype: NicObservation
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NICObservation:
|
||||
"""
|
||||
low_nmne_threshold = None
|
||||
med_nmne_threshold = None
|
||||
high_nmne_threshold = None
|
||||
Create a network interface observation from a configuration schema.
|
||||
|
||||
if game and game.options and game.options.thresholds and game.options.thresholds.get("nmne"):
|
||||
threshold = game.options.thresholds["nmne"]
|
||||
:param config: Configuration schema containing the necessary information for the network interface observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed network interface observation instance.
|
||||
:rtype: NICObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["NICs", config.nic_num], include_nmne=config.include_nmne)
|
||||
|
||||
low_nmne_threshold = int(threshold.get("low")) if threshold.get("low") is not None else None
|
||||
med_nmne_threshold = int(threshold.get("medium")) if threshold.get("medium") is not None else None
|
||||
high_nmne_threshold = int(threshold.get("high")) if threshold.get("high") is not None else None
|
||||
|
||||
return cls(
|
||||
where=parent_where + ["NICs", config["nic_num"]],
|
||||
low_nmne_threshold=low_nmne_threshold,
|
||||
med_nmne_threshold=med_nmne_threshold,
|
||||
high_nmne_threshold=high_nmne_threshold,
|
||||
)
|
||||
class PortObservation(AbstractObservation, identifier="PORT"):
|
||||
"""Port observation, provides status information about a network port within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for PortObservation."""
|
||||
|
||||
port_id: int
|
||||
"""Identifier of the port, used for querying simulation state dictionary."""
|
||||
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
"""
|
||||
Initialise a port observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this port.
|
||||
A typical location for a port might be ['network', 'nodes', <node_hostname>, 'NICs', <port_id>].
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.default_observation: ObsType = {"operating_status": 0}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the operating status of the port.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
port_state = access_from_nested_dict(state, self.where)
|
||||
if port_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {"operating_status": 1 if port_state["enabled"] else 2}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for port status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(3)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> PortObservation:
|
||||
"""
|
||||
Create a port observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the port observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this port's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed port observation instance.
|
||||
:rtype: PortObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["NICs", config.port_id])
|
||||
|
||||
@@ -1,200 +1,216 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import model_validator
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.file_system_observations import FolderObservation
|
||||
from primaite.game.agent.observations.nic_observations import NicObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.software_observation import ServiceObservation
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
from primaite.game.agent.observations.firewall_observation import FirewallObservation
|
||||
from primaite.game.agent.observations.host_observations import HostObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.observations.router_observation import RouterObservation
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
"""Nodes observation, provides status information about nodes within the simulation environment."""
|
||||
|
||||
class NodeObservation(AbstractObservation):
|
||||
"""Observation of a node in the network. Includes services, folders and NICs."""
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for NodesObservation."""
|
||||
|
||||
hosts: List[HostObservation.ConfigSchema] = []
|
||||
"""List of configurations for host observations."""
|
||||
routers: List[RouterObservation.ConfigSchema] = []
|
||||
"""List of configurations for router observations."""
|
||||
firewalls: List[FirewallObservation.ConfigSchema] = []
|
||||
"""List of configurations for firewall observations."""
|
||||
num_services: Optional[int] = None
|
||||
"""Number of services."""
|
||||
num_applications: Optional[int] = None
|
||||
"""Number of applications."""
|
||||
num_folders: Optional[int] = None
|
||||
"""Number of folders."""
|
||||
num_files: Optional[int] = None
|
||||
"""Number of files."""
|
||||
num_nics: Optional[int] = None
|
||||
"""Number of network interface cards (NICs)."""
|
||||
include_nmne: Optional[bool] = None
|
||||
"""Flag to include nmne."""
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Flag to include the number of accesses."""
|
||||
num_ports: Optional[int] = None
|
||||
"""Number of ports."""
|
||||
ip_list: Optional[List[str]] = None
|
||||
"""List of IP addresses for encoding ACLs."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of IP wildcards for encoding ACLs."""
|
||||
port_list: Optional[List[int]] = None
|
||||
"""List of ports for encoding ACLs."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocols for encoding ACLs."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of rules ACL rules to show."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def force_optional_fields(self) -> NodesObservation.ConfigSchema:
|
||||
"""Check that options are specified only if they are needed for the nodes that are part of the config."""
|
||||
# check for hosts:
|
||||
host_fields = (
|
||||
self.num_services,
|
||||
self.num_applications,
|
||||
self.num_folders,
|
||||
self.num_files,
|
||||
self.num_nics,
|
||||
self.include_nmne,
|
||||
self.include_num_access,
|
||||
)
|
||||
router_fields = (
|
||||
self.num_ports,
|
||||
self.ip_list,
|
||||
self.wildcard_list,
|
||||
self.port_list,
|
||||
self.protocol_list,
|
||||
self.num_rules,
|
||||
)
|
||||
firewall_fields = (self.ip_list, self.wildcard_list, self.port_list, self.protocol_list, self.num_rules)
|
||||
if len(self.hosts) > 0 and any([x is None for x in host_fields]):
|
||||
raise ValueError("Configuration error: Host observation options were not fully specified.")
|
||||
if len(self.routers) > 0 and any([x is None for x in router_fields]):
|
||||
raise ValueError("Configuration error: Router observation options were not fully specified.")
|
||||
if len(self.firewalls) > 0 and any([x is None for x in firewall_fields]):
|
||||
raise ValueError("Configuration error: Firewall observation options were not fully specified.")
|
||||
return self
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: Optional[Tuple[str]] = None,
|
||||
services: List[ServiceObservation] = [],
|
||||
folders: List[FolderObservation] = [],
|
||||
network_interfaces: List[NicObservation] = [],
|
||||
logon_status: bool = False,
|
||||
num_services_per_node: int = 2,
|
||||
num_folders_per_node: int = 2,
|
||||
num_files_per_folder: int = 2,
|
||||
num_nics_per_node: int = 2,
|
||||
where: WhereType,
|
||||
hosts: List[HostObservation],
|
||||
routers: List[RouterObservation],
|
||||
firewalls: List[FirewallObservation],
|
||||
) -> None:
|
||||
"""
|
||||
Configurable observation for a node in the simulation.
|
||||
Initialise a nodes observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary for find relevant information for this observation.
|
||||
A typical location for a node looks like this:
|
||||
['network','nodes',<hostname>]. If empty list, a default null observation will be output, defaults to []
|
||||
:type where: List[str], optional
|
||||
:param services: Mapping between position in observation space and service name, defaults to {}
|
||||
:type services: Dict[int,str], optional
|
||||
:param max_services: Max number of services that can be presented in observation space for this node
|
||||
, defaults to 2
|
||||
:type max_services: int, optional
|
||||
:param folders: Mapping between position in observation space and folder name, defaults to {}
|
||||
:type folders: Dict[int,str], optional
|
||||
:param max_folders: Max number of folders in this node's obs space, defaults to 2
|
||||
:type max_folders: int, optional
|
||||
:param network_interfaces: Mapping between position in observation space and NIC idx, defaults to {}
|
||||
:type network_interfaces: Dict[int,str], optional
|
||||
:param max_nics: Max number of network interfaces in this node's obs space, defaults to 5
|
||||
:type max_nics: int, optional
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for nodes.
|
||||
A typical location for nodes might be ['network', 'nodes'].
|
||||
:type where: WhereType
|
||||
:param hosts: List of host observations.
|
||||
:type hosts: List[HostObservation]
|
||||
:param routers: List of router observations.
|
||||
:type routers: List[RouterObservation]
|
||||
:param firewalls: List of firewall observations.
|
||||
:type firewalls: List[FirewallObservation]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.where: WhereType = where
|
||||
|
||||
self.services: List[ServiceObservation] = services
|
||||
while len(self.services) < num_services_per_node:
|
||||
# add empty service observation without `where` parameter so it always returns default (blank) observation
|
||||
self.services.append(ServiceObservation())
|
||||
while len(self.services) > num_services_per_node:
|
||||
truncated_service = self.services.pop()
|
||||
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
|
||||
_LOGGER.warning(msg)
|
||||
# truncate service list
|
||||
self.hosts: List[HostObservation] = hosts
|
||||
self.routers: List[RouterObservation] = routers
|
||||
self.firewalls: List[FirewallObservation] = firewalls
|
||||
|
||||
self.folders: List[FolderObservation] = folders
|
||||
# add empty folder observation without `where` parameter that will always return default (blank) observations
|
||||
while len(self.folders) < num_folders_per_node:
|
||||
self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder))
|
||||
while len(self.folders) > num_folders_per_node:
|
||||
truncated_folder = self.folders.pop()
|
||||
msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.network_interfaces: List[NicObservation] = network_interfaces
|
||||
while len(self.network_interfaces) < num_nics_per_node:
|
||||
self.network_interfaces.append(NicObservation())
|
||||
while len(self.network_interfaces) > num_nics_per_node:
|
||||
truncated_nic = self.network_interfaces.pop()
|
||||
msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.logon_status: bool = logon_status
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
|
||||
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
|
||||
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
|
||||
"operating_status": 0,
|
||||
self.default_observation = {
|
||||
**{f"HOST{i}": host.default_observation for i, host in enumerate(self.hosts)},
|
||||
**{f"ROUTER{i}": router.default_observation for i, router in enumerate(self.routers)},
|
||||
**{f"FIREWALL{i}": firewall.default_observation for i, firewall in enumerate(self.firewalls)},
|
||||
}
|
||||
if self.logon_status:
|
||||
self.default_observation["logon_status"] = 0
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
:return: Observation containing status information about nodes.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
node_state = access_from_nested_dict(state, self.where)
|
||||
if node_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
|
||||
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
|
||||
obs["operating_status"] = node_state["operating_state"]
|
||||
obs["NICS"] = {
|
||||
i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces)
|
||||
obs = {
|
||||
**{f"HOST{i}": host.observe(state) for i, host in enumerate(self.hosts)},
|
||||
**{f"ROUTER{i}": router.observe(state) for i, router in enumerate(self.routers)},
|
||||
**{f"FIREWALL{i}": firewall.observe(state) for i, firewall in enumerate(self.firewalls)},
|
||||
}
|
||||
|
||||
if self.logon_status:
|
||||
obs["logon_status"] = 0
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
space_shape = {
|
||||
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
|
||||
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
|
||||
"operating_status": spaces.Discrete(5),
|
||||
"NICS": spaces.Dict(
|
||||
{i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)}
|
||||
),
|
||||
}
|
||||
if self.logon_status:
|
||||
space_shape["logon_status"] = spaces.Discrete(3)
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
return spaces.Dict(space_shape)
|
||||
:return: Gymnasium space representing the observation space for nodes.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
space = spaces.Dict(
|
||||
{
|
||||
**{f"HOST{i}": host.space for i, host in enumerate(self.hosts)},
|
||||
**{f"ROUTER{i}": router.space for i, router in enumerate(self.routers)},
|
||||
**{f"FIREWALL{i}": firewall.space for i, firewall in enumerate(self.firewalls)},
|
||||
}
|
||||
)
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: Dict,
|
||||
game: "PrimaiteGame",
|
||||
parent_where: Optional[List[str]] = None,
|
||||
num_services_per_node: int = 2,
|
||||
num_folders_per_node: int = 2,
|
||||
num_files_per_folder: int = 2,
|
||||
num_nics_per_node: int = 2,
|
||||
) -> "NodeObservation":
|
||||
"""Create node observation from a config. Also creates child service, folder and NIC observations.
|
||||
|
||||
:param config: Dictionary containing the configuration for this node observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this node's parent
|
||||
network. A typical location for it would be: ['network',]
|
||||
:type parent_where: Optional[List[str]]
|
||||
:param num_services_per_node: How many spaces for services are in this node observation (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_services_per_node: int, optional
|
||||
:param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_folders_per_node: int, optional
|
||||
:param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_files_per_folder: int, optional
|
||||
:return: Constructed node observation
|
||||
:rtype: NodeObservation
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NodesObservation:
|
||||
"""
|
||||
node_hostname = config["node_hostname"]
|
||||
if parent_where is None:
|
||||
where = ["network", "nodes", node_hostname]
|
||||
else:
|
||||
where = parent_where + ["nodes", node_hostname]
|
||||
Create a nodes observation from a configuration schema.
|
||||
|
||||
svc_configs = config.get("services", {})
|
||||
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs]
|
||||
folder_configs = config.get("folders", {})
|
||||
folders = [
|
||||
FolderObservation.from_config(
|
||||
config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder
|
||||
)
|
||||
for c in folder_configs
|
||||
]
|
||||
# create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc.
|
||||
nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}]
|
||||
network_interfaces = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs]
|
||||
logon_status = config.get("logon_status", False)
|
||||
return cls(
|
||||
where=where,
|
||||
services=services,
|
||||
folders=folders,
|
||||
network_interfaces=network_interfaces,
|
||||
logon_status=logon_status,
|
||||
num_services_per_node=num_services_per_node,
|
||||
num_folders_per_node=num_folders_per_node,
|
||||
num_files_per_folder=num_files_per_folder,
|
||||
num_nics_per_node=num_nics_per_node,
|
||||
)
|
||||
:param config: Configuration schema containing the necessary information for nodes observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about nodes.
|
||||
A typical location for nodes might be ['network', 'nodes'].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed nodes observation instance.
|
||||
:rtype: NodesObservation
|
||||
"""
|
||||
if not parent_where:
|
||||
where = ["network", "nodes"]
|
||||
else:
|
||||
where = parent_where + ["nodes"]
|
||||
|
||||
for host_config in config.hosts:
|
||||
if host_config.num_services is None:
|
||||
host_config.num_services = config.num_services
|
||||
if host_config.num_applications is None:
|
||||
host_config.num_applications = config.num_applications
|
||||
if host_config.num_folders is None:
|
||||
host_config.num_folders = config.num_folders
|
||||
if host_config.num_files is None:
|
||||
host_config.num_files = config.num_files
|
||||
if host_config.num_nics is None:
|
||||
host_config.num_nics = config.num_nics
|
||||
if host_config.include_nmne is None:
|
||||
host_config.include_nmne = config.include_nmne
|
||||
if host_config.include_num_access is None:
|
||||
host_config.include_num_access = config.include_num_access
|
||||
|
||||
for router_config in config.routers:
|
||||
if router_config.num_ports is None:
|
||||
router_config.num_ports = config.num_ports
|
||||
if router_config.ip_list is None:
|
||||
router_config.ip_list = config.ip_list
|
||||
if router_config.wildcard_list is None:
|
||||
router_config.wildcard_list = config.wildcard_list
|
||||
if router_config.port_list is None:
|
||||
router_config.port_list = config.port_list
|
||||
if router_config.protocol_list is None:
|
||||
router_config.protocol_list = config.protocol_list
|
||||
if router_config.num_rules is None:
|
||||
router_config.num_rules = config.num_rules
|
||||
|
||||
for firewall_config in config.firewalls:
|
||||
if firewall_config.ip_list is None:
|
||||
firewall_config.ip_list = config.ip_list
|
||||
if firewall_config.wildcard_list is None:
|
||||
firewall_config.wildcard_list = config.wildcard_list
|
||||
if firewall_config.port_list is None:
|
||||
firewall_config.port_list = config.port_list
|
||||
if firewall_config.protocol_list is None:
|
||||
firewall_config.protocol_list = config.protocol_list
|
||||
if firewall_config.num_rules is None:
|
||||
firewall_config.num_rules = config.num_rules
|
||||
|
||||
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
|
||||
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]
|
||||
firewalls = [FirewallObservation.from_config(config=c, parent_where=where) for c in config.firewalls]
|
||||
|
||||
return cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls)
|
||||
|
||||
@@ -1,16 +1,142 @@
|
||||
from typing import Dict, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import BaseModel, ConfigDict, model_validator, ValidationError
|
||||
|
||||
from primaite.game.agent.observations.agent_observations import (
|
||||
UC2BlueObservation,
|
||||
UC2GreenObservation,
|
||||
UC2RedObservation,
|
||||
)
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
class NestedObservation(AbstractObservation, identifier="CUSTOM"):
|
||||
"""Observation type that allows combining other observations into a gymnasium.spaces.Dict space."""
|
||||
|
||||
class NestedObservationItem(BaseModel):
|
||||
"""One list item of the config."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
type: str
|
||||
"""Select observation class. It maps to the identifier of the obs class by checking the registry."""
|
||||
label: str
|
||||
"""Dict key in the final observation space."""
|
||||
options: Dict
|
||||
"""Options to pass to the observation class from_config method."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_model(self) -> "NestedObservation.NestedObservationItem":
|
||||
"""Make sure tha the config options match up with the selected observation type."""
|
||||
obs_subclass_name = self.type
|
||||
obs_options = self.options
|
||||
if obs_subclass_name not in AbstractObservation._registry:
|
||||
raise ValueError(f"Observation of type {obs_subclass_name} could not be found.")
|
||||
obs_schema = AbstractObservation._registry[obs_subclass_name].ConfigSchema
|
||||
try:
|
||||
obs_schema(**obs_options)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Observation options did not match schema, got this error: {e}")
|
||||
return self
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for NestedObservation."""
|
||||
|
||||
components: List[NestedObservation.NestedObservationItem] = []
|
||||
"""List of observation components to be part of this space."""
|
||||
|
||||
def __init__(self, components: Dict[str, AbstractObservation]) -> None:
|
||||
"""Initialise nested observation."""
|
||||
self.components: Dict[str, AbstractObservation] = components
|
||||
"""Maps label: observation object"""
|
||||
|
||||
self.default_observation = {label: obs.default_observation for label, obs in self.components.items()}
|
||||
"""Default observation is just the default observations of constituents."""
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the status information about the host.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
return {label: obs.observe(state) for label, obs in self.components.items()}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the nested observation space.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({label: obs.space for label, obs in self.components.items()})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NestedObservation:
|
||||
"""
|
||||
Read the Nested observation config and create all defined subcomponents.
|
||||
|
||||
Example configuration that utilises NestedObservation:
|
||||
This lets us have different options for different types of hosts.
|
||||
|
||||
```yaml
|
||||
observation_space:
|
||||
- type: CUSTOM
|
||||
options:
|
||||
components:
|
||||
|
||||
- type: HOSTS
|
||||
label: COMPUTERS # What is the dictionary key called
|
||||
options:
|
||||
hosts:
|
||||
- client_1
|
||||
- client_2
|
||||
num_services: 0
|
||||
num_applications: 5
|
||||
... # other options
|
||||
|
||||
- type: HOSTS
|
||||
label: SERVERS # What is the dictionary key called
|
||||
options:
|
||||
hosts:
|
||||
- hostname: database_server
|
||||
- hostname: web_server
|
||||
num_services: 4
|
||||
num_applications: 0
|
||||
num_folders: 2
|
||||
num_files: 2
|
||||
|
||||
```
|
||||
"""
|
||||
instances = dict()
|
||||
for component in config.components:
|
||||
obs_class = AbstractObservation._registry[component.type]
|
||||
obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options))
|
||||
instances[component.label] = obs_instance
|
||||
return cls(components=instances)
|
||||
|
||||
|
||||
class NullObservation(AbstractObservation, identifier="NONE"):
|
||||
"""Empty observation that acts as a placeholder."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialise the empty observation."""
|
||||
self.default_observation = 0
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
"""Simply return 0."""
|
||||
return 0
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Essentially empty space."""
|
||||
return spaces.Discrete(1)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: NullObservation.ConfigSchema, parent_where: WhereType = []) -> NullObservation:
|
||||
"""Instantiate a NullObservation. Accepts parameters to comply with API."""
|
||||
return cls()
|
||||
|
||||
|
||||
class ObservationManager:
|
||||
@@ -23,18 +149,15 @@ class ObservationManager:
|
||||
3. Formatting this information so an agent can use it to make decisions.
|
||||
"""
|
||||
|
||||
# TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed
|
||||
# to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next
|
||||
# refactor.
|
||||
|
||||
def __init__(self, observation: AbstractObservation) -> None:
|
||||
def __init__(self, obs: AbstractObservation) -> None:
|
||||
"""Initialise observation space.
|
||||
|
||||
:param observation: Observation object
|
||||
:type observation: AbstractObservation
|
||||
"""
|
||||
self.obs: AbstractObservation = observation
|
||||
self.obs: AbstractObservation = obs
|
||||
self.current_observation: ObsType
|
||||
"""Cached copy of the observation at the time it was most recently calculated."""
|
||||
|
||||
def update(self, state: Dict) -> Dict:
|
||||
"""
|
||||
@@ -52,22 +175,22 @@ class ObservationManager:
|
||||
return self.obs.space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager":
|
||||
"""Create observation space from a config.
|
||||
def from_config(cls, config: Optional[Dict]) -> "ObservationManager":
|
||||
"""
|
||||
Create observation space from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this observation space.
|
||||
It should contain the key 'type' which selects which observation class to use (from a choice of:
|
||||
UC2BlueObservation, UC2RedObservation, UC2GreenObservation)
|
||||
The other key is 'options' which are passed to the constructor of the selected observation class.
|
||||
If None, a blank observation space is created.
|
||||
Otherwise, this must be a Dict with a type field and options field.
|
||||
type: string that corresponds to one of the observation identifiers that are provided when subclassing
|
||||
AbstractObservation
|
||||
options: this must adhere to the chosen observation type's ConfigSchema nested class.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
"""
|
||||
if config["type"] == "UC2BlueObservation":
|
||||
return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game))
|
||||
elif config["type"] == "UC2RedObservation":
|
||||
return cls(UC2RedObservation.from_config(config.get("options", {}), game=game))
|
||||
elif config["type"] == "UC2GreenObservation":
|
||||
return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game))
|
||||
else:
|
||||
raise ValueError("Observation space type invalid")
|
||||
if config is None:
|
||||
return cls(NullObservation())
|
||||
obs_type = config["type"]
|
||||
obs_class = AbstractObservation._registry[obs_type]
|
||||
observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]))
|
||||
obs_manager = cls(observation)
|
||||
return obs_manager
|
||||
|
||||
@@ -1,22 +1,48 @@
|
||||
"""Manages the observation space for the agent."""
|
||||
from abc import ABC, abstractmethod
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Any, Dict, Iterable, Optional, Type, Union
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
WhereType = Optional[Iterable[Union[str, int]]]
|
||||
|
||||
|
||||
class AbstractObservation(ABC):
|
||||
"""Abstract class for an observation space component."""
|
||||
|
||||
class ConfigSchema(ABC, BaseModel):
|
||||
"""Config schema for observations."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
_registry: Dict[str, Type["AbstractObservation"]] = {}
|
||||
"""Registry of observation components, with their name as key.
|
||||
|
||||
Automatically populated when subclasses are defined. Used for defining from_config.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialise an observation. This method must be overwritten."""
|
||||
self.default_observation: ObsType
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register an observation type.
|
||||
|
||||
:param identifier: Identifier used to uniquely specify observation component types.
|
||||
:type identifier: str
|
||||
:raises ValueError: When attempting to create a component with a name that is already in use.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Duplicate observation component type {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@abstractmethod
|
||||
def observe(self, state: Dict) -> Any:
|
||||
"""
|
||||
@@ -37,273 +63,6 @@ class AbstractObservation(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame"):
|
||||
"""Create this observation space component form a serialised format.
|
||||
|
||||
The `game` parameter is for a the PrimaiteGame object that spawns this component.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class LinkObservation(AbstractObservation):
|
||||
"""Observation of a link in the network."""
|
||||
|
||||
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
|
||||
"Default observation is what should be returned when the link doesn't exist."
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise link observation.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_hostname>,'servics', <service_name>]`
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
if link_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
bandwidth = link_state["bandwidth"]
|
||||
load = link_state["current_load"]
|
||||
if load == 0:
|
||||
utilisation_category = 0
|
||||
else:
|
||||
utilisation_fraction = load / bandwidth
|
||||
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
|
||||
utilisation_category = int(utilisation_fraction * 9) + 1
|
||||
|
||||
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
|
||||
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
|
||||
"""Create link observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this link observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Constructed link observation
|
||||
:rtype: LinkObservation
|
||||
"""
|
||||
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
|
||||
|
||||
|
||||
class AclObservation(AbstractObservation):
|
||||
"""Observation of an Access Control List (ACL) in the network."""
|
||||
|
||||
# TODO: should where be optional, and we can use where=None to pad the observation space?
|
||||
# definitely the current approach does not support tracking files that aren't specified by name, for example
|
||||
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
|
||||
# this needs adding, but not for the MVP.
|
||||
def __init__(
|
||||
self,
|
||||
node_ip_to_id: Dict[str, int],
|
||||
ports: List[int],
|
||||
protocols: List[str],
|
||||
where: Optional[Tuple[str]] = None,
|
||||
num_rules: int = 10,
|
||||
) -> None:
|
||||
"""Initialise ACL observation.
|
||||
|
||||
:param node_ip_to_id: Mapping between IP address and ID.
|
||||
:type node_ip_to_id: Dict[str, int]
|
||||
:param ports: List of ports which are part of the game that define the ordering when converting to an ID
|
||||
:type ports: List[int]
|
||||
:param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
|
||||
:type protocols: list[str]
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
|
||||
example may look like this:
|
||||
['network','nodes',<router_hostname>,'acl','acl']
|
||||
:type where: Optional[Tuple[str]], optional
|
||||
:param num_rules: , defaults to 10
|
||||
:type num_rules: int, optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.num_rules: int = num_rules
|
||||
self.node_to_id: Dict[str, int] = node_ip_to_id
|
||||
"List of node IP addresses, order in this list determines how they are converted to an ID"
|
||||
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
|
||||
"List of ports which are part of the game that define the ordering when converting to an ID"
|
||||
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
|
||||
"List of protocols which are part of the game, defines ordering when converting to an ID"
|
||||
self.default_observation: Dict = {
|
||||
i
|
||||
+ 1: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
acl_state: Dict = access_from_nested_dict(state, self.where)
|
||||
if acl_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
# TODO: what if the ACL has more rules than num of max rules for obs space
|
||||
obs = {}
|
||||
acl_items = dict(acl_state.items())
|
||||
i = 1 # don't show rule 0 for compatibility reasons.
|
||||
while i < self.num_rules + 1:
|
||||
rule_state = acl_items[i]
|
||||
if rule_state is None:
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
else:
|
||||
src_ip = rule_state["src_ip_address"]
|
||||
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
|
||||
dst_ip = rule_state["dst_ip_address"]
|
||||
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
|
||||
src_port = rule_state["src_port"]
|
||||
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
|
||||
dst_port = rule_state["dst_port"]
|
||||
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": rule_state["action"],
|
||||
"source_node_id": src_node_id,
|
||||
"source_port": src_port_id,
|
||||
"dest_node_id": dst_node_ip,
|
||||
"dest_port": dst_port_id,
|
||||
"protocol": protocol_id,
|
||||
}
|
||||
i += 1
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
{
|
||||
"position": spaces.Discrete(self.num_rules),
|
||||
"permission": spaces.Discrete(3),
|
||||
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
||||
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
|
||||
"""Generate ACL observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this ACL observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Observation object
|
||||
:rtype: AclObservation
|
||||
"""
|
||||
max_acl_rules = config["options"]["max_acl_rules"]
|
||||
node_ip_to_idx = {}
|
||||
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
|
||||
node_ref = ip_map_config["node_hostname"]
|
||||
nic_num = ip_map_config["nic_num"]
|
||||
node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
|
||||
nic_obj = node_obj.network_interface[nic_num]
|
||||
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
|
||||
|
||||
router_hostname = config["router_hostname"]
|
||||
return cls(
|
||||
node_ip_to_id=node_ip_to_idx,
|
||||
ports=game.options.ports,
|
||||
protocols=game.options.protocols,
|
||||
where=["network", "nodes", router_hostname, "acl", "acl"],
|
||||
num_rules=max_acl_rules,
|
||||
)
|
||||
|
||||
|
||||
class NullObservation(AbstractObservation):
|
||||
"""Null observation, returns a single 0 value for the observation space."""
|
||||
|
||||
def __init__(self, where: Optional[List[str]] = None):
|
||||
"""Initialise null observation."""
|
||||
self.default_observation: Dict = {}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation."""
|
||||
return 0
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Discrete(1)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
|
||||
"""
|
||||
Create null observation from a config.
|
||||
|
||||
The parameters are ignored, they are here to match the signature of the other observation classes.
|
||||
"""
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> "AbstractObservation":
|
||||
"""Create this observation space component form a serialised format."""
|
||||
return cls()
|
||||
|
||||
|
||||
class ICSObservation(NullObservation):
|
||||
"""ICS observation placeholder, currently not implemented so always returns a single 0."""
|
||||
|
||||
pass
|
||||
|
||||
145
src/primaite/game/agent/observations/router_observation.py
Normal file
145
src/primaite/game/agent/observations/router_observation.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.acl_observation import ACLObservation
|
||||
from primaite.game.agent.observations.nic_observations import PortObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
"""Router observation, provides status information about a router within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for RouterObservation."""
|
||||
|
||||
hostname: str
|
||||
"""Hostname of the router, used for querying simulation state dictionary."""
|
||||
ports: Optional[List[PortObservation.ConfigSchema]] = None
|
||||
"""Configuration of port observations for this router."""
|
||||
num_ports: Optional[int] = None
|
||||
"""Number of port observations configured for this router."""
|
||||
acl: Optional[ACLObservation.ConfigSchema] = None
|
||||
"""Configuration of ACL observation on this router."""
|
||||
ip_list: Optional[List[str]] = None
|
||||
"""List of IP addresses for encoding ACLs."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of IP wildcards for encoding ACLs."""
|
||||
port_list: Optional[List[int]] = None
|
||||
"""List of ports for encoding ACLs."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocols for encoding ACLs."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of rules ACL rules to show."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
ports: List[PortObservation],
|
||||
num_ports: int,
|
||||
acl: ACLObservation,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a router observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this router.
|
||||
A typical location for a router might be ['network', 'nodes', <node_hostname>].
|
||||
:type where: WhereType
|
||||
:param ports: List of port observations representing the ports of the router.
|
||||
:type ports: List[PortObservation]
|
||||
:param num_ports: Number of ports for the router.
|
||||
:type num_ports: int
|
||||
:param acl: ACL observation representing the access control list of the router.
|
||||
:type acl: ACLObservation
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
self.ports: List[PortObservation] = ports
|
||||
self.acl: ACLObservation = acl
|
||||
self.num_ports: int = num_ports
|
||||
|
||||
while len(self.ports) < num_ports:
|
||||
self.ports.append(PortObservation(where=None))
|
||||
while len(self.ports) > num_ports:
|
||||
self.ports.pop()
|
||||
msg = "Too many ports in router observation. Truncating."
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.default_observation = {
|
||||
"ACL": self.acl.default_observation,
|
||||
}
|
||||
if self.ports:
|
||||
self.default_observation["PORTS"] = {i + 1: p.default_observation for i, p in enumerate(self.ports)}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the status of ports and ACL configuration of the router.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
router_state = access_from_nested_dict(state, self.where)
|
||||
if router_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["ACL"] = self.acl.observe(state)
|
||||
if self.ports:
|
||||
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for router status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
shape = {"ACL": self.acl.space}
|
||||
if self.ports:
|
||||
shape["PORTS"] = spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)})
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> RouterObservation:
|
||||
"""
|
||||
Create a router observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the router observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this router's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed router observation instance.
|
||||
:rtype: RouterObservation
|
||||
"""
|
||||
where = parent_where + [config.hostname]
|
||||
|
||||
if config.acl is None:
|
||||
config.acl = ACLObservation.ConfigSchema()
|
||||
if config.acl.num_rules is None:
|
||||
config.acl.num_rules = config.num_rules
|
||||
if config.acl.ip_list is None:
|
||||
config.acl.ip_list = config.ip_list
|
||||
if config.acl.wildcard_list is None:
|
||||
config.acl.wildcard_list = config.wildcard_list
|
||||
if config.acl.port_list is None:
|
||||
config.acl.port_list = config.port_list
|
||||
if config.acl.protocol_list is None:
|
||||
config.acl.protocol_list = config.protocol_list
|
||||
|
||||
if config.ports is None:
|
||||
config.ports = [PortObservation.ConfigSchema(port_id=i + 1) for i in range(config.num_ports)]
|
||||
|
||||
ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports]
|
||||
acl = ACLObservation.from_config(config=config.acl, parent_where=where)
|
||||
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl)
|
||||
@@ -1,45 +1,43 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
class ServiceObservation(AbstractObservation, identifier="SERVICE"):
|
||||
"""Service observation, shows status of a service in the simulation environment."""
|
||||
|
||||
class ServiceObservation(AbstractObservation):
|
||||
"""Observation of a service in the network."""
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for ServiceObservation."""
|
||||
|
||||
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0}
|
||||
"Default observation is what should be returned when the service doesn't exist."
|
||||
service_name: str
|
||||
"""Name of the service, used for querying simulation state dictionary"""
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise service observation.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_hostname>,'services', <service_name>]`
|
||||
:type where: Optional[List[str]]
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
Initialise a service observation instance.
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this service.
|
||||
A typical location for a service might be ['network', 'nodes', <node_hostname>, 'services', <service_name>].
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.default_observation = {"operating_status": 0, "health_status": 0}
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
:return: Observation containing the operating status and health status of the service.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
service_state = access_from_nested_dict(state, self.where)
|
||||
if service_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
@@ -50,114 +48,116 @@ class ServiceObservation(AbstractObservation):
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for service status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)})
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
|
||||
) -> "ServiceObservation":
|
||||
"""Create service observation from a config.
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ServiceObservation:
|
||||
"""
|
||||
Create a service observation from a configuration schema.
|
||||
|
||||
:param config: Dictionary containing the configuration for this service observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
|
||||
:type parent_where: Optional[List[str]], optional
|
||||
:return: Constructed service observation
|
||||
:param config: Configuration schema containing the necessary information for the service observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this service's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed service observation instance.
|
||||
:rtype: ServiceObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["services", config["service_name"]])
|
||||
return cls(where=parent_where + ["services", config.service_name])
|
||||
|
||||
|
||||
class ApplicationObservation(AbstractObservation):
|
||||
"""Observation of an application in the network."""
|
||||
class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
|
||||
"""Application observation, shows the status of an application within the simulation environment."""
|
||||
|
||||
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0, "num_executions": 0}
|
||||
"Default observation is what should be returned when the application doesn't exist."
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for ApplicationObservation."""
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise application observation.
|
||||
application_name: str
|
||||
"""Name of the application, used for querying simulation state dictionary"""
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_hostname>,'applications', <application_name>]`
|
||||
:type where: Optional[List[str]]
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
Initialise an application observation instance.
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this application.
|
||||
A typical location for an application might be
|
||||
['network', 'nodes', <node_hostname>, 'applications', <application_name>].
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_threshold = 10
|
||||
self.med_threshold = 5
|
||||
self.low_threshold = 0
|
||||
|
||||
def _categorise_num_executions(self, num_executions: int) -> int:
|
||||
"""
|
||||
Represent number of file accesses as a categorical variable.
|
||||
|
||||
:param num_access: Number of file accesses.
|
||||
:return: Bin number corresponding to the number of accesses.
|
||||
"""
|
||||
if num_executions > self.high_threshold:
|
||||
return 3
|
||||
elif num_executions > self.med_threshold:
|
||||
return 2
|
||||
elif num_executions > self.low_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
:return: Obs containing the operating status, health status, and number of executions of the application.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
app_state = access_from_nested_dict(state, self.where)
|
||||
if app_state is NOT_PRESENT_IN_STATE:
|
||||
application_state = access_from_nested_dict(state, self.where)
|
||||
if application_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {
|
||||
"operating_status": app_state["operating_state"],
|
||||
"health_status": app_state["health_state_visible"],
|
||||
"num_executions": self._categorise_num_executions(app_state["num_executions"]),
|
||||
"operating_status": application_state["operating_state"],
|
||||
"health_status": application_state["health_state_visible"],
|
||||
"num_executions": self._categorise_num_executions(application_state["num_executions"]),
|
||||
}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for application status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"operating_status": spaces.Discrete(7),
|
||||
"health_status": spaces.Discrete(6),
|
||||
"health_status": spaces.Discrete(5),
|
||||
"num_executions": spaces.Discrete(4),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
|
||||
) -> "ApplicationObservation":
|
||||
"""Create application observation from a config.
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ApplicationObservation:
|
||||
"""
|
||||
Create an application observation from a configuration schema.
|
||||
|
||||
:param config: Dictionary containing the configuration for this service observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
|
||||
:type parent_where: Optional[List[str]], optional
|
||||
:return: Constructed service observation
|
||||
:param config: Configuration schema containing the necessary information for the application observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this application's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed application observation instance.
|
||||
:rtype: ApplicationObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["services", config["application_name"]])
|
||||
|
||||
@classmethod
|
||||
def _categorise_num_executions(cls, num_executions: int) -> int:
|
||||
"""
|
||||
Categorise the number of executions of an application.
|
||||
|
||||
Helps classify the number of application executions into different categories.
|
||||
|
||||
Current categories:
|
||||
- 0: Application is never executed
|
||||
- 1: Application is executed a low number of times (1-5)
|
||||
- 2: Application is executed often (6-10)
|
||||
- 3: Application is executed a high number of times (more than 10)
|
||||
|
||||
:param: num_executions: Number of times the application is executed
|
||||
"""
|
||||
if num_executions > 10:
|
||||
return 3
|
||||
elif num_executions > 5:
|
||||
return 2
|
||||
elif num_executions > 0:
|
||||
return 1
|
||||
return 0
|
||||
return cls(where=parent_where + ["applications", config.application_name])
|
||||
|
||||
@@ -26,19 +26,25 @@ the structure:
|
||||
```
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Tuple, Type
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
||||
|
||||
from typing_extensions import Never
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.agent.interface import AgentActionHistoryItem
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
WhereType = Optional[Iterable[Union[str, int]]]
|
||||
|
||||
|
||||
class AbstractReward:
|
||||
"""Base class for reward function components."""
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state."""
|
||||
return 0.0
|
||||
|
||||
@@ -58,7 +64,7 @@ class AbstractReward:
|
||||
class DummyReward(AbstractReward):
|
||||
"""Dummy reward function component which always returns 0."""
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state."""
|
||||
return 0.0
|
||||
|
||||
@@ -98,7 +104,7 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
file_name,
|
||||
]
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
@@ -106,7 +112,7 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
"""
|
||||
database_file_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if database_file_state is NOT_PRESENT_IN_STATE:
|
||||
_LOGGER.info(
|
||||
_LOGGER.debug(
|
||||
f"Could not calculate {self.__class__} reward because "
|
||||
"simulation state did not contain enough information."
|
||||
)
|
||||
@@ -153,7 +159,7 @@ class WebServer404Penalty(AbstractReward):
|
||||
"""
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
@@ -203,19 +209,30 @@ class WebpageUnavailablePenalty(AbstractReward):
|
||||
:param node_hostname: Hostname of the node which has the web browser.
|
||||
:type node_hostname: str
|
||||
"""
|
||||
self._node = node_hostname
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
|
||||
self._node: str = node_hostname
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
|
||||
self._last_request_failed: bool = False
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
"""
|
||||
Calculate the reward based on current simulation state.
|
||||
Calculate the reward based on current simulation state, and the recent agent action.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
:type state: Dict
|
||||
When the green agent requests to execute the browser application, and that request fails, this reward
|
||||
component will keep track of that information. In that case, it doesn't matter whether the last webpage
|
||||
had a 200 status code, because there has been an unsuccessful request since.
|
||||
"""
|
||||
if last_action_response.request == ["network", "node", self._node, "application", "WebBrowser", "execute"]:
|
||||
self._last_request_failed = last_action_response.response.status != "success"
|
||||
|
||||
# if agent couldn't even get as far as sending the request (because for example the node was off), then
|
||||
# apply a penalty
|
||||
if self._last_request_failed:
|
||||
return -1.0
|
||||
|
||||
# If the last request did actually go through, then check if the webpage also loaded
|
||||
web_browser_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state:
|
||||
_LOGGER.info(
|
||||
_LOGGER.debug(
|
||||
"Web browser reward could not be calculated because the web browser history on node",
|
||||
f"{self._node} was not reported in the simulation state. Returning 0.0",
|
||||
)
|
||||
@@ -252,19 +269,32 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
:param node_hostname: Hostname of the node where the database client sits.
|
||||
:type node_hostname: str
|
||||
"""
|
||||
self._node = node_hostname
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
|
||||
self._node: str = node_hostname
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
|
||||
self._last_request_failed: bool = False
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
"""
|
||||
Calculate the reward based on current simulation state.
|
||||
Calculate the reward based on current simulation state, and the recent agent action.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
:type state: Dict
|
||||
When the green agent requests to execute the database client application, and that request fails, this reward
|
||||
component will keep track of that information. In that case, it doesn't matter whether the last successful
|
||||
request returned was able to connect to the database server, because there has been an unsuccessful request
|
||||
since.
|
||||
"""
|
||||
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
|
||||
self._last_request_failed = last_action_response.response.status != "success"
|
||||
|
||||
# if agent couldn't even get as far as sending the request (because for example the node was off), then
|
||||
# apply a penalty
|
||||
if self._last_request_failed:
|
||||
return -1.0
|
||||
|
||||
# If the last request was actually sent, then check if the connection was established.
|
||||
db_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
|
||||
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
|
||||
return 0.0
|
||||
last_connection_successful = db_state["last_connection_successful"]
|
||||
if last_connection_successful is False:
|
||||
return -1.0
|
||||
@@ -284,6 +314,51 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
return cls(node_hostname=node_hostname)
|
||||
|
||||
|
||||
class SharedReward(AbstractReward):
|
||||
"""Adds another agent's reward to the overall reward."""
|
||||
|
||||
def __init__(self, agent_name: Optional[str] = None) -> None:
|
||||
"""
|
||||
Initialise the shared reward.
|
||||
|
||||
The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work
|
||||
correctly.
|
||||
|
||||
:param agent_name: The name whose reward is an input
|
||||
:type agent_name: Optional[str]
|
||||
"""
|
||||
self.agent_name = agent_name
|
||||
"""Agent whose reward to track."""
|
||||
|
||||
def default_callback(agent_name: str) -> Never:
|
||||
"""
|
||||
Default callback to prevent calling this reward until it's properly initialised.
|
||||
|
||||
SharedReward should not be used until the game layer replaces self.callback with a reference to the
|
||||
function that retrieves the desired agent's reward. Therefore, we define this default callback that raises
|
||||
an error.
|
||||
"""
|
||||
raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.")
|
||||
|
||||
self.callback: Callable[[str], float] = default_callback
|
||||
"""Method that retrieves an agent's current reward given the agent's name."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
"""Simply access the other agent's reward and return it."""
|
||||
return self.callback(self.agent_name)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "SharedReward":
|
||||
"""
|
||||
Build the SharedReward object from config.
|
||||
|
||||
:param config: Configuration dictionary
|
||||
:type config: Dict
|
||||
"""
|
||||
agent_name = config.get("agent_name")
|
||||
return cls(agent_name=agent_name)
|
||||
|
||||
|
||||
class RewardFunction:
|
||||
"""Manages the reward function for the agent."""
|
||||
|
||||
@@ -293,6 +368,7 @@ class RewardFunction:
|
||||
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
|
||||
"WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty,
|
||||
"GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty,
|
||||
"SHARED_REWARD": SharedReward,
|
||||
}
|
||||
"""List of reward class identifiers."""
|
||||
|
||||
@@ -313,7 +389,7 @@ class RewardFunction:
|
||||
"""
|
||||
self.reward_components.append((component, weight))
|
||||
|
||||
def update(self, state: Dict) -> float:
|
||||
def update(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
"""Calculate the overall reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
@@ -323,7 +399,7 @@ class RewardFunction:
|
||||
for comp_and_weight in self.reward_components:
|
||||
comp = comp_and_weight[0]
|
||||
weight = comp_and_weight[1]
|
||||
total += weight * comp.calculate(state=state)
|
||||
total += weight * comp.calculate(state=state, last_action_response=last_action_response)
|
||||
self.current_reward = total
|
||||
return self.current_reward
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ class DataManipulationAgent(AbstractScriptedAgent):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.reset_agent_for_episode()
|
||||
self.setup_agent()
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
@@ -43,9 +43,8 @@ class DataManipulationAgent(AbstractScriptedAgent):
|
||||
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
|
||||
|
||||
def reset_agent_for_episode(self) -> None:
|
||||
def setup_agent(self) -> None:
|
||||
"""Set the next execution timestep when the episode resets."""
|
||||
super().reset_agent_for_episode()
|
||||
self._select_start_node()
|
||||
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
|
||||
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
from typing import Dict, Tuple
|
||||
import random
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
|
||||
class RandomAgent(AbstractScriptedAgent):
|
||||
@@ -19,3 +24,60 @@ class RandomAgent(AbstractScriptedAgent):
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.action_manager.get_action(self.action_manager.space.sample())
|
||||
|
||||
|
||||
class PeriodicAgent(AbstractScriptedAgent):
|
||||
"""Agent that does nothing most of the time, but executes application at regular intervals (with variance)."""
|
||||
|
||||
class Settings(BaseModel):
|
||||
"""Configuration values for when an agent starts performing actions."""
|
||||
|
||||
start_step: int = 20
|
||||
"The timestep at which an agent begins performing it's actions."
|
||||
start_variance: int = 5
|
||||
"Deviation around the start step."
|
||||
frequency: int = 5
|
||||
"The number of timesteps to wait between performing actions."
|
||||
variance: int = 0
|
||||
"The amount the frequency can randomly change to."
|
||||
max_executions: int = 999999
|
||||
"Maximum number of times the agent can execute its action."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: str,
|
||||
action_space: ActionManager,
|
||||
observation_space: ObservationManager,
|
||||
reward_function: RewardFunction,
|
||||
settings: Optional[Settings] = None,
|
||||
) -> None:
|
||||
"""Initialise PeriodicAgent."""
|
||||
super().__init__(
|
||||
agent_name=agent_name,
|
||||
action_space=action_space,
|
||||
observation_space=observation_space,
|
||||
reward_function=reward_function,
|
||||
)
|
||||
self.settings = settings or PeriodicAgent.Settings()
|
||||
self._set_next_execution_timestep(timestep=self.settings.start_step, variance=self.settings.start_variance)
|
||||
self.num_executions = 0
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int, variance: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
|
||||
:param timestep: The timestep when the next execute action should be taken.
|
||||
:type timestep: int
|
||||
:param variance: Uniform random variance applied to the timestep
|
||||
:type variance: int
|
||||
"""
|
||||
random_increment = random.randint(-variance, variance)
|
||||
self.next_execution_timestep = timestep + random_increment
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
|
||||
"""Do nothing, unless the current timestep is the next execution timestep, in which case do the action."""
|
||||
if timestep == self.next_execution_timestep and self.num_executions < self.settings.max_executions:
|
||||
self.num_executions += 1
|
||||
self._set_next_execution_timestep(timestep + self.settings.frequency, self.settings.variance)
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}
|
||||
|
||||
return "DONOTHING", {}
|
||||
|
||||
78
src/primaite/game/agent/scripted_agents/tap001.py
Normal file
78
src/primaite/game/agent/scripted_agents/tap001.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import random
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
|
||||
|
||||
class TAP001(AbstractScriptedAgent):
|
||||
"""
|
||||
TAP001 | Mobile Malware -- Ransomware Variant.
|
||||
|
||||
Scripted Red Agent. Capable of one action; launching the kill-chain (Ransomware Application)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.setup_agent()
|
||||
|
||||
next_execution_timestep: int = 0
|
||||
starting_node_idx: int = 0
|
||||
installed: bool = False
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
|
||||
:param timestep: The timestep to add variance to.
|
||||
"""
|
||||
random_timestep_increment = random.randint(
|
||||
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
|
||||
)
|
||||
self.next_execution_timestep = timestep + random_timestep_increment
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
|
||||
"""Waits until a specific timestep, then attempts to execute the ransomware application.
|
||||
|
||||
This application acts a wrapper around the kill-chain, similar to green-analyst and
|
||||
the previous UC2 data manipulation bot.
|
||||
|
||||
:param obs: Current observation for this agent.
|
||||
:type obs: ObsType
|
||||
:param timestep: The current simulation timestep, used for scheduling actions
|
||||
:type timestep: int
|
||||
:return: Action formatted in CAOS format
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
if timestep < self.next_execution_timestep:
|
||||
return "DONOTHING", {}
|
||||
|
||||
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
|
||||
|
||||
if not self.installed:
|
||||
self.installed = True
|
||||
return "NODE_APPLICATION_INSTALL", {
|
||||
"node_id": self.starting_node_idx,
|
||||
"application_name": "RansomwareScript",
|
||||
"ip_address": self.ip_address,
|
||||
}
|
||||
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
|
||||
|
||||
def setup_agent(self) -> None:
|
||||
"""Set the next execution timestep when the episode resets."""
|
||||
self._select_start_node()
|
||||
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
|
||||
for n, act in self.action_manager.action_map.items():
|
||||
if not act[0] == "NODE_APPLICATION_INSTALL":
|
||||
continue
|
||||
if act[1]["node_id"] == self.starting_node_idx:
|
||||
self.ip_address = act[1]["ip_address"]
|
||||
return
|
||||
raise RuntimeError("TAP001 agent could not find database server ip address in action map")
|
||||
|
||||
def _select_start_node(self) -> None:
|
||||
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
|
||||
# we are assuming that every node in the node manager has a data manipulation application at idx 0
|
||||
num_nodes = len(self.action_manager.node_names)
|
||||
self.starting_node_idx = random.randint(0, num_nodes - 1)
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Hashable, Sequence
|
||||
from typing import Any, Dict, Hashable, Optional, Sequence
|
||||
|
||||
NOT_PRESENT_IN_STATE = object()
|
||||
"""
|
||||
@@ -7,7 +7,7 @@ the thing requested in the state could equal None. This NOT_PRESENT_IN_STATE is
|
||||
"""
|
||||
|
||||
|
||||
def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any:
|
||||
def access_from_nested_dict(dictionary: Dict, keys: Optional[Sequence[Hashable]]) -> Any:
|
||||
"""
|
||||
Access an item from a deeply dictionary with a list of keys.
|
||||
|
||||
@@ -21,6 +21,8 @@ def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any:
|
||||
:return: The value in the dictionary
|
||||
:rtype: Any
|
||||
"""
|
||||
if keys is None:
|
||||
return NOT_PRESENT_IN_STATE
|
||||
key_list = [*keys] # copy keys to a new list to prevent editing original list
|
||||
if len(key_list) == 0:
|
||||
return dictionary
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""PrimAITE game - Encapsulates the simulation and agents."""
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
@@ -8,22 +8,28 @@ from primaite import getLogger
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.agent.rewards import RewardFunction, SharedReward
|
||||
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
|
||||
from primaite.game.agent.scripted_agents.tap001 import TAP001
|
||||
from primaite.game.science import graph_has_cycle, topological_sort
|
||||
from primaite.simulator.network.airspace import AIR_SPACE
|
||||
from primaite.simulator.network.hardware.base import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
|
||||
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
|
||||
from primaite.simulator.network.hardware.nodes.network.router import Router
|
||||
from primaite.simulator.network.hardware.nodes.network.switch import Switch
|
||||
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
|
||||
from primaite.simulator.network.nmne import set_nmne_config
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
@@ -41,6 +47,7 @@ APPLICATION_TYPES_MAPPING = {
|
||||
"DatabaseClient": DatabaseClient,
|
||||
"DataManipulationBot": DataManipulationBot,
|
||||
"DoSBot": DoSBot,
|
||||
"RansomwareScript": RansomwareScript,
|
||||
}
|
||||
"""List of available applications that can be installed on nodes in the PrimAITE Simulation."""
|
||||
|
||||
@@ -100,21 +107,12 @@ class PrimaiteGame:
|
||||
self.options: PrimaiteGameOptions
|
||||
"""Special options that apply for the entire game."""
|
||||
|
||||
self.ref_map_nodes: Dict[str, str] = {}
|
||||
"""Mapping from unique node reference name to node object. Used when parsing config files."""
|
||||
|
||||
self.ref_map_services: Dict[str, str] = {}
|
||||
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
|
||||
|
||||
self.ref_map_applications: Dict[str, str] = {}
|
||||
"""Mapping from human-readable application reference to application object. Used for parsing config files."""
|
||||
|
||||
self.ref_map_links: Dict[str, str] = {}
|
||||
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
|
||||
|
||||
self.save_step_metadata: bool = False
|
||||
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
|
||||
|
||||
self._reward_calculation_order: List[str] = [name for name in self.agents]
|
||||
"""Agent order for reward evaluation, as some rewards can be dependent on other agents' rewards."""
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
Perform one step of the simulation/agent loop.
|
||||
@@ -135,49 +133,55 @@ class PrimaiteGame:
|
||||
"""
|
||||
_LOGGER.debug(f"Stepping. Step counter: {self.step_counter}")
|
||||
|
||||
# Get the current state of the simulation
|
||||
sim_state = self.get_sim_state()
|
||||
|
||||
# Update agents' observations and rewards based on the current state
|
||||
self.update_agents(sim_state)
|
||||
self.pre_timestep()
|
||||
|
||||
if self.step_counter == 0:
|
||||
state = self.get_sim_state()
|
||||
for agent in self.agents.values():
|
||||
agent.update_observation(state=state)
|
||||
# Apply all actions to simulation as requests
|
||||
self.apply_agent_actions()
|
||||
|
||||
# Advance timestep
|
||||
self.advance_timestep()
|
||||
|
||||
# Get the current state of the simulation
|
||||
sim_state = self.get_sim_state()
|
||||
|
||||
# Update agents' observations and rewards based on the current state, and the response from the last action
|
||||
self.update_agents(state=sim_state)
|
||||
|
||||
def get_sim_state(self) -> Dict:
|
||||
"""Get the current state of the simulation."""
|
||||
return self.simulation.describe_state()
|
||||
|
||||
def update_agents(self, state: Dict) -> None:
|
||||
"""Update agents' observations and rewards based on the current state."""
|
||||
for _, agent in self.agents.items():
|
||||
agent.update_observation(state)
|
||||
agent.update_reward(state)
|
||||
for agent_name in self._reward_calculation_order:
|
||||
agent = self.agents[agent_name]
|
||||
if self.step_counter > 0: # can't get reward before first action
|
||||
agent.update_reward(state=state)
|
||||
agent.update_observation(state=state) # order of this doesn't matter so just use reward order
|
||||
agent.reward_function.total_reward += agent.reward_function.current_reward
|
||||
|
||||
def apply_agent_actions(self) -> Dict[str, Tuple[str, Dict]]:
|
||||
"""
|
||||
Apply all actions to simulation as requests.
|
||||
|
||||
:return: A recap of each agent's actions, in CAOS format.
|
||||
:rtype: Dict[str, Tuple[str, Dict]]
|
||||
|
||||
"""
|
||||
agent_actions = {}
|
||||
def apply_agent_actions(self) -> None:
|
||||
"""Apply all actions to simulation as requests."""
|
||||
for _, agent in self.agents.items():
|
||||
obs = agent.observation_manager.current_observation
|
||||
action_choice, options = agent.get_action(obs, timestep=self.step_counter)
|
||||
request = agent.format_request(action_choice, options)
|
||||
action_choice, parameters = agent.get_action(obs, timestep=self.step_counter)
|
||||
request = agent.format_request(action_choice, parameters)
|
||||
response = self.simulation.apply_request(request)
|
||||
agent_actions[agent.agent_name] = {
|
||||
"action": action_choice,
|
||||
"parameters": options,
|
||||
"response": response.model_dump(),
|
||||
}
|
||||
return agent_actions
|
||||
agent.process_action_response(
|
||||
timestep=self.step_counter,
|
||||
action=action_choice,
|
||||
parameters=parameters,
|
||||
request=request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
def pre_timestep(self) -> None:
|
||||
"""Apply any pre-timestep logic that helps make sure we have the correct observations."""
|
||||
self.simulation.pre_timestep(self.step_counter)
|
||||
|
||||
def advance_timestep(self) -> None:
|
||||
"""Advance timestep."""
|
||||
@@ -206,8 +210,8 @@ class PrimaiteGame:
|
||||
"""Create a PrimaiteGame object from a config dictionary.
|
||||
|
||||
The config dictionary should have the following top-level keys:
|
||||
1. training_config: options for training the RL agent.
|
||||
2. game_config: options for the game itself. Used by PrimaiteGame.
|
||||
1. io_settings: options for logging data during training
|
||||
2. game_config: options for the game itself, such as agents.
|
||||
3. simulation: defines the network topology and the initial state of the simulation.
|
||||
|
||||
The specification for each of the three major areas is described in a separate documentation page.
|
||||
@@ -218,6 +222,7 @@ class PrimaiteGame:
|
||||
:return: A PrimaiteGame object.
|
||||
:rtype: PrimaiteGame
|
||||
"""
|
||||
AIR_SPACE.clear()
|
||||
game = cls()
|
||||
game.options = PrimaiteGameOptions(**cfg["game"])
|
||||
game.save_step_metadata = cfg.get("io_settings", {}).get("save_step_metadata") or False
|
||||
@@ -233,7 +238,6 @@ class PrimaiteGame:
|
||||
links_cfg = network_config.get("links", [])
|
||||
|
||||
for node_cfg in nodes_cfg:
|
||||
node_ref = node_cfg["ref"]
|
||||
n_type = node_cfg["type"]
|
||||
if n_type == "computer":
|
||||
new_node = Computer(
|
||||
@@ -269,18 +273,29 @@ class PrimaiteGame:
|
||||
new_node = Router.from_config(node_cfg)
|
||||
elif n_type == "firewall":
|
||||
new_node = Firewall.from_config(node_cfg)
|
||||
elif n_type == "wireless_router":
|
||||
new_node = WirelessRouter.from_config(node_cfg)
|
||||
elif n_type == "printer":
|
||||
new_node = Printer(
|
||||
hostname=node_cfg["hostname"],
|
||||
ip_address=node_cfg["ip_address"],
|
||||
subnet_mask=node_cfg["subnet_mask"],
|
||||
operating_state=NodeOperatingState.ON
|
||||
if not (p := node_cfg.get("operating_state"))
|
||||
else NodeOperatingState[p.upper()],
|
||||
)
|
||||
else:
|
||||
_LOGGER.warning(f"invalid node type {n_type} in config")
|
||||
msg = f"invalid node type {n_type} in config"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
if "services" in node_cfg:
|
||||
for service_cfg in node_cfg["services"]:
|
||||
new_service = None
|
||||
service_ref = service_cfg["ref"]
|
||||
service_type = service_cfg["type"]
|
||||
if service_type in SERVICE_TYPES_MAPPING:
|
||||
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
|
||||
new_node.software_manager.install(SERVICE_TYPES_MAPPING[service_type])
|
||||
new_service = new_node.software_manager.software[service_type]
|
||||
game.ref_map_services[service_ref] = new_service.uuid
|
||||
|
||||
# start the service
|
||||
new_service.start()
|
||||
@@ -316,13 +331,11 @@ class PrimaiteGame:
|
||||
if "applications" in node_cfg:
|
||||
for application_cfg in node_cfg["applications"]:
|
||||
new_application = None
|
||||
application_ref = application_cfg["ref"]
|
||||
application_type = application_cfg["type"]
|
||||
|
||||
if application_type in APPLICATION_TYPES_MAPPING:
|
||||
new_node.software_manager.install(APPLICATION_TYPES_MAPPING[application_type])
|
||||
new_application = new_node.software_manager.software[application_type]
|
||||
game.ref_map_applications[application_ref] = new_application.uuid
|
||||
else:
|
||||
msg = f"Configuration contains an invalid application type: {application_type}"
|
||||
_LOGGER.error(msg)
|
||||
@@ -341,6 +354,19 @@ class PrimaiteGame:
|
||||
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
|
||||
data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")),
|
||||
)
|
||||
elif application_type == "RansomwareScript":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
server_ip_address=IPv4Address(opt.get("server_ip")),
|
||||
server_password=opt.get("server_password"),
|
||||
payload=opt.get("payload", "ENCRYPT"),
|
||||
c2_beacon_p_of_success=float(opt.get("c2_beacon_p_of_success", "0.5")),
|
||||
target_scan_p_of_success=float(opt.get("target_scan_p_of_success", "0.1")),
|
||||
ransomware_encrypt_p_of_success=float(
|
||||
opt.get("ransomware_encrypt_p_of_success", "0.1")
|
||||
),
|
||||
)
|
||||
elif application_type == "DatabaseClient":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
@@ -376,7 +402,6 @@ class PrimaiteGame:
|
||||
# run through the power on step if the node is to be turned on at the start
|
||||
if new_node.operating_state == NodeOperatingState.ON:
|
||||
new_node.power_on()
|
||||
game.ref_map_nodes[node_ref] = new_node.uuid
|
||||
|
||||
# set start up and shut down duration
|
||||
new_node.start_up_duration = int(node_cfg.get("start_up_duration", 3))
|
||||
@@ -384,8 +409,9 @@ class PrimaiteGame:
|
||||
|
||||
# 2. create links between nodes
|
||||
for link_cfg in links_cfg:
|
||||
node_a = net.nodes[game.ref_map_nodes[link_cfg["endpoint_a_ref"]]]
|
||||
node_b = net.nodes[game.ref_map_nodes[link_cfg["endpoint_b_ref"]]]
|
||||
node_a = net.get_node_by_hostname(link_cfg["endpoint_a_hostname"])
|
||||
node_b = net.get_node_by_hostname(link_cfg["endpoint_b_hostname"])
|
||||
|
||||
if isinstance(node_a, Switch):
|
||||
endpoint_a = node_a.network_interface[link_cfg["endpoint_a_port"]]
|
||||
else:
|
||||
@@ -394,8 +420,7 @@ class PrimaiteGame:
|
||||
endpoint_b = node_b.network_interface[link_cfg["endpoint_b_port"]]
|
||||
else:
|
||||
endpoint_b = node_b.network_interface[link_cfg["endpoint_b_port"]]
|
||||
new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)
|
||||
game.ref_map_links[link_cfg["ref"]] = new_link.uuid
|
||||
net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)
|
||||
|
||||
# 3. create agents
|
||||
agents_cfg = cfg.get("agents", [])
|
||||
@@ -408,7 +433,7 @@ class PrimaiteGame:
|
||||
reward_function_cfg = agent_cfg["reward_function"]
|
||||
|
||||
# CREATE OBSERVATION SPACE
|
||||
obs_space = ObservationManager.from_config(observation_space_cfg, game)
|
||||
obs_space = ObservationManager.from_config(observation_space_cfg)
|
||||
|
||||
# CREATE ACTION SPACE
|
||||
action_space = ActionManager.from_config(game, action_space_cfg)
|
||||
@@ -427,6 +452,16 @@ class PrimaiteGame:
|
||||
reward_function=reward_function,
|
||||
settings=settings,
|
||||
)
|
||||
elif agent_type == "PeriodicAgent":
|
||||
settings = PeriodicAgent.Settings(**agent_cfg.get("settings", {}))
|
||||
new_agent = PeriodicAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
elif agent_type == "ProxyAgent":
|
||||
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
|
||||
new_agent = ProxyAgent(
|
||||
@@ -447,13 +482,64 @@ class PrimaiteGame:
|
||||
reward_function=reward_function,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
elif agent_type == "TAP001":
|
||||
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
|
||||
new_agent = TAP001(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
else:
|
||||
msg = f"Configuration error: {agent_type} is not a valid agent type."
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
game.agents[agent_cfg["ref"]] = new_agent
|
||||
|
||||
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
|
||||
game.setup_reward_sharing()
|
||||
|
||||
# Set the NMNE capture config
|
||||
set_nmne_config(network_config.get("nmne_config", {}))
|
||||
game.update_agents(game.get_sim_state())
|
||||
|
||||
return game
|
||||
|
||||
def setup_reward_sharing(self):
|
||||
"""Do necessary setup to enable reward sharing between agents.
|
||||
|
||||
This method ensures that there are no cycles in the reward sharing. A cycle would be for example if agent_1
|
||||
depends on agent_2 and agent_2 depends on agent_1. It would cause an infinite loop.
|
||||
|
||||
Also, SharedReward requires us to pass it a callback method that will provide the reward of the agent who is
|
||||
sharing their reward. This callback is provided by this setup method.
|
||||
|
||||
Finally, this method sorts the agents in order in which rewards will be evaluated to make sure that any rewards
|
||||
that rely on the value of another reward are evaluated later.
|
||||
|
||||
:raises RuntimeError: If the reward sharing is specified with a cyclic dependency.
|
||||
"""
|
||||
# construct dependency graph in the reward sharing between agents.
|
||||
graph = {}
|
||||
for name, agent in self.agents.items():
|
||||
graph[name] = set()
|
||||
for comp, weight in agent.reward_function.reward_components:
|
||||
if isinstance(comp, SharedReward):
|
||||
comp: SharedReward
|
||||
graph[name].add(comp.agent_name)
|
||||
|
||||
# while constructing the graph, we might as well set up the reward sharing itself.
|
||||
comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward
|
||||
|
||||
# make sure the graph is acyclic. Otherwise we will enter an infinite loop of reward sharing.
|
||||
if graph_has_cycle(graph):
|
||||
raise RuntimeError(
|
||||
(
|
||||
"Detected cycle in agent reward sharing. Check the agent reward function ",
|
||||
"configuration: reward sharing can only go one way.",
|
||||
)
|
||||
)
|
||||
|
||||
# sort the agents so the rewards that depend on other rewards are always evaluated later
|
||||
self._reward_calculation_order = topological_sort(graph)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from random import random
|
||||
from typing import Any, Iterable, Mapping
|
||||
|
||||
|
||||
def simulate_trial(p_of_success: float) -> bool:
|
||||
@@ -14,3 +15,80 @@ def simulate_trial(p_of_success: float) -> bool:
|
||||
:returns: True if the trial is successful (with probability 'p_of_success'); otherwise, False.
|
||||
"""
|
||||
return random() < p_of_success
|
||||
|
||||
|
||||
def graph_has_cycle(graph: Mapping[Any, Iterable[Any]]) -> bool:
|
||||
"""Detect cycles in a directed graph.
|
||||
|
||||
Provide the graph as a dictionary that describes which nodes are linked. For example:
|
||||
{0: {1,2}, 1:{2,3}, 3:{0}} here there's a cycle 0 -> 1 -> 3 -> 0
|
||||
{'a': ('b','c'), c:('b')} here there is no cycle
|
||||
|
||||
:param graph: a mapping from node to a set of nodes to which it is connected.
|
||||
:type graph: Mapping[Any, Iterable[Any]]
|
||||
:return: Whether the graph has any cycles
|
||||
:rtype: bool
|
||||
"""
|
||||
visited = set()
|
||||
currently_visiting = set()
|
||||
|
||||
def depth_first_search(node: Any) -> bool:
|
||||
"""Perform depth-first search (DFS) traversal to detect cycles starting from a given node."""
|
||||
if node in currently_visiting:
|
||||
return True # Cycle detected
|
||||
if node in visited:
|
||||
return False # Already visited, no need to explore further
|
||||
|
||||
visited.add(node)
|
||||
currently_visiting.add(node)
|
||||
|
||||
for neighbour in graph.get(node, []):
|
||||
if depth_first_search(neighbour):
|
||||
return True # Cycle detected
|
||||
|
||||
currently_visiting.remove(node)
|
||||
return False
|
||||
|
||||
# Start DFS traversal from each node
|
||||
for node in graph:
|
||||
if depth_first_search(node):
|
||||
return True # Cycle detected
|
||||
|
||||
return False # No cycles found
|
||||
|
||||
|
||||
def topological_sort(graph: Mapping[Any, Iterable[Any]]) -> Iterable[Any]:
|
||||
"""
|
||||
Perform topological sorting on a directed graph.
|
||||
|
||||
This guarantees that if there's a directed edge from node A to node B, then A appears before B.
|
||||
|
||||
:param graph: A dictionary representing the directed graph, where keys are node identifiers
|
||||
and values are lists of outgoing edges from each node.
|
||||
:type graph: dict[int, list[Any]]
|
||||
|
||||
:return: A topologically sorted list of node identifiers.
|
||||
:rtype: list[Any]
|
||||
"""
|
||||
visited: set[Any] = set()
|
||||
stack: list[Any] = []
|
||||
|
||||
def dfs(node: Any) -> None:
|
||||
"""
|
||||
Depth-first search traversal to visit nodes and their neighbors.
|
||||
|
||||
:param node: The current node to visit.
|
||||
:type node: Any
|
||||
"""
|
||||
if node in visited:
|
||||
return
|
||||
visited.add(node)
|
||||
for neighbour in graph.get(node, []):
|
||||
dfs(neighbour)
|
||||
stack.append(node)
|
||||
|
||||
# Perform DFS traversal from each node
|
||||
for node in graph:
|
||||
dfs(node)
|
||||
|
||||
return stack
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Dict, ForwardRef, Literal
|
||||
from typing import Dict, ForwardRef, List, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, StrictBool, validate_call
|
||||
|
||||
RequestFormat = List[Union[str, int, float]]
|
||||
|
||||
RequestResponse = ForwardRef("RequestResponse")
|
||||
"""This makes it possible to type-hint RequestResponse.from_bool return type."""
|
||||
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The main PrimAITE session runner module."""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.config.load import data_manipulation_config_path, load
|
||||
from primaite.session.session import PrimaiteSession
|
||||
|
||||
# from primaite.primaite_session import PrimaiteSession
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run(
|
||||
config_path: Optional[Union[str, Path]] = "",
|
||||
agent_load_path: Optional[Union[str, Path]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run the PrimAITE Session.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param session_path: directory path of the session to load
|
||||
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
"""
|
||||
cfg = load(config_path)
|
||||
sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path)
|
||||
sess.start_session()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config")
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.config:
|
||||
args.config = data_manipulation_config_path()
|
||||
|
||||
run(args.config)
|
||||
@@ -22,6 +22,7 @@
|
||||
"# Imports\n",
|
||||
"\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"import yaml\n",
|
||||
"from pprint import pprint"
|
||||
@@ -62,12 +63,12 @@
|
||||
"source": [
|
||||
"def friendly_output_red_action(info):\n",
|
||||
" # parse the info dict form step output and write out what the red agent is doing\n",
|
||||
" red_info = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info[0]\n",
|
||||
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info.action\n",
|
||||
" if red_action == 'DONOTHING':\n",
|
||||
" red_str = 'DO NOTHING'\n",
|
||||
" elif red_action == 'NODE_APPLICATION_EXECUTE':\n",
|
||||
" client = \"client 1\" if red_info[1]['node_id'] == 0 else \"client 2\"\n",
|
||||
" client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n",
|
||||
" red_str = f\"ATTACK from {client}\"\n",
|
||||
" return red_str"
|
||||
]
|
||||
@@ -361,7 +362,7 @@
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
" cfg['simulation']['network']\n",
|
||||
" for node in cfg['simulation']['network']['nodes']:\n",
|
||||
" if node['ref'] in ['client_1', 'client_2']:\n",
|
||||
" if node['hostname'] in ['client_1', 'client_2']:\n",
|
||||
" node['applications'] = change['applications']\n",
|
||||
"\n",
|
||||
"env = PrimaiteGymEnv(game_config = cfg)\n",
|
||||
@@ -406,7 +407,7 @@
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
" cfg['simulation']['network']\n",
|
||||
" for node in cfg['simulation']['network']['nodes']:\n",
|
||||
" if node['ref'] in ['client_1', 'client_2']:\n",
|
||||
" if node['hostname'] in ['client_1', 'client_2']:\n",
|
||||
" node['applications'] = change['applications']\n",
|
||||
"\n",
|
||||
"env = PrimaiteGymEnv(game_config = cfg)\n",
|
||||
|
||||
@@ -208,7 +208,7 @@
|
||||
"|--|--|\n",
|
||||
"|0|UNUSED|\n",
|
||||
"|1|GOOD|\n",
|
||||
"|2|PATCHING|\n",
|
||||
"|2|FIXING|\n",
|
||||
"|3|COMPROMISED|\n",
|
||||
"|4|OVERWHELMED|\n",
|
||||
"\n",
|
||||
@@ -352,7 +352,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -364,7 +364,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -373,7 +373,7 @@
|
||||
"# Imports\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"from primaite.game.game import PrimaiteGame\n",
|
||||
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
|
||||
"import yaml\n",
|
||||
"from pprint import pprint\n"
|
||||
]
|
||||
@@ -389,162 +389,9 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-03-13 16:52:48,201: Resetting environment, episode 0, avg. reward: 0.0\n",
|
||||
"2024-03-13 16:52:48,205: Saving agent action log to C:\\Users\\NickTodd\\primaite\\3.0.0b6\\sessions\\2024-03-13\\16-52-48\\agent_actions\\episode_0.json\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"env created successfully\n",
|
||||
"{'ACL': {1: {'dest_node_id': 0,\n",
|
||||
" 'dest_port': 0,\n",
|
||||
" 'permission': 0,\n",
|
||||
" 'position': 0,\n",
|
||||
" 'protocol': 0,\n",
|
||||
" 'source_node_id': 0,\n",
|
||||
" 'source_port': 0},\n",
|
||||
" 2: {'dest_node_id': 0,\n",
|
||||
" 'dest_port': 0,\n",
|
||||
" 'permission': 0,\n",
|
||||
" 'position': 1,\n",
|
||||
" 'protocol': 0,\n",
|
||||
" 'source_node_id': 0,\n",
|
||||
" 'source_port': 0},\n",
|
||||
" 3: {'dest_node_id': 0,\n",
|
||||
" 'dest_port': 0,\n",
|
||||
" 'permission': 0,\n",
|
||||
" 'position': 2,\n",
|
||||
" 'protocol': 0,\n",
|
||||
" 'source_node_id': 0,\n",
|
||||
" 'source_port': 0},\n",
|
||||
" 4: {'dest_node_id': 0,\n",
|
||||
" 'dest_port': 0,\n",
|
||||
" 'permission': 0,\n",
|
||||
" 'position': 3,\n",
|
||||
" 'protocol': 0,\n",
|
||||
" 'source_node_id': 0,\n",
|
||||
" 'source_port': 0},\n",
|
||||
" 5: {'dest_node_id': 0,\n",
|
||||
" 'dest_port': 0,\n",
|
||||
" 'permission': 0,\n",
|
||||
" 'position': 4,\n",
|
||||
" 'protocol': 0,\n",
|
||||
" 'source_node_id': 0,\n",
|
||||
" 'source_port': 0},\n",
|
||||
" 6: {'dest_node_id': 0,\n",
|
||||
" 'dest_port': 0,\n",
|
||||
" 'permission': 0,\n",
|
||||
" 'position': 5,\n",
|
||||
" 'protocol': 0,\n",
|
||||
" 'source_node_id': 0,\n",
|
||||
" 'source_port': 0},\n",
|
||||
" 7: {'dest_node_id': 0,\n",
|
||||
" 'dest_port': 0,\n",
|
||||
" 'permission': 0,\n",
|
||||
" 'position': 6,\n",
|
||||
" 'protocol': 0,\n",
|
||||
" 'source_node_id': 0,\n",
|
||||
" 'source_port': 0},\n",
|
||||
" 8: {'dest_node_id': 0,\n",
|
||||
" 'dest_port': 0,\n",
|
||||
" 'permission': 0,\n",
|
||||
" 'position': 7,\n",
|
||||
" 'protocol': 0,\n",
|
||||
" 'source_node_id': 0,\n",
|
||||
" 'source_port': 0},\n",
|
||||
" 9: {'dest_node_id': 0,\n",
|
||||
" 'dest_port': 0,\n",
|
||||
" 'permission': 0,\n",
|
||||
" 'position': 8,\n",
|
||||
" 'protocol': 0,\n",
|
||||
" 'source_node_id': 0,\n",
|
||||
" 'source_port': 0},\n",
|
||||
" 10: {'dest_node_id': 0,\n",
|
||||
" 'dest_port': 0,\n",
|
||||
" 'permission': 0,\n",
|
||||
" 'position': 9,\n",
|
||||
" 'protocol': 0,\n",
|
||||
" 'source_node_id': 0,\n",
|
||||
" 'source_port': 0}},\n",
|
||||
" 'ICS': 0,\n",
|
||||
" 'LINKS': {1: {'PROTOCOLS': {'ALL': 1}},\n",
|
||||
" 2: {'PROTOCOLS': {'ALL': 1}},\n",
|
||||
" 3: {'PROTOCOLS': {'ALL': 1}},\n",
|
||||
" 4: {'PROTOCOLS': {'ALL': 1}},\n",
|
||||
" 5: {'PROTOCOLS': {'ALL': 1}},\n",
|
||||
" 6: {'PROTOCOLS': {'ALL': 1}},\n",
|
||||
" 7: {'PROTOCOLS': {'ALL': 1}},\n",
|
||||
" 8: {'PROTOCOLS': {'ALL': 1}},\n",
|
||||
" 9: {'PROTOCOLS': {'ALL': 1}},\n",
|
||||
" 10: {'PROTOCOLS': {'ALL': 0}}},\n",
|
||||
" 'NODES': {1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n",
|
||||
" 'health_status': 0}},\n",
|
||||
" 'NICS': {1: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
|
||||
" 'nic_status': 1},\n",
|
||||
" 2: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
|
||||
" 'nic_status': 0}},\n",
|
||||
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n",
|
||||
" 'operating_status': 1},\n",
|
||||
" 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n",
|
||||
" 'health_status': 0}},\n",
|
||||
" 'NICS': {1: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
|
||||
" 'nic_status': 1},\n",
|
||||
" 2: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
|
||||
" 'nic_status': 0}},\n",
|
||||
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n",
|
||||
" 'operating_status': 1},\n",
|
||||
" 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 1}},\n",
|
||||
" 'health_status': 1}},\n",
|
||||
" 'NICS': {1: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
|
||||
" 'nic_status': 1},\n",
|
||||
" 2: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
|
||||
" 'nic_status': 0}},\n",
|
||||
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n",
|
||||
" 'operating_status': 1},\n",
|
||||
" 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n",
|
||||
" 'health_status': 0}},\n",
|
||||
" 'NICS': {1: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
|
||||
" 'nic_status': 1},\n",
|
||||
" 2: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
|
||||
" 'nic_status': 0}},\n",
|
||||
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n",
|
||||
" 'operating_status': 1},\n",
|
||||
" 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n",
|
||||
" 'health_status': 0}},\n",
|
||||
" 'NICS': {1: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
|
||||
" 'nic_status': 1},\n",
|
||||
" 2: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
|
||||
" 'nic_status': 0}},\n",
|
||||
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n",
|
||||
" 'operating_status': 1},\n",
|
||||
" 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n",
|
||||
" 'health_status': 0}},\n",
|
||||
" 'NICS': {1: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
|
||||
" 'nic_status': 1},\n",
|
||||
" 2: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
|
||||
" 'nic_status': 0}},\n",
|
||||
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n",
|
||||
" 'operating_status': 1},\n",
|
||||
" 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n",
|
||||
" 'health_status': 0}},\n",
|
||||
" 'NICS': {1: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
|
||||
" 'nic_status': 1},\n",
|
||||
" 2: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
|
||||
" 'nic_status': 0}},\n",
|
||||
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n",
|
||||
" 'operating_status': 1}}}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# create the env\n",
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
@@ -565,20 +412,9 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"res = FileSystemItemHealthStatus.GOOD\n",
|
||||
"res = FileSystemItemHealthStatus.GOOD\n",
|
||||
"res = FileSystemItemHealthStatus.COMPROMISED\n",
|
||||
"res = FileSystemItemHealthStatus.COMPROMISED\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test NODE_FOLDER_CHECKHASH\n",
|
||||
"res = env.game.simulation.network.get_node_by_hostname('database_server').file_system.get_folder(folder_name = 'database').health_status\n",
|
||||
@@ -618,12 +454,12 @@
|
||||
"source": [
|
||||
"def friendly_output_red_action(info):\n",
|
||||
" # parse the info dict form step output and write out what the red agent is doing\n",
|
||||
" red_info = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info[0]\n",
|
||||
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info.action\n",
|
||||
" if red_action == 'DONOTHING':\n",
|
||||
" red_str = 'DO NOTHING'\n",
|
||||
" elif red_action == 'NODE_APPLICATION_EXECUTE':\n",
|
||||
" client = \"client 1\" if red_info[1]['node_id'] == 0 else \"client 2\"\n",
|
||||
" client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n",
|
||||
" red_str = f\"ATTACK from {client}\"\n",
|
||||
" return red_str"
|
||||
]
|
||||
@@ -643,7 +479,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now the reward is -1, let's have a look at blue agent's observation."
|
||||
"Now the reward is -0.8, let's have a look at blue agent's observation."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -704,9 +540,9 @@
|
||||
"source": [
|
||||
"obs, reward, terminated, truncated, info = env.step(13) # patch the database\n",
|
||||
"print(f\"step: {env.game.step_counter}\")\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'][0]}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_1_green_user'][0]}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_2_green_user'][0]}\" )\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'].action}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_1_green_user'].action}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_2_green_user'].action}\" )\n",
|
||||
"print(f\"Blue reward:{reward}\" )"
|
||||
]
|
||||
},
|
||||
@@ -714,7 +550,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The patching takes two steps, so the reward hasn't changed yet. Let's do nothing for another timestep, the reward should improve.\n",
|
||||
"The fixing takes two steps, so the reward hasn't changed yet. Let's do nothing for another timestep, the reward should improve.\n",
|
||||
"\n",
|
||||
"The reward will increase slightly as soon as the file finishes restoring. Then, the reward will increase to 1 when both green agents make successful requests.\n",
|
||||
"\n",
|
||||
@@ -727,9 +563,9 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs, reward, terminated, truncated, info = env.step(0) # patch the database\n",
|
||||
"obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
|
||||
"print(f\"step: {env.game.step_counter}\")\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'][0]}\" )\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'].action}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_2_green_user']}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_1_green_user']}\" )\n",
|
||||
"print(f\"Blue reward:{reward:.2f}\" )"
|
||||
@@ -751,24 +587,26 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.step(13) # Patch the database\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
|
||||
"\n",
|
||||
"env.step(50) # Block client 1\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
|
||||
"\n",
|
||||
"env.step(51) # Block client 2\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
|
||||
"\n",
|
||||
"for step in range(30):\n",
|
||||
"while abs(reward - 0.8) > 1e-5:\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )"
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
|
||||
" if env.game.step_counter > 10000:\n",
|
||||
" break # make sure there's no infinite loop if something went wrong"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, even though the red agent executes an attack, the reward stays at 0.8."
|
||||
"Now, even though the red agent executes an attack, the reward will stay at 0.8."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -784,7 +622,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs['ACL']"
|
||||
"obs['NODES']['ROUTER0']"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -800,13 +638,30 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if obs['NODES'][6]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
|
||||
" # client 1 has NMNEs, let's unblock client 2\n",
|
||||
" env.step(58) # remove ACL rule 6\n",
|
||||
"elif obs['NODES'][7]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
|
||||
" env.step(57) # remove ACL rule 5\n",
|
||||
"else:\n",
|
||||
" print(\"something went wrong, neither client has NMNEs\")"
|
||||
"env.step(58) # Remove the ACL rule that blocks client 1\n",
|
||||
"env.step(57) # Remove the ACL rule that blocks client 2\n",
|
||||
"\n",
|
||||
"tries = 0\n",
|
||||
"while True:\n",
|
||||
" tries += 1\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0)\n",
|
||||
"\n",
|
||||
" if obs['NODES']['HOST5']['NICS'][1]['NMNE']['outbound'] == 1:\n",
|
||||
" # client 1 has NMNEs, let's block it\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(50) # block client 1\n",
|
||||
" print(\"blocking client 1\")\n",
|
||||
" break\n",
|
||||
" elif obs['NODES']['HOST6']['NICS'][1]['NMNE']['outbound'] == 1:\n",
|
||||
" # client 2 has NMNEs, so let's block it\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(51) # block client 2\n",
|
||||
" print(\"blocking client 2\")\n",
|
||||
" break\n",
|
||||
" if tries>100:\n",
|
||||
" print(\"Error: NMNE never increased\")\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
"env.step(13) # Patch the database\n",
|
||||
"print()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -824,14 +679,14 @@
|
||||
"source": [
|
||||
"for step in range(30):\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )"
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Reset the environment, you can rerun the other cells to verify that the attack works the same every episode."
|
||||
"Reset the environment, you can rerun the other cells to verify that the attack works the same every episode. (except the red agent will move between `client_1` and `client_2`.)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,5 +1,21 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Training an SB3 Agent\n",
|
||||
"\n",
|
||||
"This notebook will demonstrate how to use primaite to create and train a PPO agent, using a pre-defined configuration file."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### First, we import the inital packages and read in our configuration file."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -27,7 +43,14 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n"
|
||||
" cfg = yaml.safe_load(f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Using the given configuration, we generate the environment our agent will train in."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -40,12 +63,10 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from stable_baselines3 import PPO"
|
||||
"Lets define training parameters for the agent."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -54,7 +75,13 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = PPO('MlpPolicy', gym)\n"
|
||||
"from stable_baselines3 import PPO\n",
|
||||
"\n",
|
||||
"EPISODE_LEN = 128\n",
|
||||
"NUM_EPISODES = 10\n",
|
||||
"NO_STEPS = EPISODE_LEN * NUM_EPISODES\n",
|
||||
"BATCH_SIZE = 32\n",
|
||||
"LEARNING_RATE = 3e-4"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -63,7 +90,14 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.learn(total_timesteps=10)\n"
|
||||
"model = PPO('MlpPolicy', gym, learning_rate=LEARNING_RATE, n_steps=NO_STEPS, batch_size=BATCH_SIZE, verbose=0, tensorboard_log=\"./PPO_UC2/\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"With the agent configured, let's train for our defined number of episodes."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -72,7 +106,14 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.save(\"deleteme\")"
|
||||
"model.learn(total_timesteps=NO_STEPS)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, let's save the agent to a zip file that can be used in future evaluation."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -80,7 +121,44 @@
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"model.save(\"PrimAITE-PPO-example-agent\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, we load the saved agent and run it in evaluation mode."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"eval_model = PPO(\"MlpPolicy\", gym)\n",
|
||||
"eval_model = PPO.load(\"PrimAITE-PPO-example-agent\", gym)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Finally, evaluate the agent."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from stable_baselines3.common.evaluation import evaluate_policy\n",
|
||||
"\n",
|
||||
"evaluate_policy(eval_model, gym, n_eval_episodes=10)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -99,7 +177,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -26,8 +26,13 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
def __init__(self, game_config: Dict):
|
||||
"""Initialise the environment."""
|
||||
super().__init__()
|
||||
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
|
||||
self.game_config: Dict = game_config
|
||||
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
|
||||
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config))
|
||||
"""Current game."""
|
||||
self._agent_name = next(iter(self.game.rl_agents))
|
||||
@@ -36,9 +41,6 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
|
||||
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
|
||||
@property
|
||||
def agent(self) -> ProxyAgent:
|
||||
"""Grab a fresh reference to the agent object because it will be reinstantiated each episode."""
|
||||
@@ -46,37 +48,36 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||
"""Perform a step in the environment."""
|
||||
# make ProxyAgent store the action chosen my the RL policy
|
||||
# make ProxyAgent store the action chosen by the RL policy
|
||||
step = self.game.step_counter
|
||||
self.agent.store_action(action)
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
agent_actions = self.game.apply_agent_actions()
|
||||
self.game.pre_timestep()
|
||||
self.game.apply_agent_actions()
|
||||
self.game.advance_timestep()
|
||||
state = self.game.get_sim_state()
|
||||
|
||||
self.game.update_agents(state)
|
||||
|
||||
next_obs = self._get_obs()
|
||||
next_obs = self._get_obs() # this doesn't update observation, just gets the current observation
|
||||
reward = self.agent.reward_function.current_reward
|
||||
terminated = False
|
||||
truncated = self.game.calculate_truncated()
|
||||
info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience.
|
||||
info = {
|
||||
"agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()}
|
||||
} # tell us what all the agents did for convenience.
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(action, state, reward)
|
||||
if self.io.settings.save_agent_actions:
|
||||
self.io.store_agent_actions(
|
||||
agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter
|
||||
)
|
||||
self._write_step_metadata_json(step, action, state, reward)
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
|
||||
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
|
||||
def _write_step_metadata_json(self, step: int, action: int, state: Dict, reward: int):
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"step_{self.game.step_counter}.json"
|
||||
path = output_dir / f"step_{step}.json"
|
||||
|
||||
data = {
|
||||
"episode": self.episode_counter,
|
||||
"step": self.game.step_counter,
|
||||
"step": step,
|
||||
"action": int(action),
|
||||
"reward": int(reward),
|
||||
"state": state,
|
||||
@@ -91,13 +92,13 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
f"avg. reward: {self.agent.reward_function.total_reward}"
|
||||
)
|
||||
if self.io.settings.save_agent_actions:
|
||||
self.io.write_agent_actions(episode=self.episode_counter)
|
||||
self.io.clear_agent_actions()
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
self.episode_counter += 1
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
self.game.update_agents(state=state)
|
||||
next_obs = self._get_obs()
|
||||
info = {}
|
||||
return next_obs, info
|
||||
@@ -124,6 +125,12 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
else:
|
||||
return self.agent.observation_manager.current_observation
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
|
||||
class PrimaiteRayEnv(gymnasium.Env):
|
||||
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
|
||||
@@ -147,6 +154,10 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
"""Perform a step in the environment."""
|
||||
return self.env.step(action)
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
self.env.close()
|
||||
|
||||
|
||||
class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
|
||||
@@ -160,6 +171,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
"""
|
||||
self.game_config: Dict = env_config
|
||||
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
|
||||
self.io = PrimaiteIO.from_config(env_config.get("io_settings"))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config))
|
||||
"""Reference to the primaite game"""
|
||||
self._agent_ids = list(self.game.rl_agents.keys())
|
||||
@@ -179,9 +192,6 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
{name: agent.action_manager.space for name, agent in self.agents.items()}
|
||||
)
|
||||
|
||||
self.io = PrimaiteIO.from_config(env_config.get("io_settings"))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
@@ -192,8 +202,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
if self.io.settings.save_agent_actions:
|
||||
self.io.write_agent_actions(episode=self.episode_counter)
|
||||
self.io.clear_agent_actions()
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
self.episode_counter += 1
|
||||
@@ -214,10 +224,12 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
identifier.
|
||||
:rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict]
|
||||
"""
|
||||
step = self.game.step_counter
|
||||
# 1. Perform actions
|
||||
for agent_name, action in actions.items():
|
||||
self.agents[agent_name].store_action(action)
|
||||
agent_actions = self.game.apply_agent_actions()
|
||||
self.game.pre_timestep()
|
||||
self.game.apply_agent_actions()
|
||||
|
||||
# 2. Advance timestep
|
||||
self.game.advance_timestep()
|
||||
@@ -235,22 +247,18 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
|
||||
truncateds["__all__"] = self.game.calculate_truncated()
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(actions, state, rewards)
|
||||
if self.io.settings.save_agent_actions:
|
||||
self.io.store_agent_actions(
|
||||
agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter
|
||||
)
|
||||
self._write_step_metadata_json(step, actions, state, rewards)
|
||||
return next_obs, rewards, terminateds, truncateds, infos
|
||||
|
||||
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):
|
||||
def _write_step_metadata_json(self, step: int, actions: Dict, state: Dict, rewards: Dict):
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"step_{self.game.step_counter}.json"
|
||||
path = output_dir / f"step_{step}.json"
|
||||
|
||||
data = {
|
||||
"episode": self.episode_counter,
|
||||
"step": self.game.step_counter,
|
||||
"step": step,
|
||||
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
|
||||
"reward": rewards,
|
||||
"state": state,
|
||||
@@ -267,3 +275,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
unflat_obs = agent.observation_manager.current_observation
|
||||
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
return obs
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
@@ -29,10 +29,12 @@ class PrimaiteIO:
|
||||
"""Whether to save a log of all agents' actions every step."""
|
||||
save_step_metadata: bool = False
|
||||
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
|
||||
save_pcap_logs: bool = False
|
||||
save_pcap_logs: bool = True
|
||||
"""Whether to save PCAP logs."""
|
||||
save_sys_logs: bool = False
|
||||
save_sys_logs: bool = True
|
||||
"""Whether to save system logs."""
|
||||
write_sys_log_to_terminal: bool = False
|
||||
"""Whether to write the sys log to the terminal."""
|
||||
|
||||
def __init__(self, settings: Optional[Settings] = None) -> None:
|
||||
"""
|
||||
@@ -47,8 +49,7 @@ class PrimaiteIO:
|
||||
SIM_OUTPUT.path = self.session_path / "simulation_output"
|
||||
SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs
|
||||
SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs
|
||||
|
||||
self.agent_action_log: List[Dict] = []
|
||||
SIM_OUTPUT.write_sys_log_to_terminal = self.settings.write_sys_log_to_terminal
|
||||
|
||||
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
|
||||
"""Create a folder for the session and return the path to it."""
|
||||
@@ -72,51 +73,29 @@ class PrimaiteIO:
|
||||
"""Return the path where agent actions will be saved."""
|
||||
return self.session_path / "agent_actions" / f"episode_{episode}.json"
|
||||
|
||||
def store_agent_actions(self, agent_actions: Dict, episode: int, timestep: int) -> None:
|
||||
"""Cache agent actions for a particular step.
|
||||
|
||||
:param agent_actions: Dictionary describing actions for any agents that acted in this timestep. The expected
|
||||
format contains agent identifiers as keys. The keys should map to a tuple of [CAOS action, parameters]
|
||||
CAOS action is a string representing one the CAOS actions.
|
||||
parameters is a dict of parameter names and values for that particular CAOS action.
|
||||
For example:
|
||||
{
|
||||
'green1' : ('NODE_APPLICATION_EXECUTE', {'node_id':1, 'application_id':0}),
|
||||
'defender': ('DO_NOTHING', {})
|
||||
}
|
||||
:type agent_actions: Dict
|
||||
:param timestep: Simulation timestep when these actions occurred.
|
||||
:type timestep: int
|
||||
"""
|
||||
self.agent_action_log.append(
|
||||
[
|
||||
{
|
||||
"episode": episode,
|
||||
"timestep": timestep,
|
||||
"agent_actions": agent_actions,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def write_agent_actions(self, episode: int) -> None:
|
||||
def write_agent_actions(self, agent_actions: Dict[str, List], episode: int) -> None:
|
||||
"""Take the contents of the agent action log and write it to a file.
|
||||
|
||||
:param episode: Episode number
|
||||
:type episode: int
|
||||
"""
|
||||
data = {}
|
||||
longest_history = max([len(hist) for hist in agent_actions.values()])
|
||||
for i in range(longest_history):
|
||||
data[i] = {"timestep": i, "episode": episode}
|
||||
data[i].update({name: acts[i] for name, acts in agent_actions.items() if len(acts) > i})
|
||||
|
||||
path = self.generate_agent_actions_save_path(episode=episode)
|
||||
path.parent.mkdir(exist_ok=True, parents=True)
|
||||
path.touch()
|
||||
_LOGGER.info(f"Saving agent action log to {path}")
|
||||
with open(path, "w") as file:
|
||||
json.dump(self.agent_action_log, fp=file, indent=1)
|
||||
|
||||
def clear_agent_actions(self) -> None:
|
||||
"""Reset the agent action log back to an empty dictionary."""
|
||||
self.agent_action_log = []
|
||||
json.dump(data, fp=file, indent=1, default=lambda x: x.model_dump())
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "PrimaiteIO":
|
||||
"""Create an instance of PrimaiteIO based on a configuration dict."""
|
||||
new = cls()
|
||||
config = config or {}
|
||||
new = cls(settings=cls.Settings(**config))
|
||||
|
||||
return new
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
from primaite.session.policy.rllib import RaySingleAgentPolicy
|
||||
from primaite.session.policy.sb3 import SB3Policy
|
||||
|
||||
__all__ = ["SB3Policy", "RaySingleAgentPolicy"]
|
||||
@@ -1,82 +0,0 @@
|
||||
"""Base class and common logic for RL policies."""
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Type, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.session.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
|
||||
class PolicyABC(ABC):
|
||||
"""Base class for reinforcement learning agents."""
|
||||
|
||||
_registry: Dict[str, Type["PolicyABC"]] = {}
|
||||
"""
|
||||
Registry of policy types, keyed by name.
|
||||
|
||||
Automatically populated when PolicyABC subclasses are defined. Used for defining from_config.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register a policy subclass.
|
||||
|
||||
:param name: Identifier used by from_config to create an instance of the policy.
|
||||
:type name: str
|
||||
:raises ValueError: When attempting to create a policy with a duplicate name.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Duplicate policy name {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, session: "PrimaiteSession") -> None:
|
||||
"""
|
||||
Initialize a reinforcement learning policy.
|
||||
|
||||
:param session: The session context.
|
||||
:type session: PrimaiteSession
|
||||
:param agents: The agents to train.
|
||||
:type agents: List[RLAgent]
|
||||
"""
|
||||
self.session: "PrimaiteSession" = session
|
||||
"""Reference to the session."""
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def eval(self, n_episodes: int, timesteps_per_episode: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""Save the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self) -> None:
|
||||
"""Load agent from a file."""
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the agent."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC":
|
||||
"""
|
||||
Create an RL policy from a config by calling the relevant subclass's from_config method.
|
||||
|
||||
Subclasses should not call super().from_config(), they should just handle creation form config.
|
||||
"""
|
||||
# Assume that basically the contents of training_config are passed into here.
|
||||
# I should really define a config schema class using pydantic.
|
||||
|
||||
PolicyType = cls._registry[config.rl_framework]
|
||||
return PolicyType.from_config(config=config, session=session)
|
||||
@@ -1,111 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, TYPE_CHECKING
|
||||
|
||||
from primaite.session.environment import PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from primaite.session.policy.policy import PolicyABC
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.session.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
import ray
|
||||
from ray import air, tune
|
||||
from ray.rllib.algorithms import ppo
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
|
||||
"""Single agent RL policy using Ray RLLib."""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
|
||||
super().__init__(session=session)
|
||||
|
||||
self.config = {
|
||||
"env": PrimaiteRayEnv,
|
||||
"env_config": {"game": session.game},
|
||||
"disable_env_checking": True,
|
||||
"num_rollout_workers": 0,
|
||||
}
|
||||
|
||||
ray.shutdown()
|
||||
ray.init()
|
||||
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
self.config["training_iterations"] = n_episodes * timesteps_per_episode
|
||||
self.config["train_batch_size"] = 128
|
||||
self._algo = ppo.PPO(config=self.config)
|
||||
_LOGGER.info("Starting RLLIB training session")
|
||||
self._algo.train()
|
||||
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
for ep in range(n_episodes):
|
||||
obs, info = self.session.env.reset()
|
||||
for step in range(self.session.game.options.max_episode_length):
|
||||
action = self._algo.compute_single_action(observation=obs, explore=False)
|
||||
obs, rew, term, trunc, info = self.session.env.step(action)
|
||||
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""Save the policy to a file."""
|
||||
self._algo.save(save_path)
|
||||
|
||||
def load(self, model_path: Path) -> None:
|
||||
"""Load policy parameters from a file."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy":
|
||||
"""Create a policy from a config."""
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
|
||||
|
||||
class RayMultiAgentPolicy(PolicyABC, identifier="RLLIB_multi_agent"):
|
||||
"""Mutli agent RL policy using Ray RLLib."""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO"], seed: Optional[int] = None):
|
||||
"""Initialise multi agent policy wrapper."""
|
||||
super().__init__(session=session)
|
||||
|
||||
self.config = (
|
||||
PPOConfig()
|
||||
.environment(env=PrimaiteRayMARLEnv, env_config={"game": session.game})
|
||||
.rollouts(num_rollout_workers=0)
|
||||
.multi_agent(
|
||||
policies={agent.agent_name for agent in session.game.rl_agents},
|
||||
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
|
||||
)
|
||||
.training(train_batch_size=128)
|
||||
)
|
||||
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
checkpoint_freq = self.session.io_manager.settings.checkpoint_interval
|
||||
tune.Tuner(
|
||||
"PPO",
|
||||
run_config=air.RunConfig(
|
||||
stop={"training_iteration": n_episodes * timesteps_per_episode},
|
||||
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=checkpoint_freq),
|
||||
),
|
||||
param_space=self.config,
|
||||
).fit()
|
||||
|
||||
def load(self, model_path: Path) -> None:
|
||||
"""Load policy parameters from a file."""
|
||||
return NotImplemented
|
||||
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
"""Evaluate trained policy."""
|
||||
return NotImplemented
|
||||
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""Save policy parameters to a file."""
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RayMultiAgentPolicy":
|
||||
"""Create policy from config."""
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
@@ -1,79 +0,0 @@
|
||||
"""Stable baselines 3 policy."""
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Type, TYPE_CHECKING, Union
|
||||
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
|
||||
from stable_baselines3.common.callbacks import CheckpointCallback
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.ppo import MlpPolicy as PPO_MLP
|
||||
|
||||
from primaite.session.policy.policy import PolicyABC
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.session.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
|
||||
class SB3Policy(PolicyABC, identifier="SB3"):
|
||||
"""Single agent RL policy using stable baselines 3."""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
|
||||
"""Initialize a stable baselines 3 policy."""
|
||||
super().__init__(session=session)
|
||||
|
||||
self._agent_class: Type[Union[PPO, A2C]]
|
||||
if algorithm == "PPO":
|
||||
self._agent_class = PPO
|
||||
policy = PPO_MLP
|
||||
elif algorithm == "A2C":
|
||||
self._agent_class = A2C
|
||||
policy = A2C_MLP
|
||||
else:
|
||||
raise ValueError(f"Unknown algorithm `{algorithm}` for stable_baselines3 policy")
|
||||
self._agent = self._agent_class(
|
||||
policy=policy,
|
||||
env=self.session.env,
|
||||
n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
if self.session.save_checkpoints:
|
||||
checkpoint_callback = CheckpointCallback(
|
||||
save_freq=timesteps_per_episode * self.session.checkpoint_interval,
|
||||
save_path=self.session.io_manager.generate_model_save_path("sb3"),
|
||||
name_prefix="sb3_model",
|
||||
)
|
||||
else:
|
||||
checkpoint_callback = None
|
||||
self._agent.learn(total_timesteps=n_episodes * timesteps_per_episode, callback=checkpoint_callback)
|
||||
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
_ = evaluate_policy(
|
||||
self._agent,
|
||||
self.session.env,
|
||||
n_eval_episodes=n_episodes,
|
||||
deterministic=deterministic,
|
||||
return_episode_rewards=True,
|
||||
)
|
||||
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""
|
||||
Save the current policy parameters.
|
||||
|
||||
Warning: The recommended way to save model checkpoints is to use a callback within the `learn()` method. Please
|
||||
refer to https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html for more information.
|
||||
Therefore, this method is only used to save the final model.
|
||||
"""
|
||||
self._agent.save(save_path)
|
||||
|
||||
def load(self, model_path: Path) -> None:
|
||||
"""Load agent from a checkpoint."""
|
||||
self._agent = self._agent_class.load(model_path, env=self.session.env)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy":
|
||||
"""Create an agent from config file."""
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
@@ -1,119 +0,0 @@
|
||||
# raise DeprecationWarning("This module is deprecated")
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from primaite.session.io import PrimaiteIO
|
||||
|
||||
# from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.policy.policy import PolicyABC
|
||||
|
||||
|
||||
class TrainingOptions(BaseModel):
|
||||
"""Options for training the RL agent."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
rl_framework: Literal["SB3", "RLLIB_single_agent", "RLLIB_multi_agent"]
|
||||
rl_algorithm: Literal["PPO", "A2C"]
|
||||
n_learn_episodes: int
|
||||
n_eval_episodes: Optional[int] = None
|
||||
max_steps_per_episode: int
|
||||
# checkpoint_freq: Optional[int] = None
|
||||
deterministic_eval: bool
|
||||
seed: Optional[int]
|
||||
n_agents: int
|
||||
agent_references: List[str]
|
||||
|
||||
|
||||
class SessionMode(Enum):
|
||||
"""Helper to keep track of the current session mode."""
|
||||
|
||||
TRAIN = "train"
|
||||
EVAL = "eval"
|
||||
MANUAL = "manual"
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
"""The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments."""
|
||||
|
||||
def __init__(self, game_cfg: Dict):
|
||||
"""Initialise PrimaiteSession object."""
|
||||
self.training_options: TrainingOptions
|
||||
"""Options specific to agent training."""
|
||||
|
||||
self.mode: SessionMode = SessionMode.MANUAL
|
||||
"""Current session mode."""
|
||||
|
||||
self.env: Union[PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv]
|
||||
"""The environment that the RL algorithm can consume."""
|
||||
|
||||
self.policy: PolicyABC
|
||||
"""The reinforcement learning policy."""
|
||||
|
||||
self.io_manager: Optional["PrimaiteIO"] = None
|
||||
"""IO manager for the session."""
|
||||
|
||||
self.game_cfg: Dict = game_cfg
|
||||
"""Primaite Game object for managing main simulation loop and agents."""
|
||||
|
||||
self.save_checkpoints: bool = False
|
||||
"""Whether to save checkpoints."""
|
||||
|
||||
self.checkpoint_interval: int = 10
|
||||
"""If save_checkpoints is true, checkpoints will be saved every checkpoint_interval episodes."""
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training/eval session."""
|
||||
print("Starting Primaite Session")
|
||||
self.mode = SessionMode.TRAIN
|
||||
n_learn_episodes = self.training_options.n_learn_episodes
|
||||
n_eval_episodes = self.training_options.n_eval_episodes
|
||||
max_steps_per_episode = self.training_options.max_steps_per_episode
|
||||
|
||||
deterministic_eval = self.training_options.deterministic_eval
|
||||
self.policy.learn(
|
||||
n_episodes=n_learn_episodes,
|
||||
timesteps_per_episode=max_steps_per_episode,
|
||||
)
|
||||
self.save_models()
|
||||
|
||||
self.mode = SessionMode.EVAL
|
||||
if n_eval_episodes > 0:
|
||||
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
|
||||
|
||||
self.mode = SessionMode.MANUAL
|
||||
|
||||
def save_models(self) -> None:
|
||||
"""Save the RL models."""
|
||||
save_path = self.io_manager.generate_model_save_path("temp_model_name")
|
||||
self.policy.save(save_path)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
|
||||
"""Create a PrimaiteSession object from a config dictionary."""
|
||||
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
|
||||
io_manager = PrimaiteIO.from_config(cfg.get("io_settings", {}))
|
||||
|
||||
sess = cls(game_cfg=cfg)
|
||||
sess.io_manager = io_manager
|
||||
sess.training_options = TrainingOptions(**cfg["training_config"])
|
||||
sess.save_checkpoints = cfg.get("io_settings", {}).get("save_checkpoints")
|
||||
sess.checkpoint_interval = cfg.get("io_settings", {}).get("checkpoint_interval")
|
||||
|
||||
# CREATE ENVIRONMENT
|
||||
if sess.training_options.rl_framework == "RLLIB_single_agent":
|
||||
sess.env = PrimaiteRayEnv(env_config=cfg)
|
||||
elif sess.training_options.rl_framework == "RLLIB_multi_agent":
|
||||
sess.env = PrimaiteRayMARLEnv(env_config=cfg)
|
||||
elif sess.training_options.rl_framework == "SB3":
|
||||
sess.env = PrimaiteGymEnv(game_config=cfg)
|
||||
|
||||
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
|
||||
if agent_load_path:
|
||||
sess.policy.load(Path(agent_load_path))
|
||||
|
||||
return sess
|
||||
@@ -14,6 +14,7 @@ class _SimOutput:
|
||||
)
|
||||
self.save_pcap_logs: bool = False
|
||||
self.save_sys_logs: bool = False
|
||||
self.write_sys_log_to_terminal: bool = False
|
||||
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
"source": [
|
||||
"# Build a simulation using the Python API\n",
|
||||
"\n",
|
||||
"Currently, this notbook manipulates the simulation by directly placing objects inside of the attributes of the network and domain. It should be refactored when proper methods exist for adding these objects.\n"
|
||||
"Currently, this notebook manipulates the simulation by directly placing objects inside of the attributes of the network and domain. It should be refactored when proper methods exist for adding these objects.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -58,7 +58,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.simulator.network.hardware.base import Node\n"
|
||||
"from primaite.simulator.network.hardware.nodes.host.computer import Computer\n",
|
||||
"from primaite.simulator.network.hardware.nodes.host.server import Server"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -67,9 +68,9 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"my_pc = Node(hostname=\"primaite_pc\",)\n",
|
||||
"my_pc = Computer(hostname=\"Computer\", ip_address=\"192.168.1.10\", subnet_mask=\"255.255.255.0\")\n",
|
||||
"net.add_node(my_pc)\n",
|
||||
"my_server = Node(hostname=\"google_server\")\n",
|
||||
"my_server = Server(hostname=\"Server\", ip_address=\"192.168.1.11\", subnet_mask=\"255.255.255.0\")\n",
|
||||
"net.add_node(my_server)\n"
|
||||
]
|
||||
},
|
||||
@@ -86,7 +87,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.simulator.network.hardware.base import NIC, Link, Switch\n"
|
||||
"from primaite.simulator.network.hardware.nodes.host.host_node import NIC\n",
|
||||
"from primaite.simulator.network.hardware.nodes.network.switch import Switch\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -95,19 +97,17 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"my_swtich = Switch(hostname=\"switch1\", num_ports=12)\n",
|
||||
"net.add_node(my_swtich)\n",
|
||||
"my_switch = Switch(hostname=\"switch1\", num_ports=12)\n",
|
||||
"net.add_node(my_switch)\n",
|
||||
"\n",
|
||||
"pc_nic = NIC(ip_address=\"130.1.1.1\", gateway=\"130.1.1.255\", subnet_mask=\"255.255.255.0\")\n",
|
||||
"my_pc.connect_nic(pc_nic)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"server_nic = NIC(ip_address=\"130.1.1.2\", gateway=\"130.1.1.255\", subnet_mask=\"255.255.255.0\")\n",
|
||||
"my_server.connect_nic(server_nic)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"net.connect(pc_nic, my_swtich.switch_ports[1])\n",
|
||||
"net.connect(server_nic, my_swtich.switch_ports[2])\n"
|
||||
"net.connect(pc_nic, my_switch.network_interface[1])\n",
|
||||
"net.connect(server_nic, my_switch.network_interface[2])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -124,7 +124,8 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.simulator.file_system.file_type import FileType\n",
|
||||
"from primaite.simulator.file_system.file_system import File"
|
||||
"from primaite.simulator.file_system.file_system import File\n",
|
||||
"from primaite.simulator.system.core.sys_log import SysLog"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -134,7 +135,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"my_pc_downloads_folder = my_pc.file_system.create_folder(\"downloads\")\n",
|
||||
"my_pc_downloads_folder.add_file(File(name=\"firefox_installer.zip\",file_type=FileType.ZIP))"
|
||||
"my_pc_downloads_folder.add_file(File(name=\"firefox_installer.zip\",folder_id=\"Test\", folder_name=\"downloads\" ,file_type=FileType.ZIP, sys_log=SysLog(hostname=\"Test\")))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -160,9 +161,12 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"from primaite.simulator.system.applications.application import Application, ApplicationOperatingState\n",
|
||||
"from primaite.simulator.system.software import SoftwareHealthState, SoftwareCriticality\n",
|
||||
"from primaite.simulator.network.transmission.transport_layer import Port\n",
|
||||
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
|
||||
"from primaite.simulator.file_system.file_system import FileSystem\n",
|
||||
"\n",
|
||||
"# no applications exist yet so we will create our own.\n",
|
||||
"class MSPaint(Application):\n",
|
||||
@@ -176,7 +180,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, ports={Port.HTTP}, operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual')"
|
||||
"mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port.HTTP, protocol = IPProtocol.NONE,operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -257,9 +261,8 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "03b2013a-b7d1-47ee-b08c-8dab83833720",
|
||||
"id": "0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# PrimAITE Router Simulation Demo\n",
|
||||
@@ -12,7 +12,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "raw",
|
||||
"id": "c8bb5698-e746-4e90-9c2f-efe962acdfa0",
|
||||
"id": "1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" +------------+\n",
|
||||
@@ -48,7 +48,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "415d487c-6457-497d-85d6-99439b3541e7",
|
||||
"id": "2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## The Network\n",
|
||||
@@ -60,7 +60,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "de57ac8c-5b28-4847-a759-2ceaf5593329",
|
||||
"id": "3",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -72,7 +72,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a1e2e4df-67c0-4584-ab27-47e2c7c7fcd2",
|
||||
"id": "4",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -83,7 +83,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fb052c56-e9ca-4093-9115-d0c440b5ff53",
|
||||
"id": "5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Most of the Network components have a `.show()` function that prints a table of information about that object. We can view the Nodes and Links on the Network by calling `network.show()`."
|
||||
@@ -92,7 +92,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cc199741-ef2e-47f5-b2f0-e20049ccf40f",
|
||||
"id": "6",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -103,7 +103,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "76d2b7e9-280b-4741-a8b3-a84bed219fac",
|
||||
"id": "7",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -115,7 +115,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "84113002-843e-4cab-b899-667b50f25f6b",
|
||||
"id": "8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Router Nodes\n",
|
||||
@@ -125,7 +125,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bf63a178-eee5-4669-bf64-13aea7ecf6cb",
|
||||
"id": "9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `router.show()` displays the Ethernet interfaces on the Router. If you need a table in markdown format, pass `markdown=True`."
|
||||
@@ -134,7 +134,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e76d1854-961e-438c-b40f-77fd9c3abe38",
|
||||
"id": "10",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -145,7 +145,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e000540c-687c-4254-870c-1d814603bdbf",
|
||||
"id": "11",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `router.arp.show()` displays the Router ARP Cache."
|
||||
@@ -154,7 +154,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "92de8b42-92d7-4934-9c12-50bf724c9eb2",
|
||||
"id": "12",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -165,7 +165,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a9ff7ee8-9482-44de-9039-b684866bdc82",
|
||||
"id": "13",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `router.acl.show()` displays the Access Control List."
|
||||
@@ -174,7 +174,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5922282a-d22b-4e55-9176-f3f3654c849f",
|
||||
"id": "14",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -185,7 +185,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "71c87884-f793-4c9f-b004-5b0df86cf585",
|
||||
"id": "15",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `router.router_table.show()` displays the static routes the Router provides."
|
||||
@@ -194,7 +194,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "327203be-f475-4727-82a1-e992d3b70ed8",
|
||||
"id": "16",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -205,7 +205,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eef561a8-3d39-4c8b-bbc8-e8b10b8ed25f",
|
||||
"id": "17",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `router.sys_log.show()` displays the Router system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
|
||||
@@ -214,7 +214,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3d0aa004-b10c-445f-aaab-340e0e716c74",
|
||||
"id": "18",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -225,7 +225,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "25630c90-c54e-4b5d-8bf4-ad1b0722e126",
|
||||
"id": "19",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Switch Nodes\n",
|
||||
@@ -235,16 +235,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4879394d-2981-40de-a229-e19b09a34e6e",
|
||||
"id": "20",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `switch.show()` displays the Switch orts on the Switch."
|
||||
"Calling `switch.show()` displays the Switch ports on the Switch."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e7fd439b-5442-4e9d-9e7d-86dacb77f458",
|
||||
"id": "21",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -255,29 +255,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "beb8dbd6-7250-4ac9-9fa2-d2a9c0e5fd19",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"Calling `switch.arp.show()` displays the Switch ARP Cache."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d06e1310-4a77-4315-a59f-cb1b49ca2352",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"switch_1\").arp.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fda75ac3-8123-4234-8f36-86547891d8df",
|
||||
"id": "22",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `switch.sys_log.show()` displays the Switch system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
|
||||
@@ -286,7 +264,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a0d984b7-a7c1-4bbd-aa5a-9d3caecb08dc",
|
||||
"id": "23",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -297,7 +275,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2f1d99ad-db4f-4baf-8a35-e1d95f269586",
|
||||
"id": "24",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Computer/Server Nodes\n",
|
||||
@@ -307,7 +285,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c9e2251a-1b47-46e5-840f-7fec3e39c5aa",
|
||||
"id": "25",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -318,7 +296,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "656c37f6-b145-42af-9714-8d2886d0eff8",
|
||||
"id": "26",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -329,7 +307,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f1097a49-a3da-4d79-a06d-ae8af452918f",
|
||||
"id": "27",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `computer.arp.show()` displays the Computer/Server ARP Cache."
|
||||
@@ -338,7 +316,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "66b267d6-2308-486a-b9aa-cb8d3bcf0753",
|
||||
"id": "28",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -349,16 +327,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0d1fcad8-5b1a-4d8b-a49f-aa54a95fcaf0",
|
||||
"id": "29",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `switch.sys_log.show()` displays the Computer/Server system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
|
||||
"Calling `computer.sys_log.show()` displays the Computer/Server system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1b5debe8-ef1b-445d-8fa9-6a45568f21f3",
|
||||
"id": "30",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -369,7 +347,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fcfa1773-798c-4ada-9318-c3ad928217da",
|
||||
"id": "31",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Basic Network Comms Check\n",
|
||||
@@ -380,7 +358,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "495b7de4-b6ce-41a6-9114-f74752ab4491",
|
||||
"id": "32",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -391,7 +369,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3e13922a-217f-4f4e-99b6-57a07613cade",
|
||||
"id": "33",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We'll first ping client_1's default gateway."
|
||||
@@ -400,7 +378,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a38abb71-994e-49e8-8f51-e9a550e95b99",
|
||||
"id": "34",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -412,7 +390,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8388e1e9-30e3-4534-8e5a-c6e9144149d2",
|
||||
"id": "35",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -423,7 +401,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "02c76d5c-d954-49db-912d-cb9c52f46375",
|
||||
"id": "36",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, we'll ping the interface of the 192.168.1.0/24 Network on the Router (port 1)."
|
||||
@@ -432,7 +410,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ff8e976a-c16b-470c-8923-325713a30d6c",
|
||||
"id": "37",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -443,7 +421,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "80280404-a5ab-452f-8a02-771a0d7496b1",
|
||||
"id": "38",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"And finally, we'll ping the web server."
|
||||
@@ -452,7 +430,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c4163f8d-6a72-410c-9f5c-4f881b7de45e",
|
||||
"id": "39",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -463,7 +441,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1194c045-ba77-4427-be30-ed7b5b224850",
|
||||
"id": "40",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To confirm that the ping was received and processed by the web_server, we can view the sys log"
|
||||
@@ -472,7 +450,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e79a523a-5780-45b6-8798-c434e0e522bd",
|
||||
"id": "41",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -483,17 +461,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5928f6dd-1006-45e3-99f3-8f311a875faa",
|
||||
"id": "42",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Advanced Network Usage\n",
|
||||
"\n",
|
||||
"We can now use the Network to perform some more advaced things."
|
||||
"We can now use the Network to perform some more advanced things."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5e023ef3-7d18-4006-96ee-042a06a481fc",
|
||||
"id": "43",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's attempt to prevent client_2 from being able to ping the web server. First, we'll confirm that it can ping the server first..."
|
||||
@@ -502,7 +480,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "603cf913-e261-49da-a7dd-85e1bb6dec56",
|
||||
"id": "44",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -513,7 +491,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5cf962a4-20e6-44ae-9748-7fc5267ae111",
|
||||
"id": "45",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we look at the client_2 sys log we can see that the four ICMP echo requests were sent and four ICMP each replies were received:"
|
||||
@@ -522,7 +500,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e047de00-3de4-4823-b26a-2c8d64c7a663",
|
||||
"id": "46",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -533,7 +511,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bdc4741d-6e3e-4aec-a69c-c2e9653bd02c",
|
||||
"id": "47",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we'll add an ACL to block ICMP from 192.168.10.22"
|
||||
@@ -542,7 +520,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6db355ae-b99a-441b-a2c4-4ffe78f46bff",
|
||||
"id": "48",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -550,7 +528,7 @@
|
||||
"source": [
|
||||
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
|
||||
"from primaite.simulator.network.transmission.transport_layer import Port\n",
|
||||
"from primaite.simulator.network.hardware.nodes.router import ACLAction\n",
|
||||
"from primaite.simulator.network.hardware.nodes.network.router import ACLAction\n",
|
||||
"network.get_node_by_hostname(\"router_1\").acl.add_rule(\n",
|
||||
" action=ACLAction.DENY,\n",
|
||||
" protocol=IPProtocol.ICMP,\n",
|
||||
@@ -562,7 +540,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a345e000-8842-4827-af96-adc0fbe390fb",
|
||||
"id": "49",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -573,7 +551,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3a5bfd9f-04cb-493e-a86c-cd268563a262",
|
||||
"id": "50",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we attempt (and fail) to ping the web server"
|
||||
@@ -582,7 +560,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a4f4ff31-590f-40fb-b13d-efaa8c2720b6",
|
||||
"id": "51",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -593,7 +571,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "83e56497-097b-45cb-964e-b15c72547b38",
|
||||
"id": "52",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can check that the ping was actually sent by client_2 by viewing the sys log"
|
||||
@@ -602,7 +580,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f62b8a4e-fd3b-4059-b108-3d4a0b18f2a0",
|
||||
"id": "53",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -613,7 +591,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c7040311-a879-4620-86a0-55d0774156e5",
|
||||
"id": "54",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can check the router sys log to see why the traffic was blocked"
|
||||
@@ -622,7 +600,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7e53d776-99da-4d2c-a2a7-bd7ce27bff4c",
|
||||
"id": "55",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -633,7 +611,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "aba0bc7d-da57-477b-b34a-3688b5aab2c6",
|
||||
"id": "56",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now a final check to ensure that client_1 can still ping the web_server."
|
||||
@@ -642,7 +620,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d542734b-7582-4af7-8254-bda3de50d091",
|
||||
"id": "57",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -654,7 +632,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d78e9fe3-02c6-4792-944f-5622e26e0412",
|
||||
"id": "58",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
|
||||
@@ -7,12 +7,10 @@ from uuid import uuid4
|
||||
from pydantic import BaseModel, ConfigDict, Field, validate_call
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
RequestFormat = List[Union[str, int, float]]
|
||||
|
||||
|
||||
class RequestPermissionValidator(BaseModel):
|
||||
"""
|
||||
@@ -228,6 +226,15 @@ class SimComponent(BaseModel):
|
||||
return
|
||||
return self._request_manager(request, context)
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
Apply any logic that needs to happen at the beginning of the timestep to ensure correct observations/rewards.
|
||||
|
||||
:param timestep: what's the current time
|
||||
:type timestep: int
|
||||
"""
|
||||
pass
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
Apply a timestep evolution to this component.
|
||||
|
||||
@@ -103,6 +103,10 @@ class File(FileSystemItemABC):
|
||||
"""
|
||||
super().apply_timestep(timestep=timestep)
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
|
||||
# reset the number of accesses to 0
|
||||
self.num_access = 0
|
||||
|
||||
|
||||
@@ -427,15 +427,21 @@ class FileSystem(SimComponent):
|
||||
"""Apply time step to FileSystem and its child folders and files."""
|
||||
super().apply_timestep(timestep=timestep)
|
||||
|
||||
# apply timestep to folders
|
||||
for folder_id in self.folders:
|
||||
self.folders[folder_id].apply_timestep(timestep=timestep)
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
# reset number of file creations
|
||||
self.num_file_creations = 0
|
||||
|
||||
# reset number of file deletions
|
||||
self.num_file_deletions = 0
|
||||
|
||||
# apply timestep to folders
|
||||
for folder_id in self.folders:
|
||||
self.folders[folder_id].apply_timestep(timestep=timestep)
|
||||
for folder in self.folders.values():
|
||||
folder.pre_timestep(timestep)
|
||||
|
||||
###############################################################
|
||||
# Agent actions
|
||||
|
||||
@@ -128,6 +128,13 @@ class Folder(FileSystemItemABC):
|
||||
for file_id in self.files:
|
||||
self.files[file_id].apply_timestep(timestep=timestep)
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
|
||||
for file in self.files.values():
|
||||
file.pre_timestep(timestep)
|
||||
|
||||
def _scan_timestep(self) -> None:
|
||||
"""Apply the scan action timestep."""
|
||||
if self.scan_countdown >= 0:
|
||||
|
||||
@@ -157,7 +157,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC):
|
||||
return
|
||||
|
||||
if not self._connected_node:
|
||||
_LOGGER.error(f"Interface {self} cannot be enabled as it is not connected to a Node")
|
||||
_LOGGER.warning(f"Interface {self} cannot be enabled as it is not connected to a Node")
|
||||
return
|
||||
|
||||
if self._connected_node.operating_state != NodeOperatingState.ON:
|
||||
@@ -271,7 +271,7 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC)
|
||||
# Update the state with information from Layer3Interface
|
||||
state.update(Layer3Interface.describe_state(self))
|
||||
|
||||
state["frequency"] = self.frequency
|
||||
state["frequency"] = self.frequency.value
|
||||
|
||||
return state
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -8,6 +9,7 @@ from prettytable import MARKDOWN, PrettyTable
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.network.hardware.base import Link, Node, WiredNetworkInterface
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Printer
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.services.service import Service
|
||||
|
||||
@@ -85,6 +87,16 @@ class Network(SimComponent):
|
||||
for link_id in self.links:
|
||||
self.links[link_id].apply_timestep(timestep=timestep)
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
|
||||
for node in self.nodes.values():
|
||||
node.pre_timestep(timestep)
|
||||
|
||||
for link in self.links.values():
|
||||
link.pre_timestep(timestep)
|
||||
|
||||
@property
|
||||
def router_nodes(self) -> List[Node]:
|
||||
"""The Routers in the Network."""
|
||||
@@ -110,6 +122,16 @@ class Network(SimComponent):
|
||||
"""The Firewalls in the Network."""
|
||||
return [node for node in self.nodes.values() if node.__class__.__name__ == "Firewall"]
|
||||
|
||||
@property
|
||||
def printer_nodes(self) -> List[Node]:
|
||||
"""The printers on the network."""
|
||||
return [node for node in self.nodes.values() if isinstance(node, Printer)]
|
||||
|
||||
@property
|
||||
def wireless_router_nodes(self) -> List[Node]:
|
||||
"""The Routers in the Network."""
|
||||
return [node for node in self.nodes.values() if node.__class__.__name__ == "WirelessRouter"]
|
||||
|
||||
def show(self, nodes: bool = True, ip_addresses: bool = True, links: bool = True, markdown: bool = False):
|
||||
"""
|
||||
Print tables describing the Network.
|
||||
@@ -128,6 +150,8 @@ class Network(SimComponent):
|
||||
"Switch": self.switch_nodes,
|
||||
"Server": self.server_nodes,
|
||||
"Computer": self.computer_nodes,
|
||||
"Printer": self.printer_nodes,
|
||||
"Wireless Router": self.wireless_router_nodes,
|
||||
}
|
||||
if nodes:
|
||||
table = PrettyTable(["Node", "Type", "Operating State"])
|
||||
@@ -150,14 +174,17 @@ class Network(SimComponent):
|
||||
for node in nodes:
|
||||
for i, port in node.network_interface.items():
|
||||
if hasattr(port, "ip_address"):
|
||||
port_str = port.port_name if port.port_name else port.port_num
|
||||
table.add_row(
|
||||
[node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway]
|
||||
)
|
||||
if port.ip_address != IPv4Address("127.0.0.1"):
|
||||
port_str = port.port_name if port.port_name else port.port_num
|
||||
table.add_row(
|
||||
[node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway]
|
||||
)
|
||||
print(table)
|
||||
|
||||
if links:
|
||||
table = PrettyTable(["Endpoint A", "Endpoint B", "is Up", "Bandwidth (MBits)", "Current Load"])
|
||||
table = PrettyTable(
|
||||
["Endpoint A", "A Port", "Endpoint B", "B Port", "is Up", "Bandwidth (MBits)", "Current Load"]
|
||||
)
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
@@ -170,7 +197,9 @@ class Network(SimComponent):
|
||||
table.add_row(
|
||||
[
|
||||
link.endpoint_a.parent.hostname,
|
||||
str(link.endpoint_a),
|
||||
link.endpoint_b.parent.hostname,
|
||||
str(link.endpoint_b),
|
||||
link.is_up,
|
||||
link.bandwidth,
|
||||
link.current_load_percent,
|
||||
@@ -208,18 +237,19 @@ class Network(SimComponent):
|
||||
}
|
||||
)
|
||||
# Update the links one-by-one. The key is a 4-tuple of `hostname_a, port_a, hostname_b, port_b`
|
||||
for uuid, link in self.links.items():
|
||||
for _, link in self.links.items():
|
||||
node_a = link.endpoint_a._connected_node
|
||||
node_b = link.endpoint_b._connected_node
|
||||
hostname_a = node_a.hostname if node_a else None
|
||||
hostname_b = node_b.hostname if node_b else None
|
||||
port_a = link.endpoint_a.port_num
|
||||
port_b = link.endpoint_b.port_num
|
||||
state["links"][uuid] = link.describe_state()
|
||||
state["links"][uuid]["hostname_a"] = hostname_a
|
||||
state["links"][uuid]["hostname_b"] = hostname_b
|
||||
state["links"][uuid]["port_a"] = port_a
|
||||
state["links"][uuid]["port_b"] = port_b
|
||||
link_key = f"{hostname_a}:eth-{port_a}<->{hostname_b}:eth-{port_b}"
|
||||
state["links"][link_key] = link.describe_state()
|
||||
state["links"][link_key]["hostname_a"] = hostname_a
|
||||
state["links"][link_key]["hostname_b"] = hostname_b
|
||||
state["links"][link_key]["port_a"] = port_a
|
||||
state["links"][link_key]["port_b"] = port_b
|
||||
|
||||
return state
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
|
||||
|
||||
def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
|
||||
def num_of_switches_required(num_nodes: int, max_network_interface: int = 24) -> int:
|
||||
"""
|
||||
Calculate the minimum number of network switches required to connect a given number of nodes.
|
||||
|
||||
@@ -18,7 +18,7 @@ def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
|
||||
to accommodate all nodes under this constraint.
|
||||
|
||||
:param num_nodes: The total number of nodes that need to be connected in the network.
|
||||
:param max_switch_ports: The maximum number of ports available on each switch. Defaults to 24.
|
||||
:param max_network_interface: The maximum number of ports available on each switch. Defaults to 24.
|
||||
|
||||
:return: The minimum number of switches required to connect all PCs.
|
||||
|
||||
@@ -33,11 +33,11 @@ def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
|
||||
3
|
||||
"""
|
||||
# Reduce the effective number of switch ports by 1 to leave space for the router
|
||||
effective_switch_ports = max_switch_ports - 1
|
||||
effective_network_interface = max_network_interface - 1
|
||||
|
||||
# Calculate the number of fully utilised switches and any additional switch for remaining PCs
|
||||
full_switches = num_nodes // effective_switch_ports
|
||||
extra_pcs = num_nodes % effective_switch_ports
|
||||
full_switches = num_nodes // effective_network_interface
|
||||
extra_pcs = num_nodes % effective_network_interface
|
||||
|
||||
# Return the total number of switches required
|
||||
return full_switches + (1 if extra_pcs > 0 else 0)
|
||||
@@ -77,7 +77,7 @@ def create_office_lan(
|
||||
|
||||
# Calculate the required number of switches
|
||||
num_of_switches = num_of_switches_required(num_nodes=num_pcs)
|
||||
effective_switch_ports = 23 # One port less for router connection
|
||||
effective_network_interface = 23 # One port less for router connection
|
||||
if pcs_ip_block_start <= num_of_switches:
|
||||
raise ValueError(f"pcs_ip_block_start must be greater than the number of required switches {num_of_switches}")
|
||||
|
||||
@@ -116,7 +116,7 @@ def create_office_lan(
|
||||
# Add PCs to the LAN and connect them to switches
|
||||
for i in range(1, num_pcs + 1):
|
||||
# Add a new edge switch if the current one is full
|
||||
if switch_port == effective_switch_ports:
|
||||
if switch_port == effective_network_interface:
|
||||
switch_n += 1
|
||||
switch_port = 0
|
||||
switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0)
|
||||
|
||||
@@ -5,7 +5,7 @@ import secrets
|
||||
from abc import ABC, abstractmethod
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Type, TypeVar, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -35,8 +35,11 @@ from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.processes.process import Process
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.software import IOSoftware
|
||||
from primaite.utils.validators import IPV4Address
|
||||
|
||||
IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware)
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -108,7 +111,7 @@ class NetworkInterface(SimComponent, ABC):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
super().setup_for_episode(episode=episode)
|
||||
self.nmne = {}
|
||||
if episode and self.pcap:
|
||||
if episode and self.pcap and SIM_OUTPUT.save_pcap_logs:
|
||||
self.pcap.current_episode = episode
|
||||
self.pcap.setup_logger()
|
||||
self.enable()
|
||||
@@ -261,6 +264,9 @@ class NetworkInterface(SimComponent, ABC):
|
||||
"""
|
||||
return f"Port {self.port_name if self.port_name else self.port_num}: {self.mac_address}"
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.uuid)
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
Apply a timestep evolution to this component.
|
||||
@@ -297,7 +303,7 @@ class WiredNetworkInterface(NetworkInterface, ABC):
|
||||
return True
|
||||
|
||||
if not self._connected_node:
|
||||
_LOGGER.error(f"Interface {self} cannot be enabled as it is not connected to a Node")
|
||||
_LOGGER.warning(f"Interface {self} cannot be enabled as it is not connected to a Node")
|
||||
return False
|
||||
|
||||
if self._connected_node.operating_state != NodeOperatingState.ON:
|
||||
@@ -343,11 +349,11 @@ class WiredNetworkInterface(NetworkInterface, ABC):
|
||||
:param link: The Link instance to connect to this network interface.
|
||||
"""
|
||||
if self._connected_link:
|
||||
_LOGGER.error(f"Cannot connect Link to network interface {self} as it already has a connection")
|
||||
_LOGGER.warning(f"Cannot connect Link to network interface {self} as it already has a connection")
|
||||
return
|
||||
|
||||
if self._connected_link == link:
|
||||
_LOGGER.error(f"Cannot connect Link to network interface {self} as it is already connected")
|
||||
_LOGGER.warning(f"Cannot connect Link to network interface {self} as it is already connected")
|
||||
return
|
||||
|
||||
self._connected_link = link
|
||||
@@ -519,12 +525,10 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
|
||||
"""
|
||||
super().enable()
|
||||
try:
|
||||
pass
|
||||
self._connected_node.default_gateway_hello()
|
||||
return True
|
||||
except AttributeError:
|
||||
pass
|
||||
return False
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def receive_frame(self, frame: Frame) -> bool:
|
||||
@@ -660,6 +664,10 @@ class Link(SimComponent):
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""Apply a timestep to the simulation."""
|
||||
super().apply_timestep(timestep)
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
self.current_load = 0.0
|
||||
|
||||
|
||||
@@ -845,12 +853,62 @@ class Node(SimComponent):
|
||||
)
|
||||
rm.add_request("os", RequestType(func=self._os_request_manager, validator=_node_is_on))
|
||||
|
||||
self._software_request_manager = RequestManager()
|
||||
rm.add_request("software_manager", RequestType(func=self._software_request_manager, validator=_node_is_on))
|
||||
self._application_manager = RequestManager()
|
||||
self._software_request_manager.add_request(
|
||||
name="application", request_type=RequestType(func=self._application_manager)
|
||||
)
|
||||
|
||||
self._application_manager.add_request(
|
||||
name="install",
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.application_install_action(
|
||||
application=self._read_application_type(request[0]), ip_address=request[1]
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
self._application_manager.add_request(
|
||||
name="uninstall",
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.application_uninstall_action(application=self._read_application_type(request[0]))
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return rm
|
||||
|
||||
def _install_system_software(self):
|
||||
"""Install System Software - software that is usually provided with the OS."""
|
||||
pass
|
||||
|
||||
def _read_application_type(self, application_class_str: str) -> Type[IOSoftwareClass]:
|
||||
"""Wrapper that converts the string from the request manager into the appropriate class for the application."""
|
||||
if application_class_str == "DoSBot":
|
||||
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
|
||||
|
||||
return DoSBot
|
||||
elif application_class_str == "DataManipulationBot":
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import (
|
||||
DataManipulationBot,
|
||||
)
|
||||
|
||||
return DataManipulationBot
|
||||
elif application_class_str == "WebBrowser":
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
|
||||
return WebBrowser
|
||||
elif application_class_str == "RansomwareScript":
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
|
||||
|
||||
return RansomwareScript
|
||||
else:
|
||||
return 0
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
@@ -891,8 +949,9 @@ class Node(SimComponent):
|
||||
table.align = "l"
|
||||
table.title = f"{self.hostname} Open Ports"
|
||||
for port in self.software_manager.get_open_ports():
|
||||
table.add_row([port.value, port.name])
|
||||
print(table)
|
||||
if port.value > 0:
|
||||
table.add_row([port.value, port.name])
|
||||
print(table.get_string(sortby="Port"))
|
||||
|
||||
@property
|
||||
def has_enabled_network_interface(self) -> bool:
|
||||
@@ -917,12 +976,15 @@ class Node(SimComponent):
|
||||
table.align = "l"
|
||||
table.title = f"{self.hostname} Network Interface Cards"
|
||||
for port, network_interface in self.network_interface.items():
|
||||
ip_address = ""
|
||||
if hasattr(network_interface, "ip_address"):
|
||||
ip_address = f"{network_interface.ip_address}/{network_interface.ip_network.prefixlen}"
|
||||
table.add_row(
|
||||
[
|
||||
port,
|
||||
type(network_interface),
|
||||
network_interface.__class__.__name__,
|
||||
network_interface.mac_address,
|
||||
f"{network_interface.ip_address}/{network_interface.ip_network.prefixlen}",
|
||||
ip_address,
|
||||
network_interface.speed,
|
||||
"Enabled" if network_interface.enabled else "Disabled",
|
||||
]
|
||||
@@ -1023,6 +1085,23 @@ class Node(SimComponent):
|
||||
|
||||
self.file_system.apply_timestep(timestep=timestep)
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
for network_interface in self.network_interfaces.values():
|
||||
network_interface.pre_timestep(timestep=timestep)
|
||||
|
||||
for process_id in self.processes:
|
||||
self.processes[process_id].pre_timestep(timestep=timestep)
|
||||
|
||||
for service_id in self.services:
|
||||
self.services[service_id].pre_timestep(timestep=timestep)
|
||||
|
||||
for application_id in self.applications:
|
||||
self.applications[application_id].pre_timestep(timestep=timestep)
|
||||
|
||||
self.file_system.pre_timestep(timestep=timestep)
|
||||
|
||||
def scan(self) -> bool:
|
||||
"""
|
||||
Scan the node and all the items within it.
|
||||
@@ -1259,6 +1338,77 @@ class Node(SimComponent):
|
||||
_LOGGER.info(f"Removed application {application.name} from node {self.hostname}")
|
||||
self._application_request_manager.remove_request(application.name)
|
||||
|
||||
def application_install_action(self, application: Application, ip_address: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Install an application on this node and configure it.
|
||||
|
||||
This method is useful for allowing agents to take this action.
|
||||
|
||||
:param application: Application object that has not been installed on any node yet.
|
||||
:type application: Application
|
||||
:param ip_address: IP address used to configure the application
|
||||
(target IP for the DoSBot or server IP for the DataManipulationBot)
|
||||
:type ip_address: str
|
||||
:return: True if the application is installed successfully, otherwise False.
|
||||
"""
|
||||
if application in self:
|
||||
_LOGGER.warning(
|
||||
f"Can't add application {application.__name__}" + f"to node {self.hostname}. It's already installed."
|
||||
)
|
||||
return True
|
||||
|
||||
self.software_manager.install(application)
|
||||
application_instance = self.software_manager.software.get(str(application.__name__))
|
||||
self.applications[application_instance.uuid] = application_instance
|
||||
self.sys_log.info(f"Installed application {application_instance.name}")
|
||||
_LOGGER.debug(f"Added application {application_instance.name} to node {self.hostname}")
|
||||
self._application_request_manager.add_request(
|
||||
application_instance.name, RequestType(func=application_instance._request_manager)
|
||||
)
|
||||
|
||||
# Configure application if additional parameters are given
|
||||
if ip_address:
|
||||
if application_instance.name == "DoSBot":
|
||||
application_instance.configure(target_ip_address=IPv4Address(ip_address))
|
||||
elif application_instance.name == "DataManipulationBot":
|
||||
application_instance.configure(server_ip_address=IPv4Address(ip_address))
|
||||
elif application_instance.name == "RansomwareScript":
|
||||
application_instance.configure(server_ip_address=IPv4Address(ip_address))
|
||||
else:
|
||||
pass
|
||||
|
||||
if application_instance.name in self.software_manager.software:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def application_uninstall_action(self, application: Application) -> bool:
|
||||
"""
|
||||
Uninstall and completely remove application from this node.
|
||||
|
||||
This method is useful for allowing agents to take this action.
|
||||
|
||||
:param application: Application object that is currently associated with this node.
|
||||
:type application: Application
|
||||
:return: True if the application is uninstalled successfully, otherwise False.
|
||||
"""
|
||||
if application.__name__ not in self.software_manager.software:
|
||||
_LOGGER.warning(
|
||||
f"Can't remove application {application.__name__}" + f"from node {self.hostname}. It's not installed."
|
||||
)
|
||||
return True
|
||||
|
||||
application_instance = self.software_manager.software.get(
|
||||
str(application.__name__)
|
||||
) # This works because we can't have two applications with the same name on the same node
|
||||
# self.uninstall_application(application_instance)
|
||||
self.software_manager.uninstall(application_instance.name)
|
||||
|
||||
if application_instance.name not in self.software_manager.software:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _shut_down_actions(self):
|
||||
"""Actions to perform when the node is shut down."""
|
||||
# Turn off all the services in the node
|
||||
@@ -1290,4 +1440,6 @@ class Node(SimComponent):
|
||||
def __contains__(self, item: Any) -> bool:
|
||||
if isinstance(item, Service):
|
||||
return item.uuid in self.services
|
||||
elif isinstance(item, Application):
|
||||
return item.uuid in self.applications
|
||||
return None
|
||||
|
||||
@@ -316,6 +316,16 @@ class HostNode(Node):
|
||||
super().__init__(**kwargs)
|
||||
self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask))
|
||||
|
||||
@property
|
||||
def arp(self) -> Optional[ARP]:
|
||||
"""
|
||||
Return the ARP Cache of the HostNode.
|
||||
|
||||
:return: ARP Cache for given HostNode
|
||||
:rtype: Optional[ARP]
|
||||
"""
|
||||
return self.software_manager.software.get("ARP")
|
||||
|
||||
def _install_system_software(self):
|
||||
"""
|
||||
Installs the system software and network services typically found on an operating system.
|
||||
|
||||
@@ -28,3 +28,9 @@ class Server(HostNode):
|
||||
* Applications:
|
||||
* Web Browser
|
||||
"""
|
||||
|
||||
|
||||
class Printer(HostNode):
|
||||
"""Printer? I don't even know her!."""
|
||||
|
||||
# TODO: Implement printer-specific behaviour
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Final, Optional, Union
|
||||
from typing import Dict, Final, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import validate_call
|
||||
from pydantic import Field, validate_call
|
||||
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.network.router import (
|
||||
AccessControlList,
|
||||
@@ -67,22 +68,34 @@ class Firewall(Router):
|
||||
:ivar str hostname: The Firewall hostname.
|
||||
"""
|
||||
|
||||
internal_inbound_acl: Optional[AccessControlList] = None
|
||||
internal_inbound_acl: AccessControlList = Field(
|
||||
default_factory=lambda: AccessControlList(name="Internal Inbound", implicit_action=ACLAction.DENY)
|
||||
)
|
||||
"""Access Control List for managing entering the internal network."""
|
||||
|
||||
internal_outbound_acl: Optional[AccessControlList] = None
|
||||
internal_outbound_acl: AccessControlList = Field(
|
||||
default_factory=lambda: AccessControlList(name="Internal Outbound", implicit_action=ACLAction.DENY)
|
||||
)
|
||||
"""Access Control List for managing traffic leaving the internal network."""
|
||||
|
||||
dmz_inbound_acl: Optional[AccessControlList] = None
|
||||
dmz_inbound_acl: AccessControlList = Field(
|
||||
default_factory=lambda: AccessControlList(name="DMZ Inbound", implicit_action=ACLAction.DENY)
|
||||
)
|
||||
"""Access Control List for managing traffic entering the DMZ."""
|
||||
|
||||
dmz_outbound_acl: Optional[AccessControlList] = None
|
||||
dmz_outbound_acl: AccessControlList = Field(
|
||||
default_factory=lambda: AccessControlList(name="DMZ Outbound", implicit_action=ACLAction.DENY)
|
||||
)
|
||||
"""Access Control List for managing traffic leaving the DMZ."""
|
||||
|
||||
external_inbound_acl: Optional[AccessControlList] = None
|
||||
external_inbound_acl: AccessControlList = Field(
|
||||
default_factory=lambda: AccessControlList(name="External Inbound", implicit_action=ACLAction.PERMIT)
|
||||
)
|
||||
"""Access Control List for managing traffic entering from an external network."""
|
||||
|
||||
external_outbound_acl: Optional[AccessControlList] = None
|
||||
external_outbound_acl: AccessControlList = Field(
|
||||
default_factory=lambda: AccessControlList(name="External Outbound", implicit_action=ACLAction.PERMIT)
|
||||
)
|
||||
"""Access Control List for managing traffic leaving towards an external network."""
|
||||
|
||||
def __init__(self, hostname: str, **kwargs):
|
||||
@@ -100,29 +113,85 @@ class Firewall(Router):
|
||||
self.connect_nic(
|
||||
RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", port_name="dmz")
|
||||
)
|
||||
# Update ACL objects with firewall's hostname and syslog to allow accurate logging
|
||||
self.internal_inbound_acl.sys_log = kwargs["sys_log"]
|
||||
self.internal_inbound_acl.name = f"{hostname} - Internal Inbound"
|
||||
|
||||
# Initialise ACLs for internal and dmz interfaces with a default DENY policy
|
||||
self.internal_inbound_acl = AccessControlList(
|
||||
sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=f"{hostname} - Internal Inbound"
|
||||
self.internal_outbound_acl.sys_log = kwargs["sys_log"]
|
||||
self.internal_outbound_acl.name = f"{hostname} - Internal Outbound"
|
||||
|
||||
self.dmz_inbound_acl.sys_log = kwargs["sys_log"]
|
||||
self.dmz_inbound_acl.name = f"{hostname} - DMZ Inbound"
|
||||
|
||||
self.dmz_outbound_acl.sys_log = kwargs["sys_log"]
|
||||
self.dmz_outbound_acl.name = f"{hostname} - DMZ Outbound"
|
||||
|
||||
self.external_inbound_acl.sys_log = kwargs["sys_log"]
|
||||
self.external_inbound_acl.name = f"{hostname} - External Inbound"
|
||||
|
||||
self.external_outbound_acl.sys_log = kwargs["sys_log"]
|
||||
self.external_outbound_acl.name = f"{hostname} - External Outbound"
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
self._internal_acl_request_manager = RequestManager()
|
||||
rm.add_request("internal", RequestType(func=self._internal_acl_request_manager))
|
||||
|
||||
self._dmz_acl_request_manager = RequestManager()
|
||||
rm.add_request("dmz", RequestType(func=self._dmz_acl_request_manager))
|
||||
|
||||
self._external_acl_request_manager = RequestManager()
|
||||
rm.add_request("external", RequestType(func=self._external_acl_request_manager))
|
||||
|
||||
self._internal_inbound_acl_request_manager = RequestManager()
|
||||
self._internal_outbound_acl_request_manager = RequestManager()
|
||||
self._internal_acl_request_manager.add_request(
|
||||
"inbound", RequestType(func=self._internal_inbound_acl_request_manager)
|
||||
)
|
||||
self.internal_outbound_acl = AccessControlList(
|
||||
sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=f"{hostname} - Internal Outbound"
|
||||
)
|
||||
self.dmz_inbound_acl = AccessControlList(
|
||||
sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=f"{hostname} - DMZ Inbound"
|
||||
)
|
||||
self.dmz_outbound_acl = AccessControlList(
|
||||
sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=f"{hostname} - DMZ Outbound"
|
||||
self._internal_acl_request_manager.add_request(
|
||||
"outbound", RequestType(func=self._internal_outbound_acl_request_manager)
|
||||
)
|
||||
|
||||
# external ACLs should have a default PERMIT policy
|
||||
self.external_inbound_acl = AccessControlList(
|
||||
sys_log=kwargs["sys_log"], implicit_action=ACLAction.PERMIT, name=f"{hostname} - External Inbound"
|
||||
self.dmz_inbound_acl_request_manager = RequestManager()
|
||||
self.dmz_outbound_acl_request_manager = RequestManager()
|
||||
self._dmz_acl_request_manager.add_request("inbound", RequestType(func=self.dmz_inbound_acl_request_manager))
|
||||
self._dmz_acl_request_manager.add_request("outbound", RequestType(func=self.dmz_outbound_acl_request_manager))
|
||||
|
||||
self.external_inbound_acl_request_manager = RequestManager()
|
||||
self.external_outbound_acl_request_manager = RequestManager()
|
||||
self._external_acl_request_manager.add_request(
|
||||
"inbound", RequestType(func=self.external_inbound_acl_request_manager)
|
||||
)
|
||||
self.external_outbound_acl = AccessControlList(
|
||||
sys_log=kwargs["sys_log"], implicit_action=ACLAction.PERMIT, name=f"{hostname} - External Outbound"
|
||||
self._external_acl_request_manager.add_request(
|
||||
"outbound", RequestType(func=self.external_outbound_acl_request_manager)
|
||||
)
|
||||
|
||||
self._internal_inbound_acl_request_manager.add_request(
|
||||
"acl", RequestType(func=self.internal_inbound_acl._request_manager)
|
||||
)
|
||||
self._internal_outbound_acl_request_manager.add_request(
|
||||
"acl", RequestType(func=self.internal_outbound_acl._request_manager)
|
||||
)
|
||||
|
||||
self.dmz_inbound_acl_request_manager.add_request("acl", RequestType(func=self.dmz_inbound_acl._request_manager))
|
||||
self.dmz_outbound_acl_request_manager.add_request(
|
||||
"acl", RequestType(func=self.dmz_outbound_acl._request_manager)
|
||||
)
|
||||
|
||||
self.external_inbound_acl_request_manager.add_request(
|
||||
"acl", RequestType(func=self.external_inbound_acl._request_manager)
|
||||
)
|
||||
self.external_outbound_acl_request_manager.add_request(
|
||||
"acl", RequestType(func=self.external_outbound_acl._request_manager)
|
||||
)
|
||||
|
||||
return rm
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the Firewall.
|
||||
@@ -530,7 +599,9 @@ class Firewall(Router):
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
@@ -543,7 +614,9 @@ class Firewall(Router):
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
@@ -556,7 +629,9 @@ class Firewall(Router):
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
@@ -569,7 +644,9 @@ class Firewall(Router):
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
@@ -582,7 +659,9 @@ class Firewall(Router):
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
@@ -595,7 +674,9 @@ class Firewall(Router):
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from primaite.simulator.network.hardware.base import NetworkInterface, Node
|
||||
from primaite.simulator.network.transmission.data_link_layer import Frame
|
||||
from primaite.simulator.system.services.arp.arp import ARP
|
||||
|
||||
|
||||
class NetworkNode(Node):
|
||||
@@ -28,3 +30,13 @@ class NetworkNode(Node):
|
||||
:type from_network_interface: NetworkInterface
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def arp(self) -> Optional[ARP]:
|
||||
"""
|
||||
Return the ARP Cache of the NetworkNode.
|
||||
|
||||
:return: ARP Cache for given NetworkNode
|
||||
:rtype: Optional[ARP]
|
||||
"""
|
||||
return self.software_manager.software.get("ARP")
|
||||
|
||||
@@ -18,6 +18,7 @@ from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType
|
||||
from primaite.simulator.network.transmission.data_link_layer import Frame
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.session_manager import SessionManager
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.services.arp.arp import ARP
|
||||
from primaite.simulator.system.services.icmp.icmp import ICMP
|
||||
@@ -147,8 +148,10 @@ class ACLRule(SimComponent):
|
||||
state["action"] = self.action.value
|
||||
state["protocol"] = self.protocol.name if self.protocol else None
|
||||
state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None
|
||||
state["src_wildcard_mask"] = str(self.src_wildcard_mask) if self.src_wildcard_mask else None
|
||||
state["src_port"] = self.src_port.name if self.src_port else None
|
||||
state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None
|
||||
state["dst_wildcard_mask"] = str(self.dst_wildcard_mask) if self.dst_wildcard_mask else None
|
||||
state["dst_port"] = self.dst_port.name if self.dst_port else None
|
||||
state["match_count"] = self.match_count
|
||||
return state
|
||||
@@ -275,7 +278,7 @@ class AccessControlList(SimComponent):
|
||||
:ivar int max_acl_rules: The maximum number of ACL rules that can be added to the list. Defaults to 25.
|
||||
"""
|
||||
|
||||
sys_log: SysLog
|
||||
sys_log: Optional[SysLog] = None
|
||||
implicit_action: ACLAction
|
||||
implicit_rule: ACLRule
|
||||
max_acl_rules: int = 25
|
||||
@@ -319,10 +322,12 @@ class AccessControlList(SimComponent):
|
||||
action=ACLAction[request[0]],
|
||||
protocol=None if request[1] == "ALL" else IPProtocol[request[1]],
|
||||
src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]),
|
||||
src_port=None if request[3] == "ALL" else Port[request[3]],
|
||||
dst_ip_address=None if request[4] == "ALL" else IPv4Address(request[4]),
|
||||
dst_port=None if request[5] == "ALL" else Port[request[5]],
|
||||
position=int(request[6]),
|
||||
src_wildcard_mask=None if request[3] == "NONE" else IPv4Address(request[3]),
|
||||
src_port=None if request[4] == "ALL" else Port[request[4]],
|
||||
dst_ip_address=None if request[5] == "ALL" else IPv4Address(request[5]),
|
||||
dst_wildcard_mask=None if request[6] == "NONE" else IPv4Address(request[6]),
|
||||
dst_port=None if request[7] == "ALL" else Port[request[7]],
|
||||
position=int(request[8]),
|
||||
)
|
||||
)
|
||||
),
|
||||
@@ -624,11 +629,12 @@ class RouteTable(SimComponent):
|
||||
"""
|
||||
pass
|
||||
|
||||
@validate_call()
|
||||
def add_route(
|
||||
self,
|
||||
address: Union[IPv4Address, str],
|
||||
subnet_mask: Union[IPv4Address, str],
|
||||
next_hop_ip_address: Union[IPv4Address, str],
|
||||
address: Union[IPV4Address, str],
|
||||
subnet_mask: Union[IPV4Address, str],
|
||||
next_hop_ip_address: Union[IPV4Address, str],
|
||||
metric: float = 0.0,
|
||||
):
|
||||
"""
|
||||
@@ -647,7 +653,8 @@ class RouteTable(SimComponent):
|
||||
)
|
||||
self.routes.append(route)
|
||||
|
||||
def set_default_route_next_hop_ip_address(self, ip_address: IPv4Address):
|
||||
@validate_call()
|
||||
def set_default_route_next_hop_ip_address(self, ip_address: IPV4Address):
|
||||
"""
|
||||
Sets the next-hop IP address for the default route in a routing table.
|
||||
|
||||
@@ -660,7 +667,7 @@ class RouteTable(SimComponent):
|
||||
"""
|
||||
if not self.default_route:
|
||||
self.default_route = RouteEntry(
|
||||
ip_address=IPv4Address("0.0.0.0"),
|
||||
address=IPv4Address("0.0.0.0"),
|
||||
subnet_mask=IPv4Address("0.0.0.0"),
|
||||
next_hop_ip_address=ip_address,
|
||||
)
|
||||
@@ -767,6 +774,13 @@ class RouterARP(ARP):
|
||||
is_reattempt=True,
|
||||
is_default_route_attempt=is_default_route_attempt,
|
||||
)
|
||||
elif route and route == self.router.route_table.default_route:
|
||||
self.send_arp_request(self.router.route_table.default_route.next_hop_ip_address)
|
||||
return self._get_arp_cache_mac_address(
|
||||
ip_address=self.router.route_table.default_route.next_hop_ip_address,
|
||||
is_reattempt=True,
|
||||
is_default_route_attempt=True,
|
||||
)
|
||||
else:
|
||||
if self.router.route_table.default_route:
|
||||
if not is_default_route_attempt:
|
||||
@@ -817,6 +831,12 @@ class RouterARP(ARP):
|
||||
return network_interface
|
||||
|
||||
if not is_reattempt:
|
||||
if self.router.ip_is_in_router_interface_subnet(ip_address):
|
||||
self.send_arp_request(ip_address)
|
||||
return self._get_arp_cache_network_interface(
|
||||
ip_address=ip_address, is_reattempt=True, is_default_route_attempt=is_default_route_attempt
|
||||
)
|
||||
|
||||
route = self.router.route_table.find_best_route(ip_address)
|
||||
if route and route != self.router.route_table.default_route:
|
||||
self.send_arp_request(route.next_hop_ip_address)
|
||||
@@ -825,6 +845,13 @@ class RouterARP(ARP):
|
||||
is_reattempt=True,
|
||||
is_default_route_attempt=is_default_route_attempt,
|
||||
)
|
||||
elif route and route == self.router.route_table.default_route:
|
||||
self.send_arp_request(self.router.route_table.default_route.next_hop_ip_address)
|
||||
return self._get_arp_cache_network_interface(
|
||||
ip_address=self.router.route_table.default_route.next_hop_ip_address,
|
||||
is_reattempt=True,
|
||||
is_default_route_attempt=True,
|
||||
)
|
||||
else:
|
||||
if self.router.route_table.default_route:
|
||||
if not is_default_route_attempt:
|
||||
@@ -1016,6 +1043,144 @@ class RouterInterface(IPWiredNetworkInterface):
|
||||
return f"Port {self.port_name if self.port_name else self.port_num}: {self.mac_address}/{self.ip_address}"
|
||||
|
||||
|
||||
class RouterSessionManager(SessionManager):
|
||||
"""
|
||||
Manages network sessions, including session creation, lookup, and communication with other components.
|
||||
|
||||
The RouterSessionManager is a Router/Firewall specific implementation of SessionManager. It overrides the
|
||||
resolve_outbound_network_interface and resolve_outbound_transmission_details functions, allowing them to leverage
|
||||
the route table instead of the default gateway.
|
||||
|
||||
:param sys_log: A reference to the system log component.
|
||||
"""
|
||||
|
||||
def resolve_outbound_network_interface(self, dst_ip_address: IPv4Address) -> Optional[RouterInterface]:
|
||||
"""
|
||||
Resolves the appropriate outbound network interface for a given destination IP address.
|
||||
|
||||
This method determines the most suitable network interface for sending a packet to the specified
|
||||
destination IP address. It considers only enabled network interfaces and checks if the destination
|
||||
IP address falls within the subnet of each interface. If no suitable local network interface is found,
|
||||
the method defaults to performing a route table look-up to determine if there is a dedicated route or a default
|
||||
route it can use.
|
||||
|
||||
The search process prioritises local network interfaces based on the IP network to which they belong.
|
||||
If the destination IP address does not match any local subnet, the method assumes that the destination
|
||||
is outside the local network and hence, routes the packet according to route table look-up.
|
||||
|
||||
:param dst_ip_address: The destination IP address for which the outbound interface is to be resolved.
|
||||
:type dst_ip_address: IPv4Address
|
||||
:return: The network interface through which the packet should be sent to reach the destination IP address,
|
||||
or the default gateway's network interface if the destination is not within any local subnet.
|
||||
:rtype: Optional[RouterInterface]
|
||||
"""
|
||||
network_interface = super().resolve_outbound_network_interface(dst_ip_address)
|
||||
if not network_interface:
|
||||
route = self.node.route_table.find_best_route(dst_ip_address)
|
||||
if not route:
|
||||
return None
|
||||
network_interface = super().resolve_outbound_network_interface(route.next_hop_ip_address)
|
||||
return network_interface
|
||||
|
||||
def resolve_outbound_transmission_details(
|
||||
self,
|
||||
dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
|
||||
src_port: Optional[Port] = None,
|
||||
dst_port: Optional[Port] = None,
|
||||
protocol: Optional[IPProtocol] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Tuple[
|
||||
Optional[RouterInterface],
|
||||
Optional[str],
|
||||
IPv4Address,
|
||||
Optional[Port],
|
||||
Optional[Port],
|
||||
Optional[IPProtocol],
|
||||
bool,
|
||||
]:
|
||||
"""
|
||||
Resolves the necessary details for outbound transmission based on the provided parameters.
|
||||
|
||||
This method determines whether the payload should be broadcast or unicast based on the destination IP address
|
||||
and resolves the outbound network interface and destination MAC address accordingly.
|
||||
|
||||
The method first checks if `session_id` is provided and uses the session details if available. For broadcast
|
||||
transmissions, it finds a suitable network interface and uses a broadcast MAC address. For unicast
|
||||
transmissions, it attempts to resolve the destination MAC address using ARP and finds the appropriate
|
||||
outbound network interface. If the destination IP address is outside the local network and no specific MAC
|
||||
address is resolved, it defaults to performing a route table look-up to determine if there is a dedicated route
|
||||
or a default route it can use.
|
||||
|
||||
:param dst_ip_address: The destination IP address or network. If an IPv4Network is provided, the method
|
||||
treats the transmission as a broadcast to that network. Optional.
|
||||
:type dst_ip_address: Optional[Union[IPv4Address, IPv4Network]]
|
||||
:param src_port: The source port number for the transmission. Optional.
|
||||
:type src_port: Optional[Port]
|
||||
:param dst_port: The destination port number for the transmission. Optional.
|
||||
:type dst_port: Optional[Port]
|
||||
:param protocol: The IP protocol to be used for the transmission. Optional.
|
||||
:type protocol: Optional[IPProtocol]
|
||||
:param session_id: The session ID associated with the transmission. If provided, the session details override
|
||||
other parameters. Optional.
|
||||
:type session_id: Optional[str]
|
||||
:return: A tuple containing the resolved outbound network interface, destination MAC address, destination IP
|
||||
address, source port, destination port, protocol, and a boolean indicating whether the transmission is a
|
||||
broadcast.
|
||||
:rtype: Tuple[Optional[RouterInterface], Optional[str], IPv4Address, Optional[Port], Optional[Port],
|
||||
Optional[IPProtocol], bool]
|
||||
"""
|
||||
if dst_ip_address and not isinstance(dst_ip_address, (IPv4Address, IPv4Network)):
|
||||
dst_ip_address = IPv4Address(dst_ip_address)
|
||||
is_broadcast = False
|
||||
outbound_network_interface = None
|
||||
dst_mac_address = None
|
||||
|
||||
# Use session details if session_id is provided
|
||||
if session_id:
|
||||
session = self.sessions_by_uuid[session_id]
|
||||
|
||||
dst_ip_address = session.with_ip_address
|
||||
protocol = session.protocol
|
||||
src_port = session.src_port
|
||||
dst_port = session.dst_port
|
||||
|
||||
# Determine if the payload is for broadcast or unicast
|
||||
|
||||
# Handle broadcast transmission
|
||||
if isinstance(dst_ip_address, IPv4Network):
|
||||
is_broadcast = True
|
||||
dst_ip_address = dst_ip_address.broadcast_address
|
||||
if dst_ip_address:
|
||||
# Find a suitable NIC for the broadcast
|
||||
for network_interface in self.node.network_interfaces.values():
|
||||
if dst_ip_address in network_interface.ip_network and network_interface.enabled:
|
||||
dst_mac_address = "ff:ff:ff:ff:ff:ff"
|
||||
outbound_network_interface = network_interface
|
||||
break
|
||||
else:
|
||||
# Resolve MAC address for unicast transmission
|
||||
use_route_table = True
|
||||
for network_interface in self.node.network_interfaces.values():
|
||||
if dst_ip_address in network_interface.ip_network and network_interface.enabled:
|
||||
dst_mac_address = self.software_manager.arp.get_arp_cache_mac_address(dst_ip_address)
|
||||
break
|
||||
|
||||
if dst_mac_address:
|
||||
use_route_table = False
|
||||
outbound_network_interface = self.software_manager.arp.get_arp_cache_network_interface(dst_ip_address)
|
||||
|
||||
if use_route_table:
|
||||
route = self.node.route_table.find_best_route(dst_ip_address)
|
||||
if not route:
|
||||
raise Exception("cannot use route to resolve outbound details")
|
||||
|
||||
dst_mac_address = self.software_manager.arp.get_arp_cache_mac_address(route.next_hop_ip_address)
|
||||
outbound_network_interface = self.software_manager.arp.get_arp_cache_network_interface(
|
||||
route.next_hop_ip_address
|
||||
)
|
||||
return outbound_network_interface, dst_mac_address, dst_ip_address, src_port, dst_port, protocol, is_broadcast
|
||||
|
||||
|
||||
class Router(NetworkNode):
|
||||
"""
|
||||
Represents a network router, managing routing and forwarding of IP packets across network interfaces.
|
||||
@@ -1049,6 +1214,10 @@ class Router(NetworkNode):
|
||||
if not kwargs.get("route_table"):
|
||||
kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"])
|
||||
super().__init__(hostname=hostname, num_ports=num_ports, **kwargs)
|
||||
self.session_manager = RouterSessionManager(sys_log=self.sys_log)
|
||||
self.session_manager.node = self
|
||||
self.software_manager.session_manager = self.session_manager
|
||||
self.session_manager.software_manager = self.software_manager
|
||||
for i in range(1, self.num_ports + 1):
|
||||
network_interface = RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
|
||||
self.connect_nic(network_interface)
|
||||
@@ -1068,8 +1237,7 @@ class Router(NetworkNode):
|
||||
icmp: RouterICMP = self.software_manager.icmp # noqa
|
||||
icmp.router = self
|
||||
self.software_manager.install(RouterARP)
|
||||
arp: RouterARP = self.software_manager.arp # noqa
|
||||
arp.router = self
|
||||
self.arp.router = self
|
||||
|
||||
def _set_default_acl(self):
|
||||
"""
|
||||
@@ -1313,6 +1481,8 @@ class Router(NetworkNode):
|
||||
frame.ethernet.src_mac_addr = network_interface.mac_address
|
||||
frame.ethernet.dst_mac_addr = target_mac
|
||||
network_interface.send_frame(frame)
|
||||
else:
|
||||
self.sys_log.error(f"Frame dropped as there is no route to {frame.ip.dst_ip_address}")
|
||||
|
||||
def configure_port(self, port: int, ip_address: Union[IPv4Address, str], subnet_mask: Union[IPv4Address, str]):
|
||||
"""
|
||||
@@ -1393,6 +1563,13 @@ class Router(NetworkNode):
|
||||
- protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP
|
||||
- src_ip_address (str, optional): IP address octet written in base 10
|
||||
- dst_ip_address (str, optional): IP address octet written in base 10
|
||||
- routes (list[dict]): List of route dicts with values:
|
||||
- address (str): The destination address of the route.
|
||||
- subnet_mask (str): The subnet mask of the route.
|
||||
- next_hop_ip_address (str): The next hop IP for the route.
|
||||
- metric (int): The metric of the route. Optional.
|
||||
- default_route:
|
||||
- next_hop_ip_address (str): The next hop IP for the route.
|
||||
|
||||
Example config:
|
||||
```
|
||||
@@ -1403,6 +1580,10 @@ class Router(NetworkNode):
|
||||
1: {
|
||||
'ip_address' : '192.168.1.1',
|
||||
'subnet_mask' : '255.255.255.0',
|
||||
},
|
||||
2: {
|
||||
'ip_address' : '192.168.0.1',
|
||||
'subnet_mask' : '255.255.255.252',
|
||||
}
|
||||
},
|
||||
'acl' : {
|
||||
@@ -1410,6 +1591,10 @@ class Router(NetworkNode):
|
||||
22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'},
|
||||
23: {'action': 'PERMIT', 'protocol': 'ICMP'},
|
||||
},
|
||||
'routes' : [
|
||||
{'address': '192.168.0.0', 'subnet_mask': '255.255.255.0', 'next_hop_ip_address': '192.168.1.2'}
|
||||
],
|
||||
'default_route': {'next_hop_ip_address': '192.168.0.2'}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1440,7 +1625,9 @@ class Router(NetworkNode):
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
|
||||
position=r_num,
|
||||
)
|
||||
if "routes" in cfg:
|
||||
@@ -1451,4 +1638,8 @@ class Router(NetworkNode):
|
||||
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
|
||||
metric=float(route.get("metric", 0)),
|
||||
)
|
||||
if "default_route" in cfg:
|
||||
next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None)
|
||||
if next_hop_ip_address:
|
||||
router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
|
||||
return router
|
||||
|
||||
@@ -100,13 +100,8 @@ class Switch(NetworkNode):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not self.network_interface:
|
||||
self.network_interface = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
|
||||
for port_num, port in self.network_interface.items():
|
||||
port._connected_node = self
|
||||
port.port_num = port_num
|
||||
port.parent = self
|
||||
port.port_num = port_num
|
||||
for i in range(1, self.num_ports + 1):
|
||||
self.connect_nic(SwitchPort())
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from pydantic import validate_call
|
||||
|
||||
from primaite.simulator.network.airspace import AirSpaceFrequency, IPWirelessNetworkInterface
|
||||
from primaite.simulator.network.hardware.nodes.network.router import Router, RouterInterface
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterInterface
|
||||
from primaite.simulator.network.transmission.data_link_layer import Frame
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.utils.validators import IPV4Address
|
||||
|
||||
|
||||
@@ -209,3 +213,68 @@ class WirelessRouter(Router):
|
||||
raise NotImplementedError(
|
||||
"Please use the 'configure_wireless_access_point' and 'configure_router_interface' functions."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: Dict) -> "WirelessRouter":
|
||||
"""Generate the wireless router from config.
|
||||
|
||||
Schema:
|
||||
- hostname (str): unique name for this router.
|
||||
- router_interface (dict): The values should be another dict specifying
|
||||
- ip_address (str)
|
||||
- subnet_mask (str)
|
||||
- wireless_access_point (dict): Dict with
|
||||
- ip address,
|
||||
- subnet mask,
|
||||
- frequency, (string: either WIFI_2_4 or WIFI_5)
|
||||
- acl (dict): Dict with integers from 1 - max_acl_rules as keys. The key defines the position within the ACL
|
||||
where the rule will be added (lower number is resolved first). The values should describe valid ACL
|
||||
Rules as:
|
||||
- action (str): either PERMIT or DENY
|
||||
- src_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
|
||||
- dst_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
|
||||
- protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP
|
||||
- src_ip_address (str, optional): IP address octet written in base 10
|
||||
- dst_ip_address (str, optional): IP address octet written in base 10
|
||||
|
||||
:param cfg: Config dictionary
|
||||
:type cfg: Dict
|
||||
:return: WirelessRouter instance.
|
||||
:rtype: WirelessRouter
|
||||
"""
|
||||
operating_state = (
|
||||
NodeOperatingState.ON if not (p := cfg.get("operating_state")) else NodeOperatingState[p.upper()]
|
||||
)
|
||||
router = cls(hostname=cfg["hostname"], operating_state=operating_state)
|
||||
if "router_interface" in cfg:
|
||||
ip_address = cfg["router_interface"]["ip_address"]
|
||||
subnet_mask = cfg["router_interface"]["subnet_mask"]
|
||||
router.configure_router_interface(ip_address=ip_address, subnet_mask=subnet_mask)
|
||||
if "wireless_access_point" in cfg:
|
||||
ip_address = cfg["wireless_access_point"]["ip_address"]
|
||||
subnet_mask = cfg["wireless_access_point"]["subnet_mask"]
|
||||
frequency = AirSpaceFrequency[cfg["wireless_access_point"]["frequency"]]
|
||||
router.configure_wireless_access_point(ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency)
|
||||
|
||||
if "acl" in cfg:
|
||||
for r_num, r_cfg in cfg["acl"].items():
|
||||
router.acl.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
|
||||
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
|
||||
position=r_num,
|
||||
)
|
||||
if "routes" in cfg:
|
||||
for route in cfg.get("routes"):
|
||||
router.route_table.add_route(
|
||||
address=IPv4Address(route.get("address")),
|
||||
subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")),
|
||||
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
|
||||
metric=float(route.get("metric", 0)),
|
||||
)
|
||||
return router
|
||||
|
||||
@@ -6,7 +6,7 @@ CAPTURE_NMNE: bool = True
|
||||
NMNE_CAPTURE_KEYWORDS: List[str] = []
|
||||
"""List of keywords to identify malicious network events."""
|
||||
|
||||
# TODO: Remove final and make configurable after example layout when the NicObservation creates nmne structure dynamically
|
||||
# TODO: Remove final and make configurable after example layout when the NICObservation creates nmne structure dynamically
|
||||
CAPTURE_BY_DIRECTION: Final[bool] = True
|
||||
"""Flag to determine if captures should be organized by traffic direction (inbound/outbound)."""
|
||||
CAPTURE_BY_IP_ADDRESS: Final[bool] = False
|
||||
|
||||
@@ -8,7 +8,7 @@ from primaite.simulator.network.protocols.icmp import ICMPPacket
|
||||
from primaite.simulator.network.protocols.packet import DataPacket
|
||||
from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol
|
||||
from primaite.simulator.network.transmission.primaite_layer import PrimaiteHeader
|
||||
from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader
|
||||
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader, UDPHeader
|
||||
from primaite.simulator.network.utils import convert_bytes_to_megabits
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -141,3 +141,37 @@ class Frame(BaseModel):
|
||||
def size_Mbits(self) -> float: # noqa - Keep it as MBits as this is how they're expressed
|
||||
"""The daa transfer size of the Frame in Mbits."""
|
||||
return convert_bytes_to_megabits(self.size)
|
||||
|
||||
@property
|
||||
def is_broadcast(self) -> bool:
|
||||
"""
|
||||
Determines if the Frame is a broadcast frame.
|
||||
|
||||
A Frame is considered a broadcast frame if the destination MAC address is set to the broadcast address
|
||||
"ff:ff:ff:ff:ff:ff".
|
||||
|
||||
:return: True if the destination MAC address is a broadcast address, otherwise False.
|
||||
"""
|
||||
return self.ethernet.dst_mac_addr.lower() == "ff:ff:ff:ff:ff:ff"
|
||||
|
||||
@property
|
||||
def is_arp(self) -> bool:
|
||||
"""
|
||||
Checks if the Frame is an ARP (Address Resolution Protocol) packet.
|
||||
|
||||
This is determined by checking if the destination port of the TCP header is equal to the ARP port.
|
||||
|
||||
:return: True if the Frame is an ARP packet, otherwise False.
|
||||
"""
|
||||
return self.udp.dst_port == Port.ARP
|
||||
|
||||
@property
|
||||
def is_icmp(self) -> bool:
|
||||
"""
|
||||
Determines if the Frame is an ICMP (Internet Control Message Protocol) packet.
|
||||
|
||||
This check is performed by verifying if the 'icmp' attribute of the Frame instance is present (not None).
|
||||
|
||||
:return: True if the Frame is an ICMP packet (i.e., has an ICMP header), otherwise False.
|
||||
"""
|
||||
return self.icmp is not None
|
||||
|
||||
@@ -11,6 +11,9 @@ class Port(Enum):
|
||||
.. _List of Ports:
|
||||
"""
|
||||
|
||||
UNUSED = -1
|
||||
"An unused port stub."
|
||||
|
||||
NONE = 0
|
||||
"Place holder for a non-port."
|
||||
WOL = 9
|
||||
|
||||
@@ -63,3 +63,8 @@ class Simulation(SimComponent):
|
||||
"""Apply a timestep to the simulation."""
|
||||
super().apply_timestep(timestep)
|
||||
self.network.apply_timestep(timestep)
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
self.network.pre_timestep(timestep)
|
||||
|
||||
@@ -3,6 +3,8 @@ from enum import Enum
|
||||
from typing import Any, Dict, Set
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -38,6 +40,17 @@ class Application(IOSoftware):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request("close", RequestType(func=lambda request, context: RequestResponse.from_bool(self.close())))
|
||||
return rm
|
||||
|
||||
@abstractmethod
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
@@ -67,7 +80,10 @@ class Application(IOSoftware):
|
||||
"""
|
||||
super().apply_timestep(timestep=timestep)
|
||||
|
||||
self.num_executions = 0 # reset number of executions
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
self.num_executions = 0
|
||||
|
||||
def _can_perform_action(self) -> bool:
|
||||
"""
|
||||
@@ -83,7 +99,7 @@ class Application(IOSoftware):
|
||||
|
||||
if self.operating_state is not self.operating_state.RUNNING:
|
||||
# service is not running
|
||||
_LOGGER.error(f"Cannot perform action: {self.name} is {self.operating_state.name}")
|
||||
_LOGGER.debug(f"Cannot perform action: {self.name} is {self.operating_state.name}")
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -104,11 +120,12 @@ class Application(IOSoftware):
|
||||
"""The main application loop."""
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
def close(self) -> bool:
|
||||
"""Close the Application."""
|
||||
if self.operating_state == ApplicationOperatingState.RUNNING:
|
||||
self.sys_log.info(f"Closed Application{self.name}")
|
||||
self.operating_state = ApplicationOperatingState.CLOSED
|
||||
return True
|
||||
|
||||
def install(self) -> None:
|
||||
"""Install Application."""
|
||||
|
||||
@@ -29,6 +29,9 @@ class DatabaseClient(Application):
|
||||
_query_success_tracker: Dict[str, bool] = {}
|
||||
_last_connection_successful: Optional[bool] = None
|
||||
"""Keep track of connections that were established or verified during this step. Used for rewards."""
|
||||
last_query_response: Optional[Dict] = None
|
||||
"""Keep track of the latest query response. Used to determine rewards."""
|
||||
_server_connection_id: Optional[str] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "DatabaseClient"
|
||||
@@ -49,10 +52,9 @@ class DatabaseClient(Application):
|
||||
def execute(self) -> bool:
|
||||
"""Execution definition for db client: perform a select query."""
|
||||
self.num_executions += 1 # trying to connect counts as an execution
|
||||
if self.connections:
|
||||
can_connect = self.check_connection(connection_id=list(self.connections.keys())[-1])
|
||||
else:
|
||||
can_connect = self.check_connection(connection_id=str(uuid4()))
|
||||
if not self._server_connection_id:
|
||||
self.connect()
|
||||
can_connect = self.check_connection(connection_id=self._server_connection_id)
|
||||
self._last_connection_successful = can_connect
|
||||
return can_connect
|
||||
|
||||
@@ -78,17 +80,21 @@ class DatabaseClient(Application):
|
||||
self.server_password = server_password
|
||||
self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.")
|
||||
|
||||
def connect(self, connection_id: Optional[str] = None) -> bool:
|
||||
def connect(self) -> bool:
|
||||
"""Connect to a Database Service."""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if not connection_id:
|
||||
connection_id = str(uuid4())
|
||||
if not self._server_connection_id:
|
||||
self._server_connection_id = str(uuid4())
|
||||
|
||||
self.connected = self._connect(
|
||||
server_ip_address=self.server_ip_address, password=self.server_password, connection_id=connection_id
|
||||
server_ip_address=self.server_ip_address,
|
||||
password=self.server_password,
|
||||
connection_id=self._server_connection_id,
|
||||
)
|
||||
if not self.connected:
|
||||
self._server_connection_id = None
|
||||
return self.connected
|
||||
|
||||
def check_connection(self, connection_id: str) -> bool:
|
||||
@@ -123,7 +129,7 @@ class DatabaseClient(Application):
|
||||
:type: is_reattempt: Optional[bool]
|
||||
"""
|
||||
if is_reattempt:
|
||||
if self.connections.get(connection_id):
|
||||
if self._server_connection_id:
|
||||
self.sys_log.info(
|
||||
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} authorised"
|
||||
)
|
||||
@@ -147,31 +153,28 @@ class DatabaseClient(Application):
|
||||
server_ip_address=server_ip_address, password=password, connection_id=connection_id, is_reattempt=True
|
||||
)
|
||||
|
||||
def disconnect(self, connection_id: Optional[str] = None) -> bool:
|
||||
def disconnect(self) -> bool:
|
||||
"""Disconnect from the Database Service."""
|
||||
if not self._can_perform_action():
|
||||
self.sys_log.error(f"Unable to disconnect - {self.name} is {self.operating_state.name}")
|
||||
return False
|
||||
|
||||
# if there are no connections - nothing to disconnect
|
||||
if not len(self.connections):
|
||||
if not self._server_connection_id:
|
||||
self.sys_log.error(f"Unable to disconnect - {self.name} has no active connections.")
|
||||
return False
|
||||
|
||||
# if no connection provided, disconnect the first connection
|
||||
if not connection_id:
|
||||
connection_id = list(self.connections.keys())[0]
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect", "connection_id": connection_id},
|
||||
payload={"type": "disconnect", "connection_id": self._server_connection_id},
|
||||
dest_ip_address=self.server_ip_address,
|
||||
dest_port=self.port,
|
||||
)
|
||||
self.remove_connection(connection_id=connection_id)
|
||||
self.remove_connection(connection_id=self._server_connection_id)
|
||||
|
||||
self.sys_log.info(
|
||||
f"{self.name}: DatabaseClient disconnected connection {connection_id} from {self.server_ip_address}"
|
||||
f"{self.name}: DatabaseClient disconnected {self._server_connection_id} from {self.server_ip_address}"
|
||||
)
|
||||
self.connected = False
|
||||
|
||||
@@ -219,18 +222,23 @@ class DatabaseClient(Application):
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if connection_id is None:
|
||||
if self.connections:
|
||||
connection_id = list(self.connections.keys())[-1]
|
||||
# TODO: if the most recent connection dies, it should be automatically cleared.
|
||||
else:
|
||||
connection_id = str(uuid4())
|
||||
# reset last query response
|
||||
self.last_query_response = None
|
||||
|
||||
if not self.connections.get(connection_id):
|
||||
if not self.connect(connection_id=connection_id):
|
||||
return False
|
||||
connection_id: str
|
||||
|
||||
if not connection_id:
|
||||
connection_id = self._server_connection_id
|
||||
|
||||
if not connection_id:
|
||||
self.connect()
|
||||
connection_id = self._server_connection_id
|
||||
|
||||
if not connection_id:
|
||||
msg = "Cannot run sql query, could not establish connection with the server."
|
||||
self.parent.sys_log.error(msg)
|
||||
return False
|
||||
|
||||
# Initialise the tracker of this ID to False
|
||||
uuid = str(uuid4())
|
||||
self._query_success_tracker[uuid] = False
|
||||
return self._query(sql=sql, query_id=uuid, connection_id=connection_id)
|
||||
@@ -252,6 +260,7 @@ class DatabaseClient(Application):
|
||||
# add connection
|
||||
self.add_connection(connection_id=payload.get("connection_id"), session_id=session_id)
|
||||
elif payload["type"] == "sql":
|
||||
self.last_query_response = payload
|
||||
query_id = payload.get("uuid")
|
||||
status_code = payload.get("status_code")
|
||||
self._query_success_tracker[query_id] = status_code == 200
|
||||
|
||||
@@ -0,0 +1,316 @@
|
||||
from enum import IntEnum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class RansomwareAttackStage(IntEnum):
|
||||
"""
|
||||
Enumeration representing different attack stages of the ransomware script.
|
||||
|
||||
This enumeration defines the various stages a data manipulation attack can be in during its lifecycle
|
||||
in the simulation.
|
||||
Each stage represents a specific phase in the attack process.
|
||||
"""
|
||||
|
||||
NOT_STARTED = 0
|
||||
"Indicates that the attack has not started yet."
|
||||
DOWNLOAD = 1
|
||||
"Installing the Encryption Script - Testing"
|
||||
INSTALL = 2
|
||||
"The stage where logon procedures are simulated."
|
||||
ACTIVATE = 3
|
||||
"Operating Status Changes"
|
||||
PROPAGATE = 4
|
||||
"Represents the stage of performing a horizontal port scan on the target."
|
||||
COMMAND_AND_CONTROL = 5
|
||||
"Represents the stage of setting up a rely C2 Beacon (Not Implemented)"
|
||||
PAYLOAD = 6
|
||||
"Stage of actively attacking the target."
|
||||
SUCCEEDED = 7
|
||||
"Indicates the attack has been successfully completed."
|
||||
FAILED = 8
|
||||
"Signifies that the attack has failed."
|
||||
|
||||
|
||||
class RansomwareScript(Application):
|
||||
"""Ransomware Kill Chain - Designed to be used by the TAP001 Agent on the example layout Network.
|
||||
|
||||
:ivar payload: The attack stage query payload. (Default Corrupt)
|
||||
:ivar target_scan_p_of_success: The probability of success for the target scan stage.
|
||||
:ivar c2_beacon_p_of_success: The probability of success for the c2_beacon stage
|
||||
:ivar ransomware_encrypt_p_of_success: The probability of success for the ransomware 'attack' (encrypt) stage.
|
||||
:ivar repeat: Whether to repeat attacking once finished.
|
||||
"""
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
"""IP address of node which hosts the database."""
|
||||
server_password: Optional[str] = None
|
||||
"""Password required to access the database."""
|
||||
payload: Optional[str] = "ENCRYPT"
|
||||
"Payload String for the payload stage"
|
||||
target_scan_p_of_success: float = 0.9
|
||||
"Probability of the target scan succeeding: Default 0.9"
|
||||
c2_beacon_p_of_success: float = 0.9
|
||||
"Probability of the c2 beacon setup stage succeeding: Default 0.9"
|
||||
ransomware_encrypt_p_of_success: float = 0.9
|
||||
"Probability of the ransomware attack succeeding: Default 0.9"
|
||||
repeat: bool = False
|
||||
"If true, the Denial of Service bot will keep performing the attack."
|
||||
attack_stage: RansomwareAttackStage = RansomwareAttackStage.NOT_STARTED
|
||||
"The ransomware attack stage. See RansomwareAttackStage Class"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "RansomwareScript"
|
||||
kwargs["port"] = Port.NONE
|
||||
kwargs["protocol"] = IPProtocol.NONE
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
|
||||
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
|
||||
|
||||
:return: Current state of this object and child objects.
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
return state
|
||||
|
||||
@property
|
||||
def _host_db_client(self) -> DatabaseClient:
|
||||
"""Return the database client that is installed on the same machine as the Ransomware Script."""
|
||||
db_client = self.software_manager.software.get("DatabaseClient")
|
||||
if db_client is None:
|
||||
_LOGGER.info(f"{self.__class__.__name__} cannot find a database client on its host.")
|
||||
return db_client
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
name="execute",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.attack())),
|
||||
)
|
||||
return rm
|
||||
|
||||
def _activate(self):
|
||||
"""
|
||||
Simulate the install process as the initial stage of the attack.
|
||||
|
||||
Advances the attack stage to 'ACTIVATE' attack state.
|
||||
"""
|
||||
if self.attack_stage == RansomwareAttackStage.INSTALL:
|
||||
self.sys_log.info(f"{self.name}: Activated!")
|
||||
self.attack_stage = RansomwareAttackStage.ACTIVATE
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
Apply a timestep to the bot, triggering the application loop.
|
||||
|
||||
:param timestep: The timestep value to update the bot's state.
|
||||
"""
|
||||
pass
|
||||
|
||||
def run(self) -> bool:
|
||||
"""Calls the parent classes execute method before starting the application loop."""
|
||||
super().run()
|
||||
return True
|
||||
|
||||
def _application_loop(self) -> bool:
|
||||
"""
|
||||
The main application loop of the script, handling the attack process.
|
||||
|
||||
This is the core loop where the bot sequentially goes through the stages of the attack.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
if self.server_ip_address and self.payload:
|
||||
self.sys_log.info(f"{self.name}: Running")
|
||||
self.attack_stage = RansomwareAttackStage.NOT_STARTED
|
||||
self._local_download()
|
||||
self._install()
|
||||
self._activate()
|
||||
self._perform_target_scan()
|
||||
self._setup_beacon()
|
||||
self._perform_ransomware_encrypt()
|
||||
|
||||
if self.repeat and self.attack_stage in (
|
||||
RansomwareAttackStage.SUCCEEDED,
|
||||
RansomwareAttackStage.FAILED,
|
||||
):
|
||||
self.attack_stage = RansomwareAttackStage.NOT_STARTED
|
||||
return True
|
||||
else:
|
||||
self.sys_log.error(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.")
|
||||
return False
|
||||
|
||||
def configure(
|
||||
self,
|
||||
server_ip_address: IPv4Address,
|
||||
server_password: Optional[str] = None,
|
||||
payload: Optional[str] = None,
|
||||
target_scan_p_of_success: Optional[float] = None,
|
||||
c2_beacon_p_of_success: Optional[float] = None,
|
||||
ransomware_encrypt_p_of_success: Optional[float] = None,
|
||||
repeat: bool = True,
|
||||
):
|
||||
"""
|
||||
Configure the Ransomware Script to communicate with a DatabaseService.
|
||||
|
||||
:param server_ip_address: The IP address of the Node the DatabaseService is on.
|
||||
:param server_password: The password on the DatabaseService.
|
||||
:param payload: The attack stage query (Encrypt / Delete)
|
||||
:param target_scan_p_of_success: The probability of success for the target scan stage.
|
||||
:param c2_beacon_p_of_success: The probability of success for the c2_beacon stage
|
||||
:param ransomware_encrypt_p_of_success: The probability of success for the ransomware 'attack' (encrypt) stage.
|
||||
:param repeat: Whether to repeat attacking once finished.
|
||||
"""
|
||||
if server_ip_address:
|
||||
self.server_ip_address = server_ip_address
|
||||
if server_password:
|
||||
self.server_password = server_password
|
||||
if payload:
|
||||
self.payload = payload
|
||||
if target_scan_p_of_success:
|
||||
self.target_scan_p_of_success = target_scan_p_of_success
|
||||
if c2_beacon_p_of_success:
|
||||
self.c2_beacon_p_of_success = c2_beacon_p_of_success
|
||||
if ransomware_encrypt_p_of_success:
|
||||
self.ransomware_encrypt_p_of_success = ransomware_encrypt_p_of_success
|
||||
if repeat:
|
||||
self.repeat = repeat
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}, "
|
||||
f"{repeat=}."
|
||||
)
|
||||
|
||||
def _install(self):
|
||||
"""
|
||||
Simulate the install stage in the kill-chain.
|
||||
|
||||
Advances the attack stage to 'ACTIVATE' if successful.
|
||||
|
||||
From this attack stage onwards.
|
||||
the ransomware application is now visible from this point onwardin the observation space.
|
||||
"""
|
||||
if self.attack_stage == RansomwareAttackStage.DOWNLOAD:
|
||||
self.sys_log.info(f"{self.name}: Malware installed on the local file system")
|
||||
downloads_folder = self.file_system.get_folder(folder_name="downloads")
|
||||
ransomware_file = downloads_folder.get_file(file_name="ransom_script.pdf")
|
||||
ransomware_file.num_access += 1
|
||||
self.attack_stage = RansomwareAttackStage.INSTALL
|
||||
|
||||
def _setup_beacon(self):
|
||||
"""
|
||||
Simulates setting up a c2 beacon; currently a pseudo step for increasing red variance.
|
||||
|
||||
Advances the attack stage to 'COMMAND AND CONTROL` if successful.
|
||||
|
||||
:param p_of_sucess: Probability of a successful c2 setup (Advancing this step),
|
||||
by default the success rate is 0.5
|
||||
"""
|
||||
if self.attack_stage == RansomwareAttackStage.PROPAGATE:
|
||||
self.sys_log.info(f"{self.name} Attempting to set up C&C Beacon - Scan 1/2")
|
||||
if simulate_trial(self.c2_beacon_p_of_success):
|
||||
self.sys_log.info(f"{self.name} C&C Successful setup - Scan 2/2")
|
||||
c2c_setup = True # TODO Implement the c2c step via an FTP Application/Service
|
||||
if c2c_setup:
|
||||
self.attack_stage = RansomwareAttackStage.COMMAND_AND_CONTROL
|
||||
|
||||
def _perform_target_scan(self):
|
||||
"""
|
||||
Perform a simulated port scan to check for open SQL ports.
|
||||
|
||||
Advances the attack stage to `PROPAGATE` if successful.
|
||||
|
||||
:param p_of_success: Probability of successful port scan, by default 0.1.
|
||||
"""
|
||||
if self.attack_stage == RansomwareAttackStage.ACTIVATE:
|
||||
# perform a port scan to identify that the SQL port is open on the server
|
||||
self.sys_log.info(f"{self.name}: Scanning for vulnerable databases - Scan 0/2")
|
||||
if simulate_trial(self.target_scan_p_of_success):
|
||||
self.sys_log.info(f"{self.name}: Found a target database! Scan 1/2")
|
||||
port_is_open = True # TODO Implement a NNME Triggering scan as a seperate Red Application
|
||||
if port_is_open:
|
||||
self.attack_stage = RansomwareAttackStage.PROPAGATE
|
||||
|
||||
def attack(self) -> bool:
|
||||
"""Perform the attack steps after opening the application."""
|
||||
if not self._can_perform_action():
|
||||
_LOGGER.debug("Ransomware application is unable to perform it's actions.")
|
||||
self.run()
|
||||
self.num_executions += 1
|
||||
return self._application_loop()
|
||||
|
||||
def _perform_ransomware_encrypt(self):
|
||||
"""
|
||||
Execute the Ransomware Encrypt payload on the target.
|
||||
|
||||
Advances the attack stage to `COMPLETE` if successful, or 'FAILED' if unsuccessful.
|
||||
:param p_of_success: Probability of successfully performing ransomware encryption, by default 0.1.
|
||||
"""
|
||||
if self._host_db_client is None:
|
||||
self.sys_log.info(f"{self.name}: Failed to connect to db_client - Ransomware Script")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
return
|
||||
|
||||
self._host_db_client.server_ip_address = self.server_ip_address
|
||||
self._host_db_client.server_password = self.server_password
|
||||
if self.attack_stage == RansomwareAttackStage.COMMAND_AND_CONTROL:
|
||||
if simulate_trial(self.ransomware_encrypt_p_of_success):
|
||||
self.sys_log.info(f"{self.name}: Attempting to launch payload")
|
||||
if not len(self._host_db_client.connections):
|
||||
self._host_db_client.connect()
|
||||
if len(self._host_db_client.connections):
|
||||
self._host_db_client.query(self.payload)
|
||||
self.sys_log.info(f"{self.name} Payload delivered: {self.payload}")
|
||||
attack_successful = True
|
||||
if attack_successful:
|
||||
self.sys_log.info(f"{self.name}: Payload Successful")
|
||||
self.attack_stage = RansomwareAttackStage.SUCCEEDED
|
||||
else:
|
||||
self.sys_log.info(f"{self.name}: Payload failed")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
else:
|
||||
self.sys_log.error("Attack Attempted to launch too quickly")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
|
||||
def _local_download(self):
|
||||
"""Downloads itself via the onto the local file_system."""
|
||||
if self.attack_stage == RansomwareAttackStage.NOT_STARTED:
|
||||
if self._local_download_verify():
|
||||
self.attack_stage = RansomwareAttackStage.DOWNLOAD
|
||||
else:
|
||||
self.sys_log.info("Malware failed to create a installation location")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
else:
|
||||
self.sys_log.info("Malware failed to download")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
|
||||
def _local_download_verify(self) -> bool:
|
||||
"""Verifies a download location - Creates one if needed."""
|
||||
for folder in self.file_system.folders:
|
||||
if self.file_system.folders[folder].name == "downloads":
|
||||
self.file_system.num_file_creations += 1
|
||||
return True
|
||||
|
||||
self.file_system.create_folder("downloads")
|
||||
self.file_system.create_file(folder_name="downloads", file_name="ransom_script.pdf")
|
||||
return True
|
||||
@@ -49,8 +49,9 @@ class PacketCapture:
|
||||
|
||||
self.current_episode: int = 1
|
||||
|
||||
self.setup_logger(outbound=False)
|
||||
self.setup_logger(outbound=True)
|
||||
if SIM_OUTPUT.save_pcap_logs:
|
||||
self.setup_logger(outbound=False)
|
||||
self.setup_logger(outbound=True)
|
||||
|
||||
def setup_logger(self, outbound: bool = False):
|
||||
"""Set up the logger configuration."""
|
||||
@@ -108,8 +109,9 @@ class PacketCapture:
|
||||
|
||||
:param frame: The PCAP frame to capture.
|
||||
"""
|
||||
msg = frame.model_dump_json()
|
||||
self.inbound_logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL
|
||||
if SIM_OUTPUT.save_pcap_logs:
|
||||
msg = frame.model_dump_json()
|
||||
self.inbound_logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL
|
||||
|
||||
def capture_outbound(self, frame): # noqa - I'll have a circular import and cant use if TYPE_CHECKING ;(
|
||||
"""
|
||||
@@ -117,5 +119,6 @@ class PacketCapture:
|
||||
|
||||
:param frame: The PCAP frame to capture.
|
||||
"""
|
||||
msg = frame.model_dump_json()
|
||||
self.outbound_logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL
|
||||
if SIM_OUTPUT.save_pcap_logs:
|
||||
msg = frame.model_dump_json()
|
||||
self.outbound_logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL
|
||||
|
||||
@@ -72,7 +72,6 @@ class SessionManager:
|
||||
Manages network sessions, including session creation, lookup, and communication with other components.
|
||||
|
||||
:param sys_log: A reference to the system log component.
|
||||
:param arp_cache: A reference to the ARP cache component.
|
||||
"""
|
||||
|
||||
def __init__(self, sys_log: SysLog):
|
||||
|
||||
@@ -88,6 +88,10 @@ class SysLog:
|
||||
root.mkdir(exist_ok=True, parents=True)
|
||||
return root / f"{self.hostname}_sys.log"
|
||||
|
||||
def _write_to_terminal(self, msg: str, level: str, to_terminal: bool = False):
|
||||
if to_terminal or SIM_OUTPUT.write_sys_log_to_terminal:
|
||||
print(f"{self.hostname}: ({level}) {msg}")
|
||||
|
||||
def debug(self, msg: str, to_terminal: bool = False):
|
||||
"""
|
||||
Logs a message with the DEBUG level.
|
||||
@@ -97,8 +101,7 @@ class SysLog:
|
||||
"""
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.debug(msg)
|
||||
if to_terminal:
|
||||
print(msg)
|
||||
self._write_to_terminal(msg, "DEBUG", to_terminal)
|
||||
|
||||
def info(self, msg: str, to_terminal: bool = False):
|
||||
"""
|
||||
@@ -109,8 +112,7 @@ class SysLog:
|
||||
"""
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.info(msg)
|
||||
if to_terminal:
|
||||
print(msg)
|
||||
self._write_to_terminal(msg, "INFO", to_terminal)
|
||||
|
||||
def warning(self, msg: str, to_terminal: bool = False):
|
||||
"""
|
||||
@@ -121,8 +123,7 @@ class SysLog:
|
||||
"""
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.warning(msg)
|
||||
if to_terminal:
|
||||
print(msg)
|
||||
self._write_to_terminal(msg, "WARNING", to_terminal)
|
||||
|
||||
def error(self, msg: str, to_terminal: bool = False):
|
||||
"""
|
||||
@@ -133,8 +134,7 @@ class SysLog:
|
||||
"""
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.error(msg)
|
||||
if to_terminal:
|
||||
print(msg)
|
||||
self._write_to_terminal(msg, "ERROR", to_terminal)
|
||||
|
||||
def critical(self, msg: str, to_terminal: bool = False):
|
||||
"""
|
||||
@@ -145,5 +145,4 @@ class SysLog:
|
||||
"""
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.critical(msg)
|
||||
if to_terminal:
|
||||
print(msg)
|
||||
self._write_to_terminal(msg, "CRITICAL", to_terminal)
|
||||
|
||||
@@ -65,6 +65,10 @@ class ARP(Service):
|
||||
"""Clears the arp cache."""
|
||||
self.arp.clear()
|
||||
|
||||
def get_default_gateway_network_interface(self) -> Optional[NetworkInterface]:
|
||||
"""Not used at the parent ARP level. Should return None when there is no override by child class."""
|
||||
return None
|
||||
|
||||
def add_arp_cache_entry(
|
||||
self, ip_address: IPV4Address, mac_address: str, network_interface: NetworkInterface, override: bool = False
|
||||
):
|
||||
|
||||
@@ -104,14 +104,30 @@ class DatabaseService(Service):
|
||||
self.sys_log.error("Unable to restore database backup.")
|
||||
return False
|
||||
|
||||
old_visible_state = SoftwareHealthState.GOOD
|
||||
|
||||
# get db file regardless of whether or not it was deleted
|
||||
db_file = self.file_system.get_file(folder_name="database", file_name="database.db", include_deleted=True)
|
||||
|
||||
if db_file is None:
|
||||
self.sys_log.error("Database file not initialised.")
|
||||
return False
|
||||
|
||||
# if the file was deleted, get the old visible health state
|
||||
if db_file.deleted:
|
||||
old_visible_state = db_file.visible_health_status
|
||||
else:
|
||||
old_visible_state = self.db_file.visible_health_status
|
||||
self.file_system.delete_file(folder_name="database", file_name="database.db")
|
||||
|
||||
# replace db file
|
||||
self.file_system.delete_file(folder_name="database", file_name="database.db")
|
||||
self.file_system.copy_file(src_folder_name="downloads", src_file_name="database.db", dst_folder_name="database")
|
||||
|
||||
if self.db_file is None:
|
||||
self.sys_log.error("Copying database backup failed.")
|
||||
return False
|
||||
|
||||
self.db_file.visible_health_status = old_visible_state
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
|
||||
return True
|
||||
@@ -125,8 +141,7 @@ class DatabaseService(Service):
|
||||
"""Returns the database file."""
|
||||
return self.file_system.get_file(folder_name="database", file_name="database.db")
|
||||
|
||||
@property
|
||||
def folder(self) -> Folder:
|
||||
def _return_database_folder(self) -> Folder:
|
||||
"""Returns the database folder."""
|
||||
return self.file_system.get_folder_by_id(self.db_file.folder_id)
|
||||
|
||||
@@ -171,7 +186,10 @@ class DatabaseService(Service):
|
||||
}
|
||||
|
||||
def _process_sql(
|
||||
self, query: Literal["SELECT", "DELETE", "INSERT"], query_id: str, connection_id: Optional[str] = None
|
||||
self,
|
||||
query: Literal["SELECT", "DELETE", "INSERT", "ENCRYPT"],
|
||||
query_id: str,
|
||||
connection_id: Optional[str] = None,
|
||||
) -> Dict[str, Union[int, List[Any]]]:
|
||||
"""
|
||||
Executes the given SQL query and returns the result.
|
||||
@@ -180,6 +198,7 @@ class DatabaseService(Service):
|
||||
- SELECT : returns the data
|
||||
- DELETE : deletes the data
|
||||
- INSERT : inserts the data
|
||||
- ENCRYPT : corrupts the data
|
||||
|
||||
:param query: The SQL query to be executed.
|
||||
:return: Dictionary containing status code and data fetched.
|
||||
@@ -188,10 +207,18 @@ class DatabaseService(Service):
|
||||
|
||||
if not self.db_file:
|
||||
self.sys_log.info(f"{self.name}: Failed to run {query} because the database file is missing.")
|
||||
return {"status_code": 404, "data": False}
|
||||
return {"status_code": 404, "type": "sql", "data": False}
|
||||
|
||||
if query == "SELECT":
|
||||
if self.db_file.health_status == FileSystemItemHealthStatus.GOOD:
|
||||
if self.db_file.health_status == FileSystemItemHealthStatus.CORRUPT:
|
||||
return {
|
||||
"status_code": 200,
|
||||
"type": "sql",
|
||||
"data": False,
|
||||
"uuid": query_id,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
elif self.db_file.health_status == FileSystemItemHealthStatus.GOOD:
|
||||
return {
|
||||
"status_code": 200,
|
||||
"type": "sql",
|
||||
@@ -200,7 +227,7 @@ class DatabaseService(Service):
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
return {"status_code": 404, "type": "sql", "data": False}
|
||||
elif query == "DELETE":
|
||||
self.db_file.health_status = FileSystemItemHealthStatus.COMPROMISED
|
||||
return {
|
||||
@@ -210,6 +237,20 @@ class DatabaseService(Service):
|
||||
"uuid": query_id,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
elif query == "ENCRYPT":
|
||||
self.file_system.num_file_creations += 1
|
||||
self.db_file.health_status = FileSystemItemHealthStatus.CORRUPT
|
||||
self.db_file.num_access += 1
|
||||
database_folder = self._return_database_folder()
|
||||
database_folder.health_status = FileSystemItemHealthStatus.CORRUPT
|
||||
self.file_system.num_file_deletions += 1
|
||||
return {
|
||||
"status_code": 200,
|
||||
"type": "sql",
|
||||
"data": False,
|
||||
"uuid": query_id,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
elif query == "INSERT":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
return {
|
||||
@@ -220,7 +261,7 @@ class DatabaseService(Service):
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
return {"status_code": 404, "type": "sql", "data": False}
|
||||
elif query == "SELECT * FROM pg_stat_activity":
|
||||
# Check if the connection is active.
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
@@ -304,8 +345,8 @@ class DatabaseService(Service):
|
||||
self.backup_database()
|
||||
return super().apply_timestep(timestep)
|
||||
|
||||
def _update_patch_status(self) -> None:
|
||||
"""Perform a database restore when the patching countdown is finished."""
|
||||
super()._update_patch_status()
|
||||
if self._patching_countdown is None:
|
||||
def _update_fix_status(self) -> None:
|
||||
"""Perform a database restore when the FIXING countdown is finished."""
|
||||
super()._update_fix_status()
|
||||
if self._fixing_countdown is None:
|
||||
self.restore_backup()
|
||||
|
||||
@@ -87,13 +87,9 @@ class NTPClient(Service):
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
if not isinstance(payload, NTPPacket):
|
||||
_LOGGER.debug(f"{payload} is not a NTPPacket")
|
||||
_LOGGER.debug(f"{self.name}: Failed to parse NTP update")
|
||||
return False
|
||||
if payload.ntp_reply.ntp_datetime:
|
||||
self.sys_log.info(
|
||||
f"{self.name}: \
|
||||
Received time update from NTP server{payload.ntp_reply.ntp_datetime}"
|
||||
)
|
||||
self.time = payload.ntp_reply.ntp_datetime
|
||||
return True
|
||||
|
||||
@@ -124,5 +120,3 @@ class NTPClient(Service):
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
# request time from server
|
||||
self.request_time()
|
||||
else:
|
||||
self.sys_log.debug(f"{self.name} ntp client not running")
|
||||
|
||||
@@ -59,7 +59,7 @@ class Service(IOSoftware):
|
||||
|
||||
if self.operating_state is not ServiceOperatingState.RUNNING:
|
||||
# service is not running
|
||||
_LOGGER.error(f"Cannot perform action: {self.name} is {self.operating_state.name}")
|
||||
_LOGGER.debug(f"Cannot perform action: {self.name} is {self.operating_state.name}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -43,8 +43,8 @@ class SoftwareHealthState(Enum):
|
||||
"Unused state."
|
||||
GOOD = 1
|
||||
"The software is in a good and healthy condition."
|
||||
PATCHING = 2
|
||||
"The software is undergoing patching or updates."
|
||||
FIXING = 2
|
||||
"The software is undergoing FIXING or updates."
|
||||
COMPROMISED = 3
|
||||
"The software's security has been compromised."
|
||||
OVERWHELMED = 4
|
||||
@@ -82,13 +82,13 @@ class Software(SimComponent):
|
||||
"The health state of the software visible to the red agent."
|
||||
criticality: SoftwareCriticality = SoftwareCriticality.LOWEST
|
||||
"The criticality level of the software."
|
||||
patching_count: int = 0
|
||||
fixing_count: int = 0
|
||||
"The count of patches applied to the software, defaults to 0."
|
||||
scanning_count: int = 0
|
||||
"The count of times the software has been scanned, defaults to 0."
|
||||
revealed_to_red: bool = False
|
||||
"Indicates if the software has been revealed to red agent, defaults is False."
|
||||
software_manager: "SoftwareManager" = None
|
||||
software_manager: Optional["SoftwareManager"] = None
|
||||
"An instance of Software Manager that is used by the parent node."
|
||||
sys_log: SysLog = None
|
||||
"An instance of SysLog that is used by the parent node."
|
||||
@@ -96,9 +96,9 @@ class Software(SimComponent):
|
||||
"The FileSystem of the Node the Software is installed on."
|
||||
folder: Optional[Folder] = None
|
||||
"The folder on the file system the Software uses."
|
||||
patching_duration: int = 2
|
||||
fixing_duration: int = 2
|
||||
"The number of ticks it takes to patch the software."
|
||||
_patching_countdown: Optional[int] = None
|
||||
_fixing_countdown: Optional[int] = None
|
||||
"Current number of ticks left to patch the software."
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
@@ -117,9 +117,9 @@ class Software(SimComponent):
|
||||
),
|
||||
)
|
||||
rm.add_request(
|
||||
"patch",
|
||||
"fix",
|
||||
RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.patch()),
|
||||
func=lambda request, context: RequestResponse.from_bool(self.fix()),
|
||||
),
|
||||
)
|
||||
rm.add_request("scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan())))
|
||||
@@ -149,7 +149,7 @@ class Software(SimComponent):
|
||||
"health_state_actual": self.health_state_actual.value,
|
||||
"health_state_visible": self.health_state_visible.value,
|
||||
"criticality": self.criticality.value,
|
||||
"patching_count": self.patching_count,
|
||||
"fixing_count": self.fixing_count,
|
||||
"scanning_count": self.scanning_count,
|
||||
"revealed_to_red": self.revealed_to_red,
|
||||
}
|
||||
@@ -194,21 +194,21 @@ class Software(SimComponent):
|
||||
self.health_state_visible = self.health_state_actual
|
||||
return True
|
||||
|
||||
def patch(self) -> bool:
|
||||
"""Perform a patch on the software."""
|
||||
def fix(self) -> bool:
|
||||
"""Perform a fix on the software."""
|
||||
if self.health_state_actual in (SoftwareHealthState.COMPROMISED, SoftwareHealthState.GOOD):
|
||||
self._patching_countdown = self.patching_duration
|
||||
self.set_health_state(SoftwareHealthState.PATCHING)
|
||||
self._fixing_countdown = self.fixing_duration
|
||||
self.set_health_state(SoftwareHealthState.FIXING)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _update_patch_status(self) -> None:
|
||||
"""Update the patch status of the software."""
|
||||
self._patching_countdown -= 1
|
||||
if self._patching_countdown <= 0:
|
||||
def _update_fix_status(self) -> None:
|
||||
"""Update the fix status of the software."""
|
||||
self._fixing_countdown -= 1
|
||||
if self._fixing_countdown <= 0:
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
self._patching_countdown = None
|
||||
self.patching_count += 1
|
||||
self._fixing_countdown = None
|
||||
self.fixing_count += 1
|
||||
|
||||
def reveal_to_red(self) -> None:
|
||||
"""Reveals the software to the red agent."""
|
||||
@@ -221,8 +221,12 @@ class Software(SimComponent):
|
||||
:param timestep: The current timestep of the simulation.
|
||||
"""
|
||||
super().apply_timestep(timestep)
|
||||
if self.health_state_actual == SoftwareHealthState.PATCHING:
|
||||
self._update_patch_status()
|
||||
if self.health_state_actual == SoftwareHealthState.FIXING:
|
||||
self._update_fix_status()
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
|
||||
|
||||
class IOSoftware(Software):
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
# flake8: noqa
|
||||
raise DeprecationWarning(
|
||||
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
|
||||
)
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
# flake8: noqa
|
||||
raise DeprecationWarning(
|
||||
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
|
||||
)
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
# flake8: noqa
|
||||
raise DeprecationWarning(
|
||||
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
|
||||
)
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import csv
|
||||
from logging import Logger
|
||||
|
||||
Reference in New Issue
Block a user