Merge remote-tracking branch 'origin/4.0.0a1-dev' into feature/2869-Marek
This commit is contained in:
@@ -70,7 +70,7 @@ Python
|
||||
Configuration
|
||||
=============
|
||||
|
||||
The RansomwareScript inherits configuration options such as ``fix_duration`` from its parent class. However, for the ``RansomwareScript`` the most relevant option is ``server_ip``.
|
||||
The RansomwareScript inherits configuration options such as ``fixing_duration`` from its parent class. However, for the ``RansomwareScript`` the most relevant option is ``server_ip``.
|
||||
|
||||
|
||||
``server_ip``
|
||||
|
||||
@@ -22,8 +22,8 @@ options
|
||||
|
||||
The configuration options are the attributes that fall under the options for an application or service.
|
||||
|
||||
fix_duration
|
||||
""""""""""""
|
||||
fixing_duration
|
||||
"""""""""""""""
|
||||
|
||||
Optional. Default value is ``2``.
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.services.terminal.terminal import Terminal
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
from primaite.simulator.system.software import Software
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol
|
||||
from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -370,12 +370,12 @@ class PrimaiteGame:
|
||||
|
||||
if service_class is not None:
|
||||
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
|
||||
new_node.software_manager.install(service_class, **service_cfg.get("options", {}))
|
||||
new_node.software_manager.install(service_class)
|
||||
new_service = new_node.software_manager.software[service_class.__name__]
|
||||
|
||||
# fixing duration for the service
|
||||
if "fix_duration" in service_cfg.get("options", {}):
|
||||
new_service.fixing_duration = service_cfg["options"]["fix_duration"]
|
||||
if "fixing_duration" in service_cfg.get("options", {}):
|
||||
new_service.config.fixing_duration = service_cfg["options"]["fixing_duration"]
|
||||
|
||||
_set_software_listen_on_ports(new_service, service_cfg)
|
||||
# start the service
|
||||
@@ -416,74 +416,20 @@ class PrimaiteGame:
|
||||
application_type = application_cfg["type"]
|
||||
|
||||
if application_type in Application._registry:
|
||||
new_node.software_manager.install(Application._registry[application_type])
|
||||
application_class = Application._registry[application_type]
|
||||
application_options = application_cfg.get("options", {})
|
||||
application_options["type"] = application_type
|
||||
new_node.software_manager.install(application_class, software_config=application_options)
|
||||
new_application = new_node.software_manager.software[application_type] # grab the instance
|
||||
|
||||
# fixing duration for the application
|
||||
if "fix_duration" in application_cfg.get("options", {}):
|
||||
new_application.fixing_duration = application_cfg["options"]["fix_duration"]
|
||||
else:
|
||||
msg = f"Configuration contains an invalid application type: {application_type}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
_set_software_listen_on_ports(new_application, application_cfg)
|
||||
|
||||
# run the application
|
||||
new_application.run()
|
||||
|
||||
if application_type == "DataManipulationBot":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
server_ip_address=IPv4Address(opt.get("server_ip")),
|
||||
server_password=opt.get("server_password"),
|
||||
payload=opt.get("payload", "DELETE"),
|
||||
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
|
||||
data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")),
|
||||
)
|
||||
elif application_type == "RansomwareScript":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
server_ip_address=IPv4Address(opt.get("server_ip")) if opt.get("server_ip") else None,
|
||||
server_password=opt.get("server_password"),
|
||||
payload=opt.get("payload", "ENCRYPT"),
|
||||
)
|
||||
elif application_type == "DatabaseClient":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
server_ip_address=IPv4Address(opt.get("db_server_ip")),
|
||||
server_password=opt.get("server_password"),
|
||||
)
|
||||
elif application_type == "WebBrowser":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.target_url = opt.get("target_url")
|
||||
elif application_type == "DoSBot":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
target_ip_address=IPv4Address(opt.get("target_ip_address")),
|
||||
target_port=PORT_LOOKUP[opt.get("target_port", "POSTGRES_SERVER")],
|
||||
payload=opt.get("payload"),
|
||||
repeat=bool(opt.get("repeat")),
|
||||
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
|
||||
dos_intensity=float(opt.get("dos_intensity", "1.0")),
|
||||
max_sessions=int(opt.get("max_sessions", "1000")),
|
||||
)
|
||||
elif application_type == "C2Beacon":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
c2_server_ip_address=IPv4Address(opt.get("c2_server_ip_address")),
|
||||
keep_alive_frequency=(opt.get("keep_alive_frequency", 5)),
|
||||
masquerade_protocol=PROTOCOL_LOOKUP[
|
||||
(opt.get("masquerade_protocol", PROTOCOL_LOOKUP["TCP"]))
|
||||
],
|
||||
masquerade_port=PORT_LOOKUP[(opt.get("masquerade_port", PORT_LOOKUP["HTTP"]))],
|
||||
)
|
||||
if "network_interfaces" in node_cfg:
|
||||
for nic_num, nic_cfg in node_cfg["network_interfaces"].items():
|
||||
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))
|
||||
|
||||
@@ -824,7 +824,7 @@ class User(SimComponent):
|
||||
return self.model_dump()
|
||||
|
||||
|
||||
class UserManager(Service):
|
||||
class UserManager(Service, identifier="UserManager"):
|
||||
"""
|
||||
Manages users within the PrimAITE system, handling creation, authentication, and administration.
|
||||
|
||||
@@ -833,11 +833,18 @@ class UserManager(Service):
|
||||
:param disabled_admins: A dictionary of currently disabled admin users by their usernames
|
||||
"""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for UserManager."""
|
||||
|
||||
type: str = "UserManager"
|
||||
|
||||
config: "UserManager.ConfigSchema" = Field(default_factory=lambda: UserManager.ConfigSchema())
|
||||
|
||||
users: Dict[str, User] = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Initializes a UserManager instanc.
|
||||
Initializes a UserManager instance.
|
||||
|
||||
:param username: The username for the default admin user
|
||||
:param password: The password for the default admin user
|
||||
@@ -1130,13 +1137,20 @@ class RemoteUserSession(UserSession):
|
||||
return state
|
||||
|
||||
|
||||
class UserSessionManager(Service):
|
||||
class UserSessionManager(Service, identifier="UserSessionManager"):
|
||||
"""
|
||||
Manages user sessions on a Node, including local and remote sessions.
|
||||
|
||||
This class handles authentication, session management, and session timeouts for users interacting with the Node.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for UserSessionManager."""
|
||||
|
||||
type: str = "UserSessionManager"
|
||||
|
||||
config: "UserSessionManager.ConfigSchema" = Field(default_factory=lambda: UserSessionManager.ConfigSchema())
|
||||
|
||||
local_session: Optional[UserSession] = None
|
||||
"""The current local user session, if any."""
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar, Dict, Optional, Set, Type
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType
|
||||
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
|
||||
@@ -21,13 +23,20 @@ class ApplicationOperatingState(Enum):
|
||||
"The application is being installed or updated."
|
||||
|
||||
|
||||
class Application(IOSoftware):
|
||||
class Application(IOSoftware, ABC):
|
||||
"""
|
||||
Represents an Application in the simulation environment.
|
||||
|
||||
Applications are user-facing programs that may perform input/output operations.
|
||||
"""
|
||||
|
||||
class ConfigSchema(IOSoftware.ConfigSchema, ABC):
|
||||
"""Config Schema for Application class."""
|
||||
|
||||
type: str
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: Application.ConfigSchema())
|
||||
|
||||
operating_state: ApplicationOperatingState = ApplicationOperatingState.CLOSED
|
||||
"The current operating state of the Application."
|
||||
execution_control_status: str = "manual"
|
||||
@@ -49,7 +58,7 @@ class Application(IOSoftware):
|
||||
Register an application type.
|
||||
|
||||
:param identifier: Uniquely specifies an application class by name. Used for finding items by config.
|
||||
:type identifier: str
|
||||
:type identifier: Optional[str]
|
||||
:raises ValueError: When attempting to register an application with a name that is already allocated.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
@@ -59,6 +68,21 @@ class Application(IOSoftware):
|
||||
raise ValueError(f"Tried to define new application {identifier}, but this name is already reserved.")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "Application":
|
||||
"""Create an application from a config dictionary.
|
||||
|
||||
:param config: dict of options for application components constructor
|
||||
:type config: dict
|
||||
:return: The application component.
|
||||
:rtype: Application
|
||||
"""
|
||||
if config["type"] not in cls._registry:
|
||||
raise ValueError(f"Invalid Application type {config['type']}")
|
||||
application_class = cls._registry[config["type"]]
|
||||
application_object = application_class(config=application_class.ConfigSchema(**config))
|
||||
return application_object
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
@@ -67,11 +67,19 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
|
||||
Extends the Application class to provide functionality for connecting, querying, and disconnecting from a
|
||||
Database Service. It mainly operates over TCP protocol.
|
||||
|
||||
:ivar server_ip_address: The IPv4 address of the Database Service server, defaults to None.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""ConfigSchema for DatabaseClient."""
|
||||
|
||||
type: str = "DatabaseClient"
|
||||
db_server_ip: Optional[IPV4Address] = None
|
||||
server_password: Optional[str] = None
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: DatabaseClient.ConfigSchema())
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
"""The IPv4 address of the Database Service server, defaults to None."""
|
||||
server_password: Optional[str] = None
|
||||
_query_success_tracker: Dict[str, bool] = {}
|
||||
"""Keep track of connections that were established or verified during this step. Used for rewards."""
|
||||
@@ -93,6 +101,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"]
|
||||
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
||||
super().__init__(**kwargs)
|
||||
self.server_ip_address = self.config.db_server_ip
|
||||
self.server_password = self.config.server_password
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
|
||||
@@ -3,7 +3,7 @@ from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Final, List, Optional, Set, Tuple, Union
|
||||
|
||||
from prettytable import PrettyTable
|
||||
from pydantic import validate_call
|
||||
from pydantic import Field, validate_call
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
@@ -52,6 +52,13 @@ class NMAP(Application, identifier="NMAP"):
|
||||
as ping scans to discover active hosts and port scans to detect open ports on those hosts.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""ConfigSchema for NMAP."""
|
||||
|
||||
type: str = "NMAP"
|
||||
|
||||
config: "NMAP.ConfigSchema" = Field(default_factory=lambda: NMAP.ConfigSchema())
|
||||
|
||||
_active_port_scans: Dict[str, PortScanPayload] = {}
|
||||
_port_scan_responses: Dict[str, PortScanPayload] = {}
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, validate_call
|
||||
from pydantic import Field, validate_call
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.file_system.file_system import FileSystem, Folder
|
||||
@@ -45,10 +45,10 @@ class C2Payload(Enum):
|
||||
"""C2 Input Command payload. Used by the C2 Server to send a command to the c2 beacon."""
|
||||
|
||||
OUTPUT = "output_command"
|
||||
"""C2 Output Command. Used by the C2 Beacon to send the results of a Input command to the c2 server."""
|
||||
"""C2 Output Command. Used by the C2 Beacon to send the results of an Input command to the c2 server."""
|
||||
|
||||
|
||||
class AbstractC2(Application, identifier="AbstractC2"):
|
||||
class AbstractC2(Application):
|
||||
"""
|
||||
An abstract command and control (c2) application.
|
||||
|
||||
@@ -60,9 +60,25 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
|
||||
Defaults to masquerading as HTTP (Port 80) via TCP.
|
||||
|
||||
Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite.
|
||||
Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""Configuration for AbstractC2."""
|
||||
|
||||
keep_alive_frequency: int = Field(default=5, ge=1)
|
||||
"""The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon."""
|
||||
|
||||
masquerade_protocol: IPProtocol = Field(default=PROTOCOL_LOOKUP["TCP"])
|
||||
"""The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP."""
|
||||
|
||||
masquerade_port: Port = Field(default=PORT_LOOKUP["HTTP"])
|
||||
"""The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP."""
|
||||
|
||||
listen_on_ports: Set[Port] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]}
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: AbstractC2.ConfigSchema())
|
||||
|
||||
c2_connection_active: bool = False
|
||||
"""Indicates if the c2 server and c2 beacon are currently connected."""
|
||||
|
||||
@@ -75,19 +91,6 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
keep_alive_inactivity: int = 0
|
||||
"""Indicates how many timesteps since the last time the c2 application received a keep alive."""
|
||||
|
||||
class _C2Opts(BaseModel):
|
||||
"""A Pydantic Schema for the different C2 configuration options."""
|
||||
|
||||
keep_alive_frequency: int = Field(default=5, ge=1)
|
||||
"""The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon."""
|
||||
|
||||
masquerade_protocol: IPProtocol = Field(default=PROTOCOL_LOOKUP["TCP"])
|
||||
"""The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP."""
|
||||
|
||||
masquerade_port: Port = Field(default=PORT_LOOKUP["HTTP"])
|
||||
"""The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP."""
|
||||
|
||||
c2_config: _C2Opts = _C2Opts()
|
||||
"""
|
||||
Holds the current configuration settings of the C2 Suite.
|
||||
|
||||
@@ -100,6 +103,12 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
C2 beacon to reconfigure it's configuration settings.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialise the C2 applications to by default listen for HTTP traffic."""
|
||||
kwargs["port"] = PORT_LOOKUP["NONE"]
|
||||
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _craft_packet(
|
||||
self, c2_payload: C2Payload, c2_command: Optional[C2Command] = None, command_options: Optional[Dict] = {}
|
||||
) -> C2Packet:
|
||||
@@ -118,13 +127,13 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
:type c2_command: C2Command.
|
||||
:param command_options: The relevant C2 Beacon parameters.F
|
||||
:type command_options: Dict
|
||||
:return: Returns the construct C2Packet
|
||||
:return: Returns the constructed C2Packet
|
||||
:rtype: C2Packet
|
||||
"""
|
||||
constructed_packet = C2Packet(
|
||||
masquerade_protocol=self.c2_config.masquerade_protocol,
|
||||
masquerade_port=self.c2_config.masquerade_port,
|
||||
keep_alive_frequency=self.c2_config.keep_alive_frequency,
|
||||
masquerade_protocol=self.config.masquerade_protocol,
|
||||
masquerade_port=self.config.masquerade_port,
|
||||
keep_alive_frequency=self.config.keep_alive_frequency,
|
||||
payload_type=c2_payload,
|
||||
command=c2_command,
|
||||
payload=command_options,
|
||||
@@ -140,13 +149,6 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
"""
|
||||
return super().describe_state()
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialise the C2 applications to by default listen for HTTP traffic."""
|
||||
kwargs["listen_on_ports"] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]}
|
||||
kwargs["port"] = PORT_LOOKUP["NONE"]
|
||||
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def _host_ftp_client(self) -> Optional[FTPClient]:
|
||||
"""Return the FTPClient that is installed C2 Application's host.
|
||||
@@ -330,8 +332,8 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
if self.send(
|
||||
payload=keep_alive_packet,
|
||||
dest_ip_address=self.c2_remote_connection,
|
||||
dest_port=self.c2_config.masquerade_port,
|
||||
ip_protocol=self.c2_config.masquerade_protocol,
|
||||
dest_port=self.config.masquerade_port,
|
||||
ip_protocol=self.config.masquerade_protocol,
|
||||
session_id=session_id,
|
||||
):
|
||||
# Setting the keep_alive_sent guard condition to True. This is used to prevent packet storms.
|
||||
@@ -340,8 +342,8 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
self.sys_log.info(f"{self.name}: Keep Alive sent to {self.c2_remote_connection}")
|
||||
self.sys_log.debug(
|
||||
f"{self.name}: Keep Alive sent to {self.c2_remote_connection} "
|
||||
f"Masquerade Port: {self.c2_config.masquerade_port} "
|
||||
f"Masquerade Protocol: {self.c2_config.masquerade_protocol} "
|
||||
f"Masquerade Port: {self.config.masquerade_port} "
|
||||
f"Masquerade Protocol: {self.config.masquerade_protocol} "
|
||||
)
|
||||
return True
|
||||
else:
|
||||
@@ -376,15 +378,15 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
|
||||
# Updating the C2 Configuration attribute.
|
||||
|
||||
self.c2_config.masquerade_port = payload.masquerade_port
|
||||
self.c2_config.masquerade_protocol = payload.masquerade_protocol
|
||||
self.c2_config.keep_alive_frequency = payload.keep_alive_frequency
|
||||
self.config.masquerade_port = payload.masquerade_port
|
||||
self.config.masquerade_protocol = payload.masquerade_protocol
|
||||
self.config.keep_alive_frequency = payload.keep_alive_frequency
|
||||
|
||||
self.sys_log.debug(
|
||||
f"{self.name}: C2 Config Resolved Config from Keep Alive:"
|
||||
f"Masquerade Port: {self.c2_config.masquerade_port}"
|
||||
f"Masquerade Protocol: {self.c2_config.masquerade_protocol}"
|
||||
f"Keep Alive Frequency: {self.c2_config.keep_alive_frequency}"
|
||||
f"Masquerade Port: {self.config.masquerade_port}"
|
||||
f"Masquerade Protocol: {self.config.masquerade_protocol}"
|
||||
f"Keep Alive Frequency: {self.config.keep_alive_frequency}"
|
||||
)
|
||||
|
||||
# This statement is intended to catch on the C2 Application that is listening for connection.
|
||||
@@ -410,8 +412,8 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
self.keep_alive_inactivity = 0
|
||||
self.keep_alive_frequency = 5
|
||||
self.c2_remote_connection = None
|
||||
self.c2_config.masquerade_port = PORT_LOOKUP["HTTP"]
|
||||
self.c2_config.masquerade_protocol = PROTOCOL_LOOKUP["TCP"]
|
||||
self.config.masquerade_port = PORT_LOOKUP["HTTP"]
|
||||
self.config.masquerade_protocol = PROTOCOL_LOOKUP["TCP"]
|
||||
|
||||
@abstractmethod
|
||||
def _confirm_remote_connection(self, timestep: int) -> bool:
|
||||
|
||||
@@ -3,7 +3,7 @@ from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import validate_call
|
||||
from pydantic import Field, validate_call
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
@@ -12,8 +12,9 @@ from primaite.simulator.system.applications.red_applications.c2 import ExfilOpts
|
||||
from primaite.simulator.system.applications.red_applications.c2.abstract_c2 import AbstractC2, C2Command, C2Payload
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
|
||||
from primaite.simulator.system.services.terminal.terminal import Terminal, TerminalClientConnection
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.port import PORT_LOOKUP
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
|
||||
|
||||
class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
@@ -32,15 +33,30 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
2. Leveraging the terminal application to execute requests (dependent on the command given)
|
||||
3. Sending the RequestResponse back to the C2 Server (Command output)
|
||||
|
||||
Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite.
|
||||
Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite.
|
||||
"""
|
||||
|
||||
class ConfigSchema(AbstractC2.ConfigSchema):
|
||||
"""ConfigSchema for C2Beacon."""
|
||||
|
||||
type: str = "C2Beacon"
|
||||
c2_server_ip_address: Optional[IPV4Address] = None
|
||||
keep_alive_frequency: int = 5
|
||||
masquerade_protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"]
|
||||
masquerade_port: Port = PORT_LOOKUP["HTTP"]
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: C2Beacon.ConfigSchema())
|
||||
|
||||
keep_alive_attempted: bool = False
|
||||
"""Indicates if a keep alive has been attempted to be sent this timestep. Used to prevent packet storms."""
|
||||
|
||||
terminal_session: TerminalClientConnection = None
|
||||
"The currently in use terminal session."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "C2Beacon"
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def _host_terminal(self) -> Optional[Terminal]:
|
||||
"""Return the Terminal that is installed on the same machine as the C2 Beacon."""
|
||||
@@ -119,10 +135,6 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
rm.add_request("configure", request_type=RequestType(func=_configure))
|
||||
return rm
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "C2Beacon"
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Configure is practically setter method for the ``c2.config`` attribute that also ties into the request manager.
|
||||
@validate_call
|
||||
def configure(
|
||||
@@ -146,7 +158,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
masquerade_port | What port should the C2 traffic use? (TCP or UDP)
|
||||
|
||||
These configuration options are used to reassign the fields in the inherited inner class
|
||||
``c2_config``.
|
||||
``config``.
|
||||
|
||||
If a connection is already in progress then this method also sends a keep alive to the C2
|
||||
Server in order for the C2 Server to sync with the new configuration settings.
|
||||
@@ -162,9 +174,9 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
:return: Returns True if the configuration was successful, False otherwise.
|
||||
"""
|
||||
self.c2_remote_connection = IPv4Address(c2_server_ip_address)
|
||||
self.c2_config.keep_alive_frequency = keep_alive_frequency
|
||||
self.c2_config.masquerade_port = masquerade_port
|
||||
self.c2_config.masquerade_protocol = masquerade_protocol
|
||||
self.config.keep_alive_frequency = keep_alive_frequency
|
||||
self.config.masquerade_port = masquerade_port
|
||||
self.config.masquerade_protocol = masquerade_protocol
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Configured {self.name} with remote C2 server connection: {c2_server_ip_address=}."
|
||||
)
|
||||
@@ -263,14 +275,12 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
if self.send(
|
||||
payload=output_packet,
|
||||
dest_ip_address=self.c2_remote_connection,
|
||||
dest_port=self.c2_config.masquerade_port,
|
||||
ip_protocol=self.c2_config.masquerade_protocol,
|
||||
dest_port=self.config.masquerade_port,
|
||||
ip_protocol=self.config.masquerade_protocol,
|
||||
session_id=session_id,
|
||||
):
|
||||
self.sys_log.info(f"{self.name}: Command output sent to {self.c2_remote_connection}")
|
||||
self.sys_log.debug(
|
||||
f"{self.name}: on {self.c2_config.masquerade_port} via {self.c2_config.masquerade_protocol}"
|
||||
)
|
||||
self.sys_log.debug(f"{self.name}: on {self.config.masquerade_port} via {self.config.masquerade_protocol}")
|
||||
return True
|
||||
else:
|
||||
self.sys_log.warning(
|
||||
@@ -562,7 +572,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
:rtype bool:
|
||||
"""
|
||||
self.keep_alive_attempted = False # Resetting keep alive sent.
|
||||
if self.keep_alive_inactivity == self.c2_config.keep_alive_frequency:
|
||||
if self.keep_alive_inactivity == self.config.keep_alive_frequency:
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Attempting to Send Keep Alive to {self.c2_remote_connection} at timestep {timestep}."
|
||||
)
|
||||
@@ -627,9 +637,9 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
self.c2_connection_active,
|
||||
self.c2_remote_connection,
|
||||
self.keep_alive_inactivity,
|
||||
self.c2_config.keep_alive_frequency,
|
||||
self.c2_config.masquerade_protocol,
|
||||
self.c2_config.masquerade_port,
|
||||
self.config.keep_alive_frequency,
|
||||
self.config.masquerade_protocol,
|
||||
self.config.masquerade_port,
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import validate_call
|
||||
from pydantic import Field, validate_call
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
@@ -31,9 +31,16 @@ class C2Server(AbstractC2, identifier="C2Server"):
|
||||
1. Sending commands to the C2 Beacon. (Command input)
|
||||
2. Parsing terminal RequestResponses back to the Agent.
|
||||
|
||||
Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite.
|
||||
Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite.
|
||||
"""
|
||||
|
||||
class ConfigSchema(AbstractC2.ConfigSchema):
|
||||
"""ConfigSchema for C2Server."""
|
||||
|
||||
type: str = "C2Server"
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: C2Server.ConfigSchema())
|
||||
|
||||
current_command_output: RequestResponse = None
|
||||
"""The Request Response by the last command send. This attribute is updated by the method _handle_command_output."""
|
||||
|
||||
@@ -251,8 +258,8 @@ class C2Server(AbstractC2, identifier="C2Server"):
|
||||
payload=command_packet,
|
||||
dest_ip_address=self.c2_remote_connection,
|
||||
session_id=self.c2_session.uuid,
|
||||
dest_port=self.c2_config.masquerade_port,
|
||||
ip_protocol=self.c2_config.masquerade_protocol,
|
||||
dest_port=self.config.masquerade_port,
|
||||
ip_protocol=self.config.masquerade_protocol,
|
||||
):
|
||||
self.sys_log.info(f"{self.name}: Successfully sent {given_command}.")
|
||||
self.sys_log.info(f"{self.name}: Awaiting command response {given_command}.")
|
||||
@@ -334,11 +341,11 @@ class C2Server(AbstractC2, identifier="C2Server"):
|
||||
:return: Returns False if the C2 beacon is considered dead. Otherwise True.
|
||||
:rtype bool:
|
||||
"""
|
||||
if self.keep_alive_inactivity > self.c2_config.keep_alive_frequency:
|
||||
if self.keep_alive_inactivity > self.config.keep_alive_frequency:
|
||||
self.sys_log.info(f"{self.name}: C2 Beacon connection considered dead due to inactivity.")
|
||||
self.sys_log.debug(
|
||||
f"{self.name}: Did not receive expected keep alive connection from {self.c2_remote_connection}"
|
||||
f"{self.name}: Expected at timestep: {timestep} due to frequency: {self.c2_config.keep_alive_frequency}"
|
||||
f"{self.name}: Expected at timestep: {timestep} due to frequency: {self.config.keep_alive_frequency}"
|
||||
f"{self.name}: Last Keep Alive received at {(timestep - self.keep_alive_inactivity)}"
|
||||
)
|
||||
self._reset_c2_connection()
|
||||
@@ -389,8 +396,8 @@ class C2Server(AbstractC2, identifier="C2Server"):
|
||||
[
|
||||
self.c2_connection_active,
|
||||
self.c2_remote_connection,
|
||||
self.c2_config.masquerade_protocol,
|
||||
self.c2_config.masquerade_port,
|
||||
self.config.masquerade_protocol,
|
||||
self.config.masquerade_port,
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
|
||||
@@ -3,6 +3,8 @@ from enum import IntEnum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.interface.request import RequestResponse
|
||||
@@ -10,6 +12,7 @@ from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import PORT_LOOKUP
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -40,6 +43,18 @@ class DataManipulationAttackStage(IntEnum):
|
||||
class DataManipulationBot(Application, identifier="DataManipulationBot"):
|
||||
"""A bot that simulates a script which performs a SQL injection attack."""
|
||||
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""Configuration schema for DataManipulationBot."""
|
||||
|
||||
type: str = "DataManipulationBot"
|
||||
server_ip: Optional[IPV4Address] = None
|
||||
server_password: Optional[str] = None
|
||||
payload: str = "DELETE"
|
||||
port_scan_p_of_success: float = 0.1
|
||||
data_manipulation_p_of_success: float = 0.1
|
||||
|
||||
config: "DataManipulationBot.ConfigSchema" = Field(default_factory=lambda: DataManipulationBot.ConfigSchema())
|
||||
|
||||
payload: Optional[str] = None
|
||||
port_scan_p_of_success: float = 0.1
|
||||
data_manipulation_p_of_success: float = 0.1
|
||||
@@ -56,6 +71,12 @@ class DataManipulationBot(Application, identifier="DataManipulationBot"):
|
||||
super().__init__(**kwargs)
|
||||
self._db_connection: Optional[DatabaseClientConnection] = None
|
||||
|
||||
self.server_ip_address = self.config.server_ip
|
||||
self.server_password = self.config.server_password
|
||||
self.payload = self.config.payload
|
||||
self.port_scan_p_of_success = self.config.port_scan_p_of_success
|
||||
self.data_manipulation_p_of_success = self.config.data_manipulation_p_of_success
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
|
||||
@@ -3,11 +3,14 @@ from enum import IntEnum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -32,6 +35,20 @@ class DoSAttackStage(IntEnum):
|
||||
class DoSBot(DatabaseClient, identifier="DoSBot"):
|
||||
"""A bot that simulates a Denial of Service attack."""
|
||||
|
||||
class ConfigSchema(DatabaseClient.ConfigSchema):
|
||||
"""ConfigSchema for DoSBot."""
|
||||
|
||||
type: str = "DoSBot"
|
||||
target_ip_address: Optional[IPV4Address] = None
|
||||
target_port: Port = PORT_LOOKUP["POSTGRES_SERVER"]
|
||||
payload: Optional[str] = None
|
||||
repeat: bool = False
|
||||
port_scan_p_of_success: float = 0.1
|
||||
dos_intensity: float = 1.0
|
||||
max_sessions: int = 1000
|
||||
|
||||
config: "DoSBot.ConfigSchema" = Field(default_factory=lambda: DoSBot.ConfigSchema())
|
||||
|
||||
target_ip_address: Optional[IPv4Address] = None
|
||||
"""IP address of the target service."""
|
||||
|
||||
@@ -56,7 +73,13 @@ class DoSBot(DatabaseClient, identifier="DoSBot"):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.name = "DoSBot"
|
||||
self.max_sessions = 1000 # override normal max sessions
|
||||
self.target_ip_address = self.config.target_ip_address
|
||||
self.target_port = self.config.target_port
|
||||
self.payload = self.config.payload
|
||||
self.repeat = self.config.repeat
|
||||
self.port_scan_p_of_success = self.config.port_scan_p_of_success
|
||||
self.dos_intensity = self.config.dos_intensity
|
||||
self.max_sessions = self.config.max_sessions
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
|
||||
@@ -3,12 +3,14 @@ from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import PORT_LOOKUP
|
||||
|
||||
|
||||
@@ -18,6 +20,16 @@ class RansomwareScript(Application, identifier="RansomwareScript"):
|
||||
:ivar payload: The attack stage query payload. (Default ENCRYPT)
|
||||
"""
|
||||
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""ConfigSchema for RansomwareScript."""
|
||||
|
||||
type: str = "RansomwareScript"
|
||||
server_ip: Optional[IPV4Address] = None
|
||||
server_password: Optional[str] = None
|
||||
payload: str = "ENCRYPT"
|
||||
|
||||
config: "RansomwareScript.ConfigSchema" = Field(default_factory=lambda: RansomwareScript.ConfigSchema())
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
"""IP address of node which hosts the database."""
|
||||
server_password: Optional[str] = None
|
||||
@@ -32,6 +44,9 @@ class RansomwareScript(Application, identifier="RansomwareScript"):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self._db_connection: Optional[DatabaseClientConnection] = None
|
||||
self.server_ip_address = self.config.server_ip
|
||||
self.server_password = self.config.server_password
|
||||
self.payload = self.config.payload
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
|
||||
@@ -4,7 +4,7 @@ from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
@@ -30,7 +30,13 @@ class WebBrowser(Application, identifier="WebBrowser"):
|
||||
The application requests and loads web pages using its domain name and requesting IP addresses using DNS.
|
||||
"""
|
||||
|
||||
target_url: Optional[str] = None
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""ConfigSchema for WebBrowser."""
|
||||
|
||||
type: str = "WebBrowser"
|
||||
target_url: Optional[str] = None
|
||||
|
||||
config: "WebBrowser.ConfigSchema" = Field(default_factory=lambda: WebBrowser.ConfigSchema())
|
||||
|
||||
domain_name_ip_address: Optional[IPv4Address] = None
|
||||
"The IP address of the domain name for the webpage."
|
||||
@@ -86,7 +92,7 @@ class WebBrowser(Application, identifier="WebBrowser"):
|
||||
:param: url: The address of the web page the browser requests
|
||||
:type: url: str
|
||||
"""
|
||||
url = url or self.target_url
|
||||
url = url or self.config.target_url
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
|
||||
@@ -106,7 +106,7 @@ class SoftwareManager:
|
||||
return True
|
||||
return False
|
||||
|
||||
def install(self, software_class: Type[IOSoftware], **install_kwargs):
|
||||
def install(self, software_class: Type[IOSoftware], software_config: Optional[IOSoftware.ConfigSchema] = None):
|
||||
"""
|
||||
Install an Application or Service.
|
||||
|
||||
@@ -115,13 +115,22 @@ class SoftwareManager:
|
||||
if software_class in self._software_class_to_name_map:
|
||||
self.sys_log.warning(f"Cannot install {software_class} as it is already installed")
|
||||
return
|
||||
software = software_class(
|
||||
software_manager=self,
|
||||
sys_log=self.sys_log,
|
||||
file_system=self.file_system,
|
||||
dns_server=self.dns_server,
|
||||
**install_kwargs,
|
||||
)
|
||||
if software_config is None:
|
||||
software = software_class(
|
||||
software_manager=self,
|
||||
sys_log=self.sys_log,
|
||||
file_system=self.file_system,
|
||||
dns_server=self.dns_server,
|
||||
)
|
||||
else:
|
||||
software = software_class(
|
||||
software_manager=self,
|
||||
sys_log=self.sys_log,
|
||||
file_system=self.file_system,
|
||||
dns_server=self.dns_server,
|
||||
config=software_config,
|
||||
)
|
||||
|
||||
software.parent = self.node
|
||||
if isinstance(software, Application):
|
||||
self.node.applications[software.uuid] = software
|
||||
|
||||
@@ -5,6 +5,7 @@ from abc import abstractmethod
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.simulator.network.hardware.base import NetworkInterface
|
||||
from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket
|
||||
@@ -14,7 +15,7 @@ from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import PORT_LOOKUP
|
||||
|
||||
|
||||
class ARP(Service):
|
||||
class ARP(Service, identifier="ARP"):
|
||||
"""
|
||||
The ARP (Address Resolution Protocol) Service.
|
||||
|
||||
@@ -22,6 +23,13 @@ class ARP(Service):
|
||||
sends ARP requests and replies, and processes incoming ARP packets.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for ARP."""
|
||||
|
||||
type: str = "ARP"
|
||||
|
||||
config: "ARP.ConfigSchema" = Field(default_factory=lambda: ARP.ConfigSchema())
|
||||
|
||||
arp: Dict[IPV4Address, ARPEntry] = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -3,6 +3,8 @@ from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
|
||||
@@ -17,13 +19,21 @@ from primaite.utils.validation.port import PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseService(Service):
|
||||
class DatabaseService(Service, identifier="DatabaseService"):
|
||||
"""
|
||||
A class for simulating a generic SQL Server service.
|
||||
|
||||
This class inherits from the `Service` class and provides methods to simulate a SQL database.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for DatabaseService."""
|
||||
|
||||
type: str = "DatabaseService"
|
||||
backup_server_ip: Optional[IPv4Address] = None
|
||||
|
||||
config: "DatabaseService.ConfigSchema" = Field(default_factory=lambda: DatabaseService.ConfigSchema())
|
||||
|
||||
password: Optional[str] = None
|
||||
"""Password that needs to be provided by clients if they want to connect to the DatabaseService."""
|
||||
|
||||
@@ -42,6 +52,7 @@ class DatabaseService(Service):
|
||||
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
||||
super().__init__(**kwargs)
|
||||
self._create_db_file()
|
||||
self.backup_server_ip = self.config.backup_server_ip
|
||||
|
||||
def install(self):
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
@@ -12,9 +14,15 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class DNSClient(Service):
|
||||
class DNSClient(Service, identifier="DNSClient"):
|
||||
"""Represents a DNS Client as a Service."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for DNSClient."""
|
||||
|
||||
type: str = "DNSClient"
|
||||
|
||||
config: "DNSClient.ConfigSchema" = Field(default_factory=lambda: DNSClient.ConfigSchema())
|
||||
dns_cache: Dict[str, IPv4Address] = {}
|
||||
"A dict of known mappings between domain/URLs names and IPv4 addresses."
|
||||
dns_server: Optional[IPv4Address] = None
|
||||
|
||||
@@ -3,6 +3,7 @@ from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.dns import DNSPacket
|
||||
@@ -13,9 +14,17 @@ from primaite.utils.validation.port import PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class DNSServer(Service):
|
||||
class DNSServer(Service, identifier="DNSServer"):
|
||||
"""Represents a DNS Server as a Service."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for DNSServer."""
|
||||
|
||||
type: str = "DNSServer"
|
||||
domain_mapping: dict = {}
|
||||
|
||||
config: "DNSServer.ConfigSchema" = Field(default_factory=lambda: DNSServer.ConfigSchema())
|
||||
|
||||
dns_table: Dict[str, IPv4Address] = {}
|
||||
"A dict of mappings between domain names and IPv4 addresses."
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
@@ -9,20 +11,28 @@ from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class FTPClient(FTPServiceABC):
|
||||
class FTPClient(FTPServiceABC, identifier="FTPClient"):
|
||||
"""
|
||||
A class for simulating an FTP client service.
|
||||
|
||||
This class inherits from the `Service` class and provides methods to emulate FTP
|
||||
This class inherits from the `FTPServiceABC` class and provides methods to emulate FTP
|
||||
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
|
||||
"""
|
||||
|
||||
config: "FTPClient.ConfigSchema" = Field(default_factory=lambda: FTPClient.ConfigSchema())
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for FTPClient."""
|
||||
|
||||
type: str = "FTPClient"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "FTPClient"
|
||||
kwargs["port"] = PORT_LOOKUP["FTP"]
|
||||
|
||||
@@ -1,26 +1,36 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.port import is_valid_port, PORT_LOOKUP
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class FTPServer(FTPServiceABC):
|
||||
class FTPServer(FTPServiceABC, identifier="FTPServer"):
|
||||
"""
|
||||
A class for simulating an FTP server service.
|
||||
|
||||
This class inherits from the `Service` class and provides methods to emulate FTP
|
||||
This class inherits from the `FTPServiceABC` class and provides methods to emulate FTP
|
||||
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
|
||||
"""
|
||||
|
||||
config: "FTPServer.ConfigSchema" = Field(default_factory=lambda: FTPServer.ConfigSchema())
|
||||
|
||||
server_password: Optional[str] = None
|
||||
"""Password needed to connect to FTP server. Default is None."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for FTPServer."""
|
||||
|
||||
type: str = "FTPServer"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "FTPServer"
|
||||
kwargs["port"] = PORT_LOOKUP["FTP"]
|
||||
|
||||
@@ -3,6 +3,8 @@ import secrets
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.hardware.base import NetworkInterface
|
||||
from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType
|
||||
@@ -14,7 +16,7 @@ from primaite.utils.validation.port import PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class ICMP(Service):
|
||||
class ICMP(Service, identifier="ICMP"):
|
||||
"""
|
||||
The Internet Control Message Protocol (ICMP) service.
|
||||
|
||||
@@ -22,6 +24,13 @@ class ICMP(Service):
|
||||
network diagnostics, notably the ping command.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for ICMP."""
|
||||
|
||||
type: str = "ICMP"
|
||||
|
||||
config: "ICMP.ConfigSchema" = Field(default_factory=lambda: ICMP.ConfigSchema())
|
||||
|
||||
request_replies: Dict = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -3,6 +3,8 @@ from datetime import datetime
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.ntp import NTPPacket
|
||||
from primaite.simulator.system.services.service import Service, ServiceOperatingState
|
||||
@@ -12,9 +14,16 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class NTPClient(Service):
|
||||
class NTPClient(Service, identifier="NTPClient"):
|
||||
"""Represents a NTP client as a service."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for NTPClient."""
|
||||
|
||||
type: str = "NTPClient"
|
||||
|
||||
config: "NTPClient.ConfigSchema" = Field(default_factory=lambda: NTPClient.ConfigSchema())
|
||||
|
||||
ntp_server: Optional[IPv4Address] = None
|
||||
"The NTP server the client sends requests to."
|
||||
time: Optional[datetime] = None
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.ntp import NTPPacket
|
||||
from primaite.simulator.system.services.service import Service
|
||||
@@ -11,9 +13,16 @@ from primaite.utils.validation.port import PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class NTPServer(Service):
|
||||
class NTPServer(Service, identifier="NTPServer"):
|
||||
"""Represents a NTP server as a service."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for NTPServer."""
|
||||
|
||||
type: str = "NTPServer"
|
||||
|
||||
config: "NTPServer.ConfigSchema" = Field(default_factory=lambda: NTPServer.ConfigSchema())
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "NTPServer"
|
||||
kwargs["port"] = PORT_LOOKUP["NTP"]
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar, Dict, Optional, Type
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType
|
||||
@@ -37,6 +39,13 @@ class Service(IOSoftware):
|
||||
Services are programs that run in the background and may perform input/output operations.
|
||||
"""
|
||||
|
||||
class ConfigSchema(IOSoftware.ConfigSchema, ABC):
|
||||
"""Config Schema for Service class."""
|
||||
|
||||
type: str
|
||||
|
||||
config: "Service.ConfigSchema" = Field(default_factory=lambda: Service.ConfigSchema())
|
||||
|
||||
operating_state: ServiceOperatingState = ServiceOperatingState.STOPPED
|
||||
"The current operating state of the Service."
|
||||
|
||||
@@ -69,6 +78,21 @@ class Service(IOSoftware):
|
||||
raise ValueError(f"Tried to define new hostnode {identifier}, but this name is already reserved.")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "Service":
|
||||
"""Create a service from a config dictionary.
|
||||
|
||||
:param config: dict of options for service components constructor
|
||||
:type config: dict
|
||||
:return: The service component.
|
||||
:rtype: Service
|
||||
"""
|
||||
if config["type"] not in cls._registry:
|
||||
raise ValueError(f"Invalid service type {config['type']}")
|
||||
service_class = cls._registry[config["type"]]
|
||||
service_object = service_class(config=service_class.ConfigSchema(**config))
|
||||
return service_object
|
||||
|
||||
def _can_perform_action(self) -> bool:
|
||||
"""
|
||||
Checks if the service can perform actions.
|
||||
@@ -232,14 +256,14 @@ class Service(IOSoftware):
|
||||
|
||||
def disable(self) -> bool:
|
||||
"""Disable the service."""
|
||||
self.sys_log.info(f"Disabling Application {self.name}")
|
||||
self.sys_log.info(f"Disabling Service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.DISABLED
|
||||
return True
|
||||
|
||||
def enable(self) -> bool:
|
||||
"""Enable the disabled service."""
|
||||
if self.operating_state == ServiceOperatingState.DISABLED:
|
||||
self.sys_log.info(f"Enabling Application {self.name}")
|
||||
self.sys_log.info(f"Enabling Service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.STOPPED
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -7,7 +7,7 @@ from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
@@ -129,9 +129,16 @@ class RemoteTerminalConnection(TerminalClientConnection):
|
||||
return self.parent_terminal.send(payload=payload, session_id=self.ssh_session_id)
|
||||
|
||||
|
||||
class Terminal(Service):
|
||||
class Terminal(Service, identifier="Terminal"):
|
||||
"""Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for Terminal."""
|
||||
|
||||
type: str = "Terminal"
|
||||
|
||||
config: "Terminal.ConfigSchema" = Field(default_factory=lambda: Terminal.ConfigSchema())
|
||||
|
||||
_client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {}
|
||||
"""Dictionary of connect requests made to remote nodes."""
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.http import (
|
||||
HttpRequestMethod,
|
||||
@@ -19,9 +21,16 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class WebServer(Service):
|
||||
class WebServer(Service, identifier="WebServer"):
|
||||
"""Class used to represent a Web Server Service in simulation."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for WebServer."""
|
||||
|
||||
type: str = "WebServer"
|
||||
|
||||
config: "WebServer.ConfigSchema" = Field(default_factory=lambda: WebServer.ConfigSchema())
|
||||
|
||||
response_codes_this_timestep: List[HttpStatusCode] = []
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
import copy
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Optional, Set, TYPE_CHECKING, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
@@ -70,7 +70,7 @@ class SoftwareCriticality(Enum):
|
||||
"The highest level of criticality."
|
||||
|
||||
|
||||
class Software(SimComponent):
|
||||
class Software(SimComponent, ABC):
|
||||
"""
|
||||
A base class representing software in a simulator environment.
|
||||
|
||||
@@ -78,14 +78,22 @@ class Software(SimComponent):
|
||||
It outlines the fundamental attributes and behaviors expected of any software in the simulation.
|
||||
"""
|
||||
|
||||
class ConfigSchema(BaseModel, ABC):
|
||||
"""Configurable options for all software."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
starting_health_state: SoftwareHealthState = SoftwareHealthState.UNUSED
|
||||
criticality: SoftwareCriticality = SoftwareCriticality.LOWEST
|
||||
fixing_duration: int = 2
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: Software.ConfigSchema())
|
||||
|
||||
name: str
|
||||
"The name of the software."
|
||||
health_state_actual: SoftwareHealthState = SoftwareHealthState.UNUSED
|
||||
"The actual health state of the software."
|
||||
health_state_visible: SoftwareHealthState = SoftwareHealthState.UNUSED
|
||||
"The health state of the software visible to the red agent."
|
||||
criticality: SoftwareCriticality = SoftwareCriticality.LOWEST
|
||||
"The criticality level of the software."
|
||||
fixing_count: int = 0
|
||||
"The count of patches applied to the software, defaults to 0."
|
||||
scanning_count: int = 0
|
||||
@@ -100,11 +108,13 @@ class Software(SimComponent):
|
||||
"The FileSystem of the Node the Software is installed on."
|
||||
folder: Optional[Folder] = None
|
||||
"The folder on the file system the Software uses."
|
||||
fixing_duration: int = 2
|
||||
"The number of ticks it takes to patch the software."
|
||||
_fixing_countdown: Optional[int] = None
|
||||
"Current number of ticks left to patch the software."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.health_state_actual = self.config.starting_health_state # don't remove this
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
@@ -152,7 +162,7 @@ class Software(SimComponent):
|
||||
{
|
||||
"health_state_actual": self.health_state_actual.value,
|
||||
"health_state_visible": self.health_state_visible.value,
|
||||
"criticality": self.criticality.value,
|
||||
"criticality": self.config.criticality.value,
|
||||
"fixing_count": self.fixing_count,
|
||||
"scanning_count": self.scanning_count,
|
||||
"revealed_to_red": self.revealed_to_red,
|
||||
@@ -201,7 +211,7 @@ class Software(SimComponent):
|
||||
def fix(self) -> bool:
|
||||
"""Perform a fix on the software."""
|
||||
if self.health_state_actual in (SoftwareHealthState.COMPROMISED, SoftwareHealthState.GOOD):
|
||||
self._fixing_countdown = self.fixing_duration
|
||||
self._fixing_countdown = self.config.fixing_duration
|
||||
self.set_health_state(SoftwareHealthState.FIXING)
|
||||
return True
|
||||
return False
|
||||
@@ -233,7 +243,7 @@ class Software(SimComponent):
|
||||
super().pre_timestep(timestep)
|
||||
|
||||
|
||||
class IOSoftware(Software):
|
||||
class IOSoftware(Software, ABC):
|
||||
"""
|
||||
Represents software in a simulator environment that is capable of input/output operations.
|
||||
|
||||
@@ -243,6 +253,13 @@ class IOSoftware(Software):
|
||||
required.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Software.ConfigSchema, ABC):
|
||||
"""Configuration options for all IO Software."""
|
||||
|
||||
listen_on_ports: Set[Port] = Field(default_factory=set)
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: IOSoftware.ConfigSchema())
|
||||
|
||||
installing_count: int = 0
|
||||
"The number of times the software has been installed. Default is 0."
|
||||
max_sessions: int = 100
|
||||
@@ -260,6 +277,10 @@ class IOSoftware(Software):
|
||||
_connections: Dict[str, Dict] = {}
|
||||
"Active connections."
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.listen_on_ports = self.config.listen_on_ports
|
||||
|
||||
@abstractmethod
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
|
||||
@@ -148,7 +148,7 @@ simulation:
|
||||
options:
|
||||
db_server_ip: 192.168.1.10
|
||||
server_password: arcd
|
||||
fix_duration: 1
|
||||
fixing_duration: 1
|
||||
- type: DataManipulationBot
|
||||
options:
|
||||
port_scan_p_of_success: 0.8
|
||||
@@ -169,7 +169,7 @@ simulation:
|
||||
arcd.com: 192.168.1.10
|
||||
- type: DatabaseService
|
||||
options:
|
||||
fix_duration: 5
|
||||
fixing_duration: 5
|
||||
backup_server_ip: 192.168.1.10
|
||||
- type: WebServer
|
||||
- type: FTPClient
|
||||
@@ -142,19 +142,19 @@ simulation:
|
||||
applications:
|
||||
- type: NMAP
|
||||
options:
|
||||
fix_duration: 1
|
||||
fixing_duration: 1
|
||||
- type: RansomwareScript
|
||||
options:
|
||||
fix_duration: 1
|
||||
fixing_duration: 1
|
||||
- type: WebBrowser
|
||||
options:
|
||||
target_url: http://arcd.com/users/
|
||||
fix_duration: 1
|
||||
fixing_duration: 1
|
||||
- type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.10
|
||||
server_password: arcd
|
||||
fix_duration: 1
|
||||
fixing_duration: 1
|
||||
- type: DataManipulationBot
|
||||
options:
|
||||
port_scan_p_of_success: 0.8
|
||||
@@ -162,43 +162,44 @@ simulation:
|
||||
payload: "DELETE"
|
||||
server_ip: 192.168.1.21
|
||||
server_password: arcd
|
||||
fix_duration: 1
|
||||
fixing_duration: 1
|
||||
- type: DoSBot
|
||||
options:
|
||||
target_ip_address: 192.168.10.21
|
||||
payload: SPOOF DATA
|
||||
port_scan_p_of_success: 0.8
|
||||
fix_duration: 1
|
||||
fixing_duration: 1
|
||||
services:
|
||||
- type: DNSClient
|
||||
options:
|
||||
fix_duration: 3
|
||||
dns_server: 192.168.1.10
|
||||
fixing_duration: 3
|
||||
- type: DNSServer
|
||||
options:
|
||||
fix_duration: 3
|
||||
fixing_duration: 3
|
||||
domain_mapping:
|
||||
arcd.com: 192.168.1.10
|
||||
- type: DatabaseService
|
||||
options:
|
||||
backup_server_ip: 192.168.1.10
|
||||
fix_duration: 3
|
||||
fixing_duration: 3
|
||||
- type: WebServer
|
||||
options:
|
||||
fix_duration: 3
|
||||
fixing_duration: 3
|
||||
- type: FTPClient
|
||||
options:
|
||||
fix_duration: 3
|
||||
fixing_duration: 3
|
||||
- type: FTPServer
|
||||
options:
|
||||
server_password: arcd
|
||||
fix_duration: 3
|
||||
fixing_duration: 3
|
||||
- type: NTPClient
|
||||
options:
|
||||
ntp_server_ip: 192.168.1.10
|
||||
fix_duration: 3
|
||||
fixing_duration: 3
|
||||
- type: NTPServer
|
||||
options:
|
||||
fix_duration: 3
|
||||
fixing_duration: 3
|
||||
- hostname: client_2
|
||||
type: computer
|
||||
ip_address: 192.168.10.22
|
||||
@@ -39,9 +39,16 @@ ACTION_SPACE_NODE_ACTION_VALUES = 1
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class DummyService(Service):
|
||||
class DummyService(Service, identifier="DummyService"):
|
||||
"""Test Service class"""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for DummyService."""
|
||||
|
||||
type: str = "DummyService"
|
||||
|
||||
config: "DummyService.ConfigSchema" = Field(default_factory=lambda: DummyService.ConfigSchema())
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
return super().describe_state()
|
||||
|
||||
@@ -58,6 +65,13 @@ class DummyService(Service):
|
||||
class DummyApplication(Application, identifier="DummyApplication"):
|
||||
"""Test Application class"""
|
||||
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""ConfigSchema for DummyApplication."""
|
||||
|
||||
type: str = "DummyApplication"
|
||||
|
||||
config: "DummyApplication.ConfigSchema" = Field(default_factory=lambda: DummyApplication.ConfigSchema())
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "DummyApplication"
|
||||
kwargs["port"] = PORT_LOOKUP["HTTP"]
|
||||
|
||||
@@ -13,8 +13,8 @@ from primaite.simulator.system.services.database.database_service import Databas
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fix_duration.yaml"
|
||||
ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fix_duration_one_item.yaml"
|
||||
TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fixing_duration.yaml"
|
||||
ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fixing_duration_one_item.yaml"
|
||||
|
||||
TestApplications = ["DummyApplication", "BroadcastTestClient"]
|
||||
|
||||
@@ -27,27 +27,27 @@ def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
|
||||
return PrimaiteGame.from_config(cfg)
|
||||
|
||||
|
||||
def test_default_fix_duration():
|
||||
"""Test that software with no defined fix duration in config uses the default fix duration of 2."""
|
||||
def test_default_fixing_duration():
|
||||
"""Test that software with no defined fixing duration in config uses the default fixing duration of 2."""
|
||||
game = load_config(TEST_CONFIG)
|
||||
client_2: Computer = game.simulation.network.get_node_by_hostname("client_2")
|
||||
|
||||
database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
|
||||
assert database_client.fixing_duration == 2
|
||||
assert database_client.config.fixing_duration == 2
|
||||
|
||||
dns_client: DNSClient = client_2.software_manager.software.get("DNSClient")
|
||||
assert dns_client.fixing_duration == 2
|
||||
assert dns_client.config.fixing_duration == 2
|
||||
|
||||
|
||||
def test_fix_duration_set_from_config():
|
||||
"""Test to check that the fix duration set for applications and services works as intended."""
|
||||
def test_fixing_duration_set_from_config():
|
||||
"""Test to check that the fixing duration set for applications and services works as intended."""
|
||||
game = load_config(TEST_CONFIG)
|
||||
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
# in config - services take 3 timesteps to fix
|
||||
for service in ["DNSClient", "DNSServer", "DatabaseService", "WebServer", "FTPClient", "FTPServer", "NTPServer"]:
|
||||
assert client_1.software_manager.software.get(service) is not None
|
||||
assert client_1.software_manager.software.get(service).fixing_duration == 3
|
||||
assert client_1.software_manager.software.get(service).config.fixing_duration == 3
|
||||
|
||||
# in config - applications take 1 timestep to fix
|
||||
# remove test applications from list
|
||||
@@ -55,27 +55,27 @@ def test_fix_duration_set_from_config():
|
||||
|
||||
for application in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot", "DatabaseClient"]:
|
||||
assert client_1.software_manager.software.get(application) is not None
|
||||
assert client_1.software_manager.software.get(application).fixing_duration == 1
|
||||
assert client_1.software_manager.software.get(application).config.fixing_duration == 1
|
||||
|
||||
|
||||
def test_fix_duration_for_one_item():
|
||||
"""Test that setting fix duration for one application does not affect other components."""
|
||||
def test_fixing_duration_for_one_item():
|
||||
"""Test that setting fixing duration for one application does not affect other components."""
|
||||
game = load_config(ONE_ITEM_CONFIG)
|
||||
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
# in config - services take 3 timesteps to fix
|
||||
for service in ["DNSClient", "DNSServer", "WebServer", "FTPClient", "FTPServer", "NTPServer"]:
|
||||
assert client_1.software_manager.software.get(service) is not None
|
||||
assert client_1.software_manager.software.get(service).fixing_duration == 2
|
||||
assert client_1.software_manager.software.get(service).config.fixing_duration == 2
|
||||
|
||||
# in config - applications take 1 timestep to fix
|
||||
# remove test applications from list
|
||||
for applications in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot"]:
|
||||
assert client_1.software_manager.software.get(applications) is not None
|
||||
assert client_1.software_manager.software.get(applications).fixing_duration == 2
|
||||
assert client_1.software_manager.software.get(applications).config.fixing_duration == 2
|
||||
|
||||
database_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
|
||||
assert database_client.fixing_duration == 1
|
||||
assert database_client.config.fixing_duration == 1
|
||||
|
||||
database_service: DatabaseService = client_1.software_manager.software.get("DatabaseService")
|
||||
assert database_service.fixing_duration == 5
|
||||
assert database_service.config.fixing_duration == 5
|
||||
@@ -4,7 +4,7 @@ from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
@@ -31,6 +31,14 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"):
|
||||
The application requests and loads web pages using its domain name and requesting IP addresses using DNS.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""ConfigSchema for ExtendedApplication."""
|
||||
|
||||
type: str = "ExtendedApplication"
|
||||
target_url: Optional[str] = None
|
||||
|
||||
config: "ExtendedApplication.ConfigSchema" = Field(default_factory=lambda: ExtendedApplication.ConfigSchema())
|
||||
|
||||
target_url: Optional[str] = None
|
||||
|
||||
domain_name_ip_address: Optional[IPv4Address] = None
|
||||
@@ -50,6 +58,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"):
|
||||
kwargs["port"] = PORT_LOOKUP["HTTP"]
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.target_url = self.config.target_url
|
||||
self.run()
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
|
||||
@@ -3,6 +3,8 @@ from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
|
||||
@@ -17,13 +19,20 @@ from primaite.utils.validation.port import PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class ExtendedService(Service, identifier="extendedservice"):
|
||||
class ExtendedService(Service, identifier="ExtendedService"):
|
||||
"""
|
||||
A copy of DatabaseService that uses the extension framework instead of being part of PrimAITE.
|
||||
|
||||
This class inherits from the `Service` class and provides methods to simulate a SQL database.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for ExtendedService."""
|
||||
|
||||
type: str = "ExtendedService"
|
||||
|
||||
config: "ExtendedService.ConfigSchema" = Field(default_factory=lambda: ExtendedService.ConfigSchema())
|
||||
|
||||
password: Optional[str] = None
|
||||
"""Password that needs to be provided by clients if they want to connect to the DatabaseService."""
|
||||
|
||||
|
||||
@@ -191,7 +191,7 @@ def test_nic_monitored_traffic(simulation):
|
||||
|
||||
# send a database query
|
||||
browser: WebBrowser = pc.software_manager.software.get("WebBrowser")
|
||||
browser.target_url = f"http://arcd.com/"
|
||||
browser.config.target_url = f"http://arcd.com/"
|
||||
browser.get_webpage()
|
||||
|
||||
traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC")
|
||||
|
||||
@@ -183,7 +183,7 @@ def test_router_acl_removerule_integration(game_and_agent: Tuple[PrimaiteGame, P
|
||||
|
||||
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
|
||||
browser.run()
|
||||
browser.target_url = "http://www.example.com"
|
||||
browser.config.target_url = "http://www.example.com"
|
||||
assert browser.get_webpage() # check that the browser can access example.com before we block it
|
||||
|
||||
# 2: Remove rule that allows HTTP traffic across the network
|
||||
@@ -216,7 +216,7 @@ def test_host_nic_disable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
|
||||
|
||||
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
|
||||
browser.run()
|
||||
browser.target_url = "http://www.example.com"
|
||||
browser.config.target_url = "http://www.example.com"
|
||||
assert browser.get_webpage() # check that the browser can access example.com before we block it
|
||||
|
||||
# 2: Disable the NIC on client_1
|
||||
@@ -416,7 +416,7 @@ def test_network_router_port_disable_integration(game_and_agent: Tuple[PrimaiteG
|
||||
|
||||
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
|
||||
browser.run()
|
||||
browser.target_url = "http://www.example.com"
|
||||
browser.config.target_url = "http://www.example.com"
|
||||
assert browser.get_webpage() # check that the browser can access example.com before we block it
|
||||
|
||||
# 2: Disable the NIC on client_1
|
||||
@@ -476,7 +476,7 @@ def test_node_application_scan_integration(game_and_agent: Tuple[PrimaiteGame, P
|
||||
|
||||
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
|
||||
browser.run()
|
||||
browser.target_url = "http://www.example.com"
|
||||
browser.config.target_url = "http://www.example.com"
|
||||
assert browser.get_webpage() # check that the browser can access example.com
|
||||
|
||||
assert browser.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
@@ -29,7 +29,7 @@ def test_WebpageUnavailablePenalty(game_and_agent: tuple[PrimaiteGame, Controlle
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
|
||||
browser.run()
|
||||
browser.target_url = "http://www.example.com"
|
||||
browser.config.target_url = "http://www.example.com"
|
||||
agent.reward_function.register_component(comp, 0.7)
|
||||
|
||||
# Check that before trying to fetch the webpage, the reward is 0.0
|
||||
|
||||
@@ -3,6 +3,7 @@ from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
@@ -14,9 +15,16 @@ from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.port import PORT_LOOKUP
|
||||
|
||||
|
||||
class BroadcastTestService(Service):
|
||||
class BroadcastTestService(Service, identifier="BroadcastTestService"):
|
||||
"""A service for sending broadcast and unicast messages over a network."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for BroadcastTestService."""
|
||||
|
||||
type: str = "BroadcastTestService"
|
||||
|
||||
config: "BroadcastTestService.ConfigSchema" = Field(default_factory=lambda: BroadcastTestService.ConfigSchema())
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Set default service properties for broadcasting
|
||||
kwargs["name"] = "BroadcastService"
|
||||
@@ -46,6 +54,13 @@ class BroadcastTestService(Service):
|
||||
class BroadcastTestClient(Application, identifier="BroadcastTestClient"):
|
||||
"""A client application to receive broadcast and unicast messages."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for BroadcastTestClient."""
|
||||
|
||||
type: str = "BroadcastTestClient"
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: BroadcastTestClient.ConfigSchema())
|
||||
|
||||
payloads_received: List = []
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -495,6 +495,12 @@ def test_c2_suite_yaml():
|
||||
|
||||
computer_b: Computer = yaml_network.get_node_by_hostname("node_b")
|
||||
c2_beacon: C2Beacon = computer_b.software_manager.software.get("C2Beacon")
|
||||
c2_beacon.configure(
|
||||
c2_server_ip_address=c2_beacon.config.c2_server_ip_address,
|
||||
keep_alive_frequency=c2_beacon.config.keep_alive_frequency,
|
||||
masquerade_port=c2_beacon.config.masquerade_port,
|
||||
masquerade_protocol=c2_beacon.config.masquerade_protocol,
|
||||
)
|
||||
|
||||
assert c2_server.operating_state == ApplicationOperatingState.RUNNING
|
||||
|
||||
|
||||
@@ -232,7 +232,7 @@ def test_database_service_fix(uc2_network):
|
||||
assert db_service.health_state_actual == SoftwareHealthState.FIXING
|
||||
|
||||
# apply timestep until the fix is applied
|
||||
for i in range(db_service.fixing_duration + 1):
|
||||
for i in range(db_service.config.fixing_duration + 1):
|
||||
uc2_network.apply_timestep(i)
|
||||
|
||||
assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD
|
||||
@@ -266,7 +266,7 @@ def test_database_cannot_be_queried_while_fixing(uc2_network):
|
||||
assert db_connection.query(sql="SELECT") is False
|
||||
|
||||
# apply timestep until the fix is applied
|
||||
for i in range(db_service.fixing_duration + 1):
|
||||
for i in range(db_service.config.fixing_duration + 1):
|
||||
uc2_network.apply_timestep(i)
|
||||
|
||||
assert db_service.health_state_actual == SoftwareHealthState.GOOD
|
||||
@@ -308,7 +308,7 @@ def test_database_can_create_connection_while_fixing(uc2_network):
|
||||
assert new_db_connection.query(sql="SELECT") is False # still should fail to query because FIXING
|
||||
|
||||
# apply timestep until the fix is applied
|
||||
for i in range(db_service.fixing_duration + 1):
|
||||
for i in range(db_service.config.fixing_duration + 1):
|
||||
uc2_network.apply_timestep(i)
|
||||
|
||||
assert db_service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
@@ -14,7 +14,14 @@ from primaite.utils.validation.port import PORT_LOOKUP
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
|
||||
class _DatabaseListener(Service):
|
||||
class _DatabaseListener(Service, identifier="_DatabaseListener"):
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for _DatabaseListener."""
|
||||
|
||||
type: str = "_DatabaseListener"
|
||||
listen_on_ports: Set[int] = {PORT_LOOKUP["POSTGRES_SERVER"]}
|
||||
|
||||
config: "_DatabaseListener.ConfigSchema" = Field(default_factory=lambda: _DatabaseListener.ConfigSchema())
|
||||
name: str = "DatabaseListener"
|
||||
protocol: str = PROTOCOL_LOOKUP["TCP"]
|
||||
port: int = PORT_LOOKUP["NONE"]
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_web_page_get_users_page_request_with_domain_name(web_client_and_web_ser
|
||||
web_browser_app, computer, web_server_service, server = web_client_and_web_server
|
||||
|
||||
web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address
|
||||
web_browser_app.target_url = f"http://arcd.com/"
|
||||
web_browser_app.config.target_url = f"http://arcd.com/"
|
||||
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
|
||||
|
||||
assert web_browser_app.get_webpage() is True
|
||||
@@ -66,7 +66,7 @@ def test_web_page_get_users_page_request_with_ip_address(web_client_and_web_serv
|
||||
web_browser_app, computer, web_server_service, server = web_client_and_web_server
|
||||
|
||||
web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address
|
||||
web_browser_app.target_url = f"http://{web_server_ip}/"
|
||||
web_browser_app.config.target_url = f"http://{web_server_ip}/"
|
||||
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
|
||||
|
||||
assert web_browser_app.get_webpage() is True
|
||||
@@ -81,7 +81,7 @@ def test_web_page_request_from_shut_down_server(web_client_and_web_server):
|
||||
web_browser_app, computer, web_server_service, server = web_client_and_web_server
|
||||
|
||||
web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address
|
||||
web_browser_app.target_url = f"http://arcd.com/"
|
||||
web_browser_app.config.target_url = f"http://arcd.com/"
|
||||
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
|
||||
|
||||
assert web_browser_app.get_webpage() is True
|
||||
@@ -108,7 +108,7 @@ def test_web_page_request_from_closed_web_browser(web_client_and_web_server):
|
||||
web_browser_app, computer, web_server_service, server = web_client_and_web_server
|
||||
|
||||
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
|
||||
web_browser_app.target_url = f"http://arcd.com/"
|
||||
web_browser_app.config.target_url = f"http://arcd.com/"
|
||||
assert web_browser_app.get_webpage() is True
|
||||
|
||||
# latest response should have status code 200
|
||||
|
||||
@@ -74,7 +74,7 @@ def web_client_web_server_database(example_network) -> Tuple[Network, Computer,
|
||||
# Install Web Browser on computer
|
||||
computer.software_manager.install(WebBrowser)
|
||||
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
|
||||
web_browser.target_url = "http://arcd.com/users/"
|
||||
web_browser.config.target_url = "http://arcd.com/users/"
|
||||
web_browser.run()
|
||||
|
||||
# Install DNS Client service on computer
|
||||
@@ -131,7 +131,7 @@ def test_database_fix_disrupts_web_client(uc2_network):
|
||||
|
||||
assert web_browser.get_webpage() is False
|
||||
|
||||
for i in range(database_service.fixing_duration + 1):
|
||||
for i in range(database_service.config.fixing_duration + 1):
|
||||
uc2_network.apply_timestep(i)
|
||||
|
||||
assert database_service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
@@ -128,13 +128,13 @@ def test_c2_handle_switching_port(basic_c2_network):
|
||||
assert c2_server.c2_connection_active is True
|
||||
|
||||
# Assert to confirm that both the C2 server and the C2 beacon are configured correctly.
|
||||
assert c2_beacon.c2_config.keep_alive_frequency is 2
|
||||
assert c2_beacon.c2_config.masquerade_port is PORT_LOOKUP["HTTP"]
|
||||
assert c2_beacon.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
|
||||
assert c2_beacon.config.keep_alive_frequency is 2
|
||||
assert c2_beacon.config.masquerade_port is PORT_LOOKUP["HTTP"]
|
||||
assert c2_beacon.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
|
||||
|
||||
assert c2_server.c2_config.keep_alive_frequency is 2
|
||||
assert c2_server.c2_config.masquerade_port is PORT_LOOKUP["HTTP"]
|
||||
assert c2_server.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
|
||||
assert c2_server.config.keep_alive_frequency is 2
|
||||
assert c2_server.config.masquerade_port is PORT_LOOKUP["HTTP"]
|
||||
assert c2_server.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
|
||||
|
||||
# Configuring the C2 Beacon.
|
||||
c2_beacon.configure(
|
||||
@@ -150,11 +150,11 @@ def test_c2_handle_switching_port(basic_c2_network):
|
||||
|
||||
# Assert to confirm that both the C2 server and the C2 beacon
|
||||
# Have reconfigured their C2 settings.
|
||||
assert c2_beacon.c2_config.masquerade_port is PORT_LOOKUP["FTP"]
|
||||
assert c2_beacon.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
|
||||
assert c2_beacon.config.masquerade_port is PORT_LOOKUP["FTP"]
|
||||
assert c2_beacon.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
|
||||
|
||||
assert c2_server.c2_config.masquerade_port is PORT_LOOKUP["FTP"]
|
||||
assert c2_server.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
|
||||
assert c2_server.config.masquerade_port is PORT_LOOKUP["FTP"]
|
||||
assert c2_server.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
|
||||
|
||||
|
||||
def test_c2_handle_switching_frequency(basic_c2_network):
|
||||
@@ -174,8 +174,8 @@ def test_c2_handle_switching_frequency(basic_c2_network):
|
||||
assert c2_server.c2_connection_active is True
|
||||
|
||||
# Assert to confirm that both the C2 server and the C2 beacon are configured correctly.
|
||||
assert c2_beacon.c2_config.keep_alive_frequency is 2
|
||||
assert c2_server.c2_config.keep_alive_frequency is 2
|
||||
assert c2_beacon.config.keep_alive_frequency is 2
|
||||
assert c2_server.config.keep_alive_frequency is 2
|
||||
|
||||
# Configuring the C2 Beacon.
|
||||
c2_beacon.configure(c2_server_ip_address="192.168.0.1", keep_alive_frequency=10)
|
||||
@@ -186,8 +186,8 @@ def test_c2_handle_switching_frequency(basic_c2_network):
|
||||
|
||||
# Assert to confirm that both the C2 server and the C2 beacon
|
||||
# Have reconfigured their C2 settings.
|
||||
assert c2_beacon.c2_config.keep_alive_frequency is 10
|
||||
assert c2_server.c2_config.keep_alive_frequency is 10
|
||||
assert c2_beacon.config.keep_alive_frequency is 10
|
||||
assert c2_server.config.keep_alive_frequency is 10
|
||||
|
||||
# Now skipping 9 time steps to confirm keep alive inactivity
|
||||
for i in range(9):
|
||||
|
||||
@@ -148,7 +148,7 @@ def test_service_fixing(service):
|
||||
service.fix()
|
||||
assert service.health_state_actual == SoftwareHealthState.FIXING
|
||||
|
||||
for i in range(service.fixing_duration + 1):
|
||||
for i in range(service.config.fixing_duration + 1):
|
||||
service.apply_timestep(i)
|
||||
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.services.service import Service
|
||||
@@ -10,7 +11,14 @@ from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.port import PORT_LOOKUP
|
||||
|
||||
|
||||
class TestSoftware(Service):
|
||||
class TestSoftware(Service, identifier="TestSoftware"):
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSChema for TestSoftware."""
|
||||
|
||||
type: str = "TestSoftware"
|
||||
|
||||
config: "TestSoftware.ConfigSchema" = Field(default_factory=lambda: TestSoftware.ConfigSchema())
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user