Files
PrimAITE/sandbox.ipynb
2023-10-02 17:21:43 +01:00

686 lines
31 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from primaite.game.session import PrimaiteSession\n",
"from primaite.simulator.sim_container import Simulation\n",
"from primaite.game.agent.interface import AbstractAgent\n",
"from primaite.simulator.network.networks import arcd_uc2_network\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"sess = PrimaiteSession()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"network = sess.simulation.network"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"from ipaddress import IPv4Address\n",
"\n",
"from primaite.simulator.network.container import Network\n",
"from primaite.simulator.network.hardware.base import NIC\n",
"from primaite.simulator.network.hardware.nodes.computer import Computer\n",
"from primaite.simulator.network.hardware.nodes.router import ACLAction, Router\n",
"from primaite.simulator.network.hardware.nodes.server import Server\n",
"from primaite.simulator.network.hardware.nodes.switch import Switch\n",
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
"from primaite.simulator.network.transmission.transport_layer import Port\n",
"from primaite.simulator.system.applications.database_client import DatabaseClient\n",
"from primaite.simulator.system.services.database_service import DatabaseService\n",
"from primaite.simulator.system.services.dns_client import DNSClient\n",
"from primaite.simulator.system.services.dns_server import DNSServer\n",
"from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-10-02 15:10:20,422: Added node 6abb7664-4d17-45ff-a3c7-dbcccffcfd6d to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,424: Added node 3edbc521-3c80-47e3-8017-dbc38fb00a73 to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,428: Added node 94457fb9-04a1-4dc1-9ff7-b64df0da7424 to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,432: Added node 0d311d72-139c-41bf-aef7-fa9b01b124d7 to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,439: Added node 6161e785-f377-48de-aa4f-20d3646da635 to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,444: Added node 55a9e9f8-ee3a-4c28-9b6d-c0c0d78a3f6a to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,447: Added node 2f04ca45-3439-489a-81f7-41cca5ae8adc to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,531: Added node 98660c30-8e48-4b96-967a-d62ca71b4d6d to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,545: Added node 1a184184-b204-40de-986a-4d7459036dbe to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,551: Added node 17b92b9a-6805-4677-85f7-c0c0521a6e25 to Network 045a3114-4aac-4687-a10e-432cfd138325\n",
"2023-10-02 15:10:20,555::ERROR::primaite.simulator.network.hardware.base::175::NIC da:f3:1b:87:24:20/192.168.10.110 cannot be enabled as it is not connected to a Link\n"
]
}
],
"source": [
"router_1 = Router(hostname=\"router_1\", num_ports=5)\n",
"router_1.power_on()\n",
"router_1.configure_port(port=1, ip_address=\"192.168.1.1\", subnet_mask=\"255.255.255.0\")\n",
"router_1.configure_port(port=2, ip_address=\"192.168.10.1\", subnet_mask=\"255.255.255.0\")\n",
"\n",
"# Switch 1\n",
"switch_1 = Switch(hostname=\"switch_1\", num_ports=8)\n",
"switch_1.power_on()\n",
"network.connect(endpoint_a=router_1.ethernet_ports[1], endpoint_b=switch_1.switch_ports[8])\n",
"router_1.enable_port(1)\n",
"\n",
"# Switch 2\n",
"switch_2 = Switch(hostname=\"switch_2\", num_ports=8)\n",
"switch_2.power_on()\n",
"network.connect(endpoint_a=router_1.ethernet_ports[2], endpoint_b=switch_2.switch_ports[8])\n",
"router_1.enable_port(2)\n",
"\n",
"# Client 1\n",
"client_1 = Computer(\n",
" hostname=\"client_1\",\n",
" ip_address=\"192.168.10.21\",\n",
" subnet_mask=\"255.255.255.0\",\n",
" default_gateway=\"192.168.10.1\",\n",
" dns_server=IPv4Address(\"192.168.1.10\"),\n",
")\n",
"client_1.power_on()\n",
"client_1.software_manager.install(DNSClient)\n",
"client_1_dns_client_service: DNSServer = client_1.software_manager.software[\"DNSClient\"] # noqa\n",
"client_1_dns_client_service.start()\n",
"network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])\n",
"client_1.software_manager.install(DataManipulationBot)\n",
"db_manipulation_bot: DataManipulationBot = client_1.software_manager.software[\"DataManipulationBot\"]\n",
"db_manipulation_bot.configure(server_ip_address=IPv4Address(\"192.168.1.14\"), payload=\"DROP TABLE IF EXISTS user;\")\n",
"\n",
"# Client 2\n",
"client_2 = Computer(\n",
" hostname=\"client_2\",\n",
" ip_address=\"192.168.10.22\",\n",
" subnet_mask=\"255.255.255.0\",\n",
" default_gateway=\"192.168.10.1\",\n",
" dns_server=IPv4Address(\"192.168.1.10\"),\n",
")\n",
"client_2.power_on()\n",
"client_2.software_manager.install(DNSClient)\n",
"client_2_dns_client_service: DNSServer = client_2.software_manager.software[\"DNSClient\"] # noqa\n",
"client_2_dns_client_service.start()\n",
"network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2])\n",
"\n",
"# Domain Controller\n",
"domain_controller = Server(\n",
" hostname=\"domain_controller\",\n",
" ip_address=\"192.168.1.10\",\n",
" subnet_mask=\"255.255.255.0\",\n",
" default_gateway=\"192.168.1.1\",\n",
")\n",
"domain_controller.power_on()\n",
"domain_controller.software_manager.install(DNSServer)\n",
"\n",
"network.connect(endpoint_b=domain_controller.ethernet_port[1], endpoint_a=switch_1.switch_ports[1])\n",
"\n",
"# Database Server\n",
"database_server = Server(\n",
" hostname=\"database_server\",\n",
" ip_address=\"192.168.1.14\",\n",
" subnet_mask=\"255.255.255.0\",\n",
" default_gateway=\"192.168.1.1\",\n",
" dns_server=IPv4Address(\"192.168.1.10\"),\n",
")\n",
"database_server.power_on()\n",
"network.connect(endpoint_b=database_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[3])\n",
"\n",
"ddl = \"\"\"\n",
"CREATE TABLE IF NOT EXISTS user (\n",
"id INTEGER PRIMARY KEY AUTOINCREMENT,\n",
"name VARCHAR(50) NOT NULL,\n",
"email VARCHAR(50) NOT NULL,\n",
"age INT,\n",
"city VARCHAR(50),\n",
"occupation VARCHAR(50)\n",
");\"\"\"\n",
"\n",
"user_insert_statements = [\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');\", # noqa\n",
" \"INSERT INTO user (name, email, age, city, occupation) VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');\", # noqa\n",
"]\n",
"database_server.software_manager.install(DatabaseService)\n",
"database_service: DatabaseService = database_server.software_manager.software[\"DatabaseService\"] # noqa\n",
"database_service.start()\n",
"database_service._process_sql(ddl, None) # noqa\n",
"for insert_statement in user_insert_statements:\n",
" database_service._process_sql(insert_statement, None) # noqa\n",
"\n",
"# Web Server\n",
"web_server = Server(\n",
" hostname=\"web_server\",\n",
" ip_address=\"192.168.1.12\",\n",
" subnet_mask=\"255.255.255.0\",\n",
" default_gateway=\"192.168.1.1\",\n",
" dns_server=IPv4Address(\"192.168.1.10\"),\n",
")\n",
"web_server.power_on()\n",
"web_server.software_manager.install(DatabaseClient)\n",
"\n",
"database_client: DatabaseClient = web_server.software_manager.software[\"DatabaseClient\"]\n",
"database_client.configure(server_ip_address=IPv4Address(\"192.168.1.14\"))\n",
"network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2])\n",
"database_client.run()\n",
"database_client.connect()\n",
"\n",
"# register the web_server to a domain\n",
"dns_server_service: DNSServer = domain_controller.software_manager.software[\"DNSServer\"] # noqa\n",
"dns_server_service.start()\n",
"dns_server_service.dns_register(\"arcd.com\", web_server.ip_address)\n",
"\n",
"# Backup Server\n",
"backup_server = Server(\n",
" hostname=\"backup_server\",\n",
" ip_address=\"192.168.1.16\",\n",
" subnet_mask=\"255.255.255.0\",\n",
" default_gateway=\"192.168.1.1\",\n",
" dns_server=IPv4Address(\"192.168.1.10\"),\n",
")\n",
"backup_server.power_on()\n",
"network.connect(endpoint_b=backup_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[4])\n",
"\n",
"# Security Suite\n",
"security_suite = Server(\n",
" hostname=\"security_suite\",\n",
" ip_address=\"192.168.1.110\",\n",
" subnet_mask=\"255.255.255.0\",\n",
" default_gateway=\"192.168.1.1\",\n",
" dns_server=IPv4Address(\"192.168.1.10\"),\n",
")\n",
"security_suite.power_on()\n",
"network.connect(endpoint_b=security_suite.ethernet_port[1], endpoint_a=switch_1.switch_ports[7])\n",
"security_suite.connect_nic(NIC(ip_address=\"192.168.10.110\", subnet_mask=\"255.255.255.0\"))\n",
"network.connect(endpoint_b=security_suite.ethernet_port[2], endpoint_a=switch_2.switch_ports[7])\n",
"\n",
"router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)\n",
"\n",
"router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)\n",
"\n",
"# Allow PostgreSQL requests\n",
"router_1.acl.add_rule(\n",
" action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0\n",
")\n",
"\n",
"# Allow DNS requests\n",
"router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1)\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"node_uuid_list = list(sess.simulation.network.nodes.keys())"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"from primaite.game.agent.actions import ActionManager"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"actman = ActionManager(sess.simulation, [\"DONOTHING\", \"NODE_SERVICE_SCAN\", \"NODE_SERVICE_STOP\", \"NODE_FOLDER_SCAN\"],node_uuid_list,act_map={\n",
" 0:{\n",
" \"action\": \"DONOTHING\",\n",
" \"options\": {}\n",
" },\n",
" 1:{\n",
" \"action\": \"NODE_SERVICE_SCAN\",\n",
" \"options\": {\"node_id\":0, \"service_id\":0},\n",
" },\n",
" 2:{\n",
" \"action\": \"NODE_SERVICE_SCAN\",\n",
" \"options\": {\"node_id\":1, \"service_id\":0},\n",
" },\n",
" 3:{\n",
" \"action\": \"NODE_FOLDER_SCAN\",\n",
" \"options\": {\"node_id\":4, \"folder_id\":0},\n",
" }\n",
"})"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"act_id, act_options = actman.get_action(3)\n",
"my_trial_act = actman.form_request(action_identifier=act_id, action_options=act_options)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"sess.simulation.apply_action(my_trial_act)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['network',\n",
" 'node',\n",
" '6161e785-f377-48de-aa4f-20d3646da635',\n",
" 'file_system',\n",
" 'folder',\n",
" '5aefe92b-923c-4684-b3bf-e78dd18d4771',\n",
" 'scan']"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"my_trial_act"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sess.step()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sess.step_counter"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from gym import spaces"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sp = spaces.Tuple( (spaces.MultiDiscrete([3, 2]), spaces.MultiDiscrete([3, 2]), spaces.MultiDiscrete([3, 2]),))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sp.sample()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import yaml\n",
"from primaite.simulator.sim_container import Simulation\n",
"from primaite.simulator.network.hardware.nodes.computer import Computer\n",
"from primaite.simulator.network.hardware.nodes.server import Server\n",
"from primaite.simulator.network.hardware.nodes.switch import Switch\n",
"from primaite.simulator.network.hardware.nodes.router import Router\n",
"\n",
"from primaite.simulator.system.applications.database_client import DatabaseClient\n",
"from primaite.simulator.system.services.database_service import DatabaseService\n",
"from primaite.simulator.system.services.dns_client import DNSClient\n",
"from primaite.simulator.system.services.dns_server import DNSServer\n",
"from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot\n",
"\n",
"\n",
"from primaite.simulator.network.hardware.nodes.router import ACLAction\n",
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
"from primaite.simulator.network.transmission.transport_layer import Port\n",
"\n",
"from ipaddress import IPv4Address\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# import yaml\n",
"\n",
"\n",
"from typing import Dict\n",
"from primaite.game.agent.interface import AbstractAgent\n",
"from primaite.game.agent.observations import AclObservation, FileObservation, FolderObservation, ICSObservation, LinkObservation, NicObservation, NodeObservation, NullObservation, ServiceObservation, UC2BlueObservation, UC2RedObservation\n",
"from primaite.simulator.network.hardware.base import NIC, Link, Node\n",
"from primaite.simulator.system.services.service import Service\n",
"\n",
"from primaite.game.agent.scripted_agents import GreenWebBrowsingAgent, RedDatabaseCorruptingAgent\n",
"from primaite.game.agent.GATE_agents import GATERLAgent\n",
"\n",
"class PrimaiteSession:\n",
"\n",
" def __init__(self):\n",
" self.simulation: Simulation\n",
" self.agents = []\n",
"\n",
" @classmethod\n",
" def from_config(cls, cfg_path):\n",
" ref_map_nodes: Dict[str,Node] = {}\n",
" ref_map_services: Dict[str, Service] = {}\n",
" ref_map_links: Dict[str, Link] = {}\n",
" # ref_map_agents: Dict[str, AgentInterface] = {}\n",
"\n",
"\n",
" session = cls()\n",
" with open(cfg_path, 'r') as file:\n",
" conf = yaml.safe_load(file)\n",
" \n",
" #1. create nodes \n",
" sim = Simulation()\n",
" net = sim.network\n",
" nodes_cfg = conf['simulation']['network']['nodes']\n",
" links_cfg = conf['simulation']['network']['links']\n",
" for node_cfg in nodes_cfg:\n",
" node_ref = node_cfg['ref']\n",
" n_type = node_cfg['type']\n",
" if n_type == 'computer':\n",
" new_node = Computer(hostname = node_cfg['hostname'], \n",
" ip_address = node_cfg['ip_address'], \n",
" subnet_mask = node_cfg['subnet_mask'], \n",
" default_gateway = node_cfg['default_gateway'],\n",
" dns_server = node_cfg['dns_server'])\n",
" elif n_type == 'server':\n",
" new_node = Server(hostname = node_cfg['hostname'], \n",
" ip_address = node_cfg['ip_address'], \n",
" subnet_mask = node_cfg['subnet_mask'], \n",
" default_gateway = node_cfg['default_gateway'],\n",
" dns_server = node_cfg.get('dns_server'))\n",
" elif n_type == 'switch':\n",
" new_node = Switch(hostname = node_cfg['hostname'],\n",
" num_ports = node_cfg.get('num_ports'))\n",
" elif n_type == 'router':\n",
" new_node = Router(hostname=node_cfg['hostname'],\n",
" num_ports = node_cfg.get('num_ports'))\n",
" if 'ports' in node_cfg:\n",
" for port_num, port_cfg in node_cfg['ports'].items():\n",
" new_node.configure_port(port=port_num, \n",
" ip_address=port_cfg['ip_address'],\n",
" subnet_mask=port_cfg['subnet_mask'])\n",
" if 'acl' in node_cfg:\n",
" for r_num, r_cfg in node_cfg['acl'].items():\n",
" # excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating \n",
" # this: 'r_cfg.get('src_port')'\n",
" # Port/IPProtocol. TODO Refactor\n",
" new_node.acl.add_rule(\n",
" action = ACLAction[r_cfg['action']],\n",
" src_port = None if not (p:=r_cfg.get('src_port')) else Port[p],\n",
" dst_port = None if not (p:=r_cfg.get('dst_port')) else Port[p],\n",
" protocol = None if not (p:=r_cfg.get('protocol')) else IPProtocol[p],\n",
" src_ip_address = r_cfg.get('ip_address'),\n",
" dst_ip_address = r_cfg.get('ip_address'),\n",
" position = r_num\n",
" )\n",
" else:\n",
" print('invalid node type')\n",
" if 'services' in node_cfg:\n",
" for service_cfg in node_cfg['services']:\n",
" service_ref = service_cfg['ref']\n",
" service_type = service_cfg['type']\n",
" service_types_mapping = {\n",
" 'DNSClient': DNSClient, # key is equal to the 'name' attr of the service class itself.\n",
" 'DNSServer' : DNSServer,\n",
" 'DatabaseClient': DatabaseClient,\n",
" 'DatabaseService': DatabaseService,\n",
" # 'database_backup': ,\n",
" 'DataManipulationBot': DataManipulationBot,\n",
" # 'web_browser'\n",
" }\n",
" if service_type in service_types_mapping:\n",
" new_node.software_manager.install(service_types_mapping[service_type])\n",
" new_service = new_node.software_manager.software[service_type]\n",
" ref_map_services[service_ref] = new_service\n",
" else:\n",
" print(f\"service type not found {service_type}\")\n",
" # service-dependent options\n",
" if service_type == 'DatabaseClient':\n",
" if 'options' in service_cfg:\n",
" opt = service_cfg['options']\n",
" if 'db_server_ip' in opt:\n",
" new_service.configure(server_ip_address=IPv4Address(opt['db_server_ip']))\n",
" if service_type == 'DNSServer':\n",
" if 'options' in service_cfg:\n",
" opt = service_cfg['options']\n",
" if 'domain_mapping' in opt:\n",
" for domain, ip in opt['domain_mapping'].items():\n",
" new_service.dns_register(domain, ip)\n",
" if 'nics' in node_cfg:\n",
" for nic_num, nic_cfg in node_cfg['nics'].items():\n",
" new_node.connect_nic(NIC(ip_address=nic_cfg['ip_address'], subnet_mask=nic_cfg['subnet_mask']))\n",
"\n",
" net.add_node(new_node)\n",
" new_node.power_on()\n",
" ref_map_nodes[node_ref] = new_node.uuid\n",
"\n",
" #2. create links between nodes\n",
" for link_cfg in links_cfg:\n",
" node_a = net.nodes[ref_map_nodes[link_cfg['endpoint_a_ref']]]\n",
" node_b = net.nodes[ref_map_nodes[link_cfg['endpoint_b_ref']]]\n",
" if isinstance(node_a, Switch):\n",
" endpoint_a = node_a.switch_ports[link_cfg['endpoint_a_port']]\n",
" else:\n",
" endpoint_a = node_a.ethernet_port[link_cfg['endpoint_a_port']]\n",
" if isinstance(node_b, Switch):\n",
" endpoint_b = node_b.switch_ports[link_cfg['endpoint_b_port']]\n",
" else:\n",
" endpoint_b = node_b.ethernet_port[link_cfg['endpoint_b_port']]\n",
" new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)\n",
" ref_map_links[link_cfg['ref']] = new_link.uuid\n",
"\n",
" session.simulation = sim\n",
" #3. create agents\n",
" game_cfg = conf['game_config']\n",
" ports_cfg = game_cfg['ports']\n",
" protocols_cfg = game_cfg['protocols']\n",
" agents_cfg = game_cfg['agents']\n",
"\n",
" for agent_cfg in agents_cfg:\n",
" agent_ref = agent_cfg['ref']\n",
" agent_type = agent_cfg['type']\n",
" action_space_cfg = agent_cfg['action_space']\n",
" observation_space_cfg = agent_cfg['observation_space']\n",
" reward_function_cfg = agent_cfg['reward_function']\n",
" \n",
" # CREATE OBSERVATION SPACE\n",
" if observation_space_cfg is None:\n",
" obs_space = NullObservation()\n",
" elif observation_space_cfg['type'] == 'UC2BlueObservation':\n",
" node_obs_list = []\n",
" link_obs_list = []\n",
" \n",
" \n",
" #node ip to index maps ip addresses to node id, as there are potentially multiple nics on a node, there are multiple ip addresses\n",
" node_ip_to_index = {}\n",
" for node_idx, node_cfg in enumerate(nodes_cfg):\n",
" n_ref = node_cfg['ref']\n",
" n_obj = net.nodes[ref_map_nodes[n_ref]]\n",
" for nic_uuid, nic_obj in n_obj.nics.items():\n",
" node_ip_to_index[nic_obj.ip_address] = node_idx + 2\n",
"\n",
" \n",
" \n",
" for node_obs_cfg in observation_space_cfg['options']['nodes']:\n",
" node_ref = node_obs_cfg['node_ref']\n",
" folder_obs_list = []\n",
" service_obs_list = []\n",
" if 'services' in node_obs_cfg:\n",
" for service_obs_cfg in node_obs_cfg['services']:\n",
" service_obs_list.append(ServiceObservation(where=['network','nodes',ref_map_nodes[node_ref],'services',ref_map_services[service_obs_cfg['service_ref']]]))\n",
" if 'folders' in node_obs_cfg:\n",
" for folder_obs_cfg in node_obs_cfg['folders']:\n",
" file_obs_list = []\n",
" if 'files' in folder_obs_cfg:\n",
" for file_obs_cfg in folder_obs_cfg['files']:\n",
" file_obs_list.append(FileObservation(where=['network','nodes',ref_map_nodes[node_ref], 'folders',folder_obs_cfg['folder_name'], 'files', file_obs_cfg['file_name']]))\n",
" folder_obs_list.append(FolderObservation(where=['network','nodes',ref_map_nodes[node_ref], 'folders',folder_obs_cfg['folder_name']], files=file_obs_list))\n",
" nic_obs_list = []\n",
" for nic_uuid in net.nodes[ref_map_nodes[node_obs_cfg['node_ref']]].nics.keys():\n",
" nic_obs_list.append(NicObservation(where=['network','nodes',ref_map_nodes[node_ref],'NICs',nic_uuid]))\n",
" node_obs_list.append(NodeObservation(where=['network','nodes',ref_map_nodes[node_ref]], services=service_obs_list, folders=folder_obs_list,nics=nic_obs_list, logon_status=False))\n",
" for link_obs_cfg in observation_space_cfg['options']['links']:\n",
" link_ref = link_obs_cfg['link_ref']\n",
" link_obs_list.append(LinkObservation(where=['network' ,'links', ref_map_links[link_ref]]))\n",
"\n",
" acl_obs = AclObservation(node_ip_to_id=node_ip_to_index, ports=game_cfg['ports'], protocols=game_cfg['ports'], where=['network','nodes',observation_space_cfg['options']['acl']['router_node_ref']])\n",
" obs_space = UC2BlueObservation(nodes=node_obs_list,links=link_obs_list,acl=acl_obs, ics=ICSObservation())\n",
" elif observation_space_cfg['type'] == 'UC2RedObservation':\n",
" obs_space = UC2RedObservation.from_config(observation_space_cfg['options'], sim=sim)\n",
" else:\n",
" print(\"observation space config not specified correctly.\")\n",
" obs_space = NullObservation()\n",
" \n",
" # CREATE ACTION SPACE\n",
" \n",
"\n",
"\n",
" # CREATE REWARD FUNCTION\n",
"\n",
" # CREATE AGENT\n",
" if agent_type == 'GreenWebBrowsingAgent':\n",
" ...\n",
" elif agent_type == 'GATERLAgent':\n",
" ...\n",
" elif agent_type == 'RedDatabaseCorruptingAgent':\n",
" ...\n",
" else:\n",
" print(\"agent type not found\")\n",
"\n",
"\n",
" #4. set up agents' actions and observation spaces.\n",
" return session\n",
"\n",
"s = PrimaiteSession.from_config('example_config.yaml')\n",
"# print(s.simulation.describe_state())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"s.agents"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}