From 2a0695d0d123def6f9310ea761db4d3b9775e2f6 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 5 Jul 2024 15:06:17 +0100 Subject: [PATCH 1/3] #2688: apply the request validators + fixing the fix duration test + refactor test class names --- .../simulator/network/hardware/base.py | 90 ++++++++++++++++++- .../system/applications/application.py | 53 ++++++++++- .../simulator/system/services/service.py | 84 +++++++++++++++-- .../assets/configs/software_fix_duration.yaml | 3 + tests/conftest.py | 15 ++-- .../test_software_fix_duration.py | 34 +++---- .../actions/test_node_request_permission.py | 1 + .../network/test_broadcast.py | 32 +++---- .../system/test_application_on_node.py | 4 +- .../test_simulation/test_request_response.py | 12 +-- 10 files changed, 263 insertions(+), 65 deletions(-) create mode 100644 tests/integration_tests/game_layer/actions/test_node_request_permission.py diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 6942d280..e728ae97 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -130,10 +130,25 @@ class NetworkInterface(SimComponent, ABC): More information in user guide and docstring for SimComponent._init_request_manager. """ + _is_network_interface_enabled = NetworkInterface._EnabledValidator(network_interface=self) + _is_network_interface_disabled = NetworkInterface._DisabledValidator(network_interface=self) + rm = super()._init_request_manager() - rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable()))) - rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable()))) + rm.add_request( + "enable", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.enable()), + validator=_is_network_interface_disabled, + ), + ) + rm.add_request( + "disable", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.disable()), + validator=_is_network_interface_enabled, + ), + ) return rm @@ -332,6 +347,50 @@ class NetworkInterface(SimComponent, ABC): super().pre_timestep(timestep) self.traffic = {} + class _EnabledValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only let them through if the NetworkInterface is enabled. + + This is useful because most actions should be being resolved if the NetworkInterface is disabled. + """ + + network_interface: NetworkInterface + """Save a reference to the node instance.""" + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Return whether the NetworkInterface is enabled or not.""" + return self.network_interface.enabled + + @property + def fail_message(self) -> str: + """Message that is reported when a request is rejected by this validator.""" + return ( + f"Cannot perform request on NetworkInterface " + f"'{self.network_interface.mac_address}' because it is not enabled." + ) + + class _DisabledValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only let them through if the NetworkInterface is disabled. + + This is useful because some actions should be being resolved if the NetworkInterface is disabled. + """ + + network_interface: NetworkInterface + """Save a reference to the node instance.""" + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Return whether the NetworkInterface is disabled or not.""" + return not self.network_interface.enabled + + @property + def fail_message(self) -> str: + """Message that is reported when a request is rejected by this validator.""" + return ( + f"Cannot perform request on NetworkInterface " + f"'{self.network_interface.mac_address}' because it is not disabled." + ) + class WiredNetworkInterface(NetworkInterface, ABC): """ @@ -878,6 +937,25 @@ class Node(SimComponent): """Message that is reported when a request is rejected by this validator.""" return f"Cannot perform request on node '{self.node.hostname}' because it is not turned on." + class _NodeIsOffValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only let them through if the node is off. + + This is useful because some actions require the node to be in an off state. + """ + + node: Node + """Save a reference to the node instance.""" + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Return whether the node is on or off.""" + return self.node.operating_state == NodeOperatingState.OFF + + @property + def fail_message(self) -> str: + """Message that is reported when a request is rejected by this validator.""" + return f"Cannot perform request on node '{self.node.hostname}' because it is not turned off." + def _init_request_manager(self) -> RequestManager: """ Initialise the request manager. @@ -940,6 +1018,7 @@ class Node(SimComponent): return RequestResponse.from_bool(False) _node_is_on = Node._NodeIsOnValidator(node=self) + _node_is_off = Node._NodeIsOffValidator(node=self) rm = super()._init_request_manager() # since there are potentially many services, create an request manager that can map service name @@ -969,7 +1048,12 @@ class Node(SimComponent): func=lambda request, context: RequestResponse.from_bool(self.power_off()), validator=_node_is_on ), ) - rm.add_request("startup", RequestType(func=lambda request, context: RequestResponse.from_bool(self.power_on()))) + rm.add_request( + "startup", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.power_on()), validator=_node_is_off + ), + ) rm.add_request( "reset", RequestType(func=lambda request, context: RequestResponse.from_bool(self.reset()), validator=_node_is_on), diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 848e1ef0..dc16a725 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -1,10 +1,12 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from __future__ import annotations + from abc import abstractmethod from enum import Enum from typing import Any, ClassVar, Dict, Optional, Set, Type -from primaite.interface.request import RequestResponse -from primaite.simulator.core import RequestManager, RequestType +from primaite.interface.request import RequestFormat, RequestResponse +from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType from primaite.simulator.system.software import IOSoftware, SoftwareHealthState @@ -64,9 +66,27 @@ class Application(IOSoftware): More information in user guide and docstring for SimComponent._init_request_manager. """ - rm = super()._init_request_manager() + _is_application_running = Application._StateValidator(application=self, state=ApplicationOperatingState.RUNNING) - rm.add_request("close", RequestType(func=lambda request, context: RequestResponse.from_bool(self.close()))) + rm = super()._init_request_manager() + rm.add_request( + "scan", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.scan()), validator=_is_application_running + ), + ) + rm.add_request( + "close", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.close()), validator=_is_application_running + ), + ) + rm.add_request( + "fix", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.fix()), validator=_is_application_running + ), + ) return rm @abstractmethod @@ -169,3 +189,28 @@ class Application(IOSoftware): :return: True if successful, False otherwise. """ return super().receive(payload=payload, session_id=session_id, **kwargs) + + class _StateValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only let them through if the application is in the correct state. + + This is useful because most actions require the application to be in a specific state. + """ + + application: Application + """Save a reference to the application instance.""" + + state: ApplicationOperatingState + """The state of the application to validate.""" + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Return whether the application is in the state we are validating for.""" + return self.application.operating_state == self.state + + @property + def fail_message(self) -> str: + """Message that is reported when a request is rejected by this validator.""" + return ( + f"Cannot perform request on application '{self.application.name}' because it is not in the " + f"{self.state.name} state." + ) diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index e6ce2c87..8167a8a9 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -1,11 +1,13 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from __future__ import annotations + from abc import abstractmethod from enum import Enum from typing import Any, Dict, Optional from primaite import getLogger -from primaite.interface.request import RequestResponse -from primaite.simulator.core import RequestManager, RequestType +from primaite.interface.request import RequestFormat, RequestResponse +from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType from primaite.simulator.system.software import IOSoftware, SoftwareHealthState _LOGGER = getLogger(__name__) @@ -40,6 +42,7 @@ class Service(IOSoftware): restart_duration: int = 5 "How many timesteps does it take to restart this service." + restart_countdown: Optional[int] = None "If currently restarting, how many timesteps remain until the restart is finished." @@ -86,15 +89,55 @@ class Service(IOSoftware): More information in user guide and docstring for SimComponent._init_request_manager. """ + _is_service_running = Service._StateValidator(service=self, state=ServiceOperatingState.RUNNING) + _is_service_stopped = Service._StateValidator(service=self, state=ServiceOperatingState.STOPPED) + _is_service_paused = Service._StateValidator(service=self, state=ServiceOperatingState.PAUSED) + rm = super()._init_request_manager() - rm.add_request("scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()))) - rm.add_request("stop", RequestType(func=lambda request, context: RequestResponse.from_bool(self.stop()))) - rm.add_request("start", RequestType(func=lambda request, context: RequestResponse.from_bool(self.start()))) - rm.add_request("pause", RequestType(func=lambda request, context: RequestResponse.from_bool(self.pause()))) - rm.add_request("resume", RequestType(func=lambda request, context: RequestResponse.from_bool(self.resume()))) - rm.add_request("restart", RequestType(func=lambda request, context: RequestResponse.from_bool(self.restart()))) + rm.add_request( + "scan", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.scan()), validator=_is_service_running + ), + ) + rm.add_request( + "stop", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.stop()), validator=_is_service_running + ), + ) + rm.add_request( + "start", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.start()), validator=_is_service_stopped + ), + ) + rm.add_request( + "pause", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.pause()), validator=_is_service_running + ), + ) + rm.add_request( + "resume", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.resume()), validator=_is_service_paused + ), + ) + rm.add_request( + "restart", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.restart()), validator=_is_service_running + ), + ) rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable()))) rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable()))) + rm.add_request( + "fix", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.fix()), validator=_is_service_running + ), + ) return rm @abstractmethod @@ -191,3 +234,28 @@ class Service(IOSoftware): self.sys_log.debug(f"Restarting finished for service {self.name}") self.operating_state = ServiceOperatingState.RUNNING self.restart_countdown -= 1 + + class _StateValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only let them through if the service is in the correct state. + + This is useful because most actions require the service to be in a specific state. + """ + + service: Service + """Save a reference to the service instance.""" + + state: ServiceOperatingState + """The state of the service to validate.""" + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Return whether the service is in the state we are validating for.""" + return self.service.operating_state == self.state + + @property + def fail_message(self) -> str: + """Message that is reported when a request is rejected by this validator.""" + return ( + f"Cannot perform request on service '{self.service.name}' because it is not in the " + f"{self.state.name} state." + ) diff --git a/tests/assets/configs/software_fix_duration.yaml b/tests/assets/configs/software_fix_duration.yaml index beb176d1..1acb05a9 100644 --- a/tests/assets/configs/software_fix_duration.yaml +++ b/tests/assets/configs/software_fix_duration.yaml @@ -177,6 +177,9 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: + - type: NMAP + options: + fix_duration: 1 - type: RansomwareScript options: fix_duration: 1 diff --git a/tests/conftest.py b/tests/conftest.py index 980e4aa9..e36a2460 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,11 +51,11 @@ class TestService(Service): pass -class DummyApplication(Application, identifier="DummyApplication"): +class TestDummyApplication(Application, identifier="TestDummyApplication"): """Test Application class""" def __init__(self, **kwargs): - kwargs["name"] = "DummyApplication" + kwargs["name"] = "TestDummyApplication" kwargs["port"] = Port.HTTP kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) @@ -85,15 +85,18 @@ def service_class(): @pytest.fixture(scope="function") -def application(file_system) -> DummyApplication: - return DummyApplication( - name="DummyApplication", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_application") +def application(file_system) -> TestDummyApplication: + return TestDummyApplication( + name="TestDummyApplication", + port=Port.ARP, + file_system=file_system, + sys_log=SysLog(hostname="dummy_application"), ) @pytest.fixture(scope="function") def application_class(): - return DummyApplication + return TestDummyApplication @pytest.fixture(scope="function") diff --git a/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py b/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py index bf325946..04160f8f 100644 --- a/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py +++ b/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py @@ -1,35 +1,23 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK import copy -from ipaddress import IPv4Address from pathlib import Path from typing import Union import yaml -from primaite.config.load import data_manipulation_config_path -from primaite.game.agent.interface import ProxyAgent -from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent -from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent -from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING -from primaite.simulator.network.container import Network +from primaite.game.game import PrimaiteGame, SERVICE_TYPES_MAPPING from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient -from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot -from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot -from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.dns.dns_client import DNSClient -from primaite.simulator.system.services.dns.dns_server import DNSServer -from primaite.simulator.system.services.ftp.ftp_client import FTPClient -from primaite.simulator.system.services.ftp.ftp_server import FTPServer -from primaite.simulator.system.services.ntp.ntp_client import NTPClient -from primaite.simulator.system.services.ntp.ntp_server import NTPServer -from primaite.simulator.system.services.web_server.web_server import WebServer from tests import TEST_ASSETS_ROOT TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fix_duration.yaml" ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fix_duration_one_item.yaml" +TestApplications = ["TestDummyApplication", "TestBroadcastClient"] + def load_config(config_path: Union[str, Path]) -> PrimaiteGame: """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" @@ -62,9 +50,12 @@ def test_fix_duration_set_from_config(): assert client_1.software_manager.software.get(service).fixing_duration == 3 # in config - applications take 1 timestep to fix - for applications in APPLICATION_TYPES_MAPPING: - assert client_1.software_manager.software.get(applications) is not None - assert client_1.software_manager.software.get(applications).fixing_duration == 1 + # remove test applications from list + applications = set(Application._application_registry) - set(TestApplications) + + for application in applications: + assert client_1.software_manager.software.get(application) is not None + assert client_1.software_manager.software.get(application).fixing_duration == 1 def test_fix_duration_for_one_item(): @@ -80,8 +71,9 @@ def test_fix_duration_for_one_item(): assert client_1.software_manager.software.get(service).fixing_duration == 2 # in config - applications take 1 timestep to fix - applications = copy.copy(APPLICATION_TYPES_MAPPING) - applications.pop("DatabaseClient") + # remove test applications from list + applications = set(Application._application_registry) - set(TestApplications) + applications.remove("DatabaseClient") for applications in applications: assert client_1.software_manager.software.get(applications) is not None assert client_1.software_manager.software.get(applications).fixing_duration == 2 diff --git a/tests/integration_tests/game_layer/actions/test_node_request_permission.py b/tests/integration_tests/game_layer/actions/test_node_request_permission.py new file mode 100644 index 00000000..be6c00e7 --- /dev/null +++ b/tests/integration_tests/game_layer/actions/test_node_request_permission.py @@ -0,0 +1 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index b89d6db6..bcf7b9b0 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -14,7 +14,7 @@ from primaite.simulator.system.applications.application import Application from primaite.simulator.system.services.service import Service -class BroadcastService(Service): +class TestBroadcastService(Service): """A service for sending broadcast and unicast messages over a network.""" def __init__(self, **kwargs): @@ -41,14 +41,14 @@ class BroadcastService(Service): super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port.HTTP, ip_protocol=self.protocol) -class BroadcastClient(Application, identifier="BroadcastClient"): +class TestBroadcastClient(Application, identifier="TestBroadcastClient"): """A client application to receive broadcast and unicast messages.""" payloads_received: List = [] def __init__(self, **kwargs): # Set default client properties - kwargs["name"] = "BroadcastClient" + kwargs["name"] = "TestBroadcastClient" kwargs["port"] = Port.HTTP kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) @@ -75,8 +75,8 @@ def broadcast_network() -> Network: start_up_duration=0, ) client_1.power_on() - client_1.software_manager.install(BroadcastClient) - application_1 = client_1.software_manager.software["BroadcastClient"] + client_1.software_manager.install(TestBroadcastClient) + application_1 = client_1.software_manager.software["TestBroadcastClient"] application_1.run() client_2 = Computer( @@ -87,8 +87,8 @@ def broadcast_network() -> Network: start_up_duration=0, ) client_2.power_on() - client_2.software_manager.install(BroadcastClient) - application_2 = client_2.software_manager.software["BroadcastClient"] + client_2.software_manager.install(TestBroadcastClient) + application_2 = client_2.software_manager.software["TestBroadcastClient"] application_2.run() server_1 = Server( @@ -100,8 +100,8 @@ def broadcast_network() -> Network: ) server_1.power_on() - server_1.software_manager.install(BroadcastService) - service: BroadcastService = server_1.software_manager.software["BroadcastService"] + server_1.software_manager.install(TestBroadcastService) + service: TestBroadcastService = server_1.software_manager.software["BroadcastService"] service.start() switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0) @@ -115,14 +115,16 @@ def broadcast_network() -> Network: @pytest.fixture(scope="function") -def broadcast_service_and_clients(broadcast_network) -> Tuple[BroadcastService, BroadcastClient, BroadcastClient]: - client_1: BroadcastClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[ - "BroadcastClient" +def broadcast_service_and_clients( + broadcast_network, +) -> Tuple[TestBroadcastService, TestBroadcastClient, TestBroadcastClient]: + client_1: TestBroadcastClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[ + "TestBroadcastClient" ] - client_2: BroadcastClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[ - "BroadcastClient" + client_2: TestBroadcastClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[ + "TestBroadcastClient" ] - service: BroadcastService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[ + service: TestBroadcastService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[ "BroadcastService" ] diff --git a/tests/integration_tests/system/test_application_on_node.py b/tests/integration_tests/system/test_application_on_node.py index ffb5cc7f..400ab082 100644 --- a/tests/integration_tests/system/test_application_on_node.py +++ b/tests/integration_tests/system/test_application_on_node.py @@ -21,7 +21,7 @@ def populated_node(application_class) -> Tuple[Application, Computer]: computer.power_on() computer.software_manager.install(application_class) - app = computer.software_manager.software.get("DummyApplication") + app = computer.software_manager.software.get("TestDummyApplication") app.run() return app, computer @@ -39,7 +39,7 @@ def test_application_on_offline_node(application_class): ) computer.software_manager.install(application_class) - app: Application = computer.software_manager.software.get("DummyApplication") + app: Application = computer.software_manager.software.get("TestDummyApplication") computer.power_off() diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py index a9f0b58d..29c70566 100644 --- a/tests/integration_tests/test_simulation/test_request_response.py +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -13,7 +13,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.transmission.transport_layer import Port -from tests.conftest import DummyApplication, TestService +from tests.conftest import TestDummyApplication, TestService def test_successful_node_file_system_creation_request(example_network): @@ -47,14 +47,14 @@ def test_successful_application_requests(example_network): net = example_network client_1 = net.get_node_by_hostname("client_1") - client_1.software_manager.install(DummyApplication) - client_1.software_manager.software.get("DummyApplication").run() + client_1.software_manager.install(TestDummyApplication) + client_1.software_manager.software.get("TestDummyApplication").run() - resp_1 = net.apply_request(["node", "client_1", "application", "DummyApplication", "scan"]) + resp_1 = net.apply_request(["node", "client_1", "application", "TestDummyApplication", "scan"]) assert resp_1 == RequestResponse(status="success", data={}) - resp_2 = net.apply_request(["node", "client_1", "application", "DummyApplication", "fix"]) + resp_2 = net.apply_request(["node", "client_1", "application", "TestDummyApplication", "fix"]) assert resp_2 == RequestResponse(status="success", data={}) - resp_3 = net.apply_request(["node", "client_1", "application", "DummyApplication", "compromise"]) + resp_3 = net.apply_request(["node", "client_1", "application", "TestDummyApplication", "compromise"]) assert resp_3 == RequestResponse(status="success", data={}) From 829a6371deb0a0c18b5ce3e97bb8baa15e580e1d Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 8 Jul 2024 14:39:37 +0100 Subject: [PATCH 2/3] #2688: tests --- .../test_application_request_permission.py | 54 +++++++++ .../actions/test_nic_request_permission.py | 97 ++++++++++++++++ .../actions/test_node_request_permission.py | 93 +++++++++++++++ .../test_service_request_permission.py | 106 ++++++++++++++++++ 4 files changed, 350 insertions(+) create mode 100644 tests/integration_tests/game_layer/actions/test_application_request_permission.py create mode 100644 tests/integration_tests/game_layer/actions/test_nic_request_permission.py create mode 100644 tests/integration_tests/game_layer/actions/test_service_request_permission.py diff --git a/tests/integration_tests/game_layer/actions/test_application_request_permission.py b/tests/integration_tests/game_layer/actions/test_application_request_permission.py new file mode 100644 index 00000000..36a7ae57 --- /dev/null +++ b/tests/integration_tests/game_layer/actions/test_application_request_permission.py @@ -0,0 +1,54 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Tuple + +import pytest + +from primaite.game.agent.interface import ProxyAgent +from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.system.applications.application import ApplicationOperatingState +from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.simulator.system.services.service import ServiceOperatingState + + +@pytest.fixture +def game_and_agent_fixture(game_and_agent): + """Create a game with a simple agent that can be controlled by the tests.""" + game, agent = game_and_agent + + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + client_1.start_up_duration = 3 + + return (game, agent) + + +def test_application_cannot_perform_actions_unless_running(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test the the request permissions prevent any actions unless application is running.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") + + browser.close() + assert browser.operating_state == ApplicationOperatingState.CLOSED + + action = ("NODE_APPLICATION_SCAN", {"node_id": 0, "application_id": 0}) + agent.store_action(action) + game.step() + assert browser.operating_state == ApplicationOperatingState.CLOSED + + action = ("NODE_APPLICATION_CLOSE", {"node_id": 0, "application_id": 0}) + agent.store_action(action) + game.step() + assert browser.operating_state == ApplicationOperatingState.CLOSED + + action = ("NODE_APPLICATION_FIX", {"node_id": 0, "application_id": 0}) + agent.store_action(action) + game.step() + assert browser.operating_state == ApplicationOperatingState.CLOSED + + action = ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}) + agent.store_action(action) + game.step() + assert browser.operating_state == ApplicationOperatingState.CLOSED diff --git a/tests/integration_tests/game_layer/actions/test_nic_request_permission.py b/tests/integration_tests/game_layer/actions/test_nic_request_permission.py new file mode 100644 index 00000000..4c1619e7 --- /dev/null +++ b/tests/integration_tests/game_layer/actions/test_nic_request_permission.py @@ -0,0 +1,97 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Tuple + +import pytest + +from primaite.game.agent.interface import ProxyAgent +from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.system.services.service import ServiceOperatingState + + +@pytest.fixture +def game_and_agent_fixture(game_and_agent): + """Create a game with a simple agent that can be controlled by the tests.""" + game, agent = game_and_agent + + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + client_1.start_up_duration = 3 + + return (game, agent) + + +def test_nic_cannot_be_turned_off_if_not_on(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that a NIC cannot be disabled if it is not enabled.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + nic = client_1.network_interface[1] + nic.disable() + assert nic.enabled is False + + action = ( + "HOST_NIC_DISABLE", + { + "node_id": 0, # client_1 + "nic_id": 0, # the only nic (eth-1) + }, + ) + agent.store_action(action) + game.step() + + assert nic.enabled is False + + +def test_nic_cannot_be_turned_on_if_already_on(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that a NIC cannot be enabled if it is already enabled.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + nic = client_1.network_interface[1] + assert nic.enabled + + action = ( + "HOST_NIC_ENABLE", + { + "node_id": 0, # client_1 + "nic_id": 0, # the only nic (eth-1) + }, + ) + agent.store_action(action) + game.step() + + assert nic.enabled + + +def test_that_a_nic_can_be_enabled_and_disabled(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Tests that a NIC can be enabled and disabled.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + nic = client_1.network_interface[1] + assert nic.enabled + + action = ( + "HOST_NIC_DISABLE", + { + "node_id": 0, # client_1 + "nic_id": 0, # the only nic (eth-1) + }, + ) + agent.store_action(action) + game.step() + + assert nic.enabled is False + + action = ( + "HOST_NIC_ENABLE", + { + "node_id": 0, # client_1 + "nic_id": 0, # the only nic (eth-1) + }, + ) + agent.store_action(action) + game.step() + + assert nic.enabled diff --git a/tests/integration_tests/game_layer/actions/test_node_request_permission.py b/tests/integration_tests/game_layer/actions/test_node_request_permission.py index be6c00e7..fdf04ad5 100644 --- a/tests/integration_tests/game_layer/actions/test_node_request_permission.py +++ b/tests/integration_tests/game_layer/actions/test_node_request_permission.py @@ -1 +1,94 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Tuple + +import pytest + +from primaite.game.agent.interface import ProxyAgent +from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.host.computer import Computer + + +@pytest.fixture +def game_and_agent_fixture(game_and_agent): + """Create a game with a simple agent that can be controlled by the tests.""" + game, agent = game_and_agent + + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + client_1.start_up_duration = 3 + + return (game, agent) + + +def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that the node can be shut down and started up.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + assert client_1.operating_state == NodeOperatingState.ON + + # turn it off + action = ("NODE_SHUTDOWN", {"node_id": 0}) + agent.store_action(action) + game.step() + + assert client_1.operating_state == NodeOperatingState.SHUTTING_DOWN + + for i in range(client_1.shut_down_duration + 1): + action = ("DONOTHING", {"node_id": 0}) + agent.store_action(action) + game.step() + + assert client_1.operating_state == NodeOperatingState.OFF + + # turn it on + action = ("NODE_STARTUP", {"node_id": 0}) + agent.store_action(action) + game.step() + + assert client_1.operating_state == NodeOperatingState.BOOTING + + for i in range(client_1.start_up_duration + 1): + action = ("DONOTHING", {"node_id": 0}) + agent.store_action(action) + game.step() + + assert client_1.operating_state == NodeOperatingState.ON + + +def test_node_cannot_be_started_up_if_node_is_already_on(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that a node cannot be started up if it is already on.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + assert client_1.operating_state == NodeOperatingState.ON + + # turn it on + action = ("NODE_STARTUP", {"node_id": 0}) + agent.store_action(action) + game.step() + + assert client_1.operating_state == NodeOperatingState.ON + + +def test_node_cannot_be_shut_down_if_node_is_already_off(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that a node cannot be shut down if it is already off.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + client_1.power_off() + + for i in range(client_1.shut_down_duration + 1): + action = ("DONOTHING", {"node_id": 0}) + agent.store_action(action) + game.step() + + assert client_1.operating_state == NodeOperatingState.OFF + + # turn it ff + action = ("NODE_SHUTDOWN", {"node_id": 0}) + agent.store_action(action) + game.step() + + assert client_1.operating_state == NodeOperatingState.OFF diff --git a/tests/integration_tests/game_layer/actions/test_service_request_permission.py b/tests/integration_tests/game_layer/actions/test_service_request_permission.py new file mode 100644 index 00000000..3054c73b --- /dev/null +++ b/tests/integration_tests/game_layer/actions/test_service_request_permission.py @@ -0,0 +1,106 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Tuple + +import pytest + +from primaite.game.agent.interface import ProxyAgent +from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.system.services.service import ServiceOperatingState + + +@pytest.fixture +def game_and_agent_fixture(game_and_agent): + """Create a game with a simple agent that can be controlled by the tests.""" + game, agent = game_and_agent + + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + client_1.start_up_duration = 3 + + return (game, agent) + + +def test_service_start(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that the validator makes sure that the service is stopped before starting the service.""" + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + dns_server = server_1.software_manager.software.get("DNSServer") + + dns_server.pause() + assert dns_server.operating_state == ServiceOperatingState.PAUSED + + action = ("NODE_SERVICE_START", {"node_id": 1, "service_id": 0}) + agent.store_action(action) + game.step() + assert dns_server.operating_state == ServiceOperatingState.PAUSED + + dns_server.stop() + + assert dns_server.operating_state == ServiceOperatingState.STOPPED + + action = ("NODE_SERVICE_START", {"node_id": 1, "service_id": 0}) + agent.store_action(action) + game.step() + + assert dns_server.operating_state == ServiceOperatingState.RUNNING + + +def test_service_resume(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that the validator checks if the service is paused before resuming.""" + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + dns_server = server_1.software_manager.software.get("DNSServer") + + action = ("NODE_SERVICE_RESUME", {"node_id": 1, "service_id": 0}) + agent.store_action(action) + game.step() + assert dns_server.operating_state == ServiceOperatingState.RUNNING + + dns_server.pause() + + assert dns_server.operating_state == ServiceOperatingState.PAUSED + + action = ("NODE_SERVICE_RESUME", {"node_id": 1, "service_id": 0}) + agent.store_action(action) + game.step() + + assert dns_server.operating_state == ServiceOperatingState.RUNNING + + +def test_service_cannot_perform_actions_unless_running(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test to make sure that the service cannot perform certain actions while not running.""" + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + dns_server = server_1.software_manager.software.get("DNSServer") + + dns_server.stop() + assert dns_server.operating_state == ServiceOperatingState.STOPPED + + action = ("NODE_SERVICE_SCAN", {"node_id": 1, "service_id": 0}) + agent.store_action(action) + game.step() + assert dns_server.operating_state == ServiceOperatingState.STOPPED + + action = ("NODE_SERVICE_PAUSE", {"node_id": 1, "service_id": 0}) + agent.store_action(action) + game.step() + assert dns_server.operating_state == ServiceOperatingState.STOPPED + + action = ("NODE_SERVICE_RESUME", {"node_id": 1, "service_id": 0}) + agent.store_action(action) + game.step() + assert dns_server.operating_state == ServiceOperatingState.STOPPED + + action = ("NODE_SERVICE_RESTART", {"node_id": 1, "service_id": 0}) + agent.store_action(action) + game.step() + assert dns_server.operating_state == ServiceOperatingState.STOPPED + + action = ("NODE_SERVICE_FIX", {"node_id": 1, "service_id": 0}) + agent.store_action(action) + game.step() + assert dns_server.operating_state == ServiceOperatingState.STOPPED From a3f74087fa27f0830f688cd4adb4837c297759f2 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 8 Jul 2024 15:26:30 +0100 Subject: [PATCH 3/3] #2688: refactor test classes --- tests/conftest.py | 12 ++++---- .../test_software_fix_duration.py | 2 +- .../network/test_broadcast.py | 30 +++++++++---------- .../system/test_application_on_node.py | 4 +-- .../test_simulation/test_request_response.py | 12 ++++---- 5 files changed, 30 insertions(+), 30 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e36a2460..e3c84e6d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,11 +51,11 @@ class TestService(Service): pass -class TestDummyApplication(Application, identifier="TestDummyApplication"): +class DummyApplication(Application, identifier="DummyApplication"): """Test Application class""" def __init__(self, **kwargs): - kwargs["name"] = "TestDummyApplication" + kwargs["name"] = "DummyApplication" kwargs["port"] = Port.HTTP kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) @@ -85,9 +85,9 @@ def service_class(): @pytest.fixture(scope="function") -def application(file_system) -> TestDummyApplication: - return TestDummyApplication( - name="TestDummyApplication", +def application(file_system) -> DummyApplication: + return DummyApplication( + name="DummyApplication", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_application"), @@ -96,7 +96,7 @@ def application(file_system) -> TestDummyApplication: @pytest.fixture(scope="function") def application_class(): - return TestDummyApplication + return DummyApplication @pytest.fixture(scope="function") diff --git a/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py b/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py index 04160f8f..ae4825ff 100644 --- a/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py +++ b/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py @@ -16,7 +16,7 @@ from tests import TEST_ASSETS_ROOT TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fix_duration.yaml" ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fix_duration_one_item.yaml" -TestApplications = ["TestDummyApplication", "TestBroadcastClient"] +TestApplications = ["DummyApplication", "BroadcastTestClient"] def load_config(config_path: Union[str, Path]) -> PrimaiteGame: diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index bcf7b9b0..80007c46 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -14,7 +14,7 @@ from primaite.simulator.system.applications.application import Application from primaite.simulator.system.services.service import Service -class TestBroadcastService(Service): +class BroadcastTestService(Service): """A service for sending broadcast and unicast messages over a network.""" def __init__(self, **kwargs): @@ -41,14 +41,14 @@ class TestBroadcastService(Service): super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port.HTTP, ip_protocol=self.protocol) -class TestBroadcastClient(Application, identifier="TestBroadcastClient"): +class BroadcastTestClient(Application, identifier="BroadcastTestClient"): """A client application to receive broadcast and unicast messages.""" payloads_received: List = [] def __init__(self, **kwargs): # Set default client properties - kwargs["name"] = "TestBroadcastClient" + kwargs["name"] = "BroadcastTestClient" kwargs["port"] = Port.HTTP kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) @@ -75,8 +75,8 @@ def broadcast_network() -> Network: start_up_duration=0, ) client_1.power_on() - client_1.software_manager.install(TestBroadcastClient) - application_1 = client_1.software_manager.software["TestBroadcastClient"] + client_1.software_manager.install(BroadcastTestClient) + application_1 = client_1.software_manager.software["BroadcastTestClient"] application_1.run() client_2 = Computer( @@ -87,8 +87,8 @@ def broadcast_network() -> Network: start_up_duration=0, ) client_2.power_on() - client_2.software_manager.install(TestBroadcastClient) - application_2 = client_2.software_manager.software["TestBroadcastClient"] + client_2.software_manager.install(BroadcastTestClient) + application_2 = client_2.software_manager.software["BroadcastTestClient"] application_2.run() server_1 = Server( @@ -100,8 +100,8 @@ def broadcast_network() -> Network: ) server_1.power_on() - server_1.software_manager.install(TestBroadcastService) - service: TestBroadcastService = server_1.software_manager.software["BroadcastService"] + server_1.software_manager.install(BroadcastTestService) + service: BroadcastTestService = server_1.software_manager.software["BroadcastService"] service.start() switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0) @@ -117,14 +117,14 @@ def broadcast_network() -> Network: @pytest.fixture(scope="function") def broadcast_service_and_clients( broadcast_network, -) -> Tuple[TestBroadcastService, TestBroadcastClient, TestBroadcastClient]: - client_1: TestBroadcastClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[ - "TestBroadcastClient" +) -> Tuple[BroadcastTestService, BroadcastTestClient, BroadcastTestClient]: + client_1: BroadcastTestClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[ + "BroadcastTestClient" ] - client_2: TestBroadcastClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[ - "TestBroadcastClient" + client_2: BroadcastTestClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[ + "BroadcastTestClient" ] - service: TestBroadcastService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[ + service: BroadcastTestService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[ "BroadcastService" ] diff --git a/tests/integration_tests/system/test_application_on_node.py b/tests/integration_tests/system/test_application_on_node.py index 400ab082..ffb5cc7f 100644 --- a/tests/integration_tests/system/test_application_on_node.py +++ b/tests/integration_tests/system/test_application_on_node.py @@ -21,7 +21,7 @@ def populated_node(application_class) -> Tuple[Application, Computer]: computer.power_on() computer.software_manager.install(application_class) - app = computer.software_manager.software.get("TestDummyApplication") + app = computer.software_manager.software.get("DummyApplication") app.run() return app, computer @@ -39,7 +39,7 @@ def test_application_on_offline_node(application_class): ) computer.software_manager.install(application_class) - app: Application = computer.software_manager.software.get("TestDummyApplication") + app: Application = computer.software_manager.software.get("DummyApplication") computer.power_off() diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py index 29c70566..a9f0b58d 100644 --- a/tests/integration_tests/test_simulation/test_request_response.py +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -13,7 +13,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.transmission.transport_layer import Port -from tests.conftest import TestDummyApplication, TestService +from tests.conftest import DummyApplication, TestService def test_successful_node_file_system_creation_request(example_network): @@ -47,14 +47,14 @@ def test_successful_application_requests(example_network): net = example_network client_1 = net.get_node_by_hostname("client_1") - client_1.software_manager.install(TestDummyApplication) - client_1.software_manager.software.get("TestDummyApplication").run() + client_1.software_manager.install(DummyApplication) + client_1.software_manager.software.get("DummyApplication").run() - resp_1 = net.apply_request(["node", "client_1", "application", "TestDummyApplication", "scan"]) + resp_1 = net.apply_request(["node", "client_1", "application", "DummyApplication", "scan"]) assert resp_1 == RequestResponse(status="success", data={}) - resp_2 = net.apply_request(["node", "client_1", "application", "TestDummyApplication", "fix"]) + resp_2 = net.apply_request(["node", "client_1", "application", "DummyApplication", "fix"]) assert resp_2 == RequestResponse(status="success", data={}) - resp_3 = net.apply_request(["node", "client_1", "application", "TestDummyApplication", "compromise"]) + resp_3 = net.apply_request(["node", "client_1", "application", "DummyApplication", "compromise"]) assert resp_3 == RequestResponse(status="success", data={})