Merge remote-tracking branch 'origin/4.0.0a1-dev' into feature/2869-Marek

This commit is contained in:
Marek Wolan
2025-01-20 10:39:20 +00:00
46 changed files with 559 additions and 252 deletions

View File

@@ -70,7 +70,7 @@ Python
Configuration
=============
The RansomwareScript inherits configuration options such as ``fix_duration`` from its parent class. However, for the ``RansomwareScript`` the most relevant option is ``server_ip``.
The RansomwareScript inherits configuration options such as ``fixing_duration`` from its parent class. However, for the ``RansomwareScript`` the most relevant option is ``server_ip``.
``server_ip``

View File

@@ -22,8 +22,8 @@ options
The configuration options are the attributes that fall under the options for an application or service.
fix_duration
""""""""""""
fixing_duration
"""""""""""""""
Optional. Default value is ``2``.

View File

@@ -44,7 +44,7 @@ from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.system.services.web_server.web_server import WebServer
from primaite.simulator.system.software import Software
from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP
from primaite.utils.validation.ip_protocol import IPProtocol
from primaite.utils.validation.port import Port, PORT_LOOKUP
_LOGGER = getLogger(__name__)
@@ -370,12 +370,12 @@ class PrimaiteGame:
if service_class is not None:
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
new_node.software_manager.install(service_class, **service_cfg.get("options", {}))
new_node.software_manager.install(service_class)
new_service = new_node.software_manager.software[service_class.__name__]
# fixing duration for the service
if "fix_duration" in service_cfg.get("options", {}):
new_service.fixing_duration = service_cfg["options"]["fix_duration"]
if "fixing_duration" in service_cfg.get("options", {}):
new_service.config.fixing_duration = service_cfg["options"]["fixing_duration"]
_set_software_listen_on_ports(new_service, service_cfg)
# start the service
@@ -416,74 +416,20 @@ class PrimaiteGame:
application_type = application_cfg["type"]
if application_type in Application._registry:
new_node.software_manager.install(Application._registry[application_type])
application_class = Application._registry[application_type]
application_options = application_cfg.get("options", {})
application_options["type"] = application_type
new_node.software_manager.install(application_class, software_config=application_options)
new_application = new_node.software_manager.software[application_type] # grab the instance
# fixing duration for the application
if "fix_duration" in application_cfg.get("options", {}):
new_application.fixing_duration = application_cfg["options"]["fix_duration"]
else:
msg = f"Configuration contains an invalid application type: {application_type}"
_LOGGER.error(msg)
raise ValueError(msg)
_set_software_listen_on_ports(new_application, application_cfg)
# run the application
new_application.run()
if application_type == "DataManipulationBot":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
server_ip_address=IPv4Address(opt.get("server_ip")),
server_password=opt.get("server_password"),
payload=opt.get("payload", "DELETE"),
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")),
)
elif application_type == "RansomwareScript":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
server_ip_address=IPv4Address(opt.get("server_ip")) if opt.get("server_ip") else None,
server_password=opt.get("server_password"),
payload=opt.get("payload", "ENCRYPT"),
)
elif application_type == "DatabaseClient":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
server_ip_address=IPv4Address(opt.get("db_server_ip")),
server_password=opt.get("server_password"),
)
elif application_type == "WebBrowser":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.target_url = opt.get("target_url")
elif application_type == "DoSBot":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
target_ip_address=IPv4Address(opt.get("target_ip_address")),
target_port=PORT_LOOKUP[opt.get("target_port", "POSTGRES_SERVER")],
payload=opt.get("payload"),
repeat=bool(opt.get("repeat")),
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
dos_intensity=float(opt.get("dos_intensity", "1.0")),
max_sessions=int(opt.get("max_sessions", "1000")),
)
elif application_type == "C2Beacon":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
c2_server_ip_address=IPv4Address(opt.get("c2_server_ip_address")),
keep_alive_frequency=(opt.get("keep_alive_frequency", 5)),
masquerade_protocol=PROTOCOL_LOOKUP[
(opt.get("masquerade_protocol", PROTOCOL_LOOKUP["TCP"]))
],
masquerade_port=PORT_LOOKUP[(opt.get("masquerade_port", PORT_LOOKUP["HTTP"]))],
)
if "network_interfaces" in node_cfg:
for nic_num, nic_cfg in node_cfg["network_interfaces"].items():
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))

View File

@@ -824,7 +824,7 @@ class User(SimComponent):
return self.model_dump()
class UserManager(Service):
class UserManager(Service, identifier="UserManager"):
"""
Manages users within the PrimAITE system, handling creation, authentication, and administration.
@@ -833,11 +833,18 @@ class UserManager(Service):
:param disabled_admins: A dictionary of currently disabled admin users by their usernames
"""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for UserManager."""
type: str = "UserManager"
config: "UserManager.ConfigSchema" = Field(default_factory=lambda: UserManager.ConfigSchema())
users: Dict[str, User] = {}
def __init__(self, **kwargs):
"""
Initializes a UserManager instanc.
Initializes a UserManager instance.
:param username: The username for the default admin user
:param password: The password for the default admin user
@@ -1130,13 +1137,20 @@ class RemoteUserSession(UserSession):
return state
class UserSessionManager(Service):
class UserSessionManager(Service, identifier="UserSessionManager"):
"""
Manages user sessions on a Node, including local and remote sessions.
This class handles authentication, session management, and session timeouts for users interacting with the Node.
"""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for UserSessionManager."""
type: str = "UserSessionManager"
config: "UserSessionManager.ConfigSchema" = Field(default_factory=lambda: UserSessionManager.ConfigSchema())
local_session: Optional[UserSession] = None
"""The current local user session, if any."""

View File

@@ -1,10 +1,12 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from __future__ import annotations
from abc import abstractmethod
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, ClassVar, Dict, Optional, Set, Type
from pydantic import Field
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
@@ -21,13 +23,20 @@ class ApplicationOperatingState(Enum):
"The application is being installed or updated."
class Application(IOSoftware):
class Application(IOSoftware, ABC):
"""
Represents an Application in the simulation environment.
Applications are user-facing programs that may perform input/output operations.
"""
class ConfigSchema(IOSoftware.ConfigSchema, ABC):
"""Config Schema for Application class."""
type: str
config: ConfigSchema = Field(default_factory=lambda: Application.ConfigSchema())
operating_state: ApplicationOperatingState = ApplicationOperatingState.CLOSED
"The current operating state of the Application."
execution_control_status: str = "manual"
@@ -49,7 +58,7 @@ class Application(IOSoftware):
Register an application type.
:param identifier: Uniquely specifies an application class by name. Used for finding items by config.
:type identifier: str
:type identifier: Optional[str]
:raises ValueError: When attempting to register an application with a name that is already allocated.
"""
super().__init_subclass__(**kwargs)
@@ -59,6 +68,21 @@ class Application(IOSoftware):
raise ValueError(f"Tried to define new application {identifier}, but this name is already reserved.")
cls._registry[identifier] = cls
@classmethod
def from_config(cls, config: Dict) -> "Application":
"""Create an application from a config dictionary.
:param config: dict of options for application components constructor
:type config: dict
:return: The application component.
:rtype: Application
"""
if config["type"] not in cls._registry:
raise ValueError(f"Invalid Application type {config['type']}")
application_class = cls._registry[config["type"]]
application_object = application_class(config=application_class.ConfigSchema(**config))
return application_object
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Union
from uuid import uuid4
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel
from pydantic import BaseModel, Field
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
@@ -67,11 +67,19 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
Extends the Application class to provide functionality for connecting, querying, and disconnecting from a
Database Service. It mainly operates over TCP protocol.
:ivar server_ip_address: The IPv4 address of the Database Service server, defaults to None.
"""
class ConfigSchema(Application.ConfigSchema):
"""ConfigSchema for DatabaseClient."""
type: str = "DatabaseClient"
db_server_ip: Optional[IPV4Address] = None
server_password: Optional[str] = None
config: ConfigSchema = Field(default_factory=lambda: DatabaseClient.ConfigSchema())
server_ip_address: Optional[IPv4Address] = None
"""The IPv4 address of the Database Service server, defaults to None."""
server_password: Optional[str] = None
_query_success_tracker: Dict[str, bool] = {}
"""Keep track of connections that were established or verified during this step. Used for rewards."""
@@ -93,6 +101,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"]
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
self.server_ip_address = self.config.db_server_ip
self.server_password = self.config.server_password
def _init_request_manager(self) -> RequestManager:
"""

View File

@@ -3,7 +3,7 @@ from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, Final, List, Optional, Set, Tuple, Union
from prettytable import PrettyTable
from pydantic import validate_call
from pydantic import Field, validate_call
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType, SimComponent
@@ -52,6 +52,13 @@ class NMAP(Application, identifier="NMAP"):
as ping scans to discover active hosts and port scans to detect open ports on those hosts.
"""
class ConfigSchema(Application.ConfigSchema):
"""ConfigSchema for NMAP."""
type: str = "NMAP"
config: "NMAP.ConfigSchema" = Field(default_factory=lambda: NMAP.ConfigSchema())
_active_port_scans: Dict[str, PortScanPayload] = {}
_port_scan_responses: Dict[str, PortScanPayload] = {}

View File

@@ -2,9 +2,9 @@
from abc import abstractmethod
from enum import Enum
from ipaddress import IPv4Address
from typing import Dict, Optional, Union
from typing import Dict, Optional, Set, Union
from pydantic import BaseModel, Field, validate_call
from pydantic import Field, validate_call
from primaite.interface.request import RequestResponse
from primaite.simulator.file_system.file_system import FileSystem, Folder
@@ -45,10 +45,10 @@ class C2Payload(Enum):
"""C2 Input Command payload. Used by the C2 Server to send a command to the c2 beacon."""
OUTPUT = "output_command"
"""C2 Output Command. Used by the C2 Beacon to send the results of a Input command to the c2 server."""
"""C2 Output Command. Used by the C2 Beacon to send the results of an Input command to the c2 server."""
class AbstractC2(Application, identifier="AbstractC2"):
class AbstractC2(Application):
"""
An abstract command and control (c2) application.
@@ -60,9 +60,25 @@ class AbstractC2(Application, identifier="AbstractC2"):
Defaults to masquerading as HTTP (Port 80) via TCP.
Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite.
Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite.
"""
class ConfigSchema(Application.ConfigSchema):
"""Configuration for AbstractC2."""
keep_alive_frequency: int = Field(default=5, ge=1)
"""The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon."""
masquerade_protocol: IPProtocol = Field(default=PROTOCOL_LOOKUP["TCP"])
"""The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP."""
masquerade_port: Port = Field(default=PORT_LOOKUP["HTTP"])
"""The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP."""
listen_on_ports: Set[Port] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]}
config: ConfigSchema = Field(default_factory=lambda: AbstractC2.ConfigSchema())
c2_connection_active: bool = False
"""Indicates if the c2 server and c2 beacon are currently connected."""
@@ -75,19 +91,6 @@ class AbstractC2(Application, identifier="AbstractC2"):
keep_alive_inactivity: int = 0
"""Indicates how many timesteps since the last time the c2 application received a keep alive."""
class _C2Opts(BaseModel):
"""A Pydantic Schema for the different C2 configuration options."""
keep_alive_frequency: int = Field(default=5, ge=1)
"""The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon."""
masquerade_protocol: IPProtocol = Field(default=PROTOCOL_LOOKUP["TCP"])
"""The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP."""
masquerade_port: Port = Field(default=PORT_LOOKUP["HTTP"])
"""The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP."""
c2_config: _C2Opts = _C2Opts()
"""
Holds the current configuration settings of the C2 Suite.
@@ -100,6 +103,12 @@ class AbstractC2(Application, identifier="AbstractC2"):
C2 beacon to reconfigure it's configuration settings.
"""
def __init__(self, **kwargs):
"""Initialise the C2 applications to by default listen for HTTP traffic."""
kwargs["port"] = PORT_LOOKUP["NONE"]
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
def _craft_packet(
self, c2_payload: C2Payload, c2_command: Optional[C2Command] = None, command_options: Optional[Dict] = {}
) -> C2Packet:
@@ -118,13 +127,13 @@ class AbstractC2(Application, identifier="AbstractC2"):
:type c2_command: C2Command.
:param command_options: The relevant C2 Beacon parameters.F
:type command_options: Dict
:return: Returns the construct C2Packet
:return: Returns the constructed C2Packet
:rtype: C2Packet
"""
constructed_packet = C2Packet(
masquerade_protocol=self.c2_config.masquerade_protocol,
masquerade_port=self.c2_config.masquerade_port,
keep_alive_frequency=self.c2_config.keep_alive_frequency,
masquerade_protocol=self.config.masquerade_protocol,
masquerade_port=self.config.masquerade_port,
keep_alive_frequency=self.config.keep_alive_frequency,
payload_type=c2_payload,
command=c2_command,
payload=command_options,
@@ -140,13 +149,6 @@ class AbstractC2(Application, identifier="AbstractC2"):
"""
return super().describe_state()
def __init__(self, **kwargs):
"""Initialise the C2 applications to by default listen for HTTP traffic."""
kwargs["listen_on_ports"] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]}
kwargs["port"] = PORT_LOOKUP["NONE"]
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
@property
def _host_ftp_client(self) -> Optional[FTPClient]:
"""Return the FTPClient that is installed C2 Application's host.
@@ -330,8 +332,8 @@ class AbstractC2(Application, identifier="AbstractC2"):
if self.send(
payload=keep_alive_packet,
dest_ip_address=self.c2_remote_connection,
dest_port=self.c2_config.masquerade_port,
ip_protocol=self.c2_config.masquerade_protocol,
dest_port=self.config.masquerade_port,
ip_protocol=self.config.masquerade_protocol,
session_id=session_id,
):
# Setting the keep_alive_sent guard condition to True. This is used to prevent packet storms.
@@ -340,8 +342,8 @@ class AbstractC2(Application, identifier="AbstractC2"):
self.sys_log.info(f"{self.name}: Keep Alive sent to {self.c2_remote_connection}")
self.sys_log.debug(
f"{self.name}: Keep Alive sent to {self.c2_remote_connection} "
f"Masquerade Port: {self.c2_config.masquerade_port} "
f"Masquerade Protocol: {self.c2_config.masquerade_protocol} "
f"Masquerade Port: {self.config.masquerade_port} "
f"Masquerade Protocol: {self.config.masquerade_protocol} "
)
return True
else:
@@ -376,15 +378,15 @@ class AbstractC2(Application, identifier="AbstractC2"):
# Updating the C2 Configuration attribute.
self.c2_config.masquerade_port = payload.masquerade_port
self.c2_config.masquerade_protocol = payload.masquerade_protocol
self.c2_config.keep_alive_frequency = payload.keep_alive_frequency
self.config.masquerade_port = payload.masquerade_port
self.config.masquerade_protocol = payload.masquerade_protocol
self.config.keep_alive_frequency = payload.keep_alive_frequency
self.sys_log.debug(
f"{self.name}: C2 Config Resolved Config from Keep Alive:"
f"Masquerade Port: {self.c2_config.masquerade_port}"
f"Masquerade Protocol: {self.c2_config.masquerade_protocol}"
f"Keep Alive Frequency: {self.c2_config.keep_alive_frequency}"
f"Masquerade Port: {self.config.masquerade_port}"
f"Masquerade Protocol: {self.config.masquerade_protocol}"
f"Keep Alive Frequency: {self.config.keep_alive_frequency}"
)
# This statement is intended to catch on the C2 Application that is listening for connection.
@@ -410,8 +412,8 @@ class AbstractC2(Application, identifier="AbstractC2"):
self.keep_alive_inactivity = 0
self.keep_alive_frequency = 5
self.c2_remote_connection = None
self.c2_config.masquerade_port = PORT_LOOKUP["HTTP"]
self.c2_config.masquerade_protocol = PROTOCOL_LOOKUP["TCP"]
self.config.masquerade_port = PORT_LOOKUP["HTTP"]
self.config.masquerade_protocol = PROTOCOL_LOOKUP["TCP"]
@abstractmethod
def _confirm_remote_connection(self, timestep: int) -> bool:

View File

@@ -3,7 +3,7 @@ from ipaddress import IPv4Address
from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from pydantic import validate_call
from pydantic import Field, validate_call
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
@@ -12,8 +12,9 @@ from primaite.simulator.system.applications.red_applications.c2 import ExfilOpts
from primaite.simulator.system.applications.red_applications.c2.abstract_c2 import AbstractC2, C2Command, C2Payload
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
from primaite.simulator.system.services.terminal.terminal import Terminal, TerminalClientConnection
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.port import PORT_LOOKUP
from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP
from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import Port, PORT_LOOKUP
class C2Beacon(AbstractC2, identifier="C2Beacon"):
@@ -32,15 +33,30 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
2. Leveraging the terminal application to execute requests (dependent on the command given)
3. Sending the RequestResponse back to the C2 Server (Command output)
Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite.
Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite.
"""
class ConfigSchema(AbstractC2.ConfigSchema):
"""ConfigSchema for C2Beacon."""
type: str = "C2Beacon"
c2_server_ip_address: Optional[IPV4Address] = None
keep_alive_frequency: int = 5
masquerade_protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"]
masquerade_port: Port = PORT_LOOKUP["HTTP"]
config: ConfigSchema = Field(default_factory=lambda: C2Beacon.ConfigSchema())
keep_alive_attempted: bool = False
"""Indicates if a keep alive has been attempted to be sent this timestep. Used to prevent packet storms."""
terminal_session: TerminalClientConnection = None
"The currently in use terminal session."
def __init__(self, **kwargs):
kwargs["name"] = "C2Beacon"
super().__init__(**kwargs)
@property
def _host_terminal(self) -> Optional[Terminal]:
"""Return the Terminal that is installed on the same machine as the C2 Beacon."""
@@ -119,10 +135,6 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
rm.add_request("configure", request_type=RequestType(func=_configure))
return rm
def __init__(self, **kwargs):
kwargs["name"] = "C2Beacon"
super().__init__(**kwargs)
# Configure is practically setter method for the ``c2.config`` attribute that also ties into the request manager.
@validate_call
def configure(
@@ -146,7 +158,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
masquerade_port | What port should the C2 traffic use? (TCP or UDP)
These configuration options are used to reassign the fields in the inherited inner class
``c2_config``.
``config``.
If a connection is already in progress then this method also sends a keep alive to the C2
Server in order for the C2 Server to sync with the new configuration settings.
@@ -162,9 +174,9 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
:return: Returns True if the configuration was successful, False otherwise.
"""
self.c2_remote_connection = IPv4Address(c2_server_ip_address)
self.c2_config.keep_alive_frequency = keep_alive_frequency
self.c2_config.masquerade_port = masquerade_port
self.c2_config.masquerade_protocol = masquerade_protocol
self.config.keep_alive_frequency = keep_alive_frequency
self.config.masquerade_port = masquerade_port
self.config.masquerade_protocol = masquerade_protocol
self.sys_log.info(
f"{self.name}: Configured {self.name} with remote C2 server connection: {c2_server_ip_address=}."
)
@@ -263,14 +275,12 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
if self.send(
payload=output_packet,
dest_ip_address=self.c2_remote_connection,
dest_port=self.c2_config.masquerade_port,
ip_protocol=self.c2_config.masquerade_protocol,
dest_port=self.config.masquerade_port,
ip_protocol=self.config.masquerade_protocol,
session_id=session_id,
):
self.sys_log.info(f"{self.name}: Command output sent to {self.c2_remote_connection}")
self.sys_log.debug(
f"{self.name}: on {self.c2_config.masquerade_port} via {self.c2_config.masquerade_protocol}"
)
self.sys_log.debug(f"{self.name}: on {self.config.masquerade_port} via {self.config.masquerade_protocol}")
return True
else:
self.sys_log.warning(
@@ -562,7 +572,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
:rtype bool:
"""
self.keep_alive_attempted = False # Resetting keep alive sent.
if self.keep_alive_inactivity == self.c2_config.keep_alive_frequency:
if self.keep_alive_inactivity == self.config.keep_alive_frequency:
self.sys_log.info(
f"{self.name}: Attempting to Send Keep Alive to {self.c2_remote_connection} at timestep {timestep}."
)
@@ -627,9 +637,9 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
self.c2_connection_active,
self.c2_remote_connection,
self.keep_alive_inactivity,
self.c2_config.keep_alive_frequency,
self.c2_config.masquerade_protocol,
self.c2_config.masquerade_port,
self.config.keep_alive_frequency,
self.config.masquerade_protocol,
self.config.masquerade_port,
]
)
print(table)

View File

@@ -2,7 +2,7 @@
from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from pydantic import validate_call
from pydantic import Field, validate_call
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
@@ -31,9 +31,16 @@ class C2Server(AbstractC2, identifier="C2Server"):
1. Sending commands to the C2 Beacon. (Command input)
2. Parsing terminal RequestResponses back to the Agent.
Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite.
Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite.
"""
class ConfigSchema(AbstractC2.ConfigSchema):
"""ConfigSchema for C2Server."""
type: str = "C2Server"
config: ConfigSchema = Field(default_factory=lambda: C2Server.ConfigSchema())
current_command_output: RequestResponse = None
"""The Request Response by the last command send. This attribute is updated by the method _handle_command_output."""
@@ -251,8 +258,8 @@ class C2Server(AbstractC2, identifier="C2Server"):
payload=command_packet,
dest_ip_address=self.c2_remote_connection,
session_id=self.c2_session.uuid,
dest_port=self.c2_config.masquerade_port,
ip_protocol=self.c2_config.masquerade_protocol,
dest_port=self.config.masquerade_port,
ip_protocol=self.config.masquerade_protocol,
):
self.sys_log.info(f"{self.name}: Successfully sent {given_command}.")
self.sys_log.info(f"{self.name}: Awaiting command response {given_command}.")
@@ -334,11 +341,11 @@ class C2Server(AbstractC2, identifier="C2Server"):
:return: Returns False if the C2 beacon is considered dead. Otherwise True.
:rtype bool:
"""
if self.keep_alive_inactivity > self.c2_config.keep_alive_frequency:
if self.keep_alive_inactivity > self.config.keep_alive_frequency:
self.sys_log.info(f"{self.name}: C2 Beacon connection considered dead due to inactivity.")
self.sys_log.debug(
f"{self.name}: Did not receive expected keep alive connection from {self.c2_remote_connection}"
f"{self.name}: Expected at timestep: {timestep} due to frequency: {self.c2_config.keep_alive_frequency}"
f"{self.name}: Expected at timestep: {timestep} due to frequency: {self.config.keep_alive_frequency}"
f"{self.name}: Last Keep Alive received at {(timestep - self.keep_alive_inactivity)}"
)
self._reset_c2_connection()
@@ -389,8 +396,8 @@ class C2Server(AbstractC2, identifier="C2Server"):
[
self.c2_connection_active,
self.c2_remote_connection,
self.c2_config.masquerade_protocol,
self.c2_config.masquerade_port,
self.config.masquerade_protocol,
self.config.masquerade_port,
]
)
print(table)

View File

@@ -3,6 +3,8 @@ from enum import IntEnum
from ipaddress import IPv4Address
from typing import Dict, Optional
from pydantic import Field
from primaite import getLogger
from primaite.game.science import simulate_trial
from primaite.interface.request import RequestResponse
@@ -10,6 +12,7 @@ from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import PORT_LOOKUP
_LOGGER = getLogger(__name__)
@@ -40,6 +43,18 @@ class DataManipulationAttackStage(IntEnum):
class DataManipulationBot(Application, identifier="DataManipulationBot"):
"""A bot that simulates a script which performs a SQL injection attack."""
class ConfigSchema(Application.ConfigSchema):
"""Configuration schema for DataManipulationBot."""
type: str = "DataManipulationBot"
server_ip: Optional[IPV4Address] = None
server_password: Optional[str] = None
payload: str = "DELETE"
port_scan_p_of_success: float = 0.1
data_manipulation_p_of_success: float = 0.1
config: "DataManipulationBot.ConfigSchema" = Field(default_factory=lambda: DataManipulationBot.ConfigSchema())
payload: Optional[str] = None
port_scan_p_of_success: float = 0.1
data_manipulation_p_of_success: float = 0.1
@@ -56,6 +71,12 @@ class DataManipulationBot(Application, identifier="DataManipulationBot"):
super().__init__(**kwargs)
self._db_connection: Optional[DatabaseClientConnection] = None
self.server_ip_address = self.config.server_ip
self.server_password = self.config.server_password
self.payload = self.config.payload
self.port_scan_p_of_success = self.config.port_scan_p_of_success
self.data_manipulation_p_of_success = self.config.data_manipulation_p_of_success
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.

View File

@@ -3,11 +3,14 @@ from enum import IntEnum
from ipaddress import IPv4Address
from typing import Dict, Optional
from pydantic import Field
from primaite import getLogger
from primaite.game.science import simulate_trial
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import Port, PORT_LOOKUP
_LOGGER = getLogger(__name__)
@@ -32,6 +35,20 @@ class DoSAttackStage(IntEnum):
class DoSBot(DatabaseClient, identifier="DoSBot"):
"""A bot that simulates a Denial of Service attack."""
class ConfigSchema(DatabaseClient.ConfigSchema):
"""ConfigSchema for DoSBot."""
type: str = "DoSBot"
target_ip_address: Optional[IPV4Address] = None
target_port: Port = PORT_LOOKUP["POSTGRES_SERVER"]
payload: Optional[str] = None
repeat: bool = False
port_scan_p_of_success: float = 0.1
dos_intensity: float = 1.0
max_sessions: int = 1000
config: "DoSBot.ConfigSchema" = Field(default_factory=lambda: DoSBot.ConfigSchema())
target_ip_address: Optional[IPv4Address] = None
"""IP address of the target service."""
@@ -56,7 +73,13 @@ class DoSBot(DatabaseClient, identifier="DoSBot"):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.name = "DoSBot"
self.max_sessions = 1000 # override normal max sessions
self.target_ip_address = self.config.target_ip_address
self.target_port = self.config.target_port
self.payload = self.config.payload
self.repeat = self.config.repeat
self.port_scan_p_of_success = self.config.port_scan_p_of_success
self.dos_intensity = self.config.dos_intensity
self.max_sessions = self.config.max_sessions
def _init_request_manager(self) -> RequestManager:
"""

View File

@@ -3,12 +3,14 @@ from ipaddress import IPv4Address
from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from pydantic import Field
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import PORT_LOOKUP
@@ -18,6 +20,16 @@ class RansomwareScript(Application, identifier="RansomwareScript"):
:ivar payload: The attack stage query payload. (Default ENCRYPT)
"""
class ConfigSchema(Application.ConfigSchema):
"""ConfigSchema for RansomwareScript."""
type: str = "RansomwareScript"
server_ip: Optional[IPV4Address] = None
server_password: Optional[str] = None
payload: str = "ENCRYPT"
config: "RansomwareScript.ConfigSchema" = Field(default_factory=lambda: RansomwareScript.ConfigSchema())
server_ip_address: Optional[IPv4Address] = None
"""IP address of node which hosts the database."""
server_password: Optional[str] = None
@@ -32,6 +44,9 @@ class RansomwareScript(Application, identifier="RansomwareScript"):
super().__init__(**kwargs)
self._db_connection: Optional[DatabaseClientConnection] = None
self.server_ip_address = self.config.server_ip
self.server_password = self.config.server_password
self.payload = self.config.payload
def describe_state(self) -> Dict:
"""

View File

@@ -4,7 +4,7 @@ from ipaddress import IPv4Address
from typing import Dict, List, Optional
from urllib.parse import urlparse
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from primaite import getLogger
from primaite.interface.request import RequestResponse
@@ -30,7 +30,13 @@ class WebBrowser(Application, identifier="WebBrowser"):
The application requests and loads web pages using its domain name and requesting IP addresses using DNS.
"""
target_url: Optional[str] = None
class ConfigSchema(Application.ConfigSchema):
"""ConfigSchema for WebBrowser."""
type: str = "WebBrowser"
target_url: Optional[str] = None
config: "WebBrowser.ConfigSchema" = Field(default_factory=lambda: WebBrowser.ConfigSchema())
domain_name_ip_address: Optional[IPv4Address] = None
"The IP address of the domain name for the webpage."
@@ -86,7 +92,7 @@ class WebBrowser(Application, identifier="WebBrowser"):
:param: url: The address of the web page the browser requests
:type: url: str
"""
url = url or self.target_url
url = url or self.config.target_url
if not self._can_perform_action():
return False

View File

@@ -106,7 +106,7 @@ class SoftwareManager:
return True
return False
def install(self, software_class: Type[IOSoftware], **install_kwargs):
def install(self, software_class: Type[IOSoftware], software_config: Optional[IOSoftware.ConfigSchema] = None):
"""
Install an Application or Service.
@@ -115,13 +115,22 @@ class SoftwareManager:
if software_class in self._software_class_to_name_map:
self.sys_log.warning(f"Cannot install {software_class} as it is already installed")
return
software = software_class(
software_manager=self,
sys_log=self.sys_log,
file_system=self.file_system,
dns_server=self.dns_server,
**install_kwargs,
)
if software_config is None:
software = software_class(
software_manager=self,
sys_log=self.sys_log,
file_system=self.file_system,
dns_server=self.dns_server,
)
else:
software = software_class(
software_manager=self,
sys_log=self.sys_log,
file_system=self.file_system,
dns_server=self.dns_server,
config=software_config,
)
software.parent = self.node
if isinstance(software, Application):
self.node.applications[software.uuid] = software

View File

@@ -5,6 +5,7 @@ from abc import abstractmethod
from typing import Any, Dict, Optional, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import Field
from primaite.simulator.network.hardware.base import NetworkInterface
from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket
@@ -14,7 +15,7 @@ from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import PORT_LOOKUP
class ARP(Service):
class ARP(Service, identifier="ARP"):
"""
The ARP (Address Resolution Protocol) Service.
@@ -22,6 +23,13 @@ class ARP(Service):
sends ARP requests and replies, and processes incoming ARP packets.
"""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for ARP."""
type: str = "ARP"
config: "ARP.ConfigSchema" = Field(default_factory=lambda: ARP.ConfigSchema())
arp: Dict[IPV4Address, ARPEntry] = {}
def __init__(self, **kwargs):

View File

@@ -3,6 +3,8 @@ from ipaddress import IPv4Address
from typing import Any, Dict, List, Literal, Optional, Union
from uuid import uuid4
from pydantic import Field
from primaite import getLogger
from primaite.simulator.file_system.file_system import File
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
@@ -17,13 +19,21 @@ from primaite.utils.validation.port import PORT_LOOKUP
_LOGGER = getLogger(__name__)
class DatabaseService(Service):
class DatabaseService(Service, identifier="DatabaseService"):
"""
A class for simulating a generic SQL Server service.
This class inherits from the `Service` class and provides methods to simulate a SQL database.
"""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for DatabaseService."""
type: str = "DatabaseService"
backup_server_ip: Optional[IPv4Address] = None
config: "DatabaseService.ConfigSchema" = Field(default_factory=lambda: DatabaseService.ConfigSchema())
password: Optional[str] = None
"""Password that needs to be provided by clients if they want to connect to the DatabaseService."""
@@ -42,6 +52,7 @@ class DatabaseService(Service):
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
self._create_db_file()
self.backup_server_ip = self.config.backup_server_ip
def install(self):
"""

View File

@@ -2,6 +2,8 @@
from ipaddress import IPv4Address
from typing import Dict, Optional
from pydantic import Field
from primaite import getLogger
from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest
from primaite.simulator.system.core.software_manager import SoftwareManager
@@ -12,9 +14,15 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP
_LOGGER = getLogger(__name__)
class DNSClient(Service):
class DNSClient(Service, identifier="DNSClient"):
"""Represents a DNS Client as a Service."""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for DNSClient."""
type: str = "DNSClient"
config: "DNSClient.ConfigSchema" = Field(default_factory=lambda: DNSClient.ConfigSchema())
dns_cache: Dict[str, IPv4Address] = {}
"A dict of known mappings between domain/URLs names and IPv4 addresses."
dns_server: Optional[IPv4Address] = None

View File

@@ -3,6 +3,7 @@ from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from pydantic import Field
from primaite import getLogger
from primaite.simulator.network.protocols.dns import DNSPacket
@@ -13,9 +14,17 @@ from primaite.utils.validation.port import PORT_LOOKUP
_LOGGER = getLogger(__name__)
class DNSServer(Service):
class DNSServer(Service, identifier="DNSServer"):
"""Represents a DNS Server as a Service."""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for DNSServer."""
type: str = "DNSServer"
domain_mapping: dict = {}
config: "DNSServer.ConfigSchema" = Field(default_factory=lambda: DNSServer.ConfigSchema())
dns_table: Dict[str, IPv4Address] = {}
"A dict of mappings between domain names and IPv4 addresses."

View File

@@ -2,6 +2,8 @@
from ipaddress import IPv4Address
from typing import Dict, Optional
from pydantic import Field
from primaite import getLogger
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
@@ -9,20 +11,28 @@ from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
from primaite.simulator.system.services.service import Service
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.port import Port, PORT_LOOKUP
_LOGGER = getLogger(__name__)
class FTPClient(FTPServiceABC):
class FTPClient(FTPServiceABC, identifier="FTPClient"):
"""
A class for simulating an FTP client service.
This class inherits from the `Service` class and provides methods to emulate FTP
This class inherits from the `FTPServiceABC` class and provides methods to emulate FTP
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
"""
config: "FTPClient.ConfigSchema" = Field(default_factory=lambda: FTPClient.ConfigSchema())
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for FTPClient."""
type: str = "FTPClient"
def __init__(self, **kwargs):
kwargs["name"] = "FTPClient"
kwargs["port"] = PORT_LOOKUP["FTP"]

View File

@@ -1,26 +1,36 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from typing import Any, Optional
from pydantic import Field
from primaite import getLogger
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
from primaite.simulator.system.services.service import Service
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.port import is_valid_port, PORT_LOOKUP
_LOGGER = getLogger(__name__)
class FTPServer(FTPServiceABC):
class FTPServer(FTPServiceABC, identifier="FTPServer"):
"""
A class for simulating an FTP server service.
This class inherits from the `Service` class and provides methods to emulate FTP
This class inherits from the `FTPServiceABC` class and provides methods to emulate FTP
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
"""
config: "FTPServer.ConfigSchema" = Field(default_factory=lambda: FTPServer.ConfigSchema())
server_password: Optional[str] = None
"""Password needed to connect to FTP server. Default is None."""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for FTPServer."""
type: str = "FTPServer"
def __init__(self, **kwargs):
kwargs["name"] = "FTPServer"
kwargs["port"] = PORT_LOOKUP["FTP"]

View File

@@ -3,6 +3,8 @@ import secrets
from ipaddress import IPv4Address
from typing import Any, Dict, Optional, Tuple, Union
from pydantic import Field
from primaite import getLogger
from primaite.simulator.network.hardware.base import NetworkInterface
from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType
@@ -14,7 +16,7 @@ from primaite.utils.validation.port import PORT_LOOKUP
_LOGGER = getLogger(__name__)
class ICMP(Service):
class ICMP(Service, identifier="ICMP"):
"""
The Internet Control Message Protocol (ICMP) service.
@@ -22,6 +24,13 @@ class ICMP(Service):
network diagnostics, notably the ping command.
"""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for ICMP."""
type: str = "ICMP"
config: "ICMP.ConfigSchema" = Field(default_factory=lambda: ICMP.ConfigSchema())
request_replies: Dict = {}
def __init__(self, **kwargs):

View File

@@ -3,6 +3,8 @@ from datetime import datetime
from ipaddress import IPv4Address
from typing import Dict, Optional
from pydantic import Field
from primaite import getLogger
from primaite.simulator.network.protocols.ntp import NTPPacket
from primaite.simulator.system.services.service import Service, ServiceOperatingState
@@ -12,9 +14,16 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP
_LOGGER = getLogger(__name__)
class NTPClient(Service):
class NTPClient(Service, identifier="NTPClient"):
"""Represents a NTP client as a service."""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for NTPClient."""
type: str = "NTPClient"
config: "NTPClient.ConfigSchema" = Field(default_factory=lambda: NTPClient.ConfigSchema())
ntp_server: Optional[IPv4Address] = None
"The NTP server the client sends requests to."
time: Optional[datetime] = None

View File

@@ -2,6 +2,8 @@
from datetime import datetime
from typing import Dict, Optional
from pydantic import Field
from primaite import getLogger
from primaite.simulator.network.protocols.ntp import NTPPacket
from primaite.simulator.system.services.service import Service
@@ -11,9 +13,16 @@ from primaite.utils.validation.port import PORT_LOOKUP
_LOGGER = getLogger(__name__)
class NTPServer(Service):
class NTPServer(Service, identifier="NTPServer"):
"""Represents a NTP server as a service."""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for NTPServer."""
type: str = "NTPServer"
config: "NTPServer.ConfigSchema" = Field(default_factory=lambda: NTPServer.ConfigSchema())
def __init__(self, **kwargs):
kwargs["name"] = "NTPServer"
kwargs["port"] = PORT_LOOKUP["NTP"]

View File

@@ -1,10 +1,12 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from __future__ import annotations
from abc import abstractmethod
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, ClassVar, Dict, Optional, Type
from pydantic import Field
from primaite import getLogger
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType
@@ -37,6 +39,13 @@ class Service(IOSoftware):
Services are programs that run in the background and may perform input/output operations.
"""
class ConfigSchema(IOSoftware.ConfigSchema, ABC):
"""Config Schema for Service class."""
type: str
config: "Service.ConfigSchema" = Field(default_factory=lambda: Service.ConfigSchema())
operating_state: ServiceOperatingState = ServiceOperatingState.STOPPED
"The current operating state of the Service."
@@ -69,6 +78,21 @@ class Service(IOSoftware):
raise ValueError(f"Tried to define new hostnode {identifier}, but this name is already reserved.")
cls._registry[identifier] = cls
@classmethod
def from_config(cls, config: Dict) -> "Service":
"""Create a service from a config dictionary.
:param config: dict of options for service components constructor
:type config: dict
:return: The service component.
:rtype: Service
"""
if config["type"] not in cls._registry:
raise ValueError(f"Invalid service type {config['type']}")
service_class = cls._registry[config["type"]]
service_object = service_class(config=service_class.ConfigSchema(**config))
return service_object
def _can_perform_action(self) -> bool:
"""
Checks if the service can perform actions.
@@ -232,14 +256,14 @@ class Service(IOSoftware):
def disable(self) -> bool:
"""Disable the service."""
self.sys_log.info(f"Disabling Application {self.name}")
self.sys_log.info(f"Disabling Service {self.name}")
self.operating_state = ServiceOperatingState.DISABLED
return True
def enable(self) -> bool:
"""Enable the disabled service."""
if self.operating_state == ServiceOperatingState.DISABLED:
self.sys_log.info(f"Enabling Application {self.name}")
self.sys_log.info(f"Enabling Service {self.name}")
self.operating_state = ServiceOperatingState.STOPPED
return True
return False

View File

@@ -7,7 +7,7 @@ from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional, Union
from uuid import uuid4
from pydantic import BaseModel
from pydantic import BaseModel, Field
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
@@ -129,9 +129,16 @@ class RemoteTerminalConnection(TerminalClientConnection):
return self.parent_terminal.send(payload=payload, session_id=self.ssh_session_id)
class Terminal(Service):
class Terminal(Service, identifier="Terminal"):
"""Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH."""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for Terminal."""
type: str = "Terminal"
config: "Terminal.ConfigSchema" = Field(default_factory=lambda: Terminal.ConfigSchema())
_client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {}
"""Dictionary of connect requests made to remote nodes."""

View File

@@ -3,6 +3,8 @@ from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from pydantic import Field
from primaite import getLogger
from primaite.simulator.network.protocols.http import (
HttpRequestMethod,
@@ -19,9 +21,16 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP
_LOGGER = getLogger(__name__)
class WebServer(Service):
class WebServer(Service, identifier="WebServer"):
"""Class used to represent a Web Server Service in simulation."""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for WebServer."""
type: str = "WebServer"
config: "WebServer.ConfigSchema" = Field(default_factory=lambda: WebServer.ConfigSchema())
response_codes_this_timestep: List[HttpStatusCode] = []
def describe_state(self) -> Dict:

View File

@@ -1,13 +1,13 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
import copy
from abc import abstractmethod
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, Optional, Set, TYPE_CHECKING, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import Field
from pydantic import BaseModel, ConfigDict, Field
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType, SimComponent
@@ -70,7 +70,7 @@ class SoftwareCriticality(Enum):
"The highest level of criticality."
class Software(SimComponent):
class Software(SimComponent, ABC):
"""
A base class representing software in a simulator environment.
@@ -78,14 +78,22 @@ class Software(SimComponent):
It outlines the fundamental attributes and behaviors expected of any software in the simulation.
"""
class ConfigSchema(BaseModel, ABC):
"""Configurable options for all software."""
model_config = ConfigDict(extra="forbid")
starting_health_state: SoftwareHealthState = SoftwareHealthState.UNUSED
criticality: SoftwareCriticality = SoftwareCriticality.LOWEST
fixing_duration: int = 2
config: ConfigSchema = Field(default_factory=lambda: Software.ConfigSchema())
name: str
"The name of the software."
health_state_actual: SoftwareHealthState = SoftwareHealthState.UNUSED
"The actual health state of the software."
health_state_visible: SoftwareHealthState = SoftwareHealthState.UNUSED
"The health state of the software visible to the red agent."
criticality: SoftwareCriticality = SoftwareCriticality.LOWEST
"The criticality level of the software."
fixing_count: int = 0
"The count of patches applied to the software, defaults to 0."
scanning_count: int = 0
@@ -100,11 +108,13 @@ class Software(SimComponent):
"The FileSystem of the Node the Software is installed on."
folder: Optional[Folder] = None
"The folder on the file system the Software uses."
fixing_duration: int = 2
"The number of ticks it takes to patch the software."
_fixing_countdown: Optional[int] = None
"Current number of ticks left to patch the software."
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.health_state_actual = self.config.starting_health_state # don't remove this
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
@@ -152,7 +162,7 @@ class Software(SimComponent):
{
"health_state_actual": self.health_state_actual.value,
"health_state_visible": self.health_state_visible.value,
"criticality": self.criticality.value,
"criticality": self.config.criticality.value,
"fixing_count": self.fixing_count,
"scanning_count": self.scanning_count,
"revealed_to_red": self.revealed_to_red,
@@ -201,7 +211,7 @@ class Software(SimComponent):
def fix(self) -> bool:
"""Perform a fix on the software."""
if self.health_state_actual in (SoftwareHealthState.COMPROMISED, SoftwareHealthState.GOOD):
self._fixing_countdown = self.fixing_duration
self._fixing_countdown = self.config.fixing_duration
self.set_health_state(SoftwareHealthState.FIXING)
return True
return False
@@ -233,7 +243,7 @@ class Software(SimComponent):
super().pre_timestep(timestep)
class IOSoftware(Software):
class IOSoftware(Software, ABC):
"""
Represents software in a simulator environment that is capable of input/output operations.
@@ -243,6 +253,13 @@ class IOSoftware(Software):
required.
"""
class ConfigSchema(Software.ConfigSchema, ABC):
"""Configuration options for all IO Software."""
listen_on_ports: Set[Port] = Field(default_factory=set)
config: ConfigSchema = Field(default_factory=lambda: IOSoftware.ConfigSchema())
installing_count: int = 0
"The number of times the software has been installed. Default is 0."
max_sessions: int = 100
@@ -260,6 +277,10 @@ class IOSoftware(Software):
_connections: Dict[str, Dict] = {}
"Active connections."
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.listen_on_ports = self.config.listen_on_ports
@abstractmethod
def describe_state(self) -> Dict:
"""

View File

@@ -148,7 +148,7 @@ simulation:
options:
db_server_ip: 192.168.1.10
server_password: arcd
fix_duration: 1
fixing_duration: 1
- type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
@@ -169,7 +169,7 @@ simulation:
arcd.com: 192.168.1.10
- type: DatabaseService
options:
fix_duration: 5
fixing_duration: 5
backup_server_ip: 192.168.1.10
- type: WebServer
- type: FTPClient

View File

@@ -142,19 +142,19 @@ simulation:
applications:
- type: NMAP
options:
fix_duration: 1
fixing_duration: 1
- type: RansomwareScript
options:
fix_duration: 1
fixing_duration: 1
- type: WebBrowser
options:
target_url: http://arcd.com/users/
fix_duration: 1
fixing_duration: 1
- type: DatabaseClient
options:
db_server_ip: 192.168.1.10
server_password: arcd
fix_duration: 1
fixing_duration: 1
- type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
@@ -162,43 +162,44 @@ simulation:
payload: "DELETE"
server_ip: 192.168.1.21
server_password: arcd
fix_duration: 1
fixing_duration: 1
- type: DoSBot
options:
target_ip_address: 192.168.10.21
payload: SPOOF DATA
port_scan_p_of_success: 0.8
fix_duration: 1
fixing_duration: 1
services:
- type: DNSClient
options:
fix_duration: 3
dns_server: 192.168.1.10
fixing_duration: 3
- type: DNSServer
options:
fix_duration: 3
fixing_duration: 3
domain_mapping:
arcd.com: 192.168.1.10
- type: DatabaseService
options:
backup_server_ip: 192.168.1.10
fix_duration: 3
fixing_duration: 3
- type: WebServer
options:
fix_duration: 3
fixing_duration: 3
- type: FTPClient
options:
fix_duration: 3
fixing_duration: 3
- type: FTPServer
options:
server_password: arcd
fix_duration: 3
fixing_duration: 3
- type: NTPClient
options:
ntp_server_ip: 192.168.1.10
fix_duration: 3
fixing_duration: 3
- type: NTPServer
options:
fix_duration: 3
fixing_duration: 3
- hostname: client_2
type: computer
ip_address: 192.168.10.22

View File

@@ -39,9 +39,16 @@ ACTION_SPACE_NODE_ACTION_VALUES = 1
_LOGGER = getLogger(__name__)
class DummyService(Service):
class DummyService(Service, identifier="DummyService"):
"""Test Service class"""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for DummyService."""
type: str = "DummyService"
config: "DummyService.ConfigSchema" = Field(default_factory=lambda: DummyService.ConfigSchema())
def describe_state(self) -> Dict:
return super().describe_state()
@@ -58,6 +65,13 @@ class DummyService(Service):
class DummyApplication(Application, identifier="DummyApplication"):
"""Test Application class"""
class ConfigSchema(Application.ConfigSchema):
"""ConfigSchema for DummyApplication."""
type: str = "DummyApplication"
config: "DummyApplication.ConfigSchema" = Field(default_factory=lambda: DummyApplication.ConfigSchema())
def __init__(self, **kwargs):
kwargs["name"] = "DummyApplication"
kwargs["port"] = PORT_LOOKUP["HTTP"]

View File

@@ -13,8 +13,8 @@ from primaite.simulator.system.services.database.database_service import Databas
from primaite.simulator.system.services.dns.dns_client import DNSClient
from tests import TEST_ASSETS_ROOT
TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fix_duration.yaml"
ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fix_duration_one_item.yaml"
TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fixing_duration.yaml"
ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fixing_duration_one_item.yaml"
TestApplications = ["DummyApplication", "BroadcastTestClient"]
@@ -27,27 +27,27 @@ def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
return PrimaiteGame.from_config(cfg)
def test_default_fix_duration():
"""Test that software with no defined fix duration in config uses the default fix duration of 2."""
def test_default_fixing_duration():
"""Test that software with no defined fixing duration in config uses the default fixing duration of 2."""
game = load_config(TEST_CONFIG)
client_2: Computer = game.simulation.network.get_node_by_hostname("client_2")
database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
assert database_client.fixing_duration == 2
assert database_client.config.fixing_duration == 2
dns_client: DNSClient = client_2.software_manager.software.get("DNSClient")
assert dns_client.fixing_duration == 2
assert dns_client.config.fixing_duration == 2
def test_fix_duration_set_from_config():
"""Test to check that the fix duration set for applications and services works as intended."""
def test_fixing_duration_set_from_config():
"""Test to check that the fixing duration set for applications and services works as intended."""
game = load_config(TEST_CONFIG)
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
# in config - services take 3 timesteps to fix
for service in ["DNSClient", "DNSServer", "DatabaseService", "WebServer", "FTPClient", "FTPServer", "NTPServer"]:
assert client_1.software_manager.software.get(service) is not None
assert client_1.software_manager.software.get(service).fixing_duration == 3
assert client_1.software_manager.software.get(service).config.fixing_duration == 3
# in config - applications take 1 timestep to fix
# remove test applications from list
@@ -55,27 +55,27 @@ def test_fix_duration_set_from_config():
for application in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot", "DatabaseClient"]:
assert client_1.software_manager.software.get(application) is not None
assert client_1.software_manager.software.get(application).fixing_duration == 1
assert client_1.software_manager.software.get(application).config.fixing_duration == 1
def test_fix_duration_for_one_item():
"""Test that setting fix duration for one application does not affect other components."""
def test_fixing_duration_for_one_item():
"""Test that setting fixing duration for one application does not affect other components."""
game = load_config(ONE_ITEM_CONFIG)
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
# in config - services take 3 timesteps to fix
for service in ["DNSClient", "DNSServer", "WebServer", "FTPClient", "FTPServer", "NTPServer"]:
assert client_1.software_manager.software.get(service) is not None
assert client_1.software_manager.software.get(service).fixing_duration == 2
assert client_1.software_manager.software.get(service).config.fixing_duration == 2
# in config - applications take 1 timestep to fix
# remove test applications from list
for applications in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot"]:
assert client_1.software_manager.software.get(applications) is not None
assert client_1.software_manager.software.get(applications).fixing_duration == 2
assert client_1.software_manager.software.get(applications).config.fixing_duration == 2
database_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
assert database_client.fixing_duration == 1
assert database_client.config.fixing_duration == 1
database_service: DatabaseService = client_1.software_manager.software.get("DatabaseService")
assert database_service.fixing_duration == 5
assert database_service.config.fixing_duration == 5

View File

@@ -4,7 +4,7 @@ from ipaddress import IPv4Address
from typing import Dict, List, Optional
from urllib.parse import urlparse
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from primaite import getLogger
from primaite.interface.request import RequestResponse
@@ -31,6 +31,14 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"):
The application requests and loads web pages using its domain name and requesting IP addresses using DNS.
"""
class ConfigSchema(Application.ConfigSchema):
"""ConfigSchema for ExtendedApplication."""
type: str = "ExtendedApplication"
target_url: Optional[str] = None
config: "ExtendedApplication.ConfigSchema" = Field(default_factory=lambda: ExtendedApplication.ConfigSchema())
target_url: Optional[str] = None
domain_name_ip_address: Optional[IPv4Address] = None
@@ -50,6 +58,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"):
kwargs["port"] = PORT_LOOKUP["HTTP"]
super().__init__(**kwargs)
self.target_url = self.config.target_url
self.run()
def _init_request_manager(self) -> RequestManager:

View File

@@ -3,6 +3,8 @@ from ipaddress import IPv4Address
from typing import Any, Dict, List, Literal, Optional, Union
from uuid import uuid4
from pydantic import Field
from primaite import getLogger
from primaite.simulator.file_system.file_system import File
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
@@ -17,13 +19,20 @@ from primaite.utils.validation.port import PORT_LOOKUP
_LOGGER = getLogger(__name__)
class ExtendedService(Service, identifier="extendedservice"):
class ExtendedService(Service, identifier="ExtendedService"):
"""
A copy of DatabaseService that uses the extension framework instead of being part of PrimAITE.
This class inherits from the `Service` class and provides methods to simulate a SQL database.
"""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for ExtendedService."""
type: str = "ExtendedService"
config: "ExtendedService.ConfigSchema" = Field(default_factory=lambda: ExtendedService.ConfigSchema())
password: Optional[str] = None
"""Password that needs to be provided by clients if they want to connect to the DatabaseService."""

View File

@@ -191,7 +191,7 @@ def test_nic_monitored_traffic(simulation):
# send a database query
browser: WebBrowser = pc.software_manager.software.get("WebBrowser")
browser.target_url = f"http://arcd.com/"
browser.config.target_url = f"http://arcd.com/"
browser.get_webpage()
traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC")

View File

@@ -183,7 +183,7 @@ def test_router_acl_removerule_integration(game_and_agent: Tuple[PrimaiteGame, P
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
browser.target_url = "http://www.example.com"
browser.config.target_url = "http://www.example.com"
assert browser.get_webpage() # check that the browser can access example.com before we block it
# 2: Remove rule that allows HTTP traffic across the network
@@ -216,7 +216,7 @@ def test_host_nic_disable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
browser.target_url = "http://www.example.com"
browser.config.target_url = "http://www.example.com"
assert browser.get_webpage() # check that the browser can access example.com before we block it
# 2: Disable the NIC on client_1
@@ -416,7 +416,7 @@ def test_network_router_port_disable_integration(game_and_agent: Tuple[PrimaiteG
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
browser.target_url = "http://www.example.com"
browser.config.target_url = "http://www.example.com"
assert browser.get_webpage() # check that the browser can access example.com before we block it
# 2: Disable the NIC on client_1
@@ -476,7 +476,7 @@ def test_node_application_scan_integration(game_and_agent: Tuple[PrimaiteGame, P
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
browser.target_url = "http://www.example.com"
browser.config.target_url = "http://www.example.com"
assert browser.get_webpage() # check that the browser can access example.com
assert browser.health_state_actual == SoftwareHealthState.GOOD

View File

@@ -29,7 +29,7 @@ def test_WebpageUnavailablePenalty(game_and_agent: tuple[PrimaiteGame, Controlle
client_1 = game.simulation.network.get_node_by_hostname("client_1")
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
browser.target_url = "http://www.example.com"
browser.config.target_url = "http://www.example.com"
agent.reward_function.register_component(comp, 0.7)
# Check that before trying to fetch the webpage, the reward is 0.0

View File

@@ -3,6 +3,7 @@ from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, List, Tuple
import pytest
from pydantic import Field
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
@@ -14,9 +15,16 @@ from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.port import PORT_LOOKUP
class BroadcastTestService(Service):
class BroadcastTestService(Service, identifier="BroadcastTestService"):
"""A service for sending broadcast and unicast messages over a network."""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for BroadcastTestService."""
type: str = "BroadcastTestService"
config: "BroadcastTestService.ConfigSchema" = Field(default_factory=lambda: BroadcastTestService.ConfigSchema())
def __init__(self, **kwargs):
# Set default service properties for broadcasting
kwargs["name"] = "BroadcastService"
@@ -46,6 +54,13 @@ class BroadcastTestService(Service):
class BroadcastTestClient(Application, identifier="BroadcastTestClient"):
"""A client application to receive broadcast and unicast messages."""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for BroadcastTestClient."""
type: str = "BroadcastTestClient"
config: ConfigSchema = Field(default_factory=lambda: BroadcastTestClient.ConfigSchema())
payloads_received: List = []
def __init__(self, **kwargs):

View File

@@ -495,6 +495,12 @@ def test_c2_suite_yaml():
computer_b: Computer = yaml_network.get_node_by_hostname("node_b")
c2_beacon: C2Beacon = computer_b.software_manager.software.get("C2Beacon")
c2_beacon.configure(
c2_server_ip_address=c2_beacon.config.c2_server_ip_address,
keep_alive_frequency=c2_beacon.config.keep_alive_frequency,
masquerade_port=c2_beacon.config.masquerade_port,
masquerade_protocol=c2_beacon.config.masquerade_protocol,
)
assert c2_server.operating_state == ApplicationOperatingState.RUNNING

View File

@@ -232,7 +232,7 @@ def test_database_service_fix(uc2_network):
assert db_service.health_state_actual == SoftwareHealthState.FIXING
# apply timestep until the fix is applied
for i in range(db_service.fixing_duration + 1):
for i in range(db_service.config.fixing_duration + 1):
uc2_network.apply_timestep(i)
assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD
@@ -266,7 +266,7 @@ def test_database_cannot_be_queried_while_fixing(uc2_network):
assert db_connection.query(sql="SELECT") is False
# apply timestep until the fix is applied
for i in range(db_service.fixing_duration + 1):
for i in range(db_service.config.fixing_duration + 1):
uc2_network.apply_timestep(i)
assert db_service.health_state_actual == SoftwareHealthState.GOOD
@@ -308,7 +308,7 @@ def test_database_can_create_connection_while_fixing(uc2_network):
assert new_db_connection.query(sql="SELECT") is False # still should fail to query because FIXING
# apply timestep until the fix is applied
for i in range(db_service.fixing_duration + 1):
for i in range(db_service.config.fixing_duration + 1):
uc2_network.apply_timestep(i)
assert db_service.health_state_actual == SoftwareHealthState.GOOD

View File

@@ -14,7 +14,14 @@ from primaite.utils.validation.port import PORT_LOOKUP
from tests import TEST_ASSETS_ROOT
class _DatabaseListener(Service):
class _DatabaseListener(Service, identifier="_DatabaseListener"):
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for _DatabaseListener."""
type: str = "_DatabaseListener"
listen_on_ports: Set[int] = {PORT_LOOKUP["POSTGRES_SERVER"]}
config: "_DatabaseListener.ConfigSchema" = Field(default_factory=lambda: _DatabaseListener.ConfigSchema())
name: str = "DatabaseListener"
protocol: str = PROTOCOL_LOOKUP["TCP"]
port: int = PORT_LOOKUP["NONE"]

View File

@@ -51,7 +51,7 @@ def test_web_page_get_users_page_request_with_domain_name(web_client_and_web_ser
web_browser_app, computer, web_server_service, server = web_client_and_web_server
web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address
web_browser_app.target_url = f"http://arcd.com/"
web_browser_app.config.target_url = f"http://arcd.com/"
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
assert web_browser_app.get_webpage() is True
@@ -66,7 +66,7 @@ def test_web_page_get_users_page_request_with_ip_address(web_client_and_web_serv
web_browser_app, computer, web_server_service, server = web_client_and_web_server
web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address
web_browser_app.target_url = f"http://{web_server_ip}/"
web_browser_app.config.target_url = f"http://{web_server_ip}/"
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
assert web_browser_app.get_webpage() is True
@@ -81,7 +81,7 @@ def test_web_page_request_from_shut_down_server(web_client_and_web_server):
web_browser_app, computer, web_server_service, server = web_client_and_web_server
web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address
web_browser_app.target_url = f"http://arcd.com/"
web_browser_app.config.target_url = f"http://arcd.com/"
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
assert web_browser_app.get_webpage() is True
@@ -108,7 +108,7 @@ def test_web_page_request_from_closed_web_browser(web_client_and_web_server):
web_browser_app, computer, web_server_service, server = web_client_and_web_server
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
web_browser_app.target_url = f"http://arcd.com/"
web_browser_app.config.target_url = f"http://arcd.com/"
assert web_browser_app.get_webpage() is True
# latest response should have status code 200

View File

@@ -74,7 +74,7 @@ def web_client_web_server_database(example_network) -> Tuple[Network, Computer,
# Install Web Browser on computer
computer.software_manager.install(WebBrowser)
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
web_browser.target_url = "http://arcd.com/users/"
web_browser.config.target_url = "http://arcd.com/users/"
web_browser.run()
# Install DNS Client service on computer
@@ -131,7 +131,7 @@ def test_database_fix_disrupts_web_client(uc2_network):
assert web_browser.get_webpage() is False
for i in range(database_service.fixing_duration + 1):
for i in range(database_service.config.fixing_duration + 1):
uc2_network.apply_timestep(i)
assert database_service.health_state_actual == SoftwareHealthState.GOOD

View File

@@ -128,13 +128,13 @@ def test_c2_handle_switching_port(basic_c2_network):
assert c2_server.c2_connection_active is True
# Assert to confirm that both the C2 server and the C2 beacon are configured correctly.
assert c2_beacon.c2_config.keep_alive_frequency is 2
assert c2_beacon.c2_config.masquerade_port is PORT_LOOKUP["HTTP"]
assert c2_beacon.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
assert c2_beacon.config.keep_alive_frequency is 2
assert c2_beacon.config.masquerade_port is PORT_LOOKUP["HTTP"]
assert c2_beacon.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
assert c2_server.c2_config.keep_alive_frequency is 2
assert c2_server.c2_config.masquerade_port is PORT_LOOKUP["HTTP"]
assert c2_server.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
assert c2_server.config.keep_alive_frequency is 2
assert c2_server.config.masquerade_port is PORT_LOOKUP["HTTP"]
assert c2_server.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
# Configuring the C2 Beacon.
c2_beacon.configure(
@@ -150,11 +150,11 @@ def test_c2_handle_switching_port(basic_c2_network):
# Assert to confirm that both the C2 server and the C2 beacon
# Have reconfigured their C2 settings.
assert c2_beacon.c2_config.masquerade_port is PORT_LOOKUP["FTP"]
assert c2_beacon.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
assert c2_beacon.config.masquerade_port is PORT_LOOKUP["FTP"]
assert c2_beacon.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
assert c2_server.c2_config.masquerade_port is PORT_LOOKUP["FTP"]
assert c2_server.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
assert c2_server.config.masquerade_port is PORT_LOOKUP["FTP"]
assert c2_server.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
def test_c2_handle_switching_frequency(basic_c2_network):
@@ -174,8 +174,8 @@ def test_c2_handle_switching_frequency(basic_c2_network):
assert c2_server.c2_connection_active is True
# Assert to confirm that both the C2 server and the C2 beacon are configured correctly.
assert c2_beacon.c2_config.keep_alive_frequency is 2
assert c2_server.c2_config.keep_alive_frequency is 2
assert c2_beacon.config.keep_alive_frequency is 2
assert c2_server.config.keep_alive_frequency is 2
# Configuring the C2 Beacon.
c2_beacon.configure(c2_server_ip_address="192.168.0.1", keep_alive_frequency=10)
@@ -186,8 +186,8 @@ def test_c2_handle_switching_frequency(basic_c2_network):
# Assert to confirm that both the C2 server and the C2 beacon
# Have reconfigured their C2 settings.
assert c2_beacon.c2_config.keep_alive_frequency is 10
assert c2_server.c2_config.keep_alive_frequency is 10
assert c2_beacon.config.keep_alive_frequency is 10
assert c2_server.config.keep_alive_frequency is 10
# Now skipping 9 time steps to confirm keep alive inactivity
for i in range(9):

View File

@@ -148,7 +148,7 @@ def test_service_fixing(service):
service.fix()
assert service.health_state_actual == SoftwareHealthState.FIXING
for i in range(service.fixing_duration + 1):
for i in range(service.config.fixing_duration + 1):
service.apply_timestep(i)
assert service.health_state_actual == SoftwareHealthState.GOOD

View File

@@ -2,6 +2,7 @@
from typing import Dict
import pytest
from pydantic import Field
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.services.service import Service
@@ -10,7 +11,14 @@ from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.port import PORT_LOOKUP
class TestSoftware(Service):
class TestSoftware(Service, identifier="TestSoftware"):
class ConfigSchema(Service.ConfigSchema):
"""ConfigSChema for TestSoftware."""
type: str = "TestSoftware"
config: "TestSoftware.ConfigSchema" = Field(default_factory=lambda: TestSoftware.ConfigSchema())
def describe_state(self) -> Dict:
pass