From c481847b01266e1cd93aa02ce805c3ff95bbd169 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 3 Jan 2025 13:39:58 +0000 Subject: [PATCH] #2888 - Software: align identifiers, tidy up schemas --- .../simulator/network/hardware/base.py | 22 +++---- .../system/applications/application.py | 22 +++---- .../system/applications/database_client.py | 15 ++--- .../simulator/system/applications/nmap.py | 14 ++-- .../red_applications/c2/abstract_c2.py | 66 +++++++++---------- .../red_applications/c2/c2_beacon.py | 39 +++++------ .../red_applications/c2/c2_server.py | 27 ++++---- .../red_applications/data_manipulation_bot.py | 9 +++ .../applications/red_applications/dos_bot.py | 6 +- .../red_applications/ransomware_script.py | 13 ++-- .../system/applications/web_browser.py | 14 ++-- .../simulator/system/services/arp/arp.py | 9 +-- .../services/database/database_service.py | 14 ++-- .../system/services/dns/dns_client.py | 18 ++--- .../system/services/dns/dns_server.py | 13 ++-- .../system/services/ftp/ftp_client.py | 6 +- .../system/services/ftp/ftp_server.py | 6 +- .../simulator/system/services/icmp/icmp.py | 10 +-- .../system/services/ntp/ntp_client.py | 14 ++-- .../system/services/ntp/ntp_server.py | 8 ++- .../system/services/terminal/terminal.py | 14 ++-- .../system/services/web_server/web_server.py | 12 ++-- tests/conftest.py | 13 ++-- .../applications/extended_application.py | 14 ++-- .../extensions/services/extended_service.py | 14 ++-- .../network/test_broadcast.py | 14 +++- .../system/test_service_listening_on_ports.py | 12 ++-- .../_red_applications/test_c2_suite.py | 28 ++++---- .../_simulator/_system/test_software.py | 8 +-- 29 files changed, 252 insertions(+), 222 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 96b1d9a7..a7278489 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -833,14 +833,14 @@ class UserManager(Service, identifier="UserManager"): :param disabled_admins: A dictionary of currently disabled admin users by their usernames """ - config: "UserManager.ConfigSchema" = None - - users: Dict[str, User] = {} - class ConfigSchema(Service.ConfigSchema): """ConfigSchema for UserManager.""" - type: str = "USER_MANAGER" + type: str = "UserManager" + + config: "UserManager.ConfigSchema" = Field(default_factory=lambda: UserManager.ConfigSchema()) + + users: Dict[str, User] = {} def __init__(self, **kwargs): """ @@ -1144,7 +1144,12 @@ class UserSessionManager(Service, identifier="UserSessionManager"): This class handles authentication, session management, and session timeouts for users interacting with the Node. """ - config: "UserSessionManager.ConfigSchema" = None + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for UserSessionManager.""" + + type: str = "UserSessionManager" + + config: "UserSessionManager.ConfigSchema" = Field(default_factory=lambda: UserSessionManager.ConfigSchema()) local_session: Optional[UserSession] = None """The current local user session, if any.""" @@ -1167,11 +1172,6 @@ class UserSessionManager(Service, identifier="UserSessionManager"): current_timestep: int = 0 """The current timestep in the simulation.""" - class ConfigSchema(Service.ConfigSchema): - """ConfigSchema for UserSessionManager.""" - - type: str = "USER_SESSION_MANAGER" - def __init__(self, **kwargs): """ Initializes a UserSessionManager instance. diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 29753cff..e0cac6b4 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from enum import Enum from typing import Any, ClassVar, Dict, Optional, Set, Type -from pydantic import BaseModel +from pydantic import BaseModel, Field from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType @@ -23,14 +23,19 @@ class ApplicationOperatingState(Enum): "The application is being installed or updated." -class Application(IOSoftware): +class Application(IOSoftware, ABC): """ Represents an Application in the simulation environment. Applications are user-facing programs that may perform input/output operations. """ - config: "Application.ConfigSchema" = None + class ConfigSchema(BaseModel, ABC): + """Config Schema for Application class.""" + + type: str + + config: ConfigSchema = Field(default_factory=lambda: Application.ConfigSchema()) operating_state: ApplicationOperatingState = ApplicationOperatingState.CLOSED "The current operating state of the Application." @@ -48,20 +53,15 @@ class Application(IOSoftware): _registry: ClassVar[Dict[str, Type["Application"]]] = {} """Registry of application types. Automatically populated when subclasses are defined.""" - class ConfigSchema(BaseModel, ABC): - """Config Schema for Application class.""" - - type: str - - def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None: """ Register an application type. :param identifier: Uniquely specifies an application class by name. Used for finding items by config. - :type identifier: str + :type identifier: Optional[str] :raises ValueError: When attempting to register an application with a name that is already allocated. """ - if identifier == "default": + if identifier is None: return super().__init_subclass__(**kwargs) if identifier in cls._registry: diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index d04f8298..facc4016 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Union from uuid import uuid4 from prettytable import MARKDOWN, PrettyTable -from pydantic import BaseModel +from pydantic import BaseModel, Field from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType @@ -67,10 +67,14 @@ class DatabaseClient(Application, identifier="DatabaseClient"): Extends the Application class to provide functionality for connecting, querying, and disconnecting from a Database Service. It mainly operates over TCP protocol. - """ - config: "DatabaseClient.ConfigSchema" = None + class ConfigSchema(Application.ConfigSchema): + """ConfigSchema for DatabaseClient.""" + + type: str = "DatabaseClient" + + config: ConfigSchema = Field(default_factory=lambda: DatabaseClient.ConfigSchema()) server_ip_address: Optional[IPv4Address] = None """The IPv4 address of the Database Service server, defaults to None.""" @@ -90,11 +94,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"): native_connection: Optional[DatabaseClientConnection] = None """Native Client Connection for using the client directly (similar to psql in a terminal).""" - class ConfigSchema(Application.ConfigSchema): - """ConfigSchema for DatabaseClient.""" - - type: str = "DATABASE_CLIENT" - def __init__(self, **kwargs): kwargs["name"] = "DatabaseClient" kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"] diff --git a/src/primaite/simulator/system/applications/nmap.py b/src/primaite/simulator/system/applications/nmap.py index 676515cc..3eeda4b6 100644 --- a/src/primaite/simulator/system/applications/nmap.py +++ b/src/primaite/simulator/system/applications/nmap.py @@ -3,7 +3,7 @@ from ipaddress import IPv4Address, IPv4Network from typing import Any, Dict, Final, List, Optional, Set, Tuple, Union from prettytable import PrettyTable -from pydantic import validate_call +from pydantic import Field, validate_call from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent @@ -52,7 +52,12 @@ class NMAP(Application, identifier="NMAP"): as ping scans to discover active hosts and port scans to detect open ports on those hosts. """ - config: "NMAP.ConfigSchema" = None + class ConfigSchema(Application.ConfigSchema): + """ConfigSchema for NMAP.""" + + type: str = "NMAP" + + config: "NMAP.ConfigSchema" = Field(default_factory=lambda: NMAP.ConfigSchema()) _active_port_scans: Dict[str, PortScanPayload] = {} _port_scan_responses: Dict[str, PortScanPayload] = {} @@ -64,11 +69,6 @@ class NMAP(Application, identifier="NMAP"): (False, False): "Port", } - class ConfigSchema(Application.ConfigSchema): - """ConfigSchema for NMAP.""" - - type: str = "NMAP" - def __init__(self, **kwargs): kwargs["name"] = "NMAP" kwargs["port"] = PORT_LOOKUP["NONE"] diff --git a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py index 960f8592..a379769d 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py @@ -4,7 +4,7 @@ from enum import Enum from ipaddress import IPv4Address from typing import Dict, Optional, Union -from pydantic import BaseModel, Field, validate_call +from pydantic import Field, validate_call from primaite.interface.request import RequestResponse from primaite.simulator.file_system.file_system import FileSystem, Folder @@ -48,7 +48,7 @@ class C2Payload(Enum): """C2 Output Command. Used by the C2 Beacon to send the results of an Input command to the c2 server.""" -class AbstractC2(Application, identifier="AbstractC2"): +class AbstractC2(Application): """ An abstract command and control (c2) application. @@ -63,7 +63,19 @@ class AbstractC2(Application, identifier="AbstractC2"): Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite. """ - config: "AbstractC2.ConfigSchema" = None + class ConfigSchema(Application.ConfigSchema): + """Configuration for AbstractC2.""" + + keep_alive_frequency: int = Field(default=5, ge=1) + """The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon.""" + + masquerade_protocol: IPProtocol = Field(default=PROTOCOL_LOOKUP["TCP"]) + """The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP.""" + + masquerade_port: Port = Field(default=PORT_LOOKUP["HTTP"]) + """The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP.""" + + config: ConfigSchema = Field(default_factory=lambda: AbstractC2.ConfigSchema()) c2_connection_active: bool = False """Indicates if the c2 server and c2 beacon are currently connected.""" @@ -77,24 +89,6 @@ class AbstractC2(Application, identifier="AbstractC2"): keep_alive_inactivity: int = 0 """Indicates how many timesteps since the last time the c2 application received a keep alive.""" - class ConfigSchema(Application.ConfigSchema): - """ConfigSchema for AbstractC2.""" - - type: str = "ABSTRACT_C2" - - class _C2Opts(BaseModel): - """A Pydantic Schema for the different C2 configuration options.""" - - keep_alive_frequency: int = Field(default=5, ge=1) - """The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon.""" - - masquerade_protocol: IPProtocol = Field(default=PROTOCOL_LOOKUP["TCP"]) - """The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP.""" - - masquerade_port: Port = Field(default=PORT_LOOKUP["HTTP"]) - """The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP.""" - - c2_config: _C2Opts = _C2Opts() """ Holds the current configuration settings of the C2 Suite. @@ -129,9 +123,9 @@ class AbstractC2(Application, identifier="AbstractC2"): :rtype: C2Packet """ constructed_packet = C2Packet( - masquerade_protocol=self.c2_config.masquerade_protocol, - masquerade_port=self.c2_config.masquerade_port, - keep_alive_frequency=self.c2_config.keep_alive_frequency, + masquerade_protocol=self.config.masquerade_protocol, + masquerade_port=self.config.masquerade_port, + keep_alive_frequency=self.config.keep_alive_frequency, payload_type=c2_payload, command=c2_command, payload=command_options, @@ -337,8 +331,8 @@ class AbstractC2(Application, identifier="AbstractC2"): if self.send( payload=keep_alive_packet, dest_ip_address=self.c2_remote_connection, - dest_port=self.c2_config.masquerade_port, - ip_protocol=self.c2_config.masquerade_protocol, + dest_port=self.config.masquerade_port, + ip_protocol=self.config.masquerade_protocol, session_id=session_id, ): # Setting the keep_alive_sent guard condition to True. This is used to prevent packet storms. @@ -347,8 +341,8 @@ class AbstractC2(Application, identifier="AbstractC2"): self.sys_log.info(f"{self.name}: Keep Alive sent to {self.c2_remote_connection}") self.sys_log.debug( f"{self.name}: Keep Alive sent to {self.c2_remote_connection} " - f"Masquerade Port: {self.c2_config.masquerade_port} " - f"Masquerade Protocol: {self.c2_config.masquerade_protocol} " + f"Masquerade Port: {self.config.masquerade_port} " + f"Masquerade Protocol: {self.config.masquerade_protocol} " ) return True else: @@ -383,15 +377,15 @@ class AbstractC2(Application, identifier="AbstractC2"): # Updating the C2 Configuration attribute. - self.c2_config.masquerade_port = payload.masquerade_port - self.c2_config.masquerade_protocol = payload.masquerade_protocol - self.c2_config.keep_alive_frequency = payload.keep_alive_frequency + self.config.masquerade_port = payload.masquerade_port + self.config.masquerade_protocol = payload.masquerade_protocol + self.config.keep_alive_frequency = payload.keep_alive_frequency self.sys_log.debug( f"{self.name}: C2 Config Resolved Config from Keep Alive:" - f"Masquerade Port: {self.c2_config.masquerade_port}" - f"Masquerade Protocol: {self.c2_config.masquerade_protocol}" - f"Keep Alive Frequency: {self.c2_config.keep_alive_frequency}" + f"Masquerade Port: {self.config.masquerade_port}" + f"Masquerade Protocol: {self.config.masquerade_protocol}" + f"Keep Alive Frequency: {self.config.keep_alive_frequency}" ) # This statement is intended to catch on the C2 Application that is listening for connection. @@ -417,8 +411,8 @@ class AbstractC2(Application, identifier="AbstractC2"): self.keep_alive_inactivity = 0 self.keep_alive_frequency = 5 self.c2_remote_connection = None - self.c2_config.masquerade_port = PORT_LOOKUP["HTTP"] - self.c2_config.masquerade_protocol = PROTOCOL_LOOKUP["TCP"] + self.config.masquerade_port = PORT_LOOKUP["HTTP"] + self.config.masquerade_protocol = PROTOCOL_LOOKUP["TCP"] @abstractmethod def _confirm_remote_connection(self, timestep: int) -> bool: diff --git a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py index abb620cd..014a4096 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py @@ -3,12 +3,11 @@ from ipaddress import IPv4Address from typing import Dict, Optional from prettytable import MARKDOWN, PrettyTable -from pydantic import validate_call +from pydantic import Field, validate_call from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.protocols.masquerade import C2Packet -from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.red_applications.c2 import ExfilOpts, RansomwareOpts, TerminalOpts from primaite.simulator.system.applications.red_applications.c2.abstract_c2 import AbstractC2, C2Command, C2Payload from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript @@ -36,7 +35,12 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite. """ - config: "C2Beacon.ConfigSchema" = None + class ConfigSchema(AbstractC2.ConfigSchema): + """ConfigSchema for C2Beacon.""" + + type: str = "C2Beacon" + + config: ConfigSchema = Field(default_factory=lambda: C2Beacon.ConfigSchema()) keep_alive_attempted: bool = False """Indicates if a keep alive has been attempted to be sent this timestep. Used to prevent packet storms.""" @@ -44,11 +48,6 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): terminal_session: TerminalClientConnection = None "The currently in use terminal session." - class ConfigSchema(Application.ConfigSchema): - """ConfigSchema for C2Beacon.""" - - type: str = "C2_BEACON" - @property def _host_terminal(self) -> Optional[Terminal]: """Return the Terminal that is installed on the same machine as the C2 Beacon.""" @@ -154,7 +153,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): masquerade_port | What port should the C2 traffic use? (TCP or UDP) These configuration options are used to reassign the fields in the inherited inner class - ``c2_config``. + ``config``. If a connection is already in progress then this method also sends a keep alive to the C2 Server in order for the C2 Server to sync with the new configuration settings. @@ -170,9 +169,9 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): :return: Returns True if the configuration was successful, False otherwise. """ self.c2_remote_connection = IPv4Address(c2_server_ip_address) - self.c2_config.keep_alive_frequency = keep_alive_frequency - self.c2_config.masquerade_port = masquerade_port - self.c2_config.masquerade_protocol = masquerade_protocol + self.config.keep_alive_frequency = keep_alive_frequency + self.config.masquerade_port = masquerade_port + self.config.masquerade_protocol = masquerade_protocol self.sys_log.info( f"{self.name}: Configured {self.name} with remote C2 server connection: {c2_server_ip_address=}." ) @@ -271,14 +270,12 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): if self.send( payload=output_packet, dest_ip_address=self.c2_remote_connection, - dest_port=self.c2_config.masquerade_port, - ip_protocol=self.c2_config.masquerade_protocol, + dest_port=self.config.masquerade_port, + ip_protocol=self.config.masquerade_protocol, session_id=session_id, ): self.sys_log.info(f"{self.name}: Command output sent to {self.c2_remote_connection}") - self.sys_log.debug( - f"{self.name}: on {self.c2_config.masquerade_port} via {self.c2_config.masquerade_protocol}" - ) + self.sys_log.debug(f"{self.name}: on {self.config.masquerade_port} via {self.config.masquerade_protocol}") return True else: self.sys_log.warning( @@ -570,7 +567,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): :rtype bool: """ self.keep_alive_attempted = False # Resetting keep alive sent. - if self.keep_alive_inactivity == self.c2_config.keep_alive_frequency: + if self.keep_alive_inactivity == self.config.keep_alive_frequency: self.sys_log.info( f"{self.name}: Attempting to Send Keep Alive to {self.c2_remote_connection} at timestep {timestep}." ) @@ -635,9 +632,9 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): self.c2_connection_active, self.c2_remote_connection, self.keep_alive_inactivity, - self.c2_config.keep_alive_frequency, - self.c2_config.masquerade_protocol, - self.c2_config.masquerade_port, + self.config.keep_alive_frequency, + self.config.masquerade_protocol, + self.config.masquerade_port, ] ) print(table) diff --git a/src/primaite/simulator/system/applications/red_applications/c2/c2_server.py b/src/primaite/simulator/system/applications/red_applications/c2/c2_server.py index 7308e8bc..9d2097e9 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/c2_server.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/c2_server.py @@ -2,12 +2,11 @@ from typing import Dict, Optional from prettytable import MARKDOWN, PrettyTable -from pydantic import validate_call +from pydantic import Field, validate_call from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.protocols.masquerade import C2Packet -from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.red_applications.c2 import ( CommandOpts, ExfilOpts, @@ -35,16 +34,16 @@ class C2Server(AbstractC2, identifier="C2Server"): Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite. """ - config: "C2Server.ConfigSchema" = None + class ConfigSchema(AbstractC2.ConfigSchema): + """ConfigSchema for C2Server.""" + + type: str = "C2Server" + + config: ConfigSchema = Field(default_factory=lambda: C2Server.ConfigSchema()) current_command_output: RequestResponse = None """The Request Response by the last command send. This attribute is updated by the method _handle_command_output.""" - class ConfigSchema(Application.ConfigSchema): - """ConfigSchema for C2Server.""" - - type: str = "C2_SERVER" - def _init_request_manager(self) -> RequestManager: """ Initialise the request manager. @@ -259,8 +258,8 @@ class C2Server(AbstractC2, identifier="C2Server"): payload=command_packet, dest_ip_address=self.c2_remote_connection, session_id=self.c2_session.uuid, - dest_port=self.c2_config.masquerade_port, - ip_protocol=self.c2_config.masquerade_protocol, + dest_port=self.config.masquerade_port, + ip_protocol=self.config.masquerade_protocol, ): self.sys_log.info(f"{self.name}: Successfully sent {given_command}.") self.sys_log.info(f"{self.name}: Awaiting command response {given_command}.") @@ -342,11 +341,11 @@ class C2Server(AbstractC2, identifier="C2Server"): :return: Returns False if the C2 beacon is considered dead. Otherwise True. :rtype bool: """ - if self.keep_alive_inactivity > self.c2_config.keep_alive_frequency: + if self.keep_alive_inactivity > self.config.keep_alive_frequency: self.sys_log.info(f"{self.name}: C2 Beacon connection considered dead due to inactivity.") self.sys_log.debug( f"{self.name}: Did not receive expected keep alive connection from {self.c2_remote_connection}" - f"{self.name}: Expected at timestep: {timestep} due to frequency: {self.c2_config.keep_alive_frequency}" + f"{self.name}: Expected at timestep: {timestep} due to frequency: {self.config.keep_alive_frequency}" f"{self.name}: Last Keep Alive received at {(timestep - self.keep_alive_inactivity)}" ) self._reset_c2_connection() @@ -397,8 +396,8 @@ class C2Server(AbstractC2, identifier="C2Server"): [ self.c2_connection_active, self.c2_remote_connection, - self.c2_config.masquerade_protocol, - self.c2_config.masquerade_port, + self.config.masquerade_protocol, + self.config.masquerade_port, ] ) print(table) diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index 0423087e..1978afb9 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -3,6 +3,8 @@ from enum import IntEnum from ipaddress import IPv4Address from typing import Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.game.science import simulate_trial from primaite.interface.request import RequestResponse @@ -40,6 +42,13 @@ class DataManipulationAttackStage(IntEnum): class DataManipulationBot(Application, identifier="DataManipulationBot"): """A bot that simulates a script which performs a SQL injection attack.""" + class ConfigSchema(Application.ConfigSchema): + """Configuration schema for DataManipulationBot.""" + + type: str = "DataManipulationBot" + + config: "DataManipulationBot.ConfigSchema" = Field(default_factory=lambda: DataManipulationBot.ConfigSchema()) + payload: Optional[str] = None port_scan_p_of_success: float = 0.1 data_manipulation_p_of_success: float = 0.1 diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index 0c337c53..e284ba92 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -3,6 +3,8 @@ from enum import IntEnum from ipaddress import IPv4Address from typing import Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.game.science import simulate_trial from primaite.interface.request import RequestFormat, RequestResponse @@ -33,7 +35,7 @@ class DoSAttackStage(IntEnum): class DoSBot(DatabaseClient, identifier="DoSBot"): """A bot that simulates a Denial of Service attack.""" - config: "DoSBot.ConfigSchema" = None + config: "DoSBot.ConfigSchema" = Field(default_factory=lambda: DoSBot.ConfigSchema()) target_ip_address: Optional[IPv4Address] = None """IP address of the target service.""" @@ -59,7 +61,7 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): class ConfigSchema(Application.ConfigSchema): """ConfigSchema for DoSBot.""" - type: str = "DOS_BOT" + type: str = "DoSBot" def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py index 3e6ed624..b72dc8e5 100644 --- a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py +++ b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py @@ -3,6 +3,7 @@ from ipaddress import IPv4Address from typing import Dict, Optional from prettytable import MARKDOWN, PrettyTable +from pydantic import Field from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType @@ -18,7 +19,12 @@ class RansomwareScript(Application, identifier="RansomwareScript"): :ivar payload: The attack stage query payload. (Default ENCRYPT) """ - config: "RansomwareScript.ConfigSchema" = None + class ConfigSchema(Application.ConfigSchema): + """ConfigSchema for RansomwareScript.""" + + type: str = "RansomwareScript" + + config: "RansomwareScript.ConfigSchema" = Field(default_factory=lambda: RansomwareScript.ConfigSchema()) server_ip_address: Optional[IPv4Address] = None """IP address of node which hosts the database.""" @@ -27,11 +33,6 @@ class RansomwareScript(Application, identifier="RansomwareScript"): payload: Optional[str] = "ENCRYPT" "Payload String for the payload stage" - class ConfigSchema(Application.ConfigSchema): - """ConfigSchema for RansomwareScript.""" - - type: str = "RANSOMWARE_SCRIPT" - def __init__(self, **kwargs): kwargs["name"] = "RansomwareScript" kwargs["port"] = PORT_LOOKUP["NONE"] diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 1bfe0e1a..52a566f2 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -4,7 +4,7 @@ from ipaddress import IPv4Address from typing import Dict, List, Optional from urllib.parse import urlparse -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from primaite import getLogger from primaite.interface.request import RequestResponse @@ -30,7 +30,12 @@ class WebBrowser(Application, identifier="WebBrowser"): The application requests and loads web pages using its domain name and requesting IP addresses using DNS. """ - config: "WebBrowser.ConfigSchema" = None + class ConfigSchema(Application.ConfigSchema): + """ConfigSchema for WebBrowser.""" + + type: str = "WebBrowser" + + config: "WebBrowser.ConfigSchema" = Field(default_factory=lambda: WebBrowser.ConfigSchema()) target_url: Optional[str] = None @@ -43,11 +48,6 @@ class WebBrowser(Application, identifier="WebBrowser"): history: List["BrowserHistoryItem"] = [] """Keep a log of visited websites and information about the visit, such as response code.""" - class ConfigSchema(Application.ConfigSchema): - """ConfigSchema for WebBrowser.""" - - type: str = "WEB_BROWSER" - def __init__(self, **kwargs): kwargs["name"] = "WebBrowser" kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] diff --git a/src/primaite/simulator/system/services/arp/arp.py b/src/primaite/simulator/system/services/arp/arp.py index 4f59bc15..bbeec301 100644 --- a/src/primaite/simulator/system/services/arp/arp.py +++ b/src/primaite/simulator/system/services/arp/arp.py @@ -5,6 +5,7 @@ from abc import abstractmethod from typing import Any, Dict, Optional, Union from prettytable import MARKDOWN, PrettyTable +from pydantic import Field from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket @@ -22,15 +23,15 @@ class ARP(Service, identifier="ARP"): sends ARP requests and replies, and processes incoming ARP packets. """ - config: "ARP.ConfigSchema" = None - - arp: Dict[IPV4Address, ARPEntry] = {} - class ConfigSchema(Service.ConfigSchema): """ConfigSchema for ARP.""" type: str = "ARP" + config: "ARP.ConfigSchema" = Field(default_factory=lambda: ARP.ConfigSchema()) + + arp: Dict[IPV4Address, ARPEntry] = {} + def __init__(self, **kwargs): kwargs["name"] = "ARP" kwargs["port"] = PORT_LOOKUP["ARP"] diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 68d75665..f16b4125 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -3,6 +3,8 @@ from ipaddress import IPv4Address from typing import Any, Dict, List, Literal, Optional, Union from uuid import uuid4 +from pydantic import Field + from primaite import getLogger from primaite.simulator.file_system.file_system import File from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus @@ -24,7 +26,12 @@ class DatabaseService(Service, identifier="DatabaseService"): This class inherits from the `Service` class and provides methods to simulate a SQL database. """ - config: "DatabaseService.ConfigSchema" = None + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for DatabaseService.""" + + type: str = "DatabaseService" + + config: "DatabaseService.ConfigSchema" = Field(default_factory=lambda: DatabaseService.ConfigSchema()) password: Optional[str] = None """Password that needs to be provided by clients if they want to connect to the DatabaseService.""" @@ -38,11 +45,6 @@ class DatabaseService(Service, identifier="DatabaseService"): latest_backup_file_name: str = None """File name of latest backup.""" - class ConfigSchema(Service.ConfigSchema): - """ConfigSchema for DatabaseService.""" - - type: str = "DATABASE_SERVICE" - def __init__(self, **kwargs): kwargs["name"] = "DatabaseService" kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"] diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index eb54ec71..0756eb05 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -2,6 +2,8 @@ from ipaddress import IPv4Address from typing import Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest from primaite.simulator.system.core.software_manager import SoftwareManager @@ -12,19 +14,19 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) -class DNSClient(Service): +class DNSClient(Service, identifier="DNSClient"): """Represents a DNS Client as a Service.""" - config: "DNSClient.ConfigSchema" = None - dns_cache: Dict[str, IPv4Address] = {} - "A dict of known mappings between domain/URLs names and IPv4 addresses." - dns_server: Optional[IPv4Address] = None - "The DNS Server the client sends requests to." - class ConfigSchema(Service.ConfigSchema): """ConfigSchema for DNSClient.""" - type: str = "DNS_CLIENT" + type: str = "DNSClient" + + config: "DNSClient.ConfigSchema" = Field(default_factory=lambda: DNSClient.ConfigSchema()) + dns_cache: Dict[str, IPv4Address] = {} + "A dict of known mappings between domain/URLs names and IPv4 addresses." + dns_server: Optional[IPv4Address] = None + "The DNS Server the client sends requests to." def __init__(self, **kwargs): kwargs["name"] = "DNSClient" diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index da302b6c..46008ddf 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -3,6 +3,7 @@ from ipaddress import IPv4Address from typing import Any, Dict, Optional from prettytable import MARKDOWN, PrettyTable +from pydantic import Field from primaite import getLogger from primaite.simulator.network.protocols.dns import DNSPacket @@ -16,15 +17,15 @@ _LOGGER = getLogger(__name__) class DNSServer(Service, identifier="DNSServer"): """Represents a DNS Server as a Service.""" - config: "DNSServer.ConfigSchema" = None - - dns_table: Dict[str, IPv4Address] = {} - "A dict of mappings between domain names and IPv4 addresses." - class ConfigSchema(Service.ConfigSchema): """ConfigSchema for DNSServer.""" - type: str = "DNS_SERVER" + type: str = "DNSServer" + + config: "DNSServer.ConfigSchema" = Field(default_factory=lambda: DNSServer.ConfigSchema()) + + dns_table: Dict[str, IPv4Address] = {} + "A dict of mappings between domain names and IPv4 addresses." def __init__(self, **kwargs): kwargs["name"] = "DNSServer" diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index 033d4602..16cefdd6 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -2,6 +2,8 @@ from ipaddress import IPv4Address from typing import Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType @@ -24,12 +26,12 @@ class FTPClient(FTPServiceABC, identifier="FTPClient"): RFC 959: https://datatracker.ietf.org/doc/html/rfc959 """ - config: "FTPClient.ConfigSchema" = None + config: "FTPClient.ConfigSchema" = Field(default_factory=lambda: FTPClient.ConfigSchema()) class ConfigSchema(Service.ConfigSchema): """ConfigSchema for FTPClient.""" - type: str = "FTP_CLIENT" + type: str = "FTPClient" def __init__(self, **kwargs): kwargs["name"] = "FTPClient" diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index 205ace21..054bfe15 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -1,6 +1,8 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import Any, Optional +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC @@ -19,7 +21,7 @@ class FTPServer(FTPServiceABC, identifier="FTPServer"): RFC 959: https://datatracker.ietf.org/doc/html/rfc959 """ - config: "FTPServer.ConfigSchema" = None + config: "FTPServer.ConfigSchema" = Field(default_factory=lambda: FTPServer.ConfigSchema()) server_password: Optional[str] = None """Password needed to connect to FTP server. Default is None.""" @@ -27,7 +29,7 @@ class FTPServer(FTPServiceABC, identifier="FTPServer"): class ConfigSchema(Service.ConfigSchema): """ConfigSchema for FTPServer.""" - type: str = "FTP_Server" + type: str = "FTPServer" def __init__(self, **kwargs): kwargs["name"] = "FTPServer" diff --git a/src/primaite/simulator/system/services/icmp/icmp.py b/src/primaite/simulator/system/services/icmp/icmp.py index 6d5355e7..7f626945 100644 --- a/src/primaite/simulator/system/services/icmp/icmp.py +++ b/src/primaite/simulator/system/services/icmp/icmp.py @@ -3,6 +3,8 @@ import secrets from ipaddress import IPv4Address from typing import Any, Dict, Optional, Tuple, Union +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType @@ -22,15 +24,15 @@ class ICMP(Service, identifier="ICMP"): network diagnostics, notably the ping command. """ - config: "ICMP.ConfigSchema" = None - - request_replies: Dict = {} - class ConfigSchema(Service.ConfigSchema): """ConfigSchema for ICMP.""" type: str = "ICMP" + config: "ICMP.ConfigSchema" = Field(default_factory=lambda: ICMP.ConfigSchema()) + + request_replies: Dict = {} + def __init__(self, **kwargs): kwargs["name"] = "ICMP" kwargs["port"] = PORT_LOOKUP["NONE"] diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index 6fc1f6fa..fb470faf 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -3,6 +3,8 @@ from datetime import datetime from ipaddress import IPv4Address from typing import Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket from primaite.simulator.system.services.service import Service, ServiceOperatingState @@ -15,17 +17,17 @@ _LOGGER = getLogger(__name__) class NTPClient(Service, identifier="NTPClient"): """Represents a NTP client as a service.""" - config: "NTPClient.ConfigSchema" = None + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for NTPClient.""" + + type: str = "NTPClient" + + config: "NTPClient.ConfigSchema" = Field(default_factory=lambda: NTPClient.ConfigSchema()) ntp_server: Optional[IPv4Address] = None "The NTP server the client sends requests to." time: Optional[datetime] = None - class ConfigSchema(Service.ConfigSchema): - """ConfigSchema for NTPClient.""" - - type: str = "NTP_CLIENT" - def __init__(self, **kwargs): kwargs["name"] = "NTPClient" kwargs["port"] = PORT_LOOKUP["NTP"] diff --git a/src/primaite/simulator/system/services/ntp/ntp_server.py b/src/primaite/simulator/system/services/ntp/ntp_server.py index a07d5f5c..7af33893 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_server.py +++ b/src/primaite/simulator/system/services/ntp/ntp_server.py @@ -2,6 +2,8 @@ from datetime import datetime from typing import Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket from primaite.simulator.system.services.service import Service @@ -14,12 +16,12 @@ _LOGGER = getLogger(__name__) class NTPServer(Service, identifier="NTPServer"): """Represents a NTP server as a service.""" - config: "NTPServer.ConfigSchema" = None - class ConfigSchema(Service.ConfigSchema): """ConfigSchema for NTPServer.""" - type: str = "NTP_SERVER" + type: str = "NTPServer" + + config: "NTPServer.ConfigSchema" = Field(default_factory=lambda: NTPServer.ConfigSchema()) def __init__(self, **kwargs): kwargs["name"] = "NTPServer" diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index c07af73e..f576d5ee 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -7,7 +7,7 @@ from ipaddress import IPv4Address from typing import Any, Dict, List, Optional, Union from uuid import uuid4 -from pydantic import BaseModel +from pydantic import BaseModel, Field from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType @@ -132,15 +132,15 @@ class RemoteTerminalConnection(TerminalClientConnection): class Terminal(Service, identifier="Terminal"): """Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH.""" - config: "Terminal.ConfigSchema" = None - - _client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {} - """Dictionary of connect requests made to remote nodes.""" - class ConfigSchema(Service.ConfigSchema): """ConfigSchema for Terminal.""" - type: str = "TERMINAL" + type: str = "Terminal" + + config: "Terminal.ConfigSchema" = Field(default_factory=lambda: Terminal.ConfigSchema()) + + _client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {} + """Dictionary of connect requests made to remote nodes.""" def __init__(self, **kwargs): kwargs["name"] = "Terminal" diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 70731df9..51724002 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -3,6 +3,8 @@ from ipaddress import IPv4Address from typing import Any, Dict, List, Optional from urllib.parse import urlparse +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.protocols.http import ( HttpRequestMethod, @@ -22,14 +24,14 @@ _LOGGER = getLogger(__name__) class WebServer(Service, identifier="WebServer"): """Class used to represent a Web Server Service in simulation.""" - config: "WebServer.ConfigSchema" = None - - response_codes_this_timestep: List[HttpStatusCode] = [] - class ConfigSchema(Service.ConfigSchema): """ConfigSchema for WebServer.""" - type: str = "WEB_SERVER" + type: str = "WebServer" + + config: "WebServer.ConfigSchema" = Field(default_factory=lambda: WebServer.ConfigSchema()) + + response_codes_this_timestep: List[HttpStatusCode] = [] def describe_state(self) -> Dict: """ diff --git a/tests/conftest.py b/tests/conftest.py index 2ef4904a..d1440bd2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Tuple import pytest import yaml +from pydantic import Field from ray import init as rayinit from primaite import getLogger, PRIMAITE_PATHS @@ -40,12 +41,12 @@ _LOGGER = getLogger(__name__) class DummyService(Service, identifier="DummyService"): """Test Service class""" - config: "DummyService.ConfigSchema" = None - class ConfigSchema(Service.ConfigSchema): """ConfigSchema for DummyService.""" - type: str = "DUMMY_SERVICE" + type: str = "DummyService" + + config: "DummyService.ConfigSchema" = Field(default_factory=lambda: DummyService.ConfigSchema()) def describe_state(self) -> Dict: return super().describe_state() @@ -63,12 +64,12 @@ class DummyService(Service, identifier="DummyService"): class DummyApplication(Application, identifier="DummyApplication"): """Test Application class""" - config: "DummyApplication.ConfigSchema" = None - class ConfigSchema(Application.ConfigSchema): """ConfigSchema for DummyApplication.""" - type: str = "DUMMY_APPLICATION" + type: str = "DummyApplication" + + config: "DummyApplication.ConfigSchema" = Field(default_factory=lambda: DummyApplication.ConfigSchema()) def __init__(self, **kwargs): kwargs["name"] = "DummyApplication" diff --git a/tests/integration_tests/extensions/applications/extended_application.py b/tests/integration_tests/extensions/applications/extended_application.py index f2d071b1..13fa3d1b 100644 --- a/tests/integration_tests/extensions/applications/extended_application.py +++ b/tests/integration_tests/extensions/applications/extended_application.py @@ -4,7 +4,7 @@ from ipaddress import IPv4Address from typing import Dict, List, Optional from urllib.parse import urlparse -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from primaite import getLogger from primaite.interface.request import RequestResponse @@ -31,7 +31,12 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): The application requests and loads web pages using its domain name and requesting IP addresses using DNS. """ - config: "ExtendedApplication.ConfigSchema" = None + class ConfigSchema(Application.ConfigSchema): + """ConfigSchema for ExtendedApplication.""" + + type: str = "ExtendedApplication" + + config: "ExtendedApplication.ConfigSchema" = Field(default_factory=lambda: ExtendedApplication.ConfigSchema()) target_url: Optional[str] = None @@ -44,11 +49,6 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): history: List["BrowserHistoryItem"] = [] """Keep a log of visited websites and information about the visit, such as response code.""" - class ConfigSchema(Application.ConfigSchema): - """ConfigSchema for ExtendedApplication.""" - - type: str = "EXTENDED_APPLICATION" - def __init__(self, **kwargs): kwargs["name"] = "ExtendedApplication" kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] diff --git a/tests/integration_tests/extensions/services/extended_service.py b/tests/integration_tests/extensions/services/extended_service.py index 5ec157b2..ba247369 100644 --- a/tests/integration_tests/extensions/services/extended_service.py +++ b/tests/integration_tests/extensions/services/extended_service.py @@ -3,6 +3,8 @@ from ipaddress import IPv4Address from typing import Any, Dict, List, Literal, Optional, Union from uuid import uuid4 +from pydantic import Field + from primaite import getLogger from primaite.simulator.file_system.file_system import File from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus @@ -24,7 +26,12 @@ class ExtendedService(Service, identifier="ExtendedService"): This class inherits from the `Service` class and provides methods to simulate a SQL database. """ - config: "ExtendedService.ConfigSchema" = None + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for ExtendedService.""" + + type: str = "ExtendedService" + + config: "ExtendedService.ConfigSchema" = Field(default_factory=lambda: ExtendedService.ConfigSchema()) password: Optional[str] = None """Password that needs to be provided by clients if they want to connect to the DatabaseService.""" @@ -38,11 +45,6 @@ class ExtendedService(Service, identifier="ExtendedService"): latest_backup_file_name: str = None """File name of latest backup.""" - class ConfigSchema(Service.ConfigSchema): - """ConfigSchema for ExtendedService.""" - - type: str = "EXTENDED_SERVICE" - def __init__(self, **kwargs): kwargs["name"] = "ExtendedService" kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"] diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index 37553727..ed40334f 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -3,6 +3,7 @@ from ipaddress import IPv4Address, IPv4Network from typing import Any, Dict, List, Tuple import pytest +from pydantic import Field from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer @@ -17,12 +18,12 @@ from primaite.utils.validation.port import PORT_LOOKUP class BroadcastTestService(Service, identifier="BroadcastTestService"): """A service for sending broadcast and unicast messages over a network.""" - config: "BroadcastTestService.ConfigSchema" = None - class ConfigSchema(Service.ConfigSchema): """ConfigSchema for BroadcastTestService.""" - type: str = "BROADCAST_TEST_SERVICE" + type: str = "BroadcastTestService" + + config: "BroadcastTestService.ConfigSchema" = Field(default_factory=lambda: BroadcastTestService.ConfigSchema()) def __init__(self, **kwargs): # Set default service properties for broadcasting @@ -53,6 +54,13 @@ class BroadcastTestService(Service, identifier="BroadcastTestService"): class BroadcastTestClient(Application, identifier="BroadcastTestClient"): """A client application to receive broadcast and unicast messages.""" + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for BroadcastTestClient.""" + + type: str = "BroadcastTestClient" + + config: ConfigSchema = Field(default_factory=lambda: BroadcastTestClient.ConfigSchema()) + payloads_received: List = [] def __init__(self, **kwargs): diff --git a/tests/integration_tests/system/test_service_listening_on_ports.py b/tests/integration_tests/system/test_service_listening_on_ports.py index bdfd56f0..a57bd539 100644 --- a/tests/integration_tests/system/test_service_listening_on_ports.py +++ b/tests/integration_tests/system/test_service_listening_on_ports.py @@ -15,18 +15,18 @@ from tests import TEST_ASSETS_ROOT class _DatabaseListener(Service, identifier="_DatabaseListener"): - config: "_DatabaseListener.ConfigSchema" = None + class ConfigSchema(Service.ConfigSchema): + """ConfigSchema for _DatabaseListener.""" + + type: str = "_DatabaseListener" + + config: "_DatabaseListener.ConfigSchema" = Field(default_factory=lambda: _DatabaseListener.ConfigSchema()) name: str = "DatabaseListener" protocol: str = PROTOCOL_LOOKUP["TCP"] port: int = PORT_LOOKUP["NONE"] listen_on_ports: Set[int] = {PORT_LOOKUP["POSTGRES_SERVER"]} payloads_received: List[Any] = Field(default_factory=list) - class ConfigSchema(Service.ConfigSchema): - """ConfigSchema for _DatabaseListener.""" - - type: str = "_DATABASE_LISTENER" - def receive(self, payload: Any, session_id: str, **kwargs) -> bool: self.payloads_received.append(payload) self.sys_log.info(f"{self.name}: received payload {payload}") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py index 4ff387ce..17f8445a 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py @@ -128,13 +128,13 @@ def test_c2_handle_switching_port(basic_c2_network): assert c2_server.c2_connection_active is True # Assert to confirm that both the C2 server and the C2 beacon are configured correctly. - assert c2_beacon.c2_config.keep_alive_frequency is 2 - assert c2_beacon.c2_config.masquerade_port is PORT_LOOKUP["HTTP"] - assert c2_beacon.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] + assert c2_beacon.config.keep_alive_frequency is 2 + assert c2_beacon.config.masquerade_port is PORT_LOOKUP["HTTP"] + assert c2_beacon.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] - assert c2_server.c2_config.keep_alive_frequency is 2 - assert c2_server.c2_config.masquerade_port is PORT_LOOKUP["HTTP"] - assert c2_server.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] + assert c2_server.config.keep_alive_frequency is 2 + assert c2_server.config.masquerade_port is PORT_LOOKUP["HTTP"] + assert c2_server.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] # Configuring the C2 Beacon. c2_beacon.configure( @@ -150,11 +150,11 @@ def test_c2_handle_switching_port(basic_c2_network): # Assert to confirm that both the C2 server and the C2 beacon # Have reconfigured their C2 settings. - assert c2_beacon.c2_config.masquerade_port is PORT_LOOKUP["FTP"] - assert c2_beacon.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] + assert c2_beacon.config.masquerade_port is PORT_LOOKUP["FTP"] + assert c2_beacon.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] - assert c2_server.c2_config.masquerade_port is PORT_LOOKUP["FTP"] - assert c2_server.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] + assert c2_server.config.masquerade_port is PORT_LOOKUP["FTP"] + assert c2_server.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] def test_c2_handle_switching_frequency(basic_c2_network): @@ -174,8 +174,8 @@ def test_c2_handle_switching_frequency(basic_c2_network): assert c2_server.c2_connection_active is True # Assert to confirm that both the C2 server and the C2 beacon are configured correctly. - assert c2_beacon.c2_config.keep_alive_frequency is 2 - assert c2_server.c2_config.keep_alive_frequency is 2 + assert c2_beacon.config.keep_alive_frequency is 2 + assert c2_server.config.keep_alive_frequency is 2 # Configuring the C2 Beacon. c2_beacon.configure(c2_server_ip_address="192.168.0.1", keep_alive_frequency=10) @@ -186,8 +186,8 @@ def test_c2_handle_switching_frequency(basic_c2_network): # Assert to confirm that both the C2 server and the C2 beacon # Have reconfigured their C2 settings. - assert c2_beacon.c2_config.keep_alive_frequency is 10 - assert c2_server.c2_config.keep_alive_frequency is 10 + assert c2_beacon.config.keep_alive_frequency is 10 + assert c2_server.config.keep_alive_frequency is 10 # Now skipping 9 time steps to confirm keep alive inactivity for i in range(9): diff --git a/tests/unit_tests/_primaite/_simulator/_system/test_software.py b/tests/unit_tests/_primaite/_simulator/_system/test_software.py index 46860836..bdf9cfee 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/test_software.py +++ b/tests/unit_tests/_primaite/_simulator/_system/test_software.py @@ -2,6 +2,7 @@ from typing import Dict import pytest +from pydantic import Field from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.service import Service @@ -11,13 +12,12 @@ from primaite.utils.validation.port import PORT_LOOKUP class TestSoftware(Service, identifier="TestSoftware"): - - config: "TestSoftware.ConfigSchema" = None - class ConfigSchema(Service.ConfigSchema): """ConfigSChema for TestSoftware.""" - type: str = "TEST_SOFTWARE" + type: str = "TestSoftware" + + config: "TestSoftware.ConfigSchema" = Field(default_factory=lambda: TestSoftware.ConfigSchema()) def describe_state(self) -> Dict: pass