Files
PrimAITE/sandbox.ipynb
2023-09-26 12:54:56 +01:00

265 lines
12 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import yaml\n",
"from primaite.simulator.sim_container import Simulation\n",
"from primaite.simulator.network.hardware.nodes.computer import Computer\n",
"from primaite.simulator.network.hardware.nodes.server import Server\n",
"from primaite.simulator.network.hardware.nodes.switch import Switch\n",
"from primaite.simulator.network.hardware.nodes.router import Router\n",
"\n",
"from primaite.simulator.system.applications.database_client import DatabaseClient\n",
"from primaite.simulator.system.services.database_service import DatabaseService\n",
"from primaite.simulator.system.services.dns_client import DNSClient\n",
"from primaite.simulator.system.services.dns_server import DNSServer\n",
"from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot\n",
"\n",
"\n",
"from primaite.simulator.network.hardware.nodes.router import ACLAction\n",
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
"from primaite.simulator.network.transmission.transport_layer import Port\n",
"\n",
"from ipaddress import IPv4Address\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-09-26 12:19:50,895: Added node 0fb262e1-a714-420a-aec7-be37f0deeb75 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
"2023-09-26 12:19:50,898: Added node 310ca8d7-01e0-401e-b604-705c290e5376 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
"2023-09-26 12:19:50,900: Added node b3b08f1f-7805-47b2-bdb6-3d83098cd740 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
"2023-09-26 12:19:50,903: Added node adb37f3e-2307-4123-bff3-01f125883be8 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
"2023-09-26 12:19:50,906: Added node 1a490716-2ccd-452d-b87e-324d29120b59 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
"2023-09-26 12:19:50,911: Added node 033460d8-0249-4bdd-aaf0-751b24cc0a1e to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
"2023-09-26 12:19:50,914: Added node 1e7e4e49-78bf-4031-8372-ee71902720f3 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
"2023-09-26 12:19:50,916: Added node c9f24a13-e5c8-437b-9234-b0c3f8120e2c to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
"2023-09-26 12:19:50,920: Added node c881f3ee-2176-493b-a6c2-cad829bf0b6d to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n",
"2023-09-26 12:19:50,922: Added node a3ea75d8-bc2c-4713-92a4-2588b4f43ed6 to Network 9318bac2-d9f4-4e71-bb4c-09ffc573ed1c\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"service type not found DatabaseBackup\n",
"service type not found WebBrowser\n"
]
}
],
"source": [
"# import yaml\n",
"\n",
"\n",
"from typing import Dict\n",
"from primaite.game.agent.interface import AbstractAgent\n",
"from primaite.simulator.network.hardware.base import NIC, Link, Node\n",
"from primaite.simulator.system.services.service import Service\n",
"\n",
"from primaite.game.agent.scripted_agents import GreenWebBrowsingAgent, RedDatabaseCorruptingAgent\n",
"from primaite.game.agent.GATE_agents import GATERLAgent\n",
"\n",
"class PrimaiteSession:\n",
"\n",
" def __init__(self):\n",
" self.simulation: Simulation\n",
" self.agents = []\n",
"\n",
" @classmethod\n",
" def from_config(cls, cfg_path):\n",
" ref_map_nodes: Dict[str,Node] = {}\n",
" ref_map_services: Dict[str, Service] = {}\n",
" ref_map_links: Dict[str, Link] = {}\n",
" # ref_map_agents: Dict[str, AgentInterface] = {}\n",
"\n",
"\n",
" session = cls()\n",
" with open(cfg_path, 'r') as file:\n",
" conf = yaml.safe_load(file)\n",
" \n",
" #1. create nodes \n",
" sim = Simulation()\n",
" net = sim.network\n",
" nodes_cfg = conf['simulation']['network']['nodes']\n",
" links_cfg = conf['simulation']['network']['links']\n",
" for node_cfg in nodes_cfg:\n",
" node_ref = node_cfg['ref']\n",
" n_type = node_cfg['type']\n",
" if n_type == 'computer':\n",
" new_node = Computer(hostname = node_cfg['hostname'], \n",
" ip_address = node_cfg['ip_address'], \n",
" subnet_mask = node_cfg['subnet_mask'], \n",
" default_gateway = node_cfg['default_gateway'],\n",
" dns_server = node_cfg['dns_server'])\n",
" elif n_type == 'server':\n",
" new_node = Server(hostname = node_cfg['hostname'], \n",
" ip_address = node_cfg['ip_address'], \n",
" subnet_mask = node_cfg['subnet_mask'], \n",
" default_gateway = node_cfg['default_gateway'],\n",
" dns_server = node_cfg.get('dns_server'))\n",
" elif n_type == 'switch':\n",
" new_node = Switch(hostname = node_cfg['hostname'],\n",
" num_ports = node_cfg.get('num_ports'))\n",
" elif n_type == 'router':\n",
" new_node = Router(hostname=node_cfg['hostname'],\n",
" num_ports = node_cfg.get('num_ports'))\n",
" if 'ports' in node_cfg:\n",
" for port_num, port_cfg in node_cfg['ports'].items():\n",
" new_node.configure_port(port=port_num, \n",
" ip_address=port_cfg['ip_address'],\n",
" subnet_mask=port_cfg['subnet_mask'])\n",
" if 'acl' in node_cfg:\n",
" for r_num, r_cfg in node_cfg['acl'].items():\n",
" # excuse the uncommon walrus operator ` := `. It's just here as a shorthand, so that we can do\n",
" # both of these things once: check if a key is defined, access and convert it to a \n",
" # Port/IPProtocol. TODO Refactor\n",
" new_node.acl.add_rule(\n",
" action = ACLAction[r_cfg['action']],\n",
" src_port = None if not (p:=r_cfg.get('src_port')) else Port[p],\n",
" dst_port = None if not (p:=r_cfg.get('dst_port')) else Port[p],\n",
" protocol = None if not (p:=r_cfg.get('protocol')) else IPProtocol[p],\n",
" src_ip_address = r_cfg.get('ip_address'),\n",
" dst_ip_address = r_cfg.get('ip_address'),\n",
" position = r_num\n",
" )\n",
" else:\n",
" print('invalid node type')\n",
" if 'services' in node_cfg:\n",
" for service_cfg in node_cfg['services']:\n",
" service_ref = service_cfg['ref']\n",
" service_type = service_cfg['type']\n",
" service_types_mapping = {\n",
" 'DNSClient': DNSClient, # key is equal to the 'name' attr of the service class itself.\n",
" 'DNSServer' : DNSServer,\n",
" 'DatabaseClient': DatabaseClient,\n",
" 'DatabaseService': DatabaseService,\n",
" # 'database_backup': ,\n",
" 'DataManipulationBot': DataManipulationBot,\n",
" # 'web_browser'\n",
" }\n",
" if service_type in service_types_mapping:\n",
" new_node.software_manager.install(service_types_mapping[service_type])\n",
" new_service = new_node.software_manager.software[service_type]\n",
" ref_map_services[service_ref] = new_service\n",
" else:\n",
" print(f\"service type not found {service_type}\")\n",
" # service-dependent options\n",
" if service_type == 'DatabaseClient':\n",
" if 'options' in service_cfg:\n",
" opt = service_cfg['options']\n",
" if 'db_server_ip' in opt:\n",
" new_service.configure(server_ip_address=IPv4Address(opt['db_server_ip']))\n",
" if service_type == 'DNSServer':\n",
" if 'options' in service_cfg:\n",
" opt = service_cfg['options']\n",
" if 'domain_mapping' in opt:\n",
" for domain, ip in opt['domain_mapping'].items():\n",
" new_service.dns_register(domain, ip)\n",
" if 'nics' in node_cfg:\n",
" for nic_num, nic_cfg in node_cfg['nics'].items():\n",
" new_node.connect_nic(NIC(ip_address=nic_cfg['ip_address'], subnet_mask=nic_cfg['subnet_mask']))\n",
"\n",
" net.add_node(new_node)\n",
" new_node.power_on()\n",
" ref_map_nodes[node_ref] = new_node.uuid\n",
"\n",
" #2. create links between nodes\n",
" for link_cfg in links_cfg:\n",
" node_a = net.nodes[ref_map_nodes[link_cfg['endpoint_a_ref']]]\n",
" node_b = net.nodes[ref_map_nodes[link_cfg['endpoint_b_ref']]]\n",
" if isinstance(node_a, Switch):\n",
" endpoint_a = node_a.switch_ports[link_cfg['endpoint_a_port']]\n",
" else:\n",
" endpoint_a = node_a.ethernet_port[link_cfg['endpoint_a_port']]\n",
" if isinstance(node_b, Switch):\n",
" endpoint_b = node_b.switch_ports[link_cfg['endpoint_b_port']]\n",
" else:\n",
" endpoint_b = node_b.ethernet_port[link_cfg['endpoint_b_port']]\n",
" new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)\n",
" ref_map_links[link_cfg['ref']] = new_link.uuid\n",
"\n",
" session.simulation = sim\n",
" #3. create agents\n",
" game_cfg = conf['game_config']\n",
" ports_cfg = game_cfg['ports']\n",
" protocols_cfg = game_cfg['protocols']\n",
" agents_cfg = game_cfg['agents']\n",
"\n",
" for agent_cfg in agents_cfg:\n",
" agent_ref = agent_cfg['ref']\n",
" agent_type = agent_cfg['type']\n",
" action_space_cfg = agent_cfg['action_space']\n",
" observation_space_cfg = agent_cfg['observation_space']\n",
" reward_function_cfg = agent_cfg['reward_function']\n",
" if agent_type == 'GreenWebBrowsingAgent':\n",
" new_agent = GreenWebBrowsingAgent()\n",
"\n",
"\n",
" #4. set up agents' actions and observation spaces.\n",
" return session\n",
"\n",
"s = PrimaiteSession.from_config('example_config.yaml')\n",
"# print(s.simulation.describe_state())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(s.simulation.describe_state())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}