Finalise actions interface
This commit is contained in:
@@ -27,11 +27,11 @@ game_config:
|
||||
- type: LOGON
|
||||
- type: LOGOFF
|
||||
applications:
|
||||
- application_ref: client_2_web_browser
|
||||
actions:
|
||||
- type: EXECUTE
|
||||
execution_definition:
|
||||
target_address: arcd.com
|
||||
# - application_ref: client_2_web_browser
|
||||
# actions:
|
||||
# - type: EXECUTE
|
||||
# execution_definition:
|
||||
# target_address: arcd.com
|
||||
reward_function: null
|
||||
agent_settings:
|
||||
start_step: 5
|
||||
@@ -42,7 +42,8 @@ game_config:
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
observation_space:
|
||||
network:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_1
|
||||
observations:
|
||||
@@ -85,13 +86,307 @@ game_config:
|
||||
|
||||
- ref: defender
|
||||
team: blue
|
||||
type: GATE_RL_AGENT
|
||||
type: GATERLAgent
|
||||
observation_space:
|
||||
network:
|
||||
type: UC2BlueObservation
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: router_1 #TODO: more sub-options here
|
||||
- node_ref: switch_1
|
||||
- node_ref: switch_2
|
||||
- node_ref: domain_controller
|
||||
services:
|
||||
- service_ref: domain_controller_dns_server
|
||||
- node_ref: web_server
|
||||
services:
|
||||
- service_ref: web_server_database_client
|
||||
- node_ref: database_server
|
||||
services:
|
||||
- service_ref: database_service
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_ref: backup_server
|
||||
# services:
|
||||
# - service_ref: backup_service
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
router_node_ref: router_1
|
||||
ics: null
|
||||
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- DONOTHING
|
||||
- NODE_SERVICE_SCAN
|
||||
- NODE_SERVICE_STOP
|
||||
# - NODE_SERVICE_START
|
||||
# - NODE_SERVICE_PAUSE
|
||||
# - NODE_SERVICE_RESUME
|
||||
# - NODE_SERVICE_RESTART
|
||||
# - NODE_SERVICE_DISABLE
|
||||
# - NODE_SERVICE_ENABLE
|
||||
# - NODE_FILE_SCAN
|
||||
# - NODE_FILE_CHECKHASH
|
||||
# - NODE_FILE_DELETE
|
||||
# - NODE_FILE_REPAIR
|
||||
# - NODE_FILE_RESTORE
|
||||
# - NODE_FOLDER_SCAN
|
||||
# - NODE_FOLDER_CHECKHASH
|
||||
# - NODE_FOLDER_REPAIR
|
||||
# - NODE_FOLDER_RESTORE
|
||||
# - NODE_OS_SCAN
|
||||
# - NODE_SHUTDOWN
|
||||
# - NODE_STARTUP
|
||||
# - NODE_RESET
|
||||
# - NETWORK_ACL_ADDRULE
|
||||
# - NETWORK_ACL_REMOVERULE
|
||||
# - NETWORK_NIC_ENABLE
|
||||
- NETWORK_NIC_DISABLE
|
||||
|
||||
action_map:
|
||||
0:
|
||||
- action: DONOTHING
|
||||
options: {}
|
||||
# scan webapp service
|
||||
1:
|
||||
- action: NODE_SERVICE_SCAN
|
||||
options:
|
||||
- node_id: 2
|
||||
- service_id: 1
|
||||
# stop webapp service
|
||||
2:
|
||||
- action: NODE_SERVICE_STOP
|
||||
options:
|
||||
- node_id: 2
|
||||
- service_id: 1
|
||||
# start webapp service
|
||||
3:
|
||||
- action: "NODE_SERVICE_START"
|
||||
options:
|
||||
- node_id: 2
|
||||
- service_id: 1
|
||||
4:
|
||||
- action: "NODE_SERVICE_PAUSE"
|
||||
options:
|
||||
- node_id: 2
|
||||
- service_id: 1
|
||||
5:
|
||||
- action: "NODE_SERVICE_RESUME"
|
||||
options:
|
||||
- node_id: 2
|
||||
- service_id: 1
|
||||
6:
|
||||
- action: "NODE_SERVICE_RESTART"
|
||||
options:
|
||||
- node_id: 2
|
||||
- service_id: 1
|
||||
7:
|
||||
- action: "NODE_SERVICE_DISABLE"
|
||||
options:
|
||||
- node_id: 2
|
||||
- service_id: 1
|
||||
8:
|
||||
- action: "NODE_SERVICE_ENABLE"
|
||||
options:
|
||||
- node_id: 2
|
||||
- service_id: 1
|
||||
9:
|
||||
- action: "NODE_FILE_SCAN"
|
||||
options:
|
||||
- node_id: 3
|
||||
- folder_id: 1
|
||||
- file_id: 1
|
||||
10:
|
||||
- action: "NODE_FILE_CHECKHASH"
|
||||
options:
|
||||
- node_id: 3
|
||||
- folder_id: 1
|
||||
- file_id: 1
|
||||
11:
|
||||
- action: "NODE_FILE_DELETE"
|
||||
options:
|
||||
- node_id: 3
|
||||
- folder_id: 1
|
||||
- file_id: 1
|
||||
12:
|
||||
- action: "NODE_FILE_REPAIR"
|
||||
options:
|
||||
- node_id: 3
|
||||
- folder_id: 1
|
||||
- file_id: 1
|
||||
13:
|
||||
- action: "NODE_FILE_RESTORE"
|
||||
options:
|
||||
- node_id: 3
|
||||
- folder_id: 1
|
||||
- file_id: 1
|
||||
14:
|
||||
- action: "NODE_FOLDER_SCAN"
|
||||
options:
|
||||
- node_id: 3
|
||||
- folder_id: 1
|
||||
15:
|
||||
- action: "NODE_FOLDER_CHECKHASH"
|
||||
options:
|
||||
- node_id: 3
|
||||
- folder_id: 1
|
||||
16:
|
||||
- action: "NODE_FOLDER_REPAIR"
|
||||
options:
|
||||
- node_id: 3
|
||||
- folder_id: 1
|
||||
17:
|
||||
- action: "NODE_FOLDER_RESTORE"
|
||||
options:
|
||||
- node_id: 3
|
||||
- folder_id: 1
|
||||
18:
|
||||
- action: "NODE_OS_SCAN"
|
||||
options:
|
||||
- node_id: 3
|
||||
19:
|
||||
- action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
- node_id: 6
|
||||
20:
|
||||
- action: "NODE_STARTUP"
|
||||
options:
|
||||
- node_id: 6
|
||||
21:
|
||||
- action: "NODE_RESET"
|
||||
options:
|
||||
- node_id: 6
|
||||
22:
|
||||
- action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
- position: 6
|
||||
- permission: 2
|
||||
- source_node_id: ...
|
||||
- dest_node_id: ...
|
||||
- source_port_id: ...
|
||||
- dest_port_id: ...
|
||||
- protocol_id: ...
|
||||
23:
|
||||
- action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
- position: 5
|
||||
- permission: 2
|
||||
- source_node_id: ...
|
||||
- dest_node_id: ...
|
||||
- source_port_id: ...
|
||||
- dest_port_id: ...
|
||||
- protocol_id: ...
|
||||
24:
|
||||
- action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
- position: 4
|
||||
- permission: 2
|
||||
- source_node_id: ...
|
||||
- dest_node_id: ...
|
||||
- source_port_id: ...
|
||||
- dest_port_id: ...
|
||||
- protocol_id: ...
|
||||
25:
|
||||
- action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
- position: 3
|
||||
- permission: 2
|
||||
- source_node_id: ...
|
||||
- dest_node_id: ...
|
||||
- source_port_id: ...
|
||||
- dest_port_id: ...
|
||||
- protocol_id: ...
|
||||
26:
|
||||
- action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
- position: 2
|
||||
- permission: 2
|
||||
- source_node_id: ...
|
||||
- dest_node_id: ...
|
||||
- source_port_id: ...
|
||||
- dest_port_id: ...
|
||||
- protocol_id: ...
|
||||
27:
|
||||
- action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
- position: 1
|
||||
- permission: 2
|
||||
- source_node_id: ...
|
||||
- dest_node_id: ...
|
||||
- source_port_id: ...
|
||||
- dest_port_id: ...
|
||||
- protocol_id: ...
|
||||
28:
|
||||
- action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 0
|
||||
29:
|
||||
- action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 1
|
||||
30:
|
||||
- action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 2
|
||||
31:
|
||||
- action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 3
|
||||
32:
|
||||
- action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 4
|
||||
33:
|
||||
- action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 5
|
||||
34:
|
||||
- action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 6
|
||||
35:
|
||||
- action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 7
|
||||
36:
|
||||
- action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 8
|
||||
37:
|
||||
- action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 9
|
||||
38:
|
||||
- action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
- node_id: 6
|
||||
- nic_index: 1
|
||||
39:
|
||||
- action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
- node_id: 6
|
||||
- nic_index: 1
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: router_1
|
||||
- node_ref: switch_1
|
||||
- node_ref: switch_2
|
||||
- node_ref: domain_controller
|
||||
- node_ref: web_server
|
||||
- node_ref: database_server
|
||||
@@ -99,18 +394,13 @@ game_config:
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
links:
|
||||
- link_ref: ... #
|
||||
acl: ... #
|
||||
ics: ... #
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
|
||||
|
||||
action_space:
|
||||
actions:
|
||||
- type: DO_NOTHING
|
||||
network:
|
||||
nodes:
|
||||
- node_ref: router_1
|
||||
reward_function:
|
||||
# ...
|
||||
agent_settings:
|
||||
@@ -175,7 +465,6 @@ simulation:
|
||||
domain_mapping:
|
||||
arcd.com: 192.168.1.12 # web server
|
||||
|
||||
|
||||
- ref: web_server
|
||||
type: server
|
||||
hostname: web_server
|
||||
@@ -200,7 +489,6 @@ simulation:
|
||||
- ref: database_service
|
||||
type: DatabaseService
|
||||
|
||||
|
||||
- ref: backup_server
|
||||
type: server
|
||||
hostname: backup_server
|
||||
@@ -224,7 +512,6 @@ simulation:
|
||||
ip_address: 192.168.10.110
|
||||
subnet_mask: 255.255.255.0
|
||||
|
||||
|
||||
- ref: client_1
|
||||
type: computer
|
||||
hostname: client_1
|
||||
@@ -251,7 +538,6 @@ simulation:
|
||||
- ref: client_2_dns_client
|
||||
type: DNSClient
|
||||
|
||||
|
||||
links:
|
||||
- ref: router_1___switch_1
|
||||
endpoint_a_ref: router_1
|
||||
|
||||
489
sandbox.ipynb
489
sandbox.ipynb
@@ -2,9 +2,18 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The autoreload extension is already loaded. To reload it, use:\n",
|
||||
" %reload_ext autoreload\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
@@ -12,7 +21,381 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.game.session import PrimaiteSession\n",
|
||||
"from primaite.simulator.sim_container import Simulation\n",
|
||||
"from primaite.game.agent.interface import AbstractAgent\n",
|
||||
"from primaite.simulator.network.networks import arcd_uc2_network\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sess = PrimaiteSession()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network = sess.simulation.network"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from ipaddress import IPv4Address\n",
|
||||
"\n",
|
||||
"from primaite.simulator.network.container import Network\n",
|
||||
"from primaite.simulator.network.hardware.base import NIC\n",
|
||||
"from primaite.simulator.network.hardware.nodes.computer import Computer\n",
|
||||
"from primaite.simulator.network.hardware.nodes.router import ACLAction, Router\n",
|
||||
"from primaite.simulator.network.hardware.nodes.server import Server\n",
|
||||
"from primaite.simulator.network.hardware.nodes.switch import Switch\n",
|
||||
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
|
||||
"from primaite.simulator.network.transmission.transport_layer import Port\n",
|
||||
"from primaite.simulator.system.applications.database_client import DatabaseClient\n",
|
||||
"from primaite.simulator.system.services.database_service import DatabaseService\n",
|
||||
"from primaite.simulator.system.services.dns_client import DNSClient\n",
|
||||
"from primaite.simulator.system.services.dns_server import DNSServer\n",
|
||||
"from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2023-10-02 15:10:20,422: Added node 6abb7664-4d17-45ff-a3c7-dbcccffcfd6d to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
|
||||
"2023-10-02 15:10:20,424: Added node 3edbc521-3c80-47e3-8017-dbc38fb00a73 to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
|
||||
"2023-10-02 15:10:20,428: Added node 94457fb9-04a1-4dc1-9ff7-b64df0da7424 to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
|
||||
"2023-10-02 15:10:20,432: Added node 0d311d72-139c-41bf-aef7-fa9b01b124d7 to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
|
||||
"2023-10-02 15:10:20,439: Added node 6161e785-f377-48de-aa4f-20d3646da635 to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
|
||||
"2023-10-02 15:10:20,444: Added node 55a9e9f8-ee3a-4c28-9b6d-c0c0d78a3f6a to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
|
||||
"2023-10-02 15:10:20,447: Added node 2f04ca45-3439-489a-81f7-41cca5ae8adc to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
|
||||
"2023-10-02 15:10:20,531: Added node 98660c30-8e48-4b96-967a-d62ca71b4d6d to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
|
||||
"2023-10-02 15:10:20,545: Added node 1a184184-b204-40de-986a-4d7459036dbe to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
|
||||
"2023-10-02 15:10:20,551: Added node 17b92b9a-6805-4677-85f7-c0c0521a6e25 to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
|
||||
"2023-10-02 15:10:20,555::ERROR::primaite.simulator.network.hardware.base::175::NIC da:f3:1b:87:24:20/192.168.10.110 cannot be enabled as it is not connected to a Link\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"router_1 = Router(hostname=\"router_1\", num_ports=5)\n",
|
||||
"router_1.power_on()\n",
|
||||
"router_1.configure_port(port=1, ip_address=\"192.168.1.1\", subnet_mask=\"255.255.255.0\")\n",
|
||||
"router_1.configure_port(port=2, ip_address=\"192.168.10.1\", subnet_mask=\"255.255.255.0\")\n",
|
||||
"\n",
|
||||
"# Switch 1\n",
|
||||
"switch_1 = Switch(hostname=\"switch_1\", num_ports=8)\n",
|
||||
"switch_1.power_on()\n",
|
||||
"network.connect(endpoint_a=router_1.ethernet_ports[1], endpoint_b=switch_1.switch_ports[8])\n",
|
||||
"router_1.enable_port(1)\n",
|
||||
"\n",
|
||||
"# Switch 2\n",
|
||||
"switch_2 = Switch(hostname=\"switch_2\", num_ports=8)\n",
|
||||
"switch_2.power_on()\n",
|
||||
"network.connect(endpoint_a=router_1.ethernet_ports[2], endpoint_b=switch_2.switch_ports[8])\n",
|
||||
"router_1.enable_port(2)\n",
|
||||
"\n",
|
||||
"# Client 1\n",
|
||||
"client_1 = Computer(\n",
|
||||
" hostname=\"client_1\",\n",
|
||||
" ip_address=\"192.168.10.21\",\n",
|
||||
" subnet_mask=\"255.255.255.0\",\n",
|
||||
" default_gateway=\"192.168.10.1\",\n",
|
||||
" dns_server=IPv4Address(\"192.168.1.10\"),\n",
|
||||
")\n",
|
||||
"client_1.power_on()\n",
|
||||
"client_1.software_manager.install(DNSClient)\n",
|
||||
"client_1_dns_client_service: DNSServer = client_1.software_manager.software[\"DNSClient\"] # noqa\n",
|
||||
"client_1_dns_client_service.start()\n",
|
||||
"network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])\n",
|
||||
"client_1.software_manager.install(DataManipulationBot)\n",
|
||||
"db_manipulation_bot: DataManipulationBot = client_1.software_manager.software[\"DataManipulationBot\"]\n",
|
||||
"db_manipulation_bot.configure(server_ip_address=IPv4Address(\"192.168.1.14\"), payload=\"DROP TABLE IF EXISTS user;\")\n",
|
||||
"\n",
|
||||
"# Client 2\n",
|
||||
"client_2 = Computer(\n",
|
||||
" hostname=\"client_2\",\n",
|
||||
" ip_address=\"192.168.10.22\",\n",
|
||||
" subnet_mask=\"255.255.255.0\",\n",
|
||||
" default_gateway=\"192.168.10.1\",\n",
|
||||
" dns_server=IPv4Address(\"192.168.1.10\"),\n",
|
||||
")\n",
|
||||
"client_2.power_on()\n",
|
||||
"client_2.software_manager.install(DNSClient)\n",
|
||||
"client_2_dns_client_service: DNSServer = client_2.software_manager.software[\"DNSClient\"] # noqa\n",
|
||||
"client_2_dns_client_service.start()\n",
|
||||
"network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2])\n",
|
||||
"\n",
|
||||
"# Domain Controller\n",
|
||||
"domain_controller = Server(\n",
|
||||
" hostname=\"domain_controller\",\n",
|
||||
" ip_address=\"192.168.1.10\",\n",
|
||||
" subnet_mask=\"255.255.255.0\",\n",
|
||||
" default_gateway=\"192.168.1.1\",\n",
|
||||
")\n",
|
||||
"domain_controller.power_on()\n",
|
||||
"domain_controller.software_manager.install(DNSServer)\n",
|
||||
"\n",
|
||||
"network.connect(endpoint_b=domain_controller.ethernet_port[1], endpoint_a=switch_1.switch_ports[1])\n",
|
||||
"\n",
|
||||
"# Database Server\n",
|
||||
"database_server = Server(\n",
|
||||
" hostname=\"database_server\",\n",
|
||||
" ip_address=\"192.168.1.14\",\n",
|
||||
" subnet_mask=\"255.255.255.0\",\n",
|
||||
" default_gateway=\"192.168.1.1\",\n",
|
||||
" dns_server=IPv4Address(\"192.168.1.10\"),\n",
|
||||
")\n",
|
||||
"database_server.power_on()\n",
|
||||
"network.connect(endpoint_b=database_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[3])\n",
|
||||
"\n",
|
||||
"ddl = \"\"\"\n",
|
||||
"CREATE TABLE IF NOT EXISTS user (\n",
|
||||
"id INTEGER PRIMARY KEY AUTOINCREMENT,\n",
|
||||
"name VARCHAR(50) NOT NULL,\n",
|
||||
"email VARCHAR(50) NOT NULL,\n",
|
||||
"age INT,\n",
|
||||
"city VARCHAR(50),\n",
|
||||
"occupation VARCHAR(50)\n",
|
||||
");\"\"\"\n",
|
||||
"\n",
|
||||
"user_insert_statements = [\n",
|
||||
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');\", # noqa\n",
|
||||
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');\", # noqa\n",
|
||||
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');\", # noqa\n",
|
||||
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');\", # noqa\n",
|
||||
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');\", # noqa\n",
|
||||
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');\", # noqa\n",
|
||||
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');\", # noqa\n",
|
||||
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');\", # noqa\n",
|
||||
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');\", # noqa\n",
|
||||
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');\", # noqa\n",
|
||||
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');\", # noqa\n",
|
||||
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');\", # noqa\n",
|
||||
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');\", # noqa\n",
|
||||
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');\", # noqa\n",
|
||||
"]\n",
|
||||
"database_server.software_manager.install(DatabaseService)\n",
|
||||
"database_service: DatabaseService = database_server.software_manager.software[\"DatabaseService\"] # noqa\n",
|
||||
"database_service.start()\n",
|
||||
"database_service._process_sql(ddl, None) # noqa\n",
|
||||
"for insert_statement in user_insert_statements:\n",
|
||||
" database_service._process_sql(insert_statement, None) # noqa\n",
|
||||
"\n",
|
||||
"# Web Server\n",
|
||||
"web_server = Server(\n",
|
||||
" hostname=\"web_server\",\n",
|
||||
" ip_address=\"192.168.1.12\",\n",
|
||||
" subnet_mask=\"255.255.255.0\",\n",
|
||||
" default_gateway=\"192.168.1.1\",\n",
|
||||
" dns_server=IPv4Address(\"192.168.1.10\"),\n",
|
||||
")\n",
|
||||
"web_server.power_on()\n",
|
||||
"web_server.software_manager.install(DatabaseClient)\n",
|
||||
"\n",
|
||||
"database_client: DatabaseClient = web_server.software_manager.software[\"DatabaseClient\"]\n",
|
||||
"database_client.configure(server_ip_address=IPv4Address(\"192.168.1.14\"))\n",
|
||||
"network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2])\n",
|
||||
"database_client.run()\n",
|
||||
"database_client.connect()\n",
|
||||
"\n",
|
||||
"# register the web_server to a domain\n",
|
||||
"dns_server_service: DNSServer = domain_controller.software_manager.software[\"DNSServer\"] # noqa\n",
|
||||
"dns_server_service.start()\n",
|
||||
"dns_server_service.dns_register(\"arcd.com\", web_server.ip_address)\n",
|
||||
"\n",
|
||||
"# Backup Server\n",
|
||||
"backup_server = Server(\n",
|
||||
" hostname=\"backup_server\",\n",
|
||||
" ip_address=\"192.168.1.16\",\n",
|
||||
" subnet_mask=\"255.255.255.0\",\n",
|
||||
" default_gateway=\"192.168.1.1\",\n",
|
||||
" dns_server=IPv4Address(\"192.168.1.10\"),\n",
|
||||
")\n",
|
||||
"backup_server.power_on()\n",
|
||||
"network.connect(endpoint_b=backup_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[4])\n",
|
||||
"\n",
|
||||
"# Security Suite\n",
|
||||
"security_suite = Server(\n",
|
||||
" hostname=\"security_suite\",\n",
|
||||
" ip_address=\"192.168.1.110\",\n",
|
||||
" subnet_mask=\"255.255.255.0\",\n",
|
||||
" default_gateway=\"192.168.1.1\",\n",
|
||||
" dns_server=IPv4Address(\"192.168.1.10\"),\n",
|
||||
")\n",
|
||||
"security_suite.power_on()\n",
|
||||
"network.connect(endpoint_b=security_suite.ethernet_port[1], endpoint_a=switch_1.switch_ports[7])\n",
|
||||
"security_suite.connect_nic(NIC(ip_address=\"192.168.10.110\", subnet_mask=\"255.255.255.0\"))\n",
|
||||
"network.connect(endpoint_b=security_suite.ethernet_port[2], endpoint_a=switch_2.switch_ports[7])\n",
|
||||
"\n",
|
||||
"router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)\n",
|
||||
"\n",
|
||||
"router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)\n",
|
||||
"\n",
|
||||
"# Allow PostgreSQL requests\n",
|
||||
"router_1.acl.add_rule(\n",
|
||||
" action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Allow DNS requests\n",
|
||||
"router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"node_uuid_list = list(sess.simulation.network.nodes.keys())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.game.agent.actions import ActionManager"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"actman = ActionManager(sess.simulation, [\"DONOTHING\", \"NODE_SERVICE_SCAN\", \"NODE_SERVICE_STOP\", \"NODE_FOLDER_SCAN\"],node_uuid_list,act_map={\n",
|
||||
" 0:{\n",
|
||||
" \"action\": \"DONOTHING\",\n",
|
||||
" \"options\": {}\n",
|
||||
" },\n",
|
||||
" 1:{\n",
|
||||
" \"action\": \"NODE_SERVICE_SCAN\",\n",
|
||||
" \"options\": {\"node_id\":0, \"service_id\":0},\n",
|
||||
" },\n",
|
||||
" 2:{\n",
|
||||
" \"action\": \"NODE_SERVICE_SCAN\",\n",
|
||||
" \"options\": {\"node_id\":1, \"service_id\":0},\n",
|
||||
" },\n",
|
||||
" 3:{\n",
|
||||
" \"action\": \"NODE_FOLDER_SCAN\",\n",
|
||||
" \"options\": {\"node_id\":4, \"folder_id\":0},\n",
|
||||
" }\n",
|
||||
"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"act_id, act_options = actman.get_action(3)\n",
|
||||
"my_trial_act = actman.form_request(action_identifier=act_id, action_options=act_options)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sess.simulation.apply_action(my_trial_act)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['network',\n",
|
||||
" 'node',\n",
|
||||
" '6161e785-f377-48de-aa4f-20d3646da635',\n",
|
||||
" 'file_system',\n",
|
||||
" 'folder',\n",
|
||||
" '5aefe92b-923c-4684-b3bf-e78dd18d4771',\n",
|
||||
" 'scan']"
|
||||
]
|
||||
},
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"my_trial_act"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sess.step()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sess.step_counter"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from gym import spaces"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sp = spaces.Tuple( (spaces.MultiDiscrete([3, 2]), spaces.MultiDiscrete([3, 2]), spaces.MultiDiscrete([3, 2]),))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sp.sample()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -39,40 +422,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2023-09-26 12:19:50,895: Added node 0fb262e1-a714-420a-aec7-be37f0deeb75 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
|
||||
"2023-09-26 12:19:50,898: Added node 310ca8d7-01e0-401e-b604-705c290e5376 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
|
||||
"2023-09-26 12:19:50,900: Added node b3b08f1f-7805-47b2-bdb6-3d83098cd740 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
|
||||
"2023-09-26 12:19:50,903: Added node adb37f3e-2307-4123-bff3-01f125883be8 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
|
||||
"2023-09-26 12:19:50,906: Added node 1a490716-2ccd-452d-b87e-324d29120b59 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
|
||||
"2023-09-26 12:19:50,911: Added node 033460d8-0249-4bdd-aaf0-751b24cc0a1e to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
|
||||
"2023-09-26 12:19:50,914: Added node 1e7e4e49-78bf-4031-8372-ee71902720f3 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
|
||||
"2023-09-26 12:19:50,916: Added node c9f24a13-e5c8-437b-9234-b0c3f8120e2c to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
|
||||
"2023-09-26 12:19:50,920: Added node c881f3ee-2176-493b-a6c2-cad829bf0b6d to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
|
||||
"2023-09-26 12:19:50,922: Added node a3ea75d8-bc2c-4713-92a4-2588b4f43ed6 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"service type not found DatabaseBackup\n",
|
||||
"service type not found WebBrowser\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# import yaml\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from typing import Dict\n",
|
||||
"from primaite.game.agent.interface import AbstractAgent\n",
|
||||
"from primaite.game.agent.observations import AclObservation, FileObservation, FolderObservation, ICSObservation, LinkObservation, NicObservation, NodeObservation, NullObservation, ServiceObservation, UC2BlueObservation, UC2RedObservation\n",
|
||||
"from primaite.simulator.network.hardware.base import NIC, Link, Node\n",
|
||||
"from primaite.simulator.system.services.service import Service\n",
|
||||
"\n",
|
||||
@@ -130,8 +489,8 @@
|
||||
" subnet_mask=port_cfg['subnet_mask'])\n",
|
||||
" if 'acl' in node_cfg:\n",
|
||||
" for r_num, r_cfg in node_cfg['acl'].items():\n",
|
||||
" # excuse the uncommon walrus operator ` := `. It's just here as a shorthand, so that we can do\n",
|
||||
" # both of these things once: check if a key is defined, access and convert it to a \n",
|
||||
" # excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating \n",
|
||||
" # this: 'r_cfg.get('src_port')'\n",
|
||||
" # Port/IPProtocol. TODO Refactor\n",
|
||||
" new_node.acl.add_rule(\n",
|
||||
" action = ACLAction[r_cfg['action']],\n",
|
||||
@@ -211,8 +570,70 @@
|
||||
" action_space_cfg = agent_cfg['action_space']\n",
|
||||
" observation_space_cfg = agent_cfg['observation_space']\n",
|
||||
" reward_function_cfg = agent_cfg['reward_function']\n",
|
||||
" \n",
|
||||
" # CREATE OBSERVATION SPACE\n",
|
||||
" if observation_space_cfg is None:\n",
|
||||
" obs_space = NullObservation()\n",
|
||||
" elif observation_space_cfg['type'] == 'UC2BlueObservation':\n",
|
||||
" node_obs_list = []\n",
|
||||
" link_obs_list = []\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" #node ip to index maps ip addresses to node id, as there are potentially multiple nics on a node, there are multiple ip addresses\n",
|
||||
" node_ip_to_index = {}\n",
|
||||
" for node_idx, node_cfg in enumerate(nodes_cfg):\n",
|
||||
" n_ref = node_cfg['ref']\n",
|
||||
" n_obj = net.nodes[ref_map_nodes[n_ref]]\n",
|
||||
" for nic_uuid, nic_obj in n_obj.nics.items():\n",
|
||||
" node_ip_to_index[nic_obj.ip_address] = node_idx + 2\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" for node_obs_cfg in observation_space_cfg['options']['nodes']:\n",
|
||||
" node_ref = node_obs_cfg['node_ref']\n",
|
||||
" folder_obs_list = []\n",
|
||||
" service_obs_list = []\n",
|
||||
" if 'services' in node_obs_cfg:\n",
|
||||
" for service_obs_cfg in node_obs_cfg['services']:\n",
|
||||
" service_obs_list.append(ServiceObservation(where=['network','nodes',ref_map_nodes[node_ref],'services',ref_map_services[service_obs_cfg['service_ref']]]))\n",
|
||||
" if 'folders' in node_obs_cfg:\n",
|
||||
" for folder_obs_cfg in node_obs_cfg['folders']:\n",
|
||||
" file_obs_list = []\n",
|
||||
" if 'files' in folder_obs_cfg:\n",
|
||||
" for file_obs_cfg in folder_obs_cfg['files']:\n",
|
||||
" file_obs_list.append(FileObservation(where=['network','nodes',ref_map_nodes[node_ref], 'folders',folder_obs_cfg['folder_name'], 'files', file_obs_cfg['file_name']]))\n",
|
||||
" folder_obs_list.append(FolderObservation(where=['network','nodes',ref_map_nodes[node_ref], 'folders',folder_obs_cfg['folder_name']], files=file_obs_list))\n",
|
||||
" nic_obs_list = []\n",
|
||||
" for nic_uuid in net.nodes[ref_map_nodes[node_obs_cfg['node_ref']]].nics.keys():\n",
|
||||
" nic_obs_list.append(NicObservation(where=['network','nodes',ref_map_nodes[node_ref],'NICs',nic_uuid]))\n",
|
||||
" node_obs_list.append(NodeObservation(where=['network','nodes',ref_map_nodes[node_ref]], services=service_obs_list, folders=folder_obs_list,nics=nic_obs_list, logon_status=False))\n",
|
||||
" for link_obs_cfg in observation_space_cfg['options']['links']:\n",
|
||||
" link_ref = link_obs_cfg['link_ref']\n",
|
||||
" link_obs_list.append(LinkObservation(where=['network' ,'links', ref_map_links[link_ref]]))\n",
|
||||
"\n",
|
||||
" acl_obs = AclObservation(node_ip_to_id=node_ip_to_index, ports=game_cfg['ports'], protocols=game_cfg['ports'], where=['network','nodes',observation_space_cfg['options']['acl']['router_node_ref']])\n",
|
||||
" obs_space = UC2BlueObservation(nodes=node_obs_list,links=link_obs_list,acl=acl_obs, ics=ICSObservation())\n",
|
||||
" elif observation_space_cfg['type'] == 'UC2RedObservation':\n",
|
||||
" obs_space = UC2RedObservation.from_config(observation_space_cfg['options'], sim=sim)\n",
|
||||
" else:\n",
|
||||
" print(\"observation space config not specified correctly.\")\n",
|
||||
" obs_space = NullObservation()\n",
|
||||
" \n",
|
||||
" # CREATE ACTION SPACE\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # CREATE REWARD FUNCTION\n",
|
||||
"\n",
|
||||
" # CREATE AGENT\n",
|
||||
" if agent_type == 'GreenWebBrowsingAgent':\n",
|
||||
" new_agent = GreenWebBrowsingAgent()\n",
|
||||
" ...\n",
|
||||
" elif agent_type == 'GATERLAgent':\n",
|
||||
" ...\n",
|
||||
" elif agent_type == 'RedDatabaseCorruptingAgent':\n",
|
||||
" ...\n",
|
||||
" else:\n",
|
||||
" print(\"agent type not found\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" #4. set up agents' actions and observation spaces.\n",
|
||||
@@ -228,7 +649,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(s.simulation.describe_state())"
|
||||
"s.agents"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,21 +1,374 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
import itertools
|
||||
|
||||
|
||||
class AbstractAction(BaseModel):
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
from gym import spaces
|
||||
|
||||
class ExecutionDefiniton(ABC):
|
||||
"""
|
||||
Converter from actions to simulator requests.
|
||||
|
||||
Allows adding extra data/context that defines in more detail what an action means.
|
||||
"""
|
||||
|
||||
"""
|
||||
Examples:
|
||||
('node', 'service', 'scan', 2, 0) means scan the first service on node index 2
|
||||
-> ['network', 'nodes', <node-idx-2-uuid>, 'services', <svc-idx-0-uuid>, 'scan'w]
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class AbstractAction(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, action: Any) -> List[str]:
|
||||
"""_summary_
|
||||
|
||||
:param action: _description_
|
||||
:type action: Any
|
||||
:return: _description_
|
||||
:rtype: List[str]
|
||||
def __init__(self, manager:"ActionManager", **kwargs) -> None:
|
||||
"""
|
||||
Init method for action.
|
||||
|
||||
All action init functions should accept **kwargs as a way of ignoring extra arguments.
|
||||
|
||||
Since many parameters are defined for the action space as a whole (such as max files per folder, max services
|
||||
per node), we need to pass those options to every action that gets created. To pervent verbosity, these
|
||||
parameters are just broadcasted to all actions and the actions can pay attention to the ones that apply.
|
||||
"""
|
||||
self.name:str = ""
|
||||
"""Human-readable action identifier used for printing, logging, and reporting."""
|
||||
self.shape = (0,)
|
||||
"""Tuple describing number of options for each parameter of this action. Can be passed to
|
||||
gym.spaces.MultiDiscrete to form a valid space."""
|
||||
self.manager:ActionManager = manager
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def form_request(self) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return []
|
||||
|
||||
|
||||
class DoNothingAction(AbstractAction):
|
||||
def __init__(self, manager:"ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.name = "DONOTHING"
|
||||
self.shape = (1,)
|
||||
|
||||
def form_request(self) -> List[str]:
|
||||
return ["do_nothing"]
|
||||
|
||||
class NodeServiceAbstractAction(AbstractAction):
|
||||
"""
|
||||
Base class for service actions.
|
||||
|
||||
Any action which applies to a service and uses node_id and service_id as its only two parameters can inherit from
|
||||
this base class.
|
||||
"""
|
||||
@abstractmethod
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Tuple[int] = (num_nodes, num_services)
|
||||
self.verb:str
|
||||
|
||||
def form_request(self, node_id:int, service_id:int) -> List[str]:
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
service_uuid = self.manager.get_service_uuid_by_idx(node_id, service_id)
|
||||
if node_uuid is None or service_uuid is None:
|
||||
return ["do_nothing"]
|
||||
return ['network', 'node', node_uuid, 'services', service_uuid, self.verb]
|
||||
|
||||
class NodeServiceScanAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = "scan"
|
||||
|
||||
class NodeServiceStopAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = "stop"
|
||||
|
||||
class NodeServiceStartAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = "start"
|
||||
|
||||
class NodeServicePauseAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = "pause"
|
||||
|
||||
class NodeServiceResumeAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = "resume"
|
||||
|
||||
class NodeServiceRestartAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = "restart"
|
||||
|
||||
class NodeServiceDisableAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = "disable"
|
||||
|
||||
class NodeServiceEnableAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = "enable"
|
||||
|
||||
|
||||
|
||||
class NodeFolderAbstractAction(AbstractAction):
|
||||
@abstractmethod
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape = (num_nodes, num_folders)
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id:int, folder_id:int) -> List[str]:
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
|
||||
if node_uuid is None or folder_uuid is None:
|
||||
return ["do_nothing"]
|
||||
return ['network', 'node', node_uuid, 'file_system', 'folder', folder_uuid, self.verb]
|
||||
|
||||
class NodeFolderScanAction(NodeFolderAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, **kwargs)
|
||||
self.verb:str = "scan"
|
||||
|
||||
class NodeFolderCheckhashAction(NodeFolderAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, **kwargs)
|
||||
self.verb:str = "checkhash"
|
||||
|
||||
class NodeFolderRepairAction(NodeFolderAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, **kwargs)
|
||||
self.verb:str = "repair"
|
||||
|
||||
class NodeFolderRestoreAction(NodeFolderAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, **kwargs)
|
||||
self.verb:str = "restore"
|
||||
|
||||
|
||||
class NodeFileAbstractAction(AbstractAction):
|
||||
@abstractmethod
|
||||
def __init__(self, manager:"ActionManager", num_nodes:int, num_folders:int, num_files:int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape:Tuple[int] = (num_nodes, num_folders, num_files)
|
||||
self.verb:str
|
||||
|
||||
def form_request(self, node_id:int, folder_id:int, file_id:int) -> List[str]:
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
|
||||
file_uuid = self.manager.get_file_uuid_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id)
|
||||
if node_uuid is None or folder_uuid is None or file_uuid is None:
|
||||
return ["do_nothing"]
|
||||
return ['network', 'node', node_uuid, 'file_system', 'folder', folder_uuid, 'files', file_uuid, self.verb]
|
||||
|
||||
class NodeFileScanAction(NodeFileAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
|
||||
self.verb = "scan"
|
||||
|
||||
class NodeFileCheckhashAction(NodeFileAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
|
||||
self.verb = "checkhash"
|
||||
|
||||
class NodeFileDeleteAction(NodeFileAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
|
||||
self.verb = "delete"
|
||||
|
||||
class NodeFileRepairAction(NodeFileAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
|
||||
self.verb = "repair"
|
||||
|
||||
class NodeFileRestoreAction(NodeFileAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
|
||||
self.verb = "restore"
|
||||
|
||||
class NodeAbstractAction(AbstractAction):
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Tuple[int] = (num_nodes,)
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id:int) -> List[str]:
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
return ["network", "node", node_uuid, self.verb]
|
||||
|
||||
class NodeOSScanAction(NodeAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = 'scan'
|
||||
|
||||
class NodeShutdownAction(NodeAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = 'shutdown'
|
||||
|
||||
class NodeStartupAction(NodeAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = 'start'
|
||||
|
||||
class NodeResetAction(NodeAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = 'reset'
|
||||
|
||||
class NetworkACLAddRuleAction(AbstractAction):
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
num_permissions = 2
|
||||
self.shape: Tuple[int] = (max_acl_rules, num_permissions, num_nics, num_nics, num_ports, num_ports, num_protocols)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class ActionManager:
|
||||
# let the action manager handle the conversion of action spaces into a single discrete integer space.
|
||||
#
|
||||
|
||||
|
||||
# when action space is created, it will take subspaces and generate an action map by enumerating all possibilities,
|
||||
# BUT, the action map can be provided in the config, in which case it will use that.
|
||||
|
||||
# action map is basically just a mapping between integer and CAOS action (incl. parameter values)
|
||||
# for example the action map can be:
|
||||
# 0: DONOTHING
|
||||
# 1: NODE, FILE, SCAN, NODEID=2, FOLDERID=1, FILEID=0
|
||||
# 2: ......
|
||||
__act_class_identifiers:Dict[str,type] = {
|
||||
"DONOTHING": DoNothingAction,
|
||||
"NODE_SERVICE_SCAN": NodeServiceScanAction,
|
||||
"NODE_SERVICE_STOP": NodeServiceStopAction,
|
||||
# "NODE_SERVICE_START": NodeServiceStartAction,
|
||||
# "NODE_SERVICE_PAUSE": NodeServicePauseAction,
|
||||
# "NODE_SERVICE_RESUME": NodeServiceResumeAction,
|
||||
# "NODE_SERVICE_RESTART": NodeServiceRestartAction,
|
||||
# "NODE_SERVICE_DISABLE": NodeServiceDisableAction,
|
||||
# "NODE_SERVICE_ENABLE": NodeServiceEnableAction,
|
||||
# "NODE_FILE_SCAN": NodeFileScanAction,
|
||||
# "NODE_FILE_CHECKHASH": NodeFileCheckhashAction,
|
||||
# "NODE_FILE_DELETE": NodeFileDeleteAction,
|
||||
# "NODE_FILE_REPAIR": NodeFileRepairAction,
|
||||
# "NODE_FILE_RESTORE": NodeFileRestoreAction,
|
||||
"NODE_FOLDER_SCAN": NodeFolderScanAction,
|
||||
# "NODE_FOLDER_CHECKHASH": NodeFolderCheckhashAction,
|
||||
# "NODE_FOLDER_REPAIR": NodeFolderRepairAction,
|
||||
# "NODE_FOLDER_RESTORE": NodeFolderRestoreAction,
|
||||
# "NODE_OS_SCAN": NodeOSScanAction,
|
||||
# "NODE_SHUTDOWN": NodeShutdownAction,
|
||||
# "NODE_STARTUP": NodeStartupAction,
|
||||
# "NODE_RESET": NodeResetAction,
|
||||
# "NETWORK_ACL_ADDRULE": NetworkACLAddRuleAction,
|
||||
# "NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction,
|
||||
# "NETWORK_NIC_ENABLE": NetworkNICEnable,
|
||||
# "NETWORK_NIC_DISABLE": NetworkNICDisable,
|
||||
}
|
||||
|
||||
|
||||
def __init__(self,
|
||||
sim:Simulation,
|
||||
actions:List[str],
|
||||
node_uuids:List[str],
|
||||
max_folders_per_node:int = 2,
|
||||
max_files_per_folder:int = 2,
|
||||
max_services_per_node:int = 2,
|
||||
max_nics_per_node:int=8,
|
||||
max_acl_rules:int=10,
|
||||
act_map:Optional[Dict[int, Dict]]=None) -> None:
|
||||
self.sim: Simulation = sim
|
||||
self.node_uuids:List[str] = node_uuids
|
||||
|
||||
action_args = {
|
||||
"num_nodes": len(node_uuids),
|
||||
"num_folders":max_folders_per_node,
|
||||
"num_files": max_files_per_folder,
|
||||
"num_services": max_services_per_node,
|
||||
"num_nics": max_nics_per_node,
|
||||
"num_acl_rules": max_acl_rules}
|
||||
self.actions: Dict[str, AbstractAction] = {}
|
||||
for act_type in actions:
|
||||
self.actions[act_type] = self.__act_class_identifiers[act_type](self, **action_args)
|
||||
|
||||
self.action_map:Dict[int, Tuple[str, Dict]] = {}
|
||||
"""
|
||||
Action mapping that converts an integer to a specific action and parameter choice.
|
||||
|
||||
For example :
|
||||
{0: ("NODE_SERVICE_SCAN", {node_id:0, service_id:2})}
|
||||
"""
|
||||
if act_map is None:
|
||||
self.action_map = self._enumerate_actions()
|
||||
else:
|
||||
self.action_map = {i:(a['action'], a['options']) for i,a in act_map.items()}
|
||||
# make sure all numbers between 0 and N are represented as dict keys in action map
|
||||
assert all([i in self.action_map.keys() for i in range(len(self.action_map))])
|
||||
|
||||
def _enumerate_actions(self,) -> Dict[int, Tuple[AbstractAction, Dict]]:
|
||||
...
|
||||
|
||||
def get_action(self, action: int) -> Tuple[str,Dict]:
|
||||
"""Produce action in CAOS format"""
|
||||
"""the agent chooses an action (as an integer), this is converted into an action in CAOS format"""
|
||||
"""The caos format is basically a action identifier, followed by parameters stored in a dictionary"""
|
||||
act_identifier, act_options = self.action_map[action]
|
||||
return act_identifier, act_options
|
||||
|
||||
class ActionSpace:
|
||||
def form_request(self, action_identifier:str, action_options:Dict):
|
||||
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format"""
|
||||
act_obj = self.actions[action_identifier]
|
||||
return act_obj.form_request(**action_options)
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
return spaces.Discrete(len(self.action_map))
|
||||
|
||||
def get_node_uuid_by_idx(self, node_idx):
|
||||
return self.node_uuids[node_idx]
|
||||
|
||||
def get_folder_uuid_by_idx(self, node_idx, folder_idx) -> Optional[str]:
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node = self.sim.network.nodes[node_uuid]
|
||||
folder_uuids = list(node.file_system.folders.keys())
|
||||
return folder_uuids[folder_idx] if len(folder_uuids)>folder_idx else None
|
||||
|
||||
def get_file_uuid_by_idx(self, node_idx, folder_idx, file_idx) -> Optional[str]:
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node = self.sim.network.nodes[node_uuid]
|
||||
folder_uuids = list(node.file_system.folders.keys())
|
||||
if len(folder_uuids)<=folder_idx:
|
||||
return None
|
||||
folder = node.file_system.folders[folder_uuids[folder_idx]]
|
||||
file_uuids = list(folder.files.keys())
|
||||
return file_uuids[file_idx] if len(file_uuids)>file_idx else None
|
||||
|
||||
def get_service_uuid_by_idx(self, node_idx, service_idx) -> Optional[str]:
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node = self.sim.network.nodes[node_uuid]
|
||||
service_uuids = list(node.services.keys())
|
||||
return service_uuids[service_idx] if len(service_uuids)>service_idx else None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class UC2RedActions(AbstractAction):
|
||||
...
|
||||
|
||||
class UC2GreenActionSpace(ActionManager):
|
||||
...
|
||||
|
||||
@@ -2,32 +2,70 @@
|
||||
# That's because I want to point out that this is disctinct from 'agent' in the reinforcement learning sense of the word
|
||||
# If you disagree, make a comment in the PR review and we can discuss
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union, TypeAlias
|
||||
import numpy as np
|
||||
|
||||
from primaite.game.agent.actions import ActionSpace
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
ObsType:TypeAlias = Union[Dict, np.ndarray]
|
||||
|
||||
class AbstractAgent(ABC):
|
||||
"""Base class for scripted and RL agents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_space: Optional[ActionSpace],
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationSpace],
|
||||
reward_function: Optional[RewardFunction],
|
||||
) -> None:
|
||||
self.action_space: Optional[ActionSpace] = action_space
|
||||
self.action_space: Optional[ActionManager] = action_space
|
||||
self.observation_space: Optional[ObservationSpace] = observation_space
|
||||
self.reward_function: Optional[RewardFunction] = reward_function
|
||||
|
||||
# exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info
|
||||
# by for example specifying target ip addresses, or converting a node ID into a uuid
|
||||
self.execution_definition = None
|
||||
|
||||
def get_obs_from_state(self, state:Dict) -> ObsType:
|
||||
"""
|
||||
state : dict state directly from simulation.describe_state
|
||||
output : dict state according to CAOS.
|
||||
"""
|
||||
return self.observation_space.observe(state)
|
||||
|
||||
def get_reward_from_state(self, state:Dict) -> float:
|
||||
return self.reward_function.calculate(state)
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs:ObsType, reward:float=None):
|
||||
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 1-40,
|
||||
# then use a bespoke conversion to take 1-40 int back into CAOS action
|
||||
return ('NODE', 'SERVICE', 'SCAN', '<fake-node-sid>', '<fake-service-sid>')
|
||||
|
||||
@abstractmethod
|
||||
def format_request(self, action) -> List[str]:
|
||||
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
|
||||
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
|
||||
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
|
||||
return ['network', 'nodes', '<fake-node-uuid>', 'file_system', 'folder', 'root', 'scan']
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent):
|
||||
"""Base class for actors which generate their own behaviour."""
|
||||
|
||||
...
|
||||
|
||||
class RandomAgent(AbstractScriptedAgent):
|
||||
"""Agent that ignores its observation and acts completely at random."""
|
||||
|
||||
def get_action(self, obs:ObsType, reward:float=None):
|
||||
return self.action_space.space.sample()
|
||||
|
||||
|
||||
class AbstractGATEAgent(AbstractAgent):
|
||||
"""Base class for actors controlled via external messages, such as RL policies."""
|
||||
|
||||
@@ -55,7 +55,7 @@ class AbstractObservation(ABC):
|
||||
|
||||
|
||||
class FileObservation(AbstractObservation):
|
||||
def __init__(self, where: List[str] = []) -> None:
|
||||
def __init__(self, where: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
_summary_
|
||||
|
||||
@@ -68,12 +68,12 @@ class FileObservation(AbstractObservation):
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: List[str] = where
|
||||
self.where: Optional[List[str]] = where
|
||||
self.default_observation: spaces.Space = {"health_status": 0}
|
||||
"Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted."
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
if not self.where:
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
file_state = access_from_nested_dict(state, self.where)
|
||||
if file_state is NOT_PRESENT_IN_STATE:
|
||||
@@ -89,21 +89,21 @@ class ServiceObservation(AbstractObservation):
|
||||
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0}
|
||||
"Default observation is what should be returned when the service doesn't exist."
|
||||
|
||||
def __init__(self, where: List[str] = []) -> None:
|
||||
def __init__(self, where: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_uuid>,'servics', <service_uuid>]`
|
||||
`['network','nodes',<node_uuid>,'services', <service_uuid>]`
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: List[str] = where
|
||||
self.where: Optional[List[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
if not self.where:
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
service_state = access_from_nested_dict(state, self.where)
|
||||
@@ -120,7 +120,7 @@ class LinkObservation(AbstractObservation):
|
||||
default_observation: spaces.Space = {"protocols": {"all": {"load": 0}}}
|
||||
"Default observation is what should be returned when the link doesn't exist."
|
||||
|
||||
def __init__(self, where: List[str] = []) -> None:
|
||||
def __init__(self, where: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
@@ -131,10 +131,10 @@ class LinkObservation(AbstractObservation):
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: List[str] = where
|
||||
self.where: Optional[List[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
if not self.where:
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
@@ -156,7 +156,7 @@ class LinkObservation(AbstractObservation):
|
||||
|
||||
|
||||
class FolderObservation(AbstractObservation):
|
||||
def __init__(self, where: List[str] = [], files: List[FileObservation] = []) -> None:
|
||||
def __init__(self, where: Optional[List[str]] = None, files: List[FileObservation] = []) -> None:
|
||||
"""Initialise folder Observation, including files inside of the folder.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
|
||||
@@ -175,7 +175,7 @@ class FolderObservation(AbstractObservation):
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.where: List[str] = where
|
||||
self.where: Optional[List[str]] = where
|
||||
|
||||
self.files: List[FileObservation] = files
|
||||
|
||||
@@ -185,7 +185,7 @@ class FolderObservation(AbstractObservation):
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
if not self.where:
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
folder_state = access_from_nested_dict(state, self.where)
|
||||
if folder_state is NOT_PRESENT_IN_STATE:
|
||||
@@ -213,12 +213,12 @@ class FolderObservation(AbstractObservation):
|
||||
class NicObservation(AbstractObservation):
|
||||
default_observation: spaces.Space = {"nic_status": 0}
|
||||
|
||||
def __init__(self, where: List[str] = []) -> None:
|
||||
super.__init__()
|
||||
self.where: List[str] = where
|
||||
def __init__(self, where: Optional[List[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.where: Optional[List[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
if not self.where:
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
nic_state = access_from_nested_dict(state, self.where)
|
||||
if nic_state is NOT_PRESENT_IN_STATE:
|
||||
@@ -234,10 +234,11 @@ class NicObservation(AbstractObservation):
|
||||
class NodeObservation(AbstractObservation):
|
||||
def __init__(
|
||||
self,
|
||||
where: List[str] = [],
|
||||
where: Optional[List[str]] = None,
|
||||
services: List[ServiceObservation] = [],
|
||||
folders: List[FolderObservation] = [],
|
||||
nics: List[NicObservation] = [],
|
||||
logon_status:bool=False
|
||||
) -> None:
|
||||
"""
|
||||
Configurable observation for a node in the simulation.
|
||||
@@ -259,12 +260,13 @@ class NodeObservation(AbstractObservation):
|
||||
:param max_nics: Max number of NICS in this node's obs space, defaults to 5
|
||||
:type max_nics: int, optional
|
||||
"""
|
||||
super.__init__()
|
||||
self.where: List[str] = where
|
||||
super().__init__()
|
||||
self.where: Optional[List[str]] = where
|
||||
|
||||
self.services: List[ServiceObservation] = services
|
||||
self.folders: List[FolderObservation] = folders
|
||||
self.nics: List[NicObservation] = nics
|
||||
self.logon_status:bool=logon_status
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
|
||||
@@ -272,9 +274,11 @@ class NodeObservation(AbstractObservation):
|
||||
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.nics)},
|
||||
"operating_status": 0,
|
||||
}
|
||||
if self.logon_status:
|
||||
self.default_observation['logon_status']=0
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
if not self.where:
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
node_state = access_from_nested_dict(state, self.where)
|
||||
@@ -288,18 +292,24 @@ class NodeObservation(AbstractObservation):
|
||||
obs["operating_status"] = node_state["operating_state"]
|
||||
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
|
||||
|
||||
if self.logon_status:
|
||||
obs['logon_status'] = 0
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
return spaces.Dict(
|
||||
{
|
||||
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
|
||||
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
|
||||
"operating_status": spaces.Discrete(0),
|
||||
"NICS": spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)}),
|
||||
}
|
||||
)
|
||||
space_shape = {
|
||||
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
|
||||
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
|
||||
"operating_status": spaces.Discrete(5),
|
||||
"NICS": spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)}),
|
||||
}
|
||||
if self.logon_status:
|
||||
space_shape['logon_status'] = spaces.Discrete(3)
|
||||
|
||||
return spaces.Dict(space_shape)
|
||||
|
||||
|
||||
|
||||
class AclObservation(AbstractObservation):
|
||||
@@ -308,41 +318,33 @@ class AclObservation(AbstractObservation):
|
||||
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
|
||||
# this needs adding, but not for the MVP.
|
||||
def __init__(
|
||||
self, nodes: List[str], ports: List[int], protocols: list[str], where: List[str] = [], num_rules: int = 10
|
||||
self, node_ip_to_id: Dict[str,int], ports: List[int], protocols: list[str], where: Optional[List[str]] = None, num_rules: int = 10
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.where: List[str] = where
|
||||
self.where: Optional[List[str]] = where
|
||||
self.num_rules: int = num_rules
|
||||
self.node_to_id: Dict[str, int] = {node: i + 1 for i, node in enumerate(nodes)}
|
||||
self.node_to_id: Dict[str, int] = node_ip_to_id
|
||||
"List of node IP addresses, order in this list determines how they are converted to an ID"
|
||||
self.port_to_id: Dict[int, int] = {port: i + 1 for i, port in enumerate(ports)}
|
||||
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
|
||||
"List of ports which are part of the game that define the ordering when converting to an ID"
|
||||
self.protocol_to_id: Dict[str, int] = {protocol: i + 1 for i, protocol in enumerate(protocols)}
|
||||
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
|
||||
"List of protocols which are part of the game, defines ordering when converting to an ID"
|
||||
self.default_observation: spaces.Space = spaces.Dict(
|
||||
{
|
||||
"RULES": spaces.Dict(
|
||||
{
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
{
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
)
|
||||
self.default_observation: Dict = {
|
||||
"RULES": {i+ 1:{
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
if not self.where:
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
acl_state: Dict = access_from_nested_dict(state, self.where)
|
||||
if acl_state is NOT_PRESENT_IN_STATE:
|
||||
@@ -379,16 +381,16 @@ class AclObservation(AbstractObservation):
|
||||
{
|
||||
"RULE": spaces.Dict(
|
||||
{
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
i + 1: spaces.Dict(
|
||||
{
|
||||
"position": spaces.Discrete(self.num_rules),
|
||||
"permission": spaces.Discrete(3),
|
||||
"source_node_id": spaces.Discrete(len(self.nodes) + 1),
|
||||
"source_port": spaces.Discrete(len(self.ports) + 1),
|
||||
"dest_node_id": spaces.Discrete(len(self.nodes) + 1),
|
||||
"dest_port": spaces.Discrete(len(self.ports) + 1),
|
||||
"protocol": spaces.Discrete(len(self.protocols) + 1),
|
||||
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
||||
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
@@ -398,14 +400,96 @@ class AclObservation(AbstractObservation):
|
||||
)
|
||||
|
||||
|
||||
class ICSObservation(AbstractObservation):
|
||||
def observe(self, state: Dict) -> Any:
|
||||
return 0
|
||||
|
||||
|
||||
class NullObservation(AbstractObservation):
|
||||
def __init__(self, where:Optional[List[str]]=None):
|
||||
self.default_observation: Dict = {}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
return spaces.Discrete(1)
|
||||
return spaces.Dict({})
|
||||
|
||||
class ICSObservation(NullObservation): pass
|
||||
|
||||
|
||||
class UC2BlueObservation(AbstractObservation):
|
||||
def __init__(
|
||||
self,
|
||||
nodes: List[NodeObservation],
|
||||
links: List[LinkObservation],
|
||||
acl: AclObservation,
|
||||
ics: ICSObservation,
|
||||
where:Optional[List[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.where: Optional[List[str]] = where
|
||||
|
||||
self.nodes: List[NodeObservation] = nodes
|
||||
self.links: List[LinkObservation] = links
|
||||
self.acl: AclObservation = acl
|
||||
self.ics: ICSObservation = ics
|
||||
|
||||
self.default_observation : Dict = {
|
||||
"NODES": {i+1: n.default_observation for i,n in enumerate(self.nodes)},
|
||||
"LINKS": {i+1: l.default_observation for i,l in enumerate(self.links)},
|
||||
"ACL": self.acl.default_observation,
|
||||
"ICS": self.ics.default_observation,
|
||||
}
|
||||
|
||||
def observe(self, state:Dict) -> Dict:
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
|
||||
obs['NODES'] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
|
||||
obs['LINKS'] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
|
||||
obs['ACL'] = {self.acl.observe(state)}
|
||||
obs['ICS'] = {self.ics.observe(state)}
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
return spaces.Dict({
|
||||
"NODES": spaces.Dict({i+1: node.space for i, node in enumerate(self.nodes)}),
|
||||
"LINKS": spaces.Dict({i+1: link.space for i, link in enumerate(self.links)}),
|
||||
"ACL": self.acl.space,
|
||||
"ICS": self.ics.space,
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config:Dict, sim:Simulation):
|
||||
nodes = ...
|
||||
links = ...
|
||||
acl = ...
|
||||
ics = ...
|
||||
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=['network'])
|
||||
return new
|
||||
|
||||
|
||||
class UC2RedObservation(AbstractObservation):
|
||||
def __init__(self, nodes:List[NodeObservation], where:Optional[List[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.where:Optional[List[str]] = where
|
||||
self.nodes: List[NodeObservation] = nodes
|
||||
|
||||
self.default_observation=...#TODO
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
return super().observe(state)
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
... #TODO
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, sim:Simulation):
|
||||
... #TODO
|
||||
|
||||
class ObservationSpace:
|
||||
"""
|
||||
@@ -422,29 +506,12 @@ class ObservationSpace:
|
||||
# what this class does:
|
||||
# keep a list of observations
|
||||
# create observations for an actor from the config
|
||||
def __init__(
|
||||
self,
|
||||
simulation: Simulation,
|
||||
nodes: List[NodeObservation] = [],
|
||||
links: List[LinkObservation] = [],
|
||||
acl: Optional[AclObservation] = None,
|
||||
ics: Optional[ICSObservation] = None,
|
||||
) -> None:
|
||||
self.simulation: Simulation = simulation
|
||||
self.parts: Dict[str, AbstractObservation] = {}
|
||||
def __init__(self, observation:AbstractObservation) -> None:
|
||||
self.obs: AbstractObservation = observation
|
||||
|
||||
self.nodes: List[NodeObservation] = nodes
|
||||
self.links: List[LinkObservation] = links
|
||||
self.acl: Optional[AclObservation] = acl
|
||||
self.ics: Optional[ICSObservation] = ics
|
||||
|
||||
def observe(self) -> None:
|
||||
...
|
||||
def observe(self, state) -> Dict:
|
||||
return self.obs.observe(state)
|
||||
|
||||
@property
|
||||
def space(self) -> None:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def from_config(self) -> None:
|
||||
...
|
||||
return self.obs.space
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AbstractReward(BaseModel):
|
||||
def __call__(self, states: List[Dict]) -> float:
|
||||
"""_summary_
|
||||
|
||||
:param state: _description_
|
||||
:type state: Dict
|
||||
:return: _description_
|
||||
:rtype: float
|
||||
"""
|
||||
class AbstractReward():
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
def calculate(self, state:Dict) -> float:
|
||||
return 0.3
|
||||
|
||||
class RewardFunction(BaseModel):
|
||||
...
|
||||
|
||||
class RewardFunction():
|
||||
def __init__(self, reward_function:AbstractReward):
|
||||
self.reward: AbstractReward = reward_function
|
||||
|
||||
def calculate(self, state:Dict) -> float:
|
||||
return self.reward.calculate(state)
|
||||
|
||||
@@ -4,3 +4,49 @@
|
||||
# 3. create actors and configure their actions/observations/rewards/ anything else
|
||||
# 4. Create connection with ARCD GATE
|
||||
# 5. idk
|
||||
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.game.agent.interface import AbstractAgent
|
||||
|
||||
from typing import List
|
||||
|
||||
class PrimaiteSession:
|
||||
def __init__(self):
|
||||
self.simulation: Simulation = Simulation()
|
||||
self.agents:List[AbstractAgent] = []
|
||||
self.step_counter:int = 0
|
||||
self.episode_counter:int = 0
|
||||
|
||||
|
||||
def step(self):
|
||||
# currently designed with assumption that all agents act once per step in order
|
||||
|
||||
|
||||
for agent in self.agents:
|
||||
# 3. primaite session asks simulation to provide initial state
|
||||
# 4. primate session gives state to all agents
|
||||
# 5. primaite session asks agents to produce an action based on most recent state
|
||||
sim_state = self.simulation.describe_state()
|
||||
|
||||
# 6. each agent takes most recent state and converts it to CAOS observation
|
||||
agent_obs = agent.get_obs_from_state(sim_state)
|
||||
|
||||
# 7. meanwhile each agent also takes state and calculates reward
|
||||
agent_reward = agent.get_reward_from_state(sim_state)
|
||||
|
||||
# 8. each agent takes observation and applies decision rule to observation to create CAOS
|
||||
# action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action
|
||||
# to discrete(40) is only necessary for purposes of RL learning, therefore that bit of
|
||||
# code should live inside of the GATE agent subclass)
|
||||
# gets action in CAOS format
|
||||
agent_action = agent.get_action(agent_obs, agent_reward)
|
||||
# 9. CAOS action is converted into request (extra information might be needed to enrich
|
||||
# the request, this is what the execution definition is there for)
|
||||
agent_request = agent.format_request(agent_action)
|
||||
|
||||
# 10. primaite session receives the action from the agents and asks the simulation to apply each
|
||||
self.simulation.apply_action(agent_request)
|
||||
|
||||
self.simulation.apply_timestep(self.step_counter)
|
||||
self.step_counter += 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user