diff --git a/.gitignore b/.gitignore index 892751d9..1ce2ca9d 100644 --- a/.gitignore +++ b/.gitignore @@ -156,3 +156,4 @@ benchmark/output # src/primaite/notebooks/scratch.ipynb src/primaite/notebooks/scratch.py sandbox.py +sandbox/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 227cec69..cb2f418b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,17 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +- Fixed a bug where ACL rules were not resetting on episode reset. +- Fixed a bug where blue agent's ACL actions were being applied against the wrong IP addresses +- Fixed a bug where deleted files and folders did not reset correctly on episode reset. +- Fixed a bug where service health status was using the actual health state instead of the visible health state +- Fixed a bug where the database file health status was using the incorrect value for negative rewards +- Fixed a bug preventing file actions from reaching their intended file +- Made database patch correctly take 2 timesteps instead of being immediate +- Made database patch only possible when the software is compromised or good, it's no longer possible when the software is OFF or RESETTING +- Temporarily disable the blue agent file delete action due to crashes. This issue is resolved in another branch that will be merged into dev soon. +- Fix a bug where ACLs were not showing up correctly in the observation space. +- Added a notebook which explains Data manipulation scenario, demonstrates the attack, and shows off blue agent's action space, observation space, and reward function. - Made packet capture and system logging optional (off by default). To turn on, change the io_settings.save_pcap_logs and io_settings.save_sys_logs settings in the config. - Made observation space flattening optional (on by default). To turn off for an agent, change the agent_settings.flatten_obs setting in the config. - Fixed an issue where the data manipulation attack was triggered at episode start. diff --git a/src/primaite/VERSION b/src/primaite/VERSION index 09fb39d2..72f12ef8 100644 --- a/src/primaite/VERSION +++ b/src/primaite/VERSION @@ -1 +1 @@ -3.0.0b5 +3.0.0b6dev diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index ee0eb7ff..d8cd0099 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -31,7 +31,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: @@ -304,63 +304,63 @@ agents: action: "NODE_RESET" options: node_id: 5 - 22: + 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite) action: "NETWORK_ACL_ADDRULE" options: position: 1 permission: 2 - source_ip_id: 7 - dest_ip_id: 1 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 23: + 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite) action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 2 permission: 2 - source_ip_id: 8 - dest_ip_id: 1 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 24: + 24: # block tcp traffic from client 1 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 3 permission: 2 - source_ip_id: 7 - dest_ip_id: 3 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 25: + 25: # block tcp traffic from client 2 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 4 permission: 2 - source_ip_id: 8 - dest_ip_id: 3 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 26: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 5 permission: 2 - source_ip_id: 7 - dest_ip_id: 4 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 27: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 6 permission: 2 - source_ip_id: 8 - dest_ip_id: 4 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 @@ -504,6 +504,24 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 + reward_function: reward_components: diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index c1e2ea81..6aa54487 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -25,7 +25,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 4c47bfaa..6b15c5f8 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -296,6 +296,16 @@ class NodeFileDeleteAction(NodeFileAbstractAction): super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) self.verb: str = "delete" + def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + node_uuid = self.manager.get_node_uuid_by_idx(node_id) + folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id) + file_uuid = self.manager.get_file_uuid_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id) + if node_uuid is None or folder_uuid is None or file_uuid is None: + return ["do_nothing"] + return ["do_nothing"] + # return ["network", "node", node_uuid, "file_system", "delete", "file", folder_uuid, file_uuid] + class NodeFileRepairAction(NodeFileAbstractAction): """Action which repairs a file.""" @@ -460,13 +470,13 @@ class NetworkACLAddRuleAction(AbstractAction): dst_ip = "ALL" return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS else: - dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id) + dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id - 2) # subtract 2 to account for UNUSED=0, and ALL=1 if dest_port_id == 1: dst_port = "ALL" else: - dst_port = self.manager.get_port_by_idx(dest_port_id) + dst_port = self.manager.get_port_by_idx(dest_port_id - 2) # subtract 2 to account for UNUSED=0, and ALL=1 return [ @@ -914,6 +924,15 @@ class ActionManager: :return: The constructed ActionManager. :rtype: ActionManager """ + ip_address_order = cfg["options"].pop("ip_address_order", {}) + ip_address_list = [] + for entry in ip_address_order: + node_ref = entry["node_ref"] + nic_num = entry["nic_num"] + node_obj = game.simulation.network.get_node_by_hostname(node_ref) + ip_address = node_obj.ethernet_port[nic_num].ip_address + ip_address_list.append(ip_address) + obj = cls( game=game, actions=cfg["action_list"], @@ -921,7 +940,7 @@ class ActionManager: **cfg["options"], protocols=game.options.protocols, ports=game.options.ports, - ip_address_list=None, + ip_address_list=ip_address_list or None, act_map=cfg.get("action_map"), ) diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index cac5b91e..b7962827 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -1,5 +1,6 @@ """Manages the observation space for the agent.""" from abc import ABC, abstractmethod +from ipaddress import IPv4Address from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium import spaces @@ -648,10 +649,13 @@ class AclObservation(AbstractObservation): # TODO: what if the ACL has more rules than num of max rules for obs space obs = {} - for i, rule_state in acl_state.items(): + acl_items = dict(acl_state.items()) + i = 1 # don't show rule 0 for compatibility reasons. + while i < self.num_rules + 1: + rule_state = acl_items[i] if rule_state is None: - obs[i + 1] = { - "position": i, + obs[i] = { + "position": i - 1, "permission": 0, "source_node_id": 0, "source_port": 0, @@ -660,15 +664,26 @@ class AclObservation(AbstractObservation): "protocol": 0, } else: - obs[i + 1] = { - "position": i, + src_ip = rule_state["src_ip_address"] + src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] + dst_ip = rule_state["dst_ip_address"] + dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] + src_port = rule_state["src_port"] + src_port_id = 1 if src_port is None else self.port_to_id[src_port] + dst_port = rule_state["dst_port"] + dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] + protocol = rule_state["protocol"] + protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] + obs[i] = { + "position": i - 1, "permission": rule_state["action"], - "source_node_id": self.node_to_id[rule_state["src_ip_address"]], - "source_port": self.port_to_id[rule_state["src_port"]], - "dest_node_id": self.node_to_id[rule_state["dst_ip_address"]], - "dest_port": self.port_to_id[rule_state["dst_port"]], - "protocol": self.protocol_to_id[rule_state["protocol"]], + "source_node_id": src_node_id, + "source_port": src_port_id, + "dest_node_id": dst_node_ip, + "dest_port": dst_port_id, + "protocol": protocol_id, } + i += 1 return obs @property diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 8f064be3..6cee127f 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -110,6 +110,13 @@ class DatabaseFileIntegrity(AbstractReward): :type state: Dict """ database_file_state = access_from_nested_dict(state, self.location_in_state) + if database_file_state is NOT_PRESENT_IN_STATE: + _LOGGER.info( + f"Could not calculate {self.__class__} reward because " + "simulation state did not contain enough information." + ) + return 0.0 + health_status = database_file_state["health_status"] if health_status == 2: return -1 diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 08098754..146261f9 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -13,11 +13,9 @@ from primaite.game.agent.rewards import RewardFunction from primaite.session.io import SessionIO, SessionIOSettings from primaite.simulator.network.hardware.base import NIC, NodeOperatingState from primaite.simulator.network.hardware.nodes.computer import Computer -from primaite.simulator.network.hardware.nodes.router import ACLAction, Router +from primaite.simulator.network.hardware.nodes.router import Router from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.network.hardware.nodes.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser @@ -115,7 +113,7 @@ class PrimaiteGame: self.update_agents(sim_state) # Apply all actions to simulation as requests - self.apply_agent_actions() + agent_actions = self.apply_agent_actions() # noqa # Advance timestep self.advance_timestep() @@ -133,12 +131,15 @@ class PrimaiteGame: def apply_agent_actions(self) -> None: """Apply all actions to simulation as requests.""" + agent_actions = {} for agent in self.agents: obs = agent.observation_manager.current_observation rew = agent.reward_function.current_reward action_choice, options = agent.get_action(obs, rew) + agent_actions[agent.agent_name] = (action_choice, options) request = agent.format_request(action_choice, options) self.simulation.apply_request(request) + return agent_actions def advance_timestep(self) -> None: """Advance timestep.""" @@ -227,31 +228,7 @@ class PrimaiteGame: operating_state=NodeOperatingState.ON, ) elif n_type == "router": - new_node = Router( - hostname=node_cfg["hostname"], - num_ports=node_cfg.get("num_ports"), - operating_state=NodeOperatingState.ON, - ) - if "ports" in node_cfg: - for port_num, port_cfg in node_cfg["ports"].items(): - new_node.configure_port( - port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"] - ) - # new_node.enable_port(port_num) - if "acl" in node_cfg: - for r_num, r_cfg in node_cfg["acl"].items(): - # excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating - # this: 'r_cfg.get('src_port')' - # Port/IPProtocol. TODO Refactor - new_node.acl.add_rule( - action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], - src_ip_address=r_cfg.get("ip_address"), - dst_ip_address=r_cfg.get("ip_address"), - position=r_num, - ) + new_node = Router.from_config(node_cfg) else: _LOGGER.warning(f"invalid node type {n_type} in config") if "services" in node_cfg: diff --git a/src/primaite/notebooks/_package_data/uc2_attack.png b/src/primaite/notebooks/_package_data/uc2_attack.png new file mode 100644 index 00000000..8b8df5ce Binary files /dev/null and b/src/primaite/notebooks/_package_data/uc2_attack.png differ diff --git a/src/primaite/notebooks/_package_data/uc2_network.png b/src/primaite/notebooks/_package_data/uc2_network.png new file mode 100644 index 00000000..20fa43c9 Binary files /dev/null and b/src/primaite/notebooks/_package_data/uc2_network.png differ diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/uc2_demo.ipynb index 3950ef10..679e8226 100644 --- a/src/primaite/notebooks/uc2_demo.ipynb +++ b/src/primaite/notebooks/uc2_demo.ipynb @@ -1,30 +1,337 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data Manipulation Scenario\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Scenario\n", + "\n", + "The network consists of an office subnet and a server subnet. Clients in the office access a website which fetches data from a database.\n", + "\n", + "[](_package_data/uc2_network.png)\n", + "\n", + "_(click image to enlarge)_\n", + "\n", + "The red agent deletes the contents of the database. When this happens, the web app cannot fetch data and users navigating to the website get a 404 error.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Network\n", + "\n", + "- The web server has:\n", + " - a web service that replies to user HTTP requests\n", + " - a database client that fetches data for the web service\n", + "- The database server has:\n", + " - a POSTGRES database service\n", + " - a database file which is accessed by the database service\n", + " - FTP client used for backing up the data to the backup_server\n", + "- The backup server has:\n", + " - a copy of the database file in a known good state\n", + " - FTP server that can send the backed up file back to the database server\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Green agent\n", + "\n", + "The green agent is logged onto client 2. It sometimes uses the web browser on client 2 to navigate to `http://arcd.com/users`. The web server replies with a status code 200 if the data is available on the database or 404 if not available." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Red agent\n", + "\n", + "The red agent waits a bit then sends a DELETE query to the database from client 1. If the delete is successful, the database file is flagged as compromised to signal that data is not available.\n", + "\n", + "[](_package_data/uc2_attack.png)\n", + "\n", + "_(click image to enlarge)_" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Blue agent\n", + "\n", + "The blue agent can view the entire network, but the health statuses of components are not updated until a scan is performed. The blue agent should restore the database file from backup after it was compromised. It can also prevent further attacks by blocking client 1 from reaching the database server. This can be done by removing client 1's network connection or adding ACL rules on the router to stop the packets from arriving." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Reinforcement learning details" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Scripted agents:\n", + "### Red\n", + "The red agent sits on client 1 and uses an application called DataManipulationBot whose sole purpose is to send a DELETE query to the database.\n", + "The red agent can choose one of two action each timestep:\n", + "1. do nothing\n", + "2. execute the data manipulation application\n", + "The schedule for selecting when to execute the application is controlled by three parameters:\n", + "- start time\n", + "- frequency\n", + "- variance\n", + "Attacks start at a random timestep between (start_time - variance) and (start_time + variance). After each attack, another is attempted after a random delay between (frequency - variance) and (frequency + variance) timesteps.\n", + "\n", + "The data manipulation app itself has an element of randomness because the attack has a probability of success. The default is 0.8 to succeed with the port scan step and 0.8 to succeed with the attack itself.\n", + "Upon a successful attack, the database file becomes corrupted which incurs a negative reward for the RL defender.\n", + "\n", + "The red agent does not use information about the state of the network to decide its action.\n", + "\n", + "### Green\n", + "The green agent sits on client 2 and uses the web browser application to send requests to the web server. The schedule of the green agent is currently random, meaning it will request webpage with a 50% probability, and do nothing with a 50% probability.\n", + "\n", + "When the green agent is blocked from accessing the data through the webpage, this incurs a negative reward to the RL defender." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Observation Space\n", + "\n", + "The blue agent's observation space is structured as nested dictionary with the following information:\n", + "```\n", + "\n", + "- NODES\n", + " - \n", + " - SERVICES\n", + " - \n", + " - operating_status\n", + " - health_status\n", + " - FOLDERS\n", + " - \n", + " - health_status\n", + " - FILES\n", + " - \n", + " - health_status\n", + " - NICS\n", + " - \n", + " - nic_status\n", + " - operating_status\n", + "- LINKS\n", + " - \n", + " - PROTOCOLS\n", + " - ALL\n", + " - load\n", + "- ACL\n", + " - \n", + " - position\n", + " - permission\n", + " - source_node_id\n", + " - source_port\n", + " - dest_node_id\n", + " - dest_port\n", + " - protocol\n", + "- ICS\n", + "```\n", + "\n", + "### Mappings\n", + "\n", + "The dict keys for `node_id` are in the following order:\n", + "|node_id|node name|\n", + "|--|--|\n", + "|1|domain_controller|\n", + "|2|web_server|\n", + "|3|database_server|\n", + "|4|backup_server|\n", + "|5|security_suite|\n", + "|6|client_1|\n", + "|7|client_2|\n", + "\n", + "Service 1 on node 2 (web_server) corresponds to the Web Server service. Other services are only there for padding to ensure that each node's observation space has the same shape. They are filled with zeroes.\n", + "\n", + "Folder 1 on node 3 corresponds to the database folder. File 1 in that folder corresponds to the database storage file. Other files and folders are only there for padding to ensure that each node's observation space has the same shape. They are filled with zeroes.\n", + "\n", + "The dict keys for `link_id` are in the following order:\n", + "|link_id|endpoint_a|endpoint_b|\n", + "|--|--|--|\n", + "|1|router_1|switch_1|\n", + "|1|router_1|switch_2|\n", + "|1|switch_1|domain_controller|\n", + "|1|switch_1|web_server|\n", + "|1|switch_1|database_server|\n", + "|1|switch_1|backup_server|\n", + "|1|switch_1|security_suite|\n", + "|1|switch_2|client_1|\n", + "|1|switch_2|client_2|\n", + "|1|switch_2|security_suite|\n", + "\n", + "The ACL rules in the observation space appear in the same order that they do in the actual ACL. Though, only the first 10 rules are shown, there are default rules lower down that cannot be changed by the agent. The extra rules just allow the network to function normally, by allowing pings, ARP traffic, etc.\n", + "\n", + "Most nodes have only 1 nic, so the observation for those is placed at NIC index 1 in the observation space. Only the security suite has 2 NICs, the second NIC in the observation space is the one that connects the security suite with swtich_2.\n", + "\n", + "The meaning of the services' operating_state is:\n", + "|operating_state|label|\n", + "|--|--|\n", + "|0|UNUSED|\n", + "|1|RUNNING|\n", + "|2|STOPPED|\n", + "|3|PAUSED|\n", + "|4|DISABLED|\n", + "|5|INSTALLING|\n", + "|6|RESTARTING|\n", + "\n", + "The meaning of the services' health_state is:\n", + "|health_state|label|\n", + "|--|--|\n", + "|0|UNUSED|\n", + "|1|GOOD|\n", + "|2|PATCHING|\n", + "|3|COMPROMISED|\n", + "|4|OVERWHELMED|\n", + "\n", + "The meaning of the files' and folders' health_state is:\n", + "|health_state|label|\n", + "|--|--|\n", + "|0|UNUSED|\n", + "|1|GOOD|\n", + "|2|COMPROMISED|\n", + "|3|CORRUPT|\n", + "|4|RESTORING|\n", + "|5|REPAIRING|\n", + "\n", + "The meaning of the NICs' operating_status is:\n", + "|operating_status|label|\n", + "|--|--|\n", + "|0|UNUSED|\n", + "|1|ENABLED|\n", + "|2|DISABLED|\n", + "\n", + "Link load has the following meaning:\n", + "|load|percent utilisation|\n", + "|--|--|\n", + "|0|exactly 0%|\n", + "|1|0-11%|\n", + "|2|11-22%|\n", + "|3|22-33%|\n", + "|4|33-44%|\n", + "|5|44-55%|\n", + "|6|55-66%|\n", + "|7|66-77%|\n", + "|8|77-88%|\n", + "|9|88-99%|\n", + "|10|exactly 100%|\n", + "\n", + "ACL permission has the following meaning:\n", + "|permission|label|\n", + "|--|--|\n", + "|0|UNUSED|\n", + "|1|ALLOW|\n", + "|2|DENY|\n", + "\n", + "ACL source / destination node ids actually correspond to IP addresses (since ACLs work with IP addresses)\n", + "|source / dest node id|ip_address|label|\n", + "|--|--|--|\n", + "|0| | UNUSED|\n", + "|1| |ALL addresses|\n", + "|2| 192.168.1.10 | domain_controller|\n", + "|3| 192.168.1.12 | web_server \n", + "|4| 192.168.1.14 | database_server|\n", + "|5| 192.168.1.16 | backup_server|\n", + "|6| 192.168.1.110 | security_suite (eth-1)|\n", + "|7| 192.168.10.21 | client_1|\n", + "|8| 192.168.10.22 | client_2|\n", + "|9| 192.168.10.110| security_suite (eth-2)|\n", + "\n", + "ACL source / destination port ids have the following encoding:\n", + "|port id|port number| port use |\n", + "|--|--|--|\n", + "|0||UNUSED|\n", + "|1||ALL|\n", + "|2|219|ARP|\n", + "|3|53|DNS|\n", + "|4|80|HTTP|\n", + "|5|5432|POSTGRES_SERVER|\n", + "\n", + "ACL protocol ids have the following encoding:\n", + "|protocol id|label|\n", + "|--|--|\n", + "|0|UNUSED|\n", + "|1|ALL|\n", + "|2|ICMP|\n", + "|3|TCP|\n", + "|4|UDP|\n", + "\n", + "protocol" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Action Space\n", + "\n", + "The blue agent chooses from a list of 54 pre-defined actions. The full list is defined in the `action_map` in the config. The most important ones are explained here:\n", + "\n", + "- `0`: Do nothing\n", + "- `1`: Scan the web service - this refreshes the health status in the observation space\n", + "- `9`: Scan the database file - this refreshes the health status of the database file\n", + "- `13`: Patch the database service - This triggers the database to restore data from the backup server\n", + "- `19`: Shut down client 1\n", + "- `22`: Block outgoing traffic from client 1\n", + "- `26`: Block TCP traffic from client 1 to the database node\n", + "- `28-37`: Remove ACL rules 1-10\n", + "- `42`: Disconnect client 1 from the network\n", + "\n", + "The other actions will either have no effect or will negatively impact the network, so the blue agent should avoid taking other actions, and learn about these actions." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Reward Function\n", + "\n", + "The blue agent's reward is calculated using two measures:\n", + "1. Whether the database file is in a good state (+1 for good, -1 for corrupted, 0 for any other state)\n", + "2. Whether the green agent's most recent webpage request was successful (+1 for a `200` return code, -1 for a `404` return code and 0 otherwise).\n", + "These two components are averaged to get the final reward.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Demonstration" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, load the required modules" + ] + }, { "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/cade/repos/PrimAITE/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "2023-11-26 23:25:47,985\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", - "2023-11-26 23:25:51,213\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", - "2023-11-26 23:25:51,491\tWARNING __init__.py:10 -- PG has/have been moved to `rllib_contrib` and will no longer be maintained by the RLlib team. You can still use it/them normally inside RLlib util Ray 2.8, but from Ray 2.9 on, all `rllib_contrib` algorithms will no longer be part of the core repo, and will therefore have to be installed separately with pinned dependencies for e.g. ray[rllib] and other packages! See https://github.com/ray-project/ray/tree/master/rllib_contrib#rllib-contrib for more information on the RLlib contrib effort.\n" - ] - } - ], + "outputs": [], "source": [ - "from primaite.session.session import PrimaiteSession\n", - "from primaite.game.game import PrimaiteGame\n", - "from primaite.config.load import example_config_path\n", - "\n", - "from primaite.simulator.system.services.database.database_service import DatabaseService\n", - "\n", - "import yaml" + "%load_ext autoreload\n", + "%autoreload 2" ] }, { @@ -36,61 +343,182 @@ "name": "stderr", "output_type": "stream", "text": [ - "2023-11-26 23:25:51,579::ERROR::primaite.simulator.network.hardware.base::175::NIC a9:92:0a:5e:1b:e4/127.0.0.1 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,580::ERROR::primaite.simulator.network.hardware.base::175::NIC ef:03:23:af:3c:19/127.0.0.1 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,581::ERROR::primaite.simulator.network.hardware.base::175::NIC ae:cf:83:2f:94:17/127.0.0.1 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,582::ERROR::primaite.simulator.network.hardware.base::175::NIC 4c:b2:99:e2:4a:5d/127.0.0.1 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,583::ERROR::primaite.simulator.network.hardware.base::175::NIC b9:eb:f9:c2:17:2f/127.0.0.1 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,590::ERROR::primaite.simulator.network.hardware.base::175::NIC cb:df:ca:54:be:01/192.168.1.10 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,595::ERROR::primaite.simulator.network.hardware.base::175::NIC 6e:32:12:da:4d:0d/192.168.1.12 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,600::ERROR::primaite.simulator.network.hardware.base::175::NIC 58:6e:9b:a7:68:49/192.168.1.14 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,604::ERROR::primaite.simulator.network.hardware.base::175::NIC 33:db:a6:40:dd:a3/192.168.1.16 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,608::ERROR::primaite.simulator.network.hardware.base::175::NIC 72:aa:2b:c0:4c:5f/192.168.1.110 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,610::ERROR::primaite.simulator.network.hardware.base::175::NIC 11:d7:0e:90:d9:a4/192.168.10.110 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,614::ERROR::primaite.simulator.network.hardware.base::175::NIC 86:2b:a4:e5:4d:0f/192.168.10.21 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,631::ERROR::primaite.simulator.network.hardware.base::175::NIC af:ad:8f:84:f1:db/192.168.10.22 cannot be enabled as it is not connected to a Link\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "installing DNSServer on node domain_controller\n", - "installing DatabaseClient on node web_server\n", - "installing WebServer on node web_server\n", - "installing DatabaseService on node database_server\n", - "installing FTPClient on node database_server\n", - "installing FTPServer on node backup_server\n", - "installing DNSClient on node client_1\n", - "installing DNSClient on node client_2\n" + "/home/cade/repos/PrimAITE/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "2024-01-25 14:43:32,056\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", + "2024-01-25 14:43:35,213\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n" ] } ], "source": [ + "# Imports\n", + "from primaite.config.load import example_config_path\n", + "from primaite.session.environment import PrimaiteGymEnv\n", + "from primaite.game.game import PrimaiteGame\n", + "import yaml\n", + "from pprint import pprint\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Instantiate the environment. We also disable the agent observation flattening.\n", "\n", - "with open(example_config_path(),'r') as cfgfile:\n", - " cfg = yaml.safe_load(cfgfile)\n", - "game = PrimaiteGame.from_config(cfg)\n", - "net = game.simulation.network\n", - "database_server = net.get_node_by_hostname('database_server')\n", - "web_server = net.get_node_by_hostname('web_server')\n", - "client_1 = net.get_node_by_hostname('client_1')\n", - "\n", - "db_service = database_server.software_manager.software[\"DatabaseService\"]\n", - "db_client = web_server.software_manager.software[\"DatabaseClient\"]\n", - "# db_client.run()\n", - "db_manipulation_bot = client_1.software_manager.software[\"DataManipulationBot\"]\n", - "db_manipulation_bot.port_scan_p_of_success=1.0\n", - "db_manipulation_bot.data_manipulation_p_of_success=1.0\n" + "This cell will print the observation when the network is healthy. You should be able to verify Node file and service statuses against the description above." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Resetting environment, episode 0, avg. reward: 0.0\n", + "env created successfully\n", + "{'ACL': {1: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 0,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 2: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 1,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 3: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 2,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 4: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 3,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 5: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 4,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 6: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 5,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 7: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 6,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 8: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 7,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 9: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 8,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 10: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 9,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0}},\n", + " 'ICS': 0,\n", + " 'LINKS': {1: {'PROTOCOLS': {'ALL': 1}},\n", + " 2: {'PROTOCOLS': {'ALL': 1}},\n", + " 3: {'PROTOCOLS': {'ALL': 1}},\n", + " 4: {'PROTOCOLS': {'ALL': 1}},\n", + " 5: {'PROTOCOLS': {'ALL': 1}},\n", + " 6: {'PROTOCOLS': {'ALL': 1}},\n", + " 7: {'PROTOCOLS': {'ALL': 1}},\n", + " 8: {'PROTOCOLS': {'ALL': 1}},\n", + " 9: {'PROTOCOLS': {'ALL': 1}},\n", + " 10: {'PROTOCOLS': {'ALL': 1}}},\n", + " 'NODES': {1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 1}},\n", + " 'health_status': 1}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1}}}\n" + ] + } + ], "source": [ - "db_client.run()" + "# create the env\n", + "with open(example_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + " # set success probability to 1.0 to avoid rerunning cells.\n", + " cfg['simulation']['network']['nodes'][8]['applications'][0]['options']['data_manipulation_p_of_success'] = 1.0\n", + " cfg['simulation']['network']['nodes'][8]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n", + "game = PrimaiteGame.from_config(cfg)\n", + "env = PrimaiteGymEnv(game = game)\n", + "# Don't flatten obs as we are not training an agent and we wish to see the dict-formatted observations\n", + "env.agent.flatten_obs = False\n", + "obs, info = env.reset()\n", + "print('env created successfully')\n", + "pprint(obs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The red agent will start attacking at some point between step 20 and 30. When this happens, the reward will go from 1.0 to 0.0, and to -1.0 when the green agent tries to access the webpage." ] }, { @@ -99,18 +527,55 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 1, Red action: DONOTHING, Blue reward:0.5\n", + "step: 2, Red action: DONOTHING, Blue reward:0.5\n", + "step: 3, Red action: DONOTHING, Blue reward:0.5\n", + "step: 4, Red action: DONOTHING, Blue reward:0.5\n", + "step: 5, Red action: DONOTHING, Blue reward:1.0\n", + "step: 6, Red action: DONOTHING, Blue reward:1.0\n", + "step: 7, Red action: DONOTHING, Blue reward:1.0\n", + "step: 8, Red action: DONOTHING, Blue reward:1.0\n", + "step: 9, Red action: DONOTHING, Blue reward:1.0\n", + "step: 10, Red action: DONOTHING, Blue reward:1.0\n", + "step: 11, Red action: DONOTHING, Blue reward:1.0\n", + "step: 12, Red action: DONOTHING, Blue reward:1.0\n", + "step: 13, Red action: DONOTHING, Blue reward:1.0\n", + "step: 14, Red action: DONOTHING, Blue reward:1.0\n", + "step: 15, Red action: DONOTHING, Blue reward:1.0\n", + "step: 16, Red action: DONOTHING, Blue reward:1.0\n", + "step: 17, Red action: DONOTHING, Blue reward:1.0\n", + "step: 18, Red action: DONOTHING, Blue reward:1.0\n", + "step: 19, Red action: DONOTHING, Blue reward:1.0\n", + "step: 20, Red action: DONOTHING, Blue reward:1.0\n", + "step: 21, Red action: DONOTHING, Blue reward:1.0\n", + "step: 22, Red action: NODE_APPLICATION_EXECUTE, Blue reward:0.0\n", + "step: 23, Red action: DONOTHING, Blue reward:0.0\n", + "step: 24, Red action: DONOTHING, Blue reward:0.0\n", + "step: 25, Red action: DONOTHING, Blue reward:0.0\n", + "step: 26, Red action: DONOTHING, Blue reward:-1.0\n", + "step: 27, Red action: DONOTHING, Blue reward:-1.0\n", + "step: 28, Red action: DONOTHING, Blue reward:-1.0\n", + "step: 29, Red action: DONOTHING, Blue reward:-1.0\n", + "step: 30, Red action: DONOTHING, Blue reward:-1.0\n", + "step: 31, Red action: DONOTHING, Blue reward:-1.0\n", + "step: 32, Red action: DONOTHING, Blue reward:-1.0\n" + ] } ], "source": [ - "db_service.backup_database()" + "for step in range(32):\n", + " obs, reward, terminated, truncated, info = env.step(0)\n", + " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}, Blue reward:{reward}\" )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now the reward is -1, let's have a look at blue agent's observation." ] }, { @@ -119,27 +584,110 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "{1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 1}}, 'health_status': 1}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1}}\n" + ] } ], "source": [ - "db_client.query(\"SELECT\")" + "pprint(obs['NODES'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The true statuses of the database file and webapp are not updated. The blue agent needs to perform a scan to see that they have degraded." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 3, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 2}}, 'health_status': 1}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1}}\n" + ] + } + ], "source": [ - "db_manipulation_bot.run()" + "obs, reward, terminated, truncated, info = env.step(9) # scan database file\n", + "obs, reward, terminated, truncated, info = env.step(1) # scan webapp service\n", + "pprint(obs['NODES'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now service 1 on node 2 has `health_status = 3`, indicating that the webapp is compromised.\n", + "File 1 in folder 1 on node 3 has `health_status = 2`, indicating that the database file is compromised." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The blue agent can now patch the database to restore the file to a good health status." ] }, { @@ -148,18 +696,33 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 35\n", + "Red action: DONOTHING\n", + "Green action: NODE_APPLICATION_EXECUTE\n", + "Blue reward:-1.0\n" + ] } ], "source": [ - "db_client.query(\"SELECT\")" + "obs, reward, terminated, truncated, info = env.step(13) # patch the database\n", + "print(f\"step: {env.game.step_counter}\")\n", + "print(f\"Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}\" )\n", + "print(f\"Green action: {info['agent_actions']['client_2_green_user'][0]}\" )\n", + "print(f\"Blue reward:{reward}\" )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The patching takes two steps, so the reward hasn't changed yet. Let's do nothing for another timestep, the reward should improve.\n", + "\n", + "The reward will be 0 as soon as the file finishes restoring. Then, the reward will increase to 1 when the green agent makes a request. (Because the webapp access part of the reward does not update until a successful request is made.)\n", + "\n", + "Run the following cell until the green action is `NODE_APPLICATION_EXECUTE`, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again." ] }, { @@ -168,18 +731,29 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 36\n", + "Red action: DONOTHING\n", + "Green action: NODE_APPLICATION_EXECUTE\n", + "Blue reward:0.0\n" + ] } ], "source": [ - "db_service.restore_backup()" + "obs, reward, terminated, truncated, info = env.step(0) # patch the database\n", + "print(f\"step: {env.game.step_counter}\")\n", + "print(f\"Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}\" )\n", + "print(f\"Green action: {info['agent_actions']['client_2_green_user'][0]}\" )\n", + "print(f\"Blue reward:{reward}\" )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The blue agent can prevent attacks by implementing an ACL rule to stop client_1 from sending POSTGRES traffic to the database. (Let's also patch the database file to get the reward back up.)" ] }, { @@ -188,90 +762,157 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 37, Red action: DONOTHING, Blue reward:0.0\n", + "step: 38, Red action: DONOTHING, Blue reward:0.0\n", + "step: 39, Red action: DONOTHING, Blue reward:1.0\n", + "step: 40, Red action: DONOTHING, Blue reward:1.0\n", + "step: 41, Red action: DONOTHING, Blue reward:1.0\n", + "step: 42, Red action: DONOTHING, Blue reward:1.0\n", + "step: 43, Red action: DONOTHING, Blue reward:1.0\n", + "step: 44, Red action: DONOTHING, Blue reward:1.0\n", + "step: 45, Red action: DONOTHING, Blue reward:1.0\n", + "step: 46, Red action: NODE_APPLICATION_EXECUTE, Blue reward:1.0\n", + "step: 47, Red action: DONOTHING, Blue reward:1.0\n", + "step: 48, Red action: DONOTHING, Blue reward:1.0\n", + "step: 49, Red action: DONOTHING, Blue reward:1.0\n", + "step: 50, Red action: DONOTHING, Blue reward:1.0\n", + "step: 51, Red action: DONOTHING, Blue reward:1.0\n", + "step: 52, Red action: DONOTHING, Blue reward:1.0\n", + "step: 53, Red action: DONOTHING, Blue reward:1.0\n", + "step: 54, Red action: DONOTHING, Blue reward:1.0\n", + "step: 55, Red action: DONOTHING, Blue reward:1.0\n", + "step: 56, Red action: DONOTHING, Blue reward:1.0\n", + "step: 57, Red action: DONOTHING, Blue reward:1.0\n", + "step: 58, Red action: DONOTHING, Blue reward:1.0\n", + "step: 59, Red action: DONOTHING, Blue reward:1.0\n", + "step: 60, Red action: DONOTHING, Blue reward:1.0\n", + "step: 61, Red action: DONOTHING, Blue reward:1.0\n", + "step: 62, Red action: DONOTHING, Blue reward:1.0\n", + "step: 63, Red action: DONOTHING, Blue reward:1.0\n", + "step: 64, Red action: DONOTHING, Blue reward:1.0\n", + "step: 65, Red action: DONOTHING, Blue reward:1.0\n", + "step: 66, Red action: DONOTHING, Blue reward:1.0\n", + "step: 67, Red action: DONOTHING, Blue reward:1.0\n", + "step: 68, Red action: DONOTHING, Blue reward:1.0\n" + ] } ], "source": [ - "db_client.query(\"SELECT\")" + "env.step(13) # Patch the database\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}, Blue reward:{reward}\" )\n", + "\n", + "env.step(26) # Block client 1\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}, Blue reward:{reward}\" )\n", + "\n", + "for step in range(30):\n", + " obs, reward, terminated, truncated, info = env.step(0) # do nothing\n", + " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}, Blue reward:{reward}\" )" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "db_manipulation_bot.run()" + "Now, even though the red agent executes an attack, the reward stays at 1.0" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "client_1.ping(database_server.ethernet_port[1].ip_address)" + "Let's also have a look at the ACL observation to verify our new ACL rule at position 5." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, - "outputs": [], - "source": [ - "from pydantic import validate_call, BaseModel" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "class A(BaseModel):\n", - " x:int\n", - "\n", - " @validate_call\n", - " def increase_x(self, by:int) -> None:\n", - " self.x += 1" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "my_a = A(x=3)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, "outputs": [ { - "ename": "ValidationError", - "evalue": "1 validation error for increase_x\n0\n Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=3.2, input_type=float]\n For further information visit https://errors.pydantic.dev/2.1/v/int_from_float", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/home/cade/repos/PrimAITE/src/primaite/notebooks/uc2_demo.ipynb Cell 15\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m my_a\u001b[39m.\u001b[39;49mincrease_x(\u001b[39m3.2\u001b[39;49m)\n", - "File \u001b[0;32m~/repos/PrimAITE/venv/lib/python3.10/site-packages/pydantic/_internal/_validate_call.py:91\u001b[0m, in \u001b[0;36mValidateCallWrapper.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs: Any, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs: Any) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Any:\n\u001b[0;32m---> 91\u001b[0m res \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__pydantic_validator__\u001b[39m.\u001b[39;49mvalidate_python(pydantic_core\u001b[39m.\u001b[39;49mArgsKwargs(args, kwargs))\n\u001b[1;32m 92\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__return_pydantic_validator__:\n\u001b[1;32m 93\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__return_pydantic_validator__\u001b[39m.\u001b[39mvalidate_python(res)\n", - "\u001b[0;31mValidationError\u001b[0m: 1 validation error for increase_x\n0\n Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=3.2, input_type=float]\n For further information visit https://errors.pydantic.dev/2.1/v/int_from_float" - ] + "data": { + "text/plain": [ + "{1: {'position': 0,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 2: {'position': 1,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 3: {'position': 2,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 4: {'position': 3,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 5: {'position': 4,\n", + " 'permission': 2,\n", + " 'source_node_id': 7,\n", + " 'source_port': 1,\n", + " 'dest_node_id': 4,\n", + " 'dest_port': 1,\n", + " 'protocol': 3},\n", + " 6: {'position': 5,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 7: {'position': 6,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 8: {'position': 7,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 9: {'position': 8,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 10: {'position': 9,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0}}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "my_a.increase_x(3.2)" + "obs['ACL']" ] }, { diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 6701f183..a3831bc1 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -29,7 +29,7 @@ class PrimaiteGymEnv(gymnasium.Env): # make ProxyAgent store the action chosen my the RL policy self.agent.store_action(action) # apply_agent_actions accesses the action we just stored - self.game.apply_agent_actions() + agent_actions = self.game.apply_agent_actions() self.game.advance_timestep() state = self.game.get_sim_state() @@ -39,7 +39,7 @@ class PrimaiteGymEnv(gymnasium.Env): reward = self.agent.reward_function.current_reward terminated = False truncated = self.game.calculate_truncated() - info = {} + info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience. if self.game.save_step_metadata: self._write_step_metadata_json(action, state, reward) return next_obs, reward, terminated, truncated, info @@ -172,7 +172,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): # 1. Perform actions for agent_name, action in actions.items(): self.agents[agent_name].store_action(action) - self.game.apply_agent_actions() + agent_actions = self.game.apply_agent_actions() # 2. Advance timestep self.game.advance_timestep() @@ -186,7 +186,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()} terminateds = {name: False for name, _ in self.agents.items()} truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()} - infos = {} + infos = {"agent_actions": agent_actions} terminateds["__all__"] = len(self.terminateds) == len(self.agents) truncateds["__all__"] = self.game.calculate_truncated() if self.game.save_step_metadata: diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index c2eb0d2d..149bf083 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -23,7 +23,6 @@ class FileSystem(SimComponent): "List containing all the folders in the file system." deleted_folders: Dict[str, Folder] = {} "List containing all the folders that have been deleted." - _folders_by_name: Dict[str, Folder] = {} sys_log: SysLog "Instance of SysLog used to create system logs." sim_root: Path @@ -56,7 +55,6 @@ class FileSystem(SimComponent): folder = self.deleted_folders[uuid] self.deleted_folders.pop(uuid) self.folders[uuid] = folder - self._folders_by_name[folder.name] = folder # Clear any other deleted folders that aren't original (have been created by agent) self.deleted_folders.clear() @@ -67,7 +65,6 @@ class FileSystem(SimComponent): if uuid not in original_folder_uuids: folder = self.folders[uuid] self.folders.pop(uuid) - self._folders_by_name.pop(folder.name) # Now reset all remaining folders for folder in self.folders.values(): @@ -173,7 +170,6 @@ class FileSystem(SimComponent): folder = Folder(name=folder_name, sys_log=self.sys_log) self.folders[folder.uuid] = folder - self._folders_by_name[folder.name] = folder self._folder_request_manager.add_request( name=folder.uuid, request_type=RequestType(func=folder._request_manager) ) @@ -188,14 +184,13 @@ class FileSystem(SimComponent): if folder_name == "root": self.sys_log.warning("Cannot delete the root folder.") return - folder = self._folders_by_name.get(folder_name) + folder = self.get_folder(folder_name) if folder: # set folder to deleted state folder.delete() # remove from folder list self.folders.pop(folder.uuid) - self._folders_by_name.pop(folder.name) # add to deleted list folder.remove_all_files() @@ -221,7 +216,10 @@ class FileSystem(SimComponent): :param folder_name: The folder name. :return: The matching Folder. """ - return self._folders_by_name.get(folder_name) + for folder in self.folders.values(): + if folder.name == folder_name: + return folder + return None def get_folder_by_id(self, folder_uuid: str, include_deleted: bool = False) -> Optional[Folder]: """ @@ -261,13 +259,13 @@ class FileSystem(SimComponent): """ if folder_name: # check if file with name already exists - folder = self._folders_by_name.get(folder_name) + folder = self.get_folder(folder_name) # If not then create it if not folder: folder = self.create_folder(folder_name) else: # Use root folder if folder_name not supplied - folder = self._folders_by_name["root"] + folder = self.get_folder("root") # Create the file and add it to the folder file = File( @@ -474,7 +472,6 @@ class FileSystem(SimComponent): folder.restore() self.folders[folder.uuid] = folder - self._folders_by_name[folder.name] = folder if folder.deleted: self.deleted_folders.pop(folder.uuid) diff --git a/src/primaite/simulator/file_system/file_system_item_abc.py b/src/primaite/simulator/file_system/file_system_item_abc.py index 86cd1ee7..c3e1426b 100644 --- a/src/primaite/simulator/file_system/file_system_item_abc.py +++ b/src/primaite/simulator/file_system/file_system_item_abc.py @@ -87,7 +87,7 @@ class FileSystemItemABC(SimComponent): def set_original_state(self): """Sets the original state.""" - vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red"} + vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red", "deleted"} self._original_state = self.model_dump(include=vals_to_keep) def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index 237a6341..ab862898 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -17,8 +17,6 @@ class Folder(FileSystemItemABC): files: Dict[str, File] = {} "Files stored in the folder." - _files_by_name: Dict[str, File] = {} - "Files by their name as .." deleted_files: Dict[str, File] = {} "Files that have been deleted." @@ -78,7 +76,6 @@ class Folder(FileSystemItemABC): file = self.deleted_files[uuid] self.deleted_files.pop(uuid) self.files[uuid] = file - self._files_by_name[file.name] = file # Clear any other deleted files that aren't original (have been created by agent) self.deleted_files.clear() @@ -89,7 +86,6 @@ class Folder(FileSystemItemABC): if uuid not in original_file_uuids: file = self.files[uuid] self.files.pop(uuid) - self._files_by_name.pop(file.name) # Now reset all remaining files for file in self.files.values(): @@ -219,7 +215,10 @@ class Folder(FileSystemItemABC): :return: The matching File. """ # TODO: Increment read count? - return self._files_by_name.get(file_name) + for file in self.files.values(): + if file.name == file_name: + return file + return None def get_file_by_id(self, file_uuid: str, include_deleted: Optional[bool] = False) -> File: """ @@ -250,15 +249,14 @@ class Folder(FileSystemItemABC): raise Exception(f"Invalid file: {file}") # check if file with id or name already exists in folder - if (force is not True) and file.name in self._files_by_name: + if self.get_file(file.name) is not None and not force: raise Exception(f"File with name {file.name} already exists in folder") - if (force is not True) and file.uuid in self.files: + if (file.uuid in self.files) and not force: raise Exception(f"File with uuid {file.uuid} already exists in folder") # add to list self.files[file.uuid] = file - self._files_by_name[file.name] = file self._file_request_manager.add_request(file.uuid, RequestType(func=file._request_manager)) file.folder = self @@ -275,11 +273,10 @@ class Folder(FileSystemItemABC): if self.files.get(file.uuid): self.files.pop(file.uuid) - self._files_by_name.pop(file.name) self.deleted_files[file.uuid] = file file.delete() self.sys_log.info(f"Removed file {file.name} (id: {file.uuid})") - self._file_request_manager.remove_request(file.uuid) + # self._file_request_manager.remove_request(file.uuid) else: _LOGGER.debug(f"File with UUID {file.uuid} was not found.") @@ -300,7 +297,6 @@ class Folder(FileSystemItemABC): self.deleted_files[file_id] = file self.files = {} - self._files_by_name = {} def restore_file(self, file_uuid: str): """ @@ -316,7 +312,6 @@ class Folder(FileSystemItemABC): file.restore() self.files[file.uuid] = file - self._files_by_name[file.name] = file if file.deleted: self.deleted_files.pop(file_uuid) diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index 0234934d..41c14967 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -9,6 +9,7 @@ from prettytable import MARKDOWN, PrettyTable from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader @@ -18,8 +19,8 @@ from primaite.simulator.system.core.sys_log import SysLog class ACLAction(Enum): """Enum for defining the ACL action types.""" - DENY = 0 PERMIT = 1 + DENY = 2 class ACLRule(SimComponent): @@ -65,11 +66,11 @@ class ACLRule(SimComponent): """ state = super().describe_state() state["action"] = self.action.value - state["protocol"] = self.protocol.value if self.protocol else None + state["protocol"] = self.protocol.name if self.protocol else None state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None - state["src_port"] = self.src_port.value if self.src_port else None + state["src_port"] = self.src_port.name if self.src_port else None state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None - state["dst_port"] = self.dst_port.value if self.dst_port else None + state["dst_port"] = self.dst_port.name if self.dst_port else None return state @@ -89,6 +90,8 @@ class AccessControlList(SimComponent): implicit_rule: ACLRule max_acl_rules: int = 25 _acl: List[Optional[ACLRule]] = [None] * 24 + _default_config: Dict[int, dict] = {} + """Config dict describing how the ACL list should look at episode start""" def __init__(self, **kwargs) -> None: if not kwargs.get("implicit_action"): @@ -106,10 +109,40 @@ class AccessControlList(SimComponent): vals_to_keep = {"implicit_action", "max_acl_rules", "acl"} self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True) + for i, rule in enumerate(self._acl): + if not rule: + continue + self._default_config[i] = {"action": rule.action.name} + if rule.src_ip_address: + self._default_config[i]["src_ip"] = str(rule.src_ip_address) + if rule.dst_ip_address: + self._default_config[i]["dst_ip"] = str(rule.dst_ip_address) + if rule.src_port: + self._default_config[i]["src_port"] = rule.src_port.name + if rule.dst_port: + self._default_config[i]["dst_port"] = rule.dst_port.name + if rule.protocol: + self._default_config[i]["protocol"] = rule.protocol.name + def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" self.implicit_rule.reset_component_for_episode(episode) super().reset_component_for_episode(episode) + self._reset_rules_to_default() + + def _reset_rules_to_default(self) -> None: + """Clear all ACL rules and set them to the default rules config.""" + self._acl = [None] * (self.max_acl_rules - 1) + for r_num, r_cfg in self._default_config.items(): + self.add_rule( + action=ACLAction[r_cfg["action"]], + src_port=None if not (p := r_cfg.get("src_port")) else Port[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], + protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_ip_address=r_cfg.get("src_ip"), + dst_ip_address=r_cfg.get("dst_ip"), + position=r_num, + ) def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() @@ -391,7 +424,6 @@ class RouteTable(SimComponent): sys_log: SysLog def set_original_state(self): - """Sets the original state.""" """Sets the original state.""" super().set_original_state() self._original_state["routes_orig"] = self.routes @@ -716,8 +748,8 @@ class Router(Node): :return: A dictionary representing the current state. """ state = super().describe_state() - state["num_ports"] = (self.num_ports,) - state["acl"] = (self.acl.describe_state(),) + state["num_ports"] = self.num_ports + state["acl"] = self.acl.describe_state() return state def route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None: @@ -864,3 +896,63 @@ class Router(Node): ] ) print(table) + + @classmethod + def from_config(cls, cfg: dict) -> "Router": + """Create a router based on a config dict. + + Schema: + - hostname (str): unique name for this router. + - num_ports (int, optional): Number of network ports on the router. 8 by default + - ports (dict): Dict with integers from 1 - num_ports as keys. The values should be another dict specifying + ip_address and subnet_mask assigned to that ports (as strings) + - acl (dict): Dict with integers from 1 - max_acl_rules as keys. The key defines the position within the ACL + where the rule will be added (lower number is resolved first). The values should describe valid ACL + Rules as: + - action (str): either PERMIT or DENY + - src_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER + - dst_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER + - protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP + - src_ip_address (str, optional): IP address octet written in base 10 + - dst_ip_address (str, optional): IP address octet written in base 10 + + Example config: + ``` + { + 'hostname': 'router_1', + 'num_ports': 5, + 'ports': { + 1: { + 'ip_address' : '192.168.1.1', + 'subnet_mask' : '255.255.255.0', + } + }, + 'acl' : { + 21: {'action': 'PERMIT', 'src_port': 'HTTP', dst_port: 'HTTP'}, + 22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'}, + 23: {'action': 'PERMIT', 'protocol': 'ICMP'}, + }, + } + ``` + + :param cfg: Router config adhering to schema described in main docstring body + :type cfg: dict + :return: Configured router. + :rtype: Router + """ + new = Router( + hostname=cfg["hostname"], + num_ports=cfg.get("num_ports"), + operating_state=NodeOperatingState.ON, + ) + if "ports" in cfg: + for port_num, port_cfg in cfg["ports"].items(): + new.configure_port( + port=port_num, + ip_address=port_cfg["ip_address"], + subnet_mask=port_cfg["subnet_mask"], + ) + if "acl" in cfg: + new.acl._default_config = cfg["acl"] # save the config to allow resetting + new.acl._reset_rules_to_default() # read the config and apply rules + return new diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 1cdd0390..14190dd2 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -84,6 +84,10 @@ class DatabaseService(Service): ftp_client_service: FTPClient = software_manager.software.get("FTPClient") # send backup copy of database file to FTP server + if not self.db_file: + self.sys_log.error("Attempted to backup database file but it doesn't exist.") + return False + response = ftp_client_service.send_file( dest_ip_address=self.backup_server_ip, src_file_name=self.db_file.name, @@ -121,7 +125,7 @@ class DatabaseService(Service): return False # replace db file - self.file_system.delete_file(folder_name="database", file_name="downloads.db") + self.file_system.delete_file(folder_name="database", file_name="database.db") self.file_system.copy_file(src_folder_name="downloads", src_file_name="database.db", dst_folder_name="database") if self.db_file is None: diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 9070f246..4a1fc275 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -19,7 +19,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: @@ -491,6 +491,23 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 reward_function: reward_components: diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index e67f6606..c8ffa23f 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -23,7 +23,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: @@ -502,6 +502,23 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 reward_function: reward_components: diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml index 220ca21e..6cd22694 100644 --- a/tests/assets/configs/multi_agent_session.yaml +++ b/tests/assets/configs/multi_agent_session.yaml @@ -29,7 +29,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: @@ -509,6 +509,23 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 reward_function: reward_components: @@ -940,6 +957,23 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 reward_function: reward_components: diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index d7e94cb6..99087798 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -27,7 +27,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: @@ -507,6 +507,23 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 reward_function: reward_components: diff --git a/tests/assets/configs/train_only_primaite_session.yaml b/tests/assets/configs/train_only_primaite_session.yaml index b89349c0..c2842a06 100644 --- a/tests/assets/configs/train_only_primaite_session.yaml +++ b/tests/assets/configs/train_only_primaite_session.yaml @@ -23,7 +23,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: @@ -503,6 +503,23 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 reward_function: reward_components: diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_container.py b/tests/unit_tests/_primaite/_simulator/_network/test_container.py index e348838e..7667a59f 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_container.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_container.py @@ -10,6 +10,22 @@ from primaite.simulator.system.applications.database_client import DatabaseClien from primaite.simulator.system.services.database.database_service import DatabaseService +def filter_keys_nested_item(data, keys): + stack = [(data, {})] + while stack: + current, filtered = stack.pop() + if isinstance(current, dict): + for k, v in current.items(): + if k in keys: + filtered[k] = filter_keys_nested_item(v, keys) + elif isinstance(v, (dict, list)): + stack.append((v, {})) + elif isinstance(current, list): + for item in current: + stack.append((item, {})) + return filtered + + @pytest.fixture(scope="function") def network(example_network) -> Network: assert len(example_network.routers) is 1 @@ -59,10 +75,10 @@ def test_reset_network(network): assert client_1.operating_state is NodeOperatingState.ON assert server_1.operating_state is NodeOperatingState.ON - - assert json.dumps(network.describe_state(), sort_keys=True, indent=2) == json.dumps( - state_before, sort_keys=True, indent=2 - ) + # don't worry if UUIDs change + a = filter_keys_nested_item(json.dumps(network.describe_state(), sort_keys=True, indent=2), ["uuid"]) + b = filter_keys_nested_item(json.dumps(state_before, sort_keys=True, indent=2), ["uuid"]) + assert a == b def test_creating_container():