diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index eae9dc1f..ff61714a 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -69,3 +69,95 @@ class ServiceObservation(AbstractObservation): :rtype: ServiceObservation """ return cls(where=parent_where + ["services", config["service_name"]]) + + +class ApplicationObservation(AbstractObservation): + """Observation of an application in the network.""" + + default_observation: spaces.Space = {"operating_status": 0, "health_status": 0, "num_executions": 0} + "Default observation is what should be returned when the application doesn't exist." + + def __init__(self, where: Optional[Tuple[str]] = None) -> None: + """Initialise application observation. + + :param where: Store information about where in the simulation state dictionary to find the relevant information. + Optional. If None, this corresponds that the file does not exist and the observation will be populated with + zeroes. + + A typical location for a service looks like this: + `['network','nodes',,'applications', ]` + :type where: Optional[List[str]] + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + + app_state = access_from_nested_dict(state, self.where) + if app_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return { + "operating_status": app_state["operating_state"], + "health_status": app_state["health_state_visible"], + "num_executions": self._categorise_num_executions(app_state["num_executions"]), + } + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Dict( + { + "operating_status": spaces.Discrete(7), + "health_status": spaces.Discrete(6), + "num_executions": spaces.Discrete(4), + } + ) + + @classmethod + def from_config( + cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None + ) -> "ApplicationObservation": + """Create application observation from a config. + + :param config: Dictionary containing the configuration for this service observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional. + :type parent_where: Optional[List[str]], optional + :return: Constructed service observation + :rtype: ApplicationObservation + """ + return cls(where=parent_where + ["services", config["application_name"]]) + + @classmethod + def _categorise_num_executions(cls, num_executions: int) -> int: + """ + Categorise the number of executions of an application. + + Helps classify the number of application executions into different categories. + + Current categories: + - 0: Application is never executed + - 1: Application is executed a low number of times (1-5) + - 2: Application is executed often (6-10) + - 3: Application is executed a high number of times (more than 10) + + :param: num_executions: Number of times the application is executed + """ + if num_executions > 10: + return 3 + elif num_executions > 5: + return 2 + elif num_executions > 0: + return 1 + return 0 diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index 3ddc1e5f..529bfe11 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -130,7 +130,7 @@ class Folder(FileSystemItemABC): file = self.get_file_by_id(file_uuid=file_id) file.scan() if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT: - self.visible_health_status = FileSystemItemHealthStatus.CORRUPT + self.health_status = FileSystemItemHealthStatus.CORRUPT self.visible_health_status = self.health_status def _reveal_to_red_timestep(self) -> None: diff --git a/tests/integration_tests/game_layer/observations/__init__.py b/tests/integration_tests/game_layer/observations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py new file mode 100644 index 00000000..808007cc --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -0,0 +1,68 @@ +import pytest +from gymnasium import spaces + +from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.sim_container import Simulation + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_file_observation(simulation): + """Test the file observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + # create a file on the pc + file = pc.file_system.create_file(file_name="dog.png") + + dog_file_obs = FileObservation( + where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"] + ) + + assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)}) + + observation_state = dog_file_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 1 # good initial + + file.corrupt() + observation_state = dog_file_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 1 # scan file so this changes + + file.scan() + file.apply_timestep(0) # apply time step + observation_state = dog_file_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 3 # corrupted + + +def test_folder_observation(simulation): + """Test the folder observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + # create a file and folder on the pc + folder = pc.file_system.create_folder("test_folder") + file = pc.file_system.create_file(file_name="dog.png", folder_name="test_folder") + + root_folder_obs = FolderObservation( + where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"] + ) + + observation_state = root_folder_obs.observe(simulation.describe_state()) + assert observation_state.get("FILES") is not None + assert observation_state.get("health_status") == 1 + + file.corrupt() # corrupt just the file + observation_state = root_folder_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 1 # scan folder to change this + + folder.scan() + for i in range(folder.scan_duration + 1): + folder.apply_timestep(i) # apply as many timesteps as needed for a scan + + observation_state = root_folder_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 3 # file is corrupt therefore folder is corrupted too diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py new file mode 100644 index 00000000..835202c6 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -0,0 +1,43 @@ +import copy +from uuid import uuid4 + +import pytest + +from primaite.game.agent.observations.node_observations import NodeObservation +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.sim_container import Simulation + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_node_observation(simulation): + """Test a Node observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + + node_obs = NodeObservation(where=["network", "nodes", pc.hostname]) + + observation_state = node_obs.observe(simulation.describe_state()) + assert observation_state.get("operating_status") == 1 # computer is on + + assert observation_state.get("SERVICES") is not None + assert observation_state.get("FOLDERS") is not None + assert observation_state.get("NETWORK_INTERFACES") is not None + + # turn off computer + pc.power_off() + observation_state = node_obs.observe(simulation.describe_state()) + assert observation_state.get("operating_status") == 4 # shutting down + + for i in range(pc.shut_down_duration + 1): + pc.apply_timestep(i) + + observation_state = node_obs.observe(simulation.describe_state()) + assert observation_state.get("operating_status") == 2 diff --git a/tests/integration_tests/game_layer/observations/test_observations.py b/tests/integration_tests/game_layer/observations/test_observations.py new file mode 100644 index 00000000..eccda238 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_observations.py @@ -0,0 +1,35 @@ +import pytest + +from primaite.game.agent.observations.observations import NicObservation +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.host_node import NIC +from primaite.simulator.sim_container import Simulation + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_nic(simulation): + """Test the NIC observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + + nic: NIC = pc.network_interface[1] + + nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1]) + + observation_state = nic_obs.observe(simulation.describe_state()) + assert observation_state.get("nic_status") == 1 # enabled + assert observation_state.get("nmne") is not None + assert observation_state["nmne"].get("inbound") == 0 + assert observation_state["nmne"].get("outbound") == 0 + + nic.disable() + observation_state = nic_obs.observe(simulation.describe_state()) + assert observation_state.get("nic_status") == 2 # disabled diff --git a/tests/integration_tests/game_layer/observations/test_software_observations.py b/tests/integration_tests/game_layer/observations/test_software_observations.py new file mode 100644 index 00000000..17fc386f --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_software_observations.py @@ -0,0 +1,66 @@ +import pytest + +from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.sim_container import Simulation +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.simulator.system.services.ntp.ntp_server import NTPServer + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_service_observation(simulation): + """Test the service observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + # install software on the computer + pc.software_manager.install(NTPServer) + + ntp_server = pc.software_manager.software.get("NTPServer") + assert ntp_server + + service_obs = ServiceObservation(where=["network", "nodes", pc.hostname, "services", "NTPServer"]) + + observation_state = service_obs.observe(simulation.describe_state()) + + assert observation_state.get("health_status") == 0 + assert observation_state.get("operating_status") == 1 # running + + ntp_server.restart() + observation_state = service_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 0 + assert observation_state.get("operating_status") == 6 # resetting + + +def test_application_observation(simulation): + """Test the application observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + # install software on the computer + pc.software_manager.install(DatabaseClient) + + web_browser: WebBrowser = pc.software_manager.software.get("WebBrowser") + assert web_browser + + app_obs = ApplicationObservation(where=["network", "nodes", pc.hostname, "applications", "WebBrowser"]) + + web_browser.close() + observation_state = app_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 0 + assert observation_state.get("operating_status") == 2 # stopped + assert observation_state.get("num_executions") == 0 + + web_browser.run() + web_browser.scan() # scan to update health status + web_browser.get_webpage("test") + observation_state = app_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 1 + assert observation_state.get("operating_status") == 1 # running + assert observation_state.get("num_executions") == 1