From 4fb54c9492d79b6dfe6b76784e3914a084bcb7ec Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 31 Jan 2025 12:18:52 +0000 Subject: [PATCH 1/5] #3029 - Add string-ip validator, improve validation, fix minor bugs in pulling schema data --- src/primaite/game/agent/actions/software.py | 21 ++++--- .../agent/observations/acl_observation.py | 22 +++---- .../observations/firewall_observation.py | 21 ++++--- .../agent/observations/node_observations.py | 7 ++- .../agent/observations/router_observation.py | 9 ++- .../agent/scripted_agents/random_agent.py | 6 ++ src/primaite/simulator/network/creation.py | 3 +- .../red_applications/c2/c2_beacon.py | 4 +- .../applications/red_applications/dos_bot.py | 8 +-- .../services/database/database_service.py | 2 + .../system/services/dns/dns_client.py | 60 +++++++++++-------- .../system/services/dns/dns_server.py | 1 + .../system/services/ftp/ftp_server.py | 7 +-- .../system/services/ntp/ntp_client.py | 3 + src/primaite/utils/validation/ipv4_address.py | 9 +++ .../configs/basic_switched_network.yaml | 2 - tests/assets/configs/extended_config.yaml | 2 - ...software_installation_and_configuration.py | 1 - .../test_c2_suite_integration.py | 4 +- 19 files changed, 116 insertions(+), 76 deletions(-) diff --git a/src/primaite/game/agent/actions/software.py b/src/primaite/game/agent/actions/software.py index e0d602ed..81a3a315 100644 --- a/src/primaite/game/agent/actions/software.py +++ b/src/primaite/game/agent/actions/software.py @@ -6,6 +6,9 @@ from pydantic import ConfigDict, Field from primaite.game.agent.actions.manager import AbstractAction from primaite.interface.request import RequestFormat +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.ipv4_address import StrIP +from primaite.utils.validation.port import Port __all__ = ( "ConfigureRansomwareScriptAction", @@ -64,8 +67,8 @@ class ConfigureDoSBotAction(AbstractAction, identifier="configure_dos_bot"): model_config = ConfigDict(extra="forbid") node_name: str - target_ip_address: Optional[str] = None - target_port: Optional[str] = None + target_ip_address: Optional[StrIP] = None + target_port: Optional[Port] = None payload: Optional[str] = None repeat: Optional[bool] = None port_scan_p_of_success: Optional[float] = None @@ -95,10 +98,10 @@ class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2_beacon"): """Configuration schema for ConfigureC2BeaconAction.""" node_name: str - c2_server_ip_address: str + c2_server_ip_address: StrIP keep_alive_frequency: int = Field(default=5, ge=1) - masquerade_protocol: str = Field(default="TCP") - masquerade_port: str = Field(default="HTTP") + masquerade_protocol: IPProtocol = Field(default="tcp") + masquerade_port: Port = Field(default=80) @classmethod def form_request(self, config: ConfigSchema) -> RequestFormat: @@ -121,7 +124,7 @@ class NodeSendRemoteCommandAction(AbstractAction, identifier="node_send_remote_c """Configuration schema for NodeSendRemoteCommandAction.""" node_name: str - remote_ip: str + remote_ip: StrIP command: RequestFormat @classmethod @@ -149,7 +152,7 @@ class TerminalC2ServerAction(AbstractAction, identifier="c2_server_terminal_comm node_name: str commands: Union[List[RequestFormat], RequestFormat] - ip_address: Optional[str] + ip_address: Optional[StrIP] username: Optional[str] password: Optional[str] @@ -198,7 +201,7 @@ class ExfiltrationC2ServerAction(AbstractAction, identifier="c2_server_data_exfi node_name: str username: Optional[str] password: Optional[str] - target_ip_address: str + target_ip_address: StrIP target_file_name: str target_folder_name: str exfiltration_folder_name: Optional[str] @@ -229,7 +232,7 @@ class ConfigureDatabaseClientAction(AbstractAction, identifier="configure_databa """Schema for options that can be passed to this action.""" node_name: str - server_ip_address: Optional[str] = None + server_ip_address: Optional[StrIP] = None server_password: Optional[str] = None @classmethod diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py index cb2cb38e..fde49a6b 100644 --- a/src/primaite/game/agent/observations/acl_observation.py +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -1,7 +1,6 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations -from ipaddress import IPv4Address from typing import Dict, List, Optional from gymnasium import spaces @@ -10,6 +9,9 @@ from gymnasium.core import ObsType from primaite import getLogger from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.ipv4_address import StrIP +from primaite.utils.validation.port import Port _LOGGER = getLogger(__name__) @@ -20,13 +22,13 @@ class ACLObservation(AbstractObservation, identifier="ACL"): class ConfigSchema(AbstractObservation.ConfigSchema): """Configuration schema for ACLObservation.""" - ip_list: Optional[List[IPv4Address]] = None + ip_list: Optional[List[StrIP]] = None """List of IP addresses.""" wildcard_list: Optional[List[str]] = None """List of wildcard strings.""" - port_list: Optional[List[str]] = None + port_list: Optional[List[Port]] = None """List of port names.""" - protocol_list: Optional[List[str]] = None + protocol_list: Optional[List[IPProtocol]] = None """List of protocol names.""" num_rules: Optional[int] = None """Number of ACL rules.""" @@ -35,10 +37,10 @@ class ACLObservation(AbstractObservation, identifier="ACL"): self, where: WhereType, num_rules: int, - ip_list: List[IPv4Address], + ip_list: List[StrIP], wildcard_list: List[str], - port_list: List[str], - protocol_list: List[str], + port_list: List[Port], + protocol_list: List[IPProtocol], ) -> None: """ Initialise an ACL observation instance. @@ -48,13 +50,13 @@ class ACLObservation(AbstractObservation, identifier="ACL"): :param num_rules: Number of ACL rules. :type num_rules: int :param ip_list: List of IP addresses. - :type ip_list: List[IPv4Address] + :type ip_list: List[StrIP] :param wildcard_list: List of wildcard strings. :type wildcard_list: List[str] :param port_list: List of port names. - :type port_list: List[str] + :type port_list: List[Port] :param protocol_list: List of protocol names. - :type protocol_list: List[str] + :type protocol_list: List[IPProtocol] """ self.where = where self.num_rules: int = num_rules diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py index 44541f24..6e5fffb9 100644 --- a/src/primaite/game/agent/observations/firewall_observation.py +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -11,6 +11,9 @@ from primaite.game.agent.observations.acl_observation import ACLObservation from primaite.game.agent.observations.nic_observations import PortObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.ipv4_address import StrIP +from primaite.utils.validation.port import Port _LOGGER = getLogger(__name__) @@ -23,13 +26,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): hostname: str """Hostname of the firewall node, used for querying simulation state dictionary.""" - ip_list: Optional[List[str]] = None + ip_list: Optional[List[StrIP]] = None """List of IP addresses for encoding ACLs.""" wildcard_list: Optional[List[str]] = None """List of IP wildcards for encoding ACLs.""" - port_list: Optional[List[str]] = None + port_list: Optional[List[Port]] = None """List of ports for encoding ACLs.""" - protocol_list: Optional[List[str]] = None + protocol_list: Optional[List[IPProtocol]] = None """List of protocols for encoding ACLs.""" num_rules: Optional[int] = None """Number of rules ACL rules to show.""" @@ -39,10 +42,10 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): def __init__( self, where: WhereType, - ip_list: List[str], + ip_list: List[StrIP], wildcard_list: List[str], - port_list: List[str], - protocol_list: List[str], + port_list: List[Port], + protocol_list: List[IPProtocol], num_rules: int, include_users: bool, ) -> None: @@ -53,13 +56,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): A typical location for a firewall might be ['network', 'nodes', ]. :type where: WhereType :param ip_list: List of IP addresses. - :type ip_list: List[str] + :type ip_list: List[StrIP] :param wildcard_list: List of wildcard rules. :type wildcard_list: List[str] :param port_list: List of port names. - :type port_list: List[str] + :type port_list: List[Port] :param protocol_list: List of protocol types. - :type protocol_list: List[str] + :type protocol_list: List[IPProtocol] :param num_rules: Number of rules configured in the firewall. :type num_rules: int :param include_users: If True, report user session information. diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 0c5d11da..1a0f48b4 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -13,6 +13,7 @@ from primaite.game.agent.observations.host_observations import HostObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.observations.router_observation import RouterObservation from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.ipv4_address import StrIP from primaite.utils.validation.port import Port _LOGGER = getLogger(__name__) @@ -52,13 +53,13 @@ class NodesObservation(AbstractObservation, identifier="NODES"): """If True, report user session information.""" num_ports: Optional[int] = None """Number of ports.""" - ip_list: Optional[List[str]] = None + ip_list: Optional[List[StrIP]] = None """List of IP addresses for encoding ACLs.""" wildcard_list: Optional[List[str]] = None """List of IP wildcards for encoding ACLs.""" - port_list: Optional[List[str]] = None + port_list: Optional[List[Port]] = None """List of ports for encoding ACLs.""" - protocol_list: Optional[List[str]] = None + protocol_list: Optional[List[IPProtocol]] = None """List of protocols for encoding ACLs.""" num_rules: Optional[int] = None """Number of rules ACL rules to show.""" diff --git a/src/primaite/game/agent/observations/router_observation.py b/src/primaite/game/agent/observations/router_observation.py index 9687d083..ab759779 100644 --- a/src/primaite/game/agent/observations/router_observation.py +++ b/src/primaite/game/agent/observations/router_observation.py @@ -11,6 +11,9 @@ from primaite.game.agent.observations.acl_observation import ACLObservation from primaite.game.agent.observations.nic_observations import PortObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.ipv4_address import StrIP +from primaite.utils.validation.port import Port _LOGGER = getLogger(__name__) @@ -29,13 +32,13 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): """Number of port observations configured for this router.""" acl: Optional[ACLObservation.ConfigSchema] = None """Configuration of ACL observation on this router.""" - ip_list: Optional[List[str]] = None + ip_list: Optional[List[StrIP]] = None """List of IP addresses for encoding ACLs.""" wildcard_list: Optional[List[str]] = None """List of IP wildcards for encoding ACLs.""" - port_list: Optional[List[str]] = None + port_list: Optional[List[Port]] = None """List of ports for encoding ACLs.""" - protocol_list: Optional[List[str]] = None + protocol_list: Optional[List[IPProtocol]] = None """List of protocols for encoding ACLs.""" num_rules: Optional[int] = None """Number of rules ACL rules to show.""" diff --git a/src/primaite/game/agent/scripted_agents/random_agent.py b/src/primaite/game/agent/scripted_agents/random_agent.py index 9d82a063..9cf8e798 100644 --- a/src/primaite/game/agent/scripted_agents/random_agent.py +++ b/src/primaite/game/agent/scripted_agents/random_agent.py @@ -83,6 +83,12 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"): next_execution_timestep: int = 0 """Timestep of the next action execution by the agent.""" + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self._set_next_execution_timestep( + timestep=self.config.agent_settings.start_step, variance=self.config.agent_settings.start_variance + ) + @computed_field @cached_property def start_node(self) -> str: diff --git a/src/primaite/simulator/network/creation.py b/src/primaite/simulator/network/creation.py index ebd17638..2cf8774e 100644 --- a/src/primaite/simulator/network/creation.py +++ b/src/primaite/simulator/network/creation.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from ipaddress import IPv4Address from typing import Any, ClassVar, Dict, Literal, Optional, Type -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, ConfigDict, model_validator from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer @@ -44,6 +44,7 @@ class NetworkNodeAdder(BaseModel): by the from_config method to select the correct node adder at runtime. """ + model_config = ConfigDict(extra="forbid") type: str """Uniquely identifies the node adder class to use for adding nodes to network.""" diff --git a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py index 13918cd7..b989671e 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py @@ -127,8 +127,8 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): self.configure( c2_server_ip_address=c2_remote_ip, keep_alive_frequency=frequency, - masquerade_protocol=PROTOCOL_LOOKUP[protocol], - masquerade_port=PORT_LOOKUP[port], + masquerade_protocol=protocol, + masquerade_port=port, ) ) diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index ea7a4d8d..a6cb2b75 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -10,8 +10,8 @@ from primaite.game.science import simulate_trial from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.system.applications.database_client import DatabaseClient -from primaite.utils.validation.ipv4_address import IPV4Address -from primaite.utils.validation.port import Port, PORT_LOOKUP +from primaite.utils.validation.ipv4_address import ipv4_validator, IPV4Address +from primaite.utils.validation.port import Port, PORT_LOOKUP, port_validator _LOGGER = getLogger(__name__) @@ -106,9 +106,9 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): :rtype: RequestResponse """ if "target_ip_address" in request[-1]: - request[-1]["target_ip_address"] = IPv4Address(request[-1]["target_ip_address"]) + request[-1]["target_ip_address"] = ipv4_validator(request[-1]["target_ip_address"]) if "target_port" in request[-1]: - request[-1]["target_port"] = PORT_LOOKUP[request[-1]["target_port"]] + request[-1]["target_port"] = port_validator(request[-1]["target_port"]) return RequestResponse.from_bool(self.configure(**request[-1])) rm.add_request("configure", request_type=RequestType(func=_configure)) diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 4ba4c4d4..91f71302 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -31,6 +31,7 @@ class DatabaseService(Service, identifier="DatabaseService"): type: str = "DatabaseService" backup_server_ip: Optional[IPv4Address] = None + db_password: Optional[str] = None config: "DatabaseService.ConfigSchema" = Field(default_factory=lambda: DatabaseService.ConfigSchema()) @@ -53,6 +54,7 @@ class DatabaseService(Service, identifier="DatabaseService"): super().__init__(**kwargs) self._create_db_file() self.backup_server_ip = self.config.backup_server_ip + self.password = self.config.db_password def install(self): """ diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 0756eb05..6e6f7729 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -1,6 +1,6 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address -from typing import Dict, Optional +from typing import Dict, Optional, TYPE_CHECKING from pydantic import Field @@ -9,8 +9,12 @@ from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address from primaite.utils.validation.port import Port, PORT_LOOKUP +if TYPE_CHECKING: + from primaite.simulator.network.hardware.base import Node + _LOGGER = getLogger(__name__) @@ -21,6 +25,7 @@ class DNSClient(Service, identifier="DNSClient"): """ConfigSchema for DNSClient.""" type: str = "DNSClient" + dns_server: Optional[IPV4Address] = None config: "DNSClient.ConfigSchema" = Field(default_factory=lambda: DNSClient.ConfigSchema()) dns_cache: Dict[str, IPv4Address] = {} @@ -36,6 +41,7 @@ class DNSClient(Service, identifier="DNSClient"): # TCP for now kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) + self.dns_server = self.config.dns_server self.start() def describe_state(self) -> Dict: @@ -79,6 +85,14 @@ class DNSClient(Service, identifier="DNSClient"): if not self._can_perform_action(): return False + # check if the domain is already in the DNS cache + if target_domain in self.dns_cache: + self.sys_log.info( + f"{self.name}: Domain lookup for {target_domain} successful," + f"resolves to {self.dns_cache[target_domain]}" + ) + return True + # check if DNS server is configured if self.dns_server is None: self.sys_log.warning(f"{self.name}: DNS Server is not configured") @@ -87,31 +101,23 @@ class DNSClient(Service, identifier="DNSClient"): # check if the target domain is in the client's DNS cache payload = DNSPacket(dns_request=DNSRequest(domain_name_request=target_domain)) - # check if the domain is already in the DNS cache - if target_domain in self.dns_cache: - self.sys_log.info( - f"{self.name}: Domain lookup for {target_domain} successful," - f"resolves to {self.dns_cache[target_domain]}" - ) - return True + # return False if already reattempted + if is_reattempt: + self.sys_log.warning(f"{self.name}: Domain lookup for {target_domain} failed") + return False else: - # return False if already reattempted - if is_reattempt: - self.sys_log.warning(f"{self.name}: Domain lookup for {target_domain} failed") - return False - else: - # send a request to check if domain name exists in the DNS Server - software_manager: SoftwareManager = self.software_manager - software_manager.send_payload_to_session_manager( - payload=payload, dest_ip_address=self.dns_server, dest_port=PORT_LOOKUP["DNS"] - ) + # send a request to check if domain name exists in the DNS Server + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload=payload, dest_ip_address=self.dns_server, dest_port=PORT_LOOKUP["DNS"] + ) - # recursively re-call the function passing is_reattempt=True - return self.check_domain_exists( - target_domain=target_domain, - session_id=session_id, - is_reattempt=True, - ) + # recursively re-call the function passing is_reattempt=True + return self.check_domain_exists( + target_domain=target_domain, + session_id=session_id, + is_reattempt=True, + ) def send( self, @@ -168,3 +174,9 @@ class DNSClient(Service, identifier="DNSClient"): self.sys_log.warning(f"Failed to resolve domain name {payload.dns_request.domain_name_request}") return False + + def install(self) -> None: + """Set the DNS server to be the node's DNS server unless a different one was already provided.""" + self.parent: Node + if self.parent and not self.dns_server: + self.dns_server = self.parent.dns_server diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index 3a1c0e18..41a5b25f 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -36,6 +36,7 @@ class DNSServer(Service, identifier="DNSServer"): # TCP for now kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) + self.dns_table = self.config.domain_mapping self.start() def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index 054bfe15..5f4ac846 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -6,7 +6,6 @@ from pydantic import Field from primaite import getLogger from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC -from primaite.simulator.system.services.service import Service from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP from primaite.utils.validation.port import is_valid_port, PORT_LOOKUP @@ -22,14 +21,13 @@ class FTPServer(FTPServiceABC, identifier="FTPServer"): """ config: "FTPServer.ConfigSchema" = Field(default_factory=lambda: FTPServer.ConfigSchema()) - server_password: Optional[str] = None - """Password needed to connect to FTP server. Default is None.""" - class ConfigSchema(Service.ConfigSchema): + class ConfigSchema(FTPServiceABC.ConfigSchema): """ConfigSchema for FTPServer.""" type: str = "FTPServer" + server_password: Optional[str] = None def __init__(self, **kwargs): kwargs["name"] = "FTPServer" @@ -37,6 +35,7 @@ class FTPServer(FTPServiceABC, identifier="FTPServer"): kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self.start() + self.server_password = self.config.server_password def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket: """ diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index fb470faf..b5f921c9 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -9,6 +9,7 @@ from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket from primaite.simulator.system.services.service import Service, ServiceOperatingState from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -21,6 +22,7 @@ class NTPClient(Service, identifier="NTPClient"): """ConfigSchema for NTPClient.""" type: str = "NTPClient" + ntp_server_ip: Optional[IPV4Address] = None config: "NTPClient.ConfigSchema" = Field(default_factory=lambda: NTPClient.ConfigSchema()) @@ -33,6 +35,7 @@ class NTPClient(Service, identifier="NTPClient"): kwargs["port"] = PORT_LOOKUP["NTP"] kwargs["protocol"] = PROTOCOL_LOOKUP["UDP"] super().__init__(**kwargs) + self.ntp_server = self.config.ntp_server_ip self.start() def configure(self, ntp_server_ip_address: IPv4Address) -> None: diff --git a/src/primaite/utils/validation/ipv4_address.py b/src/primaite/utils/validation/ipv4_address.py index b2b8b72e..1dc6c74e 100644 --- a/src/primaite/utils/validation/ipv4_address.py +++ b/src/primaite/utils/validation/ipv4_address.py @@ -39,3 +39,12 @@ will automatically check and convert the input value to an instance of IPv4Addre any Pydantic model uses it. This ensures that any field marked with this type is not just an IPv4Address in form, but also valid according to the rules defined in ipv4_validator. """ + + +def str_ip(value: Any) -> str: + """Make sure it's a valid IP, but represent it as a string.""" + # TODO: this is a bit of a hack, we should change RequestResponse to be able to handle IPV4Address objects + return str(IPV4Address(value)) + + +StrIP: Final[Annotated] = Annotated[str, BeforeValidator(str_ip)] diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index a39bf876..b0591da6 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -178,8 +178,6 @@ simulation: backup_server_ip: 192.168.1.10 - type: WebServer - type: FTPServer - options: - server_password: arcd - type: NTPClient options: ntp_server_ip: 192.168.1.10 diff --git a/tests/assets/configs/extended_config.yaml b/tests/assets/configs/extended_config.yaml index bff58ebd..fcfc93ef 100644 --- a/tests/assets/configs/extended_config.yaml +++ b/tests/assets/configs/extended_config.yaml @@ -771,8 +771,6 @@ simulation: options: backup_server_ip: 192.168.1.16 - type: ExtendedService - options: - backup_server_ip: 192.168.1.16 - hostname: client_2 type: computer diff --git a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py index 0ff6754d..2a3691ae 100644 --- a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py +++ b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py @@ -201,7 +201,6 @@ def test_ftp_server_install(): ftp_server_service: FTPServer = client_1.software_manager.software.get("FTPServer") assert ftp_server_service is not None - assert ftp_server_service.server_password == "arcd" def test_ntp_client_install(): diff --git a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py index 40226be6..faf0466f 100644 --- a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py +++ b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py @@ -173,8 +173,8 @@ def test_c2_suite_configure_request(basic_network): c2_beacon_config = { "c2_server_ip_address": "192.168.0.2", "keep_alive_frequency": 5, - "masquerade_protocol": "TCP", - "masquerade_port": "HTTP", + "masquerade_protocol": "tcp", + "masquerade_port": 80, } network.apply_request(["node", "node_b", "application", "C2Beacon", "configure", c2_beacon_config]) From 037dd8278bc092f2a7979264256d775df6d29690 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 31 Jan 2025 12:30:08 +0000 Subject: [PATCH 2/5] #3029 - update changelog --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 315579d5..c91bf4f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Removed action space options which were previously used for assigning meaning to action space IDs - Updated tests that don't use YAMLs to still use the new action and agent schemas +### Fixed +- DNS client no longer fails to check its cache if a DNS server address is missing. +- DNS client now correctly inherits the node's DNS address configuration setting. + + ## [3.3.0] - 2024-09-04 ### Added From 3260e1f30b06d453c6c4115754d464409c7585f0 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 31 Jan 2025 14:41:49 +0000 Subject: [PATCH 3/5] #3029 - make new config items properties as per PR comments --- .../simulator/system/services/database/database_service.py | 6 +++++- src/primaite/simulator/system/services/dns/dns_client.py | 7 ++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 91f71302..d65f05bd 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -54,7 +54,11 @@ class DatabaseService(Service, identifier="DatabaseService"): super().__init__(**kwargs) self._create_db_file() self.backup_server_ip = self.config.backup_server_ip - self.password = self.config.db_password + + @property + def password(self) -> Optional[str]: + """Convenience property for accessing the password.""" + return self.config.db_password def install(self): """ diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 6e6f7729..f4b427cd 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -44,6 +44,11 @@ class DNSClient(Service, identifier="DNSClient"): self.dns_server = self.config.dns_server self.start() + @property + def dns_server(self) -> Optional[IPV4Address]: + """Convenience property for accessing the dns server configuration.""" + return self.config.dns_server + def describe_state(self) -> Dict: """ Describes the current state of the software. @@ -179,4 +184,4 @@ class DNSClient(Service, identifier="DNSClient"): """Set the DNS server to be the node's DNS server unless a different one was already provided.""" self.parent: Node if self.parent and not self.dns_server: - self.dns_server = self.parent.dns_server + self.config.dns_server = self.parent.dns_server From a77fa65c39de68c62bc1264d771b21a7b1b9e519 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 31 Jan 2025 14:46:43 +0000 Subject: [PATCH 4/5] #3029 - Remove old initialisation of dns server attr that caused a bug --- src/primaite/simulator/system/services/dns/dns_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index f4b427cd..1c64c9a9 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -41,7 +41,6 @@ class DNSClient(Service, identifier="DNSClient"): # TCP for now kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) - self.dns_server = self.config.dns_server self.start() @property From 8feb2db954bcb153a4912d2096e8e34f0771b395 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 31 Jan 2025 15:29:10 +0000 Subject: [PATCH 5/5] Fix properties --- src/primaite/game/agent/agent_log.py | 5 ++-- src/primaite/game/game.py | 26 ------------------- .../services/database/database_service.py | 7 ++--- .../system/services/dns/dns_client.py | 16 +++++++----- 4 files changed, 16 insertions(+), 38 deletions(-) diff --git a/src/primaite/game/agent/agent_log.py b/src/primaite/game/agent/agent_log.py index 5d9dc848..ddf14489 100644 --- a/src/primaite/game/agent/agent_log.py +++ b/src/primaite/game/agent/agent_log.py @@ -65,8 +65,9 @@ class AgentLog: The logger is set to the DEBUG level, and is equipped with a handler that writes to a file and filters out JSON-like messages. """ - if not SIM_OUTPUT.save_agent_logs: - return + # TODO: uncomment this once we figure out why it's broken + # if not SIM_OUTPUT.save_agent_logs: + # return log_path = self._get_log_path() file_handler = logging.FileHandler(filename=log_path) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index b869cfd4..a8e23d56 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -406,32 +406,6 @@ class PrimaiteGame: if "service_install_duration" in defaults_config: new_service.install_duration = defaults_config["service_install_duration"] - # service-dependent options - if service_type == "DNSClient": - if "options" in service_cfg: - opt = service_cfg["options"] - if "dns_server" in opt: - new_service.dns_server = IPv4Address(opt["dns_server"]) - if service_type == "DNSServer": - if "options" in service_cfg: - opt = service_cfg["options"] - if "domain_mapping" in opt: - for domain, ip in opt["domain_mapping"].items(): - new_service.dns_register(domain, IPv4Address(ip)) - if service_type == "DatabaseService": - if "options" in service_cfg: - opt = service_cfg["options"] - new_service.password = opt.get("db_password", None) - if "backup_server_ip" in opt: - new_service.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip"))) - if service_type == "FTPServer": - if "options" in service_cfg: - opt = service_cfg["options"] - new_service.server_password = opt.get("server_password") - if service_type == "NTPClient": - if "options" in service_cfg: - opt = service_cfg["options"] - new_service.ntp_server = IPv4Address(opt.get("ntp_server_ip")) if "applications" in node_cfg: for application_cfg in node_cfg["applications"]: new_application = None diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index d65f05bd..1745b9d1 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -35,9 +35,6 @@ class DatabaseService(Service, identifier="DatabaseService"): config: "DatabaseService.ConfigSchema" = Field(default_factory=lambda: DatabaseService.ConfigSchema()) - password: Optional[str] = None - """Password that needs to be provided by clients if they want to connect to the DatabaseService.""" - backup_server_ip: IPv4Address = None """IP address of the backup server.""" @@ -60,6 +57,10 @@ class DatabaseService(Service, identifier="DatabaseService"): """Convenience property for accessing the password.""" return self.config.db_password + @password.setter + def password(self, val: str) -> None: + self.config.db_password = val + def install(self): """ Perform first-time setup of the DatabaseService. diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 1c64c9a9..825896e0 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -30,8 +30,6 @@ class DNSClient(Service, identifier="DNSClient"): config: "DNSClient.ConfigSchema" = Field(default_factory=lambda: DNSClient.ConfigSchema()) dns_cache: Dict[str, IPv4Address] = {} "A dict of known mappings between domain/URLs names and IPv4 addresses." - dns_server: Optional[IPv4Address] = None - "The DNS Server the client sends requests to." def __init__(self, **kwargs): kwargs["name"] = "DNSClient" @@ -43,11 +41,6 @@ class DNSClient(Service, identifier="DNSClient"): super().__init__(**kwargs) self.start() - @property - def dns_server(self) -> Optional[IPV4Address]: - """Convenience property for accessing the dns server configuration.""" - return self.config.dns_server - def describe_state(self) -> Dict: """ Describes the current state of the software. @@ -61,6 +54,15 @@ class DNSClient(Service, identifier="DNSClient"): state = super().describe_state() return state + @property + def dns_server(self) -> Optional[IPV4Address]: + """Convenience property for accessing the dns server configuration.""" + return self.config.dns_server + + @dns_server.setter + def dns_server(self, val: IPV4Address) -> None: + self.config.dns_server = val + def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address) -> bool: """ Adds a domain name to the DNS Client cache.