Merge branch '4.0.0a1-dev' into feature/3075_Migrate_notebooks_to_MilPac_(Core_changes)
This commit is contained in:
@@ -23,6 +23,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Removed action space options which were previously used for assigning meaning to action space IDs
|
||||
- Updated tests that don't use YAMLs to still use the new action and agent schemas
|
||||
|
||||
### Fixed
|
||||
- DNS client no longer fails to check its cache if a DNS server address is missing.
|
||||
- DNS client now correctly inherits the node's DNS address configuration setting.
|
||||
|
||||
|
||||
## [3.3.0] - 2024-09-04
|
||||
|
||||
### Added
|
||||
|
||||
@@ -6,6 +6,9 @@ from pydantic import ConfigDict, Field
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol
|
||||
from primaite.utils.validation.ipv4_address import StrIP
|
||||
from primaite.utils.validation.port import Port
|
||||
|
||||
__all__ = (
|
||||
"ConfigureRansomwareScriptAction",
|
||||
@@ -64,8 +67,8 @@ class ConfigureDoSBotAction(AbstractAction, identifier="configure_dos_bot"):
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
node_name: str
|
||||
target_ip_address: Optional[str] = None
|
||||
target_port: Optional[str] = None
|
||||
target_ip_address: Optional[StrIP] = None
|
||||
target_port: Optional[Port] = None
|
||||
payload: Optional[str] = None
|
||||
repeat: Optional[bool] = None
|
||||
port_scan_p_of_success: Optional[float] = None
|
||||
@@ -95,10 +98,10 @@ class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2_beacon"):
|
||||
"""Configuration schema for ConfigureC2BeaconAction."""
|
||||
|
||||
node_name: str
|
||||
c2_server_ip_address: str
|
||||
c2_server_ip_address: StrIP
|
||||
keep_alive_frequency: int = Field(default=5, ge=1)
|
||||
masquerade_protocol: str = Field(default="TCP")
|
||||
masquerade_port: str = Field(default="HTTP")
|
||||
masquerade_protocol: IPProtocol = Field(default="tcp")
|
||||
masquerade_port: Port = Field(default=80)
|
||||
|
||||
@classmethod
|
||||
def form_request(self, config: ConfigSchema) -> RequestFormat:
|
||||
@@ -121,7 +124,7 @@ class NodeSendRemoteCommandAction(AbstractAction, identifier="node_send_remote_c
|
||||
"""Configuration schema for NodeSendRemoteCommandAction."""
|
||||
|
||||
node_name: str
|
||||
remote_ip: str
|
||||
remote_ip: StrIP
|
||||
command: RequestFormat
|
||||
|
||||
@classmethod
|
||||
@@ -149,7 +152,7 @@ class TerminalC2ServerAction(AbstractAction, identifier="c2_server_terminal_comm
|
||||
|
||||
node_name: str
|
||||
commands: Union[List[RequestFormat], RequestFormat]
|
||||
ip_address: Optional[str]
|
||||
ip_address: Optional[StrIP]
|
||||
username: Optional[str]
|
||||
password: Optional[str]
|
||||
|
||||
@@ -198,7 +201,7 @@ class ExfiltrationC2ServerAction(AbstractAction, identifier="c2_server_data_exfi
|
||||
node_name: str
|
||||
username: Optional[str]
|
||||
password: Optional[str]
|
||||
target_ip_address: str
|
||||
target_ip_address: StrIP
|
||||
target_file_name: str
|
||||
target_folder_name: str
|
||||
exfiltration_folder_name: Optional[str]
|
||||
@@ -229,7 +232,7 @@ class ConfigureDatabaseClientAction(AbstractAction, identifier="configure_databa
|
||||
"""Schema for options that can be passed to this action."""
|
||||
|
||||
node_name: str
|
||||
server_ip_address: Optional[str] = None
|
||||
server_ip_address: Optional[StrIP] = None
|
||||
server_password: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -65,8 +65,9 @@ class AgentLog:
|
||||
The logger is set to the DEBUG level, and is equipped with a handler that writes to a file and filters out
|
||||
JSON-like messages.
|
||||
"""
|
||||
if not SIM_OUTPUT.save_agent_logs:
|
||||
return
|
||||
# TODO: uncomment this once we figure out why it's broken
|
||||
# if not SIM_OUTPUT.save_agent_logs:
|
||||
# return
|
||||
|
||||
log_path = self._get_log_path()
|
||||
file_handler = logging.FileHandler(filename=log_path)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
@@ -10,6 +9,9 @@ from gymnasium.core import ObsType
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol
|
||||
from primaite.utils.validation.ipv4_address import StrIP
|
||||
from primaite.utils.validation.port import Port
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -20,13 +22,13 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for ACLObservation."""
|
||||
|
||||
ip_list: Optional[List[IPv4Address]] = None
|
||||
ip_list: Optional[List[StrIP]] = None
|
||||
"""List of IP addresses."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of wildcard strings."""
|
||||
port_list: Optional[List[str]] = None
|
||||
port_list: Optional[List[Port]] = None
|
||||
"""List of port names."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
protocol_list: Optional[List[IPProtocol]] = None
|
||||
"""List of protocol names."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of ACL rules."""
|
||||
@@ -35,10 +37,10 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
self,
|
||||
where: WhereType,
|
||||
num_rules: int,
|
||||
ip_list: List[IPv4Address],
|
||||
ip_list: List[StrIP],
|
||||
wildcard_list: List[str],
|
||||
port_list: List[str],
|
||||
protocol_list: List[str],
|
||||
port_list: List[Port],
|
||||
protocol_list: List[IPProtocol],
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an ACL observation instance.
|
||||
@@ -48,13 +50,13 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
:param num_rules: Number of ACL rules.
|
||||
:type num_rules: int
|
||||
:param ip_list: List of IP addresses.
|
||||
:type ip_list: List[IPv4Address]
|
||||
:type ip_list: List[StrIP]
|
||||
:param wildcard_list: List of wildcard strings.
|
||||
:type wildcard_list: List[str]
|
||||
:param port_list: List of port names.
|
||||
:type port_list: List[str]
|
||||
:type port_list: List[Port]
|
||||
:param protocol_list: List of protocol names.
|
||||
:type protocol_list: List[str]
|
||||
:type protocol_list: List[IPProtocol]
|
||||
"""
|
||||
self.where = where
|
||||
self.num_rules: int = num_rules
|
||||
|
||||
@@ -11,6 +11,9 @@ from primaite.game.agent.observations.acl_observation import ACLObservation
|
||||
from primaite.game.agent.observations.nic_observations import PortObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol
|
||||
from primaite.utils.validation.ipv4_address import StrIP
|
||||
from primaite.utils.validation.port import Port
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -23,13 +26,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
|
||||
hostname: str
|
||||
"""Hostname of the firewall node, used for querying simulation state dictionary."""
|
||||
ip_list: Optional[List[str]] = None
|
||||
ip_list: Optional[List[StrIP]] = None
|
||||
"""List of IP addresses for encoding ACLs."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of IP wildcards for encoding ACLs."""
|
||||
port_list: Optional[List[str]] = None
|
||||
port_list: Optional[List[Port]] = None
|
||||
"""List of ports for encoding ACLs."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
protocol_list: Optional[List[IPProtocol]] = None
|
||||
"""List of protocols for encoding ACLs."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of rules ACL rules to show."""
|
||||
@@ -39,10 +42,10 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
ip_list: List[str],
|
||||
ip_list: List[StrIP],
|
||||
wildcard_list: List[str],
|
||||
port_list: List[str],
|
||||
protocol_list: List[str],
|
||||
port_list: List[Port],
|
||||
protocol_list: List[IPProtocol],
|
||||
num_rules: int,
|
||||
include_users: bool,
|
||||
) -> None:
|
||||
@@ -53,13 +56,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
A typical location for a firewall might be ['network', 'nodes', <firewall_hostname>].
|
||||
:type where: WhereType
|
||||
:param ip_list: List of IP addresses.
|
||||
:type ip_list: List[str]
|
||||
:type ip_list: List[StrIP]
|
||||
:param wildcard_list: List of wildcard rules.
|
||||
:type wildcard_list: List[str]
|
||||
:param port_list: List of port names.
|
||||
:type port_list: List[str]
|
||||
:type port_list: List[Port]
|
||||
:param protocol_list: List of protocol types.
|
||||
:type protocol_list: List[str]
|
||||
:type protocol_list: List[IPProtocol]
|
||||
:param num_rules: Number of rules configured in the firewall.
|
||||
:type num_rules: int
|
||||
:param include_users: If True, report user session information.
|
||||
|
||||
@@ -13,6 +13,7 @@ from primaite.game.agent.observations.host_observations import HostObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.observations.router_observation import RouterObservation
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol
|
||||
from primaite.utils.validation.ipv4_address import StrIP
|
||||
from primaite.utils.validation.port import Port
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -52,13 +53,13 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
"""If True, report user session information."""
|
||||
num_ports: Optional[int] = None
|
||||
"""Number of ports."""
|
||||
ip_list: Optional[List[str]] = None
|
||||
ip_list: Optional[List[StrIP]] = None
|
||||
"""List of IP addresses for encoding ACLs."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of IP wildcards for encoding ACLs."""
|
||||
port_list: Optional[List[str]] = None
|
||||
port_list: Optional[List[Port]] = None
|
||||
"""List of ports for encoding ACLs."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
protocol_list: Optional[List[IPProtocol]] = None
|
||||
"""List of protocols for encoding ACLs."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of rules ACL rules to show."""
|
||||
|
||||
@@ -11,6 +11,9 @@ from primaite.game.agent.observations.acl_observation import ACLObservation
|
||||
from primaite.game.agent.observations.nic_observations import PortObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol
|
||||
from primaite.utils.validation.ipv4_address import StrIP
|
||||
from primaite.utils.validation.port import Port
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -29,13 +32,13 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
"""Number of port observations configured for this router."""
|
||||
acl: Optional[ACLObservation.ConfigSchema] = None
|
||||
"""Configuration of ACL observation on this router."""
|
||||
ip_list: Optional[List[str]] = None
|
||||
ip_list: Optional[List[StrIP]] = None
|
||||
"""List of IP addresses for encoding ACLs."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of IP wildcards for encoding ACLs."""
|
||||
port_list: Optional[List[str]] = None
|
||||
port_list: Optional[List[Port]] = None
|
||||
"""List of ports for encoding ACLs."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
protocol_list: Optional[List[IPProtocol]] = None
|
||||
"""List of protocols for encoding ACLs."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of rules ACL rules to show."""
|
||||
|
||||
@@ -83,6 +83,12 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"):
|
||||
next_execution_timestep: int = 0
|
||||
"""Timestep of the next action execution by the agent."""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._set_next_execution_timestep(
|
||||
timestep=self.config.agent_settings.start_step, variance=self.config.agent_settings.start_variance
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@cached_property
|
||||
def start_node(self) -> str:
|
||||
|
||||
@@ -406,32 +406,6 @@ class PrimaiteGame:
|
||||
if "service_install_duration" in defaults_config:
|
||||
new_service.install_duration = defaults_config["service_install_duration"]
|
||||
|
||||
# service-dependent options
|
||||
if service_type == "DNSClient":
|
||||
if "options" in service_cfg:
|
||||
opt = service_cfg["options"]
|
||||
if "dns_server" in opt:
|
||||
new_service.dns_server = IPv4Address(opt["dns_server"])
|
||||
if service_type == "DNSServer":
|
||||
if "options" in service_cfg:
|
||||
opt = service_cfg["options"]
|
||||
if "domain_mapping" in opt:
|
||||
for domain, ip in opt["domain_mapping"].items():
|
||||
new_service.dns_register(domain, IPv4Address(ip))
|
||||
if service_type == "DatabaseService":
|
||||
if "options" in service_cfg:
|
||||
opt = service_cfg["options"]
|
||||
new_service.password = opt.get("db_password", None)
|
||||
if "backup_server_ip" in opt:
|
||||
new_service.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip")))
|
||||
if service_type == "FTPServer":
|
||||
if "options" in service_cfg:
|
||||
opt = service_cfg["options"]
|
||||
new_service.server_password = opt.get("server_password")
|
||||
if service_type == "NTPClient":
|
||||
if "options" in service_cfg:
|
||||
opt = service_cfg["options"]
|
||||
new_service.ntp_server = IPv4Address(opt.get("ntp_server_ip"))
|
||||
if "applications" in node_cfg:
|
||||
for application_cfg in node_cfg["applications"]:
|
||||
new_application = None
|
||||
|
||||
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, ClassVar, Dict, Literal, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
@@ -44,6 +44,7 @@ class NetworkNodeAdder(BaseModel):
|
||||
by the from_config method to select the correct node adder at runtime.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
type: str
|
||||
"""Uniquely identifies the node adder class to use for adding nodes to network."""
|
||||
|
||||
|
||||
@@ -127,8 +127,8 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
self.configure(
|
||||
c2_server_ip_address=c2_remote_ip,
|
||||
keep_alive_frequency=frequency,
|
||||
masquerade_protocol=PROTOCOL_LOOKUP[protocol],
|
||||
masquerade_port=PORT_LOOKUP[port],
|
||||
masquerade_protocol=protocol,
|
||||
masquerade_port=port,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ from primaite.game.science import simulate_trial
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
from primaite.utils.validation.ipv4_address import ipv4_validator, IPV4Address
|
||||
from primaite.utils.validation.port import Port, PORT_LOOKUP, port_validator
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -106,9 +106,9 @@ class DoSBot(DatabaseClient, identifier="DoSBot"):
|
||||
:rtype: RequestResponse
|
||||
"""
|
||||
if "target_ip_address" in request[-1]:
|
||||
request[-1]["target_ip_address"] = IPv4Address(request[-1]["target_ip_address"])
|
||||
request[-1]["target_ip_address"] = ipv4_validator(request[-1]["target_ip_address"])
|
||||
if "target_port" in request[-1]:
|
||||
request[-1]["target_port"] = PORT_LOOKUP[request[-1]["target_port"]]
|
||||
request[-1]["target_port"] = port_validator(request[-1]["target_port"])
|
||||
return RequestResponse.from_bool(self.configure(**request[-1]))
|
||||
|
||||
rm.add_request("configure", request_type=RequestType(func=_configure))
|
||||
|
||||
@@ -31,12 +31,10 @@ class DatabaseService(Service, identifier="DatabaseService"):
|
||||
|
||||
type: str = "DatabaseService"
|
||||
backup_server_ip: Optional[IPv4Address] = None
|
||||
db_password: Optional[str] = None
|
||||
|
||||
config: "DatabaseService.ConfigSchema" = Field(default_factory=lambda: DatabaseService.ConfigSchema())
|
||||
|
||||
password: Optional[str] = None
|
||||
"""Password that needs to be provided by clients if they want to connect to the DatabaseService."""
|
||||
|
||||
backup_server_ip: IPv4Address = None
|
||||
"""IP address of the backup server."""
|
||||
|
||||
@@ -54,6 +52,15 @@ class DatabaseService(Service, identifier="DatabaseService"):
|
||||
self._create_db_file()
|
||||
self.backup_server_ip = self.config.backup_server_ip
|
||||
|
||||
@property
|
||||
def password(self) -> Optional[str]:
|
||||
"""Convenience property for accessing the password."""
|
||||
return self.config.db_password
|
||||
|
||||
@password.setter
|
||||
def password(self, val: str) -> None:
|
||||
self.config.db_password = val
|
||||
|
||||
def install(self):
|
||||
"""
|
||||
Perform first-time setup of the DatabaseService.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, TYPE_CHECKING
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -9,8 +9,12 @@ from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -21,12 +25,11 @@ class DNSClient(Service, identifier="DNSClient"):
|
||||
"""ConfigSchema for DNSClient."""
|
||||
|
||||
type: str = "DNSClient"
|
||||
dns_server: Optional[IPV4Address] = None
|
||||
|
||||
config: "DNSClient.ConfigSchema" = Field(default_factory=lambda: DNSClient.ConfigSchema())
|
||||
dns_cache: Dict[str, IPv4Address] = {}
|
||||
"A dict of known mappings between domain/URLs names and IPv4 addresses."
|
||||
dns_server: Optional[IPv4Address] = None
|
||||
"The DNS Server the client sends requests to."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "DNSClient"
|
||||
@@ -51,6 +54,15 @@ class DNSClient(Service, identifier="DNSClient"):
|
||||
state = super().describe_state()
|
||||
return state
|
||||
|
||||
@property
|
||||
def dns_server(self) -> Optional[IPV4Address]:
|
||||
"""Convenience property for accessing the dns server configuration."""
|
||||
return self.config.dns_server
|
||||
|
||||
@dns_server.setter
|
||||
def dns_server(self, val: IPV4Address) -> None:
|
||||
self.config.dns_server = val
|
||||
|
||||
def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address) -> bool:
|
||||
"""
|
||||
Adds a domain name to the DNS Client cache.
|
||||
@@ -79,6 +91,14 @@ class DNSClient(Service, identifier="DNSClient"):
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
# check if the domain is already in the DNS cache
|
||||
if target_domain in self.dns_cache:
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Domain lookup for {target_domain} successful,"
|
||||
f"resolves to {self.dns_cache[target_domain]}"
|
||||
)
|
||||
return True
|
||||
|
||||
# check if DNS server is configured
|
||||
if self.dns_server is None:
|
||||
self.sys_log.warning(f"{self.name}: DNS Server is not configured")
|
||||
@@ -87,31 +107,23 @@ class DNSClient(Service, identifier="DNSClient"):
|
||||
# check if the target domain is in the client's DNS cache
|
||||
payload = DNSPacket(dns_request=DNSRequest(domain_name_request=target_domain))
|
||||
|
||||
# check if the domain is already in the DNS cache
|
||||
if target_domain in self.dns_cache:
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Domain lookup for {target_domain} successful,"
|
||||
f"resolves to {self.dns_cache[target_domain]}"
|
||||
)
|
||||
return True
|
||||
# return False if already reattempted
|
||||
if is_reattempt:
|
||||
self.sys_log.warning(f"{self.name}: Domain lookup for {target_domain} failed")
|
||||
return False
|
||||
else:
|
||||
# return False if already reattempted
|
||||
if is_reattempt:
|
||||
self.sys_log.warning(f"{self.name}: Domain lookup for {target_domain} failed")
|
||||
return False
|
||||
else:
|
||||
# send a request to check if domain name exists in the DNS Server
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=self.dns_server, dest_port=PORT_LOOKUP["DNS"]
|
||||
)
|
||||
# send a request to check if domain name exists in the DNS Server
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=self.dns_server, dest_port=PORT_LOOKUP["DNS"]
|
||||
)
|
||||
|
||||
# recursively re-call the function passing is_reattempt=True
|
||||
return self.check_domain_exists(
|
||||
target_domain=target_domain,
|
||||
session_id=session_id,
|
||||
is_reattempt=True,
|
||||
)
|
||||
# recursively re-call the function passing is_reattempt=True
|
||||
return self.check_domain_exists(
|
||||
target_domain=target_domain,
|
||||
session_id=session_id,
|
||||
is_reattempt=True,
|
||||
)
|
||||
|
||||
def send(
|
||||
self,
|
||||
@@ -168,3 +180,9 @@ class DNSClient(Service, identifier="DNSClient"):
|
||||
|
||||
self.sys_log.warning(f"Failed to resolve domain name {payload.dns_request.domain_name_request}")
|
||||
return False
|
||||
|
||||
def install(self) -> None:
|
||||
"""Set the DNS server to be the node's DNS server unless a different one was already provided."""
|
||||
self.parent: Node
|
||||
if self.parent and not self.dns_server:
|
||||
self.config.dns_server = self.parent.dns_server
|
||||
|
||||
@@ -36,6 +36,7 @@ class DNSServer(Service, identifier="DNSServer"):
|
||||
# TCP for now
|
||||
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
||||
super().__init__(**kwargs)
|
||||
self.dns_table = self.config.domain_mapping
|
||||
self.start()
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
|
||||
@@ -6,7 +6,6 @@ from pydantic import Field
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.port import is_valid_port, PORT_LOOKUP
|
||||
|
||||
@@ -22,14 +21,13 @@ class FTPServer(FTPServiceABC, identifier="FTPServer"):
|
||||
"""
|
||||
|
||||
config: "FTPServer.ConfigSchema" = Field(default_factory=lambda: FTPServer.ConfigSchema())
|
||||
|
||||
server_password: Optional[str] = None
|
||||
"""Password needed to connect to FTP server. Default is None."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
class ConfigSchema(FTPServiceABC.ConfigSchema):
|
||||
"""ConfigSchema for FTPServer."""
|
||||
|
||||
type: str = "FTPServer"
|
||||
server_password: Optional[str] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "FTPServer"
|
||||
@@ -37,6 +35,7 @@ class FTPServer(FTPServiceABC, identifier="FTPServer"):
|
||||
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
||||
super().__init__(**kwargs)
|
||||
self.start()
|
||||
self.server_password = self.config.server_password
|
||||
|
||||
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
|
||||
"""
|
||||
|
||||
@@ -9,6 +9,7 @@ from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.ntp import NTPPacket
|
||||
from primaite.simulator.system.services.service import Service, ServiceOperatingState
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -21,6 +22,7 @@ class NTPClient(Service, identifier="NTPClient"):
|
||||
"""ConfigSchema for NTPClient."""
|
||||
|
||||
type: str = "NTPClient"
|
||||
ntp_server_ip: Optional[IPV4Address] = None
|
||||
|
||||
config: "NTPClient.ConfigSchema" = Field(default_factory=lambda: NTPClient.ConfigSchema())
|
||||
|
||||
@@ -33,6 +35,7 @@ class NTPClient(Service, identifier="NTPClient"):
|
||||
kwargs["port"] = PORT_LOOKUP["NTP"]
|
||||
kwargs["protocol"] = PROTOCOL_LOOKUP["UDP"]
|
||||
super().__init__(**kwargs)
|
||||
self.ntp_server = self.config.ntp_server_ip
|
||||
self.start()
|
||||
|
||||
def configure(self, ntp_server_ip_address: IPv4Address) -> None:
|
||||
|
||||
@@ -39,3 +39,12 @@ will automatically check and convert the input value to an instance of IPv4Addre
|
||||
any Pydantic model uses it. This ensures that any field marked with this type is not just
|
||||
an IPv4Address in form, but also valid according to the rules defined in ipv4_validator.
|
||||
"""
|
||||
|
||||
|
||||
def str_ip(value: Any) -> str:
|
||||
"""Make sure it's a valid IP, but represent it as a string."""
|
||||
# TODO: this is a bit of a hack, we should change RequestResponse to be able to handle IPV4Address objects
|
||||
return str(IPV4Address(value))
|
||||
|
||||
|
||||
StrIP: Final[Annotated] = Annotated[str, BeforeValidator(str_ip)]
|
||||
|
||||
@@ -178,8 +178,6 @@ simulation:
|
||||
backup_server_ip: 192.168.1.10
|
||||
- type: WebServer
|
||||
- type: FTPServer
|
||||
options:
|
||||
server_password: arcd
|
||||
- type: NTPClient
|
||||
options:
|
||||
ntp_server_ip: 192.168.1.10
|
||||
|
||||
@@ -771,8 +771,6 @@ simulation:
|
||||
options:
|
||||
backup_server_ip: 192.168.1.16
|
||||
- type: ExtendedService
|
||||
options:
|
||||
backup_server_ip: 192.168.1.16
|
||||
|
||||
- hostname: client_2
|
||||
type: computer
|
||||
|
||||
@@ -201,7 +201,6 @@ def test_ftp_server_install():
|
||||
|
||||
ftp_server_service: FTPServer = client_1.software_manager.software.get("FTPServer")
|
||||
assert ftp_server_service is not None
|
||||
assert ftp_server_service.server_password == "arcd"
|
||||
|
||||
|
||||
def test_ntp_client_install():
|
||||
|
||||
@@ -173,8 +173,8 @@ def test_c2_suite_configure_request(basic_network):
|
||||
c2_beacon_config = {
|
||||
"c2_server_ip_address": "192.168.0.2",
|
||||
"keep_alive_frequency": 5,
|
||||
"masquerade_protocol": "TCP",
|
||||
"masquerade_port": "HTTP",
|
||||
"masquerade_protocol": "tcp",
|
||||
"masquerade_port": 80,
|
||||
}
|
||||
|
||||
network.apply_request(["node", "node_b", "application", "C2Beacon", "configure", c2_beacon_config])
|
||||
|
||||
Reference in New Issue
Block a user