#2688: apply the request validators + fixing the fix duration test + refactor test class names

This commit is contained in:
Czar Echavez
2024-07-05 15:06:17 +01:00
parent 20e5e40d0d
commit 2a0695d0d1
10 changed files with 263 additions and 65 deletions

View File

@@ -130,10 +130,25 @@ class NetworkInterface(SimComponent, ABC):
More information in user guide and docstring for SimComponent._init_request_manager.
"""
_is_network_interface_enabled = NetworkInterface._EnabledValidator(network_interface=self)
_is_network_interface_disabled = NetworkInterface._DisabledValidator(network_interface=self)
rm = super()._init_request_manager()
rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable())))
rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable())))
rm.add_request(
"enable",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.enable()),
validator=_is_network_interface_disabled,
),
)
rm.add_request(
"disable",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.disable()),
validator=_is_network_interface_enabled,
),
)
return rm
@@ -332,6 +347,50 @@ class NetworkInterface(SimComponent, ABC):
super().pre_timestep(timestep)
self.traffic = {}
class _EnabledValidator(RequestPermissionValidator):
"""
When requests come in, this validator will only let them through if the NetworkInterface is enabled.
This is useful because most actions should be being resolved if the NetworkInterface is disabled.
"""
network_interface: NetworkInterface
"""Save a reference to the node instance."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Return whether the NetworkInterface is enabled or not."""
return self.network_interface.enabled
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return (
f"Cannot perform request on NetworkInterface "
f"'{self.network_interface.mac_address}' because it is not enabled."
)
class _DisabledValidator(RequestPermissionValidator):
"""
When requests come in, this validator will only let them through if the NetworkInterface is disabled.
This is useful because some actions should be being resolved if the NetworkInterface is disabled.
"""
network_interface: NetworkInterface
"""Save a reference to the node instance."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Return whether the NetworkInterface is disabled or not."""
return not self.network_interface.enabled
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return (
f"Cannot perform request on NetworkInterface "
f"'{self.network_interface.mac_address}' because it is not disabled."
)
class WiredNetworkInterface(NetworkInterface, ABC):
"""
@@ -878,6 +937,25 @@ class Node(SimComponent):
"""Message that is reported when a request is rejected by this validator."""
return f"Cannot perform request on node '{self.node.hostname}' because it is not turned on."
class _NodeIsOffValidator(RequestPermissionValidator):
"""
When requests come in, this validator will only let them through if the node is off.
This is useful because some actions require the node to be in an off state.
"""
node: Node
"""Save a reference to the node instance."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Return whether the node is on or off."""
return self.node.operating_state == NodeOperatingState.OFF
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return f"Cannot perform request on node '{self.node.hostname}' because it is not turned off."
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
@@ -940,6 +1018,7 @@ class Node(SimComponent):
return RequestResponse.from_bool(False)
_node_is_on = Node._NodeIsOnValidator(node=self)
_node_is_off = Node._NodeIsOffValidator(node=self)
rm = super()._init_request_manager()
# since there are potentially many services, create an request manager that can map service name
@@ -969,7 +1048,12 @@ class Node(SimComponent):
func=lambda request, context: RequestResponse.from_bool(self.power_off()), validator=_node_is_on
),
)
rm.add_request("startup", RequestType(func=lambda request, context: RequestResponse.from_bool(self.power_on())))
rm.add_request(
"startup",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.power_on()), validator=_node_is_off
),
)
rm.add_request(
"reset",
RequestType(func=lambda request, context: RequestResponse.from_bool(self.reset()), validator=_node_is_on),

View File

@@ -1,10 +1,12 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from __future__ import annotations
from abc import abstractmethod
from enum import Enum
from typing import Any, ClassVar, Dict, Optional, Set, Type
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
@@ -64,9 +66,27 @@ class Application(IOSoftware):
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
_is_application_running = Application._StateValidator(application=self, state=ApplicationOperatingState.RUNNING)
rm.add_request("close", RequestType(func=lambda request, context: RequestResponse.from_bool(self.close())))
rm = super()._init_request_manager()
rm.add_request(
"scan",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.scan()), validator=_is_application_running
),
)
rm.add_request(
"close",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.close()), validator=_is_application_running
),
)
rm.add_request(
"fix",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.fix()), validator=_is_application_running
),
)
return rm
@abstractmethod
@@ -169,3 +189,28 @@ class Application(IOSoftware):
:return: True if successful, False otherwise.
"""
return super().receive(payload=payload, session_id=session_id, **kwargs)
class _StateValidator(RequestPermissionValidator):
"""
When requests come in, this validator will only let them through if the application is in the correct state.
This is useful because most actions require the application to be in a specific state.
"""
application: Application
"""Save a reference to the application instance."""
state: ApplicationOperatingState
"""The state of the application to validate."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Return whether the application is in the state we are validating for."""
return self.application.operating_state == self.state
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return (
f"Cannot perform request on application '{self.application.name}' because it is not in the "
f"{self.state.name} state."
)

View File

@@ -1,11 +1,13 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from __future__ import annotations
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, Optional
from primaite import getLogger
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
_LOGGER = getLogger(__name__)
@@ -40,6 +42,7 @@ class Service(IOSoftware):
restart_duration: int = 5
"How many timesteps does it take to restart this service."
restart_countdown: Optional[int] = None
"If currently restarting, how many timesteps remain until the restart is finished."
@@ -86,15 +89,55 @@ class Service(IOSoftware):
More information in user guide and docstring for SimComponent._init_request_manager.
"""
_is_service_running = Service._StateValidator(service=self, state=ServiceOperatingState.RUNNING)
_is_service_stopped = Service._StateValidator(service=self, state=ServiceOperatingState.STOPPED)
_is_service_paused = Service._StateValidator(service=self, state=ServiceOperatingState.PAUSED)
rm = super()._init_request_manager()
rm.add_request("scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan())))
rm.add_request("stop", RequestType(func=lambda request, context: RequestResponse.from_bool(self.stop())))
rm.add_request("start", RequestType(func=lambda request, context: RequestResponse.from_bool(self.start())))
rm.add_request("pause", RequestType(func=lambda request, context: RequestResponse.from_bool(self.pause())))
rm.add_request("resume", RequestType(func=lambda request, context: RequestResponse.from_bool(self.resume())))
rm.add_request("restart", RequestType(func=lambda request, context: RequestResponse.from_bool(self.restart())))
rm.add_request(
"scan",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.scan()), validator=_is_service_running
),
)
rm.add_request(
"stop",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.stop()), validator=_is_service_running
),
)
rm.add_request(
"start",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.start()), validator=_is_service_stopped
),
)
rm.add_request(
"pause",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.pause()), validator=_is_service_running
),
)
rm.add_request(
"resume",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.resume()), validator=_is_service_paused
),
)
rm.add_request(
"restart",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.restart()), validator=_is_service_running
),
)
rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable())))
rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable())))
rm.add_request(
"fix",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.fix()), validator=_is_service_running
),
)
return rm
@abstractmethod
@@ -191,3 +234,28 @@ class Service(IOSoftware):
self.sys_log.debug(f"Restarting finished for service {self.name}")
self.operating_state = ServiceOperatingState.RUNNING
self.restart_countdown -= 1
class _StateValidator(RequestPermissionValidator):
"""
When requests come in, this validator will only let them through if the service is in the correct state.
This is useful because most actions require the service to be in a specific state.
"""
service: Service
"""Save a reference to the service instance."""
state: ServiceOperatingState
"""The state of the service to validate."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Return whether the service is in the state we are validating for."""
return self.service.operating_state == self.state
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return (
f"Cannot perform request on service '{self.service.name}' because it is not in the "
f"{self.state.name} state."
)

View File

@@ -177,6 +177,9 @@ simulation:
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- type: NMAP
options:
fix_duration: 1
- type: RansomwareScript
options:
fix_duration: 1

View File

@@ -51,11 +51,11 @@ class TestService(Service):
pass
class DummyApplication(Application, identifier="DummyApplication"):
class TestDummyApplication(Application, identifier="TestDummyApplication"):
"""Test Application class"""
def __init__(self, **kwargs):
kwargs["name"] = "DummyApplication"
kwargs["name"] = "TestDummyApplication"
kwargs["port"] = Port.HTTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
@@ -85,15 +85,18 @@ def service_class():
@pytest.fixture(scope="function")
def application(file_system) -> DummyApplication:
return DummyApplication(
name="DummyApplication", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_application")
def application(file_system) -> TestDummyApplication:
return TestDummyApplication(
name="TestDummyApplication",
port=Port.ARP,
file_system=file_system,
sys_log=SysLog(hostname="dummy_application"),
)
@pytest.fixture(scope="function")
def application_class():
return DummyApplication
return TestDummyApplication
@pytest.fixture(scope="function")

View File

@@ -1,35 +1,23 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import copy
from ipaddress import IPv4Address
from pathlib import Path
from typing import Union
import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING
from primaite.simulator.network.container import Network
from primaite.game.game import PrimaiteGame, SERVICE_TYPES_MAPPING
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
from primaite.simulator.system.services.web_server.web_server import WebServer
from tests import TEST_ASSETS_ROOT
TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fix_duration.yaml"
ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fix_duration_one_item.yaml"
TestApplications = ["TestDummyApplication", "TestBroadcastClient"]
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
"""Returns a PrimaiteGame object which loads the contents of a given yaml path."""
@@ -62,9 +50,12 @@ def test_fix_duration_set_from_config():
assert client_1.software_manager.software.get(service).fixing_duration == 3
# in config - applications take 1 timestep to fix
for applications in APPLICATION_TYPES_MAPPING:
assert client_1.software_manager.software.get(applications) is not None
assert client_1.software_manager.software.get(applications).fixing_duration == 1
# remove test applications from list
applications = set(Application._application_registry) - set(TestApplications)
for application in applications:
assert client_1.software_manager.software.get(application) is not None
assert client_1.software_manager.software.get(application).fixing_duration == 1
def test_fix_duration_for_one_item():
@@ -80,8 +71,9 @@ def test_fix_duration_for_one_item():
assert client_1.software_manager.software.get(service).fixing_duration == 2
# in config - applications take 1 timestep to fix
applications = copy.copy(APPLICATION_TYPES_MAPPING)
applications.pop("DatabaseClient")
# remove test applications from list
applications = set(Application._application_registry) - set(TestApplications)
applications.remove("DatabaseClient")
for applications in applications:
assert client_1.software_manager.software.get(applications) is not None
assert client_1.software_manager.software.get(applications).fixing_duration == 2

View File

@@ -0,0 +1 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

View File

@@ -14,7 +14,7 @@ from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.services.service import Service
class BroadcastService(Service):
class TestBroadcastService(Service):
"""A service for sending broadcast and unicast messages over a network."""
def __init__(self, **kwargs):
@@ -41,14 +41,14 @@ class BroadcastService(Service):
super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port.HTTP, ip_protocol=self.protocol)
class BroadcastClient(Application, identifier="BroadcastClient"):
class TestBroadcastClient(Application, identifier="TestBroadcastClient"):
"""A client application to receive broadcast and unicast messages."""
payloads_received: List = []
def __init__(self, **kwargs):
# Set default client properties
kwargs["name"] = "BroadcastClient"
kwargs["name"] = "TestBroadcastClient"
kwargs["port"] = Port.HTTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
@@ -75,8 +75,8 @@ def broadcast_network() -> Network:
start_up_duration=0,
)
client_1.power_on()
client_1.software_manager.install(BroadcastClient)
application_1 = client_1.software_manager.software["BroadcastClient"]
client_1.software_manager.install(TestBroadcastClient)
application_1 = client_1.software_manager.software["TestBroadcastClient"]
application_1.run()
client_2 = Computer(
@@ -87,8 +87,8 @@ def broadcast_network() -> Network:
start_up_duration=0,
)
client_2.power_on()
client_2.software_manager.install(BroadcastClient)
application_2 = client_2.software_manager.software["BroadcastClient"]
client_2.software_manager.install(TestBroadcastClient)
application_2 = client_2.software_manager.software["TestBroadcastClient"]
application_2.run()
server_1 = Server(
@@ -100,8 +100,8 @@ def broadcast_network() -> Network:
)
server_1.power_on()
server_1.software_manager.install(BroadcastService)
service: BroadcastService = server_1.software_manager.software["BroadcastService"]
server_1.software_manager.install(TestBroadcastService)
service: TestBroadcastService = server_1.software_manager.software["BroadcastService"]
service.start()
switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0)
@@ -115,14 +115,16 @@ def broadcast_network() -> Network:
@pytest.fixture(scope="function")
def broadcast_service_and_clients(broadcast_network) -> Tuple[BroadcastService, BroadcastClient, BroadcastClient]:
client_1: BroadcastClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[
"BroadcastClient"
def broadcast_service_and_clients(
broadcast_network,
) -> Tuple[TestBroadcastService, TestBroadcastClient, TestBroadcastClient]:
client_1: TestBroadcastClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[
"TestBroadcastClient"
]
client_2: BroadcastClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[
"BroadcastClient"
client_2: TestBroadcastClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[
"TestBroadcastClient"
]
service: BroadcastService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[
service: TestBroadcastService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[
"BroadcastService"
]

View File

@@ -21,7 +21,7 @@ def populated_node(application_class) -> Tuple[Application, Computer]:
computer.power_on()
computer.software_manager.install(application_class)
app = computer.software_manager.software.get("DummyApplication")
app = computer.software_manager.software.get("TestDummyApplication")
app.run()
return app, computer
@@ -39,7 +39,7 @@ def test_application_on_offline_node(application_class):
)
computer.software_manager.install(application_class)
app: Application = computer.software_manager.software.get("DummyApplication")
app: Application = computer.software_manager.software.get("TestDummyApplication")
computer.power_off()

View File

@@ -13,7 +13,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.transmission.transport_layer import Port
from tests.conftest import DummyApplication, TestService
from tests.conftest import TestDummyApplication, TestService
def test_successful_node_file_system_creation_request(example_network):
@@ -47,14 +47,14 @@ def test_successful_application_requests(example_network):
net = example_network
client_1 = net.get_node_by_hostname("client_1")
client_1.software_manager.install(DummyApplication)
client_1.software_manager.software.get("DummyApplication").run()
client_1.software_manager.install(TestDummyApplication)
client_1.software_manager.software.get("TestDummyApplication").run()
resp_1 = net.apply_request(["node", "client_1", "application", "DummyApplication", "scan"])
resp_1 = net.apply_request(["node", "client_1", "application", "TestDummyApplication", "scan"])
assert resp_1 == RequestResponse(status="success", data={})
resp_2 = net.apply_request(["node", "client_1", "application", "DummyApplication", "fix"])
resp_2 = net.apply_request(["node", "client_1", "application", "TestDummyApplication", "fix"])
assert resp_2 == RequestResponse(status="success", data={})
resp_3 = net.apply_request(["node", "client_1", "application", "DummyApplication", "compromise"])
resp_3 = net.apply_request(["node", "client_1", "application", "TestDummyApplication", "compromise"])
assert resp_3 == RequestResponse(status="success", data={})