Merged PR 593: Core bugfixes in support of extension tests.

## Summary
- Corrected some validation in observations and actions to use strings (in alignment with 'describe_state' methods.)
- Fixed bug where periodic agent would start on step 0 instead of on the configured start step
- Improved validations on network node adder
- Added database password to config schema of database service
- DNS client lookup no longer requires a DNS server address to be configured if the requested domain exists in the client cache
- DNS client can now inherit the parent node's DNS server address

## Test process
Unit tests and tests against the extension

## Checklist
- [X] PR is linked to a **work item**
- [X] **acceptance criteria** of linked ticket are met
- [X] performed **self-review** of the code
- [X] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [X] updated the **change log**
- [X] ran **pre-commit** checks for code style
- [ ] attended to any **TO-DOs** left in the code

#3029 - Add string-ip validator, improve validation, fix minor bugs in pulling schema data

Related work items: #3029
This commit is contained in:
Marek Wolan
2025-01-31 14:47:16 +00:00
20 changed files with 129 additions and 76 deletions

View File

@@ -23,6 +23,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Removed action space options which were previously used for assigning meaning to action space IDs
- Updated tests that don't use YAMLs to still use the new action and agent schemas
### Fixed
- DNS client no longer fails to check its cache if a DNS server address is missing.
- DNS client now correctly inherits the node's DNS address configuration setting.
## [3.3.0] - 2024-09-04
### Added

View File

@@ -6,6 +6,9 @@ from pydantic import ConfigDict, Field
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
from primaite.utils.validation.ip_protocol import IPProtocol
from primaite.utils.validation.ipv4_address import StrIP
from primaite.utils.validation.port import Port
__all__ = (
"ConfigureRansomwareScriptAction",
@@ -64,8 +67,8 @@ class ConfigureDoSBotAction(AbstractAction, identifier="configure_dos_bot"):
model_config = ConfigDict(extra="forbid")
node_name: str
target_ip_address: Optional[str] = None
target_port: Optional[str] = None
target_ip_address: Optional[StrIP] = None
target_port: Optional[Port] = None
payload: Optional[str] = None
repeat: Optional[bool] = None
port_scan_p_of_success: Optional[float] = None
@@ -95,10 +98,10 @@ class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2_beacon"):
"""Configuration schema for ConfigureC2BeaconAction."""
node_name: str
c2_server_ip_address: str
c2_server_ip_address: StrIP
keep_alive_frequency: int = Field(default=5, ge=1)
masquerade_protocol: str = Field(default="TCP")
masquerade_port: str = Field(default="HTTP")
masquerade_protocol: IPProtocol = Field(default="tcp")
masquerade_port: Port = Field(default=80)
@classmethod
def form_request(self, config: ConfigSchema) -> RequestFormat:
@@ -121,7 +124,7 @@ class NodeSendRemoteCommandAction(AbstractAction, identifier="node_send_remote_c
"""Configuration schema for NodeSendRemoteCommandAction."""
node_name: str
remote_ip: str
remote_ip: StrIP
command: RequestFormat
@classmethod
@@ -149,7 +152,7 @@ class TerminalC2ServerAction(AbstractAction, identifier="c2_server_terminal_comm
node_name: str
commands: Union[List[RequestFormat], RequestFormat]
ip_address: Optional[str]
ip_address: Optional[StrIP]
username: Optional[str]
password: Optional[str]
@@ -198,7 +201,7 @@ class ExfiltrationC2ServerAction(AbstractAction, identifier="c2_server_data_exfi
node_name: str
username: Optional[str]
password: Optional[str]
target_ip_address: str
target_ip_address: StrIP
target_file_name: str
target_folder_name: str
exfiltration_folder_name: Optional[str]
@@ -229,7 +232,7 @@ class ConfigureDatabaseClientAction(AbstractAction, identifier="configure_databa
"""Schema for options that can be passed to this action."""
node_name: str
server_ip_address: Optional[str] = None
server_ip_address: Optional[StrIP] = None
server_password: Optional[str] = None
@classmethod

View File

@@ -1,7 +1,6 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Dict, List, Optional
from gymnasium import spaces
@@ -10,6 +9,9 @@ from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
from primaite.utils.validation.ip_protocol import IPProtocol
from primaite.utils.validation.ipv4_address import StrIP
from primaite.utils.validation.port import Port
_LOGGER = getLogger(__name__)
@@ -20,13 +22,13 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for ACLObservation."""
ip_list: Optional[List[IPv4Address]] = None
ip_list: Optional[List[StrIP]] = None
"""List of IP addresses."""
wildcard_list: Optional[List[str]] = None
"""List of wildcard strings."""
port_list: Optional[List[str]] = None
port_list: Optional[List[Port]] = None
"""List of port names."""
protocol_list: Optional[List[str]] = None
protocol_list: Optional[List[IPProtocol]] = None
"""List of protocol names."""
num_rules: Optional[int] = None
"""Number of ACL rules."""
@@ -35,10 +37,10 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
self,
where: WhereType,
num_rules: int,
ip_list: List[IPv4Address],
ip_list: List[StrIP],
wildcard_list: List[str],
port_list: List[str],
protocol_list: List[str],
port_list: List[Port],
protocol_list: List[IPProtocol],
) -> None:
"""
Initialise an ACL observation instance.
@@ -48,13 +50,13 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
:param num_rules: Number of ACL rules.
:type num_rules: int
:param ip_list: List of IP addresses.
:type ip_list: List[IPv4Address]
:type ip_list: List[StrIP]
:param wildcard_list: List of wildcard strings.
:type wildcard_list: List[str]
:param port_list: List of port names.
:type port_list: List[str]
:type port_list: List[Port]
:param protocol_list: List of protocol names.
:type protocol_list: List[str]
:type protocol_list: List[IPProtocol]
"""
self.where = where
self.num_rules: int = num_rules

View File

@@ -11,6 +11,9 @@ from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.game.agent.observations.nic_observations import PortObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
from primaite.utils.validation.ip_protocol import IPProtocol
from primaite.utils.validation.ipv4_address import StrIP
from primaite.utils.validation.port import Port
_LOGGER = getLogger(__name__)
@@ -23,13 +26,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
hostname: str
"""Hostname of the firewall node, used for querying simulation state dictionary."""
ip_list: Optional[List[str]] = None
ip_list: Optional[List[StrIP]] = None
"""List of IP addresses for encoding ACLs."""
wildcard_list: Optional[List[str]] = None
"""List of IP wildcards for encoding ACLs."""
port_list: Optional[List[str]] = None
port_list: Optional[List[Port]] = None
"""List of ports for encoding ACLs."""
protocol_list: Optional[List[str]] = None
protocol_list: Optional[List[IPProtocol]] = None
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
@@ -39,10 +42,10 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
def __init__(
self,
where: WhereType,
ip_list: List[str],
ip_list: List[StrIP],
wildcard_list: List[str],
port_list: List[str],
protocol_list: List[str],
port_list: List[Port],
protocol_list: List[IPProtocol],
num_rules: int,
include_users: bool,
) -> None:
@@ -53,13 +56,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
A typical location for a firewall might be ['network', 'nodes', <firewall_hostname>].
:type where: WhereType
:param ip_list: List of IP addresses.
:type ip_list: List[str]
:type ip_list: List[StrIP]
:param wildcard_list: List of wildcard rules.
:type wildcard_list: List[str]
:param port_list: List of port names.
:type port_list: List[str]
:type port_list: List[Port]
:param protocol_list: List of protocol types.
:type protocol_list: List[str]
:type protocol_list: List[IPProtocol]
:param num_rules: Number of rules configured in the firewall.
:type num_rules: int
:param include_users: If True, report user session information.

View File

@@ -13,6 +13,7 @@ from primaite.game.agent.observations.host_observations import HostObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.observations.router_observation import RouterObservation
from primaite.utils.validation.ip_protocol import IPProtocol
from primaite.utils.validation.ipv4_address import StrIP
from primaite.utils.validation.port import Port
_LOGGER = getLogger(__name__)
@@ -52,13 +53,13 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
"""If True, report user session information."""
num_ports: Optional[int] = None
"""Number of ports."""
ip_list: Optional[List[str]] = None
ip_list: Optional[List[StrIP]] = None
"""List of IP addresses for encoding ACLs."""
wildcard_list: Optional[List[str]] = None
"""List of IP wildcards for encoding ACLs."""
port_list: Optional[List[str]] = None
port_list: Optional[List[Port]] = None
"""List of ports for encoding ACLs."""
protocol_list: Optional[List[str]] = None
protocol_list: Optional[List[IPProtocol]] = None
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""

View File

@@ -11,6 +11,9 @@ from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.game.agent.observations.nic_observations import PortObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
from primaite.utils.validation.ip_protocol import IPProtocol
from primaite.utils.validation.ipv4_address import StrIP
from primaite.utils.validation.port import Port
_LOGGER = getLogger(__name__)
@@ -29,13 +32,13 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
"""Number of port observations configured for this router."""
acl: Optional[ACLObservation.ConfigSchema] = None
"""Configuration of ACL observation on this router."""
ip_list: Optional[List[str]] = None
ip_list: Optional[List[StrIP]] = None
"""List of IP addresses for encoding ACLs."""
wildcard_list: Optional[List[str]] = None
"""List of IP wildcards for encoding ACLs."""
port_list: Optional[List[str]] = None
port_list: Optional[List[Port]] = None
"""List of ports for encoding ACLs."""
protocol_list: Optional[List[str]] = None
protocol_list: Optional[List[IPProtocol]] = None
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""

View File

@@ -83,6 +83,12 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"):
next_execution_timestep: int = 0
"""Timestep of the next action execution by the agent."""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self._set_next_execution_timestep(
timestep=self.config.agent_settings.start_step, variance=self.config.agent_settings.start_variance
)
@computed_field
@cached_property
def start_node(self) -> str:

View File

@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
from ipaddress import IPv4Address
from typing import Any, ClassVar, Dict, Literal, Optional, Type
from pydantic import BaseModel, model_validator
from pydantic import BaseModel, ConfigDict, model_validator
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
@@ -44,6 +44,7 @@ class NetworkNodeAdder(BaseModel):
by the from_config method to select the correct node adder at runtime.
"""
model_config = ConfigDict(extra="forbid")
type: str
"""Uniquely identifies the node adder class to use for adding nodes to network."""

View File

@@ -127,8 +127,8 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
self.configure(
c2_server_ip_address=c2_remote_ip,
keep_alive_frequency=frequency,
masquerade_protocol=PROTOCOL_LOOKUP[protocol],
masquerade_port=PORT_LOOKUP[port],
masquerade_protocol=protocol,
masquerade_port=port,
)
)

View File

@@ -10,8 +10,8 @@ from primaite.game.science import simulate_trial
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import Port, PORT_LOOKUP
from primaite.utils.validation.ipv4_address import ipv4_validator, IPV4Address
from primaite.utils.validation.port import Port, PORT_LOOKUP, port_validator
_LOGGER = getLogger(__name__)
@@ -106,9 +106,9 @@ class DoSBot(DatabaseClient, identifier="DoSBot"):
:rtype: RequestResponse
"""
if "target_ip_address" in request[-1]:
request[-1]["target_ip_address"] = IPv4Address(request[-1]["target_ip_address"])
request[-1]["target_ip_address"] = ipv4_validator(request[-1]["target_ip_address"])
if "target_port" in request[-1]:
request[-1]["target_port"] = PORT_LOOKUP[request[-1]["target_port"]]
request[-1]["target_port"] = port_validator(request[-1]["target_port"])
return RequestResponse.from_bool(self.configure(**request[-1]))
rm.add_request("configure", request_type=RequestType(func=_configure))

View File

@@ -31,6 +31,7 @@ class DatabaseService(Service, identifier="DatabaseService"):
type: str = "DatabaseService"
backup_server_ip: Optional[IPv4Address] = None
db_password: Optional[str] = None
config: "DatabaseService.ConfigSchema" = Field(default_factory=lambda: DatabaseService.ConfigSchema())
@@ -54,6 +55,11 @@ class DatabaseService(Service, identifier="DatabaseService"):
self._create_db_file()
self.backup_server_ip = self.config.backup_server_ip
@property
def password(self) -> Optional[str]:
"""Convenience property for accessing the password."""
return self.config.db_password
def install(self):
"""
Perform first-time setup of the DatabaseService.

View File

@@ -1,6 +1,6 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from ipaddress import IPv4Address
from typing import Dict, Optional
from typing import Dict, Optional, TYPE_CHECKING
from pydantic import Field
@@ -9,8 +9,12 @@ from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.service import Service
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import Port, PORT_LOOKUP
if TYPE_CHECKING:
from primaite.simulator.network.hardware.base import Node
_LOGGER = getLogger(__name__)
@@ -21,6 +25,7 @@ class DNSClient(Service, identifier="DNSClient"):
"""ConfigSchema for DNSClient."""
type: str = "DNSClient"
dns_server: Optional[IPV4Address] = None
config: "DNSClient.ConfigSchema" = Field(default_factory=lambda: DNSClient.ConfigSchema())
dns_cache: Dict[str, IPv4Address] = {}
@@ -38,6 +43,11 @@ class DNSClient(Service, identifier="DNSClient"):
super().__init__(**kwargs)
self.start()
@property
def dns_server(self) -> Optional[IPV4Address]:
"""Convenience property for accessing the dns server configuration."""
return self.config.dns_server
def describe_state(self) -> Dict:
"""
Describes the current state of the software.
@@ -79,6 +89,14 @@ class DNSClient(Service, identifier="DNSClient"):
if not self._can_perform_action():
return False
# check if the domain is already in the DNS cache
if target_domain in self.dns_cache:
self.sys_log.info(
f"{self.name}: Domain lookup for {target_domain} successful,"
f"resolves to {self.dns_cache[target_domain]}"
)
return True
# check if DNS server is configured
if self.dns_server is None:
self.sys_log.warning(f"{self.name}: DNS Server is not configured")
@@ -87,14 +105,6 @@ class DNSClient(Service, identifier="DNSClient"):
# check if the target domain is in the client's DNS cache
payload = DNSPacket(dns_request=DNSRequest(domain_name_request=target_domain))
# check if the domain is already in the DNS cache
if target_domain in self.dns_cache:
self.sys_log.info(
f"{self.name}: Domain lookup for {target_domain} successful,"
f"resolves to {self.dns_cache[target_domain]}"
)
return True
else:
# return False if already reattempted
if is_reattempt:
self.sys_log.warning(f"{self.name}: Domain lookup for {target_domain} failed")
@@ -168,3 +178,9 @@ class DNSClient(Service, identifier="DNSClient"):
self.sys_log.warning(f"Failed to resolve domain name {payload.dns_request.domain_name_request}")
return False
def install(self) -> None:
"""Set the DNS server to be the node's DNS server unless a different one was already provided."""
self.parent: Node
if self.parent and not self.dns_server:
self.config.dns_server = self.parent.dns_server

View File

@@ -36,6 +36,7 @@ class DNSServer(Service, identifier="DNSServer"):
# TCP for now
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
self.dns_table = self.config.domain_mapping
self.start()
def describe_state(self) -> Dict:

View File

@@ -6,7 +6,6 @@ from pydantic import Field
from primaite import getLogger
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
from primaite.simulator.system.services.service import Service
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.port import is_valid_port, PORT_LOOKUP
@@ -22,14 +21,13 @@ class FTPServer(FTPServiceABC, identifier="FTPServer"):
"""
config: "FTPServer.ConfigSchema" = Field(default_factory=lambda: FTPServer.ConfigSchema())
server_password: Optional[str] = None
"""Password needed to connect to FTP server. Default is None."""
class ConfigSchema(Service.ConfigSchema):
class ConfigSchema(FTPServiceABC.ConfigSchema):
"""ConfigSchema for FTPServer."""
type: str = "FTPServer"
server_password: Optional[str] = None
def __init__(self, **kwargs):
kwargs["name"] = "FTPServer"
@@ -37,6 +35,7 @@ class FTPServer(FTPServiceABC, identifier="FTPServer"):
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
self.start()
self.server_password = self.config.server_password
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""

View File

@@ -9,6 +9,7 @@ from primaite import getLogger
from primaite.simulator.network.protocols.ntp import NTPPacket
from primaite.simulator.system.services.service import Service, ServiceOperatingState
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import Port, PORT_LOOKUP
_LOGGER = getLogger(__name__)
@@ -21,6 +22,7 @@ class NTPClient(Service, identifier="NTPClient"):
"""ConfigSchema for NTPClient."""
type: str = "NTPClient"
ntp_server_ip: Optional[IPV4Address] = None
config: "NTPClient.ConfigSchema" = Field(default_factory=lambda: NTPClient.ConfigSchema())
@@ -33,6 +35,7 @@ class NTPClient(Service, identifier="NTPClient"):
kwargs["port"] = PORT_LOOKUP["NTP"]
kwargs["protocol"] = PROTOCOL_LOOKUP["UDP"]
super().__init__(**kwargs)
self.ntp_server = self.config.ntp_server_ip
self.start()
def configure(self, ntp_server_ip_address: IPv4Address) -> None:

View File

@@ -39,3 +39,12 @@ will automatically check and convert the input value to an instance of IPv4Addre
any Pydantic model uses it. This ensures that any field marked with this type is not just
an IPv4Address in form, but also valid according to the rules defined in ipv4_validator.
"""
def str_ip(value: Any) -> str:
"""Make sure it's a valid IP, but represent it as a string."""
# TODO: this is a bit of a hack, we should change RequestResponse to be able to handle IPV4Address objects
return str(IPV4Address(value))
StrIP: Final[Annotated] = Annotated[str, BeforeValidator(str_ip)]

View File

@@ -178,8 +178,6 @@ simulation:
backup_server_ip: 192.168.1.10
- type: WebServer
- type: FTPServer
options:
server_password: arcd
- type: NTPClient
options:
ntp_server_ip: 192.168.1.10

View File

@@ -771,8 +771,6 @@ simulation:
options:
backup_server_ip: 192.168.1.16
- type: ExtendedService
options:
backup_server_ip: 192.168.1.16
- hostname: client_2
type: computer

View File

@@ -201,7 +201,6 @@ def test_ftp_server_install():
ftp_server_service: FTPServer = client_1.software_manager.software.get("FTPServer")
assert ftp_server_service is not None
assert ftp_server_service.server_password == "arcd"
def test_ntp_client_install():

View File

@@ -173,8 +173,8 @@ def test_c2_suite_configure_request(basic_network):
c2_beacon_config = {
"c2_server_ip_address": "192.168.0.2",
"keep_alive_frequency": 5,
"masquerade_protocol": "TCP",
"masquerade_port": "HTTP",
"masquerade_protocol": "tcp",
"masquerade_port": 80,
}
network.apply_request(["node", "node_b", "application", "C2Beacon", "configure", c2_beacon_config])