diff --git a/docs/source/simulation_components/system/list_of_applications.rst b/docs/source/simulation_components/system/list_of_applications.rst index a1d8bfd4..94090d93 100644 --- a/docs/source/simulation_components/system/list_of_applications.rst +++ b/docs/source/simulation_components/system/list_of_applications.rst @@ -8,7 +8,7 @@ applications/* -More info :py:mod:`primaite.game.game.APPLICATION_TYPES_MAPPING` +More info :py:mod:`primaite.simulator.system.applications.application.Application` .. include:: list_of_system_applications.rst diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 8a79d068..05210278 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -26,11 +26,14 @@ from primaite.simulator.network.hardware.nodes.network.wireless_router import Wi from primaite.simulator.network.nmne import set_nmne_config from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation -from primaite.simulator.system.applications.database_client import DatabaseClient -from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot -from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot -from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript -from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.simulator.system.applications.application import Application +from primaite.simulator.system.applications.database_client import DatabaseClient # noqa: F401 +from primaite.simulator.system.applications.red_applications.data_manipulation_bot import ( # noqa: F401 + DataManipulationBot, +) +from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot # noqa: F401 +from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript # noqa: F401 +from primaite.simulator.system.applications.web_browser import WebBrowser # noqa: F401 from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer @@ -42,15 +45,6 @@ from primaite.simulator.system.services.web_server.web_server import WebServer _LOGGER = getLogger(__name__) -APPLICATION_TYPES_MAPPING = { - "WebBrowser": WebBrowser, - "DatabaseClient": DatabaseClient, - "DataManipulationBot": DataManipulationBot, - "DoSBot": DoSBot, - "RansomwareScript": RansomwareScript, -} -"""List of available applications that can be installed on nodes in the PrimAITE Simulation.""" - SERVICE_TYPES_MAPPING = { "DNSClient": DNSClient, "DNSServer": DNSServer, @@ -333,9 +327,9 @@ class PrimaiteGame: new_application = None application_type = application_cfg["type"] - if application_type in APPLICATION_TYPES_MAPPING: - new_node.software_manager.install(APPLICATION_TYPES_MAPPING[application_type]) - new_application = new_node.software_manager.software[application_type] + if application_type in Application._application_registry: + new_node.software_manager.install(Application._application_registry[application_type]) + new_application = new_node.software_manager.software[application_type] # grab the instance else: msg = f"Configuration contains an invalid application type: {application_type}" _LOGGER.error(msg) diff --git a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb index 9f4abbf3..77ac4842 100644 --- a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb +++ b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb @@ -171,7 +171,7 @@ "from primaite.simulator.file_system.file_system import FileSystem\n", "\n", "# no applications exist yet so we will create our own.\n", - "class MSPaint(Application):\n", + "class MSPaint(Application, identifier=\"MSPaint\"):\n", " def describe_state(self):\n", " return super().describe_state()" ] diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 8c7d64c9..8d8425ec 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -196,7 +196,7 @@ class SimComponent(BaseModel): ..code::python - class WebBrowser(Application): + class WebBrowser(Application, identifier="WebBrowser"): def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() # all requests generic to any Application get initialised rm.add_request(...) # initialise any requests specific to the web browser diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 1b9a9657..848e1ef0 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from abc import abstractmethod from enum import Enum -from typing import Any, Dict, Optional, Set +from typing import Any, ClassVar, Dict, Optional, Set, Type from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType @@ -39,6 +39,22 @@ class Application(IOSoftware): install_countdown: Optional[int] = None "The countdown to the end of the installation process. None if not currently installing" + _application_registry: ClassVar[Dict[str, Type["Application"]]] = {} + """Registry of application types. Automatically populated when subclasses are defined.""" + + def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: + """ + Register an application type. + + :param identifier: Uniquely specifies an application class by name. Used for finding items by config. + :type identifier: str + :raises ValueError: When attempting to register an application with a name that is already allocated. + """ + super().__init_subclass__(**kwargs) + if identifier in cls._application_registry: + raise ValueError(f"Tried to define new application {identifier}, but this name is already reserved.") + cls._application_registry[identifier] = cls + def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index fcfd603b..06d22126 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -55,7 +55,7 @@ class DatabaseClientConnection(BaseModel): self.client._disconnect(self.connection_id) # noqa -class DatabaseClient(Application): +class DatabaseClient(Application, identifier="DatabaseClient"): """ A DatabaseClient application. diff --git a/src/primaite/simulator/system/applications/nmap.py b/src/primaite/simulator/system/applications/nmap.py index d8af1b7b..c87eaaf5 100644 --- a/src/primaite/simulator/system/applications/nmap.py +++ b/src/primaite/simulator/system/applications/nmap.py @@ -44,7 +44,7 @@ class PortScanPayload(SimComponent): return state -class NMAP(Application): +class NMAP(Application, identifier="NMAP"): """ A class representing the NMAP application for network scanning. diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index cf03d901..fefb22c3 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -37,7 +37,7 @@ class DataManipulationAttackStage(IntEnum): "Signifies that the attack has failed." -class DataManipulationBot(Application): +class DataManipulationBot(Application, identifier="DataManipulationBot"): """A bot that simulates a script which performs a SQL injection attack.""" payload: Optional[str] = None diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index dccf45f5..17478b71 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -29,7 +29,7 @@ class DoSAttackStage(IntEnum): "Attack is completed." -class DoSBot(DatabaseClient): +class DoSBot(DatabaseClient, identifier="DoSBot"): """A bot that simulates a Denial of Service attack.""" target_ip_address: Optional[IPv4Address] = None diff --git a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py index 46e42fc2..8d9d0d18 100644 --- a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py +++ b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py @@ -10,7 +10,7 @@ from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection -class RansomwareScript(Application): +class RansomwareScript(Application, identifier="RansomwareScript"): """Ransomware Kill Chain - Designed to be used by the TAP001 Agent on the example layout Network. :ivar payload: The attack stage query payload. (Default ENCRYPT) diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 19cc4065..73791676 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -23,7 +23,7 @@ from primaite.simulator.system.services.dns.dns_client import DNSClient _LOGGER = getLogger(__name__) -class WebBrowser(Application): +class WebBrowser(Application, identifier="WebBrowser"): """ Represents a web browser in the simulation environment. diff --git a/tests/conftest.py b/tests/conftest.py index b8359323..980e4aa9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,11 +51,11 @@ class TestService(Service): pass -class TestApplication(Application): +class DummyApplication(Application, identifier="DummyApplication"): """Test Application class""" def __init__(self, **kwargs): - kwargs["name"] = "TestApplication" + kwargs["name"] = "DummyApplication" kwargs["port"] = Port.HTTP kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) @@ -85,15 +85,15 @@ def service_class(): @pytest.fixture(scope="function") -def application(file_system) -> TestApplication: - return TestApplication( - name="TestApplication", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_application") +def application(file_system) -> DummyApplication: + return DummyApplication( + name="DummyApplication", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_application") ) @pytest.fixture(scope="function") def application_class(): - return TestApplication + return DummyApplication @pytest.fixture(scope="function") diff --git a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py index 4da1b674..3e06d371 100644 --- a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py +++ b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py @@ -9,9 +9,10 @@ from primaite.config.load import data_manipulation_config_path from primaite.game.agent.interface import ProxyAgent from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent -from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING +from primaite.game.game import PrimaiteGame, SERVICE_TYPES_MAPPING from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot @@ -85,7 +86,7 @@ def test_node_software_install(): assert client_2.software_manager.software.get(software.__name__) is not None # check that applications have been installed on client 1 - for applications in APPLICATION_TYPES_MAPPING: + for applications in Application._application_registry: assert client_1.software_manager.software.get(applications) is not None # check that services have been installed on client 1 diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index 8f65344f..b89d6db6 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -41,7 +41,7 @@ class BroadcastService(Service): super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port.HTTP, ip_protocol=self.protocol) -class BroadcastClient(Application): +class BroadcastClient(Application, identifier="BroadcastClient"): """A client application to receive broadcast and unicast messages.""" payloads_received: List = [] diff --git a/tests/integration_tests/network/test_multi_lan_internet_example_network.py b/tests/integration_tests/network/test_multi_lan_internet_example_network.py index fa290b79..bcc9ad94 100644 --- a/tests/integration_tests/network/test_multi_lan_internet_example_network.py +++ b/tests/integration_tests/network/test_multi_lan_internet_example_network.py @@ -3,9 +3,9 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.networks import multi_lan_internet_network_example from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.ftp.ftp_client import FTPClient -from src.primaite.simulator.system.applications.web_browser import WebBrowser def test_all_with_configured_dns_server_ip_can_resolve_url(): diff --git a/tests/integration_tests/system/test_application_on_node.py b/tests/integration_tests/system/test_application_on_node.py index 275646c6..ffb5cc7f 100644 --- a/tests/integration_tests/system/test_application_on_node.py +++ b/tests/integration_tests/system/test_application_on_node.py @@ -21,7 +21,7 @@ def populated_node(application_class) -> Tuple[Application, Computer]: computer.power_on() computer.software_manager.install(application_class) - app = computer.software_manager.software.get("TestApplication") + app = computer.software_manager.software.get("DummyApplication") app.run() return app, computer @@ -39,7 +39,7 @@ def test_application_on_offline_node(application_class): ) computer.software_manager.install(application_class) - app: Application = computer.software_manager.software.get("TestApplication") + app: Application = computer.software_manager.software.get("DummyApplication") computer.power_off() diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py index 79c72339..a9f0b58d 100644 --- a/tests/integration_tests/test_simulation/test_request_response.py +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -13,7 +13,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.transmission.transport_layer import Port -from tests.conftest import TestApplication, TestService +from tests.conftest import DummyApplication, TestService def test_successful_node_file_system_creation_request(example_network): @@ -47,14 +47,14 @@ def test_successful_application_requests(example_network): net = example_network client_1 = net.get_node_by_hostname("client_1") - client_1.software_manager.install(TestApplication) - client_1.software_manager.software.get("TestApplication").run() + client_1.software_manager.install(DummyApplication) + client_1.software_manager.software.get("DummyApplication").run() - resp_1 = net.apply_request(["node", "client_1", "application", "TestApplication", "scan"]) + resp_1 = net.apply_request(["node", "client_1", "application", "DummyApplication", "scan"]) assert resp_1 == RequestResponse(status="success", data={}) - resp_2 = net.apply_request(["node", "client_1", "application", "TestApplication", "fix"]) + resp_2 = net.apply_request(["node", "client_1", "application", "DummyApplication", "fix"]) assert resp_2 == RequestResponse(status="success", data={}) - resp_3 = net.apply_request(["node", "client_1", "application", "TestApplication", "compromise"]) + resp_3 = net.apply_request(["node", "client_1", "application", "DummyApplication", "compromise"]) assert resp_3 == RequestResponse(status="success", data={}) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_registry.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_registry.py new file mode 100644 index 00000000..d8d7dfab --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_registry.py @@ -0,0 +1,22 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest + +from primaite.simulator.system.applications.application import Application + + +def test_adding_to_app_registry(): + class temp_application(Application, identifier="temp_app"): + pass + + assert Application._application_registry["temp_app"] is temp_application + + with pytest.raises(ValueError): + + class another_application(Application, identifier="temp_app"): + pass + + # This is kinda evil... + # Because pytest doesn't reimport classes from modules, registering this temporary test application will change the + # state of the Application registry for all subsequently run tests. So, we have to delete and unregister the class. + del temp_application + Application._application_registry.pop("temp_app")