diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 3371a99c..c05b493a 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -281,6 +281,10 @@ class HostObservation(AbstractObservation, identifier="HOST"): folder_config.file_system_requires_scan = config.file_system_requires_scan for nic_config in config.network_interfaces: nic_config.include_nmne = config.include_nmne + for service_config in config.services: + service_config.services_requires_scan = config.services_requires_scan + for application_config in config.applications: + application_config.application_config_requires_scan = config.application_config_requires_scan services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services] applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications] diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index 69d9f106..9d60823b 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -39,6 +39,8 @@ def test_host_observation(simulation): folders=[], network_interfaces=[], file_system_requires_scan=True, + services_requires_scan=True, + applications_requires_scan=True, include_users=False, ) diff --git a/tests/integration_tests/game_layer/observations/test_software_observations.py b/tests/integration_tests/game_layer/observations/test_software_observations.py index 998aa755..ab9f6e9c 100644 --- a/tests/integration_tests/game_layer/observations/test_software_observations.py +++ b/tests/integration_tests/game_layer/observations/test_software_observations.py @@ -29,7 +29,9 @@ def test_service_observation(simulation): ntp_server = pc.software_manager.software.get("NTPServer") assert ntp_server - service_obs = ServiceObservation(where=["network", "nodes", pc.hostname, "services", "NTPServer"]) + service_obs = ServiceObservation( + where=["network", "nodes", pc.hostname, "services", "NTPServer"], services_requires_scan=True + ) assert service_obs.space["operating_status"] == spaces.Discrete(7) assert service_obs.space["health_status"] == spaces.Discrete(5) @@ -54,7 +56,9 @@ def test_application_observation(simulation): web_browser: WebBrowser = pc.software_manager.software.get("WebBrowser") assert web_browser - app_obs = ApplicationObservation(where=["network", "nodes", pc.hostname, "applications", "WebBrowser"]) + app_obs = ApplicationObservation( + where=["network", "nodes", pc.hostname, "applications", "WebBrowser"], applications_requires_scan=True + ) web_browser.close() observation_state = app_obs.observe(simulation.describe_state()) diff --git a/tests/unit_tests/_primaite/_game/_agent/test_observations.py b/tests/unit_tests/_primaite/_game/_agent/test_observations.py index 583b9cbd..912b672e 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_observations.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_observations.py @@ -1,4 +1,5 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import json from typing import List import pytest @@ -142,7 +143,7 @@ class TestServiceRequiresScan: ), ) def test_obs_config(self, yaml_option_string, expected_val): - """Check that the default behaviour is to set FileSystemRequiresScan to True.""" + """Check that the default behaviour is to set service_requires_scan to True.""" obs_cfg_yaml = f""" type: CUSTOM options: @@ -155,19 +156,20 @@ class TestServiceRequiresScan: - hostname: web_server services: - service_name: WebServer + - service_name: DNSClient - hostname: database_server folders: - folder_name: database files: - file_name: database.db - hostname: backup_server + services: + - service_name: FTPServer - hostname: security_suite - hostname: client_1 - applications: - - application_name: WebBrowser - hostname: client_2 - num_services: 1 - num_applications: 1 + num_services: 3 + num_applications: 0 num_folders: 1 num_files: 1 num_nics: 2 @@ -226,10 +228,12 @@ class TestServiceRequiresScan: manager = ObservationManager.from_config(cfg) hosts: List[HostObservation] = manager.obs.components["NODES"].hosts - for host in hosts: + for i, host in enumerate(hosts): services: List[ServiceObservation] = host.services - for service in services: - assert service.services_requires_scan == expected_val # Make sure services require scan by default + for j, service in enumerate(services): + val = service.services_requires_scan + print(f"host {i} service {j} {val}") + assert val == expected_val # Make sure services require scan by default def test_services_requires_scan(self): state = {"health_state_actual": 3, "health_state_visible": 1, "operating_state": 1}