diff --git a/src/primaite/interface/request.py b/src/primaite/interface/request.py index 8e61c1cb..8e922ef9 100644 --- a/src/primaite/interface/request.py +++ b/src/primaite/interface/request.py @@ -1,6 +1,6 @@ from typing import Dict, ForwardRef, Literal -from pydantic import BaseModel, ConfigDict, validate_call +from pydantic import BaseModel, ConfigDict, StrictBool, validate_call RequestResponse = ForwardRef("RequestResponse") """This makes it possible to type-hint RequestResponse.from_bool return type.""" @@ -9,7 +9,7 @@ RequestResponse = ForwardRef("RequestResponse") class RequestResponse(BaseModel): """Schema for generic request responses.""" - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", strict=True) """Cannot have extra fields in the response. Anything custom goes into the data field.""" status: Literal["pending", "success", "failure", "unreachable"] = "pending" @@ -29,7 +29,7 @@ class RequestResponse(BaseModel): @classmethod @validate_call - def from_bool(cls, status_bool: bool) -> RequestResponse: + def from_bool(cls, status_bool: StrictBool) -> RequestResponse: """ Construct a basic request response from a boolean. diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index d5945653..f3cf29bb 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -14,7 +14,7 @@ from primaite import getLogger from primaite.exceptions import NetworkError from primaite.interface.request import RequestResponse from primaite.simulator import SIM_OUTPUT -from primaite.simulator.core import RequestManager, RequestType, SimComponent +from primaite.simulator.core import RequestFormat, RequestManager, RequestPermissionValidator, RequestType, SimComponent from primaite.simulator.domain.account import Account from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState @@ -772,47 +772,69 @@ class Node(SimComponent): self.sys_log.current_episode = episode self.sys_log.setup_logger() + class _NodeIsOnValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only let them through if the node is on. + + This is useful because no actions should be being resolved if the node is off. + """ + + node: Node + """Save a reference to the node instance.""" + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Return whether the node is on or off.""" + return self.node.operating_state == NodeOperatingState.ON + def _init_request_manager(self) -> RequestManager: - # TODO: I see that this code is really confusing and hard to read right now... I think some of these things will - # need a better name and better documentation. + _node_is_on = Node._NodeIsOnValidator(node=self) + rm = super()._init_request_manager() # since there are potentially many services, create an request manager that can map service name self._service_request_manager = RequestManager() - rm.add_request("service", RequestType(func=self._service_request_manager)) + rm.add_request("service", RequestType(func=self._service_request_manager, validator=_node_is_on)) self._nic_request_manager = RequestManager() - rm.add_request("network_interface", RequestType(func=self._nic_request_manager)) + rm.add_request("network_interface", RequestType(func=self._nic_request_manager, validator=_node_is_on)) - rm.add_request("file_system", RequestType(func=self.file_system._request_manager)) + rm.add_request("file_system", RequestType(func=self.file_system._request_manager, validator=_node_is_on)) # currently we don't have any applications nor processes, so these will be empty self._process_request_manager = RequestManager() - rm.add_request("process", RequestType(func=self._process_request_manager)) + rm.add_request("process", RequestType(func=self._process_request_manager, validator=_node_is_on)) self._application_request_manager = RequestManager() - rm.add_request("application", RequestType(func=self._application_request_manager)) + rm.add_request("application", RequestType(func=self._application_request_manager, validator=_node_is_on)) rm.add_request( - "scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.reveal_to_red())) + "scan", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.reveal_to_red()), validator=_node_is_on + ), ) rm.add_request( - "shutdown", RequestType(func=lambda request, context: RequestResponse.from_bool(self.power_off())) + "shutdown", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.power_off()), validator=_node_is_on + ), ) rm.add_request("startup", RequestType(func=lambda request, context: RequestResponse.from_bool(self.power_on()))) rm.add_request( - "reset", RequestType(func=lambda request, context: RequestResponse.from_bool(self.reset())) + "reset", + RequestType(func=lambda request, context: RequestResponse.from_bool(self.reset()), validator=_node_is_on), ) # TODO implement node reset rm.add_request( - "logon", RequestType(func=lambda request, context: RequestResponse.from_bool(False)) + "logon", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on) ) # TODO implement logon request rm.add_request( - "logoff", RequestType(func=lambda request, context: RequestResponse.from_bool(False)) + "logoff", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on) ) # TODO implement logoff request self._os_request_manager = RequestManager() self._os_request_manager.add_request( - "scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan())) + "scan", + RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()), validator=_node_is_on), ) - rm.add_request("os", RequestType(func=self._os_request_manager)) + rm.add_request("os", RequestType(func=self._os_request_manager, validator=_node_is_on)) return rm diff --git a/tests/integration_tests/test_simulation/__init__.py b/tests/integration_tests/test_simulation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py new file mode 100644 index 00000000..09680740 --- /dev/null +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -0,0 +1,92 @@ +# some test cases: +# 0. test that sending a request to a valid target results in a success +# 1. test that sending a request to a component that doesn't exist results in a failure +# 2. test that sending a request to a software on a turned-off component results in a failure +# 3. test every implemented action under several different situation, some of which should lead to a success and some to a failure. + +import pytest + +from primaite.interface.request import RequestResponse +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.host.host_node import HostNode +from tests.conftest import TestApplication, TestService + + +def test_successful_application_requests(example_network): + net = example_network + + client_1 = net.get_node_by_hostname("client_1") + client_1.software_manager.install(TestApplication) + client_1.software_manager.software.get("TestApplication").run() + + resp_1 = net.apply_request(["node", "client_1", "application", "TestApplication", "scan"]) + assert resp_1 == RequestResponse(status="success", data={}) + resp_2 = net.apply_request(["node", "client_1", "application", "TestApplication", "patch"]) + assert resp_2 == RequestResponse(status="success", data={}) + resp_3 = net.apply_request(["node", "client_1", "application", "TestApplication", "compromise"]) + assert resp_3 == RequestResponse(status="success", data={}) + + +def test_successful_service_requests(example_network): + net = example_network + server_1 = net.get_node_by_hostname("server_1") + server_1.software_manager.install(TestService) + + # Careful: the order here is important, for example we cannot run "stop" unless we run "start" first + for verb in [ + "disable", + "enable", + "start", + "stop", + "start", + "restart", + "pause", + "resume", + "compromise", + "scan", + "patch", + ]: + resp_1 = net.apply_request(["node", "server_1", "service", "TestService", verb]) + assert resp_1 == RequestResponse(status="success", data={}) + server_1.apply_timestep(timestep=1) + server_1.apply_timestep(timestep=1) + server_1.apply_timestep(timestep=1) + server_1.apply_timestep(timestep=1) + server_1.apply_timestep(timestep=1) + server_1.apply_timestep(timestep=1) + server_1.apply_timestep(timestep=1) + # lazily apply timestep 7 times to make absolutely sure any time-based things like restart have a chance to finish + + +def test_non_existent_requests(example_network): + net = example_network + resp_1 = net.apply_request(["fake"]) + assert resp_1.status == "unreachable" + resp_2 = net.apply_request(["network", "node", "client_39", "application", "WebBrowser", "execute"]) + assert resp_2.status == "unreachable" + + +@pytest.mark.parametrize( + "node_request", + [ + ["node", "client_1", "file_system", "folder", "root", "scan"], + ["node", "client_1", "os", "scan"], + ["node", "client_1", "service", "DNSClient", "stop"], + ["node", "client_1", "application", "WebBrowser", "scan"], + ["node", "client_1", "network_interface", 1, "disable"], + ], +) +def test_request_fails_if_node_off(example_network, node_request): + """Test that requests succeed when the node is on, and fail if the node is off.""" + net = example_network + client_1: HostNode = net.get_node_by_hostname("client_1") + client_1.shut_down_duration = 0 + + assert client_1.operating_state == NodeOperatingState.ON + resp_1 = net.apply_request(node_request) + assert resp_1.status == "success" + + client_1.power_off() + assert client_1.operating_state == NodeOperatingState.OFF + resp_2 = net.apply_request(node_request) + assert resp_2.status == "failure" diff --git a/tests/unit_tests/_primaite/_interface/__init__.py b/tests/unit_tests/_primaite/_interface/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/_primaite/_interface/test_request.py b/tests/unit_tests/_primaite/_interface/test_request.py new file mode 100644 index 00000000..5c65b572 --- /dev/null +++ b/tests/unit_tests/_primaite/_interface/test_request.py @@ -0,0 +1,32 @@ +import pytest +from pydantic import ValidationError + +from primaite.interface.request import RequestResponse + + +def test_creating_response_object(): + """Test that we can create a response object with given parameters.""" + r1 = RequestResponse(status="success", data={"test_data": 1, "other_data": 2}) + r2 = RequestResponse(status="unreachable") + r3 = RequestResponse(data={"test_data": "is_good"}) + r4 = RequestResponse() + assert isinstance(r1, RequestResponse) + assert isinstance(r2, RequestResponse) + assert isinstance(r3, RequestResponse) + assert isinstance(r4, RequestResponse) + + +def test_creating_response_from_boolean(): + """Test that we can build a response with a single boolean.""" + r1 = RequestResponse.from_bool(True) + assert r1.status == "success" + + r2 = RequestResponse.from_bool(False) + assert r2.status == "failure" + + with pytest.raises(ValidationError): + r3 = RequestResponse.from_bool(1) + with pytest.raises(ValidationError): + r4 = RequestResponse.from_bool("good") + with pytest.raises(ValidationError): + r5 = RequestResponse.from_bool({"data": True})