Merge in updates from dev

This commit is contained in:
Charlie Crane
2025-02-27 18:21:43 +00:00
70 changed files with 2140 additions and 239 deletions

View File

@@ -1 +1 @@
4.0.0a1-dev
4.0.0-dev

View File

@@ -1,6 +1,6 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import ABC, abstractmethod
from typing import ClassVar, List, Optional, Union
from typing import ClassVar, List, Literal, Optional, Union
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
@@ -153,8 +153,6 @@ class NodeNMAPPortScanAction(NodeNMAPAbstractAction, discriminator="node-nmap-po
class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, discriminator="node-network-service-recon"):
"""Action which performs an nmap network service recon (ping scan followed by port scan)."""
config: "NodeNetworkServiceReconAction.ConfigSchema"
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
"""Configuration schema for NodeNetworkServiceReconAction."""
@@ -179,3 +177,70 @@ class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, discriminator="node-
"show": config.show,
},
]
class NodeAccountsAddUserAction(AbstractAction, discriminator="node-account-add-user"):
class ConfigSchema(AbstractAction.ConfigSchema):
type: Literal["node-account-add-user"] = "node-account-add-user"
node_name: str
username: str
password: str
is_admin: bool
@classmethod
@staticmethod
def form_request(config: ConfigSchema) -> RequestFormat:
return [
"network",
"node",
config.node_name,
"service",
"user-manager",
"add_user",
config.username,
config.password,
config.is_admin,
]
class NodeAccountsDisableUserAction(AbstractAction, discriminator="node-account-disable-user"):
class ConfigSchema(AbstractAction.ConfigSchema):
type: Literal["node-account-disable-user"] = "node-account-disable-user"
node_name: str
username: str
@classmethod
@staticmethod
def form_request(config: ConfigSchema) -> RequestFormat:
return [
"network",
"node",
config.node_name,
"service",
"user-manager",
"disable_user",
config.username,
]
class NodeSendLocalCommandAction(AbstractAction, discriminator="node-send-local-command"):
class ConfigSchema(AbstractAction.ConfigSchema):
type: Literal["node-send-local-command"] = "node-send-local-command"
node_name: str
username: str
password: str
command: RequestFormat
@staticmethod
def form_request(config: ConfigSchema) -> RequestFormat:
return [
"network",
"node",
config.node_name,
"service",
"terminal",
"send_local_command",
config.username,
config.password,
{"command": config.command},
]

View File

@@ -34,8 +34,6 @@ class NodeSessionAbstractAction(AbstractAction, ABC):
class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="node-session-remote-login"):
"""Action which performs a remote session login."""
config: "NodeSessionsRemoteLoginAction.ConfigSchema"
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
"""Configuration schema for NodeSessionsRemoteLoginAction."""
@@ -53,7 +51,7 @@ class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="no
config.node_name,
"service",
"terminal",
"node-session-remote-login",
"node_session_remote_login",
config.username,
config.password,
config.remote_ip,

View File

@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Type, TYPE_CHECKING
from gymnasium.core import ActType, ObsType
from prettytable import PrettyTable
from pydantic import BaseModel, ConfigDict, Field
from primaite.game.agent.actions import ActionManager
@@ -42,6 +43,9 @@ class AgentHistoryItem(BaseModel):
reward_info: Dict[str, Any] = {}
observation: Optional[ObsType] = None
"""The observation space data for this step."""
class AbstractAgent(BaseModel, ABC):
"""Base class for scripted and RL agents."""
@@ -67,6 +71,9 @@ class AbstractAgent(BaseModel, ABC):
default_factory=lambda: ObservationManager.ConfigSchema()
)
reward_function: RewardFunction.ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema())
thresholds: Optional[Dict] = {}
# TODO: this is only relevant to some observations, need to refactor the way thresholds are dealt with (#3085)
"""A dict containing the observation thresholds."""
config: ConfigSchema = Field(default_factory=lambda: AbstractAgent.ConfigSchema())
@@ -95,10 +102,34 @@ class AbstractAgent(BaseModel, ABC):
def model_post_init(self, __context: Any) -> None:
"""Overwrite the default empty action, observation, and rewards with ones defined through the config."""
self.action_manager = ActionManager(config=self.config.action_space)
self.config.observation_space.options.thresholds = self.config.thresholds
self.observation_manager = ObservationManager(config=self.config.observation_space)
self.reward_function = RewardFunction(config=self.config.reward_function)
return super().model_post_init(__context)
def show_history(self, ignored_actions: Optional[list] = None):
"""
Print an agent action provided it's not the do-nothing action.
:param ignored_actions: OPTIONAL: List of actions to be ignored when displaying the history.
If not provided, defaults to ignore do-nothing actions.
"""
if not ignored_actions:
ignored_actions = ["do-nothing"]
table = PrettyTable()
table.field_names = ["Step", "Action", "Params", "Response", "Response Data"]
print(f"Actions for '{self.config.ref}':")
for item in self.history:
if item.action in ignored_actions:
pass
else:
# format dict by putting each key-value entry on a separate line and putting a blank line on the end.
param_string = "\n".join([*[f"{k}: {v:.30}" for k, v in item.parameters.items()], ""])
data_string = "\n".join([*[f"{k}: {v:.30}" for k, v in item.response.data], ""])
table.add_row([item.timestep, item.action, param_string, item.response.status, data_string])
print(table)
def update_observation(self, state: Dict) -> ObsType:
"""
Convert a state from the simulator into an observation for the agent using the observation space.
@@ -145,12 +176,23 @@ class AbstractAgent(BaseModel, ABC):
return request
def process_action_response(
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
self,
timestep: int,
action: str,
parameters: Dict[str, Any],
request: RequestFormat,
response: RequestResponse,
observation: ObsType,
) -> None:
"""Process the response from the most recent action."""
self.history.append(
AgentHistoryItem(
timestep=timestep, action=action, parameters=parameters, request=request, response=response
timestep=timestep,
action=action,
parameters=parameters,
request=request,
response=response,
observation=observation,
)
)

View File

@@ -26,7 +26,13 @@ class FileObservation(AbstractObservation, discriminator="file"):
file_system_requires_scan: Optional[bool] = None
"""If True, the file must be scanned to update the health state. Tf False, the true state is always shown."""
def __init__(self, where: WhereType, include_num_access: bool, file_system_requires_scan: bool) -> None:
def __init__(
self,
where: WhereType,
include_num_access: bool,
file_system_requires_scan: bool,
thresholds: Optional[Dict] = {},
) -> None:
"""
Initialise a file observation instance.
@@ -48,10 +54,36 @@ class FileObservation(AbstractObservation, discriminator="file"):
if self.include_num_access:
self.default_observation["num_access"] = 0
# TODO: allow these to be configured in yaml
self.high_threshold = 10
self.med_threshold = 5
self.low_threshold = 0
if thresholds.get("file_access") is None:
self.low_file_access_threshold = 0
self.med_file_access_threshold = 5
self.high_file_access_threshold = 10
else:
self._set_file_access_threshold(
thresholds=[
thresholds.get("file_access")["low"],
thresholds.get("file_access")["medium"],
thresholds.get("file_access")["high"],
]
)
def _set_file_access_threshold(self, thresholds: List[int]):
"""
Method that validates and then sets the file access threshold.
:param: thresholds: The file access threshold to validate and set.
"""
if self._validate_thresholds(
thresholds=[
thresholds[0],
thresholds[1],
thresholds[2],
],
threshold_identifier="file_access",
):
self.low_file_access_threshold = thresholds[0]
self.med_file_access_threshold = thresholds[1]
self.high_file_access_threshold = thresholds[2]
def _categorise_num_access(self, num_access: int) -> int:
"""
@@ -60,11 +92,11 @@ class FileObservation(AbstractObservation, discriminator="file"):
:param num_access: Number of file accesses.
:return: Bin number corresponding to the number of accesses.
"""
if num_access > self.high_threshold:
if num_access > self.high_file_access_threshold:
return 3
elif num_access > self.med_threshold:
elif num_access > self.med_file_access_threshold:
return 2
elif num_access > self.low_threshold:
elif num_access > self.low_file_access_threshold:
return 1
return 0
@@ -122,6 +154,7 @@ class FileObservation(AbstractObservation, discriminator="file"):
where=parent_where + ["files", config.file_name],
include_num_access=config.include_num_access,
file_system_requires_scan=config.file_system_requires_scan,
thresholds=config.thresholds,
)
@@ -149,6 +182,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"):
num_files: int,
include_num_access: bool,
file_system_requires_scan: bool,
thresholds: Optional[Dict] = {},
) -> None:
"""
Initialise a folder observation instance.
@@ -177,6 +211,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"):
where=None,
include_num_access=include_num_access,
file_system_requires_scan=self.file_system_requires_scan,
thresholds=thresholds,
)
)
while len(self.files) > num_files:
@@ -253,6 +288,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"):
for file_config in config.files:
file_config.include_num_access = config.include_num_access
file_config.file_system_requires_scan = config.file_system_requires_scan
file_config.thresholds = config.thresholds
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
return cls(
@@ -261,4 +297,5 @@ class FolderObservation(AbstractObservation, discriminator="folder"):
num_files=config.num_files,
include_num_access=config.include_num_access,
file_system_requires_scan=config.file_system_requires_scan,
thresholds=config.thresholds,
)

View File

@@ -54,7 +54,15 @@ class HostObservation(AbstractObservation, discriminator="host"):
"""
If True, files and folders must be scanned to update the health state. If False, true state is always shown.
"""
include_users: Optional[bool] = None
services_requires_scan: Optional[bool] = None
"""
If True, services must be scanned to update the health state. If False, true state is always shown.
"""
applications_requires_scan: Optional[bool] = None
"""
If True, applications must be scanned to update the health state. If False, true state is always shown.
"""
include_users: Optional[bool] = True
"""If True, report user session information."""
def __init__(
@@ -73,6 +81,8 @@ class HostObservation(AbstractObservation, discriminator="host"):
monitored_traffic: Optional[Dict],
include_num_access: bool,
file_system_requires_scan: bool,
services_requires_scan: bool,
applications_requires_scan: bool,
include_users: bool,
) -> None:
"""
@@ -108,6 +118,12 @@ class HostObservation(AbstractObservation, discriminator="host"):
:param file_system_requires_scan: If True, the files and folders must be scanned to update the health state.
If False, the true state is always shown.
:type file_system_requires_scan: bool
:param services_requires_scan: If True, services must be scanned to update the health state.
If False, the true state is always shown.
:type services_requires_scan: bool
:param applications_requires_scan: If True, applications must be scanned to update the health state.
If False, the true state is always shown.
:type applications_requires_scan: bool
:param include_users: If True, report user session information.
:type include_users: bool
"""
@@ -121,7 +137,7 @@ class HostObservation(AbstractObservation, discriminator="host"):
# Ensure lists have lengths equal to specified counts by truncating or padding
self.services: List[ServiceObservation] = services
while len(self.services) < num_services:
self.services.append(ServiceObservation(where=None))
self.services.append(ServiceObservation(where=None, services_requires_scan=services_requires_scan))
while len(self.services) > num_services:
truncated_service = self.services.pop()
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
@@ -129,7 +145,9 @@ class HostObservation(AbstractObservation, discriminator="host"):
self.applications: List[ApplicationObservation] = applications
while len(self.applications) < num_applications:
self.applications.append(ApplicationObservation(where=None))
self.applications.append(
ApplicationObservation(where=None, applications_requires_scan=applications_requires_scan)
)
while len(self.applications) > num_applications:
truncated_application = self.applications.pop()
msg = f"Too many applications in Node observation space for node. Truncating {truncated_application.where}"
@@ -153,7 +171,13 @@ class HostObservation(AbstractObservation, discriminator="host"):
self.nics: List[NICObservation] = network_interfaces
while len(self.nics) < num_nics:
self.nics.append(NICObservation(where=None, include_nmne=include_nmne, monitored_traffic=monitored_traffic))
self.nics.append(
NICObservation(
where=None,
include_nmne=include_nmne,
monitored_traffic=monitored_traffic,
)
)
while len(self.nics) > num_nics:
truncated_nic = self.nics.pop()
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}"
@@ -269,8 +293,15 @@ class HostObservation(AbstractObservation, discriminator="host"):
folder_config.include_num_access = config.include_num_access
folder_config.num_files = config.num_files
folder_config.file_system_requires_scan = config.file_system_requires_scan
folder_config.thresholds = config.thresholds
for nic_config in config.network_interfaces:
nic_config.include_nmne = config.include_nmne
nic_config.thresholds = config.thresholds
for service_config in config.services:
service_config.services_requires_scan = config.services_requires_scan
for application_config in config.applications:
application_config.applications_requires_scan = config.applications_requires_scan
application_config.thresholds = config.thresholds
services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services]
applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications]
@@ -281,7 +312,10 @@ class HostObservation(AbstractObservation, discriminator="host"):
count = 1
while len(nics) < config.num_nics:
nic_config = NICObservation.ConfigSchema(
nic_num=count, include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic
nic_num=count,
include_nmne=config.include_nmne,
monitored_traffic=config.monitored_traffic,
thresholds=config.thresholds,
)
nics.append(NICObservation.from_config(config=nic_config, parent_where=where))
count += 1
@@ -301,5 +335,7 @@ class HostObservation(AbstractObservation, discriminator="host"):
monitored_traffic=config.monitored_traffic,
include_num_access=config.include_num_access,
file_system_requires_scan=config.file_system_requires_scan,
services_requires_scan=config.services_requires_scan,
applications_requires_scan=config.applications_requires_scan,
include_users=config.include_users,
)

View File

@@ -1,13 +1,14 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from __future__ import annotations
from typing import Dict, List, Optional
from typing import ClassVar, Dict, List, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
from primaite.simulator.network.nmne import NMNEConfig
from primaite.utils.validation.ip_protocol import IPProtocol
from primaite.utils.validation.port import Port
@@ -15,6 +16,9 @@ from primaite.utils.validation.port import Port
class NICObservation(AbstractObservation, discriminator="network-interface"):
"""Status information about a network interface within the simulation environment."""
capture_nmne: ClassVar[bool] = NMNEConfig().capture_nmne
"A Boolean specifying whether malicious network events should be captured."
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for NICObservation."""
@@ -25,7 +29,13 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
monitored_traffic: Optional[Dict[IPProtocol, List[Port]]] = None
"""A dict containing which traffic types are to be included in the observation."""
def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None:
def __init__(
self,
where: WhereType,
include_nmne: bool,
monitored_traffic: Optional[Dict] = None,
thresholds: Dict = {},
) -> None:
"""
Initialise a network interface observation instance.
@@ -45,10 +55,18 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
self.nmne_inbound_last_step: int = 0
self.nmne_outbound_last_step: int = 0
# TODO: allow these to be configured in yaml
self.high_nmne_threshold = 10
self.med_nmne_threshold = 5
self.low_nmne_threshold = 0
if thresholds.get("nmne") is None:
self.low_nmne_threshold = 0
self.med_nmne_threshold = 5
self.high_nmne_threshold = 10
else:
self._set_nmne_threshold(
thresholds=[
thresholds.get("nmne")["low"],
thresholds.get("nmne")["medium"],
thresholds.get("nmne")["high"],
]
)
self.monitored_traffic = monitored_traffic
if self.monitored_traffic:
@@ -105,6 +123,20 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
bandwidth_utilisation = traffic_value / nic_max_bandwidth
return int(bandwidth_utilisation * 9) + 1
def _set_nmne_threshold(self, thresholds: List[int]):
"""
Method that validates and then sets the NMNE threshold.
:param: thresholds: The NMNE threshold to validate and set.
"""
if self._validate_thresholds(
thresholds=thresholds,
threshold_identifier="nmne",
):
self.low_nmne_threshold = thresholds[0]
self.med_nmne_threshold = thresholds[1]
self.high_nmne_threshold = thresholds[2]
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
@@ -116,7 +148,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
"""
nic_state = access_from_nested_dict(state, self.where)
if nic_state is NOT_PRESENT_IN_STATE:
if nic_state is NOT_PRESENT_IN_STATE or self.where is None:
return self.default_observation
obs = {"nic_status": 1 if nic_state["enabled"] else 2}
@@ -164,7 +196,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
for port in self.monitored_traffic[protocol]:
obs["TRAFFIC"][protocol][port] = {"inbound": 0, "outbound": 0}
if self.include_nmne:
if self.capture_nmne and self.include_nmne:
obs.update({"NMNE": {}})
direction_dict = nic_state["nmne"].get("direction", {})
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
@@ -224,6 +256,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
where=parent_where + ["NICs", config.nic_num],
include_nmne=config.include_nmne,
monitored_traffic=config.monitored_traffic,
thresholds=config.thresholds,
)

View File

@@ -48,7 +48,13 @@ class NodesObservation(AbstractObservation, discriminator="nodes"):
include_num_access: Optional[bool] = None
"""Flag to include the number of accesses."""
file_system_requires_scan: bool = True
"""If True, the folder must be scanned to update the health state. Tf False, the true state is always shown."""
"""If True, the folder must be scanned to update the health state. If False, the true state is always shown."""
services_requires_scan: bool = True
"""If True, the services must be scanned to update the health state.
If False, the true state is always shown."""
applications_requires_scan: bool = True
"""If True, the applications must be scanned to update the health state.
If False, the true state is always shown."""
include_users: Optional[bool] = True
"""If True, report user session information."""
num_ports: Optional[int] = None
@@ -196,8 +202,14 @@ class NodesObservation(AbstractObservation, discriminator="nodes"):
host_config.include_num_access = config.include_num_access
if host_config.file_system_requires_scan is None:
host_config.file_system_requires_scan = config.file_system_requires_scan
if host_config.services_requires_scan is None:
host_config.services_requires_scan = config.services_requires_scan
if host_config.applications_requires_scan is None:
host_config.applications_requires_scan = config.applications_requires_scan
if host_config.include_users is None:
host_config.include_users = config.include_users
if not host_config.thresholds:
host_config.thresholds = config.thresholds
for router_config in config.routers:
if router_config.num_ports is None:
@@ -214,6 +226,8 @@ class NodesObservation(AbstractObservation, discriminator="nodes"):
router_config.num_rules = config.num_rules
if router_config.include_users is None:
router_config.include_users = config.include_users
if not router_config.thresholds:
router_config.thresholds = config.thresholds
for firewall_config in config.firewalls:
if firewall_config.ip_list is None:
@@ -228,6 +242,8 @@ class NodesObservation(AbstractObservation, discriminator="nodes"):
firewall_config.num_rules = config.num_rules
if firewall_config.include_users is None:
firewall_config.include_users = config.include_users
if not firewall_config.thresholds:
firewall_config.thresholds = config.thresholds
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]

View File

@@ -114,7 +114,9 @@ class NestedObservation(AbstractObservation, discriminator="custom"):
instances = dict()
for component in config.components:
obs_class = AbstractObservation._registry[component.type]
obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options))
obs_instance = obs_class.from_config(
config=obs_class.ConfigSchema(**component.options, thresholds=config.thresholds)
)
instances[component.label] = obs_instance
return cls(components=instances)
@@ -228,7 +230,7 @@ class ObservationManager(BaseModel):
return self.obs.space
@classmethod
def from_config(cls, config: Optional[Dict]) -> "ObservationManager":
def from_config(cls, config: Optional[Dict], thresholds: Optional[Dict] = {}) -> "ObservationManager":
"""
Create observation space from a config.
@@ -239,11 +241,10 @@ class ObservationManager(BaseModel):
AbstractObservation
options: this must adhere to the chosen observation type's ConfigSchema nested class.
:type config: Dict
:param thresholds: Dictionary containing the observation thresholds.
:type thresholds: Optional[Dict]
"""
if config is None:
return cls(NullObservation())
obs_type = config["type"]
obs_class = AbstractObservation._registry[obs_type]
observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]))
obs_manager = cls(observation)
obs_manager = cls(config=config)
return obs_manager

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
"""Manages the observation space for the agent."""
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, Optional, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Type, Union
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -19,6 +19,9 @@ class AbstractObservation(ABC):
class ConfigSchema(ABC, BaseModel):
"""Config schema for observations."""
thresholds: Optional[Dict] = {}
"""A dict containing the observation thresholds."""
model_config = ConfigDict(extra="forbid")
_registry: Dict[str, Type["AbstractObservation"]] = {}
@@ -69,3 +72,34 @@ class AbstractObservation(ABC):
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> "AbstractObservation":
"""Create this observation space component form a serialised format."""
return cls()
def _validate_thresholds(self, thresholds: List[int] = None, threshold_identifier: Optional[str] = "") -> bool:
"""
Method that checks if the thresholds are non overlapping and in the correct (ascending) order.
Pass in the thresholds from low to high e.g.
thresholds=[low_threshold, med_threshold, ..._threshold, high_threshold]
Throws an error if the threshold is not valid
:param: thresholds: List of thresholds in ascending order.
:type: List[int]
:param: threshold_identifier: The name of the threshold option.
:type: Optional[str]
:returns: bool
"""
if thresholds is None or len(thresholds) < 2:
raise Exception(f"{threshold_identifier} thresholds are invalid {thresholds}")
for idx in range(1, len(thresholds)):
if not isinstance(thresholds[idx], int):
raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.")
if not isinstance(thresholds[idx - 1], int):
raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.")
if thresholds[idx] <= thresholds[idx - 1]:
raise Exception(
f"{threshold_identifier} threshold ({thresholds[idx - 1]}) "
f"is greater than or equal to ({thresholds[idx]}.)"
)
return True

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from __future__ import annotations
from typing import Dict
from typing import Dict, List, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -19,7 +19,10 @@ class ServiceObservation(AbstractObservation, discriminator="service"):
service_name: str
"""Name of the service, used for querying simulation state dictionary"""
def __init__(self, where: WhereType) -> None:
services_requires_scan: Optional[bool] = None
"""If True, services must be scanned to update the health state. If False, true state is always shown."""
def __init__(self, where: WhereType, services_requires_scan: bool) -> None:
"""
Initialise a service observation instance.
@@ -28,6 +31,7 @@ class ServiceObservation(AbstractObservation, discriminator="service"):
:type where: WhereType
"""
self.where = where
self.services_requires_scan = services_requires_scan
self.default_observation = {"operating_status": 0, "health_status": 0}
def observe(self, state: Dict) -> ObsType:
@@ -44,7 +48,9 @@ class ServiceObservation(AbstractObservation, discriminator="service"):
return self.default_observation
return {
"operating_status": service_state["operating_state"],
"health_status": service_state["health_state_visible"],
"health_status": service_state["health_state_visible"]
if self.services_requires_scan
else service_state["health_state_actual"],
}
@property
@@ -70,7 +76,9 @@ class ServiceObservation(AbstractObservation, discriminator="service"):
:return: Constructed service observation instance.
:rtype: ServiceObservation
"""
return cls(where=parent_where + ["services", config.service_name])
return cls(
where=parent_where + ["services", config.service_name], services_requires_scan=config.services_requires_scan
)
class ApplicationObservation(AbstractObservation, discriminator="application"):
@@ -82,7 +90,12 @@ class ApplicationObservation(AbstractObservation, discriminator="application"):
application_name: str
"""Name of the application, used for querying simulation state dictionary"""
def __init__(self, where: WhereType) -> None:
applications_requires_scan: Optional[bool] = None
"""
If True, applications must be scanned to update the health state. If False, true state is always shown.
"""
def __init__(self, where: WhereType, applications_requires_scan: bool, thresholds: Optional[Dict] = {}) -> None:
"""
Initialise an application observation instance.
@@ -92,25 +105,52 @@ class ApplicationObservation(AbstractObservation, discriminator="application"):
:type where: WhereType
"""
self.where = where
self.applications_requires_scan = applications_requires_scan
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
# TODO: allow these to be configured in yaml
self.high_threshold = 10
self.med_threshold = 5
self.low_threshold = 0
if thresholds.get("app_executions") is None:
self.low_app_execution_threshold = 0
self.med_app_execution_threshold = 5
self.high_app_execution_threshold = 10
else:
self._set_application_execution_thresholds(
thresholds=[
thresholds.get("app_executions")["low"],
thresholds.get("app_executions")["medium"],
thresholds.get("app_executions")["high"],
]
)
def _set_application_execution_thresholds(self, thresholds: List[int]):
"""
Method that validates and then sets the application execution threshold.
:param: thresholds: The application execution threshold to validate and set.
"""
if self._validate_thresholds(
thresholds=[
thresholds[0],
thresholds[1],
thresholds[2],
],
threshold_identifier="app_executions",
):
self.low_app_execution_threshold = thresholds[0]
self.med_app_execution_threshold = thresholds[1]
self.high_app_execution_threshold = thresholds[2]
def _categorise_num_executions(self, num_executions: int) -> int:
"""
Represent number of file accesses as a categorical variable.
Represent number of application executions as a categorical variable.
:param num_access: Number of file accesses.
:param num_access: Number of application executions.
:return: Bin number corresponding to the number of accesses.
"""
if num_executions > self.high_threshold:
if num_executions > self.high_app_execution_threshold:
return 3
elif num_executions > self.med_threshold:
elif num_executions > self.med_app_execution_threshold:
return 2
elif num_executions > self.low_threshold:
elif num_executions > self.low_app_execution_threshold:
return 1
return 0
@@ -128,7 +168,9 @@ class ApplicationObservation(AbstractObservation, discriminator="application"):
return self.default_observation
return {
"operating_status": application_state["operating_state"],
"health_status": application_state["health_state_visible"],
"health_status": application_state["health_state_visible"]
if self.applications_requires_scan
else application_state["health_state_actual"],
"num_executions": self._categorise_num_executions(application_state["num_executions"]),
}
@@ -161,4 +203,8 @@ class ApplicationObservation(AbstractObservation, discriminator="application"):
:return: Constructed application observation instance.
:rtype: ApplicationObservation
"""
return cls(where=parent_where + ["applications", config.application_name])
return cls(
where=parent_where + ["applications", config.application_name],
applications_requires_scan=config.applications_requires_scan,
thresholds=config.thresholds,
)

View File

@@ -7,6 +7,7 @@ from pydantic import BaseModel, ConfigDict
from primaite import DEFAULT_BANDWIDTH, getLogger
from primaite.game.agent.interface import AbstractAgent, ProxyAgent
from primaite.game.agent.observations import NICObservation
from primaite.game.agent.rewards import SharedReward
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT
@@ -44,15 +45,15 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP
_LOGGER = getLogger(__name__)
SERVICE_TYPES_MAPPING = {
"DNSClient": DNSClient,
"DNSServer": DNSServer,
"DatabaseService": DatabaseService,
"WebServer": WebServer,
"FTPClient": FTPClient,
"FTPServer": FTPServer,
"NTPClient": NTPClient,
"NTPServer": NTPServer,
"Terminal": Terminal,
"dns-client": DNSClient,
"dns-server": DNSServer,
"database-service": DatabaseService,
"web-server": WebServer,
"ftp-client": FTPClient,
"ftp-server": FTPServer,
"ntp-client": NTPClient,
"ntp-server": NTPServer,
"terminal": Terminal,
}
"""List of available services that can be installed on nodes in the PrimAITE Simulation."""
@@ -68,6 +69,8 @@ class PrimaiteGameOptions(BaseModel):
seed: int = None
"""Random number seed for RNGs."""
generate_seed_value: bool = False
"""Internally generated seed value."""
max_episode_length: int = 256
"""Maximum number of episodes for the PrimAITE game."""
ports: List[Port]
@@ -175,6 +178,7 @@ class PrimaiteGame:
parameters=parameters,
request=request,
response=response,
observation=obs,
)
def pre_timestep(self) -> None:
@@ -263,6 +267,7 @@ class PrimaiteGame:
node_sets_cfg = network_config.get("node_sets", [])
# Set the NMNE capture config
NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {}))
NICObservation.capture_nmne = NMNEConfig(**network_config.get("nmne_config", {})).capture_nmne
for node_cfg in nodes_cfg:
n_type = node_cfg["type"]
@@ -293,6 +298,7 @@ class PrimaiteGame:
if "users" in node_cfg and new_node.software_manager.software.get("user-manager"):
user_manager: UserManager = new_node.software_manager.software["user-manager"] # noqa
for user_cfg in node_cfg["users"]:
user_manager.add_user(**user_cfg, bypass_can_perform_action=True)
@@ -407,6 +413,7 @@ class PrimaiteGame:
agents_cfg = cfg.get("agents", [])
for agent_cfg in agents_cfg:
agent_cfg = {**agent_cfg, "thresholds": game.options.thresholds}
new_agent = AbstractAgent.from_config(agent_cfg)
game.agents[agent_cfg["ref"]] = new_agent
if isinstance(new_agent, ProxyAgent):

View File

@@ -47,7 +47,7 @@
"source": [
"def make_cfg_have_flat_obs(cfg):\n",
" for agent in cfg['agents']:\n",
" if agent['type'] == \"ProxyAgent\":\n",
" if agent['type'] == \"proxy-agent\":\n",
" agent['agent_settings']['flatten_obs'] = False"
]
},

View File

@@ -684,6 +684,15 @@
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env.game.agents[\"data_manipulation_attacker\"].show_history()"
]
},
{
"cell_type": "markdown",
"metadata": {},

View File

@@ -153,6 +153,49 @@
"PRIMAITE_CONFIG[\"developer_mode\"][\"enabled\"] = was_enabled\n",
"PRIMAITE_CONFIG[\"developer_mode\"][\"output_sys_logs\"] = was_syslogs_enabled"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Viewing Agent history\n",
"\n",
"It's possible to view the actions carried out by an agent for a given training session using the `show_history()` method. By default, this will be all actions apart from DONOTHING actions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(data_manipulation_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"env = PrimaiteGymEnv(env_config=cfg)\n",
"\n",
"# Run the training session to generate some resultant data.\n",
"for i in range(100):\n",
" env.step(0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Calling `.show_history()` should show us when the Data Manipulation used the `NODE_APPLICATION_EXECUTE` action."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"attacker = env.game.agents[\"data_manipulation_attacker\"]\n",
"\n",
"attacker.show_history()"
]
}
],
"metadata": {

View File

@@ -15,15 +15,6 @@
"To display the available dev-mode options, run the command below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,

View File

@@ -9,6 +9,13 @@
"© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Simulation Layer Implementation."
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -186,6 +193,22 @@
"computer_b.file_system.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Information about the latest response when executing a remote command can be seen by calling the `last_response` attribute within `Terminal`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(terminal_a.last_response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -224,11 +247,254 @@
"source": [
"computer_b.user_session_manager.show(include_historic=True, include_session_id=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Game Layer Implementation\n",
"\n",
"This notebook section will detail the implementation of how the game layer utilises the terminal to support different agent actions.\n",
"\n",
"The ``Terminal`` is used in a variety of different ways in the game layer. Specifically, the terminal is leveraged to implement the following actions:\n",
"\n",
"\n",
"| Game Layer Action | Simulation Layer |\n",
"|-----------------------------------|--------------------------|\n",
"| ``node-send-local-command`` | Uses the given user credentials, creates a ``LocalTerminalSession`` and executes the given command and returns the ``RequestResponse``.\n",
"| ``node-session-remote-login`` | Uses the given user credentials and remote IP to create a ``RemoteTerminalSession``.\n",
"| ``node-send-remote-command`` | Uses the given remote IP to locate the correct ``RemoteTerminalSession``, executes the given command and returns the ``RequestsResponse``."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Game Layer Setup\n",
"\n",
"Similar to other notebooks, the next code cells create a custom proxy agent to demonstrate how these commands can be leveraged by agents in the ``UC2`` network environment.\n",
"\n",
"If you're unfamiliar with ``UC2`` then please refer to the [UC2-E2E-Demo notebook for further reference](./Data-Manipulation-E2E-Demonstration.ipynb)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import yaml\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from primaite.session.environment import PrimaiteGymEnv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"custom_terminal_agent = \"\"\"\n",
" - ref: CustomC2Agent\n",
" team: RED\n",
" type: proxy-agent\n",
" action_space:\n",
" action_map:\n",
" 0:\n",
" action: do-nothing\n",
" options: {}\n",
" 1:\n",
" action: node-send-local-command\n",
" options:\n",
" node_name: client_1\n",
" username: admin\n",
" password: admin\n",
" command:\n",
" - file_system\n",
" - create\n",
" - file\n",
" - downloads\n",
" - dog.png\n",
" - False\n",
" 2:\n",
" action: node-session-remote-login\n",
" options:\n",
" node_name: client_1\n",
" username: admin\n",
" password: admin\n",
" remote_ip: 192.168.10.22\n",
" 3:\n",
" action: node-send-remote-command\n",
" options:\n",
" node_name: client_1\n",
" remote_ip: 192.168.10.22\n",
" command:\n",
" - file_system\n",
" - create\n",
" - file\n",
" - downloads\n",
" - cat.png\n",
" - False\n",
"\"\"\"\n",
"custom_terminal_agent_yaml = yaml.safe_load(custom_terminal_agent)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(data_manipulation_config_path()) as f:\n",
" cfg = yaml.safe_load(f)\n",
" # removing all agents & adding the custom agent.\n",
" cfg['agents'] = {}\n",
" cfg['agents'] = custom_terminal_agent_yaml\n",
"\n",
"env = PrimaiteGymEnv(env_config=cfg)\n",
"\n",
"client_1: Computer = env.game.simulation.network.get_node_by_hostname(\"client_1\")\n",
"client_2: Computer = env.game.simulation.network.get_node_by_hostname(\"client_2\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Terminal Action | ``node-send-local-command`` \n",
"\n",
"The yaml snippet below shows all the relevant agent options for this action:\n",
"\n",
"```yaml\n",
"\n",
" action_space:\n",
" action_list:\n",
" ...\n",
" - type: node-send-local-command\n",
" ...\n",
" options:\n",
" nodes: # Node List\n",
" - node_name: client_1\n",
" ...\n",
" ...\n",
" action_map:\n",
" 1:\n",
" action: node-send-local-command\n",
" options:\n",
" node_id: 0 # Index 0 at the node list.\n",
" username: admin\n",
" password: admin\n",
" command:\n",
" - file_system\n",
" - create\n",
" - file\n",
" - downloads\n",
" - dog.png\n",
" - False\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env.step(1)\n",
"client_1.file_system.show(full=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Terminal Action | ``node-session-remote-login`` \n",
"\n",
"The yaml snippet below shows all the relevant agent options for this action:\n",
"\n",
"```yaml\n",
"\n",
" action_space:\n",
" action_list:\n",
" ...\n",
" - type: node-session-remote-login\n",
" ...\n",
" options:\n",
" nodes: # Node List\n",
" - node_name: client_1\n",
" ...\n",
" ...\n",
" action_map:\n",
" 2:\n",
" action: node-session-remote-login\n",
" options:\n",
" node_id: 0 # Index 0 at the node list.\n",
" username: admin\n",
" password: admin\n",
" remote_ip: 192.168.10.22 # client_2's ip address.\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env.step(2)\n",
"client_2.session_manager.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Terminal Action | ``node-send-remote-command``\n",
"\n",
"The yaml snippet below shows all the relevant agent options for this action:\n",
"\n",
"```yaml\n",
"\n",
" action_space:\n",
" action_list:\n",
" ...\n",
" - type: node-send-remote-command\n",
" ...\n",
" options:\n",
" nodes: # Node List\n",
" - node_name: client_1\n",
" ...\n",
" ...\n",
" action_map:\n",
" 1:\n",
" action: node-send-remote-command\n",
" options:\n",
" node_id: 0 # Index 0 at the node list.\n",
" remote_ip: 192.168.10.22\n",
" commands:\n",
" - file_system\n",
" - create\n",
" - file\n",
" - downloads\n",
" - cat.png\n",
" - False\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env.step(3)\n",
"client_2.file_system.show(full=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "venv",
"language": "python",
"name": "python3"
},

View File

@@ -26,14 +26,26 @@ except ModuleNotFoundError:
_LOGGER.debug("Torch not available for importing")
def set_random_seed(seed: int) -> Union[None, int]:
def set_random_seed(seed: int, generate_seed_value: bool) -> Union[None, int]:
"""
Set random number generators.
If seed is None or -1 and generate_seed_value is True randomly generate a
seed value.
If seed is > -1 and generate_seed_value is True ignore the latter and use
the provide seed value.
:param seed: int
:param generate_seed_value: bool
:return: None or the int representing the seed used.
"""
if seed is None or seed == -1:
return None
if generate_seed_value:
rng = np.random.default_rng()
# 2**32-1 is highest value for python RNG seed.
seed = int(rng.integers(low=0, high=2**32 - 1))
else:
return None
elif seed < -1:
raise ValueError("Invalid random number seed")
# Seed python RNG
@@ -50,6 +62,13 @@ def set_random_seed(seed: int) -> Union[None, int]:
return seed
def log_seed_value(seed: int):
"""Log the selected seed value to file."""
path = SIM_OUTPUT.path / "seed.log"
with open(path, "w") as file:
file.write(f"Seed value = {seed}")
class PrimaiteGymEnv(gymnasium.Env):
"""
Thin wrapper env to provide agents with a gymnasium API.
@@ -65,7 +84,8 @@ class PrimaiteGymEnv(gymnasium.Env):
"""Object that returns a config corresponding to the current episode."""
self.seed = self.episode_scheduler(0).get("game", {}).get("seed")
"""Get RNG seed from config file. NB: Must be before game instantiation."""
self.seed = set_random_seed(self.seed)
self.generate_seed_value = self.episode_scheduler(0).get("game", {}).get("generate_seed_value")
self.seed = set_random_seed(self.seed, self.generate_seed_value)
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
@@ -79,6 +99,8 @@ class PrimaiteGymEnv(gymnasium.Env):
_LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}")
log_seed_value(self.seed)
def action_masks(self) -> np.ndarray:
"""
Return the action mask for the agent.
@@ -146,7 +168,7 @@ class PrimaiteGymEnv(gymnasium.Env):
f"avg. reward: {self.agent.reward_function.total_reward}"
)
if seed is not None:
set_random_seed(seed)
set_random_seed(seed, self.generate_seed_value)
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
if self.io.settings.save_agent_actions:

View File

@@ -862,7 +862,21 @@ class UserManager(Service, discriminator="user-manager"):
"""
rm = super()._init_request_manager()
# todo add doc about requeest schemas
# todo add doc about request schemas
rm.add_request(
"add_user",
RequestType(
func=lambda request, context: RequestResponse.from_bool(
self.add_user(username=request[0], password=request[1], is_admin=request[2])
)
),
)
rm.add_request(
"disable_user",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.disable_user(username=request[0]))
),
)
rm.add_request(
"change_password",
RequestType(
@@ -1570,7 +1584,7 @@ class Node(SimComponent, ABC):
operating_state: Any = None
users: Any = None # Temporary to appease "extra=forbid"
users: List[Dict] = [] # Temporary to appease "extra=forbid"
config: ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema())
"""Configuration items within Node"""
@@ -1636,6 +1650,8 @@ class Node(SimComponent, ABC):
self._install_system_software()
self.session_manager.node = self
self.session_manager.software_manager = self.software_manager
for user in self.config.users:
self.user_manager.add_user(**user, bypass_can_perform_action=True)
@property
def user_manager(self) -> Optional[UserManager]:
@@ -1767,7 +1783,7 @@ class Node(SimComponent, ABC):
"""
application_name = request[0]
if self.software_manager.software.get(application_name):
self.sys_log.warning(f"Can't install {application_name}. It's already installed.")
self.sys_log.info(f"Can't install {application_name}. It's already installed.")
return RequestResponse(status="success", data={"reason": "already installed"})
application_class = Application._registry[application_name]
self.software_manager.install(application_class)

View File

@@ -2,11 +2,12 @@
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Any, ClassVar, Dict, Literal, Optional
from typing import Any, ClassVar, Dict, List, Literal, Optional
from pydantic import Field
from primaite import getLogger
from primaite.simulator.file_system.file_type import FileType
from primaite.simulator.network.hardware.base import (
IPWiredNetworkInterface,
Link,
@@ -339,7 +340,7 @@ class HostNode(Node, discriminator="host-node"):
ip_address: IPV4Address
services: Any = None # temporarily unset to appease extra="forbid"
applications: Any = None # temporarily unset to appease extra="forbid"
folders: Any = None # temporarily unset to appease extra="forbid"
folders: List[Dict] = {} # temporarily unset to appease extra="forbid"
network_interfaces: Any = None # temporarily unset to appease extra="forbid"
config: ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema())
@@ -348,6 +349,18 @@ class HostNode(Node, discriminator="host-node"):
super().__init__(**kwargs)
self.connect_nic(NIC(ip_address=kwargs["config"].ip_address, subnet_mask=kwargs["config"].subnet_mask))
for folder in self.config.folders:
# handle empty foler defined by just a string
self.file_system.create_folder(folder["folder_name"])
for file in folder.get("files", []):
self.file_system.create_file(
folder_name=folder["folder_name"],
file_name=file["file_name"],
size=file.get("size", 0),
file_type=FileType[file.get("type", "UNKNOWN").upper()],
)
@property
def nmap(self) -> Optional[NMAP]:
"""

View File

@@ -49,7 +49,7 @@ class Firewall(Router, discriminator="firewall"):
Example:
>>> from primaite.simulator.network.transmission.network_layer import IPProtocol
>>> from primaite.simulator.network.transmission.transport_layer import Port
>>> from primaite.utils.validation.port import Port
>>> firewall = Firewall(hostname="Firewall1")
>>> firewall.configure_internal_port(ip_address="192.168.1.1", subnet_mask="255.255.255.0")
>>> firewall.configure_external_port(ip_address="10.0.0.1", subnet_mask="255.255.255.0")

View File

@@ -467,6 +467,7 @@ class AccessControlList(SimComponent):
"""Check if a packet with the given properties is permitted through the ACL."""
permitted = False
rule: ACLRule = None
for _rule in self._acl:
if not _rule:
continue
@@ -1215,9 +1216,9 @@ class Router(NetworkNode, discriminator="router"):
config: ConfigSchema = Field(default_factory=lambda: Router.ConfigSchema())
SYSTEM_SOFTWARE: ClassVar[Dict] = {
"UserSessionManager": UserSessionManager,
"UserManager": UserManager,
"Terminal": Terminal,
"user-session-manager": UserSessionManager,
"user-manager": UserManager,
"terminal": Terminal,
}
network_interfaces: Dict[str, RouterInterface] = {}
@@ -1384,6 +1385,12 @@ class Router(NetworkNode, discriminator="router"):
return False
def subject_to_acl(self, frame: Frame) -> bool:
"""Check that frame is subject to ACL rules."""
if frame.ip.protocol == "udp" and frame.is_arp:
return False
return True
def receive_frame(self, frame: Frame, from_network_interface: RouterInterface):
"""
Processes an incoming frame received on one of the router's interfaces.
@@ -1397,8 +1404,12 @@ class Router(NetworkNode, discriminator="router"):
if self.operating_state != NodeOperatingState.ON:
return
# Check if it's permitted
permitted, rule = self.acl.is_permitted(frame)
if self.subject_to_acl(frame=frame):
# Check if it's permitted
permitted, rule = self.acl.is_permitted(frame)
else:
permitted = True
rule = None
if not permitted:
at_port = self._get_port_of_nic(from_network_interface)

View File

@@ -163,7 +163,7 @@ class Frame(BaseModel):
"""
Checks if the Frame is an ARP (Address Resolution Protocol) packet.
This is determined by checking if the destination port of the TCP header is equal to the ARP port.
This is determined by checking if the destination and source port of the UDP header is equal to the ARP port.
:return: True if the Frame is an ARP packet, otherwise False.
"""

View File

@@ -415,5 +415,5 @@ class SessionManager:
table.align = "l"
table.title = f"{self.sys_log.hostname} Session Manager"
for session in self.sessions_by_key.values():
table.add_row([session.dst_ip_address, session.dst_port, session.protocol])
table.add_row([session.with_ip_address, session.dst_port, session.protocol])
print(table)

View File

@@ -55,7 +55,7 @@ class ARP(Service, discriminator="arp"):
:param markdown: If True, format the output as Markdown. Otherwise, use plain text.
"""
table = PrettyTable(["IP Address", "MAC Address", "Via"])
table = PrettyTable(["IP Address", "MAC Address", "Via", "Port"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
@@ -66,6 +66,7 @@ class ARP(Service, discriminator="arp"):
str(ip),
arp.mac_address,
self.software_manager.node.network_interfaces[arp.network_interface_uuid].mac_address,
self.software_manager.node.network_interfaces[arp.network_interface_uuid].port_num,
]
)
print(table)

View File

@@ -142,12 +142,20 @@ class Terminal(Service, discriminator="terminal"):
_client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {}
"""Dictionary of connect requests made to remote nodes."""
_last_response: Optional[RequestResponse] = None
"""Last response received from RequestManager, for returning remote RequestResponse."""
def __init__(self, **kwargs):
kwargs["name"] = "terminal"
kwargs["port"] = PORT_LOOKUP["SSH"]
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
@property
def last_response(self) -> Optional[RequestResponse]:
"""Public version of _last_response attribute."""
return self._last_response
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -186,7 +194,7 @@ class Terminal(Service, discriminator="terminal"):
return RequestResponse(status="failure", data={})
rm.add_request(
"node-session-remote-login",
"node_session_remote_login",
request_type=RequestType(func=_remote_login),
)
@@ -209,28 +217,45 @@ class Terminal(Service, discriminator="terminal"):
command: str = request[1]["command"]
remote_connection = self._get_connection_from_ip(ip_address=ip_address)
if remote_connection:
outcome = remote_connection.execute(command)
if outcome:
return RequestResponse(
status="success",
data={},
)
else:
return RequestResponse(
status="failure",
data={},
)
remote_connection.execute(command)
return self.last_response if not None else RequestResponse(status="failure", data={})
return RequestResponse(
status="failure",
data={"reason": "Failed to execute command."},
)
rm.add_request(
"send_remote_command",
request_type=RequestType(func=remote_execute_request),
)
def local_execute_request(request: RequestFormat, context: Dict) -> RequestResponse:
"""Executes a command using a local terminal session."""
command: str = request[2]["command"]
local_connection = self._process_local_login(username=request[0], password=request[1])
if local_connection:
outcome = local_connection.execute(command)
if outcome:
return RequestResponse(
status="success",
data={"reason": outcome},
)
return RequestResponse(
status="success",
data={"reason": "Local Terminal failed to resolve command. Potentially invalid credentials?"},
)
rm.add_request(
"send_local_command",
request_type=RequestType(func=local_execute_request),
)
return rm
def execute(self, command: List[Any]) -> Optional[RequestResponse]:
"""Execute a passed ssh command via the request manager."""
return self.parent.apply_request(command)
self._last_response = self.parent.apply_request(command)
return self._last_response
def _get_connection_from_ip(self, ip_address: IPv4Address) -> Optional[RemoteTerminalConnection]:
"""Find Remote Terminal Connection from a given IP."""
@@ -409,6 +434,8 @@ class Terminal(Service, discriminator="terminal"):
"""
source_ip = kwargs["frame"].ip.src_ip_address
self.sys_log.info(f"{self.name}: Received payload: {payload}. Source: {source_ip}")
self._last_response = None # Clear last response
if isinstance(payload, SSHPacket):
if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST:
# validate & add connection
@@ -457,6 +484,9 @@ class Terminal(Service, discriminator="terminal"):
session_id=session_id,
source_ip=source_ip,
)
self._last_response: RequestResponse = RequestResponse(
status="success", data={"reason": "Login Successful"}
)
elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST:
# Requesting a command to be executed
@@ -468,12 +498,32 @@ class Terminal(Service, discriminator="terminal"):
payload.connection_uuid
)
remote_session.last_active_step = self.software_manager.node.user_session_manager.current_timestep
self.execute(command)
self._last_response: RequestResponse = self.execute(command)
if self._last_response.status == "success":
transport_message = SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS
else:
transport_message = SSHTransportMessage.SSH_MSG_SERVICE_FAILED
payload: SSHPacket = SSHPacket(
payload=self._last_response,
transport_message=transport_message,
connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_DATA,
)
self.software_manager.send_payload_to_session_manager(
payload=payload, dest_port=self.port, session_id=session_id
)
return True
else:
self.sys_log.error(
f"{self.name}: Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command."
)
elif (
payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS
or SSHTransportMessage.SSH_MSG_SERVICE_FAILED
):
# Likely receiving command ack from remote.
self._last_response = payload.payload
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "disconnect":

View File

@@ -117,37 +117,44 @@ class WebServer(Service, discriminator="web-server"):
:type: payload: HttpRequestPacket
"""
response = HttpResponsePacket(status_code=HttpStatusCode.NOT_FOUND, payload=payload)
try:
parsed_url = urlparse(payload.request_url)
path = parsed_url.path.strip("/")
if len(path) < 1:
parsed_url = urlparse(payload.request_url)
path = parsed_url.path.strip("/") if parsed_url and parsed_url.path else ""
if len(path) < 1:
# query succeeded
response.status_code = HttpStatusCode.OK
if path.startswith("users"):
# get data from DatabaseServer
# get all users
if not self._establish_db_connection():
# unable to create a db connection
response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR
return response
if self.db_connection.query("SELECT"):
# query succeeded
self.set_health_state(SoftwareHealthState.GOOD)
response.status_code = HttpStatusCode.OK
else:
self.set_health_state(SoftwareHealthState.COMPROMISED)
return response
if path.startswith("users"):
# get data from DatabaseServer
# get all users
if not self.db_connection:
self._establish_db_connection()
if self.db_connection.query("SELECT"):
# query succeeded
self.set_health_state(SoftwareHealthState.GOOD)
response.status_code = HttpStatusCode.OK
else:
self.set_health_state(SoftwareHealthState.COMPROMISED)
return response
except Exception: # TODO: refactor this. Likely to cause silent bugs. (ADO ticket #2345 )
# something went wrong on the server
response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR
return response
def _establish_db_connection(self) -> None:
def _establish_db_connection(self) -> bool:
"""Establish a connection to db."""
# if active db connection, return true
if self.db_connection:
return True
# otherwise, try to create db connection
db_client = self.software_manager.software.get("database-client")
if db_client is None:
return False # database client not installed
self.db_connection: DatabaseClientConnection = db_client.get_new_connection()
return self.db_connection is not None
def send(
self,