Make nodes only accept requests when they're on
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from typing import Dict, ForwardRef, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, validate_call
|
||||
from pydantic import BaseModel, ConfigDict, StrictBool, validate_call
|
||||
|
||||
RequestResponse = ForwardRef("RequestResponse")
|
||||
"""This makes it possible to type-hint RequestResponse.from_bool return type."""
|
||||
@@ -9,7 +9,7 @@ RequestResponse = ForwardRef("RequestResponse")
|
||||
class RequestResponse(BaseModel):
|
||||
"""Schema for generic request responses."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
model_config = ConfigDict(extra="forbid", strict=True)
|
||||
"""Cannot have extra fields in the response. Anything custom goes into the data field."""
|
||||
|
||||
status: Literal["pending", "success", "failure", "unreachable"] = "pending"
|
||||
@@ -29,7 +29,7 @@ class RequestResponse(BaseModel):
|
||||
|
||||
@classmethod
|
||||
@validate_call
|
||||
def from_bool(cls, status_bool: bool) -> RequestResponse:
|
||||
def from_bool(cls, status_bool: StrictBool) -> RequestResponse:
|
||||
"""
|
||||
Construct a basic request response from a boolean.
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from primaite import getLogger
|
||||
from primaite.exceptions import NetworkError
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.core import RequestFormat, RequestManager, RequestPermissionValidator, RequestType, SimComponent
|
||||
from primaite.simulator.domain.account import Account
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
@@ -772,47 +772,69 @@ class Node(SimComponent):
|
||||
self.sys_log.current_episode = episode
|
||||
self.sys_log.setup_logger()
|
||||
|
||||
class _NodeIsOnValidator(RequestPermissionValidator):
|
||||
"""
|
||||
When requests come in, this validator will only let them through if the node is on.
|
||||
|
||||
This is useful because no actions should be being resolved if the node is off.
|
||||
"""
|
||||
|
||||
node: Node
|
||||
"""Save a reference to the node instance."""
|
||||
|
||||
def __call__(self, request: RequestFormat, context: Dict) -> bool:
|
||||
"""Return whether the node is on or off."""
|
||||
return self.node.operating_state == NodeOperatingState.ON
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
# TODO: I see that this code is really confusing and hard to read right now... I think some of these things will
|
||||
# need a better name and better documentation.
|
||||
_node_is_on = Node._NodeIsOnValidator(node=self)
|
||||
|
||||
rm = super()._init_request_manager()
|
||||
# since there are potentially many services, create an request manager that can map service name
|
||||
self._service_request_manager = RequestManager()
|
||||
rm.add_request("service", RequestType(func=self._service_request_manager))
|
||||
rm.add_request("service", RequestType(func=self._service_request_manager, validator=_node_is_on))
|
||||
self._nic_request_manager = RequestManager()
|
||||
rm.add_request("network_interface", RequestType(func=self._nic_request_manager))
|
||||
rm.add_request("network_interface", RequestType(func=self._nic_request_manager, validator=_node_is_on))
|
||||
|
||||
rm.add_request("file_system", RequestType(func=self.file_system._request_manager))
|
||||
rm.add_request("file_system", RequestType(func=self.file_system._request_manager, validator=_node_is_on))
|
||||
|
||||
# currently we don't have any applications nor processes, so these will be empty
|
||||
self._process_request_manager = RequestManager()
|
||||
rm.add_request("process", RequestType(func=self._process_request_manager))
|
||||
rm.add_request("process", RequestType(func=self._process_request_manager, validator=_node_is_on))
|
||||
self._application_request_manager = RequestManager()
|
||||
rm.add_request("application", RequestType(func=self._application_request_manager))
|
||||
rm.add_request("application", RequestType(func=self._application_request_manager, validator=_node_is_on))
|
||||
|
||||
rm.add_request(
|
||||
"scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.reveal_to_red()))
|
||||
"scan",
|
||||
RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.reveal_to_red()), validator=_node_is_on
|
||||
),
|
||||
)
|
||||
|
||||
rm.add_request(
|
||||
"shutdown", RequestType(func=lambda request, context: RequestResponse.from_bool(self.power_off()))
|
||||
"shutdown",
|
||||
RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.power_off()), validator=_node_is_on
|
||||
),
|
||||
)
|
||||
rm.add_request("startup", RequestType(func=lambda request, context: RequestResponse.from_bool(self.power_on())))
|
||||
rm.add_request(
|
||||
"reset", RequestType(func=lambda request, context: RequestResponse.from_bool(self.reset()))
|
||||
"reset",
|
||||
RequestType(func=lambda request, context: RequestResponse.from_bool(self.reset()), validator=_node_is_on),
|
||||
) # TODO implement node reset
|
||||
rm.add_request(
|
||||
"logon", RequestType(func=lambda request, context: RequestResponse.from_bool(False))
|
||||
"logon", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on)
|
||||
) # TODO implement logon request
|
||||
rm.add_request(
|
||||
"logoff", RequestType(func=lambda request, context: RequestResponse.from_bool(False))
|
||||
"logoff", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on)
|
||||
) # TODO implement logoff request
|
||||
|
||||
self._os_request_manager = RequestManager()
|
||||
self._os_request_manager.add_request(
|
||||
"scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()))
|
||||
"scan",
|
||||
RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()), validator=_node_is_on),
|
||||
)
|
||||
rm.add_request("os", RequestType(func=self._os_request_manager))
|
||||
rm.add_request("os", RequestType(func=self._os_request_manager, validator=_node_is_on))
|
||||
|
||||
return rm
|
||||
|
||||
|
||||
0
tests/integration_tests/test_simulation/__init__.py
Normal file
0
tests/integration_tests/test_simulation/__init__.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# some test cases:
|
||||
# 0. test that sending a request to a valid target results in a success
|
||||
# 1. test that sending a request to a component that doesn't exist results in a failure
|
||||
# 2. test that sending a request to a software on a turned-off component results in a failure
|
||||
# 3. test every implemented action under several different situation, some of which should lead to a success and some to a failure.
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
|
||||
from tests.conftest import TestApplication, TestService
|
||||
|
||||
|
||||
def test_successful_application_requests(example_network):
|
||||
net = example_network
|
||||
|
||||
client_1 = net.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.install(TestApplication)
|
||||
client_1.software_manager.software.get("TestApplication").run()
|
||||
|
||||
resp_1 = net.apply_request(["node", "client_1", "application", "TestApplication", "scan"])
|
||||
assert resp_1 == RequestResponse(status="success", data={})
|
||||
resp_2 = net.apply_request(["node", "client_1", "application", "TestApplication", "patch"])
|
||||
assert resp_2 == RequestResponse(status="success", data={})
|
||||
resp_3 = net.apply_request(["node", "client_1", "application", "TestApplication", "compromise"])
|
||||
assert resp_3 == RequestResponse(status="success", data={})
|
||||
|
||||
|
||||
def test_successful_service_requests(example_network):
|
||||
net = example_network
|
||||
server_1 = net.get_node_by_hostname("server_1")
|
||||
server_1.software_manager.install(TestService)
|
||||
|
||||
# Careful: the order here is important, for example we cannot run "stop" unless we run "start" first
|
||||
for verb in [
|
||||
"disable",
|
||||
"enable",
|
||||
"start",
|
||||
"stop",
|
||||
"start",
|
||||
"restart",
|
||||
"pause",
|
||||
"resume",
|
||||
"compromise",
|
||||
"scan",
|
||||
"patch",
|
||||
]:
|
||||
resp_1 = net.apply_request(["node", "server_1", "service", "TestService", verb])
|
||||
assert resp_1 == RequestResponse(status="success", data={})
|
||||
server_1.apply_timestep(timestep=1)
|
||||
server_1.apply_timestep(timestep=1)
|
||||
server_1.apply_timestep(timestep=1)
|
||||
server_1.apply_timestep(timestep=1)
|
||||
server_1.apply_timestep(timestep=1)
|
||||
server_1.apply_timestep(timestep=1)
|
||||
server_1.apply_timestep(timestep=1)
|
||||
# lazily apply timestep 7 times to make absolutely sure any time-based things like restart have a chance to finish
|
||||
|
||||
|
||||
def test_non_existent_requests(example_network):
|
||||
net = example_network
|
||||
resp_1 = net.apply_request(["fake"])
|
||||
assert resp_1.status == "unreachable"
|
||||
resp_2 = net.apply_request(["network", "node", "client_39", "application", "WebBrowser", "execute"])
|
||||
assert resp_2.status == "unreachable"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"node_request",
|
||||
[
|
||||
["node", "client_1", "file_system", "folder", "root", "scan"],
|
||||
["node", "client_1", "os", "scan"],
|
||||
["node", "client_1", "service", "DNSClient", "stop"],
|
||||
["node", "client_1", "application", "WebBrowser", "scan"],
|
||||
["node", "client_1", "network_interface", 1, "disable"],
|
||||
],
|
||||
)
|
||||
def test_request_fails_if_node_off(example_network, node_request):
|
||||
"""Test that requests succeed when the node is on, and fail if the node is off."""
|
||||
net = example_network
|
||||
client_1: HostNode = net.get_node_by_hostname("client_1")
|
||||
client_1.shut_down_duration = 0
|
||||
|
||||
assert client_1.operating_state == NodeOperatingState.ON
|
||||
resp_1 = net.apply_request(node_request)
|
||||
assert resp_1.status == "success"
|
||||
|
||||
client_1.power_off()
|
||||
assert client_1.operating_state == NodeOperatingState.OFF
|
||||
resp_2 = net.apply_request(node_request)
|
||||
assert resp_2.status == "failure"
|
||||
0
tests/unit_tests/_primaite/_interface/__init__.py
Normal file
0
tests/unit_tests/_primaite/_interface/__init__.py
Normal file
32
tests/unit_tests/_primaite/_interface/test_request.py
Normal file
32
tests/unit_tests/_primaite/_interface/test_request.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
|
||||
|
||||
def test_creating_response_object():
|
||||
"""Test that we can create a response object with given parameters."""
|
||||
r1 = RequestResponse(status="success", data={"test_data": 1, "other_data": 2})
|
||||
r2 = RequestResponse(status="unreachable")
|
||||
r3 = RequestResponse(data={"test_data": "is_good"})
|
||||
r4 = RequestResponse()
|
||||
assert isinstance(r1, RequestResponse)
|
||||
assert isinstance(r2, RequestResponse)
|
||||
assert isinstance(r3, RequestResponse)
|
||||
assert isinstance(r4, RequestResponse)
|
||||
|
||||
|
||||
def test_creating_response_from_boolean():
|
||||
"""Test that we can build a response with a single boolean."""
|
||||
r1 = RequestResponse.from_bool(True)
|
||||
assert r1.status == "success"
|
||||
|
||||
r2 = RequestResponse.from_bool(False)
|
||||
assert r2.status == "failure"
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
r3 = RequestResponse.from_bool(1)
|
||||
with pytest.raises(ValidationError):
|
||||
r4 = RequestResponse.from_bool("good")
|
||||
with pytest.raises(ValidationError):
|
||||
r5 = RequestResponse.from_bool({"data": True})
|
||||
Reference in New Issue
Block a user