From 3dea9743c355d61b1fd7b7461abe7a79dc77f822 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 6 Oct 2023 20:32:52 +0100 Subject: [PATCH] Get primaite session step working --- example_config.yaml | 455 +++++++++++++----------- sandbox.ipynb | 195 ++++++---- src/primaite/game/agent/actions.py | 158 +++++--- src/primaite/game/agent/interface.py | 16 +- src/primaite/game/agent/rewards.py | 17 + src/primaite/game/session.py | 34 +- src/primaite/simulator/sim_container.py | 1 + 7 files changed, 520 insertions(+), 356 deletions(-) diff --git a/example_config.yaml b/example_config.yaml index b47355c3..9c75c92e 100644 --- a/example_config.yaml +++ b/example_config.yaml @@ -19,20 +19,28 @@ game_config: type: GreenWebBrowsingAgent observation_space: null action_space: - actions: - - type: DONOTHING - nodes: - - node_ref: client_2 - actions: - - type: LOGON - - type: LOGOFF - applications: - # - application_ref: client_2_web_browser - # actions: - # - type: EXECUTE - # execution_definition: - # target_address: arcd.com - reward_function: null + action_list: + - type: DONOTHING + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 2 + max_files_per_folder: 2 + max_services_per_node: 2 + max_nics_per_node: 8 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + agent_settings: start_step: 5 frequency: 4 @@ -41,6 +49,7 @@ game_config: - ref: client_1_data_manipulation_red_bot team: RED type: RedDatabaseCorruptingAgent + observation_space: type: UC2RedObservation options: @@ -55,27 +64,56 @@ game_config: - operating_status - health_status folders: {} + action_space: - actions: - - type: DO_NOTHING - network: + action_list: + - type: DONOTHING + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # server_ip: 192.168.1.14 + # payload: "DROP TABLE IF EXISTS user;" + # success_rate: 80% + - type: NODE_FILE_DELETE + - type: NODE_FILE_CORRUPT + # - type: NODE_FOLDER_DELETE + # - type: NODE_FOLDER_CORRUPT + - type: NODE_OS_SCAN + # - type: NODE_LOGON + # - type: NODE_LOGOFF + options: nodes: - node_ref: client_1 - actions: - - type: SCAN - - type: LOGON - - type: LOGOFF - services: - - service_ref: data_manipulation_bot - actions: - - type: COMPROMISE - execution_definition: - server_ip: 192.168.1.14 - payload: "DROP TABLE IF EXISTS user;" - success_rate: 80% - folders: - files: {} - reward_function: null + max_folders_per_node: 2 + max_files_per_folder: 2 + max_services_per_node: 2 + # max_nics_per_node: 8 + # max_acl_rules: 10 + + # actions: + # - type: DO_NOTHING + # network: + # nodes: + # - node_ref: client_1 + # actions: + # - type: SCAN + # - type: LOGON + # - type: LOGOFF + # services: + # - service_ref: data_manipulation_bot + # actions: + # - type: COMPROMISE + # execution_definition: + # server_ip: 192.168.1.14 + # payload: "DROP TABLE IF EXISTS user;" + # success_rate: 80% + # folders: + # files: {} + + reward_function: + reward_components: + - type: DUMMY + agent_settings: # options specific to this particular agent type, basically args of __init__(self) start_step: 25 frequency: 20 @@ -85,8 +123,9 @@ game_config: - ref: defender - team: blue + team: BLUE type: GATERLAgent + observation_space: type: UC2BlueObservation options: @@ -128,7 +167,6 @@ game_config: router_node_ref: router_1 ics: null - action_space: action_list: - type: DONOTHING @@ -164,227 +202,227 @@ game_config: action_map: 0: - - action: DONOTHING + action: DONOTHING options: {} # scan webapp service 1: - - action: NODE_SERVICE_SCAN + action: NODE_SERVICE_SCAN options: - node_id: 2 - service_id: 1 # stop webapp service 2: - - action: NODE_SERVICE_STOP + action: NODE_SERVICE_STOP options: - node_id: 2 - service_id: 1 # start webapp service 3: - - action: "NODE_SERVICE_START" - options: - - node_id: 2 - - service_id: 1 + action: "NODE_SERVICE_START" + options: + - node_id: 2 + - service_id: 1 4: - - action: "NODE_SERVICE_PAUSE" - options: - - node_id: 2 - - service_id: 1 + action: "NODE_SERVICE_PAUSE" + options: + - node_id: 2 + - service_id: 1 5: - - action: "NODE_SERVICE_RESUME" - options: - - node_id: 2 - - service_id: 1 + action: "NODE_SERVICE_RESUME" + options: + - node_id: 2 + - service_id: 1 6: - - action: "NODE_SERVICE_RESTART" - options: - - node_id: 2 - - service_id: 1 + action: "NODE_SERVICE_RESTART" + options: + - node_id: 2 + - service_id: 1 7: - - action: "NODE_SERVICE_DISABLE" - options: - - node_id: 2 - - service_id: 1 + action: "NODE_SERVICE_DISABLE" + options: + - node_id: 2 + - service_id: 1 8: - - action: "NODE_SERVICE_ENABLE" - options: - - node_id: 2 - - service_id: 1 + action: "NODE_SERVICE_ENABLE" + options: + - node_id: 2 + - service_id: 1 9: - - action: "NODE_FILE_SCAN" - options: - - node_id: 3 - - folder_id: 1 - - file_id: 1 + action: "NODE_FILE_SCAN" + options: + - node_id: 3 + - folder_id: 1 + - file_id: 1 10: - - action: "NODE_FILE_CHECKHASH" - options: - - node_id: 3 - - folder_id: 1 - - file_id: 1 + action: "NODE_FILE_CHECKHASH" + options: + - node_id: 3 + - folder_id: 1 + - file_id: 1 11: - - action: "NODE_FILE_DELETE" - options: - - node_id: 3 - - folder_id: 1 - - file_id: 1 + action: "NODE_FILE_DELETE" + options: + - node_id: 3 + - folder_id: 1 + - file_id: 1 12: - - action: "NODE_FILE_REPAIR" - options: - - node_id: 3 - - folder_id: 1 - - file_id: 1 + action: "NODE_FILE_REPAIR" + options: + - node_id: 3 + - folder_id: 1 + - file_id: 1 13: - - action: "NODE_FILE_RESTORE" - options: - - node_id: 3 - - folder_id: 1 - - file_id: 1 + action: "NODE_FILE_RESTORE" + options: + - node_id: 3 + - folder_id: 1 + - file_id: 1 14: - - action: "NODE_FOLDER_SCAN" - options: - - node_id: 3 - - folder_id: 1 + action: "NODE_FOLDER_SCAN" + options: + - node_id: 3 + - folder_id: 1 15: - - action: "NODE_FOLDER_CHECKHASH" - options: - - node_id: 3 - - folder_id: 1 + action: "NODE_FOLDER_CHECKHASH" + options: + - node_id: 3 + - folder_id: 1 16: - - action: "NODE_FOLDER_REPAIR" - options: - - node_id: 3 - - folder_id: 1 + action: "NODE_FOLDER_REPAIR" + options: + - node_id: 3 + - folder_id: 1 17: - - action: "NODE_FOLDER_RESTORE" - options: - - node_id: 3 - - folder_id: 1 + action: "NODE_FOLDER_RESTORE" + options: + - node_id: 3 + - folder_id: 1 18: - - action: "NODE_OS_SCAN" - options: - - node_id: 3 + action: "NODE_OS_SCAN" + options: + - node_id: 3 19: - - action: "NODE_SHUTDOWN" - options: - - node_id: 6 + action: "NODE_SHUTDOWN" + options: + - node_id: 6 20: - - action: "NODE_STARTUP" - options: - - node_id: 6 + action: "NODE_STARTUP" + options: + - node_id: 6 21: - - action: "NODE_RESET" - options: - - node_id: 6 + action: "NODE_RESET" + options: + - node_id: 6 22: - - action: "NETWORK_ACL_ADDRULE" - options: - - position: 6 - - permission: 2 - - source_node_id: ... - - dest_node_id: ... - - source_port_id: ... - - dest_port_id: ... - - protocol_id: ... + action: "NETWORK_ACL_ADDRULE" + options: + - position: 6 + - permission: 2 + - source_node_id: ... + - dest_node_id: ... + - source_port_id: ... + - dest_port_id: ... + - protocol_id: ... 23: - - action: "NETWORK_ACL_ADDRULE" - options: - - position: 5 - - permission: 2 - - source_node_id: ... - - dest_node_id: ... - - source_port_id: ... - - dest_port_id: ... - - protocol_id: ... + action: "NETWORK_ACL_ADDRULE" + options: + - position: 5 + - permission: 2 + - source_node_id: ... + - dest_node_id: ... + - source_port_id: ... + - dest_port_id: ... + - protocol_id: ... 24: - - action: "NETWORK_ACL_ADDRULE" - options: - - position: 4 - - permission: 2 - - source_node_id: ... - - dest_node_id: ... - - source_port_id: ... - - dest_port_id: ... - - protocol_id: ... + action: "NETWORK_ACL_ADDRULE" + options: + - position: 4 + - permission: 2 + - source_node_id: ... + - dest_node_id: ... + - source_port_id: ... + - dest_port_id: ... + - protocol_id: ... 25: - - action: "NETWORK_ACL_ADDRULE" - options: - - position: 3 - - permission: 2 - - source_node_id: ... - - dest_node_id: ... - - source_port_id: ... - - dest_port_id: ... - - protocol_id: ... + action: "NETWORK_ACL_ADDRULE" + options: + - position: 3 + - permission: 2 + - source_node_id: ... + - dest_node_id: ... + - source_port_id: ... + - dest_port_id: ... + - protocol_id: ... 26: - - action: "NETWORK_ACL_ADDRULE" - options: - - position: 2 - - permission: 2 - - source_node_id: ... - - dest_node_id: ... - - source_port_id: ... - - dest_port_id: ... - - protocol_id: ... + action: "NETWORK_ACL_ADDRULE" + options: + - position: 2 + - permission: 2 + - source_node_id: ... + - dest_node_id: ... + - source_port_id: ... + - dest_port_id: ... + - protocol_id: ... 27: - - action: "NETWORK_ACL_ADDRULE" - options: - - position: 1 - - permission: 2 - - source_node_id: ... - - dest_node_id: ... - - source_port_id: ... - - dest_port_id: ... - - protocol_id: ... + action: "NETWORK_ACL_ADDRULE" + options: + - position: 1 + - permission: 2 + - source_node_id: ... + - dest_node_id: ... + - source_port_id: ... + - dest_port_id: ... + - protocol_id: ... 28: - - action: "NETWORK_ACL_REMOVERULE" - options: - - position: 0 + action: "NETWORK_ACL_REMOVERULE" + options: + - position: 0 29: - - action: "NETWORK_ACL_REMOVERULE" - options: - - position: 1 + action: "NETWORK_ACL_REMOVERULE" + options: + - position: 1 30: - - action: "NETWORK_ACL_REMOVERULE" - options: - - position: 2 + action: "NETWORK_ACL_REMOVERULE" + options: + - position: 2 31: - - action: "NETWORK_ACL_REMOVERULE" - options: - - position: 3 + action: "NETWORK_ACL_REMOVERULE" + options: + - position: 3 32: - - action: "NETWORK_ACL_REMOVERULE" - options: - - position: 4 + action: "NETWORK_ACL_REMOVERULE" + options: + - position: 4 33: - - action: "NETWORK_ACL_REMOVERULE" - options: - - position: 5 + action: "NETWORK_ACL_REMOVERULE" + options: + - position: 5 34: - - action: "NETWORK_ACL_REMOVERULE" - options: - - position: 6 + action: "NETWORK_ACL_REMOVERULE" + options: + - position: 6 35: - - action: "NETWORK_ACL_REMOVERULE" - options: - - position: 7 + action: "NETWORK_ACL_REMOVERULE" + options: + - position: 7 36: - - action: "NETWORK_ACL_REMOVERULE" - options: - - position: 8 + action: "NETWORK_ACL_REMOVERULE" + options: + - position: 8 37: - - action: "NETWORK_ACL_REMOVERULE" - options: - - position: 9 + action: "NETWORK_ACL_REMOVERULE" + options: + - position: 9 38: - - action: "NETWORK_NIC_DISABLE" - options: - - node_id: 6 - - nic_index: 1 + action: "NETWORK_NIC_DISABLE" + options: + - node_id: 6 + - nic_index: 1 39: - - action: "NETWORK_NIC_ENABLE" - options: - - node_id: 6 - - nic_index: 1 + action: "NETWORK_NIC_ENABLE" + options: + - node_id: 6 + - nic_index: 1 options: nodes: @@ -404,9 +442,10 @@ game_config: max_nics_per_node: 8 max_acl_rules: 10 - reward_function: - # ... + reward_components: + - type: DUMMY + agent_settings: # ... diff --git a/sandbox.ipynb b/sandbox.ipynb index 3ff72170..51849298 100644 --- a/sandbox.ipynb +++ b/sandbox.ipynb @@ -2,18 +2,9 @@ "cells": [ { "cell_type": "code", - "execution_count": 13, + "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" @@ -21,28 +12,102 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.game.session import PrimaiteSession\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import itertools" + ] + }, + { + "cell_type": "code", + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from primaite.game.session import PrimaiteSession\n", "from primaite.simulator.sim_container import Simulation\n", "from primaite.game.agent.interface import AbstractAgent\n", - "from primaite.simulator.network.networks import arcd_uc2_network\n" + "from primaite.simulator.network.networks import arcd_uc2_network\n", + "import yaml\n" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ - "sess = PrimaiteSession()" + "with open('example_config.yaml', 'r') as file:\n", + " cfg = yaml.safe_load(file)" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-10-06 19:05:49,548: Added node 387fba92-e5ff-4ead-b525-1872091935ad to Network 35300ca7-ca53-41a4-b617-1f64c7645e52\n", + "2023-10-06 19:05:49,557: Added node a808ea99-5c8b-42c4-8e38-bf406ceb1f87 to Network 35300ca7-ca53-41a4-b617-1f64c7645e52\n", + "2023-10-06 19:05:49,562: Added node 922c77bb-096a-4236-9e1e-a44da15c1718 to Network 35300ca7-ca53-41a4-b617-1f64c7645e52\n", + "2023-10-06 19:05:49,579: Added node f11cc63b-537c-4813-be09-a1f5597dfe14 to Network 35300ca7-ca53-41a4-b617-1f64c7645e52\n", + "2023-10-06 19:05:49,591: Added node a866b811-efa2-41cc-adc0-a4752f40a0b8 to Network 35300ca7-ca53-41a4-b617-1f64c7645e52\n", + "2023-10-06 19:05:49,607: Added node a01c22b8-cdfb-4105-a8d0-c67c53b3d08b to Network 35300ca7-ca53-41a4-b617-1f64c7645e52\n", + "2023-10-06 19:05:49,635: Added node 217074fc-021e-4b19-94db-3bc2d5f15d49 to Network 35300ca7-ca53-41a4-b617-1f64c7645e52\n", + "2023-10-06 19:05:49,641: Added node 28db0167-0621-4fdb-9e2b-65e25a91a101 to Network 35300ca7-ca53-41a4-b617-1f64c7645e52\n", + "2023-10-06 19:05:49,648: Added node e754e649-7ba3-4f80-8621-906255cf8749 to Network 35300ca7-ca53-41a4-b617-1f64c7645e52\n", + "2023-10-06 19:05:49,657: Added node 65508b30-defa-46c8-af44-9f0c4c0ae59d to Network 35300ca7-ca53-41a4-b617-1f64c7645e52\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "service type not found DatabaseBackup\n", + "service type not found WebBrowser\n" + ] + } + ], + "source": [ + "sess = PrimaiteSession.from_config(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sess.agents" + ] + }, + { + "cell_type": "code", + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -51,7 +116,16 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "sess.step()" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -74,27 +148,9 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-10-02 15:10:20,422: Added node 6abb7664-4d17-45ff-a3c7-dbcccffcfd6d to Network 045a3114-4aac-4687-a10e-432cfd138325\n", - "2023-10-02 15:10:20,424: Added node 3edbc521-3c80-47e3-8017-dbc38fb00a73 to Network 045a3114-4aac-4687-a10e-432cfd138325\n", - "2023-10-02 15:10:20,428: Added node 94457fb9-04a1-4dc1-9ff7-b64df0da7424 to Network 045a3114-4aac-4687-a10e-432cfd138325\n", - "2023-10-02 15:10:20,432: Added node 0d311d72-139c-41bf-aef7-fa9b01b124d7 to Network 045a3114-4aac-4687-a10e-432cfd138325\n", - "2023-10-02 15:10:20,439: Added node 6161e785-f377-48de-aa4f-20d3646da635 to Network 045a3114-4aac-4687-a10e-432cfd138325\n", - "2023-10-02 15:10:20,444: Added node 55a9e9f8-ee3a-4c28-9b6d-c0c0d78a3f6a to Network 045a3114-4aac-4687-a10e-432cfd138325\n", - "2023-10-02 15:10:20,447: Added node 2f04ca45-3439-489a-81f7-41cca5ae8adc to Network 045a3114-4aac-4687-a10e-432cfd138325\n", - "2023-10-02 15:10:20,531: Added node 98660c30-8e48-4b96-967a-d62ca71b4d6d to Network 045a3114-4aac-4687-a10e-432cfd138325\n", - "2023-10-02 15:10:20,545: Added node 1a184184-b204-40de-986a-4d7459036dbe to Network 045a3114-4aac-4687-a10e-432cfd138325\n", - "2023-10-02 15:10:20,551: Added node 17b92b9a-6805-4677-85f7-c0c0521a6e25 to Network 045a3114-4aac-4687-a10e-432cfd138325\n", - "2023-10-02 15:10:20,555::ERROR::primaite.simulator.network.hardware.base::175::NIC da:f3:1b:87:24:20/192.168.10.110 cannot be enabled as it is not connected to a Link\n" - ] - } - ], + "outputs": [], "source": [ "router_1 = Router(hostname=\"router_1\", num_ports=5)\n", "router_1.power_on()\n", @@ -261,7 +317,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -270,7 +326,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -279,7 +335,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -305,7 +361,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -315,7 +371,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -324,26 +380,9 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['network',\n", - " 'node',\n", - " '6161e785-f377-48de-aa4f-20d3646da635',\n", - " 'file_system',\n", - " 'folder',\n", - " '5aefe92b-923c-4684-b3bf-e78dd18d4771',\n", - " 'scan']" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "my_trial_act" ] @@ -455,8 +494,8 @@ " session = cls()\n", " with open(cfg_path, 'r') as file:\n", " conf = yaml.safe_load(file)\n", - " \n", - " #1. create nodes \n", + "\n", + " #1. create nodes\n", " sim = Simulation()\n", " net = sim.network\n", " nodes_cfg = conf['simulation']['network']['nodes']\n", @@ -465,15 +504,15 @@ " node_ref = node_cfg['ref']\n", " n_type = node_cfg['type']\n", " if n_type == 'computer':\n", - " new_node = Computer(hostname = node_cfg['hostname'], \n", - " ip_address = node_cfg['ip_address'], \n", - " subnet_mask = node_cfg['subnet_mask'], \n", + " new_node = Computer(hostname = node_cfg['hostname'],\n", + " ip_address = node_cfg['ip_address'],\n", + " subnet_mask = node_cfg['subnet_mask'],\n", " default_gateway = node_cfg['default_gateway'],\n", " dns_server = node_cfg['dns_server'])\n", " elif n_type == 'server':\n", - " new_node = Server(hostname = node_cfg['hostname'], \n", - " ip_address = node_cfg['ip_address'], \n", - " subnet_mask = node_cfg['subnet_mask'], \n", + " new_node = Server(hostname = node_cfg['hostname'],\n", + " ip_address = node_cfg['ip_address'],\n", + " subnet_mask = node_cfg['subnet_mask'],\n", " default_gateway = node_cfg['default_gateway'],\n", " dns_server = node_cfg.get('dns_server'))\n", " elif n_type == 'switch':\n", @@ -484,12 +523,12 @@ " num_ports = node_cfg.get('num_ports'))\n", " if 'ports' in node_cfg:\n", " for port_num, port_cfg in node_cfg['ports'].items():\n", - " new_node.configure_port(port=port_num, \n", + " new_node.configure_port(port=port_num,\n", " ip_address=port_cfg['ip_address'],\n", " subnet_mask=port_cfg['subnet_mask'])\n", " if 'acl' in node_cfg:\n", " for r_num, r_cfg in node_cfg['acl'].items():\n", - " # excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating \n", + " # excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating\n", " # this: 'r_cfg.get('src_port')'\n", " # Port/IPProtocol. TODO Refactor\n", " new_node.acl.add_rule(\n", @@ -570,15 +609,15 @@ " action_space_cfg = agent_cfg['action_space']\n", " observation_space_cfg = agent_cfg['observation_space']\n", " reward_function_cfg = agent_cfg['reward_function']\n", - " \n", + "\n", " # CREATE OBSERVATION SPACE\n", " if observation_space_cfg is None:\n", " obs_space = NullObservation()\n", " elif observation_space_cfg['type'] == 'UC2BlueObservation':\n", " node_obs_list = []\n", " link_obs_list = []\n", - " \n", - " \n", + "\n", + "\n", " #node ip to index maps ip addresses to node id, as there are potentially multiple nics on a node, there are multiple ip addresses\n", " node_ip_to_index = {}\n", " for node_idx, node_cfg in enumerate(nodes_cfg):\n", @@ -587,8 +626,8 @@ " for nic_uuid, nic_obj in n_obj.nics.items():\n", " node_ip_to_index[nic_obj.ip_address] = node_idx + 2\n", "\n", - " \n", - " \n", + "\n", + "\n", " for node_obs_cfg in observation_space_cfg['options']['nodes']:\n", " node_ref = node_obs_cfg['node_ref']\n", " folder_obs_list = []\n", @@ -618,9 +657,9 @@ " else:\n", " print(\"observation space config not specified correctly.\")\n", " obs_space = NullObservation()\n", - " \n", + "\n", " # CREATE ACTION SPACE\n", - " \n", + "\n", "\n", "\n", " # CREATE REWARD FUNCTION\n", diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index f6f96161..3f674fbb 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1,12 +1,14 @@ import itertools from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING from gym import spaces -from primaite.game.session import PrimaiteSession from primaite.simulator.sim_container import Simulation +if TYPE_CHECKING: + from primaite.game.session import PrimaiteSession + class ExecutionDefiniton(ABC): """ @@ -59,7 +61,7 @@ class DoNothingAction(AbstractAction): # i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter # with one option. This just aids the Action Manager to enumerate all possibilities. - def form_request(self) -> List[str]: + def form_request(self, **kwargs) -> List[str]: return ["do_nothing"] @@ -86,56 +88,56 @@ class NodeServiceAbstractAction(AbstractAction): class NodeServiceScanAction(NodeServiceAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None: - super().__init__(manager=manager) + def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None: + super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "scan" class NodeServiceStopAction(NodeServiceAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None: - super().__init__(manager=manager) + def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None: + super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "stop" class NodeServiceStartAction(NodeServiceAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None: - super().__init__(manager=manager) + def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None: + super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "start" class NodeServicePauseAction(NodeServiceAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None: - super().__init__(manager=manager) + def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None: + super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "pause" class NodeServiceResumeAction(NodeServiceAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None: - super().__init__(manager=manager) + def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None: + super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "resume" class NodeServiceRestartAction(NodeServiceAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None: - super().__init__(manager=manager) + def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None: + super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "restart" class NodeServiceDisableAction(NodeServiceAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None: - super().__init__(manager=manager) + def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None: + super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "disable" class NodeServiceEnableAction(NodeServiceAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None: - super().__init__(manager=manager) + def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None: + super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services) self.verb = "enable" class NodeFolderAbstractAction(AbstractAction): @abstractmethod - def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None: + def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None: super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders} self.verb: str @@ -149,26 +151,26 @@ class NodeFolderAbstractAction(AbstractAction): class NodeFolderScanAction(NodeFolderAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None: - super().__init__(manager, num_nodes, num_folders, **kwargs) + def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None: + super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "scan" class NodeFolderCheckhashAction(NodeFolderAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None: - super().__init__(manager, num_nodes, num_folders, **kwargs) + def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None: + super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "checkhash" class NodeFolderRepairAction(NodeFolderAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None: - super().__init__(manager, num_nodes, num_folders, **kwargs) + def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None: + super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "repair" class NodeFolderRestoreAction(NodeFolderAbstractAction): - def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None: - super().__init__(manager, num_nodes, num_folders, **kwargs) + def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None: + super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "restore" @@ -190,34 +192,40 @@ class NodeFileAbstractAction(AbstractAction): class NodeFileScanAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: - super().__init__(manager, num_nodes, num_folders, num_files, **kwargs) + super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) self.verb = "scan" class NodeFileCheckhashAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: - super().__init__(manager, num_nodes, num_folders, num_files, **kwargs) + super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) self.verb = "checkhash" class NodeFileDeleteAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: - super().__init__(manager, num_nodes, num_folders, num_files, **kwargs) + super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) self.verb = "delete" class NodeFileRepairAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: - super().__init__(manager, num_nodes, num_folders, num_files, **kwargs) + super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) self.verb = "repair" class NodeFileRestoreAction(NodeFileAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: - super().__init__(manager, num_nodes, num_folders, num_files, **kwargs) + super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) self.verb = "restore" +class NodeFileCorruptAction(NodeFileAbstractAction): + def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None: + super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) + self.verb = "corrupt" + + class NodeAbstractAction(AbstractAction): @abstractmethod def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: @@ -232,25 +240,25 @@ class NodeAbstractAction(AbstractAction): class NodeOSScanAction(NodeAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: - super().__init__(manager=manager) + super().__init__(manager=manager, num_nodes=num_nodes) self.verb = "scan" class NodeShutdownAction(NodeAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: - super().__init__(manager=manager) + super().__init__(manager=manager, num_nodes=num_nodes) self.verb = "shutdown" class NodeStartupAction(NodeAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: - super().__init__(manager=manager) + super().__init__(manager=manager, num_nodes=num_nodes) self.verb = "start" class NodeResetAction(NodeAbstractAction): def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None: - super().__init__(manager=manager) + super().__init__(manager=manager, num_nodes=num_nodes) self.verb = "reset" @@ -371,6 +379,7 @@ class ActionManager: "NODE_FILE_DELETE": NodeFileDeleteAction, "NODE_FILE_REPAIR": NodeFileRepairAction, "NODE_FILE_RESTORE": NodeFileRestoreAction, + "NODE_FILE_CORRUPT": NodeFileCorruptAction, "NODE_FOLDER_SCAN": NodeFolderScanAction, "NODE_FOLDER_CHECKHASH": NodeFolderCheckhashAction, "NODE_FOLDER_REPAIR": NodeFolderRepairAction, @@ -387,7 +396,7 @@ class ActionManager: def __init__( self, - session: PrimaiteSession, # reference to session for looking up stuff + session: "PrimaiteSession", # reference to session for looking up stuff actions: List[str], # stores list of actions available to agent node_uuids: List[str], # allows mapping index to node max_folders_per_node: int = 2, # allows calculating shape @@ -400,7 +409,7 @@ class ActionManager: ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address. act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions ) -> None: - self.session: PrimaiteSession = session + self.session: "PrimaiteSession" = session self.sim: Simulation = self.session.simulation self.node_uuids: List[str] = node_uuids self.protocols: List[str] = protocols @@ -417,7 +426,8 @@ class ActionManager: for nic_uuid, nic_obj in nics.items(): self.ip_address_list.append(nic_obj.ip_address) - action_args = { + # action_args are settings which are applied to the action space as a whole. + global_action_args = { "num_nodes": len(node_uuids), "num_folders": max_folders_per_node, "num_files": max_files_per_folder, @@ -427,10 +437,21 @@ class ActionManager: "num_protocols": len(self.protocols), "num_ports": len(self.protocols), "num_ips": len(self.ip_address_list), + "max_acl_rules":max_acl_rules, + "max_nics_per_node": max_nics_per_node, } self.actions: Dict[str, AbstractAction] = {} - for act_type in actions: - self.actions[act_type] = self.__act_class_identifiers[act_type](self, **action_args) + for act_spec in actions: + # each action is provided into the action space config like this: + # - type: ACTION_TYPE + # options: + # option_1: value1 + # option_2: value2 + # where `type` decides which AbstractAction subclass should be used + # and `options` is an optional dict of options to pass to the init method of the action class + act_type = act_spec.get('type') + act_options = act_spec.get('options', {}) + self.actions[act_type] = self.__act_class_identifiers[act_type](self, **global_action_args, **act_options) self.action_map: Dict[int, Tuple[str, Dict]] = {} """ @@ -448,15 +469,41 @@ class ActionManager: def _enumerate_actions( self, - ) -> Dict[int, Tuple[AbstractAction, Dict]]: + ) -> Dict[int, Tuple[str, Dict]]: + """Generate a list of all the possible actions that could be taken. + + This enumerates all actions all combinations of parametes you could choose for those actions. The output + of this function is intended to populate the self.action_map parameter in the situation where the user provides + a list of action types, but doesn't specify any subset of actions that should be made available to the agent. + + The enumeration relies on the Actions' `shape` attribute. + + :return: An action map maps consecutive integers to a combination of Action type and parameter choices. + An example output could be: + {0: ("DONOTHING", {'dummy': 0}), + 1: ("NODE_OS_SCAN", {'node_id': 0}), + 2: ("NODE_OS_SCAN", {'node_id': 1}), + 3: ("NODE_FOLDER_SCAN", {'node_id:0, folder_id:0}), + ... #etc... + } + :rtype: Dict[int, Tuple[AbstractAction, Dict]] + """ all_action_possibilities = [] - for action in self.actions.values(): - param_names = (list(action.shape.keys()),) + for act_name, action in self.actions.items(): + param_names = list(action.shape.keys()) num_possibilities = list(action.shape.values()) possibilities = [range(n) for n in num_possibilities] - itertools.product(action.shape.values()) - all_action_possibilities.append((action, {})) + param_combinations = list(itertools.product(*possibilities)) + all_action_possibilities.extend( + [ + ( + act_name, {param_names[i]:param_combinations[j][i] for i in range(len(param_names))} + ) for j in range(len(param_combinations))] + ) + + return {i:p for i,p in enumerate(all_action_possibilities)} + def get_action(self, action: int) -> Tuple[str, Dict]: """Produce action in CAOS format""" @@ -517,21 +564,16 @@ class ActionManager: return nics[nic_idx] @classmethod - def from_config(cls, session: PrimaiteSession, cfg: Dict) -> "ActionManager": + def from_config(cls, session: "PrimaiteSession", cfg: Dict) -> "ActionManager": obj = cls( session=session, actions=cfg["action_list"], - node_uuids=cfg["options"]["nodes"], - max_folders_per_node=cfg["options"]["max_folders_per_node"], - max_files_per_folder=cfg["options"]["max_files_per_folder"], - max_services_per_node=cfg["options"]["max_services_per_node"], - max_nics_per_node=cfg["options"]["max_nics_per_node"], - max_acl_rules=cfg["options"]["max_acl_rules"], - max_X=cfg["options"]["max_X"], - protocols=session.options.ports, - ports=session.options.protocols, + # node_uuids=cfg["options"]["node_uuids"], + **cfg['options'], + protocols=session.options.protocols, + ports=session.options.ports, ip_address_list=None, - act_map=cfg["action_map"], + act_map=cfg.get("action_map"), ) return obj diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 528c0b1a..4fd52d96 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -2,7 +2,7 @@ # That's because I want to point out that this is disctinct from 'agent' in the reinforcement learning sense of the word # If you disagree, make a comment in the PR review and we can discuss from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, TypeAlias, Union +from typing import Any, Dict, List, Optional, Tuple, TypeAlias, Union import numpy as np @@ -41,17 +41,17 @@ class AbstractAgent(ABC): return self.reward_function.calculate(state) @abstractmethod - def get_action(self, obs: ObsType, reward: float = None): + def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: # in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39, # then use a bespoke conversion to take 1-40 int back into CAOS action - return ("NODE", "SERVICE", "SCAN", "", "") + return ("DO_NOTHING", {} ) - @abstractmethod - def format_request(self, action) -> List[str]: + def format_request(self, action:Tuple[str,Dict], options:Dict[str, int]) -> List[str]: # this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator. # therefore the execution definition needs to be a mapping from CAOS into SIMULATOR """Format action into format expected by the simulator, and apply execution definition if applicable.""" - return ["network", "nodes", "", "file_system", "folder", "root", "scan"] + request = self.action_space.form_request(action_identifier=action, action_options=options) + return request class AbstractScriptedAgent(AbstractAgent): @@ -63,8 +63,8 @@ class AbstractScriptedAgent(AbstractAgent): class RandomAgent(AbstractScriptedAgent): """Agent that ignores its observation and acts completely at random.""" - def get_action(self, obs: ObsType, reward: float = None): - return self.action_space.space.sample() + def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: + return self.action_space.get_action(self.action_space.space.sample()) class AbstractGATEAgent(AbstractAgent): diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index ec778176..a4ceb2dd 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -5,13 +5,30 @@ class AbstractReward(): def __init__(self): ... + @abstractmethod def calculate(self, state:Dict) -> float: return 0.3 +class DummyReward(AbstractReward): + + def calculate(self, state: Dict) -> float: + return -0.1 class RewardFunction(): + __rew_class_identifiers:Dict[str,type[AbstractReward]] = { + "DUMMY" : DummyReward + } def __init__(self, reward_function:AbstractReward): self.reward: AbstractReward = reward_function def calculate(self, state:Dict) -> float: return self.reward.calculate(state) + + @classmethod + def from_config(cls, cfg:Dict) -> "RewardFunction": + for rew_component_cfg in cfg['reward_components']: + rew_type = rew_component_cfg['type'] + rew_component = cls.__rew_class_identifiers[rew_type]() + new = cls(reward_function=rew_component) + return new + diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 0f88b322..46e834d6 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -11,7 +11,7 @@ from typing import Dict, List from pydantic import BaseModel from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import AbstractAgent +from primaite.game.agent.interface import AbstractAgent, RandomAgent from primaite.game.agent.observations import ( AclObservation, FileObservation, @@ -25,6 +25,7 @@ from primaite.game.agent.observations import ( UC2BlueObservation, UC2RedObservation, ) +from primaite.game.agent.rewards import RewardFunction from primaite.simulator.network.hardware.base import Link, NIC, Node from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.router import ACLAction, Router @@ -74,10 +75,10 @@ class PrimaiteSession: # to discrete(40) is only necessary for purposes of RL learning, therefore that bit of # code should live inside of the GATE agent subclass) # gets action in CAOS format - agent_action = agent.get_action(agent_obs, agent_reward) + agent_action, action_options = agent.get_action(agent_obs, agent_reward) # 9. CAOS action is converted into request (extra information might be needed to enrich # the request, this is what the execution definition is there for) - agent_request = agent.format_request(agent_action) + agent_request = agent.format_request(agent_action, action_options) # 10. primaite session receives the action from the agents and asks the simulation to apply each self.simulation.apply_action(agent_request) @@ -88,6 +89,10 @@ class PrimaiteSession: @classmethod def from_config(cls, cfg: dict) -> "PrimaiteSession": sess = cls() + sess.options = PrimaiteSessionOptions( + ports = cfg['game_config']['ports'], + protocols = cfg['game_config']['protocols'], + ) sim = sess.simulation net = sim.network @@ -304,13 +309,33 @@ class PrimaiteSession: obs_space = NullObservation() # CREATE ACTION SPACE + action_space_cfg['options']['node_uuids'] = [] + # if a list of nodes is defined, convert them from node references to node UUIDs + for action_node_option in action_space_cfg.get('options',{}).pop('nodes', {}): + if 'node_ref' in action_node_option: + node_uuid = ref_map_nodes[action_node_option['node_ref']] + action_space_cfg['options']['node_uuids'].append(node_uuid) + # Each action space can potentially have a different list of nodes that it can apply to. Therefore, + # we will pass node_uuids as a part of the action space config. + # However, it's not possible to specify the node uuids directly in the config, as they are generated + # dynamically, so we have to translate node references to uuids before passing this config on. + + if 'action_list' in action_space_cfg: + for action_config in action_space_cfg['action_list']: + if 'options' in action_config: + if 'target_router_ref' in action_config['options']: + _target = action_config['options']['target_router_ref'] + action_config['options']['target_router_uuid'] = ref_map_nodes[_target] + action_space = ActionManager.from_config(sess, action_space_cfg) # CREATE REWARD FUNCTION + rew_function = RewardFunction.from_config(reward_function_cfg) # CREATE AGENT if agent_type == "GreenWebBrowsingAgent": - ... + new_agent = RandomAgent(action_space=action_space, observation_space=obs_space, reward_function=rew_function) + sess.agents.append(new_agent) elif agent_type == "GATERLAgent": ... elif agent_type == "RedDatabaseCorruptingAgent": @@ -318,4 +343,5 @@ class PrimaiteSession: else: print("agent type not found") + return sess diff --git a/src/primaite/simulator/sim_container.py b/src/primaite/simulator/sim_container.py index d647b0bc..1df5fe12 100644 --- a/src/primaite/simulator/sim_container.py +++ b/src/primaite/simulator/sim_container.py @@ -27,6 +27,7 @@ class Simulation(SimComponent): am.add_action("network", Action(func=self.network._action_manager)) # pass through domain actions to the domain object am.add_action("domain", Action(func=self.domain._action_manager)) + am.add_action("do_nothing", Action(func=lambda request, context: ())) return am def describe_state(self) -> Dict: