#2888 - Put software configuration items in the ConfigSchema

This commit is contained in:
Marek Wolan
2025-01-03 16:26:12 +00:00
parent c481847b01
commit 30d8f14251
13 changed files with 125 additions and 97 deletions

View File

@@ -50,7 +50,7 @@ from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.system.services.web_server.web_server import WebServer
from primaite.simulator.system.software import Software
from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP
from primaite.utils.validation.ip_protocol import IPProtocol
from primaite.utils.validation.port import Port, PORT_LOOKUP
_LOGGER = getLogger(__name__)
@@ -422,74 +422,20 @@ class PrimaiteGame:
application_type = application_cfg["type"]
if application_type in Application._registry:
new_node.software_manager.install(Application._registry[application_type])
application_class = Application._registry[application_type]
application_options = application_cfg.get("options", {})
application_options["type"] = application_type
new_node.software_manager.install(application_class, software_config=application_options)
new_application = new_node.software_manager.software[application_type] # grab the instance
# fixing duration for the application
if "fix_duration" in application_cfg.get("options", {}):
new_application.fixing_duration = application_cfg["options"]["fix_duration"]
else:
msg = f"Configuration contains an invalid application type: {application_type}"
_LOGGER.error(msg)
raise ValueError(msg)
_set_software_listen_on_ports(new_application, application_cfg)
# run the application
new_application.run()
if application_type == "DataManipulationBot":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
server_ip_address=IPv4Address(opt.get("server_ip")),
server_password=opt.get("server_password"),
payload=opt.get("payload", "DELETE"),
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")),
)
elif application_type == "RansomwareScript":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
server_ip_address=IPv4Address(opt.get("server_ip")) if opt.get("server_ip") else None,
server_password=opt.get("server_password"),
payload=opt.get("payload", "ENCRYPT"),
)
elif application_type == "DatabaseClient":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
server_ip_address=IPv4Address(opt.get("db_server_ip")),
server_password=opt.get("server_password"),
)
elif application_type == "WebBrowser":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.target_url = opt.get("target_url")
elif application_type == "DoSBot":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
target_ip_address=IPv4Address(opt.get("target_ip_address")),
target_port=PORT_LOOKUP[opt.get("target_port", "POSTGRES_SERVER")],
payload=opt.get("payload"),
repeat=bool(opt.get("repeat")),
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
dos_intensity=float(opt.get("dos_intensity", "1.0")),
max_sessions=int(opt.get("max_sessions", "1000")),
)
elif application_type == "C2Beacon":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
c2_server_ip_address=IPv4Address(opt.get("c2_server_ip_address")),
keep_alive_frequency=(opt.get("keep_alive_frequency", 5)),
masquerade_protocol=PROTOCOL_LOOKUP[
(opt.get("masquerade_protocol", PROTOCOL_LOOKUP["TCP"]))
],
masquerade_port=PORT_LOOKUP[(opt.get("masquerade_port", PORT_LOOKUP["HTTP"]))],
)
if "network_interfaces" in node_cfg:
for nic_num, nic_cfg in node_cfg["network_interfaces"].items():
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))

View File

@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, ClassVar, Dict, Optional, Set, Type
from pydantic import BaseModel, Field
from pydantic import Field
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType
@@ -30,7 +30,7 @@ class Application(IOSoftware, ABC):
Applications are user-facing programs that may perform input/output operations.
"""
class ConfigSchema(BaseModel, ABC):
class ConfigSchema(IOSoftware.ConfigSchema, ABC):
"""Config Schema for Application class."""
type: str

View File

@@ -73,6 +73,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
"""ConfigSchema for DatabaseClient."""
type: str = "DatabaseClient"
db_server_ip: Optional[IPV4Address] = None
server_password: Optional[str] = None
config: ConfigSchema = Field(default_factory=lambda: DatabaseClient.ConfigSchema())
@@ -99,6 +101,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"]
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
self.server_ip_address = self.config.db_server_ip
self.server_password = self.config.server_password
def _init_request_manager(self) -> RequestManager:
"""

View File

@@ -2,7 +2,7 @@
from abc import abstractmethod
from enum import Enum
from ipaddress import IPv4Address
from typing import Dict, Optional, Union
from typing import Dict, Optional, Set, Union
from pydantic import Field, validate_call
@@ -75,6 +75,8 @@ class AbstractC2(Application):
masquerade_port: Port = Field(default=PORT_LOOKUP["HTTP"])
"""The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP."""
listen_on_ports: Set[Port] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]}
config: ConfigSchema = Field(default_factory=lambda: AbstractC2.ConfigSchema())
c2_connection_active: bool = False
@@ -101,6 +103,12 @@ class AbstractC2(Application):
C2 beacon to reconfigure it's configuration settings.
"""
def __init__(self, **kwargs):
"""Initialise the C2 applications to by default listen for HTTP traffic."""
kwargs["port"] = PORT_LOOKUP["NONE"]
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
def _craft_packet(
self, c2_payload: C2Payload, c2_command: Optional[C2Command] = None, command_options: Optional[Dict] = {}
) -> C2Packet:
@@ -141,13 +149,6 @@ class AbstractC2(Application):
"""
return super().describe_state()
def __init__(self, **kwargs):
"""Initialise the C2 applications to by default listen for HTTP traffic."""
kwargs["listen_on_ports"] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]}
kwargs["port"] = PORT_LOOKUP["NONE"]
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
@property
def _host_ftp_client(self) -> Optional[FTPClient]:
"""Return the FTPClient that is installed C2 Application's host.

View File

@@ -12,8 +12,9 @@ from primaite.simulator.system.applications.red_applications.c2 import ExfilOpts
from primaite.simulator.system.applications.red_applications.c2.abstract_c2 import AbstractC2, C2Command, C2Payload
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
from primaite.simulator.system.services.terminal.terminal import Terminal, TerminalClientConnection
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.port import PORT_LOOKUP
from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP
from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import Port, PORT_LOOKUP
class C2Beacon(AbstractC2, identifier="C2Beacon"):
@@ -39,6 +40,10 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
"""ConfigSchema for C2Beacon."""
type: str = "C2Beacon"
c2_server_ip_address: Optional[IPV4Address] = None
keep_alive_frequency: int = 5
masquerade_protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"]
masquerade_port: Port = PORT_LOOKUP["HTTP"]
config: ConfigSchema = Field(default_factory=lambda: C2Beacon.ConfigSchema())

View File

@@ -12,6 +12,7 @@ from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import PORT_LOOKUP
_LOGGER = getLogger(__name__)
@@ -46,6 +47,11 @@ class DataManipulationBot(Application, identifier="DataManipulationBot"):
"""Configuration schema for DataManipulationBot."""
type: str = "DataManipulationBot"
server_ip: Optional[IPV4Address] = None
server_password: Optional[str] = None
payload: str = "DELETE"
port_scan_p_of_success: float = 0.1
data_manipulation_p_of_success: float = 0.1
config: "DataManipulationBot.ConfigSchema" = Field(default_factory=lambda: DataManipulationBot.ConfigSchema())
@@ -65,6 +71,12 @@ class DataManipulationBot(Application, identifier="DataManipulationBot"):
super().__init__(**kwargs)
self._db_connection: Optional[DatabaseClientConnection] = None
self.server_ip_address = self.config.server_ip
self.server_password = self.config.server_password
self.payload = self.config.payload
self.port_scan_p_of_success = self.config.port_scan_p_of_success
self.data_manipulation_p_of_success = self.config.data_manipulation_p_of_success
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.

View File

@@ -9,8 +9,8 @@ from primaite import getLogger
from primaite.game.science import simulate_trial
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import Port, PORT_LOOKUP
_LOGGER = getLogger(__name__)
@@ -35,6 +35,18 @@ class DoSAttackStage(IntEnum):
class DoSBot(DatabaseClient, identifier="DoSBot"):
"""A bot that simulates a Denial of Service attack."""
class ConfigSchema(DatabaseClient.ConfigSchema):
"""ConfigSchema for DoSBot."""
type: str = "DoSBot"
target_ip_address: Optional[IPV4Address] = None
target_port: Port = PORT_LOOKUP["POSTGRES_SERVER"]
payload: Optional[str] = None
repeat: bool = False
port_scan_p_of_success: float = 0.1
dos_intensity: float = 1.0
max_sessions: int = 1000
config: "DoSBot.ConfigSchema" = Field(default_factory=lambda: DoSBot.ConfigSchema())
target_ip_address: Optional[IPv4Address] = None
@@ -58,15 +70,16 @@ class DoSBot(DatabaseClient, identifier="DoSBot"):
dos_intensity: float = 1.0
"""How much of the max sessions will be used by the DoS when attacking."""
class ConfigSchema(Application.ConfigSchema):
"""ConfigSchema for DoSBot."""
type: str = "DoSBot"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.name = "DoSBot"
self.max_sessions = 1000 # override normal max sessions
self.target_ip_address = self.config.target_ip_address
self.target_port = self.config.target_port
self.payload = self.config.payload
self.repeat = self.config.repeat
self.port_scan_p_of_success = self.config.port_scan_p_of_success
self.dos_intensity = self.config.dos_intensity
self.max_sessions = self.config.max_sessions
def _init_request_manager(self) -> RequestManager:
"""

View File

@@ -10,6 +10,7 @@ from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import PORT_LOOKUP
@@ -23,6 +24,9 @@ class RansomwareScript(Application, identifier="RansomwareScript"):
"""ConfigSchema for RansomwareScript."""
type: str = "RansomwareScript"
server_ip: Optional[IPV4Address] = None
server_password: Optional[str] = None
payload: str = "ENCRYPT"
config: "RansomwareScript.ConfigSchema" = Field(default_factory=lambda: RansomwareScript.ConfigSchema())
@@ -40,6 +44,9 @@ class RansomwareScript(Application, identifier="RansomwareScript"):
super().__init__(**kwargs)
self._db_connection: Optional[DatabaseClientConnection] = None
self.server_ip_address = self.config.server_ip
self.server_password = self.config.server_password
self.payload = self.config.payload
def describe_state(self) -> Dict:
"""

View File

@@ -34,6 +34,7 @@ class WebBrowser(Application, identifier="WebBrowser"):
"""ConfigSchema for WebBrowser."""
type: str = "WebBrowser"
target_url: Optional[str] = None
config: "WebBrowser.ConfigSchema" = Field(default_factory=lambda: WebBrowser.ConfigSchema())
@@ -56,6 +57,7 @@ class WebBrowser(Application, identifier="WebBrowser"):
kwargs["port"] = PORT_LOOKUP["HTTP"]
super().__init__(**kwargs)
self.target_url = self.config.target_url
self.run()
def _init_request_manager(self) -> RequestManager:

View File

@@ -106,7 +106,7 @@ class SoftwareManager:
return True
return False
def install(self, software_class: Type[IOSoftware], **install_kwargs):
def install(self, software_class: Type[IOSoftware], software_config: Optional[IOSoftware.ConfigSchema] = None):
"""
Install an Application or Service.
@@ -115,13 +115,22 @@ class SoftwareManager:
if software_class in self._software_class_to_name_map:
self.sys_log.warning(f"Cannot install {software_class} as it is already installed")
return
software = software_class(
software_manager=self,
sys_log=self.sys_log,
file_system=self.file_system,
dns_server=self.dns_server,
**install_kwargs,
)
if software_config is None:
software = software_class(
software_manager=self,
sys_log=self.sys_log,
file_system=self.file_system,
dns_server=self.dns_server,
)
else:
software = software_class(
software_manager=self,
sys_log=self.sys_log,
file_system=self.file_system,
dns_server=self.dns_server,
config=software_config,
)
software.parent = self.node
if isinstance(software, Application):
self.node.applications[software.uuid] = software

View File

@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, ClassVar, Dict, Optional, Type
from pydantic import BaseModel
from pydantic import Field
from primaite import getLogger
from primaite.interface.request import RequestFormat, RequestResponse
@@ -39,7 +39,12 @@ class Service(IOSoftware):
Services are programs that run in the background and may perform input/output operations.
"""
config: "Service.ConfigSchema"
class ConfigSchema(IOSoftware.ConfigSchema, ABC):
"""Config Schema for Service class."""
type: str
config: "Service.ConfigSchema" = Field(default_factory=lambda: Service.ConfigSchema())
operating_state: ServiceOperatingState = ServiceOperatingState.STOPPED
"The current operating state of the Service."
@@ -53,11 +58,6 @@ class Service(IOSoftware):
_registry: ClassVar[Dict[str, Type["Service"]]] = {}
"""Registry of service types. Automatically populated when subclasses are defined."""
class ConfigSchema(BaseModel, ABC):
"""Config Schema for Service class."""
type: str
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -1,13 +1,13 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
import copy
from abc import abstractmethod
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, Optional, Set, TYPE_CHECKING, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import Field
from pydantic import BaseModel, ConfigDict, Field
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType, SimComponent
@@ -70,7 +70,7 @@ class SoftwareCriticality(Enum):
"The highest level of criticality."
class Software(SimComponent):
class Software(SimComponent, ABC):
"""
A base class representing software in a simulator environment.
@@ -78,6 +78,16 @@ class Software(SimComponent):
It outlines the fundamental attributes and behaviors expected of any software in the simulation.
"""
class ConfigSchema(BaseModel, ABC):
"""Configurable options for all software."""
model_config = ConfigDict(extra="forbid")
starting_health_state: SoftwareHealthState = SoftwareHealthState.UNUSED
criticality: SoftwareCriticality = SoftwareCriticality.LOWEST
fixing_duration: int = 2
config: ConfigSchema = Field(default_factory=lambda: Software.ConfigSchema())
name: str
"The name of the software."
health_state_actual: SoftwareHealthState = SoftwareHealthState.UNUSED
@@ -105,6 +115,12 @@ class Software(SimComponent):
_fixing_countdown: Optional[int] = None
"Current number of ticks left to patch the software."
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.health_state_actual = self.config.starting_health_state
self.criticality = self.config.criticality
self.fixing_duration = self.config.fixing_duration
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
@@ -233,7 +249,7 @@ class Software(SimComponent):
super().pre_timestep(timestep)
class IOSoftware(Software):
class IOSoftware(Software, ABC):
"""
Represents software in a simulator environment that is capable of input/output operations.
@@ -243,6 +259,13 @@ class IOSoftware(Software):
required.
"""
class ConfigSchema(Software.ConfigSchema, ABC):
"""Configuration options for all IO Software."""
listen_on_ports: Set[Port] = Field(default_factory=set)
config: ConfigSchema = Field(default_factory=lambda: IOSoftware.ConfigSchema())
installing_count: int = 0
"The number of times the software has been installed. Default is 0."
max_sessions: int = 100
@@ -260,6 +283,10 @@ class IOSoftware(Software):
_connections: Dict[str, Dict] = {}
"Active connections."
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.listen_on_ports = self.config.listen_on_ports
@abstractmethod
def describe_state(self) -> Dict:
"""

View File

@@ -35,6 +35,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"):
"""ConfigSchema for ExtendedApplication."""
type: str = "ExtendedApplication"
target_url: Optional[str] = None
config: "ExtendedApplication.ConfigSchema" = Field(default_factory=lambda: ExtendedApplication.ConfigSchema())
@@ -57,6 +58,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"):
kwargs["port"] = PORT_LOOKUP["HTTP"]
super().__init__(**kwargs)
self.target_url = self.config.target_url
self.run()
def _init_request_manager(self) -> RequestManager: