#2688: apply the request validators + fixing the fix duration test + refactor test class names

This commit is contained in:
Czar Echavez
2024-07-05 15:06:17 +01:00
parent 20e5e40d0d
commit 2a0695d0d1
10 changed files with 263 additions and 65 deletions

View File

@@ -177,6 +177,9 @@ simulation:
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- type: NMAP
options:
fix_duration: 1
- type: RansomwareScript
options:
fix_duration: 1

View File

@@ -51,11 +51,11 @@ class TestService(Service):
pass
class DummyApplication(Application, identifier="DummyApplication"):
class TestDummyApplication(Application, identifier="TestDummyApplication"):
"""Test Application class"""
def __init__(self, **kwargs):
kwargs["name"] = "DummyApplication"
kwargs["name"] = "TestDummyApplication"
kwargs["port"] = Port.HTTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
@@ -85,15 +85,18 @@ def service_class():
@pytest.fixture(scope="function")
def application(file_system) -> DummyApplication:
return DummyApplication(
name="DummyApplication", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_application")
def application(file_system) -> TestDummyApplication:
return TestDummyApplication(
name="TestDummyApplication",
port=Port.ARP,
file_system=file_system,
sys_log=SysLog(hostname="dummy_application"),
)
@pytest.fixture(scope="function")
def application_class():
return DummyApplication
return TestDummyApplication
@pytest.fixture(scope="function")

View File

@@ -1,35 +1,23 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import copy
from ipaddress import IPv4Address
from pathlib import Path
from typing import Union
import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING
from primaite.simulator.network.container import Network
from primaite.game.game import PrimaiteGame, SERVICE_TYPES_MAPPING
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
from primaite.simulator.system.services.web_server.web_server import WebServer
from tests import TEST_ASSETS_ROOT
TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fix_duration.yaml"
ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fix_duration_one_item.yaml"
TestApplications = ["TestDummyApplication", "TestBroadcastClient"]
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
"""Returns a PrimaiteGame object which loads the contents of a given yaml path."""
@@ -62,9 +50,12 @@ def test_fix_duration_set_from_config():
assert client_1.software_manager.software.get(service).fixing_duration == 3
# in config - applications take 1 timestep to fix
for applications in APPLICATION_TYPES_MAPPING:
assert client_1.software_manager.software.get(applications) is not None
assert client_1.software_manager.software.get(applications).fixing_duration == 1
# remove test applications from list
applications = set(Application._application_registry) - set(TestApplications)
for application in applications:
assert client_1.software_manager.software.get(application) is not None
assert client_1.software_manager.software.get(application).fixing_duration == 1
def test_fix_duration_for_one_item():
@@ -80,8 +71,9 @@ def test_fix_duration_for_one_item():
assert client_1.software_manager.software.get(service).fixing_duration == 2
# in config - applications take 1 timestep to fix
applications = copy.copy(APPLICATION_TYPES_MAPPING)
applications.pop("DatabaseClient")
# remove test applications from list
applications = set(Application._application_registry) - set(TestApplications)
applications.remove("DatabaseClient")
for applications in applications:
assert client_1.software_manager.software.get(applications) is not None
assert client_1.software_manager.software.get(applications).fixing_duration == 2

View File

@@ -0,0 +1 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

View File

@@ -14,7 +14,7 @@ from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.services.service import Service
class BroadcastService(Service):
class TestBroadcastService(Service):
"""A service for sending broadcast and unicast messages over a network."""
def __init__(self, **kwargs):
@@ -41,14 +41,14 @@ class BroadcastService(Service):
super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port.HTTP, ip_protocol=self.protocol)
class BroadcastClient(Application, identifier="BroadcastClient"):
class TestBroadcastClient(Application, identifier="TestBroadcastClient"):
"""A client application to receive broadcast and unicast messages."""
payloads_received: List = []
def __init__(self, **kwargs):
# Set default client properties
kwargs["name"] = "BroadcastClient"
kwargs["name"] = "TestBroadcastClient"
kwargs["port"] = Port.HTTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
@@ -75,8 +75,8 @@ def broadcast_network() -> Network:
start_up_duration=0,
)
client_1.power_on()
client_1.software_manager.install(BroadcastClient)
application_1 = client_1.software_manager.software["BroadcastClient"]
client_1.software_manager.install(TestBroadcastClient)
application_1 = client_1.software_manager.software["TestBroadcastClient"]
application_1.run()
client_2 = Computer(
@@ -87,8 +87,8 @@ def broadcast_network() -> Network:
start_up_duration=0,
)
client_2.power_on()
client_2.software_manager.install(BroadcastClient)
application_2 = client_2.software_manager.software["BroadcastClient"]
client_2.software_manager.install(TestBroadcastClient)
application_2 = client_2.software_manager.software["TestBroadcastClient"]
application_2.run()
server_1 = Server(
@@ -100,8 +100,8 @@ def broadcast_network() -> Network:
)
server_1.power_on()
server_1.software_manager.install(BroadcastService)
service: BroadcastService = server_1.software_manager.software["BroadcastService"]
server_1.software_manager.install(TestBroadcastService)
service: TestBroadcastService = server_1.software_manager.software["BroadcastService"]
service.start()
switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0)
@@ -115,14 +115,16 @@ def broadcast_network() -> Network:
@pytest.fixture(scope="function")
def broadcast_service_and_clients(broadcast_network) -> Tuple[BroadcastService, BroadcastClient, BroadcastClient]:
client_1: BroadcastClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[
"BroadcastClient"
def broadcast_service_and_clients(
broadcast_network,
) -> Tuple[TestBroadcastService, TestBroadcastClient, TestBroadcastClient]:
client_1: TestBroadcastClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[
"TestBroadcastClient"
]
client_2: BroadcastClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[
"BroadcastClient"
client_2: TestBroadcastClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[
"TestBroadcastClient"
]
service: BroadcastService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[
service: TestBroadcastService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[
"BroadcastService"
]

View File

@@ -21,7 +21,7 @@ def populated_node(application_class) -> Tuple[Application, Computer]:
computer.power_on()
computer.software_manager.install(application_class)
app = computer.software_manager.software.get("DummyApplication")
app = computer.software_manager.software.get("TestDummyApplication")
app.run()
return app, computer
@@ -39,7 +39,7 @@ def test_application_on_offline_node(application_class):
)
computer.software_manager.install(application_class)
app: Application = computer.software_manager.software.get("DummyApplication")
app: Application = computer.software_manager.software.get("TestDummyApplication")
computer.power_off()

View File

@@ -13,7 +13,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.transmission.transport_layer import Port
from tests.conftest import DummyApplication, TestService
from tests.conftest import TestDummyApplication, TestService
def test_successful_node_file_system_creation_request(example_network):
@@ -47,14 +47,14 @@ def test_successful_application_requests(example_network):
net = example_network
client_1 = net.get_node_by_hostname("client_1")
client_1.software_manager.install(DummyApplication)
client_1.software_manager.software.get("DummyApplication").run()
client_1.software_manager.install(TestDummyApplication)
client_1.software_manager.software.get("TestDummyApplication").run()
resp_1 = net.apply_request(["node", "client_1", "application", "DummyApplication", "scan"])
resp_1 = net.apply_request(["node", "client_1", "application", "TestDummyApplication", "scan"])
assert resp_1 == RequestResponse(status="success", data={})
resp_2 = net.apply_request(["node", "client_1", "application", "DummyApplication", "fix"])
resp_2 = net.apply_request(["node", "client_1", "application", "TestDummyApplication", "fix"])
assert resp_2 == RequestResponse(status="success", data={})
resp_3 = net.apply_request(["node", "client_1", "application", "DummyApplication", "compromise"])
resp_3 = net.apply_request(["node", "client_1", "application", "TestDummyApplication", "compromise"])
assert resp_3 == RequestResponse(status="success", data={})