diff --git a/docs/source/simulation_components/system/applications/ransomware_script.rst b/docs/source/simulation_components/system/applications/ransomware_script.rst index b79ca802..192618fc 100644 --- a/docs/source/simulation_components/system/applications/ransomware_script.rst +++ b/docs/source/simulation_components/system/applications/ransomware_script.rst @@ -70,7 +70,7 @@ Python Configuration ============= -The RansomwareScript inherits configuration options such as ``fix_duration`` from its parent class. However, for the ``RansomwareScript`` the most relevant option is ``server_ip``. +The RansomwareScript inherits configuration options such as ``fixing_duration`` from its parent class. However, for the ``RansomwareScript`` the most relevant option is ``server_ip``. ``server_ip`` diff --git a/docs/source/simulation_components/system/common/common_configuration.rst b/docs/source/simulation_components/system/common/common_configuration.rst index 411fd529..c1bbd4b2 100644 --- a/docs/source/simulation_components/system/common/common_configuration.rst +++ b/docs/source/simulation_components/system/common/common_configuration.rst @@ -22,8 +22,8 @@ options The configuration options are the attributes that fall under the options for an application or service. -fix_duration -"""""""""""" +fixing_duration +""""""""""""""" Optional. Default value is ``2``. diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 650d6a10..198e8750 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -44,7 +44,7 @@ from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.terminal.terminal import Terminal from primaite.simulator.system.services.web_server.web_server import WebServer from primaite.simulator.system.software import Software -from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import IPProtocol from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -370,12 +370,12 @@ class PrimaiteGame: if service_class is not None: _LOGGER.debug(f"installing {service_type} on node {new_node.hostname}") - new_node.software_manager.install(service_class, **service_cfg.get("options", {})) + new_node.software_manager.install(service_class) new_service = new_node.software_manager.software[service_class.__name__] # fixing duration for the service - if "fix_duration" in service_cfg.get("options", {}): - new_service.fixing_duration = service_cfg["options"]["fix_duration"] + if "fixing_duration" in service_cfg.get("options", {}): + new_service.config.fixing_duration = service_cfg["options"]["fixing_duration"] _set_software_listen_on_ports(new_service, service_cfg) # start the service @@ -416,74 +416,20 @@ class PrimaiteGame: application_type = application_cfg["type"] if application_type in Application._registry: - new_node.software_manager.install(Application._registry[application_type]) + application_class = Application._registry[application_type] + application_options = application_cfg.get("options", {}) + application_options["type"] = application_type + new_node.software_manager.install(application_class, software_config=application_options) new_application = new_node.software_manager.software[application_type] # grab the instance - # fixing duration for the application - if "fix_duration" in application_cfg.get("options", {}): - new_application.fixing_duration = application_cfg["options"]["fix_duration"] else: msg = f"Configuration contains an invalid application type: {application_type}" _LOGGER.error(msg) raise ValueError(msg) - _set_software_listen_on_ports(new_application, application_cfg) - # run the application new_application.run() - if application_type == "DataManipulationBot": - if "options" in application_cfg: - opt = application_cfg["options"] - new_application.configure( - server_ip_address=IPv4Address(opt.get("server_ip")), - server_password=opt.get("server_password"), - payload=opt.get("payload", "DELETE"), - port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), - data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")), - ) - elif application_type == "RansomwareScript": - if "options" in application_cfg: - opt = application_cfg["options"] - new_application.configure( - server_ip_address=IPv4Address(opt.get("server_ip")) if opt.get("server_ip") else None, - server_password=opt.get("server_password"), - payload=opt.get("payload", "ENCRYPT"), - ) - elif application_type == "DatabaseClient": - if "options" in application_cfg: - opt = application_cfg["options"] - new_application.configure( - server_ip_address=IPv4Address(opt.get("db_server_ip")), - server_password=opt.get("server_password"), - ) - elif application_type == "WebBrowser": - if "options" in application_cfg: - opt = application_cfg["options"] - new_application.target_url = opt.get("target_url") - elif application_type == "DoSBot": - if "options" in application_cfg: - opt = application_cfg["options"] - new_application.configure( - target_ip_address=IPv4Address(opt.get("target_ip_address")), - target_port=PORT_LOOKUP[opt.get("target_port", "POSTGRES_SERVER")], - payload=opt.get("payload"), - repeat=bool(opt.get("repeat")), - port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), - dos_intensity=float(opt.get("dos_intensity", "1.0")), - max_sessions=int(opt.get("max_sessions", "1000")), - ) - elif application_type == "C2Beacon": - if "options" in application_cfg: - opt = application_cfg["options"] - new_application.configure( - c2_server_ip_address=IPv4Address(opt.get("c2_server_ip_address")), - keep_alive_frequency=(opt.get("keep_alive_frequency", 5)), - masquerade_protocol=PROTOCOL_LOOKUP[ - (opt.get("masquerade_protocol", PROTOCOL_LOOKUP["TCP"])) - ], - masquerade_port=PORT_LOOKUP[(opt.get("masquerade_port", PORT_LOOKUP["HTTP"]))], - ) if "network_interfaces" in node_cfg: for nic_num, nic_cfg in node_cfg["network_interfaces"].items(): new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"])) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 0bbc0768..ecbd0629 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -824,7 +824,7 @@ class User(SimComponent): return self.model_dump() -class UserManager(Service): +class UserManager(Service, identifier="UserManager"): """ Manages users within the PrimAITE system, handling creation, authentication, and administration. @@ -833,11 +833,18 @@ class UserManager(Service): :param disabled_admins: A dictionary of currently disabled admin users by their usernames """ + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for UserManager.""" + + type: str = "UserManager" + + config: "UserManager.ConfigSchema" = Field(default_factory=lambda: UserManager.ConfigSchema()) + users: Dict[str, User] = {} def __init__(self, **kwargs): """ - Initializes a UserManager instanc. + Initializes a UserManager instance. :param username: The username for the default admin user :param password: The password for the default admin user @@ -1130,13 +1137,20 @@ class RemoteUserSession(UserSession): return state -class UserSessionManager(Service): +class UserSessionManager(Service, identifier="UserSessionManager"): """ Manages user sessions on a Node, including local and remote sessions. This class handles authentication, session management, and session timeouts for users interacting with the Node. """ + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for UserSessionManager.""" + + type: str = "UserSessionManager" + + config: "UserSessionManager.ConfigSchema" = Field(default_factory=lambda: UserSessionManager.ConfigSchema()) + local_session: Optional[UserSession] = None """The current local user session, if any.""" diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 250a4dd9..05a47d7a 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -1,10 +1,12 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations -from abc import abstractmethod +from abc import ABC, abstractmethod from enum import Enum from typing import Any, ClassVar, Dict, Optional, Set, Type +from pydantic import Field + from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType from primaite.simulator.system.software import IOSoftware, SoftwareHealthState @@ -21,13 +23,20 @@ class ApplicationOperatingState(Enum): "The application is being installed or updated." -class Application(IOSoftware): +class Application(IOSoftware, ABC): """ Represents an Application in the simulation environment. Applications are user-facing programs that may perform input/output operations. """ + class ConfigSchema(IOSoftware.ConfigSchema, ABC): + """Config Schema for Application class.""" + + type: str + + config: ConfigSchema = Field(default_factory=lambda: Application.ConfigSchema()) + operating_state: ApplicationOperatingState = ApplicationOperatingState.CLOSED "The current operating state of the Application." execution_control_status: str = "manual" @@ -49,7 +58,7 @@ class Application(IOSoftware): Register an application type. :param identifier: Uniquely specifies an application class by name. Used for finding items by config. - :type identifier: str + :type identifier: Optional[str] :raises ValueError: When attempting to register an application with a name that is already allocated. """ super().__init_subclass__(**kwargs) @@ -59,6 +68,21 @@ class Application(IOSoftware): raise ValueError(f"Tried to define new application {identifier}, but this name is already reserved.") cls._registry[identifier] = cls + @classmethod + def from_config(cls, config: Dict) -> "Application": + """Create an application from a config dictionary. + + :param config: dict of options for application components constructor + :type config: dict + :return: The application component. + :rtype: Application + """ + if config["type"] not in cls._registry: + raise ValueError(f"Invalid Application type {config['type']}") + application_class = cls._registry[config["type"]] + application_object = application_class(config=application_class.ConfigSchema(**config)) + return application_object + def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index f3ccc531..96130e16 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Union from uuid import uuid4 from prettytable import MARKDOWN, PrettyTable -from pydantic import BaseModel +from pydantic import BaseModel, Field from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType @@ -67,11 +67,19 @@ class DatabaseClient(Application, identifier="DatabaseClient"): Extends the Application class to provide functionality for connecting, querying, and disconnecting from a Database Service. It mainly operates over TCP protocol. - - :ivar server_ip_address: The IPv4 address of the Database Service server, defaults to None. """ + class ConfigSchema(Application.ConfigSchema): + """ConfigSchema for DatabaseClient.""" + + type: str = "DatabaseClient" + db_server_ip: Optional[IPV4Address] = None + server_password: Optional[str] = None + + config: ConfigSchema = Field(default_factory=lambda: DatabaseClient.ConfigSchema()) + server_ip_address: Optional[IPv4Address] = None + """The IPv4 address of the Database Service server, defaults to None.""" server_password: Optional[str] = None _query_success_tracker: Dict[str, bool] = {} """Keep track of connections that were established or verified during this step. Used for rewards.""" @@ -93,6 +101,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"): kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"] kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) + self.server_ip_address = self.config.db_server_ip + self.server_password = self.config.server_password def _init_request_manager(self) -> RequestManager: """ diff --git a/src/primaite/simulator/system/applications/nmap.py b/src/primaite/simulator/system/applications/nmap.py index f064eae3..3eeda4b6 100644 --- a/src/primaite/simulator/system/applications/nmap.py +++ b/src/primaite/simulator/system/applications/nmap.py @@ -3,7 +3,7 @@ from ipaddress import IPv4Address, IPv4Network from typing import Any, Dict, Final, List, Optional, Set, Tuple, Union from prettytable import PrettyTable -from pydantic import validate_call +from pydantic import Field, validate_call from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent @@ -52,6 +52,13 @@ class NMAP(Application, identifier="NMAP"): as ping scans to discover active hosts and port scans to detect open ports on those hosts. """ + class ConfigSchema(Application.ConfigSchema): + """ConfigSchema for NMAP.""" + + type: str = "NMAP" + + config: "NMAP.ConfigSchema" = Field(default_factory=lambda: NMAP.ConfigSchema()) + _active_port_scans: Dict[str, PortScanPayload] = {} _port_scan_responses: Dict[str, PortScanPayload] = {} diff --git a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py index 4cd54d69..71a896bc 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py @@ -2,9 +2,9 @@ from abc import abstractmethod from enum import Enum from ipaddress import IPv4Address -from typing import Dict, Optional, Union +from typing import Dict, Optional, Set, Union -from pydantic import BaseModel, Field, validate_call +from pydantic import Field, validate_call from primaite.interface.request import RequestResponse from primaite.simulator.file_system.file_system import FileSystem, Folder @@ -45,10 +45,10 @@ class C2Payload(Enum): """C2 Input Command payload. Used by the C2 Server to send a command to the c2 beacon.""" OUTPUT = "output_command" - """C2 Output Command. Used by the C2 Beacon to send the results of a Input command to the c2 server.""" + """C2 Output Command. Used by the C2 Beacon to send the results of an Input command to the c2 server.""" -class AbstractC2(Application, identifier="AbstractC2"): +class AbstractC2(Application): """ An abstract command and control (c2) application. @@ -60,9 +60,25 @@ class AbstractC2(Application, identifier="AbstractC2"): Defaults to masquerading as HTTP (Port 80) via TCP. - Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite. + Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite. """ + class ConfigSchema(Application.ConfigSchema): + """Configuration for AbstractC2.""" + + keep_alive_frequency: int = Field(default=5, ge=1) + """The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon.""" + + masquerade_protocol: IPProtocol = Field(default=PROTOCOL_LOOKUP["TCP"]) + """The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP.""" + + masquerade_port: Port = Field(default=PORT_LOOKUP["HTTP"]) + """The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP.""" + + listen_on_ports: Set[Port] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]} + + config: ConfigSchema = Field(default_factory=lambda: AbstractC2.ConfigSchema()) + c2_connection_active: bool = False """Indicates if the c2 server and c2 beacon are currently connected.""" @@ -75,19 +91,6 @@ class AbstractC2(Application, identifier="AbstractC2"): keep_alive_inactivity: int = 0 """Indicates how many timesteps since the last time the c2 application received a keep alive.""" - class _C2Opts(BaseModel): - """A Pydantic Schema for the different C2 configuration options.""" - - keep_alive_frequency: int = Field(default=5, ge=1) - """The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon.""" - - masquerade_protocol: IPProtocol = Field(default=PROTOCOL_LOOKUP["TCP"]) - """The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP.""" - - masquerade_port: Port = Field(default=PORT_LOOKUP["HTTP"]) - """The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP.""" - - c2_config: _C2Opts = _C2Opts() """ Holds the current configuration settings of the C2 Suite. @@ -100,6 +103,12 @@ class AbstractC2(Application, identifier="AbstractC2"): C2 beacon to reconfigure it's configuration settings. """ + def __init__(self, **kwargs): + """Initialise the C2 applications to by default listen for HTTP traffic.""" + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] + super().__init__(**kwargs) + def _craft_packet( self, c2_payload: C2Payload, c2_command: Optional[C2Command] = None, command_options: Optional[Dict] = {} ) -> C2Packet: @@ -118,13 +127,13 @@ class AbstractC2(Application, identifier="AbstractC2"): :type c2_command: C2Command. :param command_options: The relevant C2 Beacon parameters.F :type command_options: Dict - :return: Returns the construct C2Packet + :return: Returns the constructed C2Packet :rtype: C2Packet """ constructed_packet = C2Packet( - masquerade_protocol=self.c2_config.masquerade_protocol, - masquerade_port=self.c2_config.masquerade_port, - keep_alive_frequency=self.c2_config.keep_alive_frequency, + masquerade_protocol=self.config.masquerade_protocol, + masquerade_port=self.config.masquerade_port, + keep_alive_frequency=self.config.keep_alive_frequency, payload_type=c2_payload, command=c2_command, payload=command_options, @@ -140,13 +149,6 @@ class AbstractC2(Application, identifier="AbstractC2"): """ return super().describe_state() - def __init__(self, **kwargs): - """Initialise the C2 applications to by default listen for HTTP traffic.""" - kwargs["listen_on_ports"] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]} - kwargs["port"] = PORT_LOOKUP["NONE"] - kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] - super().__init__(**kwargs) - @property def _host_ftp_client(self) -> Optional[FTPClient]: """Return the FTPClient that is installed C2 Application's host. @@ -330,8 +332,8 @@ class AbstractC2(Application, identifier="AbstractC2"): if self.send( payload=keep_alive_packet, dest_ip_address=self.c2_remote_connection, - dest_port=self.c2_config.masquerade_port, - ip_protocol=self.c2_config.masquerade_protocol, + dest_port=self.config.masquerade_port, + ip_protocol=self.config.masquerade_protocol, session_id=session_id, ): # Setting the keep_alive_sent guard condition to True. This is used to prevent packet storms. @@ -340,8 +342,8 @@ class AbstractC2(Application, identifier="AbstractC2"): self.sys_log.info(f"{self.name}: Keep Alive sent to {self.c2_remote_connection}") self.sys_log.debug( f"{self.name}: Keep Alive sent to {self.c2_remote_connection} " - f"Masquerade Port: {self.c2_config.masquerade_port} " - f"Masquerade Protocol: {self.c2_config.masquerade_protocol} " + f"Masquerade Port: {self.config.masquerade_port} " + f"Masquerade Protocol: {self.config.masquerade_protocol} " ) return True else: @@ -376,15 +378,15 @@ class AbstractC2(Application, identifier="AbstractC2"): # Updating the C2 Configuration attribute. - self.c2_config.masquerade_port = payload.masquerade_port - self.c2_config.masquerade_protocol = payload.masquerade_protocol - self.c2_config.keep_alive_frequency = payload.keep_alive_frequency + self.config.masquerade_port = payload.masquerade_port + self.config.masquerade_protocol = payload.masquerade_protocol + self.config.keep_alive_frequency = payload.keep_alive_frequency self.sys_log.debug( f"{self.name}: C2 Config Resolved Config from Keep Alive:" - f"Masquerade Port: {self.c2_config.masquerade_port}" - f"Masquerade Protocol: {self.c2_config.masquerade_protocol}" - f"Keep Alive Frequency: {self.c2_config.keep_alive_frequency}" + f"Masquerade Port: {self.config.masquerade_port}" + f"Masquerade Protocol: {self.config.masquerade_protocol}" + f"Keep Alive Frequency: {self.config.keep_alive_frequency}" ) # This statement is intended to catch on the C2 Application that is listening for connection. @@ -410,8 +412,8 @@ class AbstractC2(Application, identifier="AbstractC2"): self.keep_alive_inactivity = 0 self.keep_alive_frequency = 5 self.c2_remote_connection = None - self.c2_config.masquerade_port = PORT_LOOKUP["HTTP"] - self.c2_config.masquerade_protocol = PROTOCOL_LOOKUP["TCP"] + self.config.masquerade_port = PORT_LOOKUP["HTTP"] + self.config.masquerade_protocol = PROTOCOL_LOOKUP["TCP"] @abstractmethod def _confirm_remote_connection(self, timestep: int) -> bool: diff --git a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py index b25eea6e..13918cd7 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py @@ -3,7 +3,7 @@ from ipaddress import IPv4Address from typing import Dict, Optional from prettytable import MARKDOWN, PrettyTable -from pydantic import validate_call +from pydantic import Field, validate_call from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType @@ -12,8 +12,9 @@ from primaite.simulator.system.applications.red_applications.c2 import ExfilOpts from primaite.simulator.system.applications.red_applications.c2.abstract_c2 import AbstractC2, C2Command, C2Payload from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.terminal.terminal import Terminal, TerminalClientConnection -from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP -from primaite.utils.validation.port import PORT_LOOKUP +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import Port, PORT_LOOKUP class C2Beacon(AbstractC2, identifier="C2Beacon"): @@ -32,15 +33,30 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): 2. Leveraging the terminal application to execute requests (dependent on the command given) 3. Sending the RequestResponse back to the C2 Server (Command output) - Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite. + Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite. """ + class ConfigSchema(AbstractC2.ConfigSchema): + """ConfigSchema for C2Beacon.""" + + type: str = "C2Beacon" + c2_server_ip_address: Optional[IPV4Address] = None + keep_alive_frequency: int = 5 + masquerade_protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"] + masquerade_port: Port = PORT_LOOKUP["HTTP"] + + config: ConfigSchema = Field(default_factory=lambda: C2Beacon.ConfigSchema()) + keep_alive_attempted: bool = False """Indicates if a keep alive has been attempted to be sent this timestep. Used to prevent packet storms.""" terminal_session: TerminalClientConnection = None "The currently in use terminal session." + def __init__(self, **kwargs): + kwargs["name"] = "C2Beacon" + super().__init__(**kwargs) + @property def _host_terminal(self) -> Optional[Terminal]: """Return the Terminal that is installed on the same machine as the C2 Beacon.""" @@ -119,10 +135,6 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): rm.add_request("configure", request_type=RequestType(func=_configure)) return rm - def __init__(self, **kwargs): - kwargs["name"] = "C2Beacon" - super().__init__(**kwargs) - # Configure is practically setter method for the ``c2.config`` attribute that also ties into the request manager. @validate_call def configure( @@ -146,7 +158,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): masquerade_port | What port should the C2 traffic use? (TCP or UDP) These configuration options are used to reassign the fields in the inherited inner class - ``c2_config``. + ``config``. If a connection is already in progress then this method also sends a keep alive to the C2 Server in order for the C2 Server to sync with the new configuration settings. @@ -162,9 +174,9 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): :return: Returns True if the configuration was successful, False otherwise. """ self.c2_remote_connection = IPv4Address(c2_server_ip_address) - self.c2_config.keep_alive_frequency = keep_alive_frequency - self.c2_config.masquerade_port = masquerade_port - self.c2_config.masquerade_protocol = masquerade_protocol + self.config.keep_alive_frequency = keep_alive_frequency + self.config.masquerade_port = masquerade_port + self.config.masquerade_protocol = masquerade_protocol self.sys_log.info( f"{self.name}: Configured {self.name} with remote C2 server connection: {c2_server_ip_address=}." ) @@ -263,14 +275,12 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): if self.send( payload=output_packet, dest_ip_address=self.c2_remote_connection, - dest_port=self.c2_config.masquerade_port, - ip_protocol=self.c2_config.masquerade_protocol, + dest_port=self.config.masquerade_port, + ip_protocol=self.config.masquerade_protocol, session_id=session_id, ): self.sys_log.info(f"{self.name}: Command output sent to {self.c2_remote_connection}") - self.sys_log.debug( - f"{self.name}: on {self.c2_config.masquerade_port} via {self.c2_config.masquerade_protocol}" - ) + self.sys_log.debug(f"{self.name}: on {self.config.masquerade_port} via {self.config.masquerade_protocol}") return True else: self.sys_log.warning( @@ -562,7 +572,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): :rtype bool: """ self.keep_alive_attempted = False # Resetting keep alive sent. - if self.keep_alive_inactivity == self.c2_config.keep_alive_frequency: + if self.keep_alive_inactivity == self.config.keep_alive_frequency: self.sys_log.info( f"{self.name}: Attempting to Send Keep Alive to {self.c2_remote_connection} at timestep {timestep}." ) @@ -627,9 +637,9 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): self.c2_connection_active, self.c2_remote_connection, self.keep_alive_inactivity, - self.c2_config.keep_alive_frequency, - self.c2_config.masquerade_protocol, - self.c2_config.masquerade_port, + self.config.keep_alive_frequency, + self.config.masquerade_protocol, + self.config.masquerade_port, ] ) print(table) diff --git a/src/primaite/simulator/system/applications/red_applications/c2/c2_server.py b/src/primaite/simulator/system/applications/red_applications/c2/c2_server.py index 654b86e7..9d2097e9 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/c2_server.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/c2_server.py @@ -2,7 +2,7 @@ from typing import Dict, Optional from prettytable import MARKDOWN, PrettyTable -from pydantic import validate_call +from pydantic import Field, validate_call from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType @@ -31,9 +31,16 @@ class C2Server(AbstractC2, identifier="C2Server"): 1. Sending commands to the C2 Beacon. (Command input) 2. Parsing terminal RequestResponses back to the Agent. - Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite. + Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite. """ + class ConfigSchema(AbstractC2.ConfigSchema): + """ConfigSchema for C2Server.""" + + type: str = "C2Server" + + config: ConfigSchema = Field(default_factory=lambda: C2Server.ConfigSchema()) + current_command_output: RequestResponse = None """The Request Response by the last command send. This attribute is updated by the method _handle_command_output.""" @@ -251,8 +258,8 @@ class C2Server(AbstractC2, identifier="C2Server"): payload=command_packet, dest_ip_address=self.c2_remote_connection, session_id=self.c2_session.uuid, - dest_port=self.c2_config.masquerade_port, - ip_protocol=self.c2_config.masquerade_protocol, + dest_port=self.config.masquerade_port, + ip_protocol=self.config.masquerade_protocol, ): self.sys_log.info(f"{self.name}: Successfully sent {given_command}.") self.sys_log.info(f"{self.name}: Awaiting command response {given_command}.") @@ -334,11 +341,11 @@ class C2Server(AbstractC2, identifier="C2Server"): :return: Returns False if the C2 beacon is considered dead. Otherwise True. :rtype bool: """ - if self.keep_alive_inactivity > self.c2_config.keep_alive_frequency: + if self.keep_alive_inactivity > self.config.keep_alive_frequency: self.sys_log.info(f"{self.name}: C2 Beacon connection considered dead due to inactivity.") self.sys_log.debug( f"{self.name}: Did not receive expected keep alive connection from {self.c2_remote_connection}" - f"{self.name}: Expected at timestep: {timestep} due to frequency: {self.c2_config.keep_alive_frequency}" + f"{self.name}: Expected at timestep: {timestep} due to frequency: {self.config.keep_alive_frequency}" f"{self.name}: Last Keep Alive received at {(timestep - self.keep_alive_inactivity)}" ) self._reset_c2_connection() @@ -389,8 +396,8 @@ class C2Server(AbstractC2, identifier="C2Server"): [ self.c2_connection_active, self.c2_remote_connection, - self.c2_config.masquerade_protocol, - self.c2_config.masquerade_port, + self.config.masquerade_protocol, + self.config.masquerade_port, ] ) print(table) diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index 0423087e..392cdfba 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -3,6 +3,8 @@ from enum import IntEnum from ipaddress import IPv4Address from typing import Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.game.science import simulate_trial from primaite.interface.request import RequestResponse @@ -10,6 +12,7 @@ from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -40,6 +43,18 @@ class DataManipulationAttackStage(IntEnum): class DataManipulationBot(Application, identifier="DataManipulationBot"): """A bot that simulates a script which performs a SQL injection attack.""" + class ConfigSchema(Application.ConfigSchema): + """Configuration schema for DataManipulationBot.""" + + type: str = "DataManipulationBot" + server_ip: Optional[IPV4Address] = None + server_password: Optional[str] = None + payload: str = "DELETE" + port_scan_p_of_success: float = 0.1 + data_manipulation_p_of_success: float = 0.1 + + config: "DataManipulationBot.ConfigSchema" = Field(default_factory=lambda: DataManipulationBot.ConfigSchema()) + payload: Optional[str] = None port_scan_p_of_success: float = 0.1 data_manipulation_p_of_success: float = 0.1 @@ -56,6 +71,12 @@ class DataManipulationBot(Application, identifier="DataManipulationBot"): super().__init__(**kwargs) self._db_connection: Optional[DatabaseClientConnection] = None + self.server_ip_address = self.config.server_ip + self.server_password = self.config.server_password + self.payload = self.config.payload + self.port_scan_p_of_success = self.config.port_scan_p_of_success + self.data_manipulation_p_of_success = self.config.data_manipulation_p_of_success + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index 99c4acb3..ea7a4d8d 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -3,11 +3,14 @@ from enum import IntEnum from ipaddress import IPv4Address from typing import Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.game.science import simulate_trial from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.utils.validation.ipv4_address import IPV4Address from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -32,6 +35,20 @@ class DoSAttackStage(IntEnum): class DoSBot(DatabaseClient, identifier="DoSBot"): """A bot that simulates a Denial of Service attack.""" + class ConfigSchema(DatabaseClient.ConfigSchema): + """ConfigSchema for DoSBot.""" + + type: str = "DoSBot" + target_ip_address: Optional[IPV4Address] = None + target_port: Port = PORT_LOOKUP["POSTGRES_SERVER"] + payload: Optional[str] = None + repeat: bool = False + port_scan_p_of_success: float = 0.1 + dos_intensity: float = 1.0 + max_sessions: int = 1000 + + config: "DoSBot.ConfigSchema" = Field(default_factory=lambda: DoSBot.ConfigSchema()) + target_ip_address: Optional[IPv4Address] = None """IP address of the target service.""" @@ -56,7 +73,13 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): def __init__(self, **kwargs): super().__init__(**kwargs) self.name = "DoSBot" - self.max_sessions = 1000 # override normal max sessions + self.target_ip_address = self.config.target_ip_address + self.target_port = self.config.target_port + self.payload = self.config.payload + self.repeat = self.config.repeat + self.port_scan_p_of_success = self.config.port_scan_p_of_success + self.dos_intensity = self.config.dos_intensity + self.max_sessions = self.config.max_sessions def _init_request_manager(self) -> RequestManager: """ diff --git a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py index 3a8ac5ae..114d5716 100644 --- a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py +++ b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py @@ -3,12 +3,14 @@ from ipaddress import IPv4Address from typing import Dict, Optional from prettytable import MARKDOWN, PrettyTable +from pydantic import Field from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address from primaite.utils.validation.port import PORT_LOOKUP @@ -18,6 +20,16 @@ class RansomwareScript(Application, identifier="RansomwareScript"): :ivar payload: The attack stage query payload. (Default ENCRYPT) """ + class ConfigSchema(Application.ConfigSchema): + """ConfigSchema for RansomwareScript.""" + + type: str = "RansomwareScript" + server_ip: Optional[IPV4Address] = None + server_password: Optional[str] = None + payload: str = "ENCRYPT" + + config: "RansomwareScript.ConfigSchema" = Field(default_factory=lambda: RansomwareScript.ConfigSchema()) + server_ip_address: Optional[IPv4Address] = None """IP address of node which hosts the database.""" server_password: Optional[str] = None @@ -32,6 +44,9 @@ class RansomwareScript(Application, identifier="RansomwareScript"): super().__init__(**kwargs) self._db_connection: Optional[DatabaseClientConnection] = None + self.server_ip_address = self.config.server_ip + self.server_password = self.config.server_password + self.payload = self.config.payload def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index ff185e2a..49f303b5 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -4,7 +4,7 @@ from ipaddress import IPv4Address from typing import Dict, List, Optional from urllib.parse import urlparse -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from primaite import getLogger from primaite.interface.request import RequestResponse @@ -30,7 +30,13 @@ class WebBrowser(Application, identifier="WebBrowser"): The application requests and loads web pages using its domain name and requesting IP addresses using DNS. """ - target_url: Optional[str] = None + class ConfigSchema(Application.ConfigSchema): + """ConfigSchema for WebBrowser.""" + + type: str = "WebBrowser" + target_url: Optional[str] = None + + config: "WebBrowser.ConfigSchema" = Field(default_factory=lambda: WebBrowser.ConfigSchema()) domain_name_ip_address: Optional[IPv4Address] = None "The IP address of the domain name for the webpage." @@ -86,7 +92,7 @@ class WebBrowser(Application, identifier="WebBrowser"): :param: url: The address of the web page the browser requests :type: url: str """ - url = url or self.target_url + url = url or self.config.target_url if not self._can_perform_action(): return False diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index f0ee6f7c..ddb30a3b 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -106,7 +106,7 @@ class SoftwareManager: return True return False - def install(self, software_class: Type[IOSoftware], **install_kwargs): + def install(self, software_class: Type[IOSoftware], software_config: Optional[IOSoftware.ConfigSchema] = None): """ Install an Application or Service. @@ -115,13 +115,22 @@ class SoftwareManager: if software_class in self._software_class_to_name_map: self.sys_log.warning(f"Cannot install {software_class} as it is already installed") return - software = software_class( - software_manager=self, - sys_log=self.sys_log, - file_system=self.file_system, - dns_server=self.dns_server, - **install_kwargs, - ) + if software_config is None: + software = software_class( + software_manager=self, + sys_log=self.sys_log, + file_system=self.file_system, + dns_server=self.dns_server, + ) + else: + software = software_class( + software_manager=self, + sys_log=self.sys_log, + file_system=self.file_system, + dns_server=self.dns_server, + config=software_config, + ) + software.parent = self.node if isinstance(software, Application): self.node.applications[software.uuid] = software diff --git a/src/primaite/simulator/system/services/arp/arp.py b/src/primaite/simulator/system/services/arp/arp.py index 31938e83..bbeec301 100644 --- a/src/primaite/simulator/system/services/arp/arp.py +++ b/src/primaite/simulator/system/services/arp/arp.py @@ -5,6 +5,7 @@ from abc import abstractmethod from typing import Any, Dict, Optional, Union from prettytable import MARKDOWN, PrettyTable +from pydantic import Field from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket @@ -14,7 +15,7 @@ from primaite.utils.validation.ipv4_address import IPV4Address from primaite.utils.validation.port import PORT_LOOKUP -class ARP(Service): +class ARP(Service, identifier="ARP"): """ The ARP (Address Resolution Protocol) Service. @@ -22,6 +23,13 @@ class ARP(Service): sends ARP requests and replies, and processes incoming ARP packets. """ + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for ARP.""" + + type: str = "ARP" + + config: "ARP.ConfigSchema" = Field(default_factory=lambda: ARP.ConfigSchema()) + arp: Dict[IPV4Address, ARPEntry] = {} def __init__(self, **kwargs): diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 3a5f5b31..4ba4c4d4 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -3,6 +3,8 @@ from ipaddress import IPv4Address from typing import Any, Dict, List, Literal, Optional, Union from uuid import uuid4 +from pydantic import Field + from primaite import getLogger from primaite.simulator.file_system.file_system import File from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus @@ -17,13 +19,21 @@ from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) -class DatabaseService(Service): +class DatabaseService(Service, identifier="DatabaseService"): """ A class for simulating a generic SQL Server service. This class inherits from the `Service` class and provides methods to simulate a SQL database. """ + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for DatabaseService.""" + + type: str = "DatabaseService" + backup_server_ip: Optional[IPv4Address] = None + + config: "DatabaseService.ConfigSchema" = Field(default_factory=lambda: DatabaseService.ConfigSchema()) + password: Optional[str] = None """Password that needs to be provided by clients if they want to connect to the DatabaseService.""" @@ -42,6 +52,7 @@ class DatabaseService(Service): kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self._create_db_file() + self.backup_server_ip = self.config.backup_server_ip def install(self): """ diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 02cf54ae..0756eb05 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -2,6 +2,8 @@ from ipaddress import IPv4Address from typing import Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest from primaite.simulator.system.core.software_manager import SoftwareManager @@ -12,9 +14,15 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) -class DNSClient(Service): +class DNSClient(Service, identifier="DNSClient"): """Represents a DNS Client as a Service.""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for DNSClient.""" + + type: str = "DNSClient" + + config: "DNSClient.ConfigSchema" = Field(default_factory=lambda: DNSClient.ConfigSchema()) dns_cache: Dict[str, IPv4Address] = {} "A dict of known mappings between domain/URLs names and IPv4 addresses." dns_server: Optional[IPv4Address] = None diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index b7c9a42c..3a1c0e18 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -3,6 +3,7 @@ from ipaddress import IPv4Address from typing import Any, Dict, Optional from prettytable import MARKDOWN, PrettyTable +from pydantic import Field from primaite import getLogger from primaite.simulator.network.protocols.dns import DNSPacket @@ -13,9 +14,17 @@ from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) -class DNSServer(Service): +class DNSServer(Service, identifier="DNSServer"): """Represents a DNS Server as a Service.""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for DNSServer.""" + + type: str = "DNSServer" + domain_mapping: dict = {} + + config: "DNSServer.ConfigSchema" = Field(default_factory=lambda: DNSServer.ConfigSchema()) + dns_table: Dict[str, IPv4Address] = {} "A dict of mappings between domain names and IPv4 addresses." diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index 9c7b91ce..16cefdd6 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -2,6 +2,8 @@ from ipaddress import IPv4Address from typing import Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType @@ -9,20 +11,28 @@ from primaite.simulator.file_system.file_system import File from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC +from primaite.simulator.system.services.service import Service from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) -class FTPClient(FTPServiceABC): +class FTPClient(FTPServiceABC, identifier="FTPClient"): """ A class for simulating an FTP client service. - This class inherits from the `Service` class and provides methods to emulate FTP + This class inherits from the `FTPServiceABC` class and provides methods to emulate FTP RFC 959: https://datatracker.ietf.org/doc/html/rfc959 """ + config: "FTPClient.ConfigSchema" = Field(default_factory=lambda: FTPClient.ConfigSchema()) + + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for FTPClient.""" + + type: str = "FTPClient" + def __init__(self, **kwargs): kwargs["name"] = "FTPClient" kwargs["port"] = PORT_LOOKUP["FTP"] diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index 9ce7d658..054bfe15 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -1,26 +1,36 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Any, Optional +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC +from primaite.simulator.system.services.service import Service from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP from primaite.utils.validation.port import is_valid_port, PORT_LOOKUP _LOGGER = getLogger(__name__) -class FTPServer(FTPServiceABC): +class FTPServer(FTPServiceABC, identifier="FTPServer"): """ A class for simulating an FTP server service. - This class inherits from the `Service` class and provides methods to emulate FTP + This class inherits from the `FTPServiceABC` class and provides methods to emulate FTP RFC 959: https://datatracker.ietf.org/doc/html/rfc959 """ + config: "FTPServer.ConfigSchema" = Field(default_factory=lambda: FTPServer.ConfigSchema()) + server_password: Optional[str] = None """Password needed to connect to FTP server. Default is None.""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for FTPServer.""" + + type: str = "FTPServer" + def __init__(self, **kwargs): kwargs["name"] = "FTPServer" kwargs["port"] = PORT_LOOKUP["FTP"] diff --git a/src/primaite/simulator/system/services/icmp/icmp.py b/src/primaite/simulator/system/services/icmp/icmp.py index 933d0591..7f626945 100644 --- a/src/primaite/simulator/system/services/icmp/icmp.py +++ b/src/primaite/simulator/system/services/icmp/icmp.py @@ -3,6 +3,8 @@ import secrets from ipaddress import IPv4Address from typing import Any, Dict, Optional, Tuple, Union +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType @@ -14,7 +16,7 @@ from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) -class ICMP(Service): +class ICMP(Service, identifier="ICMP"): """ The Internet Control Message Protocol (ICMP) service. @@ -22,6 +24,13 @@ class ICMP(Service): network diagnostics, notably the ping command. """ + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for ICMP.""" + + type: str = "ICMP" + + config: "ICMP.ConfigSchema" = Field(default_factory=lambda: ICMP.ConfigSchema()) + request_replies: Dict = {} def __init__(self, **kwargs): diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index 9606c61f..fb470faf 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -3,6 +3,8 @@ from datetime import datetime from ipaddress import IPv4Address from typing import Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket from primaite.simulator.system.services.service import Service, ServiceOperatingState @@ -12,9 +14,16 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) -class NTPClient(Service): +class NTPClient(Service, identifier="NTPClient"): """Represents a NTP client as a service.""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for NTPClient.""" + + type: str = "NTPClient" + + config: "NTPClient.ConfigSchema" = Field(default_factory=lambda: NTPClient.ConfigSchema()) + ntp_server: Optional[IPv4Address] = None "The NTP server the client sends requests to." time: Optional[datetime] = None diff --git a/src/primaite/simulator/system/services/ntp/ntp_server.py b/src/primaite/simulator/system/services/ntp/ntp_server.py index 6e73ccc6..7af33893 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_server.py +++ b/src/primaite/simulator/system/services/ntp/ntp_server.py @@ -2,6 +2,8 @@ from datetime import datetime from typing import Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket from primaite.simulator.system.services.service import Service @@ -11,9 +13,16 @@ from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) -class NTPServer(Service): +class NTPServer(Service, identifier="NTPServer"): """Represents a NTP server as a service.""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for NTPServer.""" + + type: str = "NTPServer" + + config: "NTPServer.ConfigSchema" = Field(default_factory=lambda: NTPServer.ConfigSchema()) + def __init__(self, **kwargs): kwargs["name"] = "NTPServer" kwargs["port"] = PORT_LOOKUP["NTP"] diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 329aefef..c4a73301 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -1,10 +1,12 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations -from abc import abstractmethod +from abc import ABC, abstractmethod from enum import Enum from typing import Any, ClassVar, Dict, Optional, Type +from pydantic import Field + from primaite import getLogger from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType @@ -37,6 +39,13 @@ class Service(IOSoftware): Services are programs that run in the background and may perform input/output operations. """ + class ConfigSchema(IOSoftware.ConfigSchema, ABC): + """Config Schema for Service class.""" + + type: str + + config: "Service.ConfigSchema" = Field(default_factory=lambda: Service.ConfigSchema()) + operating_state: ServiceOperatingState = ServiceOperatingState.STOPPED "The current operating state of the Service." @@ -69,6 +78,21 @@ class Service(IOSoftware): raise ValueError(f"Tried to define new hostnode {identifier}, but this name is already reserved.") cls._registry[identifier] = cls + @classmethod + def from_config(cls, config: Dict) -> "Service": + """Create a service from a config dictionary. + + :param config: dict of options for service components constructor + :type config: dict + :return: The service component. + :rtype: Service + """ + if config["type"] not in cls._registry: + raise ValueError(f"Invalid service type {config['type']}") + service_class = cls._registry[config["type"]] + service_object = service_class(config=service_class.ConfigSchema(**config)) + return service_object + def _can_perform_action(self) -> bool: """ Checks if the service can perform actions. @@ -232,14 +256,14 @@ class Service(IOSoftware): def disable(self) -> bool: """Disable the service.""" - self.sys_log.info(f"Disabling Application {self.name}") + self.sys_log.info(f"Disabling Service {self.name}") self.operating_state = ServiceOperatingState.DISABLED return True def enable(self) -> bool: """Enable the disabled service.""" if self.operating_state == ServiceOperatingState.DISABLED: - self.sys_log.info(f"Enabling Application {self.name}") + self.sys_log.info(f"Enabling Service {self.name}") self.operating_state = ServiceOperatingState.STOPPED return True return False diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 1c249ebb..0d93248e 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -7,7 +7,7 @@ from ipaddress import IPv4Address from typing import Any, Dict, List, Optional, Union from uuid import uuid4 -from pydantic import BaseModel +from pydantic import BaseModel, Field from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType @@ -129,9 +129,16 @@ class RemoteTerminalConnection(TerminalClientConnection): return self.parent_terminal.send(payload=payload, session_id=self.ssh_session_id) -class Terminal(Service): +class Terminal(Service, identifier="Terminal"): """Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH.""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for Terminal.""" + + type: str = "Terminal" + + config: "Terminal.ConfigSchema" = Field(default_factory=lambda: Terminal.ConfigSchema()) + _client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {} """Dictionary of connect requests made to remote nodes.""" diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 1aab374d..51724002 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -3,6 +3,8 @@ from ipaddress import IPv4Address from typing import Any, Dict, List, Optional from urllib.parse import urlparse +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.protocols.http import ( HttpRequestMethod, @@ -19,9 +21,16 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) -class WebServer(Service): +class WebServer(Service, identifier="WebServer"): """Class used to represent a Web Server Service in simulation.""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for WebServer.""" + + type: str = "WebServer" + + config: "WebServer.ConfigSchema" = Field(default_factory=lambda: WebServer.ConfigSchema()) + response_codes_this_timestep: List[HttpStatusCode] = [] def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 34c893eb..25b2366c 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -1,13 +1,13 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import copy -from abc import abstractmethod +from abc import ABC, abstractmethod from datetime import datetime from enum import Enum from ipaddress import IPv4Address, IPv4Network from typing import Any, Dict, Optional, Set, TYPE_CHECKING, Union from prettytable import MARKDOWN, PrettyTable -from pydantic import Field +from pydantic import BaseModel, ConfigDict, Field from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent @@ -70,7 +70,7 @@ class SoftwareCriticality(Enum): "The highest level of criticality." -class Software(SimComponent): +class Software(SimComponent, ABC): """ A base class representing software in a simulator environment. @@ -78,14 +78,22 @@ class Software(SimComponent): It outlines the fundamental attributes and behaviors expected of any software in the simulation. """ + class ConfigSchema(BaseModel, ABC): + """Configurable options for all software.""" + + model_config = ConfigDict(extra="forbid") + starting_health_state: SoftwareHealthState = SoftwareHealthState.UNUSED + criticality: SoftwareCriticality = SoftwareCriticality.LOWEST + fixing_duration: int = 2 + + config: ConfigSchema = Field(default_factory=lambda: Software.ConfigSchema()) + name: str "The name of the software." health_state_actual: SoftwareHealthState = SoftwareHealthState.UNUSED "The actual health state of the software." health_state_visible: SoftwareHealthState = SoftwareHealthState.UNUSED "The health state of the software visible to the red agent." - criticality: SoftwareCriticality = SoftwareCriticality.LOWEST - "The criticality level of the software." fixing_count: int = 0 "The count of patches applied to the software, defaults to 0." scanning_count: int = 0 @@ -100,11 +108,13 @@ class Software(SimComponent): "The FileSystem of the Node the Software is installed on." folder: Optional[Folder] = None "The folder on the file system the Software uses." - fixing_duration: int = 2 - "The number of ticks it takes to patch the software." _fixing_countdown: Optional[int] = None "Current number of ticks left to patch the software." + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.health_state_actual = self.config.starting_health_state # don't remove this + def _init_request_manager(self) -> RequestManager: """ Initialise the request manager. @@ -152,7 +162,7 @@ class Software(SimComponent): { "health_state_actual": self.health_state_actual.value, "health_state_visible": self.health_state_visible.value, - "criticality": self.criticality.value, + "criticality": self.config.criticality.value, "fixing_count": self.fixing_count, "scanning_count": self.scanning_count, "revealed_to_red": self.revealed_to_red, @@ -201,7 +211,7 @@ class Software(SimComponent): def fix(self) -> bool: """Perform a fix on the software.""" if self.health_state_actual in (SoftwareHealthState.COMPROMISED, SoftwareHealthState.GOOD): - self._fixing_countdown = self.fixing_duration + self._fixing_countdown = self.config.fixing_duration self.set_health_state(SoftwareHealthState.FIXING) return True return False @@ -233,7 +243,7 @@ class Software(SimComponent): super().pre_timestep(timestep) -class IOSoftware(Software): +class IOSoftware(Software, ABC): """ Represents software in a simulator environment that is capable of input/output operations. @@ -243,6 +253,13 @@ class IOSoftware(Software): required. """ + class ConfigSchema(Software.ConfigSchema, ABC): + """Configuration options for all IO Software.""" + + listen_on_ports: Set[Port] = Field(default_factory=set) + + config: ConfigSchema = Field(default_factory=lambda: IOSoftware.ConfigSchema()) + installing_count: int = 0 "The number of times the software has been installed. Default is 0." max_sessions: int = 100 @@ -260,6 +277,10 @@ class IOSoftware(Software): _connections: Dict[str, Dict] = {} "Active connections." + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.listen_on_ports = self.config.listen_on_ports + @abstractmethod def describe_state(self) -> Dict: """ diff --git a/tests/assets/configs/fix_duration_one_item.yaml b/tests/assets/configs/fixing_duration_one_item.yaml similarity index 98% rename from tests/assets/configs/fix_duration_one_item.yaml rename to tests/assets/configs/fixing_duration_one_item.yaml index 6444e9e8..02aa8e4b 100644 --- a/tests/assets/configs/fix_duration_one_item.yaml +++ b/tests/assets/configs/fixing_duration_one_item.yaml @@ -148,7 +148,7 @@ simulation: options: db_server_ip: 192.168.1.10 server_password: arcd - fix_duration: 1 + fixing_duration: 1 - type: DataManipulationBot options: port_scan_p_of_success: 0.8 @@ -169,7 +169,7 @@ simulation: arcd.com: 192.168.1.10 - type: DatabaseService options: - fix_duration: 5 + fixing_duration: 5 backup_server_ip: 192.168.1.10 - type: WebServer - type: FTPClient diff --git a/tests/assets/configs/software_fix_duration.yaml b/tests/assets/configs/software_fixing_duration.yaml similarity index 91% rename from tests/assets/configs/software_fix_duration.yaml rename to tests/assets/configs/software_fixing_duration.yaml index 0059d18a..073a5f83 100644 --- a/tests/assets/configs/software_fix_duration.yaml +++ b/tests/assets/configs/software_fixing_duration.yaml @@ -142,19 +142,19 @@ simulation: applications: - type: NMAP options: - fix_duration: 1 + fixing_duration: 1 - type: RansomwareScript options: - fix_duration: 1 + fixing_duration: 1 - type: WebBrowser options: target_url: http://arcd.com/users/ - fix_duration: 1 + fixing_duration: 1 - type: DatabaseClient options: db_server_ip: 192.168.1.10 server_password: arcd - fix_duration: 1 + fixing_duration: 1 - type: DataManipulationBot options: port_scan_p_of_success: 0.8 @@ -162,43 +162,44 @@ simulation: payload: "DELETE" server_ip: 192.168.1.21 server_password: arcd - fix_duration: 1 + fixing_duration: 1 - type: DoSBot options: target_ip_address: 192.168.10.21 payload: SPOOF DATA port_scan_p_of_success: 0.8 - fix_duration: 1 + fixing_duration: 1 services: - type: DNSClient options: - fix_duration: 3 + dns_server: 192.168.1.10 + fixing_duration: 3 - type: DNSServer options: - fix_duration: 3 + fixing_duration: 3 domain_mapping: arcd.com: 192.168.1.10 - type: DatabaseService options: backup_server_ip: 192.168.1.10 - fix_duration: 3 + fixing_duration: 3 - type: WebServer options: - fix_duration: 3 + fixing_duration: 3 - type: FTPClient options: - fix_duration: 3 + fixing_duration: 3 - type: FTPServer options: server_password: arcd - fix_duration: 3 + fixing_duration: 3 - type: NTPClient options: ntp_server_ip: 192.168.1.10 - fix_duration: 3 + fixing_duration: 3 - type: NTPServer options: - fix_duration: 3 + fixing_duration: 3 - hostname: client_2 type: computer ip_address: 192.168.10.22 diff --git a/tests/conftest.py b/tests/conftest.py index aa4a0ef0..165ab30e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -39,9 +39,16 @@ ACTION_SPACE_NODE_ACTION_VALUES = 1 _LOGGER = getLogger(__name__) -class DummyService(Service): +class DummyService(Service, identifier="DummyService"): """Test Service class""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for DummyService.""" + + type: str = "DummyService" + + config: "DummyService.ConfigSchema" = Field(default_factory=lambda: DummyService.ConfigSchema()) + def describe_state(self) -> Dict: return super().describe_state() @@ -58,6 +65,13 @@ class DummyService(Service): class DummyApplication(Application, identifier="DummyApplication"): """Test Application class""" + class ConfigSchema(Application.ConfigSchema): + """ConfigSchema for DummyApplication.""" + + type: str = "DummyApplication" + + config: "DummyApplication.ConfigSchema" = Field(default_factory=lambda: DummyApplication.ConfigSchema()) + def __init__(self, **kwargs): kwargs["name"] = "DummyApplication" kwargs["port"] = PORT_LOOKUP["HTTP"] diff --git a/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py b/tests/integration_tests/configuration_file_parsing/test_software_fixing_duration.py similarity index 73% rename from tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py rename to tests/integration_tests/configuration_file_parsing/test_software_fixing_duration.py index b1c644cc..10896956 100644 --- a/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py +++ b/tests/integration_tests/configuration_file_parsing/test_software_fixing_duration.py @@ -13,8 +13,8 @@ from primaite.simulator.system.services.database.database_service import Databas from primaite.simulator.system.services.dns.dns_client import DNSClient from tests import TEST_ASSETS_ROOT -TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fix_duration.yaml" -ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fix_duration_one_item.yaml" +TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fixing_duration.yaml" +ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fixing_duration_one_item.yaml" TestApplications = ["DummyApplication", "BroadcastTestClient"] @@ -27,27 +27,27 @@ def load_config(config_path: Union[str, Path]) -> PrimaiteGame: return PrimaiteGame.from_config(cfg) -def test_default_fix_duration(): - """Test that software with no defined fix duration in config uses the default fix duration of 2.""" +def test_default_fixing_duration(): + """Test that software with no defined fixing duration in config uses the default fixing duration of 2.""" game = load_config(TEST_CONFIG) client_2: Computer = game.simulation.network.get_node_by_hostname("client_2") database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient") - assert database_client.fixing_duration == 2 + assert database_client.config.fixing_duration == 2 dns_client: DNSClient = client_2.software_manager.software.get("DNSClient") - assert dns_client.fixing_duration == 2 + assert dns_client.config.fixing_duration == 2 -def test_fix_duration_set_from_config(): - """Test to check that the fix duration set for applications and services works as intended.""" +def test_fixing_duration_set_from_config(): + """Test to check that the fixing duration set for applications and services works as intended.""" game = load_config(TEST_CONFIG) client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") # in config - services take 3 timesteps to fix for service in ["DNSClient", "DNSServer", "DatabaseService", "WebServer", "FTPClient", "FTPServer", "NTPServer"]: assert client_1.software_manager.software.get(service) is not None - assert client_1.software_manager.software.get(service).fixing_duration == 3 + assert client_1.software_manager.software.get(service).config.fixing_duration == 3 # in config - applications take 1 timestep to fix # remove test applications from list @@ -55,27 +55,27 @@ def test_fix_duration_set_from_config(): for application in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot", "DatabaseClient"]: assert client_1.software_manager.software.get(application) is not None - assert client_1.software_manager.software.get(application).fixing_duration == 1 + assert client_1.software_manager.software.get(application).config.fixing_duration == 1 -def test_fix_duration_for_one_item(): - """Test that setting fix duration for one application does not affect other components.""" +def test_fixing_duration_for_one_item(): + """Test that setting fixing duration for one application does not affect other components.""" game = load_config(ONE_ITEM_CONFIG) client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") # in config - services take 3 timesteps to fix for service in ["DNSClient", "DNSServer", "WebServer", "FTPClient", "FTPServer", "NTPServer"]: assert client_1.software_manager.software.get(service) is not None - assert client_1.software_manager.software.get(service).fixing_duration == 2 + assert client_1.software_manager.software.get(service).config.fixing_duration == 2 # in config - applications take 1 timestep to fix # remove test applications from list for applications in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot"]: assert client_1.software_manager.software.get(applications) is not None - assert client_1.software_manager.software.get(applications).fixing_duration == 2 + assert client_1.software_manager.software.get(applications).config.fixing_duration == 2 database_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") - assert database_client.fixing_duration == 1 + assert database_client.config.fixing_duration == 1 database_service: DatabaseService = client_1.software_manager.software.get("DatabaseService") - assert database_service.fixing_duration == 5 + assert database_service.config.fixing_duration == 5 diff --git a/tests/integration_tests/extensions/applications/extended_application.py b/tests/integration_tests/extensions/applications/extended_application.py index 9863dbba..159cfd06 100644 --- a/tests/integration_tests/extensions/applications/extended_application.py +++ b/tests/integration_tests/extensions/applications/extended_application.py @@ -4,7 +4,7 @@ from ipaddress import IPv4Address from typing import Dict, List, Optional from urllib.parse import urlparse -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from primaite import getLogger from primaite.interface.request import RequestResponse @@ -31,6 +31,14 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): The application requests and loads web pages using its domain name and requesting IP addresses using DNS. """ + class ConfigSchema(Application.ConfigSchema): + """ConfigSchema for ExtendedApplication.""" + + type: str = "ExtendedApplication" + target_url: Optional[str] = None + + config: "ExtendedApplication.ConfigSchema" = Field(default_factory=lambda: ExtendedApplication.ConfigSchema()) + target_url: Optional[str] = None domain_name_ip_address: Optional[IPv4Address] = None @@ -50,6 +58,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): kwargs["port"] = PORT_LOOKUP["HTTP"] super().__init__(**kwargs) + self.target_url = self.config.target_url self.run() def _init_request_manager(self) -> RequestManager: diff --git a/tests/integration_tests/extensions/services/extended_service.py b/tests/integration_tests/extensions/services/extended_service.py index 0924a91b..ba247369 100644 --- a/tests/integration_tests/extensions/services/extended_service.py +++ b/tests/integration_tests/extensions/services/extended_service.py @@ -3,6 +3,8 @@ from ipaddress import IPv4Address from typing import Any, Dict, List, Literal, Optional, Union from uuid import uuid4 +from pydantic import Field + from primaite import getLogger from primaite.simulator.file_system.file_system import File from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus @@ -17,13 +19,20 @@ from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) -class ExtendedService(Service, identifier="extendedservice"): +class ExtendedService(Service, identifier="ExtendedService"): """ A copy of DatabaseService that uses the extension framework instead of being part of PrimAITE. This class inherits from the `Service` class and provides methods to simulate a SQL database. """ + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for ExtendedService.""" + + type: str = "ExtendedService" + + config: "ExtendedService.ConfigSchema" = Field(default_factory=lambda: ExtendedService.ConfigSchema()) + password: Optional[str] = None """Password that needs to be provided by clients if they want to connect to the DatabaseService.""" diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index 0ad03198..bd9417ba 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -191,7 +191,7 @@ def test_nic_monitored_traffic(simulation): # send a database query browser: WebBrowser = pc.software_manager.software.get("WebBrowser") - browser.target_url = f"http://arcd.com/" + browser.config.target_url = f"http://arcd.com/" browser.get_webpage() traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC") diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index 80f359da..cef41e1b 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -183,7 +183,7 @@ def test_router_acl_removerule_integration(game_and_agent: Tuple[PrimaiteGame, P browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") browser.run() - browser.target_url = "http://www.example.com" + browser.config.target_url = "http://www.example.com" assert browser.get_webpage() # check that the browser can access example.com before we block it # 2: Remove rule that allows HTTP traffic across the network @@ -216,7 +216,7 @@ def test_host_nic_disable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") browser.run() - browser.target_url = "http://www.example.com" + browser.config.target_url = "http://www.example.com" assert browser.get_webpage() # check that the browser can access example.com before we block it # 2: Disable the NIC on client_1 @@ -416,7 +416,7 @@ def test_network_router_port_disable_integration(game_and_agent: Tuple[PrimaiteG browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") browser.run() - browser.target_url = "http://www.example.com" + browser.config.target_url = "http://www.example.com" assert browser.get_webpage() # check that the browser can access example.com before we block it # 2: Disable the NIC on client_1 @@ -476,7 +476,7 @@ def test_node_application_scan_integration(game_and_agent: Tuple[PrimaiteGame, P browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") browser.run() - browser.target_url = "http://www.example.com" + browser.config.target_url = "http://www.example.com" assert browser.get_webpage() # check that the browser can access example.com assert browser.health_state_actual == SoftwareHealthState.GOOD diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 91a022d5..de61b215 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -29,7 +29,7 @@ def test_WebpageUnavailablePenalty(game_and_agent: tuple[PrimaiteGame, Controlle client_1 = game.simulation.network.get_node_by_hostname("client_1") browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") browser.run() - browser.target_url = "http://www.example.com" + browser.config.target_url = "http://www.example.com" agent.reward_function.register_component(comp, 0.7) # Check that before trying to fetch the webpage, the reward is 0.0 diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index 33fe70c3..ed40334f 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -3,6 +3,7 @@ from ipaddress import IPv4Address, IPv4Network from typing import Any, Dict, List, Tuple import pytest +from pydantic import Field from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer @@ -14,9 +15,16 @@ from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP from primaite.utils.validation.port import PORT_LOOKUP -class BroadcastTestService(Service): +class BroadcastTestService(Service, identifier="BroadcastTestService"): """A service for sending broadcast and unicast messages over a network.""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for BroadcastTestService.""" + + type: str = "BroadcastTestService" + + config: "BroadcastTestService.ConfigSchema" = Field(default_factory=lambda: BroadcastTestService.ConfigSchema()) + def __init__(self, **kwargs): # Set default service properties for broadcasting kwargs["name"] = "BroadcastService" @@ -46,6 +54,13 @@ class BroadcastTestService(Service): class BroadcastTestClient(Application, identifier="BroadcastTestClient"): """A client application to receive broadcast and unicast messages.""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for BroadcastTestClient.""" + + type: str = "BroadcastTestClient" + + config: ConfigSchema = Field(default_factory=lambda: BroadcastTestClient.ConfigSchema()) + payloads_received: List = [] def __init__(self, **kwargs): diff --git a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py index d88f8249..40226be6 100644 --- a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py +++ b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py @@ -495,6 +495,12 @@ def test_c2_suite_yaml(): computer_b: Computer = yaml_network.get_node_by_hostname("node_b") c2_beacon: C2Beacon = computer_b.software_manager.software.get("C2Beacon") + c2_beacon.configure( + c2_server_ip_address=c2_beacon.config.c2_server_ip_address, + keep_alive_frequency=c2_beacon.config.keep_alive_frequency, + masquerade_port=c2_beacon.config.masquerade_port, + masquerade_protocol=c2_beacon.config.masquerade_protocol, + ) assert c2_server.operating_state == ApplicationOperatingState.RUNNING diff --git a/tests/integration_tests/system/test_database_on_node.py b/tests/integration_tests/system/test_database_on_node.py index 674603fa..31732f77 100644 --- a/tests/integration_tests/system/test_database_on_node.py +++ b/tests/integration_tests/system/test_database_on_node.py @@ -232,7 +232,7 @@ def test_database_service_fix(uc2_network): assert db_service.health_state_actual == SoftwareHealthState.FIXING # apply timestep until the fix is applied - for i in range(db_service.fixing_duration + 1): + for i in range(db_service.config.fixing_duration + 1): uc2_network.apply_timestep(i) assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD @@ -266,7 +266,7 @@ def test_database_cannot_be_queried_while_fixing(uc2_network): assert db_connection.query(sql="SELECT") is False # apply timestep until the fix is applied - for i in range(db_service.fixing_duration + 1): + for i in range(db_service.config.fixing_duration + 1): uc2_network.apply_timestep(i) assert db_service.health_state_actual == SoftwareHealthState.GOOD @@ -308,7 +308,7 @@ def test_database_can_create_connection_while_fixing(uc2_network): assert new_db_connection.query(sql="SELECT") is False # still should fail to query because FIXING # apply timestep until the fix is applied - for i in range(db_service.fixing_duration + 1): + for i in range(db_service.config.fixing_duration + 1): uc2_network.apply_timestep(i) assert db_service.health_state_actual == SoftwareHealthState.GOOD diff --git a/tests/integration_tests/system/test_service_listening_on_ports.py b/tests/integration_tests/system/test_service_listening_on_ports.py index 2d3679ed..84413ac9 100644 --- a/tests/integration_tests/system/test_service_listening_on_ports.py +++ b/tests/integration_tests/system/test_service_listening_on_ports.py @@ -14,7 +14,14 @@ from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT -class _DatabaseListener(Service): +class _DatabaseListener(Service, identifier="_DatabaseListener"): + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for _DatabaseListener.""" + + type: str = "_DatabaseListener" + listen_on_ports: Set[int] = {PORT_LOOKUP["POSTGRES_SERVER"]} + + config: "_DatabaseListener.ConfigSchema" = Field(default_factory=lambda: _DatabaseListener.ConfigSchema()) name: str = "DatabaseListener" protocol: str = PROTOCOL_LOOKUP["TCP"] port: int = PORT_LOOKUP["NONE"] diff --git a/tests/integration_tests/system/test_web_client_server.py b/tests/integration_tests/system/test_web_client_server.py index c1028e8e..8aea34c1 100644 --- a/tests/integration_tests/system/test_web_client_server.py +++ b/tests/integration_tests/system/test_web_client_server.py @@ -51,7 +51,7 @@ def test_web_page_get_users_page_request_with_domain_name(web_client_and_web_ser web_browser_app, computer, web_server_service, server = web_client_and_web_server web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address - web_browser_app.target_url = f"http://arcd.com/" + web_browser_app.config.target_url = f"http://arcd.com/" assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING assert web_browser_app.get_webpage() is True @@ -66,7 +66,7 @@ def test_web_page_get_users_page_request_with_ip_address(web_client_and_web_serv web_browser_app, computer, web_server_service, server = web_client_and_web_server web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address - web_browser_app.target_url = f"http://{web_server_ip}/" + web_browser_app.config.target_url = f"http://{web_server_ip}/" assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING assert web_browser_app.get_webpage() is True @@ -81,7 +81,7 @@ def test_web_page_request_from_shut_down_server(web_client_and_web_server): web_browser_app, computer, web_server_service, server = web_client_and_web_server web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address - web_browser_app.target_url = f"http://arcd.com/" + web_browser_app.config.target_url = f"http://arcd.com/" assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING assert web_browser_app.get_webpage() is True @@ -108,7 +108,7 @@ def test_web_page_request_from_closed_web_browser(web_client_and_web_server): web_browser_app, computer, web_server_service, server = web_client_and_web_server assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING - web_browser_app.target_url = f"http://arcd.com/" + web_browser_app.config.target_url = f"http://arcd.com/" assert web_browser_app.get_webpage() is True # latest response should have status code 200 diff --git a/tests/integration_tests/system/test_web_client_server_and_database.py b/tests/integration_tests/system/test_web_client_server_and_database.py index 8fb6dc18..41f1a837 100644 --- a/tests/integration_tests/system/test_web_client_server_and_database.py +++ b/tests/integration_tests/system/test_web_client_server_and_database.py @@ -74,7 +74,7 @@ def web_client_web_server_database(example_network) -> Tuple[Network, Computer, # Install Web Browser on computer computer.software_manager.install(WebBrowser) web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") - web_browser.target_url = "http://arcd.com/users/" + web_browser.config.target_url = "http://arcd.com/users/" web_browser.run() # Install DNS Client service on computer @@ -131,7 +131,7 @@ def test_database_fix_disrupts_web_client(uc2_network): assert web_browser.get_webpage() is False - for i in range(database_service.fixing_duration + 1): + for i in range(database_service.config.fixing_duration + 1): uc2_network.apply_timestep(i) assert database_service.health_state_actual == SoftwareHealthState.GOOD diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py index 4ff387ce..17f8445a 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py @@ -128,13 +128,13 @@ def test_c2_handle_switching_port(basic_c2_network): assert c2_server.c2_connection_active is True # Assert to confirm that both the C2 server and the C2 beacon are configured correctly. - assert c2_beacon.c2_config.keep_alive_frequency is 2 - assert c2_beacon.c2_config.masquerade_port is PORT_LOOKUP["HTTP"] - assert c2_beacon.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] + assert c2_beacon.config.keep_alive_frequency is 2 + assert c2_beacon.config.masquerade_port is PORT_LOOKUP["HTTP"] + assert c2_beacon.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] - assert c2_server.c2_config.keep_alive_frequency is 2 - assert c2_server.c2_config.masquerade_port is PORT_LOOKUP["HTTP"] - assert c2_server.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] + assert c2_server.config.keep_alive_frequency is 2 + assert c2_server.config.masquerade_port is PORT_LOOKUP["HTTP"] + assert c2_server.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] # Configuring the C2 Beacon. c2_beacon.configure( @@ -150,11 +150,11 @@ def test_c2_handle_switching_port(basic_c2_network): # Assert to confirm that both the C2 server and the C2 beacon # Have reconfigured their C2 settings. - assert c2_beacon.c2_config.masquerade_port is PORT_LOOKUP["FTP"] - assert c2_beacon.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] + assert c2_beacon.config.masquerade_port is PORT_LOOKUP["FTP"] + assert c2_beacon.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] - assert c2_server.c2_config.masquerade_port is PORT_LOOKUP["FTP"] - assert c2_server.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] + assert c2_server.config.masquerade_port is PORT_LOOKUP["FTP"] + assert c2_server.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] def test_c2_handle_switching_frequency(basic_c2_network): @@ -174,8 +174,8 @@ def test_c2_handle_switching_frequency(basic_c2_network): assert c2_server.c2_connection_active is True # Assert to confirm that both the C2 server and the C2 beacon are configured correctly. - assert c2_beacon.c2_config.keep_alive_frequency is 2 - assert c2_server.c2_config.keep_alive_frequency is 2 + assert c2_beacon.config.keep_alive_frequency is 2 + assert c2_server.config.keep_alive_frequency is 2 # Configuring the C2 Beacon. c2_beacon.configure(c2_server_ip_address="192.168.0.1", keep_alive_frequency=10) @@ -186,8 +186,8 @@ def test_c2_handle_switching_frequency(basic_c2_network): # Assert to confirm that both the C2 server and the C2 beacon # Have reconfigured their C2 settings. - assert c2_beacon.c2_config.keep_alive_frequency is 10 - assert c2_server.c2_config.keep_alive_frequency is 10 + assert c2_beacon.config.keep_alive_frequency is 10 + assert c2_server.config.keep_alive_frequency is 10 # Now skipping 9 time steps to confirm keep alive inactivity for i in range(9): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py index ad6fe135..5598e1a7 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py @@ -148,7 +148,7 @@ def test_service_fixing(service): service.fix() assert service.health_state_actual == SoftwareHealthState.FIXING - for i in range(service.fixing_duration + 1): + for i in range(service.config.fixing_duration + 1): service.apply_timestep(i) assert service.health_state_actual == SoftwareHealthState.GOOD diff --git a/tests/unit_tests/_primaite/_simulator/_system/test_software.py b/tests/unit_tests/_primaite/_simulator/_system/test_software.py index a203a636..bdf9cfee 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/test_software.py +++ b/tests/unit_tests/_primaite/_simulator/_system/test_software.py @@ -2,6 +2,7 @@ from typing import Dict import pytest +from pydantic import Field from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.service import Service @@ -10,7 +11,14 @@ from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP from primaite.utils.validation.port import PORT_LOOKUP -class TestSoftware(Service): +class TestSoftware(Service, identifier="TestSoftware"): + class ConfigSchema(Service.ConfigSchema): + """ConfigSChema for TestSoftware.""" + + type: str = "TestSoftware" + + config: "TestSoftware.ConfigSchema" = Field(default_factory=lambda: TestSoftware.ConfigSchema()) + def describe_state(self) -> Dict: pass