From 30d8f142511e2d3c0add63c0bcb13ddce09bb91c Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 3 Jan 2025 16:26:12 +0000 Subject: [PATCH] #2888 - Put software configuration items in the ConfigSchema --- src/primaite/game/game.py | 64 ++----------------- .../system/applications/application.py | 4 +- .../system/applications/database_client.py | 4 ++ .../red_applications/c2/abstract_c2.py | 17 ++--- .../red_applications/c2/c2_beacon.py | 9 ++- .../red_applications/data_manipulation_bot.py | 12 ++++ .../applications/red_applications/dos_bot.py | 27 ++++++-- .../red_applications/ransomware_script.py | 7 ++ .../system/applications/web_browser.py | 2 + .../simulator/system/core/software_manager.py | 25 +++++--- .../simulator/system/services/service.py | 14 ++-- src/primaite/simulator/system/software.py | 35 ++++++++-- .../applications/extended_application.py | 2 + 13 files changed, 125 insertions(+), 97 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 6555e272..5764ad11 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -50,7 +50,7 @@ from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.terminal.terminal import Terminal from primaite.simulator.system.services.web_server.web_server import WebServer from primaite.simulator.system.software import Software -from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import IPProtocol from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -422,74 +422,20 @@ class PrimaiteGame: application_type = application_cfg["type"] if application_type in Application._registry: - new_node.software_manager.install(Application._registry[application_type]) + application_class = Application._registry[application_type] + application_options = application_cfg.get("options", {}) + application_options["type"] = application_type + new_node.software_manager.install(application_class, software_config=application_options) new_application = new_node.software_manager.software[application_type] # grab the instance - # fixing duration for the application - if "fix_duration" in application_cfg.get("options", {}): - new_application.fixing_duration = application_cfg["options"]["fix_duration"] else: msg = f"Configuration contains an invalid application type: {application_type}" _LOGGER.error(msg) raise ValueError(msg) - _set_software_listen_on_ports(new_application, application_cfg) - # run the application new_application.run() - if application_type == "DataManipulationBot": - if "options" in application_cfg: - opt = application_cfg["options"] - new_application.configure( - server_ip_address=IPv4Address(opt.get("server_ip")), - server_password=opt.get("server_password"), - payload=opt.get("payload", "DELETE"), - port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), - data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")), - ) - elif application_type == "RansomwareScript": - if "options" in application_cfg: - opt = application_cfg["options"] - new_application.configure( - server_ip_address=IPv4Address(opt.get("server_ip")) if opt.get("server_ip") else None, - server_password=opt.get("server_password"), - payload=opt.get("payload", "ENCRYPT"), - ) - elif application_type == "DatabaseClient": - if "options" in application_cfg: - opt = application_cfg["options"] - new_application.configure( - server_ip_address=IPv4Address(opt.get("db_server_ip")), - server_password=opt.get("server_password"), - ) - elif application_type == "WebBrowser": - if "options" in application_cfg: - opt = application_cfg["options"] - new_application.target_url = opt.get("target_url") - elif application_type == "DoSBot": - if "options" in application_cfg: - opt = application_cfg["options"] - new_application.configure( - target_ip_address=IPv4Address(opt.get("target_ip_address")), - target_port=PORT_LOOKUP[opt.get("target_port", "POSTGRES_SERVER")], - payload=opt.get("payload"), - repeat=bool(opt.get("repeat")), - port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), - dos_intensity=float(opt.get("dos_intensity", "1.0")), - max_sessions=int(opt.get("max_sessions", "1000")), - ) - elif application_type == "C2Beacon": - if "options" in application_cfg: - opt = application_cfg["options"] - new_application.configure( - c2_server_ip_address=IPv4Address(opt.get("c2_server_ip_address")), - keep_alive_frequency=(opt.get("keep_alive_frequency", 5)), - masquerade_protocol=PROTOCOL_LOOKUP[ - (opt.get("masquerade_protocol", PROTOCOL_LOOKUP["TCP"])) - ], - masquerade_port=PORT_LOOKUP[(opt.get("masquerade_port", PORT_LOOKUP["HTTP"]))], - ) if "network_interfaces" in node_cfg: for nic_num, nic_cfg in node_cfg["network_interfaces"].items(): new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"])) diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index e0cac6b4..4e6f5cf0 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from enum import Enum from typing import Any, ClassVar, Dict, Optional, Set, Type -from pydantic import BaseModel, Field +from pydantic import Field from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType @@ -30,7 +30,7 @@ class Application(IOSoftware, ABC): Applications are user-facing programs that may perform input/output operations. """ - class ConfigSchema(BaseModel, ABC): + class ConfigSchema(IOSoftware.ConfigSchema, ABC): """Config Schema for Application class.""" type: str diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index facc4016..4b7286de 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -73,6 +73,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"): """ConfigSchema for DatabaseClient.""" type: str = "DatabaseClient" + db_server_ip: Optional[IPV4Address] = None + server_password: Optional[str] = None config: ConfigSchema = Field(default_factory=lambda: DatabaseClient.ConfigSchema()) @@ -99,6 +101,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"): kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"] kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) + self.server_ip_address = self.config.db_server_ip + self.server_password = self.config.server_password def _init_request_manager(self) -> RequestManager: """ diff --git a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py index a379769d..71a896bc 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py @@ -2,7 +2,7 @@ from abc import abstractmethod from enum import Enum from ipaddress import IPv4Address -from typing import Dict, Optional, Union +from typing import Dict, Optional, Set, Union from pydantic import Field, validate_call @@ -75,6 +75,8 @@ class AbstractC2(Application): masquerade_port: Port = Field(default=PORT_LOOKUP["HTTP"]) """The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP.""" + listen_on_ports: Set[Port] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]} + config: ConfigSchema = Field(default_factory=lambda: AbstractC2.ConfigSchema()) c2_connection_active: bool = False @@ -101,6 +103,12 @@ class AbstractC2(Application): C2 beacon to reconfigure it's configuration settings. """ + def __init__(self, **kwargs): + """Initialise the C2 applications to by default listen for HTTP traffic.""" + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] + super().__init__(**kwargs) + def _craft_packet( self, c2_payload: C2Payload, c2_command: Optional[C2Command] = None, command_options: Optional[Dict] = {} ) -> C2Packet: @@ -141,13 +149,6 @@ class AbstractC2(Application): """ return super().describe_state() - def __init__(self, **kwargs): - """Initialise the C2 applications to by default listen for HTTP traffic.""" - kwargs["listen_on_ports"] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]} - kwargs["port"] = PORT_LOOKUP["NONE"] - kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] - super().__init__(**kwargs) - @property def _host_ftp_client(self) -> Optional[FTPClient]: """Return the FTPClient that is installed C2 Application's host. diff --git a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py index 014a4096..b9c968c5 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py @@ -12,8 +12,9 @@ from primaite.simulator.system.applications.red_applications.c2 import ExfilOpts from primaite.simulator.system.applications.red_applications.c2.abstract_c2 import AbstractC2, C2Command, C2Payload from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.terminal.terminal import Terminal, TerminalClientConnection -from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP -from primaite.utils.validation.port import PORT_LOOKUP +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import Port, PORT_LOOKUP class C2Beacon(AbstractC2, identifier="C2Beacon"): @@ -39,6 +40,10 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): """ConfigSchema for C2Beacon.""" type: str = "C2Beacon" + c2_server_ip_address: Optional[IPV4Address] = None + keep_alive_frequency: int = 5 + masquerade_protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"] + masquerade_port: Port = PORT_LOOKUP["HTTP"] config: ConfigSchema = Field(default_factory=lambda: C2Beacon.ConfigSchema()) diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index 1978afb9..392cdfba 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -12,6 +12,7 @@ from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -46,6 +47,11 @@ class DataManipulationBot(Application, identifier="DataManipulationBot"): """Configuration schema for DataManipulationBot.""" type: str = "DataManipulationBot" + server_ip: Optional[IPV4Address] = None + server_password: Optional[str] = None + payload: str = "DELETE" + port_scan_p_of_success: float = 0.1 + data_manipulation_p_of_success: float = 0.1 config: "DataManipulationBot.ConfigSchema" = Field(default_factory=lambda: DataManipulationBot.ConfigSchema()) @@ -65,6 +71,12 @@ class DataManipulationBot(Application, identifier="DataManipulationBot"): super().__init__(**kwargs) self._db_connection: Optional[DatabaseClientConnection] = None + self.server_ip_address = self.config.server_ip + self.server_password = self.config.server_password + self.payload = self.config.payload + self.port_scan_p_of_success = self.config.port_scan_p_of_success + self.data_manipulation_p_of_success = self.config.data_manipulation_p_of_success + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index e284ba92..ea7a4d8d 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -9,8 +9,8 @@ from primaite import getLogger from primaite.game.science import simulate_trial from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.utils.validation.ipv4_address import IPV4Address from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -35,6 +35,18 @@ class DoSAttackStage(IntEnum): class DoSBot(DatabaseClient, identifier="DoSBot"): """A bot that simulates a Denial of Service attack.""" + class ConfigSchema(DatabaseClient.ConfigSchema): + """ConfigSchema for DoSBot.""" + + type: str = "DoSBot" + target_ip_address: Optional[IPV4Address] = None + target_port: Port = PORT_LOOKUP["POSTGRES_SERVER"] + payload: Optional[str] = None + repeat: bool = False + port_scan_p_of_success: float = 0.1 + dos_intensity: float = 1.0 + max_sessions: int = 1000 + config: "DoSBot.ConfigSchema" = Field(default_factory=lambda: DoSBot.ConfigSchema()) target_ip_address: Optional[IPv4Address] = None @@ -58,15 +70,16 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): dos_intensity: float = 1.0 """How much of the max sessions will be used by the DoS when attacking.""" - class ConfigSchema(Application.ConfigSchema): - """ConfigSchema for DoSBot.""" - - type: str = "DoSBot" - def __init__(self, **kwargs): super().__init__(**kwargs) self.name = "DoSBot" - self.max_sessions = 1000 # override normal max sessions + self.target_ip_address = self.config.target_ip_address + self.target_port = self.config.target_port + self.payload = self.config.payload + self.repeat = self.config.repeat + self.port_scan_p_of_success = self.config.port_scan_p_of_success + self.dos_intensity = self.config.dos_intensity + self.max_sessions = self.config.max_sessions def _init_request_manager(self) -> RequestManager: """ diff --git a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py index b72dc8e5..114d5716 100644 --- a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py +++ b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py @@ -10,6 +10,7 @@ from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address from primaite.utils.validation.port import PORT_LOOKUP @@ -23,6 +24,9 @@ class RansomwareScript(Application, identifier="RansomwareScript"): """ConfigSchema for RansomwareScript.""" type: str = "RansomwareScript" + server_ip: Optional[IPV4Address] = None + server_password: Optional[str] = None + payload: str = "ENCRYPT" config: "RansomwareScript.ConfigSchema" = Field(default_factory=lambda: RansomwareScript.ConfigSchema()) @@ -40,6 +44,9 @@ class RansomwareScript(Application, identifier="RansomwareScript"): super().__init__(**kwargs) self._db_connection: Optional[DatabaseClientConnection] = None + self.server_ip_address = self.config.server_ip + self.server_password = self.config.server_password + self.payload = self.config.payload def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 52a566f2..ad20640f 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -34,6 +34,7 @@ class WebBrowser(Application, identifier="WebBrowser"): """ConfigSchema for WebBrowser.""" type: str = "WebBrowser" + target_url: Optional[str] = None config: "WebBrowser.ConfigSchema" = Field(default_factory=lambda: WebBrowser.ConfigSchema()) @@ -56,6 +57,7 @@ class WebBrowser(Application, identifier="WebBrowser"): kwargs["port"] = PORT_LOOKUP["HTTP"] super().__init__(**kwargs) + self.target_url = self.config.target_url self.run() def _init_request_manager(self) -> RequestManager: diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index f0ee6f7c..ddb30a3b 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -106,7 +106,7 @@ class SoftwareManager: return True return False - def install(self, software_class: Type[IOSoftware], **install_kwargs): + def install(self, software_class: Type[IOSoftware], software_config: Optional[IOSoftware.ConfigSchema] = None): """ Install an Application or Service. @@ -115,13 +115,22 @@ class SoftwareManager: if software_class in self._software_class_to_name_map: self.sys_log.warning(f"Cannot install {software_class} as it is already installed") return - software = software_class( - software_manager=self, - sys_log=self.sys_log, - file_system=self.file_system, - dns_server=self.dns_server, - **install_kwargs, - ) + if software_config is None: + software = software_class( + software_manager=self, + sys_log=self.sys_log, + file_system=self.file_system, + dns_server=self.dns_server, + ) + else: + software = software_class( + software_manager=self, + sys_log=self.sys_log, + file_system=self.file_system, + dns_server=self.dns_server, + config=software_config, + ) + software.parent = self.node if isinstance(software, Application): self.node.applications[software.uuid] = software diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index bbf8c479..c30294bb 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from enum import Enum from typing import Any, ClassVar, Dict, Optional, Type -from pydantic import BaseModel +from pydantic import Field from primaite import getLogger from primaite.interface.request import RequestFormat, RequestResponse @@ -39,7 +39,12 @@ class Service(IOSoftware): Services are programs that run in the background and may perform input/output operations. """ - config: "Service.ConfigSchema" + class ConfigSchema(IOSoftware.ConfigSchema, ABC): + """Config Schema for Service class.""" + + type: str + + config: "Service.ConfigSchema" = Field(default_factory=lambda: Service.ConfigSchema()) operating_state: ServiceOperatingState = ServiceOperatingState.STOPPED "The current operating state of the Service." @@ -53,11 +58,6 @@ class Service(IOSoftware): _registry: ClassVar[Dict[str, Type["Service"]]] = {} """Registry of service types. Automatically populated when subclasses are defined.""" - class ConfigSchema(BaseModel, ABC): - """Config Schema for Service class.""" - - type: str - def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 34c893eb..4b670fe0 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -1,13 +1,13 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import copy -from abc import abstractmethod +from abc import ABC, abstractmethod from datetime import datetime from enum import Enum from ipaddress import IPv4Address, IPv4Network from typing import Any, Dict, Optional, Set, TYPE_CHECKING, Union from prettytable import MARKDOWN, PrettyTable -from pydantic import Field +from pydantic import BaseModel, ConfigDict, Field from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent @@ -70,7 +70,7 @@ class SoftwareCriticality(Enum): "The highest level of criticality." -class Software(SimComponent): +class Software(SimComponent, ABC): """ A base class representing software in a simulator environment. @@ -78,6 +78,16 @@ class Software(SimComponent): It outlines the fundamental attributes and behaviors expected of any software in the simulation. """ + class ConfigSchema(BaseModel, ABC): + """Configurable options for all software.""" + + model_config = ConfigDict(extra="forbid") + starting_health_state: SoftwareHealthState = SoftwareHealthState.UNUSED + criticality: SoftwareCriticality = SoftwareCriticality.LOWEST + fixing_duration: int = 2 + + config: ConfigSchema = Field(default_factory=lambda: Software.ConfigSchema()) + name: str "The name of the software." health_state_actual: SoftwareHealthState = SoftwareHealthState.UNUSED @@ -105,6 +115,12 @@ class Software(SimComponent): _fixing_countdown: Optional[int] = None "Current number of ticks left to patch the software." + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.health_state_actual = self.config.starting_health_state + self.criticality = self.config.criticality + self.fixing_duration = self.config.fixing_duration + def _init_request_manager(self) -> RequestManager: """ Initialise the request manager. @@ -233,7 +249,7 @@ class Software(SimComponent): super().pre_timestep(timestep) -class IOSoftware(Software): +class IOSoftware(Software, ABC): """ Represents software in a simulator environment that is capable of input/output operations. @@ -243,6 +259,13 @@ class IOSoftware(Software): required. """ + class ConfigSchema(Software.ConfigSchema, ABC): + """Configuration options for all IO Software.""" + + listen_on_ports: Set[Port] = Field(default_factory=set) + + config: ConfigSchema = Field(default_factory=lambda: IOSoftware.ConfigSchema()) + installing_count: int = 0 "The number of times the software has been installed. Default is 0." max_sessions: int = 100 @@ -260,6 +283,10 @@ class IOSoftware(Software): _connections: Dict[str, Dict] = {} "Active connections." + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.listen_on_ports = self.config.listen_on_ports + @abstractmethod def describe_state(self) -> Dict: """ diff --git a/tests/integration_tests/extensions/applications/extended_application.py b/tests/integration_tests/extensions/applications/extended_application.py index 13fa3d1b..159cfd06 100644 --- a/tests/integration_tests/extensions/applications/extended_application.py +++ b/tests/integration_tests/extensions/applications/extended_application.py @@ -35,6 +35,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): """ConfigSchema for ExtendedApplication.""" type: str = "ExtendedApplication" + target_url: Optional[str] = None config: "ExtendedApplication.ConfigSchema" = Field(default_factory=lambda: ExtendedApplication.ConfigSchema()) @@ -57,6 +58,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): kwargs["port"] = PORT_LOOKUP["HTTP"] super().__init__(**kwargs) + self.target_url = self.config.target_url self.run() def _init_request_manager(self) -> RequestManager: