From 539577ddc36aeb7591a430e8319d98dee9e29e41 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Thu, 13 Jun 2024 11:48:13 +0100 Subject: [PATCH] #2658: added monitored traffic into config + default obs space --- .../_package_data/data_manipulation.yaml | 5 + .../agent/observations/host_observations.py | 12 ++- .../agent/observations/nic_observations.py | 89 +++++++++++++-- .../agent/observations/node_observations.py | 4 + .../simulator/network/hardware/base.py | 9 +- .../configs/basic_switched_network.yaml | 5 + .../observations/test_nic_observations.py | 102 +++++++++++++++++- .../observations/test_node_observations.py | 1 + 8 files changed, 216 insertions(+), 11 deletions(-) diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index e3d68706..6cded5f2 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -188,6 +188,11 @@ agents: num_nics: 2 include_num_access: false include_nmne: true + monitored_traffic: + icmp: + - NONE + tcp: + - DNS routers: - hostname: router_1 num_ports: 0 diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 02c0d17f..35e08424 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -43,6 +43,8 @@ class HostObservation(AbstractObservation, identifier="HOST"): """Number of spaces for network interface observations on this host.""" include_nmne: Optional[bool] = None """Whether network interface observations should include number of malicious network events.""" + monitored_traffic: Optional[Dict] = None + """A dict containing which traffic types are to be included in the observation.""" include_num_access: Optional[bool] = None """Whether to include the number of accesses to files observations on this host.""" @@ -59,6 +61,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): num_files: int, num_nics: int, include_nmne: bool, + monitored_traffic: Optional[Dict], include_num_access: bool, ) -> None: """ @@ -87,6 +90,8 @@ class HostObservation(AbstractObservation, identifier="HOST"): :type num_nics: int :param include_nmne: Flag to include network metrics and errors. :type include_nmne: bool + :param monitored_traffic: Dict which contains the protocol and ports to observe + :type monitored_traffic: Dict :param include_num_access: Flag to include the number of accesses to files. :type include_num_access: bool """ @@ -123,7 +128,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): self.nics: List[NICObservation] = network_interfaces while len(self.nics) < num_nics: - self.nics.append(NICObservation(where=None, include_nmne=include_nmne)) + self.nics.append(NICObservation(where=None, include_nmne=include_nmne, monitored_traffic=monitored_traffic)) while len(self.nics) > num_nics: truncated_nic = self.nics.pop() msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}" @@ -231,7 +236,9 @@ class HostObservation(AbstractObservation, identifier="HOST"): # monitor the first N interfaces. Network interface numbering starts at 1. count = 1 while len(nics) < config.num_nics: - nic_config = NICObservation.ConfigSchema(nic_num=count, include_nmne=config.include_nmne) + nic_config = NICObservation.ConfigSchema( + nic_num=count, include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic + ) nics.append(NICObservation.from_config(config=nic_config, parent_where=where)) count += 1 @@ -247,5 +254,6 @@ class HostObservation(AbstractObservation, identifier="HOST"): num_files=config.num_files, num_nics=config.num_nics, include_nmne=config.include_nmne, + monitored_traffic=config.monitored_traffic, include_num_access=config.include_num_access, ) diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index afce9095..56494748 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -7,6 +7,7 @@ from gymnasium.core import ObsType from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.simulator.network.transmission.transport_layer import Port class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): @@ -19,12 +20,10 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): """Number of the network interface.""" include_nmne: Optional[bool] = None """Whether to include number of malicious network events (NMNE) in the observation.""" + monitored_traffic: Optional[Dict] = None + """A dict containing which traffic types are to be included in the observation.""" - def __init__( - self, - where: WhereType, - include_nmne: bool, - ) -> None: + def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None: """ Initialise a network interface observation instance. @@ -49,6 +48,26 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): self.med_nmne_threshold = 5 self.low_nmne_threshold = 0 + self.monitored_traffic = monitored_traffic + if self.monitored_traffic: + self.default_observation.update( + self._default_monitored_traffic_observation(monitored_traffic_config=monitored_traffic) + ) + + def _default_monitored_traffic_observation(self, monitored_traffic_config: Dict) -> Dict: + default_traffic_obs = {"monitored_traffic": {}} + + for protocol in monitored_traffic_config: + default_traffic_obs["monitored_traffic"][str(protocol).lower()] = {} + + if protocol == "icmp": + default_traffic_obs["monitored_traffic"]["icmp"] = {"inbound": 0, "outbound": 0} + else: + for port in monitored_traffic_config[protocol]: + default_traffic_obs["monitored_traffic"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0} + + return default_traffic_obs + def _categorise_mne_count(self, nmne_count: int) -> int: """ Categorise the number of Malicious Network Events (NMNEs) into discrete bins. @@ -72,6 +91,16 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): return 1 return 0 + def _categorise_traffic(self, traffic_value: float, nic_state: Dict) -> int: + """Categorise the traffic into discrete categories.""" + if traffic_value == 0: + return 0 + + nic_max_bandwidth = nic_state.get("speed") + + bandwidth_utilisation = traffic_value / nic_max_bandwidth + return int(bandwidth_utilisation * 9) + 1 + def observe(self, state: Dict) -> ObsType: """ Generate observation based on the current state of the simulation. @@ -87,6 +116,50 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): return self.default_observation obs = {"nic_status": 1 if nic_state["enabled"] else 2} + + # if the observation was configured to monitor traffic from ports/protocols + if self.monitored_traffic: + obs["monitored_traffic"] = {} + + # iterate through the protocols + for protocol in self.monitored_traffic: + obs["monitored_traffic"][str(protocol).lower()] = {} + # check if the nic has seen traffic with this protocol + if nic_state["traffic"].get(protocol): + # deal with icmp + if str(protocol).lower() == "icmp": + obs["monitored_traffic"][protocol] = { + "inbound": self._categorise_traffic( + traffic_value=nic_state["traffic"]["icmp"]["inbound"], nic_state=nic_state + ), + "outbound": self._categorise_traffic( + traffic_value=nic_state["traffic"]["icmp"]["outbound"], nic_state=nic_state + ), + } + else: + for port in self.monitored_traffic[protocol]: + port_enum = Port[port] + obs["monitored_traffic"][str(protocol).lower()][port_enum.value] = {} + traffic = {"inbound": 0, "outbound": 0} + + if nic_state["traffic"][protocol].get(port_enum.value) is not None: + traffic = nic_state["traffic"][protocol][port_enum.value] + + obs["monitored_traffic"][protocol][port_enum.value]["inbound"] = self._categorise_traffic( + traffic_value=traffic["inbound"], nic_state=nic_state + ) + obs["monitored_traffic"][protocol][port_enum.value]["outbound"] = self._categorise_traffic( + traffic_value=traffic["outbound"], nic_state=nic_state + ) + + # set all the ports under the protocol to 0 + else: + if str(protocol).lower() == "icmp": + obs["monitored_traffic"]["icmp"] = {"inbound": 0, "outbound": 0} + else: + for port in self.monitored_traffic[protocol]: + obs["monitored_traffic"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0} + if self.include_nmne: obs.update({"NMNE": {}}) direction_dict = nic_state["nmne"].get("direction", {}) @@ -128,7 +201,11 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): :return: Constructed network interface observation instance. :rtype: NICObservation """ - return cls(where=parent_where + ["NICs", config.nic_num], include_nmne=config.include_nmne) + return cls( + where=parent_where + ["NICs", config.nic_num], + include_nmne=config.include_nmne, + monitored_traffic=config.monitored_traffic, + ) class PortObservation(AbstractObservation, identifier="PORT"): diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 8f7ac0fc..9d82b4aa 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -39,6 +39,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): """Number of network interface cards (NICs).""" include_nmne: Optional[bool] = None """Flag to include nmne.""" + monitored_traffic: Optional[Dict] = None + """A dict containing which traffic types are to be included in the observation.""" include_num_access: Optional[bool] = None """Flag to include the number of accesses.""" num_ports: Optional[int] = None @@ -180,6 +182,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): host_config.num_nics = config.num_nics if host_config.include_nmne is None: host_config.include_nmne = config.include_nmne + if host_config.monitored_traffic is None: + host_config.monitored_traffic = config.monitored_traffic if host_config.include_num_access is None: host_config.include_num_access = config.include_num_access diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index e9560b91..f63efc5a 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -277,12 +277,12 @@ class NetworkInterface(SimComponent, ABC): if protocol != IPProtocol.ICMP: if port not in self.traffic[protocol]: self.traffic[protocol][port] = {"inbound": 0, "outbound": 0} - self.traffic[protocol][port][direction] += frame.size + self.traffic[protocol][port][direction] += frame.size_Mbits else: # Handle ICMP protocol separately (ICMP does not use ports) if not self.traffic[protocol]: self.traffic[protocol] = {"inbound": 0, "outbound": 0} - self.traffic[protocol][direction] += frame.size + self.traffic[protocol][direction] += frame.size_Mbits @abstractmethod def send_frame(self, frame: Frame) -> bool: @@ -325,6 +325,11 @@ class NetworkInterface(SimComponent, ABC): """ super().apply_timestep(timestep=timestep) + def pre_timestep(self, timestep: int) -> None: + """Apply pre-timestep logic.""" + super().pre_timestep(timestep) + self.traffic = {} + class WiredNetworkInterface(NetworkInterface, ABC): """ diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index 0cbaefdb..7d40075d 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -84,6 +84,11 @@ agents: num_files: 1 num_nics: 2 include_num_access: false + monitored_traffic: + icmp: + - NONE + tcp: + - DNS include_nmne: true routers: - hostname: router_1 diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index 66b7df55..d7b1e347 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -5,12 +5,18 @@ import pytest import yaml from gymnasium import spaces +from primaite.game.agent.interface import ProxyAgent from primaite.game.agent.observations.nic_observations import NICObservation from primaite.game.game import PrimaiteGame from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC -from primaite.simulator.network.nmne import CAPTURE_NMNE +from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.sim_container import Simulation +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.simulator.system.services.dns.dns_client import DNSClient +from primaite.simulator.system.services.dns.dns_server import DNSServer +from primaite.simulator.system.services.web_server.web_server import WebServer from tests import TEST_ASSETS_ROOT BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" @@ -31,6 +37,32 @@ def simulation(example_network) -> Simulation: # set simulation network as example network sim.network = example_network + computer: Computer = example_network.get_node_by_hostname("client_1") + server: Server = example_network.get_node_by_hostname("server_1") + + web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") + web_browser.run() + + # Install DNS Client service on computer + computer.software_manager.install(DNSClient) + dns_client: DNSClient = computer.software_manager.software.get("DNSClient") + # set dns server + dns_client.dns_server = server.network_interface[1].ip_address + + # Install Web Server service on server + server.software_manager.install(WebServer) + web_server_service: WebServer = server.software_manager.software.get("WebServer") + web_server_service.start() + + # Install DNS Server service on server + server.software_manager.install(DNSServer) + dns_server: DNSServer = server.software_manager.software.get("DNSServer") + # register arcd.com to DNS + dns_server.dns_register( + domain_name="arcd.com", + domain_ip_address=server.network_interfaces[next(iter(server.network_interfaces))].ip_address, + ) + return sim @@ -102,3 +134,71 @@ def test_config_nic_categories(simulation): high_nmne_threshold=9, include_nmne=True, ) + + +def test_nic_monitored_traffic(simulation): + monitored_traffic = {"icmp": ["NONE"], "tcp": ["DNS"]} + + pc: Computer = simulation.network.get_node_by_hostname("client_1") + pc2: Computer = simulation.network.get_node_by_hostname("client_2") + + nic_obs = NICObservation( + where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True, monitored_traffic=monitored_traffic + ) + + simulation.pre_timestep(0) # apply timestep to whole sim + simulation.apply_timestep(0) # apply timestep to whole sim + traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic") + + assert traffic_obs["icmp"]["inbound"] == 0 + assert traffic_obs["icmp"]["outbound"] == 0 + + # send a ping + pc.ping(target_ip_address=pc2.network_interface[1].ip_address) + traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic") + + assert traffic_obs["icmp"]["inbound"] == 1 + assert traffic_obs["icmp"]["outbound"] == 1 + + simulation.pre_timestep(1) # apply timestep to whole sim + simulation.apply_timestep(1) # apply timestep to whole sim + traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic") + + assert traffic_obs["icmp"]["inbound"] == 0 + assert traffic_obs["icmp"]["outbound"] == 0 + assert traffic_obs["tcp"][53]["inbound"] == 0 + assert traffic_obs["tcp"][53]["outbound"] == 0 + + # send a database query + browser: WebBrowser = pc.software_manager.software.get("WebBrowser") + browser.target_url = f"http://arcd.com/" + browser.get_webpage() + + traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic") + assert traffic_obs["icmp"]["inbound"] == 0 + assert traffic_obs["icmp"]["outbound"] == 0 + assert traffic_obs["tcp"][53]["inbound"] == 0 + assert traffic_obs["tcp"][53]["outbound"] == 1 # getting a webpage sent a dns request out + + simulation.pre_timestep(2) # apply timestep to whole sim + simulation.apply_timestep(2) # apply timestep to whole sim + traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic") + + assert traffic_obs["icmp"]["inbound"] == 0 + assert traffic_obs["icmp"]["outbound"] == 0 + assert traffic_obs["tcp"][53]["inbound"] == 0 + assert traffic_obs["tcp"][53]["outbound"] == 0 + + +def test_nic_monitored_traffic_config(): + """Test that the config loads the monitored traffic config correctly.""" + game: PrimaiteGame = load_config(BASIC_CONFIG) + + # should have icmp and DNS + defender_agent: ProxyAgent = game.agents.get("defender") + cur_obs = defender_agent.observation_manager.current_observation + + assert cur_obs["NODES"]["HOST0"]["NICS"][1]["monitored_traffic"] == { + "icmp": {"inbound": 0, "outbound": 0}, + "tcp": {53: {"inbound": 0, "outbound": 0}}, + } diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index 458cf0ab..2e417192 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -32,6 +32,7 @@ def test_host_observation(simulation): num_services=1, include_num_access=False, include_nmne=False, + monitored_traffic=None, services=[], applications=[], folders=[],