From cd6d6325db51ab7857efaf8af4fba03f06f79aa9 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 11 Mar 2024 17:47:33 +0000 Subject: [PATCH] #2350: add tests to check spaces + acl obs test + nmne space changes --- .../_package_data/data_manipulation.yaml | 2 - .../agent/observations/nic_observations.py | 29 ++++++-- .../observations/software_observation.py | 2 +- .../observations/test_acl_observations.py | 66 +++++++++++++++++ .../test_file_system_observations.py | 4 +- .../observations/test_link_observations.py | 73 +++++++++++++++++++ ...servations.py => test_nic_observations.py} | 22 ++++++ .../observations/test_node_observations.py | 3 + .../test_software_observations.py | 4 + 9 files changed, 193 insertions(+), 12 deletions(-) create mode 100644 tests/integration_tests/game_layer/observations/test_acl_observations.py create mode 100644 tests/integration_tests/game_layer/observations/test_link_observations.py rename tests/integration_tests/game_layer/observations/{test_observations.py => test_nic_observations.py} (76%) diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index 47204878..a3a7e44a 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -22,8 +22,6 @@ io_settings: game: max_episode_length: 256 ports: - - ARP - - DNS - HTTP - POSTGRES_SERVER protocols: diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 39298ffe..735b41d4 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -20,6 +20,8 @@ class NicObservation(AbstractObservation): high_nmne_threshold: int = 10 """The minimum number of malicious network events to be considered high.""" + global CAPTURE_NMNE + @property def default_observation(self) -> Dict: """The default NIC observation dict.""" @@ -47,6 +49,15 @@ class NicObservation(AbstractObservation): super().__init__() self.where: Optional[Tuple[str]] = where + global CAPTURE_NMNE + if CAPTURE_NMNE: + self.nmne_inbound_last_step: int = 0 + """NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets + us find the difference.""" + self.nmne_outbound_last_step: int = 0 + """NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets + us find the difference.""" + if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold: self._validate_nmne_categories( low_nmne_threshold=low_nmne_threshold, @@ -128,19 +139,21 @@ class NicObservation(AbstractObservation): inbound_count = inbound_keywords.get("*", 0) outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) outbound_count = outbound_keywords.get("*", 0) - obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count) - obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count) + obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step) + obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step) + self.nmne_inbound_last_step = inbound_count + self.nmne_outbound_last_step = outbound_count return obs_dict @property def space(self) -> spaces.Space: """Gymnasium space object describing the observation space shape.""" - return spaces.Dict( - { - "nic_status": spaces.Discrete(3), - "nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}), - } - ) + space = spaces.Dict({"nic_status": spaces.Discrete(3)}) + + if CAPTURE_NMNE: + space["nmne"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)}) + + return space @classmethod def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index ff61714a..6caf791c 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -51,7 +51,7 @@ class ServiceObservation(AbstractObservation): @property def space(self) -> spaces.Space: """Gymnasium space object describing the observation space shape.""" - return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)}) + return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)}) @classmethod def from_config( diff --git a/tests/integration_tests/game_layer/observations/test_acl_observations.py b/tests/integration_tests/game_layer/observations/test_acl_observations.py new file mode 100644 index 00000000..93867edd --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_acl_observations.py @@ -0,0 +1,66 @@ +import pytest + +from primaite.game.agent.observations.observations import AclObservation +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.sim_container import Simulation +from primaite.simulator.system.services.ntp.ntp_client import NTPClient +from primaite.simulator.system.services.ntp.ntp_server import NTPServer + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_acl_observations(simulation): + """Test the ACL rule observations.""" + router: Router = simulation.network.get_node_by_hostname("router_1") + client_1: Computer = simulation.network.get_node_by_hostname("client_1") + server: Computer = simulation.network.get_node_by_hostname("server_1") + + # quick set up of ntp + client_1.software_manager.install(NTPClient) + ntp_client: NTPClient = client_1.software_manager.software.get("NTPClient") + ntp_client.configure(server.network_interface.get(1).ip_address) + server.software_manager.install(NTPServer) + + # add router acl rule + router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port.NTP, src_port=Port.NTP, position=1) + + acl_obs = AclObservation( + where=["network", "nodes", router.hostname, "acl", "acl"], + node_ip_to_id={}, + ports=["NTP", "HTTP", "POSTGRES_SERVER"], + protocols=["TCP", "UDP", "ICMP"], + ) + + observation_space = acl_obs.observe(simulation.describe_state()) + assert observation_space.get(1) is not None + rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP + assert rule_obs.get("position") == 0 # rule was put at position 1 (0 because counting from 1 instead of 1) + assert rule_obs.get("permission") == 1 # permit = 1 deny = 2 + assert rule_obs.get("source_node_id") == 1 # applies to all source nodes + assert rule_obs.get("dest_node_id") == 1 # applies to all destination nodes + assert rule_obs.get("source_port") == 2 # NTP port is mapped to value 2 (1 = ALL, so 1+1 = 2 quik mafs) + assert rule_obs.get("dest_port") == 2 # NTP port is mapped to value 2 + assert rule_obs.get("protocol") == 1 # 1 = No Protocol + + router.acl.remove_rule(1) + + observation_space = acl_obs.observe(simulation.describe_state()) + assert observation_space.get(1) is not None + rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP + assert rule_obs.get("position") == 0 + assert rule_obs.get("permission") == 0 + assert rule_obs.get("source_node_id") == 0 + assert rule_obs.get("dest_node_id") == 0 + assert rule_obs.get("source_port") == 0 + assert rule_obs.get("dest_port") == 0 + assert rule_obs.get("protocol") == 0 diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py index 808007cc..35bb95fd 100644 --- a/tests/integration_tests/game_layer/observations/test_file_system_observations.py +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -26,7 +26,7 @@ def test_file_observation(simulation): where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"] ) - assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)}) + assert dog_file_obs.space["health_status"] == spaces.Discrete(6) observation_state = dog_file_obs.observe(simulation.describe_state()) assert observation_state.get("health_status") == 1 # good initial @@ -52,6 +52,8 @@ def test_folder_observation(simulation): where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"] ) + assert root_folder_obs.space["health_status"] == spaces.Discrete(6) + observation_state = root_folder_obs.observe(simulation.describe_state()) assert observation_state.get("FILES") is not None assert observation_state.get("health_status") == 1 diff --git a/tests/integration_tests/game_layer/observations/test_link_observations.py b/tests/integration_tests/game_layer/observations/test_link_observations.py new file mode 100644 index 00000000..bfe4d5cc --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_link_observations.py @@ -0,0 +1,73 @@ +import pytest +from gymnasium import spaces + +from primaite.game.agent.observations.observations import LinkObservation +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.base import Link, Node +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.sim_container import Simulation + + +@pytest.fixture(scope="function") +def simulation() -> Simulation: + sim = Simulation() + + network = Network() + + # Create Computer + computer = Computer( + hostname="computer", + ip_address="192.168.1.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + computer.power_on() + + # Create Server + server = Server( + hostname="server", + ip_address="192.168.1.3", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + server.power_on() + + # Connect Computer and Server + network.connect(computer.network_interface[1], server.network_interface[1]) + + # Should be linked + assert next(iter(network.links.values())).is_up + + assert computer.ping(server.network_interface.get(1).ip_address) + + # set simulation network as example network + sim.network = network + + return sim + + +def test_link_observation(simulation): + """Test the link observation.""" + # get a link + link: Link = next(iter(simulation.network.links.values())) + + computer: Computer = simulation.network.get_node_by_hostname("computer") + server: Server = simulation.network.get_node_by_hostname("server") + + simulation.apply_timestep(0) # some pings when network was made - reset with apply timestep + + link_obs = LinkObservation(where=["network", "links", link.uuid]) + + assert link_obs.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) # test that the spaces are 0-10 including 0 and 10 + + observation_state = link_obs.observe(simulation.describe_state()) + assert observation_state.get("PROTOCOLS") is not None + assert observation_state["PROTOCOLS"]["ALL"] == 0 + + computer.ping(server.network_interface.get(1).ip_address) + + observation_state = link_obs.observe(simulation.describe_state()) + assert observation_state["PROTOCOLS"]["ALL"] == 1 diff --git a/tests/integration_tests/game_layer/observations/test_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py similarity index 76% rename from tests/integration_tests/game_layer/observations/test_observations.py rename to tests/integration_tests/game_layer/observations/test_nic_observations.py index 97df7882..c210b751 100644 --- a/tests/integration_tests/game_layer/observations/test_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -1,9 +1,27 @@ +from pathlib import Path +from typing import Union + import pytest +import yaml +from gymnasium import spaces from primaite.game.agent.observations.nic_observations import NicObservation +from primaite.game.game import PrimaiteGame from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC +from primaite.simulator.network.nmne import CAPTURE_NMNE from primaite.simulator.sim_container import Simulation +from tests import TEST_ASSETS_ROOT + +BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" + + +def load_config(config_path: Union[str, Path]) -> PrimaiteGame: + """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) + + return PrimaiteGame.from_config(cfg) @pytest.fixture(scope="function") @@ -24,6 +42,10 @@ def test_nic(simulation): nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1]) + assert nic_obs.space["nic_status"] == spaces.Discrete(3) + assert nic_obs.space["nmne"]["inbound"] == spaces.Discrete(4) + assert nic_obs.space["nmne"]["outbound"] == spaces.Discrete(4) + observation_state = nic_obs.observe(simulation.describe_state()) assert observation_state.get("nic_status") == 1 # enabled assert observation_state.get("nmne") is not None diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index 835202c6..b1563fbd 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -2,6 +2,7 @@ import copy from uuid import uuid4 import pytest +from gymnasium import spaces from primaite.game.agent.observations.node_observations import NodeObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer @@ -24,6 +25,8 @@ def test_node_observation(simulation): node_obs = NodeObservation(where=["network", "nodes", pc.hostname]) + assert node_obs.space["operating_status"] == spaces.Discrete(5) + observation_state = node_obs.observe(simulation.describe_state()) assert observation_state.get("operating_status") == 1 # computer is on diff --git a/tests/integration_tests/game_layer/observations/test_software_observations.py b/tests/integration_tests/game_layer/observations/test_software_observations.py index 17fc386f..4ae0701e 100644 --- a/tests/integration_tests/game_layer/observations/test_software_observations.py +++ b/tests/integration_tests/game_layer/observations/test_software_observations.py @@ -1,4 +1,5 @@ import pytest +from gymnasium import spaces from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer @@ -29,6 +30,9 @@ def test_service_observation(simulation): service_obs = ServiceObservation(where=["network", "nodes", pc.hostname, "services", "NTPServer"]) + assert service_obs.space["operating_status"] == spaces.Discrete(7) + assert service_obs.space["health_status"] == spaces.Discrete(5) + observation_state = service_obs.observe(simulation.describe_state()) assert observation_state.get("health_status") == 0