Finalise actions interface

This commit is contained in:
Marek Wolan
2023-10-02 17:21:43 +01:00
parent f1346ae278
commit 2b617e01a3
7 changed files with 1382 additions and 174 deletions

View File

@@ -27,11 +27,11 @@ game_config:
- type: LOGON
- type: LOGOFF
applications:
- application_ref: client_2_web_browser
actions:
- type: EXECUTE
execution_definition:
target_address: arcd.com
# - application_ref: client_2_web_browser
# actions:
# - type: EXECUTE
# execution_definition:
# target_address: arcd.com
reward_function: null
agent_settings:
start_step: 5
@@ -42,7 +42,8 @@ game_config:
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
network:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
@@ -85,13 +86,307 @@ game_config:
- ref: defender
team: blue
type: GATE_RL_AGENT
type: GATERLAgent
observation_space:
network:
type: UC2BlueObservation
options:
nodes:
- node_ref: router_1 #TODO: more sub-options here
- node_ref: switch_1
- node_ref: switch_2
- node_ref: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
services:
- service_ref: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
router_node_ref: router_1
ics: null
action_space:
action_list:
- DONOTHING
- NODE_SERVICE_SCAN
- NODE_SERVICE_STOP
# - NODE_SERVICE_START
# - NODE_SERVICE_PAUSE
# - NODE_SERVICE_RESUME
# - NODE_SERVICE_RESTART
# - NODE_SERVICE_DISABLE
# - NODE_SERVICE_ENABLE
# - NODE_FILE_SCAN
# - NODE_FILE_CHECKHASH
# - NODE_FILE_DELETE
# - NODE_FILE_REPAIR
# - NODE_FILE_RESTORE
# - NODE_FOLDER_SCAN
# - NODE_FOLDER_CHECKHASH
# - NODE_FOLDER_REPAIR
# - NODE_FOLDER_RESTORE
# - NODE_OS_SCAN
# - NODE_SHUTDOWN
# - NODE_STARTUP
# - NODE_RESET
# - NETWORK_ACL_ADDRULE
# - NETWORK_ACL_REMOVERULE
# - NETWORK_NIC_ENABLE
- NETWORK_NIC_DISABLE
action_map:
0:
- action: DONOTHING
options: {}
# scan webapp service
1:
- action: NODE_SERVICE_SCAN
options:
- node_id: 2
- service_id: 1
# stop webapp service
2:
- action: NODE_SERVICE_STOP
options:
- node_id: 2
- service_id: 1
# start webapp service
3:
- action: "NODE_SERVICE_START"
options:
- node_id: 2
- service_id: 1
4:
- action: "NODE_SERVICE_PAUSE"
options:
- node_id: 2
- service_id: 1
5:
- action: "NODE_SERVICE_RESUME"
options:
- node_id: 2
- service_id: 1
6:
- action: "NODE_SERVICE_RESTART"
options:
- node_id: 2
- service_id: 1
7:
- action: "NODE_SERVICE_DISABLE"
options:
- node_id: 2
- service_id: 1
8:
- action: "NODE_SERVICE_ENABLE"
options:
- node_id: 2
- service_id: 1
9:
- action: "NODE_FILE_SCAN"
options:
- node_id: 3
- folder_id: 1
- file_id: 1
10:
- action: "NODE_FILE_CHECKHASH"
options:
- node_id: 3
- folder_id: 1
- file_id: 1
11:
- action: "NODE_FILE_DELETE"
options:
- node_id: 3
- folder_id: 1
- file_id: 1
12:
- action: "NODE_FILE_REPAIR"
options:
- node_id: 3
- folder_id: 1
- file_id: 1
13:
- action: "NODE_FILE_RESTORE"
options:
- node_id: 3
- folder_id: 1
- file_id: 1
14:
- action: "NODE_FOLDER_SCAN"
options:
- node_id: 3
- folder_id: 1
15:
- action: "NODE_FOLDER_CHECKHASH"
options:
- node_id: 3
- folder_id: 1
16:
- action: "NODE_FOLDER_REPAIR"
options:
- node_id: 3
- folder_id: 1
17:
- action: "NODE_FOLDER_RESTORE"
options:
- node_id: 3
- folder_id: 1
18:
- action: "NODE_OS_SCAN"
options:
- node_id: 3
19:
- action: "NODE_SHUTDOWN"
options:
- node_id: 6
20:
- action: "NODE_STARTUP"
options:
- node_id: 6
21:
- action: "NODE_RESET"
options:
- node_id: 6
22:
- action: "NETWORK_ACL_ADDRULE"
options:
- position: 6
- permission: 2
- source_node_id: ...
- dest_node_id: ...
- source_port_id: ...
- dest_port_id: ...
- protocol_id: ...
23:
- action: "NETWORK_ACL_ADDRULE"
options:
- position: 5
- permission: 2
- source_node_id: ...
- dest_node_id: ...
- source_port_id: ...
- dest_port_id: ...
- protocol_id: ...
24:
- action: "NETWORK_ACL_ADDRULE"
options:
- position: 4
- permission: 2
- source_node_id: ...
- dest_node_id: ...
- source_port_id: ...
- dest_port_id: ...
- protocol_id: ...
25:
- action: "NETWORK_ACL_ADDRULE"
options:
- position: 3
- permission: 2
- source_node_id: ...
- dest_node_id: ...
- source_port_id: ...
- dest_port_id: ...
- protocol_id: ...
26:
- action: "NETWORK_ACL_ADDRULE"
options:
- position: 2
- permission: 2
- source_node_id: ...
- dest_node_id: ...
- source_port_id: ...
- dest_port_id: ...
- protocol_id: ...
27:
- action: "NETWORK_ACL_ADDRULE"
options:
- position: 1
- permission: 2
- source_node_id: ...
- dest_node_id: ...
- source_port_id: ...
- dest_port_id: ...
- protocol_id: ...
28:
- action: "NETWORK_ACL_REMOVERULE"
options:
- position: 0
29:
- action: "NETWORK_ACL_REMOVERULE"
options:
- position: 1
30:
- action: "NETWORK_ACL_REMOVERULE"
options:
- position: 2
31:
- action: "NETWORK_ACL_REMOVERULE"
options:
- position: 3
32:
- action: "NETWORK_ACL_REMOVERULE"
options:
- position: 4
33:
- action: "NETWORK_ACL_REMOVERULE"
options:
- position: 5
34:
- action: "NETWORK_ACL_REMOVERULE"
options:
- position: 6
35:
- action: "NETWORK_ACL_REMOVERULE"
options:
- position: 7
36:
- action: "NETWORK_ACL_REMOVERULE"
options:
- position: 8
37:
- action: "NETWORK_ACL_REMOVERULE"
options:
- position: 9
38:
- action: "NETWORK_NIC_DISABLE"
options:
- node_id: 6
- nic_index: 1
39:
- action: "NETWORK_NIC_ENABLE"
options:
- node_id: 6
- nic_index: 1
options:
nodes:
- node_ref: router_1
- node_ref: switch_1
- node_ref: switch_2
- node_ref: domain_controller
- node_ref: web_server
- node_ref: database_server
@@ -99,18 +394,13 @@ game_config:
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
links:
- link_ref: ... #
acl: ... #
ics: ... #
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
action_space:
actions:
- type: DO_NOTHING
network:
nodes:
- node_ref: router_1
reward_function:
# ...
agent_settings:
@@ -175,7 +465,6 @@ simulation:
domain_mapping:
arcd.com: 192.168.1.12 # web server
- ref: web_server
type: server
hostname: web_server
@@ -200,7 +489,6 @@ simulation:
- ref: database_service
type: DatabaseService
- ref: backup_server
type: server
hostname: backup_server
@@ -224,7 +512,6 @@ simulation:
ip_address: 192.168.10.110
subnet_mask: 255.255.255.0
- ref: client_1
type: computer
hostname: client_1
@@ -251,7 +538,6 @@ simulation:
- ref: client_2_dns_client
type: DNSClient
links:
- ref: router_1___switch_1
endpoint_a_ref: router_1

View File

@@ -2,9 +2,18 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 13,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
@@ -12,7 +21,381 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from primaite.game.session import PrimaiteSession\n",
"from primaite.simulator.sim_container import Simulation\n",
"from primaite.game.agent.interface import AbstractAgent\n",
"from primaite.simulator.network.networks import arcd_uc2_network\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"sess = PrimaiteSession()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"network = sess.simulation.network"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"from ipaddress import IPv4Address\n",
"\n",
"from primaite.simulator.network.container import Network\n",
"from primaite.simulator.network.hardware.base import NIC\n",
"from primaite.simulator.network.hardware.nodes.computer import Computer\n",
"from primaite.simulator.network.hardware.nodes.router import ACLAction, Router\n",
"from primaite.simulator.network.hardware.nodes.server import Server\n",
"from primaite.simulator.network.hardware.nodes.switch import Switch\n",
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
"from primaite.simulator.network.transmission.transport_layer import Port\n",
"from primaite.simulator.system.applications.database_client import DatabaseClient\n",
"from primaite.simulator.system.services.database_service import DatabaseService\n",
"from primaite.simulator.system.services.dns_client import DNSClient\n",
"from primaite.simulator.system.services.dns_server import DNSServer\n",
"from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-10-02 15:10:20,422: Added node 6abb7664-4d17-45ff-a3c7-dbcccffcfd6d to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,424: Added node 3edbc521-3c80-47e3-8017-dbc38fb00a73 to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,428: Added node 94457fb9-04a1-4dc1-9ff7-b64df0da7424 to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,432: Added node 0d311d72-139c-41bf-aef7-fa9b01b124d7 to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,439: Added node 6161e785-f377-48de-aa4f-20d3646da635 to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,444: Added node 55a9e9f8-ee3a-4c28-9b6d-c0c0d78a3f6a to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,447: Added node 2f04ca45-3439-489a-81f7-41cca5ae8adc to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,531: Added node 98660c30-8e48-4b96-967a-d62ca71b4d6d to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,545: Added node 1a184184-b204-40de-986a-4d7459036dbe to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,551: Added node 17b92b9a-6805-4677-85f7-c0c0521a6e25 to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,555::ERROR::primaite.simulator.network.hardware.base::175::NIC da:f3:1b:87:24:20/192.168.10.110 cannot be enabled as it is not connected to a Link\n"
]
}
],
"source": [
"router_1 = Router(hostname=\"router_1\", num_ports=5)\n",
"router_1.power_on()\n",
"router_1.configure_port(port=1, ip_address=\"192.168.1.1\", subnet_mask=\"255.255.255.0\")\n",
"router_1.configure_port(port=2, ip_address=\"192.168.10.1\", subnet_mask=\"255.255.255.0\")\n",
"\n",
"# Switch 1\n",
"switch_1 = Switch(hostname=\"switch_1\", num_ports=8)\n",
"switch_1.power_on()\n",
"network.connect(endpoint_a=router_1.ethernet_ports[1], endpoint_b=switch_1.switch_ports[8])\n",
"router_1.enable_port(1)\n",
"\n",
"# Switch 2\n",
"switch_2 = Switch(hostname=\"switch_2\", num_ports=8)\n",
"switch_2.power_on()\n",
"network.connect(endpoint_a=router_1.ethernet_ports[2], endpoint_b=switch_2.switch_ports[8])\n",
"router_1.enable_port(2)\n",
"\n",
"# Client 1\n",
"client_1 = Computer(\n",
" hostname=\"client_1\",\n",
" ip_address=\"192.168.10.21\",\n",
" subnet_mask=\"255.255.255.0\",\n",
" default_gateway=\"192.168.10.1\",\n",
" dns_server=IPv4Address(\"192.168.1.10\"),\n",
")\n",
"client_1.power_on()\n",
"client_1.software_manager.install(DNSClient)\n",
"client_1_dns_client_service: DNSServer = client_1.software_manager.software[\"DNSClient\"] # noqa\n",
"client_1_dns_client_service.start()\n",
"network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])\n",
"client_1.software_manager.install(DataManipulationBot)\n",
"db_manipulation_bot: DataManipulationBot = client_1.software_manager.software[\"DataManipulationBot\"]\n",
"db_manipulation_bot.configure(server_ip_address=IPv4Address(\"192.168.1.14\"), payload=\"DROP TABLE IF EXISTS user;\")\n",
"\n",
"# Client 2\n",
"client_2 = Computer(\n",
" hostname=\"client_2\",\n",
" ip_address=\"192.168.10.22\",\n",
" subnet_mask=\"255.255.255.0\",\n",
" default_gateway=\"192.168.10.1\",\n",
" dns_server=IPv4Address(\"192.168.1.10\"),\n",
")\n",
"client_2.power_on()\n",
"client_2.software_manager.install(DNSClient)\n",
"client_2_dns_client_service: DNSServer = client_2.software_manager.software[\"DNSClient\"] # noqa\n",
"client_2_dns_client_service.start()\n",
"network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2])\n",
"\n",
"# Domain Controller\n",
"domain_controller = Server(\n",
" hostname=\"domain_controller\",\n",
" ip_address=\"192.168.1.10\",\n",
" subnet_mask=\"255.255.255.0\",\n",
" default_gateway=\"192.168.1.1\",\n",
")\n",
"domain_controller.power_on()\n",
"domain_controller.software_manager.install(DNSServer)\n",
"\n",
"network.connect(endpoint_b=domain_controller.ethernet_port[1], endpoint_a=switch_1.switch_ports[1])\n",
"\n",
"# Database Server\n",
"database_server = Server(\n",
" hostname=\"database_server\",\n",
" ip_address=\"192.168.1.14\",\n",
" subnet_mask=\"255.255.255.0\",\n",
" default_gateway=\"192.168.1.1\",\n",
" dns_server=IPv4Address(\"192.168.1.10\"),\n",
")\n",
"database_server.power_on()\n",
"network.connect(endpoint_b=database_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[3])\n",
"\n",
"ddl = \"\"\"\n",
"CREATE TABLE IF NOT EXISTS user (\n",
"id INTEGER PRIMARY KEY AUTOINCREMENT,\n",
"name VARCHAR(50) NOT NULL,\n",
"email VARCHAR(50) NOT NULL,\n",
"age INT,\n",
"city VARCHAR(50),\n",
"occupation VARCHAR(50)\n",
");\"\"\"\n",
"\n",
"user_insert_statements = [\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');\", # noqa\n",
"]\n",
"database_server.software_manager.install(DatabaseService)\n",
"database_service: DatabaseService = database_server.software_manager.software[\"DatabaseService\"] # noqa\n",
"database_service.start()\n",
"database_service._process_sql(ddl, None) # noqa\n",
"for insert_statement in user_insert_statements:\n",
" database_service._process_sql(insert_statement, None) # noqa\n",
"\n",
"# Web Server\n",
"web_server = Server(\n",
" hostname=\"web_server\",\n",
" ip_address=\"192.168.1.12\",\n",
" subnet_mask=\"255.255.255.0\",\n",
" default_gateway=\"192.168.1.1\",\n",
" dns_server=IPv4Address(\"192.168.1.10\"),\n",
")\n",
"web_server.power_on()\n",
"web_server.software_manager.install(DatabaseClient)\n",
"\n",
"database_client: DatabaseClient = web_server.software_manager.software[\"DatabaseClient\"]\n",
"database_client.configure(server_ip_address=IPv4Address(\"192.168.1.14\"))\n",
"network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2])\n",
"database_client.run()\n",
"database_client.connect()\n",
"\n",
"# register the web_server to a domain\n",
"dns_server_service: DNSServer = domain_controller.software_manager.software[\"DNSServer\"] # noqa\n",
"dns_server_service.start()\n",
"dns_server_service.dns_register(\"arcd.com\", web_server.ip_address)\n",
"\n",
"# Backup Server\n",
"backup_server = Server(\n",
" hostname=\"backup_server\",\n",
" ip_address=\"192.168.1.16\",\n",
" subnet_mask=\"255.255.255.0\",\n",
" default_gateway=\"192.168.1.1\",\n",
" dns_server=IPv4Address(\"192.168.1.10\"),\n",
")\n",
"backup_server.power_on()\n",
"network.connect(endpoint_b=backup_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[4])\n",
"\n",
"# Security Suite\n",
"security_suite = Server(\n",
" hostname=\"security_suite\",\n",
" ip_address=\"192.168.1.110\",\n",
" subnet_mask=\"255.255.255.0\",\n",
" default_gateway=\"192.168.1.1\",\n",
" dns_server=IPv4Address(\"192.168.1.10\"),\n",
")\n",
"security_suite.power_on()\n",
"network.connect(endpoint_b=security_suite.ethernet_port[1], endpoint_a=switch_1.switch_ports[7])\n",
"security_suite.connect_nic(NIC(ip_address=\"192.168.10.110\", subnet_mask=\"255.255.255.0\"))\n",
"network.connect(endpoint_b=security_suite.ethernet_port[2], endpoint_a=switch_2.switch_ports[7])\n",
"\n",
"router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)\n",
"\n",
"router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)\n",
"\n",
"# Allow PostgreSQL requests\n",
"router_1.acl.add_rule(\n",
" action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0\n",
")\n",
"\n",
"# Allow DNS requests\n",
"router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1)\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"node_uuid_list = list(sess.simulation.network.nodes.keys())"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"from primaite.game.agent.actions import ActionManager"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"actman = ActionManager(sess.simulation, [\"DONOTHING\", \"NODE_SERVICE_SCAN\", \"NODE_SERVICE_STOP\", \"NODE_FOLDER_SCAN\"],node_uuid_list,act_map={\n",
" 0:{\n",
" \"action\": \"DONOTHING\",\n",
" \"options\": {}\n",
" },\n",
" 1:{\n",
" \"action\": \"NODE_SERVICE_SCAN\",\n",
" \"options\": {\"node_id\":0, \"service_id\":0},\n",
" },\n",
" 2:{\n",
" \"action\": \"NODE_SERVICE_SCAN\",\n",
" \"options\": {\"node_id\":1, \"service_id\":0},\n",
" },\n",
" 3:{\n",
" \"action\": \"NODE_FOLDER_SCAN\",\n",
" \"options\": {\"node_id\":4, \"folder_id\":0},\n",
" }\n",
"})"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"act_id, act_options = actman.get_action(3)\n",
"my_trial_act = actman.form_request(action_identifier=act_id, action_options=act_options)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"sess.simulation.apply_action(my_trial_act)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['network',\n",
" 'node',\n",
" '6161e785-f377-48de-aa4f-20d3646da635',\n",
" 'file_system',\n",
" 'folder',\n",
" '5aefe92b-923c-4684-b3bf-e78dd18d4771',\n",
" 'scan']"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"my_trial_act"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sess.step()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sess.step_counter"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from gym import spaces"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sp = spaces.Tuple( (spaces.MultiDiscrete([3, 2]), spaces.MultiDiscrete([3, 2]), spaces.MultiDiscrete([3, 2]),))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sp.sample()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -39,40 +422,16 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-09-26 12:19:50,895: Added node 0fb262e1-a714-420a-aec7-be37f0deeb75 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
"2023-09-26 12:19:50,898: Added node 310ca8d7-01e0-401e-b604-705c290e5376 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
"2023-09-26 12:19:50,900: Added node b3b08f1f-7805-47b2-bdb6-3d83098cd740 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
"2023-09-26 12:19:50,903: Added node adb37f3e-2307-4123-bff3-01f125883be8 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
"2023-09-26 12:19:50,906: Added node 1a490716-2ccd-452d-b87e-324d29120b59 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
"2023-09-26 12:19:50,911: Added node 033460d8-0249-4bdd-aaf0-751b24cc0a1e to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
"2023-09-26 12:19:50,914: Added node 1e7e4e49-78bf-4031-8372-ee71902720f3 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
"2023-09-26 12:19:50,916: Added node c9f24a13-e5c8-437b-9234-b0c3f8120e2c to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
"2023-09-26 12:19:50,920: Added node c881f3ee-2176-493b-a6c2-cad829bf0b6d to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
"2023-09-26 12:19:50,922: Added node a3ea75d8-bc2c-4713-92a4-2588b4f43ed6 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"service type not found DatabaseBackup\n",
"service type not found WebBrowser\n"
]
}
],
"outputs": [],
"source": [
"# import yaml\n",
"\n",
"\n",
"from typing import Dict\n",
"from primaite.game.agent.interface import AbstractAgent\n",
"from primaite.game.agent.observations import AclObservation, FileObservation, FolderObservation, ICSObservation, LinkObservation, NicObservation, NodeObservation, NullObservation, ServiceObservation, UC2BlueObservation, UC2RedObservation\n",
"from primaite.simulator.network.hardware.base import NIC, Link, Node\n",
"from primaite.simulator.system.services.service import Service\n",
"\n",
@@ -130,8 +489,8 @@
" subnet_mask=port_cfg['subnet_mask'])\n",
" if 'acl' in node_cfg:\n",
" for r_num, r_cfg in node_cfg['acl'].items():\n",
" # excuse the uncommon walrus operator ` := `. It's just here as a shorthand, so that we can do\n",
" # both of these things once: check if a key is defined, access and convert it to a \n",
" # excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating \n",
" # this: 'r_cfg.get('src_port')'\n",
" # Port/IPProtocol. TODO Refactor\n",
" new_node.acl.add_rule(\n",
" action = ACLAction[r_cfg['action']],\n",
@@ -211,8 +570,70 @@
" action_space_cfg = agent_cfg['action_space']\n",
" observation_space_cfg = agent_cfg['observation_space']\n",
" reward_function_cfg = agent_cfg['reward_function']\n",
" \n",
" # CREATE OBSERVATION SPACE\n",
" if observation_space_cfg is None:\n",
" obs_space = NullObservation()\n",
" elif observation_space_cfg['type'] == 'UC2BlueObservation':\n",
" node_obs_list = []\n",
" link_obs_list = []\n",
" \n",
" \n",
" #node ip to index maps ip addresses to node id, as there are potentially multiple nics on a node, there are multiple ip addresses\n",
" node_ip_to_index = {}\n",
" for node_idx, node_cfg in enumerate(nodes_cfg):\n",
" n_ref = node_cfg['ref']\n",
" n_obj = net.nodes[ref_map_nodes[n_ref]]\n",
" for nic_uuid, nic_obj in n_obj.nics.items():\n",
" node_ip_to_index[nic_obj.ip_address] = node_idx + 2\n",
"\n",
" \n",
" \n",
" for node_obs_cfg in observation_space_cfg['options']['nodes']:\n",
" node_ref = node_obs_cfg['node_ref']\n",
" folder_obs_list = []\n",
" service_obs_list = []\n",
" if 'services' in node_obs_cfg:\n",
" for service_obs_cfg in node_obs_cfg['services']:\n",
" service_obs_list.append(ServiceObservation(where=['network','nodes',ref_map_nodes[node_ref],'services',ref_map_services[service_obs_cfg['service_ref']]]))\n",
" if 'folders' in node_obs_cfg:\n",
" for folder_obs_cfg in node_obs_cfg['folders']:\n",
" file_obs_list = []\n",
" if 'files' in folder_obs_cfg:\n",
" for file_obs_cfg in folder_obs_cfg['files']:\n",
" file_obs_list.append(FileObservation(where=['network','nodes',ref_map_nodes[node_ref], 'folders',folder_obs_cfg['folder_name'], 'files', file_obs_cfg['file_name']]))\n",
" folder_obs_list.append(FolderObservation(where=['network','nodes',ref_map_nodes[node_ref], 'folders',folder_obs_cfg['folder_name']], files=file_obs_list))\n",
" nic_obs_list = []\n",
" for nic_uuid in net.nodes[ref_map_nodes[node_obs_cfg['node_ref']]].nics.keys():\n",
" nic_obs_list.append(NicObservation(where=['network','nodes',ref_map_nodes[node_ref],'NICs',nic_uuid]))\n",
" node_obs_list.append(NodeObservation(where=['network','nodes',ref_map_nodes[node_ref]], services=service_obs_list, folders=folder_obs_list,nics=nic_obs_list, logon_status=False))\n",
" for link_obs_cfg in observation_space_cfg['options']['links']:\n",
" link_ref = link_obs_cfg['link_ref']\n",
" link_obs_list.append(LinkObservation(where=['network' ,'links', ref_map_links[link_ref]]))\n",
"\n",
" acl_obs = AclObservation(node_ip_to_id=node_ip_to_index, ports=game_cfg['ports'], protocols=game_cfg['ports'], where=['network','nodes',observation_space_cfg['options']['acl']['router_node_ref']])\n",
" obs_space = UC2BlueObservation(nodes=node_obs_list,links=link_obs_list,acl=acl_obs, ics=ICSObservation())\n",
" elif observation_space_cfg['type'] == 'UC2RedObservation':\n",
" obs_space = UC2RedObservation.from_config(observation_space_cfg['options'], sim=sim)\n",
" else:\n",
" print(\"observation space config not specified correctly.\")\n",
" obs_space = NullObservation()\n",
" \n",
" # CREATE ACTION SPACE\n",
" \n",
"\n",
"\n",
" # CREATE REWARD FUNCTION\n",
"\n",
" # CREATE AGENT\n",
" if agent_type == 'GreenWebBrowsingAgent':\n",
" new_agent = GreenWebBrowsingAgent()\n",
" ...\n",
" elif agent_type == 'GATERLAgent':\n",
" ...\n",
" elif agent_type == 'RedDatabaseCorruptingAgent':\n",
" ...\n",
" else:\n",
" print(\"agent type not found\")\n",
"\n",
"\n",
" #4. set up agents' actions and observation spaces.\n",
@@ -228,7 +649,7 @@
"metadata": {},
"outputs": [],
"source": [
"print(s.simulation.describe_state())"
"s.agents"
]
},
{

View File

@@ -1,21 +1,374 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from pydantic import BaseModel
from typing import Any, Dict, List, Optional, Tuple
import itertools
class AbstractAction(BaseModel):
from primaite.simulator.sim_container import Simulation
from gym import spaces
class ExecutionDefiniton(ABC):
"""
Converter from actions to simulator requests.
Allows adding extra data/context that defines in more detail what an action means.
"""
"""
Examples:
('node', 'service', 'scan', 2, 0) means scan the first service on node index 2
-> ['network', 'nodes', <node-idx-2-uuid>, 'services', <svc-idx-0-uuid>, 'scan'w]
"""
...
class AbstractAction(ABC):
@abstractmethod
def __call__(self, action: Any) -> List[str]:
"""_summary_
:param action: _description_
:type action: Any
:return: _description_
:rtype: List[str]
def __init__(self, manager:"ActionManager", **kwargs) -> None:
"""
Init method for action.
All action init functions should accept **kwargs as a way of ignoring extra arguments.
Since many parameters are defined for the action space as a whole (such as max files per folder, max services
per node), we need to pass those options to every action that gets created. To pervent verbosity, these
parameters are just broadcasted to all actions and the actions can pay attention to the ones that apply.
"""
self.name:str = ""
"""Human-readable action identifier used for printing, logging, and reporting."""
self.shape = (0,)
"""Tuple describing number of options for each parameter of this action. Can be passed to
gym.spaces.MultiDiscrete to form a valid space."""
self.manager:ActionManager = manager
@abstractmethod
def form_request(self) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return []
class DoNothingAction(AbstractAction):
def __init__(self, manager:"ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
self.name = "DONOTHING"
self.shape = (1,)
def form_request(self) -> List[str]:
return ["do_nothing"]
class NodeServiceAbstractAction(AbstractAction):
"""
Base class for service actions.
Any action which applies to a service and uses node_id and service_id as its only two parameters can inherit from
this base class.
"""
@abstractmethod
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Tuple[int] = (num_nodes, num_services)
self.verb:str
def form_request(self, node_id:int, service_id:int) -> List[str]:
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
service_uuid = self.manager.get_service_uuid_by_idx(node_id, service_id)
if node_uuid is None or service_uuid is None:
return ["do_nothing"]
return ['network', 'node', node_uuid, 'services', service_uuid, self.verb]
class NodeServiceScanAction(NodeServiceAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = "scan"
class NodeServiceStopAction(NodeServiceAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = "stop"
class NodeServiceStartAction(NodeServiceAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = "start"
class NodeServicePauseAction(NodeServiceAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = "pause"
class NodeServiceResumeAction(NodeServiceAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = "resume"
class NodeServiceRestartAction(NodeServiceAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = "restart"
class NodeServiceDisableAction(NodeServiceAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = "disable"
class NodeServiceEnableAction(NodeServiceAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = "enable"
class NodeFolderAbstractAction(AbstractAction):
@abstractmethod
def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None:
super().__init__(manager=manager)
self.shape = (num_nodes, num_folders)
self.verb: str
def form_request(self, node_id:int, folder_id:int) -> List[str]:
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
if node_uuid is None or folder_uuid is None:
return ["do_nothing"]
return ['network', 'node', node_uuid, 'file_system', 'folder', folder_uuid, self.verb]
class NodeFolderScanAction(NodeFolderAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, **kwargs)
self.verb:str = "scan"
class NodeFolderCheckhashAction(NodeFolderAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, **kwargs)
self.verb:str = "checkhash"
class NodeFolderRepairAction(NodeFolderAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, **kwargs)
self.verb:str = "repair"
class NodeFolderRestoreAction(NodeFolderAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, **kwargs)
self.verb:str = "restore"
class NodeFileAbstractAction(AbstractAction):
@abstractmethod
def __init__(self, manager:"ActionManager", num_nodes:int, num_folders:int, num_files:int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape:Tuple[int] = (num_nodes, num_folders, num_files)
self.verb:str
def form_request(self, node_id:int, folder_id:int, file_id:int) -> List[str]:
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
file_uuid = self.manager.get_file_uuid_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id)
if node_uuid is None or folder_uuid is None or file_uuid is None:
return ["do_nothing"]
return ['network', 'node', node_uuid, 'file_system', 'folder', folder_uuid, 'files', file_uuid, self.verb]
class NodeFileScanAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
self.verb = "scan"
class NodeFileCheckhashAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
self.verb = "checkhash"
class NodeFileDeleteAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
self.verb = "delete"
class NodeFileRepairAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
self.verb = "repair"
class NodeFileRestoreAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
self.verb = "restore"
class NodeAbstractAction(AbstractAction):
@abstractmethod
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Tuple[int] = (num_nodes,)
self.verb: str
def form_request(self, node_id:int) -> List[str]:
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
return ["network", "node", node_uuid, self.verb]
class NodeOSScanAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = 'scan'
class NodeShutdownAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = 'shutdown'
class NodeStartupAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = 'start'
class NodeResetAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = 'reset'
class NetworkACLAddRuleAction(AbstractAction):
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
num_permissions = 2
self.shape: Tuple[int] = (max_acl_rules, num_permissions, num_nics, num_nics, num_ports, num_ports, num_protocols)
class ActionManager:
# let the action manager handle the conversion of action spaces into a single discrete integer space.
#
# when action space is created, it will take subspaces and generate an action map by enumerating all possibilities,
# BUT, the action map can be provided in the config, in which case it will use that.
# action map is basically just a mapping between integer and CAOS action (incl. parameter values)
# for example the action map can be:
# 0: DONOTHING
# 1: NODE, FILE, SCAN, NODEID=2, FOLDERID=1, FILEID=0
# 2: ......
__act_class_identifiers:Dict[str,type] = {
"DONOTHING": DoNothingAction,
"NODE_SERVICE_SCAN": NodeServiceScanAction,
"NODE_SERVICE_STOP": NodeServiceStopAction,
# "NODE_SERVICE_START": NodeServiceStartAction,
# "NODE_SERVICE_PAUSE": NodeServicePauseAction,
# "NODE_SERVICE_RESUME": NodeServiceResumeAction,
# "NODE_SERVICE_RESTART": NodeServiceRestartAction,
# "NODE_SERVICE_DISABLE": NodeServiceDisableAction,
# "NODE_SERVICE_ENABLE": NodeServiceEnableAction,
# "NODE_FILE_SCAN": NodeFileScanAction,
# "NODE_FILE_CHECKHASH": NodeFileCheckhashAction,
# "NODE_FILE_DELETE": NodeFileDeleteAction,
# "NODE_FILE_REPAIR": NodeFileRepairAction,
# "NODE_FILE_RESTORE": NodeFileRestoreAction,
"NODE_FOLDER_SCAN": NodeFolderScanAction,
# "NODE_FOLDER_CHECKHASH": NodeFolderCheckhashAction,
# "NODE_FOLDER_REPAIR": NodeFolderRepairAction,
# "NODE_FOLDER_RESTORE": NodeFolderRestoreAction,
# "NODE_OS_SCAN": NodeOSScanAction,
# "NODE_SHUTDOWN": NodeShutdownAction,
# "NODE_STARTUP": NodeStartupAction,
# "NODE_RESET": NodeResetAction,
# "NETWORK_ACL_ADDRULE": NetworkACLAddRuleAction,
# "NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction,
# "NETWORK_NIC_ENABLE": NetworkNICEnable,
# "NETWORK_NIC_DISABLE": NetworkNICDisable,
}
def __init__(self,
sim:Simulation,
actions:List[str],
node_uuids:List[str],
max_folders_per_node:int = 2,
max_files_per_folder:int = 2,
max_services_per_node:int = 2,
max_nics_per_node:int=8,
max_acl_rules:int=10,
act_map:Optional[Dict[int, Dict]]=None) -> None:
self.sim: Simulation = sim
self.node_uuids:List[str] = node_uuids
action_args = {
"num_nodes": len(node_uuids),
"num_folders":max_folders_per_node,
"num_files": max_files_per_folder,
"num_services": max_services_per_node,
"num_nics": max_nics_per_node,
"num_acl_rules": max_acl_rules}
self.actions: Dict[str, AbstractAction] = {}
for act_type in actions:
self.actions[act_type] = self.__act_class_identifiers[act_type](self, **action_args)
self.action_map:Dict[int, Tuple[str, Dict]] = {}
"""
Action mapping that converts an integer to a specific action and parameter choice.
For example :
{0: ("NODE_SERVICE_SCAN", {node_id:0, service_id:2})}
"""
if act_map is None:
self.action_map = self._enumerate_actions()
else:
self.action_map = {i:(a['action'], a['options']) for i,a in act_map.items()}
# make sure all numbers between 0 and N are represented as dict keys in action map
assert all([i in self.action_map.keys() for i in range(len(self.action_map))])
def _enumerate_actions(self,) -> Dict[int, Tuple[AbstractAction, Dict]]:
...
def get_action(self, action: int) -> Tuple[str,Dict]:
"""Produce action in CAOS format"""
"""the agent chooses an action (as an integer), this is converted into an action in CAOS format"""
"""The caos format is basically a action identifier, followed by parameters stored in a dictionary"""
act_identifier, act_options = self.action_map[action]
return act_identifier, act_options
class ActionSpace:
def form_request(self, action_identifier:str, action_options:Dict):
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format"""
act_obj = self.actions[action_identifier]
return act_obj.form_request(**action_options)
@property
def space(self) -> spaces.Space:
return spaces.Discrete(len(self.action_map))
def get_node_uuid_by_idx(self, node_idx):
return self.node_uuids[node_idx]
def get_folder_uuid_by_idx(self, node_idx, folder_idx) -> Optional[str]:
node_uuid = self.get_node_uuid_by_idx(node_idx)
node = self.sim.network.nodes[node_uuid]
folder_uuids = list(node.file_system.folders.keys())
return folder_uuids[folder_idx] if len(folder_uuids)>folder_idx else None
def get_file_uuid_by_idx(self, node_idx, folder_idx, file_idx) -> Optional[str]:
node_uuid = self.get_node_uuid_by_idx(node_idx)
node = self.sim.network.nodes[node_uuid]
folder_uuids = list(node.file_system.folders.keys())
if len(folder_uuids)<=folder_idx:
return None
folder = node.file_system.folders[folder_uuids[folder_idx]]
file_uuids = list(folder.files.keys())
return file_uuids[file_idx] if len(file_uuids)>file_idx else None
def get_service_uuid_by_idx(self, node_idx, service_idx) -> Optional[str]:
node_uuid = self.get_node_uuid_by_idx(node_idx)
node = self.sim.network.nodes[node_uuid]
service_uuids = list(node.services.keys())
return service_uuids[service_idx] if len(service_uuids)>service_idx else None
class UC2RedActions(AbstractAction):
...
class UC2GreenActionSpace(ActionManager):
...

View File

@@ -2,32 +2,70 @@
# That's because I want to point out that this is disctinct from 'agent' in the reinforcement learning sense of the word
# If you disagree, make a comment in the PR review and we can discuss
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union, TypeAlias
import numpy as np
from primaite.game.agent.actions import ActionSpace
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.rewards import RewardFunction
ObsType:TypeAlias = Union[Dict, np.ndarray]
class AbstractAgent(ABC):
"""Base class for scripted and RL agents."""
def __init__(
self,
action_space: Optional[ActionSpace],
action_space: Optional[ActionManager],
observation_space: Optional[ObservationSpace],
reward_function: Optional[RewardFunction],
) -> None:
self.action_space: Optional[ActionSpace] = action_space
self.action_space: Optional[ActionManager] = action_space
self.observation_space: Optional[ObservationSpace] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
# exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info
# by for example specifying target ip addresses, or converting a node ID into a uuid
self.execution_definition = None
def get_obs_from_state(self, state:Dict) -> ObsType:
"""
state : dict state directly from simulation.describe_state
output : dict state according to CAOS.
"""
return self.observation_space.observe(state)
def get_reward_from_state(self, state:Dict) -> float:
return self.reward_function.calculate(state)
@abstractmethod
def get_action(self, obs:ObsType, reward:float=None):
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 1-40,
# then use a bespoke conversion to take 1-40 int back into CAOS action
return ('NODE', 'SERVICE', 'SCAN', '<fake-node-sid>', '<fake-service-sid>')
@abstractmethod
def format_request(self, action) -> List[str]:
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
return ['network', 'nodes', '<fake-node-uuid>', 'file_system', 'folder', 'root', 'scan']
class AbstractScriptedAgent(AbstractAgent):
"""Base class for actors which generate their own behaviour."""
...
class RandomAgent(AbstractScriptedAgent):
"""Agent that ignores its observation and acts completely at random."""
def get_action(self, obs:ObsType, reward:float=None):
return self.action_space.space.sample()
class AbstractGATEAgent(AbstractAgent):
"""Base class for actors controlled via external messages, such as RL policies."""

View File

@@ -55,7 +55,7 @@ class AbstractObservation(ABC):
class FileObservation(AbstractObservation):
def __init__(self, where: List[str] = []) -> None:
def __init__(self, where: Optional[List[str]] = None) -> None:
"""
_summary_
@@ -68,12 +68,12 @@ class FileObservation(AbstractObservation):
:type where: Optional[List[str]]
"""
super().__init__()
self.where: List[str] = where
self.where: Optional[List[str]] = where
self.default_observation: spaces.Space = {"health_status": 0}
"Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted."
def observe(self, state: Dict) -> Dict:
if not self.where:
if self.where is None:
return self.default_observation
file_state = access_from_nested_dict(state, self.where)
if file_state is NOT_PRESENT_IN_STATE:
@@ -89,21 +89,21 @@ class ServiceObservation(AbstractObservation):
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0}
"Default observation is what should be returned when the service doesn't exist."
def __init__(self, where: List[str] = []) -> None:
def __init__(self, where: Optional[List[str]] = None) -> None:
"""
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_uuid>,'servics', <service_uuid>]`
`['network','nodes',<node_uuid>,'services', <service_uuid>]`
:type where: Optional[List[str]]
"""
super().__init__()
self.where: List[str] = where
self.where: Optional[List[str]] = where
def observe(self, state: Dict) -> Dict:
if not self.where:
if self.where is None:
return self.default_observation
service_state = access_from_nested_dict(state, self.where)
@@ -120,7 +120,7 @@ class LinkObservation(AbstractObservation):
default_observation: spaces.Space = {"protocols": {"all": {"load": 0}}}
"Default observation is what should be returned when the link doesn't exist."
def __init__(self, where: List[str] = []) -> None:
def __init__(self, where: Optional[List[str]] = None) -> None:
"""
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
@@ -131,10 +131,10 @@ class LinkObservation(AbstractObservation):
:type where: Optional[List[str]]
"""
super().__init__()
self.where: List[str] = where
self.where: Optional[List[str]] = where
def observe(self, state: Dict) -> Dict:
if not self.where:
if self.where is None:
return self.default_observation
link_state = access_from_nested_dict(state, self.where)
@@ -156,7 +156,7 @@ class LinkObservation(AbstractObservation):
class FolderObservation(AbstractObservation):
def __init__(self, where: List[str] = [], files: List[FileObservation] = []) -> None:
def __init__(self, where: Optional[List[str]] = None, files: List[FileObservation] = []) -> None:
"""Initialise folder Observation, including files inside of the folder.
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
@@ -175,7 +175,7 @@ class FolderObservation(AbstractObservation):
"""
super().__init__()
self.where: List[str] = where
self.where: Optional[List[str]] = where
self.files: List[FileObservation] = files
@@ -185,7 +185,7 @@ class FolderObservation(AbstractObservation):
}
def observe(self, state: Dict) -> Dict:
if not self.where:
if self.where is None:
return self.default_observation
folder_state = access_from_nested_dict(state, self.where)
if folder_state is NOT_PRESENT_IN_STATE:
@@ -213,12 +213,12 @@ class FolderObservation(AbstractObservation):
class NicObservation(AbstractObservation):
default_observation: spaces.Space = {"nic_status": 0}
def __init__(self, where: List[str] = []) -> None:
super.__init__()
self.where: List[str] = where
def __init__(self, where: Optional[List[str]] = None) -> None:
super().__init__()
self.where: Optional[List[str]] = where
def observe(self, state: Dict) -> Dict:
if not self.where:
if self.where is None:
return self.default_observation
nic_state = access_from_nested_dict(state, self.where)
if nic_state is NOT_PRESENT_IN_STATE:
@@ -234,10 +234,11 @@ class NicObservation(AbstractObservation):
class NodeObservation(AbstractObservation):
def __init__(
self,
where: List[str] = [],
where: Optional[List[str]] = None,
services: List[ServiceObservation] = [],
folders: List[FolderObservation] = [],
nics: List[NicObservation] = [],
logon_status:bool=False
) -> None:
"""
Configurable observation for a node in the simulation.
@@ -259,12 +260,13 @@ class NodeObservation(AbstractObservation):
:param max_nics: Max number of NICS in this node's obs space, defaults to 5
:type max_nics: int, optional
"""
super.__init__()
self.where: List[str] = where
super().__init__()
self.where: Optional[List[str]] = where
self.services: List[ServiceObservation] = services
self.folders: List[FolderObservation] = folders
self.nics: List[NicObservation] = nics
self.logon_status:bool=logon_status
self.default_observation: Dict = {
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
@@ -272,9 +274,11 @@ class NodeObservation(AbstractObservation):
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.nics)},
"operating_status": 0,
}
if self.logon_status:
self.default_observation['logon_status']=0
def observe(self, state: Dict) -> Dict:
if not self.where:
if self.where is None:
return self.default_observation
node_state = access_from_nested_dict(state, self.where)
@@ -288,18 +292,24 @@ class NodeObservation(AbstractObservation):
obs["operating_status"] = node_state["operating_state"]
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
if self.logon_status:
obs['logon_status'] = 0
return obs
@property
def space(self) -> spaces.Space:
return spaces.Dict(
{
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
"operating_status": spaces.Discrete(0),
"NICS": spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)}),
}
)
space_shape = {
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
"operating_status": spaces.Discrete(5),
"NICS": spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)}),
}
if self.logon_status:
space_shape['logon_status'] = spaces.Discrete(3)
return spaces.Dict(space_shape)
class AclObservation(AbstractObservation):
@@ -308,41 +318,33 @@ class AclObservation(AbstractObservation):
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
# this needs adding, but not for the MVP.
def __init__(
self, nodes: List[str], ports: List[int], protocols: list[str], where: List[str] = [], num_rules: int = 10
self, node_ip_to_id: Dict[str,int], ports: List[int], protocols: list[str], where: Optional[List[str]] = None, num_rules: int = 10
) -> None:
super().__init__()
self.where: List[str] = where
self.where: Optional[List[str]] = where
self.num_rules: int = num_rules
self.node_to_id: Dict[str, int] = {node: i + 1 for i, node in enumerate(nodes)}
self.node_to_id: Dict[str, int] = node_ip_to_id
"List of node IP addresses, order in this list determines how they are converted to an ID"
self.port_to_id: Dict[int, int] = {port: i + 1 for i, port in enumerate(ports)}
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
"List of ports which are part of the game that define the ordering when converting to an ID"
self.protocol_to_id: Dict[str, int] = {protocol: i + 1 for i, protocol in enumerate(protocols)}
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
"List of protocols which are part of the game, defines ordering when converting to an ID"
self.default_observation: spaces.Space = spaces.Dict(
{
"RULES": spaces.Dict(
{
i
+ 1: spaces.Dict(
{
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
)
for i in range(self.num_rules)
}
)
self.default_observation: Dict = {
"RULES": {i+ 1:{
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
for i in range(self.num_rules)
}
)
}
def observe(self, state: Dict) -> Dict:
if not self.where:
if self.where is None:
return self.default_observation
acl_state: Dict = access_from_nested_dict(state, self.where)
if acl_state is NOT_PRESENT_IN_STATE:
@@ -379,16 +381,16 @@ class AclObservation(AbstractObservation):
{
"RULE": spaces.Dict(
{
i
+ 1: spaces.Dict(
i + 1: spaces.Dict(
{
"position": spaces.Discrete(self.num_rules),
"permission": spaces.Discrete(3),
"source_node_id": spaces.Discrete(len(self.nodes) + 1),
"source_port": spaces.Discrete(len(self.ports) + 1),
"dest_node_id": spaces.Discrete(len(self.nodes) + 1),
"dest_port": spaces.Discrete(len(self.ports) + 1),
"protocol": spaces.Discrete(len(self.protocols) + 1),
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
}
)
for i in range(self.num_rules)
@@ -398,14 +400,96 @@ class AclObservation(AbstractObservation):
)
class ICSObservation(AbstractObservation):
def observe(self, state: Dict) -> Any:
return 0
class NullObservation(AbstractObservation):
def __init__(self, where:Optional[List[str]]=None):
self.default_observation: Dict = {}
def observe(self, state: Dict) -> Dict:
return {}
@property
def space(self) -> spaces.Space:
return spaces.Discrete(1)
return spaces.Dict({})
class ICSObservation(NullObservation): pass
class UC2BlueObservation(AbstractObservation):
def __init__(
self,
nodes: List[NodeObservation],
links: List[LinkObservation],
acl: AclObservation,
ics: ICSObservation,
where:Optional[List[str]] = None,
) -> None:
super().__init__()
self.where: Optional[List[str]] = where
self.nodes: List[NodeObservation] = nodes
self.links: List[LinkObservation] = links
self.acl: AclObservation = acl
self.ics: ICSObservation = ics
self.default_observation : Dict = {
"NODES": {i+1: n.default_observation for i,n in enumerate(self.nodes)},
"LINKS": {i+1: l.default_observation for i,l in enumerate(self.links)},
"ACL": self.acl.default_observation,
"ICS": self.ics.default_observation,
}
def observe(self, state:Dict) -> Dict:
if self.where is None:
return self.default_observation
obs = {}
obs['NODES'] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
obs['LINKS'] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
obs['ACL'] = {self.acl.observe(state)}
obs['ICS'] = {self.ics.observe(state)}
return obs
@property
def space(self) -> spaces.Space:
return spaces.Dict({
"NODES": spaces.Dict({i+1: node.space for i, node in enumerate(self.nodes)}),
"LINKS": spaces.Dict({i+1: link.space for i, link in enumerate(self.links)}),
"ACL": self.acl.space,
"ICS": self.ics.space,
})
@classmethod
def from_config(cls, config:Dict, sim:Simulation):
nodes = ...
links = ...
acl = ...
ics = ...
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=['network'])
return new
class UC2RedObservation(AbstractObservation):
def __init__(self, nodes:List[NodeObservation], where:Optional[List[str]] = None) -> None:
super().__init__()
self.where:Optional[List[str]] = where
self.nodes: List[NodeObservation] = nodes
self.default_observation=...#TODO
def observe(self, state: Dict) -> Any:
return super().observe(state)
@property
def space(self) -> spaces.Space:
... #TODO
@classmethod
def from_config(cls, config: Dict, sim:Simulation):
... #TODO
class ObservationSpace:
"""
@@ -422,29 +506,12 @@ class ObservationSpace:
# what this class does:
# keep a list of observations
# create observations for an actor from the config
def __init__(
self,
simulation: Simulation,
nodes: List[NodeObservation] = [],
links: List[LinkObservation] = [],
acl: Optional[AclObservation] = None,
ics: Optional[ICSObservation] = None,
) -> None:
self.simulation: Simulation = simulation
self.parts: Dict[str, AbstractObservation] = {}
def __init__(self, observation:AbstractObservation) -> None:
self.obs: AbstractObservation = observation
self.nodes: List[NodeObservation] = nodes
self.links: List[LinkObservation] = links
self.acl: Optional[AclObservation] = acl
self.ics: Optional[ICSObservation] = ics
def observe(self) -> None:
...
def observe(self, state) -> Dict:
return self.obs.observe(state)
@property
def space(self) -> None:
...
@classmethod
def from_config(self) -> None:
...
return self.obs.space

View File

@@ -1,20 +1,17 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from pydantic import BaseModel
class AbstractReward(BaseModel):
def __call__(self, states: List[Dict]) -> float:
"""_summary_
:param state: _description_
:type state: Dict
:return: _description_
:rtype: float
"""
class AbstractReward():
def __init__(self):
...
def calculate(self, state:Dict) -> float:
return 0.3
class RewardFunction(BaseModel):
...
class RewardFunction():
def __init__(self, reward_function:AbstractReward):
self.reward: AbstractReward = reward_function
def calculate(self, state:Dict) -> float:
return self.reward.calculate(state)

View File

@@ -4,3 +4,49 @@
# 3. create actors and configure their actions/observations/rewards/ anything else
# 4. Create connection with ARCD GATE
# 5. idk
from primaite.simulator.sim_container import Simulation
from primaite.game.agent.interface import AbstractAgent
from typing import List
class PrimaiteSession:
def __init__(self):
self.simulation: Simulation = Simulation()
self.agents:List[AbstractAgent] = []
self.step_counter:int = 0
self.episode_counter:int = 0
def step(self):
# currently designed with assumption that all agents act once per step in order
for agent in self.agents:
# 3. primaite session asks simulation to provide initial state
# 4. primate session gives state to all agents
# 5. primaite session asks agents to produce an action based on most recent state
sim_state = self.simulation.describe_state()
# 6. each agent takes most recent state and converts it to CAOS observation
agent_obs = agent.get_obs_from_state(sim_state)
# 7. meanwhile each agent also takes state and calculates reward
agent_reward = agent.get_reward_from_state(sim_state)
# 8. each agent takes observation and applies decision rule to observation to create CAOS
# action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action
# to discrete(40) is only necessary for purposes of RL learning, therefore that bit of
# code should live inside of the GATE agent subclass)
# gets action in CAOS format
agent_action = agent.get_action(agent_obs, agent_reward)
# 9. CAOS action is converted into request (extra information might be needed to enrich
# the request, this is what the execution definition is there for)
agent_request = agent.format_request(agent_action)
# 10. primaite session receives the action from the agents and asks the simulation to apply each
self.simulation.apply_action(agent_request)
self.simulation.apply_timestep(self.step_counter)
self.step_counter += 1