diff --git a/example_config.yaml b/example_config.yaml index 79cfccac..8cf401cc 100644 --- a/example_config.yaml +++ b/example_config.yaml @@ -27,11 +27,11 @@ game_config: - type: LOGON - type: LOGOFF applications: - - application_ref: client_2_web_browser - actions: - - type: EXECUTE - execution_definition: - target_address: arcd.com + # - application_ref: client_2_web_browser + # actions: + # - type: EXECUTE + # execution_definition: + # target_address: arcd.com reward_function: null agent_settings: start_step: 5 @@ -42,7 +42,8 @@ game_config: team: RED type: RedDatabaseCorruptingAgent observation_space: - network: + type: UC2RedObservation + options: nodes: - node_ref: client_1 observations: @@ -85,13 +86,307 @@ game_config: - ref: defender team: blue - type: GATE_RL_AGENT + type: GATERLAgent observation_space: - network: + type: UC2BlueObservation + options: nodes: - node_ref: router_1 #TODO: more sub-options here - node_ref: switch_1 - node_ref: switch_2 + - node_ref: domain_controller + services: + - service_ref: domain_controller_dns_server + - node_ref: web_server + services: + - service_ref: web_server_database_client + - node_ref: database_server + services: + - service_ref: database_service + folders: + - folder_name: database + files: + - file_name: database.db + - node_ref: backup_server + # services: + # - service_ref: backup_service + - node_ref: security_suite + - node_ref: client_1 + - node_ref: client_2 + links: + - link_ref: router_1___switch_1 + - link_ref: router_1___switch_2 + - link_ref: switch_1___domain_controller + - link_ref: switch_1___web_server + - link_ref: switch_1___database_server + - link_ref: switch_1___backup_server + - link_ref: switch_1___security_suite + - link_ref: switch_2___client_1 + - link_ref: switch_2___client_2 + - link_ref: switch_2___security_suite + acl: + router_node_ref: router_1 + ics: null + + + action_space: + action_list: + - DONOTHING + - NODE_SERVICE_SCAN + - NODE_SERVICE_STOP + # - NODE_SERVICE_START + # - NODE_SERVICE_PAUSE + # - NODE_SERVICE_RESUME + # - NODE_SERVICE_RESTART + # - NODE_SERVICE_DISABLE + # - NODE_SERVICE_ENABLE + # - NODE_FILE_SCAN + # - NODE_FILE_CHECKHASH + # - NODE_FILE_DELETE + # - NODE_FILE_REPAIR + # - NODE_FILE_RESTORE + # - NODE_FOLDER_SCAN + # - NODE_FOLDER_CHECKHASH + # - NODE_FOLDER_REPAIR + # - NODE_FOLDER_RESTORE + # - NODE_OS_SCAN + # - NODE_SHUTDOWN + # - NODE_STARTUP + # - NODE_RESET + # - NETWORK_ACL_ADDRULE + # - NETWORK_ACL_REMOVERULE + # - NETWORK_NIC_ENABLE + - NETWORK_NIC_DISABLE + + action_map: + 0: + - action: DONOTHING + options: {} + # scan webapp service + 1: + - action: NODE_SERVICE_SCAN + options: + - node_id: 2 + - service_id: 1 + # stop webapp service + 2: + - action: NODE_SERVICE_STOP + options: + - node_id: 2 + - service_id: 1 + # start webapp service + 3: + - action: "NODE_SERVICE_START" + options: + - node_id: 2 + - service_id: 1 + 4: + - action: "NODE_SERVICE_PAUSE" + options: + - node_id: 2 + - service_id: 1 + 5: + - action: "NODE_SERVICE_RESUME" + options: + - node_id: 2 + - service_id: 1 + 6: + - action: "NODE_SERVICE_RESTART" + options: + - node_id: 2 + - service_id: 1 + 7: + - action: "NODE_SERVICE_DISABLE" + options: + - node_id: 2 + - service_id: 1 + 8: + - action: "NODE_SERVICE_ENABLE" + options: + - node_id: 2 + - service_id: 1 + 9: + - action: "NODE_FILE_SCAN" + options: + - node_id: 3 + - folder_id: 1 + - file_id: 1 + 10: + - action: "NODE_FILE_CHECKHASH" + options: + - node_id: 3 + - folder_id: 1 + - file_id: 1 + 11: + - action: "NODE_FILE_DELETE" + options: + - node_id: 3 + - folder_id: 1 + - file_id: 1 + 12: + - action: "NODE_FILE_REPAIR" + options: + - node_id: 3 + - folder_id: 1 + - file_id: 1 + 13: + - action: "NODE_FILE_RESTORE" + options: + - node_id: 3 + - folder_id: 1 + - file_id: 1 + 14: + - action: "NODE_FOLDER_SCAN" + options: + - node_id: 3 + - folder_id: 1 + 15: + - action: "NODE_FOLDER_CHECKHASH" + options: + - node_id: 3 + - folder_id: 1 + 16: + - action: "NODE_FOLDER_REPAIR" + options: + - node_id: 3 + - folder_id: 1 + 17: + - action: "NODE_FOLDER_RESTORE" + options: + - node_id: 3 + - folder_id: 1 + 18: + - action: "NODE_OS_SCAN" + options: + - node_id: 3 + 19: + - action: "NODE_SHUTDOWN" + options: + - node_id: 6 + 20: + - action: "NODE_STARTUP" + options: + - node_id: 6 + 21: + - action: "NODE_RESET" + options: + - node_id: 6 + 22: + - action: "NETWORK_ACL_ADDRULE" + options: + - position: 6 + - permission: 2 + - source_node_id: ... + - dest_node_id: ... + - source_port_id: ... + - dest_port_id: ... + - protocol_id: ... + 23: + - action: "NETWORK_ACL_ADDRULE" + options: + - position: 5 + - permission: 2 + - source_node_id: ... + - dest_node_id: ... + - source_port_id: ... + - dest_port_id: ... + - protocol_id: ... + 24: + - action: "NETWORK_ACL_ADDRULE" + options: + - position: 4 + - permission: 2 + - source_node_id: ... + - dest_node_id: ... + - source_port_id: ... + - dest_port_id: ... + - protocol_id: ... + 25: + - action: "NETWORK_ACL_ADDRULE" + options: + - position: 3 + - permission: 2 + - source_node_id: ... + - dest_node_id: ... + - source_port_id: ... + - dest_port_id: ... + - protocol_id: ... + 26: + - action: "NETWORK_ACL_ADDRULE" + options: + - position: 2 + - permission: 2 + - source_node_id: ... + - dest_node_id: ... + - source_port_id: ... + - dest_port_id: ... + - protocol_id: ... + 27: + - action: "NETWORK_ACL_ADDRULE" + options: + - position: 1 + - permission: 2 + - source_node_id: ... + - dest_node_id: ... + - source_port_id: ... + - dest_port_id: ... + - protocol_id: ... + 28: + - action: "NETWORK_ACL_REMOVERULE" + options: + - position: 0 + 29: + - action: "NETWORK_ACL_REMOVERULE" + options: + - position: 1 + 30: + - action: "NETWORK_ACL_REMOVERULE" + options: + - position: 2 + 31: + - action: "NETWORK_ACL_REMOVERULE" + options: + - position: 3 + 32: + - action: "NETWORK_ACL_REMOVERULE" + options: + - position: 4 + 33: + - action: "NETWORK_ACL_REMOVERULE" + options: + - position: 5 + 34: + - action: "NETWORK_ACL_REMOVERULE" + options: + - position: 6 + 35: + - action: "NETWORK_ACL_REMOVERULE" + options: + - position: 7 + 36: + - action: "NETWORK_ACL_REMOVERULE" + options: + - position: 8 + 37: + - action: "NETWORK_ACL_REMOVERULE" + options: + - position: 9 + 38: + - action: "NETWORK_NIC_DISABLE" + options: + - node_id: 6 + - nic_index: 1 + 39: + - action: "NETWORK_NIC_ENABLE" + options: + - node_id: 6 + - nic_index: 1 + + options: + nodes: + - node_ref: router_1 + - node_ref: switch_1 + - node_ref: switch_2 - node_ref: domain_controller - node_ref: web_server - node_ref: database_server @@ -99,18 +394,13 @@ game_config: - node_ref: security_suite - node_ref: client_1 - node_ref: client_2 - links: - - link_ref: ... # - acl: ... # - ics: ... # + max_folders_per_node: 2 + max_files_per_folder: 2 + max_services_per_node: 2 + max_nics_per_node: 8 + max_acl_rules: 10 - action_space: - actions: - - type: DO_NOTHING - network: - nodes: - - node_ref: router_1 reward_function: # ... agent_settings: @@ -175,7 +465,6 @@ simulation: domain_mapping: arcd.com: 192.168.1.12 # web server - - ref: web_server type: server hostname: web_server @@ -200,7 +489,6 @@ simulation: - ref: database_service type: DatabaseService - - ref: backup_server type: server hostname: backup_server @@ -224,7 +512,6 @@ simulation: ip_address: 192.168.10.110 subnet_mask: 255.255.255.0 - - ref: client_1 type: computer hostname: client_1 @@ -251,7 +538,6 @@ simulation: - ref: client_2_dns_client type: DNSClient - links: - ref: router_1___switch_1 endpoint_a_ref: router_1 diff --git a/sandbox.ipynb b/sandbox.ipynb index 05efcfa2..3ff72170 100644 --- a/sandbox.ipynb +++ b/sandbox.ipynb @@ -2,9 +2,18 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "%load_ext autoreload\n", "%autoreload 2" @@ -12,7 +21,381 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.game.session import PrimaiteSession\n", + "from primaite.simulator.sim_container import Simulation\n", + "from primaite.game.agent.interface import AbstractAgent\n", + "from primaite.simulator.network.networks import arcd_uc2_network\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "sess = PrimaiteSession()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "network = sess.simulation.network" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "from ipaddress import IPv4Address\n", + "\n", + "from primaite.simulator.network.container import Network\n", + "from primaite.simulator.network.hardware.base import NIC\n", + "from primaite.simulator.network.hardware.nodes.computer import Computer\n", + "from primaite.simulator.network.hardware.nodes.router import ACLAction, Router\n", + "from primaite.simulator.network.hardware.nodes.server import Server\n", + "from primaite.simulator.network.hardware.nodes.switch import Switch\n", + "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", + "from primaite.simulator.network.transmission.transport_layer import Port\n", + "from primaite.simulator.system.applications.database_client import DatabaseClient\n", + "from primaite.simulator.system.services.database_service import DatabaseService\n", + "from primaite.simulator.system.services.dns_client import DNSClient\n", + "from primaite.simulator.system.services.dns_server import DNSServer\n", + "from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-10-02 15:10:20,422: Added node 6abb7664-4d17-45ff-a3c7-dbcccffcfd6d to Network 045a3114-4aac-4687-a10e-432cfd138325\n", + "2023-10-02 15:10:20,424: Added node 3edbc521-3c80-47e3-8017-dbc38fb00a73 to Network 045a3114-4aac-4687-a10e-432cfd138325\n", + "2023-10-02 15:10:20,428: Added node 94457fb9-04a1-4dc1-9ff7-b64df0da7424 to Network 045a3114-4aac-4687-a10e-432cfd138325\n", + "2023-10-02 15:10:20,432: Added node 0d311d72-139c-41bf-aef7-fa9b01b124d7 to Network 045a3114-4aac-4687-a10e-432cfd138325\n", + "2023-10-02 15:10:20,439: Added node 6161e785-f377-48de-aa4f-20d3646da635 to Network 045a3114-4aac-4687-a10e-432cfd138325\n", + "2023-10-02 15:10:20,444: Added node 55a9e9f8-ee3a-4c28-9b6d-c0c0d78a3f6a to Network 045a3114-4aac-4687-a10e-432cfd138325\n", + "2023-10-02 15:10:20,447: Added node 2f04ca45-3439-489a-81f7-41cca5ae8adc to Network 045a3114-4aac-4687-a10e-432cfd138325\n", + "2023-10-02 15:10:20,531: Added node 98660c30-8e48-4b96-967a-d62ca71b4d6d to Network 045a3114-4aac-4687-a10e-432cfd138325\n", + "2023-10-02 15:10:20,545: Added node 1a184184-b204-40de-986a-4d7459036dbe to Network 045a3114-4aac-4687-a10e-432cfd138325\n", + "2023-10-02 15:10:20,551: Added node 17b92b9a-6805-4677-85f7-c0c0521a6e25 to Network 045a3114-4aac-4687-a10e-432cfd138325\n", + "2023-10-02 15:10:20,555::ERROR::primaite.simulator.network.hardware.base::175::NIC da:f3:1b:87:24:20/192.168.10.110 cannot be enabled as it is not connected to a Link\n" + ] + } + ], + "source": [ + "router_1 = Router(hostname=\"router_1\", num_ports=5)\n", + "router_1.power_on()\n", + "router_1.configure_port(port=1, ip_address=\"192.168.1.1\", subnet_mask=\"255.255.255.0\")\n", + "router_1.configure_port(port=2, ip_address=\"192.168.10.1\", subnet_mask=\"255.255.255.0\")\n", + "\n", + "# Switch 1\n", + "switch_1 = Switch(hostname=\"switch_1\", num_ports=8)\n", + "switch_1.power_on()\n", + "network.connect(endpoint_a=router_1.ethernet_ports[1], endpoint_b=switch_1.switch_ports[8])\n", + "router_1.enable_port(1)\n", + "\n", + "# Switch 2\n", + "switch_2 = Switch(hostname=\"switch_2\", num_ports=8)\n", + "switch_2.power_on()\n", + "network.connect(endpoint_a=router_1.ethernet_ports[2], endpoint_b=switch_2.switch_ports[8])\n", + "router_1.enable_port(2)\n", + "\n", + "# Client 1\n", + "client_1 = Computer(\n", + " hostname=\"client_1\",\n", + " ip_address=\"192.168.10.21\",\n", + " subnet_mask=\"255.255.255.0\",\n", + " default_gateway=\"192.168.10.1\",\n", + " dns_server=IPv4Address(\"192.168.1.10\"),\n", + ")\n", + "client_1.power_on()\n", + "client_1.software_manager.install(DNSClient)\n", + "client_1_dns_client_service: DNSServer = client_1.software_manager.software[\"DNSClient\"] # noqa\n", + "client_1_dns_client_service.start()\n", + "network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])\n", + "client_1.software_manager.install(DataManipulationBot)\n", + "db_manipulation_bot: DataManipulationBot = client_1.software_manager.software[\"DataManipulationBot\"]\n", + "db_manipulation_bot.configure(server_ip_address=IPv4Address(\"192.168.1.14\"), payload=\"DROP TABLE IF EXISTS user;\")\n", + "\n", + "# Client 2\n", + "client_2 = Computer(\n", + " hostname=\"client_2\",\n", + " ip_address=\"192.168.10.22\",\n", + " subnet_mask=\"255.255.255.0\",\n", + " default_gateway=\"192.168.10.1\",\n", + " dns_server=IPv4Address(\"192.168.1.10\"),\n", + ")\n", + "client_2.power_on()\n", + "client_2.software_manager.install(DNSClient)\n", + "client_2_dns_client_service: DNSServer = client_2.software_manager.software[\"DNSClient\"] # noqa\n", + "client_2_dns_client_service.start()\n", + "network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2])\n", + "\n", + "# Domain Controller\n", + "domain_controller = Server(\n", + " hostname=\"domain_controller\",\n", + " ip_address=\"192.168.1.10\",\n", + " subnet_mask=\"255.255.255.0\",\n", + " default_gateway=\"192.168.1.1\",\n", + ")\n", + "domain_controller.power_on()\n", + "domain_controller.software_manager.install(DNSServer)\n", + "\n", + "network.connect(endpoint_b=domain_controller.ethernet_port[1], endpoint_a=switch_1.switch_ports[1])\n", + "\n", + "# Database Server\n", + "database_server = Server(\n", + " hostname=\"database_server\",\n", + " ip_address=\"192.168.1.14\",\n", + " subnet_mask=\"255.255.255.0\",\n", + " default_gateway=\"192.168.1.1\",\n", + " dns_server=IPv4Address(\"192.168.1.10\"),\n", + ")\n", + "database_server.power_on()\n", + "network.connect(endpoint_b=database_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[3])\n", + "\n", + "ddl = \"\"\"\n", + "CREATE TABLE IF NOT EXISTS user (\n", + "id INTEGER PRIMARY KEY AUTOINCREMENT,\n", + "name VARCHAR(50) NOT NULL,\n", + "email VARCHAR(50) NOT NULL,\n", + "age INT,\n", + "city VARCHAR(50),\n", + "occupation VARCHAR(50)\n", + ");\"\"\"\n", + "\n", + "user_insert_statements = [\n", + " \"INSERT INTO user (name, email, age, city, occupation) VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');\", # noqa\n", + " \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');\", # noqa\n", + " \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');\", # noqa\n", + " \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');\", # noqa\n", + " \"INSERT INTO user (name, email, age, city, occupation) VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');\", # noqa\n", + " \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');\", # noqa\n", + " \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');\", # noqa\n", + " \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');\", # noqa\n", + " \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');\", # noqa\n", + " \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');\", # noqa\n", + " \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');\", # noqa\n", + " \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');\", # noqa\n", + " \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');\", # noqa\n", + " \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');\", # noqa\n", + "]\n", + "database_server.software_manager.install(DatabaseService)\n", + "database_service: DatabaseService = database_server.software_manager.software[\"DatabaseService\"] # noqa\n", + "database_service.start()\n", + "database_service._process_sql(ddl, None) # noqa\n", + "for insert_statement in user_insert_statements:\n", + " database_service._process_sql(insert_statement, None) # noqa\n", + "\n", + "# Web Server\n", + "web_server = Server(\n", + " hostname=\"web_server\",\n", + " ip_address=\"192.168.1.12\",\n", + " subnet_mask=\"255.255.255.0\",\n", + " default_gateway=\"192.168.1.1\",\n", + " dns_server=IPv4Address(\"192.168.1.10\"),\n", + ")\n", + "web_server.power_on()\n", + "web_server.software_manager.install(DatabaseClient)\n", + "\n", + "database_client: DatabaseClient = web_server.software_manager.software[\"DatabaseClient\"]\n", + "database_client.configure(server_ip_address=IPv4Address(\"192.168.1.14\"))\n", + "network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2])\n", + "database_client.run()\n", + "database_client.connect()\n", + "\n", + "# register the web_server to a domain\n", + "dns_server_service: DNSServer = domain_controller.software_manager.software[\"DNSServer\"] # noqa\n", + "dns_server_service.start()\n", + "dns_server_service.dns_register(\"arcd.com\", web_server.ip_address)\n", + "\n", + "# Backup Server\n", + "backup_server = Server(\n", + " hostname=\"backup_server\",\n", + " ip_address=\"192.168.1.16\",\n", + " subnet_mask=\"255.255.255.0\",\n", + " default_gateway=\"192.168.1.1\",\n", + " dns_server=IPv4Address(\"192.168.1.10\"),\n", + ")\n", + "backup_server.power_on()\n", + "network.connect(endpoint_b=backup_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[4])\n", + "\n", + "# Security Suite\n", + "security_suite = Server(\n", + " hostname=\"security_suite\",\n", + " ip_address=\"192.168.1.110\",\n", + " subnet_mask=\"255.255.255.0\",\n", + " default_gateway=\"192.168.1.1\",\n", + " dns_server=IPv4Address(\"192.168.1.10\"),\n", + ")\n", + "security_suite.power_on()\n", + "network.connect(endpoint_b=security_suite.ethernet_port[1], endpoint_a=switch_1.switch_ports[7])\n", + "security_suite.connect_nic(NIC(ip_address=\"192.168.10.110\", subnet_mask=\"255.255.255.0\"))\n", + "network.connect(endpoint_b=security_suite.ethernet_port[2], endpoint_a=switch_2.switch_ports[7])\n", + "\n", + "router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)\n", + "\n", + "router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)\n", + "\n", + "# Allow PostgreSQL requests\n", + "router_1.acl.add_rule(\n", + " action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0\n", + ")\n", + "\n", + "# Allow DNS requests\n", + "router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "node_uuid_list = list(sess.simulation.network.nodes.keys())" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.game.agent.actions import ActionManager" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "actman = ActionManager(sess.simulation, [\"DONOTHING\", \"NODE_SERVICE_SCAN\", \"NODE_SERVICE_STOP\", \"NODE_FOLDER_SCAN\"],node_uuid_list,act_map={\n", + " 0:{\n", + " \"action\": \"DONOTHING\",\n", + " \"options\": {}\n", + " },\n", + " 1:{\n", + " \"action\": \"NODE_SERVICE_SCAN\",\n", + " \"options\": {\"node_id\":0, \"service_id\":0},\n", + " },\n", + " 2:{\n", + " \"action\": \"NODE_SERVICE_SCAN\",\n", + " \"options\": {\"node_id\":1, \"service_id\":0},\n", + " },\n", + " 3:{\n", + " \"action\": \"NODE_FOLDER_SCAN\",\n", + " \"options\": {\"node_id\":4, \"folder_id\":0},\n", + " }\n", + "})" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "act_id, act_options = actman.get_action(3)\n", + "my_trial_act = actman.form_request(action_identifier=act_id, action_options=act_options)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "sess.simulation.apply_action(my_trial_act)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['network',\n", + " 'node',\n", + " '6161e785-f377-48de-aa4f-20d3646da635',\n", + " 'file_system',\n", + " 'folder',\n", + " '5aefe92b-923c-4684-b3bf-e78dd18d4771',\n", + " 'scan']" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "my_trial_act" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sess.step()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sess.step_counter" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from gym import spaces" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sp = spaces.Tuple( (spaces.MultiDiscrete([3, 2]), spaces.MultiDiscrete([3, 2]), spaces.MultiDiscrete([3, 2]),))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sp.sample()" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -39,40 +422,16 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-09-26 12:19:50,895: Added node 0fb262e1-a714-420a-aec7-be37f0deeb75 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n", - "2023-09-26 12:19:50,898: Added node 310ca8d7-01e0-401e-b604-705c290e5376 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n", - "2023-09-26 12:19:50,900: Added node b3b08f1f-7805-47b2-bdb6-3d83098cd740 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n", - "2023-09-26 12:19:50,903: Added node adb37f3e-2307-4123-bff3-01f125883be8 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n", - "2023-09-26 12:19:50,906: Added node 1a490716-2ccd-452d-b87e-324d29120b59 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n", - "2023-09-26 12:19:50,911: Added node 033460d8-0249-4bdd-aaf0-751b24cc0a1e to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n", - "2023-09-26 12:19:50,914: Added node 1e7e4e49-78bf-4031-8372-ee71902720f3 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n", - "2023-09-26 12:19:50,916: Added node c9f24a13-e5c8-437b-9234-b0c3f8120e2c to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n", - "2023-09-26 12:19:50,920: Added node c881f3ee-2176-493b-a6c2-cad829bf0b6d to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n", - "2023-09-26 12:19:50,922: Added node a3ea75d8-bc2c-4713-92a4-2588b4f43ed6 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "service type not found DatabaseBackup\n", - "service type not found WebBrowser\n" - ] - } - ], + "outputs": [], "source": [ "# import yaml\n", "\n", "\n", "from typing import Dict\n", "from primaite.game.agent.interface import AbstractAgent\n", + "from primaite.game.agent.observations import AclObservation, FileObservation, FolderObservation, ICSObservation, LinkObservation, NicObservation, NodeObservation, NullObservation, ServiceObservation, UC2BlueObservation, UC2RedObservation\n", "from primaite.simulator.network.hardware.base import NIC, Link, Node\n", "from primaite.simulator.system.services.service import Service\n", "\n", @@ -130,8 +489,8 @@ " subnet_mask=port_cfg['subnet_mask'])\n", " if 'acl' in node_cfg:\n", " for r_num, r_cfg in node_cfg['acl'].items():\n", - " # excuse the uncommon walrus operator ` := `. It's just here as a shorthand, so that we can do\n", - " # both of these things once: check if a key is defined, access and convert it to a \n", + " # excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating \n", + " # this: 'r_cfg.get('src_port')'\n", " # Port/IPProtocol. TODO Refactor\n", " new_node.acl.add_rule(\n", " action = ACLAction[r_cfg['action']],\n", @@ -211,8 +570,70 @@ " action_space_cfg = agent_cfg['action_space']\n", " observation_space_cfg = agent_cfg['observation_space']\n", " reward_function_cfg = agent_cfg['reward_function']\n", + " \n", + " # CREATE OBSERVATION SPACE\n", + " if observation_space_cfg is None:\n", + " obs_space = NullObservation()\n", + " elif observation_space_cfg['type'] == 'UC2BlueObservation':\n", + " node_obs_list = []\n", + " link_obs_list = []\n", + " \n", + " \n", + " #node ip to index maps ip addresses to node id, as there are potentially multiple nics on a node, there are multiple ip addresses\n", + " node_ip_to_index = {}\n", + " for node_idx, node_cfg in enumerate(nodes_cfg):\n", + " n_ref = node_cfg['ref']\n", + " n_obj = net.nodes[ref_map_nodes[n_ref]]\n", + " for nic_uuid, nic_obj in n_obj.nics.items():\n", + " node_ip_to_index[nic_obj.ip_address] = node_idx + 2\n", + "\n", + " \n", + " \n", + " for node_obs_cfg in observation_space_cfg['options']['nodes']:\n", + " node_ref = node_obs_cfg['node_ref']\n", + " folder_obs_list = []\n", + " service_obs_list = []\n", + " if 'services' in node_obs_cfg:\n", + " for service_obs_cfg in node_obs_cfg['services']:\n", + " service_obs_list.append(ServiceObservation(where=['network','nodes',ref_map_nodes[node_ref],'services',ref_map_services[service_obs_cfg['service_ref']]]))\n", + " if 'folders' in node_obs_cfg:\n", + " for folder_obs_cfg in node_obs_cfg['folders']:\n", + " file_obs_list = []\n", + " if 'files' in folder_obs_cfg:\n", + " for file_obs_cfg in folder_obs_cfg['files']:\n", + " file_obs_list.append(FileObservation(where=['network','nodes',ref_map_nodes[node_ref], 'folders',folder_obs_cfg['folder_name'], 'files', file_obs_cfg['file_name']]))\n", + " folder_obs_list.append(FolderObservation(where=['network','nodes',ref_map_nodes[node_ref], 'folders',folder_obs_cfg['folder_name']], files=file_obs_list))\n", + " nic_obs_list = []\n", + " for nic_uuid in net.nodes[ref_map_nodes[node_obs_cfg['node_ref']]].nics.keys():\n", + " nic_obs_list.append(NicObservation(where=['network','nodes',ref_map_nodes[node_ref],'NICs',nic_uuid]))\n", + " node_obs_list.append(NodeObservation(where=['network','nodes',ref_map_nodes[node_ref]], services=service_obs_list, folders=folder_obs_list,nics=nic_obs_list, logon_status=False))\n", + " for link_obs_cfg in observation_space_cfg['options']['links']:\n", + " link_ref = link_obs_cfg['link_ref']\n", + " link_obs_list.append(LinkObservation(where=['network' ,'links', ref_map_links[link_ref]]))\n", + "\n", + " acl_obs = AclObservation(node_ip_to_id=node_ip_to_index, ports=game_cfg['ports'], protocols=game_cfg['ports'], where=['network','nodes',observation_space_cfg['options']['acl']['router_node_ref']])\n", + " obs_space = UC2BlueObservation(nodes=node_obs_list,links=link_obs_list,acl=acl_obs, ics=ICSObservation())\n", + " elif observation_space_cfg['type'] == 'UC2RedObservation':\n", + " obs_space = UC2RedObservation.from_config(observation_space_cfg['options'], sim=sim)\n", + " else:\n", + " print(\"observation space config not specified correctly.\")\n", + " obs_space = NullObservation()\n", + " \n", + " # CREATE ACTION SPACE\n", + " \n", + "\n", + "\n", + " # CREATE REWARD FUNCTION\n", + "\n", + " # CREATE AGENT\n", " if agent_type == 'GreenWebBrowsingAgent':\n", - " new_agent = GreenWebBrowsingAgent()\n", + " ...\n", + " elif agent_type == 'GATERLAgent':\n", + " ...\n", + " elif agent_type == 'RedDatabaseCorruptingAgent':\n", + " ...\n", + " else:\n", + " print(\"agent type not found\")\n", "\n", "\n", " #4. set up agents' actions and observation spaces.\n", @@ -228,7 +649,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(s.simulation.describe_state())" + "s.agents" ] }, { diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index cefd9917..cb7061fc 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1,21 +1,374 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List - -from pydantic import BaseModel +from typing import Any, Dict, List, Optional, Tuple +import itertools -class AbstractAction(BaseModel): +from primaite.simulator.sim_container import Simulation + +from gym import spaces + +class ExecutionDefiniton(ABC): + """ + Converter from actions to simulator requests. + + Allows adding extra data/context that defines in more detail what an action means. + """ + + """ + Examples: + ('node', 'service', 'scan', 2, 0) means scan the first service on node index 2 + -> ['network', 'nodes', , 'services', , 'scan'w] + """ + ... + + +class AbstractAction(ABC): + @abstractmethod - def __call__(self, action: Any) -> List[str]: - """_summary_ - - :param action: _description_ - :type action: Any - :return: _description_ - :rtype: List[str] + def __init__(self, manager:"ActionManager", **kwargs) -> None: """ + Init method for action. + + All action init functions should accept **kwargs as a way of ignoring extra arguments. + + Since many parameters are defined for the action space as a whole (such as max files per folder, max services + per node), we need to pass those options to every action that gets created. To pervent verbosity, these + parameters are just broadcasted to all actions and the actions can pay attention to the ones that apply. + """ + self.name:str = "" + """Human-readable action identifier used for printing, logging, and reporting.""" + self.shape = (0,) + """Tuple describing number of options for each parameter of this action. Can be passed to + gym.spaces.MultiDiscrete to form a valid space.""" + self.manager:ActionManager = manager + + + @abstractmethod + def form_request(self) -> List[str]: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [] + + +class DoNothingAction(AbstractAction): + def __init__(self, manager:"ActionManager", **kwargs) -> None: + super().__init__(manager=manager) + self.name = "DONOTHING" + self.shape = (1,) + + def form_request(self) -> List[str]: + return ["do_nothing"] + +class NodeServiceAbstractAction(AbstractAction): + """ + Base class for service actions. + + Any action which applies to a service and uses node_id and service_id as its only two parameters can inherit from + this base class. + """ + @abstractmethod + def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None: + super().__init__(manager=manager) + self.shape: Tuple[int] = (num_nodes, num_services) + self.verb:str + + def form_request(self, node_id:int, service_id:int) -> List[str]: + node_uuid = self.manager.get_node_uuid_by_idx(node_id) + service_uuid = self.manager.get_service_uuid_by_idx(node_id, service_id) + if node_uuid is None or service_uuid is None: + return ["do_nothing"] + return ['network', 'node', node_uuid, 'services', service_uuid, self.verb] + +class NodeServiceScanAction(NodeServiceAbstractAction): + def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None: + super().__init__(manager=manager) + self.verb = "scan" + +class NodeServiceStopAction(NodeServiceAbstractAction): + def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None: + super().__init__(manager=manager) + self.verb = "stop" + +class NodeServiceStartAction(NodeServiceAbstractAction): + def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None: + super().__init__(manager=manager) + self.verb = "start" + +class NodeServicePauseAction(NodeServiceAbstractAction): + def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None: + super().__init__(manager=manager) + self.verb = "pause" + +class NodeServiceResumeAction(NodeServiceAbstractAction): + def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None: + super().__init__(manager=manager) + self.verb = "resume" + +class NodeServiceRestartAction(NodeServiceAbstractAction): + def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None: + super().__init__(manager=manager) + self.verb = "restart" + +class NodeServiceDisableAction(NodeServiceAbstractAction): + def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None: + super().__init__(manager=manager) + self.verb = "disable" + +class NodeServiceEnableAction(NodeServiceAbstractAction): + def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None: + super().__init__(manager=manager) + self.verb = "enable" + + + +class NodeFolderAbstractAction(AbstractAction): + @abstractmethod + def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None: + super().__init__(manager=manager) + self.shape = (num_nodes, num_folders) + self.verb: str + + def form_request(self, node_id:int, folder_id:int) -> List[str]: + node_uuid = self.manager.get_node_uuid_by_idx(node_id) + folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id) + if node_uuid is None or folder_uuid is None: + return ["do_nothing"] + return ['network', 'node', node_uuid, 'file_system', 'folder', folder_uuid, self.verb] + +class NodeFolderScanAction(NodeFolderAbstractAction): + def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None: + super().__init__(manager, num_nodes, num_folders, **kwargs) + self.verb:str = "scan" + +class NodeFolderCheckhashAction(NodeFolderAbstractAction): + def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None: + super().__init__(manager, num_nodes, num_folders, **kwargs) + self.verb:str = "checkhash" + +class NodeFolderRepairAction(NodeFolderAbstractAction): + def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None: + super().__init__(manager, num_nodes, num_folders, **kwargs) + self.verb:str = "repair" + +class NodeFolderRestoreAction(NodeFolderAbstractAction): + def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None: + super().__init__(manager, num_nodes, num_folders, **kwargs) + self.verb:str = "restore" + + +class NodeFileAbstractAction(AbstractAction): + @abstractmethod + def __init__(self, manager:"ActionManager", num_nodes:int, num_folders:int, num_files:int, **kwargs) -> None: + super().__init__(manager=manager) + self.shape:Tuple[int] = (num_nodes, num_folders, num_files) + self.verb:str + + def form_request(self, node_id:int, folder_id:int, file_id:int) -> List[str]: + node_uuid = self.manager.get_node_uuid_by_idx(node_id) + folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id) + file_uuid = self.manager.get_file_uuid_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id) + if node_uuid is None or folder_uuid is None or file_uuid is None: + return ["do_nothing"] + return ['network', 'node', node_uuid, 'file_system', 'folder', folder_uuid, 'files', file_uuid, self.verb] + +class NodeFileScanAction(NodeFileAbstractAction): + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: + super().__init__(manager, num_nodes, num_folders, num_files, **kwargs) + self.verb = "scan" + +class NodeFileCheckhashAction(NodeFileAbstractAction): + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: + super().__init__(manager, num_nodes, num_folders, num_files, **kwargs) + self.verb = "checkhash" + +class NodeFileDeleteAction(NodeFileAbstractAction): + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: + super().__init__(manager, num_nodes, num_folders, num_files, **kwargs) + self.verb = "delete" + +class NodeFileRepairAction(NodeFileAbstractAction): + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: + super().__init__(manager, num_nodes, num_folders, num_files, **kwargs) + self.verb = "repair" + +class NodeFileRestoreAction(NodeFileAbstractAction): + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: + super().__init__(manager, num_nodes, num_folders, num_files, **kwargs) + self.verb = "restore" + +class NodeAbstractAction(AbstractAction): + @abstractmethod + def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: + super().__init__(manager=manager) + self.shape: Tuple[int] = (num_nodes,) + self.verb: str + + def form_request(self, node_id:int) -> List[str]: + node_uuid = self.manager.get_node_uuid_by_idx(node_id) + return ["network", "node", node_uuid, self.verb] + +class NodeOSScanAction(NodeAbstractAction): + def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: + super().__init__(manager=manager) + self.verb = 'scan' + +class NodeShutdownAction(NodeAbstractAction): + def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: + super().__init__(manager=manager) + self.verb = 'shutdown' + +class NodeStartupAction(NodeAbstractAction): + def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: + super().__init__(manager=manager) + self.verb = 'start' + +class NodeResetAction(NodeAbstractAction): + def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: + super().__init__(manager=manager) + self.verb = 'reset' + +class NetworkACLAddRuleAction(AbstractAction): + def __init__(self, manager: "ActionManager", **kwargs) -> None: + super().__init__(manager=manager) + num_permissions = 2 + self.shape: Tuple[int] = (max_acl_rules, num_permissions, num_nics, num_nics, num_ports, num_ports, num_protocols) + + + + + + +class ActionManager: + # let the action manager handle the conversion of action spaces into a single discrete integer space. + # + + + # when action space is created, it will take subspaces and generate an action map by enumerating all possibilities, + # BUT, the action map can be provided in the config, in which case it will use that. + + # action map is basically just a mapping between integer and CAOS action (incl. parameter values) + # for example the action map can be: + # 0: DONOTHING + # 1: NODE, FILE, SCAN, NODEID=2, FOLDERID=1, FILEID=0 + # 2: ...... + __act_class_identifiers:Dict[str,type] = { + "DONOTHING": DoNothingAction, + "NODE_SERVICE_SCAN": NodeServiceScanAction, + "NODE_SERVICE_STOP": NodeServiceStopAction, + # "NODE_SERVICE_START": NodeServiceStartAction, + # "NODE_SERVICE_PAUSE": NodeServicePauseAction, + # "NODE_SERVICE_RESUME": NodeServiceResumeAction, + # "NODE_SERVICE_RESTART": NodeServiceRestartAction, + # "NODE_SERVICE_DISABLE": NodeServiceDisableAction, + # "NODE_SERVICE_ENABLE": NodeServiceEnableAction, + # "NODE_FILE_SCAN": NodeFileScanAction, + # "NODE_FILE_CHECKHASH": NodeFileCheckhashAction, + # "NODE_FILE_DELETE": NodeFileDeleteAction, + # "NODE_FILE_REPAIR": NodeFileRepairAction, + # "NODE_FILE_RESTORE": NodeFileRestoreAction, + "NODE_FOLDER_SCAN": NodeFolderScanAction, + # "NODE_FOLDER_CHECKHASH": NodeFolderCheckhashAction, + # "NODE_FOLDER_REPAIR": NodeFolderRepairAction, + # "NODE_FOLDER_RESTORE": NodeFolderRestoreAction, + # "NODE_OS_SCAN": NodeOSScanAction, + # "NODE_SHUTDOWN": NodeShutdownAction, + # "NODE_STARTUP": NodeStartupAction, + # "NODE_RESET": NodeResetAction, + # "NETWORK_ACL_ADDRULE": NetworkACLAddRuleAction, + # "NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction, + # "NETWORK_NIC_ENABLE": NetworkNICEnable, + # "NETWORK_NIC_DISABLE": NetworkNICDisable, + } + + + def __init__(self, + sim:Simulation, + actions:List[str], + node_uuids:List[str], + max_folders_per_node:int = 2, + max_files_per_folder:int = 2, + max_services_per_node:int = 2, + max_nics_per_node:int=8, + max_acl_rules:int=10, + act_map:Optional[Dict[int, Dict]]=None) -> None: + self.sim: Simulation = sim + self.node_uuids:List[str] = node_uuids + + action_args = { + "num_nodes": len(node_uuids), + "num_folders":max_folders_per_node, + "num_files": max_files_per_folder, + "num_services": max_services_per_node, + "num_nics": max_nics_per_node, + "num_acl_rules": max_acl_rules} + self.actions: Dict[str, AbstractAction] = {} + for act_type in actions: + self.actions[act_type] = self.__act_class_identifiers[act_type](self, **action_args) + + self.action_map:Dict[int, Tuple[str, Dict]] = {} + """ + Action mapping that converts an integer to a specific action and parameter choice. + + For example : + {0: ("NODE_SERVICE_SCAN", {node_id:0, service_id:2})} + """ + if act_map is None: + self.action_map = self._enumerate_actions() + else: + self.action_map = {i:(a['action'], a['options']) for i,a in act_map.items()} + # make sure all numbers between 0 and N are represented as dict keys in action map + assert all([i in self.action_map.keys() for i in range(len(self.action_map))]) + + def _enumerate_actions(self,) -> Dict[int, Tuple[AbstractAction, Dict]]: ... + def get_action(self, action: int) -> Tuple[str,Dict]: + """Produce action in CAOS format""" + """the agent chooses an action (as an integer), this is converted into an action in CAOS format""" + """The caos format is basically a action identifier, followed by parameters stored in a dictionary""" + act_identifier, act_options = self.action_map[action] + return act_identifier, act_options -class ActionSpace: + def form_request(self, action_identifier:str, action_options:Dict): + """Take action in CAOS format and use the execution definition to change it into PrimAITE request format""" + act_obj = self.actions[action_identifier] + return act_obj.form_request(**action_options) + + @property + def space(self) -> spaces.Space: + return spaces.Discrete(len(self.action_map)) + + def get_node_uuid_by_idx(self, node_idx): + return self.node_uuids[node_idx] + + def get_folder_uuid_by_idx(self, node_idx, folder_idx) -> Optional[str]: + node_uuid = self.get_node_uuid_by_idx(node_idx) + node = self.sim.network.nodes[node_uuid] + folder_uuids = list(node.file_system.folders.keys()) + return folder_uuids[folder_idx] if len(folder_uuids)>folder_idx else None + + def get_file_uuid_by_idx(self, node_idx, folder_idx, file_idx) -> Optional[str]: + node_uuid = self.get_node_uuid_by_idx(node_idx) + node = self.sim.network.nodes[node_uuid] + folder_uuids = list(node.file_system.folders.keys()) + if len(folder_uuids)<=folder_idx: + return None + folder = node.file_system.folders[folder_uuids[folder_idx]] + file_uuids = list(folder.files.keys()) + return file_uuids[file_idx] if len(file_uuids)>file_idx else None + + def get_service_uuid_by_idx(self, node_idx, service_idx) -> Optional[str]: + node_uuid = self.get_node_uuid_by_idx(node_idx) + node = self.sim.network.nodes[node_uuid] + service_uuids = list(node.services.keys()) + return service_uuids[service_idx] if len(service_uuids)>service_idx else None + + + + + + +class UC2RedActions(AbstractAction): + ... + +class UC2GreenActionSpace(ActionManager): ... diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index b1ade94b..0e682b60 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -2,32 +2,70 @@ # That's because I want to point out that this is disctinct from 'agent' in the reinforcement learning sense of the word # If you disagree, make a comment in the PR review and we can discuss from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union, TypeAlias +import numpy as np -from primaite.game.agent.actions import ActionSpace +from primaite.game.agent.actions import ActionManager from primaite.game.agent.observations import ObservationSpace from primaite.game.agent.rewards import RewardFunction +ObsType:TypeAlias = Union[Dict, np.ndarray] class AbstractAgent(ABC): """Base class for scripted and RL agents.""" def __init__( self, - action_space: Optional[ActionSpace], + action_space: Optional[ActionManager], observation_space: Optional[ObservationSpace], reward_function: Optional[RewardFunction], ) -> None: - self.action_space: Optional[ActionSpace] = action_space + self.action_space: Optional[ActionManager] = action_space self.observation_space: Optional[ObservationSpace] = observation_space self.reward_function: Optional[RewardFunction] = reward_function + # exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info + # by for example specifying target ip addresses, or converting a node ID into a uuid + self.execution_definition = None + + def get_obs_from_state(self, state:Dict) -> ObsType: + """ + state : dict state directly from simulation.describe_state + output : dict state according to CAOS. + """ + return self.observation_space.observe(state) + + def get_reward_from_state(self, state:Dict) -> float: + return self.reward_function.calculate(state) + + @abstractmethod + def get_action(self, obs:ObsType, reward:float=None): + # in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 1-40, + # then use a bespoke conversion to take 1-40 int back into CAOS action + return ('NODE', 'SERVICE', 'SCAN', '', '') + + @abstractmethod + def format_request(self, action) -> List[str]: + # this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator. + # therefore the execution definition needs to be a mapping from CAOS into SIMULATOR + """Format action into format expected by the simulator, and apply execution definition if applicable.""" + return ['network', 'nodes', '', 'file_system', 'folder', 'root', 'scan'] + + + + class AbstractScriptedAgent(AbstractAgent): """Base class for actors which generate their own behaviour.""" ... +class RandomAgent(AbstractScriptedAgent): + """Agent that ignores its observation and acts completely at random.""" + + def get_action(self, obs:ObsType, reward:float=None): + return self.action_space.space.sample() + class AbstractGATEAgent(AbstractAgent): """Base class for actors controlled via external messages, such as RL policies.""" diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index 4d4796e1..f919a723 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -55,7 +55,7 @@ class AbstractObservation(ABC): class FileObservation(AbstractObservation): - def __init__(self, where: List[str] = []) -> None: + def __init__(self, where: Optional[List[str]] = None) -> None: """ _summary_ @@ -68,12 +68,12 @@ class FileObservation(AbstractObservation): :type where: Optional[List[str]] """ super().__init__() - self.where: List[str] = where + self.where: Optional[List[str]] = where self.default_observation: spaces.Space = {"health_status": 0} "Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted." def observe(self, state: Dict) -> Dict: - if not self.where: + if self.where is None: return self.default_observation file_state = access_from_nested_dict(state, self.where) if file_state is NOT_PRESENT_IN_STATE: @@ -89,21 +89,21 @@ class ServiceObservation(AbstractObservation): default_observation: spaces.Space = {"operating_status": 0, "health_status": 0} "Default observation is what should be returned when the service doesn't exist." - def __init__(self, where: List[str] = []) -> None: + def __init__(self, where: Optional[List[str]] = None) -> None: """ :param where: Store information about where in the simulation state dictionary to find the relevant information. Optional. If None, this corresponds that the file does not exist and the observation will be populated with zeroes. A typical location for a service looks like this: - `['network','nodes',,'servics', ]` + `['network','nodes',,'services', ]` :type where: Optional[List[str]] """ super().__init__() - self.where: List[str] = where + self.where: Optional[List[str]] = where def observe(self, state: Dict) -> Dict: - if not self.where: + if self.where is None: return self.default_observation service_state = access_from_nested_dict(state, self.where) @@ -120,7 +120,7 @@ class LinkObservation(AbstractObservation): default_observation: spaces.Space = {"protocols": {"all": {"load": 0}}} "Default observation is what should be returned when the link doesn't exist." - def __init__(self, where: List[str] = []) -> None: + def __init__(self, where: Optional[List[str]] = None) -> None: """ :param where: Store information about where in the simulation state dictionary to find the relevant information. Optional. If None, this corresponds that the file does not exist and the observation will be populated with @@ -131,10 +131,10 @@ class LinkObservation(AbstractObservation): :type where: Optional[List[str]] """ super().__init__() - self.where: List[str] = where + self.where: Optional[List[str]] = where def observe(self, state: Dict) -> Dict: - if not self.where: + if self.where is None: return self.default_observation link_state = access_from_nested_dict(state, self.where) @@ -156,7 +156,7 @@ class LinkObservation(AbstractObservation): class FolderObservation(AbstractObservation): - def __init__(self, where: List[str] = [], files: List[FileObservation] = []) -> None: + def __init__(self, where: Optional[List[str]] = None, files: List[FileObservation] = []) -> None: """Initialise folder Observation, including files inside of the folder. :param where: Where in the simulation state dictionary to find the relevant information for this folder. @@ -175,7 +175,7 @@ class FolderObservation(AbstractObservation): """ super().__init__() - self.where: List[str] = where + self.where: Optional[List[str]] = where self.files: List[FileObservation] = files @@ -185,7 +185,7 @@ class FolderObservation(AbstractObservation): } def observe(self, state: Dict) -> Dict: - if not self.where: + if self.where is None: return self.default_observation folder_state = access_from_nested_dict(state, self.where) if folder_state is NOT_PRESENT_IN_STATE: @@ -213,12 +213,12 @@ class FolderObservation(AbstractObservation): class NicObservation(AbstractObservation): default_observation: spaces.Space = {"nic_status": 0} - def __init__(self, where: List[str] = []) -> None: - super.__init__() - self.where: List[str] = where + def __init__(self, where: Optional[List[str]] = None) -> None: + super().__init__() + self.where: Optional[List[str]] = where def observe(self, state: Dict) -> Dict: - if not self.where: + if self.where is None: return self.default_observation nic_state = access_from_nested_dict(state, self.where) if nic_state is NOT_PRESENT_IN_STATE: @@ -234,10 +234,11 @@ class NicObservation(AbstractObservation): class NodeObservation(AbstractObservation): def __init__( self, - where: List[str] = [], + where: Optional[List[str]] = None, services: List[ServiceObservation] = [], folders: List[FolderObservation] = [], nics: List[NicObservation] = [], + logon_status:bool=False ) -> None: """ Configurable observation for a node in the simulation. @@ -259,12 +260,13 @@ class NodeObservation(AbstractObservation): :param max_nics: Max number of NICS in this node's obs space, defaults to 5 :type max_nics: int, optional """ - super.__init__() - self.where: List[str] = where + super().__init__() + self.where: Optional[List[str]] = where self.services: List[ServiceObservation] = services self.folders: List[FolderObservation] = folders self.nics: List[NicObservation] = nics + self.logon_status:bool=logon_status self.default_observation: Dict = { "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, @@ -272,9 +274,11 @@ class NodeObservation(AbstractObservation): "NICS": {i + 1: n.default_observation for i, n in enumerate(self.nics)}, "operating_status": 0, } + if self.logon_status: + self.default_observation['logon_status']=0 def observe(self, state: Dict) -> Dict: - if not self.where: + if self.where is None: return self.default_observation node_state = access_from_nested_dict(state, self.where) @@ -288,18 +292,24 @@ class NodeObservation(AbstractObservation): obs["operating_status"] = node_state["operating_state"] obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)} + if self.logon_status: + obs['logon_status'] = 0 + return obs @property def space(self) -> spaces.Space: - return spaces.Dict( - { - "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), - "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), - "operating_status": spaces.Discrete(0), - "NICS": spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)}), - } - ) + space_shape = { + "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), + "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), + "operating_status": spaces.Discrete(5), + "NICS": spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)}), + } + if self.logon_status: + space_shape['logon_status'] = spaces.Discrete(3) + + return spaces.Dict(space_shape) + class AclObservation(AbstractObservation): @@ -308,41 +318,33 @@ class AclObservation(AbstractObservation): # if a file is created at runtime, we have currently got no way of telling the observation space to track it. # this needs adding, but not for the MVP. def __init__( - self, nodes: List[str], ports: List[int], protocols: list[str], where: List[str] = [], num_rules: int = 10 + self, node_ip_to_id: Dict[str,int], ports: List[int], protocols: list[str], where: Optional[List[str]] = None, num_rules: int = 10 ) -> None: super().__init__() - self.where: List[str] = where + self.where: Optional[List[str]] = where self.num_rules: int = num_rules - self.node_to_id: Dict[str, int] = {node: i + 1 for i, node in enumerate(nodes)} + self.node_to_id: Dict[str, int] = node_ip_to_id "List of node IP addresses, order in this list determines how they are converted to an ID" - self.port_to_id: Dict[int, int] = {port: i + 1 for i, port in enumerate(ports)} + self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)} "List of ports which are part of the game that define the ordering when converting to an ID" - self.protocol_to_id: Dict[str, int] = {protocol: i + 1 for i, protocol in enumerate(protocols)} + self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)} "List of protocols which are part of the game, defines ordering when converting to an ID" - self.default_observation: spaces.Space = spaces.Dict( - { - "RULES": spaces.Dict( - { - i - + 1: spaces.Dict( - { - "position": i, - "permission": 0, - "source_node_id": 0, - "source_port": 0, - "dest_node_id": 0, - "dest_port": 0, - "protocol": 0, - } - ) - for i in range(self.num_rules) - } - ) + self.default_observation: Dict = { + "RULES": {i+ 1:{ + "position": i, + "permission": 0, + "source_node_id": 0, + "source_port": 0, + "dest_node_id": 0, + "dest_port": 0, + "protocol": 0, + } + for i in range(self.num_rules) } - ) + } def observe(self, state: Dict) -> Dict: - if not self.where: + if self.where is None: return self.default_observation acl_state: Dict = access_from_nested_dict(state, self.where) if acl_state is NOT_PRESENT_IN_STATE: @@ -379,16 +381,16 @@ class AclObservation(AbstractObservation): { "RULE": spaces.Dict( { - i - + 1: spaces.Dict( + i + 1: spaces.Dict( { "position": spaces.Discrete(self.num_rules), "permission": spaces.Discrete(3), - "source_node_id": spaces.Discrete(len(self.nodes) + 1), - "source_port": spaces.Discrete(len(self.ports) + 1), - "dest_node_id": spaces.Discrete(len(self.nodes) + 1), - "dest_port": spaces.Discrete(len(self.ports) + 1), - "protocol": spaces.Discrete(len(self.protocols) + 1), + # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) + "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), + "source_port": spaces.Discrete(len(self.port_to_id) + 2), + "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), + "dest_port": spaces.Discrete(len(self.port_to_id) + 2), + "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), } ) for i in range(self.num_rules) @@ -398,14 +400,96 @@ class AclObservation(AbstractObservation): ) -class ICSObservation(AbstractObservation): - def observe(self, state: Dict) -> Any: - return 0 + + +class NullObservation(AbstractObservation): + def __init__(self, where:Optional[List[str]]=None): + self.default_observation: Dict = {} + + def observe(self, state: Dict) -> Dict: + return {} @property def space(self) -> spaces.Space: - return spaces.Discrete(1) + return spaces.Dict({}) +class ICSObservation(NullObservation): pass + + +class UC2BlueObservation(AbstractObservation): + def __init__( + self, + nodes: List[NodeObservation], + links: List[LinkObservation], + acl: AclObservation, + ics: ICSObservation, + where:Optional[List[str]] = None, + ) -> None: + super().__init__() + self.where: Optional[List[str]] = where + + self.nodes: List[NodeObservation] = nodes + self.links: List[LinkObservation] = links + self.acl: AclObservation = acl + self.ics: ICSObservation = ics + + self.default_observation : Dict = { + "NODES": {i+1: n.default_observation for i,n in enumerate(self.nodes)}, + "LINKS": {i+1: l.default_observation for i,l in enumerate(self.links)}, + "ACL": self.acl.default_observation, + "ICS": self.ics.default_observation, + } + + def observe(self, state:Dict) -> Dict: + if self.where is None: + return self.default_observation + + obs = {} + + obs['NODES'] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} + obs['LINKS'] = {i + 1: link.observe(state) for i, link in enumerate(self.links)} + obs['ACL'] = {self.acl.observe(state)} + obs['ICS'] = {self.ics.observe(state)} + + return obs + + @property + def space(self) -> spaces.Space: + return spaces.Dict({ + "NODES": spaces.Dict({i+1: node.space for i, node in enumerate(self.nodes)}), + "LINKS": spaces.Dict({i+1: link.space for i, link in enumerate(self.links)}), + "ACL": self.acl.space, + "ICS": self.ics.space, + }) + + @classmethod + def from_config(cls, config:Dict, sim:Simulation): + nodes = ... + links = ... + acl = ... + ics = ... + new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=['network']) + return new + + +class UC2RedObservation(AbstractObservation): + def __init__(self, nodes:List[NodeObservation], where:Optional[List[str]] = None) -> None: + super().__init__() + self.where:Optional[List[str]] = where + self.nodes: List[NodeObservation] = nodes + + self.default_observation=...#TODO + + def observe(self, state: Dict) -> Any: + return super().observe(state) + + @property + def space(self) -> spaces.Space: + ... #TODO + + @classmethod + def from_config(cls, config: Dict, sim:Simulation): + ... #TODO class ObservationSpace: """ @@ -422,29 +506,12 @@ class ObservationSpace: # what this class does: # keep a list of observations # create observations for an actor from the config - def __init__( - self, - simulation: Simulation, - nodes: List[NodeObservation] = [], - links: List[LinkObservation] = [], - acl: Optional[AclObservation] = None, - ics: Optional[ICSObservation] = None, - ) -> None: - self.simulation: Simulation = simulation - self.parts: Dict[str, AbstractObservation] = {} + def __init__(self, observation:AbstractObservation) -> None: + self.obs: AbstractObservation = observation - self.nodes: List[NodeObservation] = nodes - self.links: List[LinkObservation] = links - self.acl: Optional[AclObservation] = acl - self.ics: Optional[ICSObservation] = ics - - def observe(self) -> None: - ... + def observe(self, state) -> Dict: + return self.obs.observe(state) @property def space(self) -> None: - ... - - @classmethod - def from_config(self) -> None: - ... + return self.obs.space diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 1db54176..ec778176 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -1,20 +1,17 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List -from pydantic import BaseModel - - -class AbstractReward(BaseModel): - def __call__(self, states: List[Dict]) -> float: - """_summary_ - - :param state: _description_ - :type state: Dict - :return: _description_ - :rtype: float - """ +class AbstractReward(): + def __init__(self): ... + def calculate(self, state:Dict) -> float: + return 0.3 -class RewardFunction(BaseModel): - ... + +class RewardFunction(): + def __init__(self, reward_function:AbstractReward): + self.reward: AbstractReward = reward_function + + def calculate(self, state:Dict) -> float: + return self.reward.calculate(state) diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 47ef4ce9..fcd8b4b3 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -4,3 +4,49 @@ # 3. create actors and configure their actions/observations/rewards/ anything else # 4. Create connection with ARCD GATE # 5. idk + +from primaite.simulator.sim_container import Simulation +from primaite.game.agent.interface import AbstractAgent + +from typing import List + +class PrimaiteSession: + def __init__(self): + self.simulation: Simulation = Simulation() + self.agents:List[AbstractAgent] = [] + self.step_counter:int = 0 + self.episode_counter:int = 0 + + + def step(self): + # currently designed with assumption that all agents act once per step in order + + + for agent in self.agents: + # 3. primaite session asks simulation to provide initial state + # 4. primate session gives state to all agents + # 5. primaite session asks agents to produce an action based on most recent state + sim_state = self.simulation.describe_state() + + # 6. each agent takes most recent state and converts it to CAOS observation + agent_obs = agent.get_obs_from_state(sim_state) + + # 7. meanwhile each agent also takes state and calculates reward + agent_reward = agent.get_reward_from_state(sim_state) + + # 8. each agent takes observation and applies decision rule to observation to create CAOS + # action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action + # to discrete(40) is only necessary for purposes of RL learning, therefore that bit of + # code should live inside of the GATE agent subclass) + # gets action in CAOS format + agent_action = agent.get_action(agent_obs, agent_reward) + # 9. CAOS action is converted into request (extra information might be needed to enrich + # the request, this is what the execution definition is there for) + agent_request = agent.format_request(agent_action) + + # 10. primaite session receives the action from the agents and asks the simulation to apply each + self.simulation.apply_action(agent_request) + + self.simulation.apply_timestep(self.step_counter) + self.step_counter += 1 +