Make nodes only accept requests when they're on

This commit is contained in:
Marek Wolan
2024-03-09 20:47:57 +00:00
parent 289b5c548a
commit 31ae4672ac
6 changed files with 164 additions and 18 deletions

View File

@@ -1,6 +1,6 @@
from typing import Dict, ForwardRef, Literal
from pydantic import BaseModel, ConfigDict, validate_call
from pydantic import BaseModel, ConfigDict, StrictBool, validate_call
RequestResponse = ForwardRef("RequestResponse")
"""This makes it possible to type-hint RequestResponse.from_bool return type."""
@@ -9,7 +9,7 @@ RequestResponse = ForwardRef("RequestResponse")
class RequestResponse(BaseModel):
"""Schema for generic request responses."""
model_config = ConfigDict(extra="forbid")
model_config = ConfigDict(extra="forbid", strict=True)
"""Cannot have extra fields in the response. Anything custom goes into the data field."""
status: Literal["pending", "success", "failure", "unreachable"] = "pending"
@@ -29,7 +29,7 @@ class RequestResponse(BaseModel):
@classmethod
@validate_call
def from_bool(cls, status_bool: bool) -> RequestResponse:
def from_bool(cls, status_bool: StrictBool) -> RequestResponse:
"""
Construct a basic request response from a boolean.

View File

@@ -14,7 +14,7 @@ from primaite import getLogger
from primaite.exceptions import NetworkError
from primaite.interface.request import RequestResponse
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.core import RequestFormat, RequestManager, RequestPermissionValidator, RequestType, SimComponent
from primaite.simulator.domain.account import Account
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
@@ -772,47 +772,69 @@ class Node(SimComponent):
self.sys_log.current_episode = episode
self.sys_log.setup_logger()
class _NodeIsOnValidator(RequestPermissionValidator):
"""
When requests come in, this validator will only let them through if the node is on.
This is useful because no actions should be being resolved if the node is off.
"""
node: Node
"""Save a reference to the node instance."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Return whether the node is on or off."""
return self.node.operating_state == NodeOperatingState.ON
def _init_request_manager(self) -> RequestManager:
# TODO: I see that this code is really confusing and hard to read right now... I think some of these things will
# need a better name and better documentation.
_node_is_on = Node._NodeIsOnValidator(node=self)
rm = super()._init_request_manager()
# since there are potentially many services, create an request manager that can map service name
self._service_request_manager = RequestManager()
rm.add_request("service", RequestType(func=self._service_request_manager))
rm.add_request("service", RequestType(func=self._service_request_manager, validator=_node_is_on))
self._nic_request_manager = RequestManager()
rm.add_request("network_interface", RequestType(func=self._nic_request_manager))
rm.add_request("network_interface", RequestType(func=self._nic_request_manager, validator=_node_is_on))
rm.add_request("file_system", RequestType(func=self.file_system._request_manager))
rm.add_request("file_system", RequestType(func=self.file_system._request_manager, validator=_node_is_on))
# currently we don't have any applications nor processes, so these will be empty
self._process_request_manager = RequestManager()
rm.add_request("process", RequestType(func=self._process_request_manager))
rm.add_request("process", RequestType(func=self._process_request_manager, validator=_node_is_on))
self._application_request_manager = RequestManager()
rm.add_request("application", RequestType(func=self._application_request_manager))
rm.add_request("application", RequestType(func=self._application_request_manager, validator=_node_is_on))
rm.add_request(
"scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.reveal_to_red()))
"scan",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.reveal_to_red()), validator=_node_is_on
),
)
rm.add_request(
"shutdown", RequestType(func=lambda request, context: RequestResponse.from_bool(self.power_off()))
"shutdown",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.power_off()), validator=_node_is_on
),
)
rm.add_request("startup", RequestType(func=lambda request, context: RequestResponse.from_bool(self.power_on())))
rm.add_request(
"reset", RequestType(func=lambda request, context: RequestResponse.from_bool(self.reset()))
"reset",
RequestType(func=lambda request, context: RequestResponse.from_bool(self.reset()), validator=_node_is_on),
) # TODO implement node reset
rm.add_request(
"logon", RequestType(func=lambda request, context: RequestResponse.from_bool(False))
"logon", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on)
) # TODO implement logon request
rm.add_request(
"logoff", RequestType(func=lambda request, context: RequestResponse.from_bool(False))
"logoff", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on)
) # TODO implement logoff request
self._os_request_manager = RequestManager()
self._os_request_manager.add_request(
"scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()))
"scan",
RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()), validator=_node_is_on),
)
rm.add_request("os", RequestType(func=self._os_request_manager))
rm.add_request("os", RequestType(func=self._os_request_manager, validator=_node_is_on))
return rm

View File

@@ -0,0 +1,92 @@
# some test cases:
# 0. test that sending a request to a valid target results in a success
# 1. test that sending a request to a component that doesn't exist results in a failure
# 2. test that sending a request to a software on a turned-off component results in a failure
# 3. test every implemented action under several different situation, some of which should lead to a success and some to a failure.
import pytest
from primaite.interface.request import RequestResponse
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from tests.conftest import TestApplication, TestService
def test_successful_application_requests(example_network):
net = example_network
client_1 = net.get_node_by_hostname("client_1")
client_1.software_manager.install(TestApplication)
client_1.software_manager.software.get("TestApplication").run()
resp_1 = net.apply_request(["node", "client_1", "application", "TestApplication", "scan"])
assert resp_1 == RequestResponse(status="success", data={})
resp_2 = net.apply_request(["node", "client_1", "application", "TestApplication", "patch"])
assert resp_2 == RequestResponse(status="success", data={})
resp_3 = net.apply_request(["node", "client_1", "application", "TestApplication", "compromise"])
assert resp_3 == RequestResponse(status="success", data={})
def test_successful_service_requests(example_network):
net = example_network
server_1 = net.get_node_by_hostname("server_1")
server_1.software_manager.install(TestService)
# Careful: the order here is important, for example we cannot run "stop" unless we run "start" first
for verb in [
"disable",
"enable",
"start",
"stop",
"start",
"restart",
"pause",
"resume",
"compromise",
"scan",
"patch",
]:
resp_1 = net.apply_request(["node", "server_1", "service", "TestService", verb])
assert resp_1 == RequestResponse(status="success", data={})
server_1.apply_timestep(timestep=1)
server_1.apply_timestep(timestep=1)
server_1.apply_timestep(timestep=1)
server_1.apply_timestep(timestep=1)
server_1.apply_timestep(timestep=1)
server_1.apply_timestep(timestep=1)
server_1.apply_timestep(timestep=1)
# lazily apply timestep 7 times to make absolutely sure any time-based things like restart have a chance to finish
def test_non_existent_requests(example_network):
net = example_network
resp_1 = net.apply_request(["fake"])
assert resp_1.status == "unreachable"
resp_2 = net.apply_request(["network", "node", "client_39", "application", "WebBrowser", "execute"])
assert resp_2.status == "unreachable"
@pytest.mark.parametrize(
"node_request",
[
["node", "client_1", "file_system", "folder", "root", "scan"],
["node", "client_1", "os", "scan"],
["node", "client_1", "service", "DNSClient", "stop"],
["node", "client_1", "application", "WebBrowser", "scan"],
["node", "client_1", "network_interface", 1, "disable"],
],
)
def test_request_fails_if_node_off(example_network, node_request):
"""Test that requests succeed when the node is on, and fail if the node is off."""
net = example_network
client_1: HostNode = net.get_node_by_hostname("client_1")
client_1.shut_down_duration = 0
assert client_1.operating_state == NodeOperatingState.ON
resp_1 = net.apply_request(node_request)
assert resp_1.status == "success"
client_1.power_off()
assert client_1.operating_state == NodeOperatingState.OFF
resp_2 = net.apply_request(node_request)
assert resp_2.status == "failure"

View File

@@ -0,0 +1,32 @@
import pytest
from pydantic import ValidationError
from primaite.interface.request import RequestResponse
def test_creating_response_object():
"""Test that we can create a response object with given parameters."""
r1 = RequestResponse(status="success", data={"test_data": 1, "other_data": 2})
r2 = RequestResponse(status="unreachable")
r3 = RequestResponse(data={"test_data": "is_good"})
r4 = RequestResponse()
assert isinstance(r1, RequestResponse)
assert isinstance(r2, RequestResponse)
assert isinstance(r3, RequestResponse)
assert isinstance(r4, RequestResponse)
def test_creating_response_from_boolean():
"""Test that we can build a response with a single boolean."""
r1 = RequestResponse.from_bool(True)
assert r1.status == "success"
r2 = RequestResponse.from_bool(False)
assert r2.status == "failure"
with pytest.raises(ValidationError):
r3 = RequestResponse.from_bool(1)
with pytest.raises(ValidationError):
r4 = RequestResponse.from_bool("good")
with pytest.raises(ValidationError):
r5 = RequestResponse.from_bool({"data": True})