Merge in updates from dev
This commit is contained in:
@@ -1 +1 @@
|
||||
4.0.0a1-dev
|
||||
4.0.0-dev
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import ClassVar, List, Optional, Union
|
||||
from typing import ClassVar, List, Literal, Optional, Union
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
@@ -153,8 +153,6 @@ class NodeNMAPPortScanAction(NodeNMAPAbstractAction, discriminator="node-nmap-po
|
||||
class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, discriminator="node-network-service-recon"):
|
||||
"""Action which performs an nmap network service recon (ping scan followed by port scan)."""
|
||||
|
||||
config: "NodeNetworkServiceReconAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeNetworkServiceReconAction."""
|
||||
|
||||
@@ -179,3 +177,70 @@ class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, discriminator="node-
|
||||
"show": config.show,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class NodeAccountsAddUserAction(AbstractAction, discriminator="node-account-add-user"):
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
type: Literal["node-account-add-user"] = "node-account-add-user"
|
||||
node_name: str
|
||||
username: str
|
||||
password: str
|
||||
is_admin: bool
|
||||
|
||||
@classmethod
|
||||
@staticmethod
|
||||
def form_request(config: ConfigSchema) -> RequestFormat:
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"service",
|
||||
"user-manager",
|
||||
"add_user",
|
||||
config.username,
|
||||
config.password,
|
||||
config.is_admin,
|
||||
]
|
||||
|
||||
|
||||
class NodeAccountsDisableUserAction(AbstractAction, discriminator="node-account-disable-user"):
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
type: Literal["node-account-disable-user"] = "node-account-disable-user"
|
||||
node_name: str
|
||||
username: str
|
||||
|
||||
@classmethod
|
||||
@staticmethod
|
||||
def form_request(config: ConfigSchema) -> RequestFormat:
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"service",
|
||||
"user-manager",
|
||||
"disable_user",
|
||||
config.username,
|
||||
]
|
||||
|
||||
|
||||
class NodeSendLocalCommandAction(AbstractAction, discriminator="node-send-local-command"):
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
type: Literal["node-send-local-command"] = "node-send-local-command"
|
||||
node_name: str
|
||||
username: str
|
||||
password: str
|
||||
command: RequestFormat
|
||||
|
||||
@staticmethod
|
||||
def form_request(config: ConfigSchema) -> RequestFormat:
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"service",
|
||||
"terminal",
|
||||
"send_local_command",
|
||||
config.username,
|
||||
config.password,
|
||||
{"command": config.command},
|
||||
]
|
||||
|
||||
@@ -34,8 +34,6 @@ class NodeSessionAbstractAction(AbstractAction, ABC):
|
||||
class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="node-session-remote-login"):
|
||||
"""Action which performs a remote session login."""
|
||||
|
||||
config: "NodeSessionsRemoteLoginAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeSessionsRemoteLoginAction."""
|
||||
|
||||
@@ -53,7 +51,7 @@ class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="no
|
||||
config.node_name,
|
||||
"service",
|
||||
"terminal",
|
||||
"node-session-remote-login",
|
||||
"node_session_remote_login",
|
||||
config.username,
|
||||
config.password,
|
||||
config.remote_ip,
|
||||
|
||||
@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from prettytable import PrettyTable
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
@@ -42,6 +43,9 @@ class AgentHistoryItem(BaseModel):
|
||||
|
||||
reward_info: Dict[str, Any] = {}
|
||||
|
||||
observation: Optional[ObsType] = None
|
||||
"""The observation space data for this step."""
|
||||
|
||||
|
||||
class AbstractAgent(BaseModel, ABC):
|
||||
"""Base class for scripted and RL agents."""
|
||||
@@ -67,6 +71,9 @@ class AbstractAgent(BaseModel, ABC):
|
||||
default_factory=lambda: ObservationManager.ConfigSchema()
|
||||
)
|
||||
reward_function: RewardFunction.ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema())
|
||||
thresholds: Optional[Dict] = {}
|
||||
# TODO: this is only relevant to some observations, need to refactor the way thresholds are dealt with (#3085)
|
||||
"""A dict containing the observation thresholds."""
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: AbstractAgent.ConfigSchema())
|
||||
|
||||
@@ -95,10 +102,34 @@ class AbstractAgent(BaseModel, ABC):
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Overwrite the default empty action, observation, and rewards with ones defined through the config."""
|
||||
self.action_manager = ActionManager(config=self.config.action_space)
|
||||
self.config.observation_space.options.thresholds = self.config.thresholds
|
||||
self.observation_manager = ObservationManager(config=self.config.observation_space)
|
||||
self.reward_function = RewardFunction(config=self.config.reward_function)
|
||||
return super().model_post_init(__context)
|
||||
|
||||
def show_history(self, ignored_actions: Optional[list] = None):
|
||||
"""
|
||||
Print an agent action provided it's not the do-nothing action.
|
||||
|
||||
:param ignored_actions: OPTIONAL: List of actions to be ignored when displaying the history.
|
||||
If not provided, defaults to ignore do-nothing actions.
|
||||
"""
|
||||
if not ignored_actions:
|
||||
ignored_actions = ["do-nothing"]
|
||||
table = PrettyTable()
|
||||
table.field_names = ["Step", "Action", "Params", "Response", "Response Data"]
|
||||
print(f"Actions for '{self.config.ref}':")
|
||||
for item in self.history:
|
||||
if item.action in ignored_actions:
|
||||
pass
|
||||
else:
|
||||
# format dict by putting each key-value entry on a separate line and putting a blank line on the end.
|
||||
param_string = "\n".join([*[f"{k}: {v:.30}" for k, v in item.parameters.items()], ""])
|
||||
data_string = "\n".join([*[f"{k}: {v:.30}" for k, v in item.response.data], ""])
|
||||
|
||||
table.add_row([item.timestep, item.action, param_string, item.response.status, data_string])
|
||||
print(table)
|
||||
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Convert a state from the simulator into an observation for the agent using the observation space.
|
||||
@@ -145,12 +176,23 @@ class AbstractAgent(BaseModel, ABC):
|
||||
return request
|
||||
|
||||
def process_action_response(
|
||||
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
|
||||
self,
|
||||
timestep: int,
|
||||
action: str,
|
||||
parameters: Dict[str, Any],
|
||||
request: RequestFormat,
|
||||
response: RequestResponse,
|
||||
observation: ObsType,
|
||||
) -> None:
|
||||
"""Process the response from the most recent action."""
|
||||
self.history.append(
|
||||
AgentHistoryItem(
|
||||
timestep=timestep, action=action, parameters=parameters, request=request, response=response
|
||||
timestep=timestep,
|
||||
action=action,
|
||||
parameters=parameters,
|
||||
request=request,
|
||||
response=response,
|
||||
observation=observation,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -26,7 +26,13 @@ class FileObservation(AbstractObservation, discriminator="file"):
|
||||
file_system_requires_scan: Optional[bool] = None
|
||||
"""If True, the file must be scanned to update the health state. Tf False, the true state is always shown."""
|
||||
|
||||
def __init__(self, where: WhereType, include_num_access: bool, file_system_requires_scan: bool) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
include_num_access: bool,
|
||||
file_system_requires_scan: bool,
|
||||
thresholds: Optional[Dict] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a file observation instance.
|
||||
|
||||
@@ -48,10 +54,36 @@ class FileObservation(AbstractObservation, discriminator="file"):
|
||||
if self.include_num_access:
|
||||
self.default_observation["num_access"] = 0
|
||||
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_threshold = 10
|
||||
self.med_threshold = 5
|
||||
self.low_threshold = 0
|
||||
if thresholds.get("file_access") is None:
|
||||
self.low_file_access_threshold = 0
|
||||
self.med_file_access_threshold = 5
|
||||
self.high_file_access_threshold = 10
|
||||
else:
|
||||
self._set_file_access_threshold(
|
||||
thresholds=[
|
||||
thresholds.get("file_access")["low"],
|
||||
thresholds.get("file_access")["medium"],
|
||||
thresholds.get("file_access")["high"],
|
||||
]
|
||||
)
|
||||
|
||||
def _set_file_access_threshold(self, thresholds: List[int]):
|
||||
"""
|
||||
Method that validates and then sets the file access threshold.
|
||||
|
||||
:param: thresholds: The file access threshold to validate and set.
|
||||
"""
|
||||
if self._validate_thresholds(
|
||||
thresholds=[
|
||||
thresholds[0],
|
||||
thresholds[1],
|
||||
thresholds[2],
|
||||
],
|
||||
threshold_identifier="file_access",
|
||||
):
|
||||
self.low_file_access_threshold = thresholds[0]
|
||||
self.med_file_access_threshold = thresholds[1]
|
||||
self.high_file_access_threshold = thresholds[2]
|
||||
|
||||
def _categorise_num_access(self, num_access: int) -> int:
|
||||
"""
|
||||
@@ -60,11 +92,11 @@ class FileObservation(AbstractObservation, discriminator="file"):
|
||||
:param num_access: Number of file accesses.
|
||||
:return: Bin number corresponding to the number of accesses.
|
||||
"""
|
||||
if num_access > self.high_threshold:
|
||||
if num_access > self.high_file_access_threshold:
|
||||
return 3
|
||||
elif num_access > self.med_threshold:
|
||||
elif num_access > self.med_file_access_threshold:
|
||||
return 2
|
||||
elif num_access > self.low_threshold:
|
||||
elif num_access > self.low_file_access_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
@@ -122,6 +154,7 @@ class FileObservation(AbstractObservation, discriminator="file"):
|
||||
where=parent_where + ["files", config.file_name],
|
||||
include_num_access=config.include_num_access,
|
||||
file_system_requires_scan=config.file_system_requires_scan,
|
||||
thresholds=config.thresholds,
|
||||
)
|
||||
|
||||
|
||||
@@ -149,6 +182,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"):
|
||||
num_files: int,
|
||||
include_num_access: bool,
|
||||
file_system_requires_scan: bool,
|
||||
thresholds: Optional[Dict] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a folder observation instance.
|
||||
@@ -177,6 +211,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"):
|
||||
where=None,
|
||||
include_num_access=include_num_access,
|
||||
file_system_requires_scan=self.file_system_requires_scan,
|
||||
thresholds=thresholds,
|
||||
)
|
||||
)
|
||||
while len(self.files) > num_files:
|
||||
@@ -253,6 +288,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"):
|
||||
for file_config in config.files:
|
||||
file_config.include_num_access = config.include_num_access
|
||||
file_config.file_system_requires_scan = config.file_system_requires_scan
|
||||
file_config.thresholds = config.thresholds
|
||||
|
||||
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
|
||||
return cls(
|
||||
@@ -261,4 +297,5 @@ class FolderObservation(AbstractObservation, discriminator="folder"):
|
||||
num_files=config.num_files,
|
||||
include_num_access=config.include_num_access,
|
||||
file_system_requires_scan=config.file_system_requires_scan,
|
||||
thresholds=config.thresholds,
|
||||
)
|
||||
|
||||
@@ -54,7 +54,15 @@ class HostObservation(AbstractObservation, discriminator="host"):
|
||||
"""
|
||||
If True, files and folders must be scanned to update the health state. If False, true state is always shown.
|
||||
"""
|
||||
include_users: Optional[bool] = None
|
||||
services_requires_scan: Optional[bool] = None
|
||||
"""
|
||||
If True, services must be scanned to update the health state. If False, true state is always shown.
|
||||
"""
|
||||
applications_requires_scan: Optional[bool] = None
|
||||
"""
|
||||
If True, applications must be scanned to update the health state. If False, true state is always shown.
|
||||
"""
|
||||
include_users: Optional[bool] = True
|
||||
"""If True, report user session information."""
|
||||
|
||||
def __init__(
|
||||
@@ -73,6 +81,8 @@ class HostObservation(AbstractObservation, discriminator="host"):
|
||||
monitored_traffic: Optional[Dict],
|
||||
include_num_access: bool,
|
||||
file_system_requires_scan: bool,
|
||||
services_requires_scan: bool,
|
||||
applications_requires_scan: bool,
|
||||
include_users: bool,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -108,6 +118,12 @@ class HostObservation(AbstractObservation, discriminator="host"):
|
||||
:param file_system_requires_scan: If True, the files and folders must be scanned to update the health state.
|
||||
If False, the true state is always shown.
|
||||
:type file_system_requires_scan: bool
|
||||
:param services_requires_scan: If True, services must be scanned to update the health state.
|
||||
If False, the true state is always shown.
|
||||
:type services_requires_scan: bool
|
||||
:param applications_requires_scan: If True, applications must be scanned to update the health state.
|
||||
If False, the true state is always shown.
|
||||
:type applications_requires_scan: bool
|
||||
:param include_users: If True, report user session information.
|
||||
:type include_users: bool
|
||||
"""
|
||||
@@ -121,7 +137,7 @@ class HostObservation(AbstractObservation, discriminator="host"):
|
||||
# Ensure lists have lengths equal to specified counts by truncating or padding
|
||||
self.services: List[ServiceObservation] = services
|
||||
while len(self.services) < num_services:
|
||||
self.services.append(ServiceObservation(where=None))
|
||||
self.services.append(ServiceObservation(where=None, services_requires_scan=services_requires_scan))
|
||||
while len(self.services) > num_services:
|
||||
truncated_service = self.services.pop()
|
||||
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
|
||||
@@ -129,7 +145,9 @@ class HostObservation(AbstractObservation, discriminator="host"):
|
||||
|
||||
self.applications: List[ApplicationObservation] = applications
|
||||
while len(self.applications) < num_applications:
|
||||
self.applications.append(ApplicationObservation(where=None))
|
||||
self.applications.append(
|
||||
ApplicationObservation(where=None, applications_requires_scan=applications_requires_scan)
|
||||
)
|
||||
while len(self.applications) > num_applications:
|
||||
truncated_application = self.applications.pop()
|
||||
msg = f"Too many applications in Node observation space for node. Truncating {truncated_application.where}"
|
||||
@@ -153,7 +171,13 @@ class HostObservation(AbstractObservation, discriminator="host"):
|
||||
|
||||
self.nics: List[NICObservation] = network_interfaces
|
||||
while len(self.nics) < num_nics:
|
||||
self.nics.append(NICObservation(where=None, include_nmne=include_nmne, monitored_traffic=monitored_traffic))
|
||||
self.nics.append(
|
||||
NICObservation(
|
||||
where=None,
|
||||
include_nmne=include_nmne,
|
||||
monitored_traffic=monitored_traffic,
|
||||
)
|
||||
)
|
||||
while len(self.nics) > num_nics:
|
||||
truncated_nic = self.nics.pop()
|
||||
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}"
|
||||
@@ -269,8 +293,15 @@ class HostObservation(AbstractObservation, discriminator="host"):
|
||||
folder_config.include_num_access = config.include_num_access
|
||||
folder_config.num_files = config.num_files
|
||||
folder_config.file_system_requires_scan = config.file_system_requires_scan
|
||||
folder_config.thresholds = config.thresholds
|
||||
for nic_config in config.network_interfaces:
|
||||
nic_config.include_nmne = config.include_nmne
|
||||
nic_config.thresholds = config.thresholds
|
||||
for service_config in config.services:
|
||||
service_config.services_requires_scan = config.services_requires_scan
|
||||
for application_config in config.applications:
|
||||
application_config.applications_requires_scan = config.applications_requires_scan
|
||||
application_config.thresholds = config.thresholds
|
||||
|
||||
services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services]
|
||||
applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications]
|
||||
@@ -281,7 +312,10 @@ class HostObservation(AbstractObservation, discriminator="host"):
|
||||
count = 1
|
||||
while len(nics) < config.num_nics:
|
||||
nic_config = NICObservation.ConfigSchema(
|
||||
nic_num=count, include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic
|
||||
nic_num=count,
|
||||
include_nmne=config.include_nmne,
|
||||
monitored_traffic=config.monitored_traffic,
|
||||
thresholds=config.thresholds,
|
||||
)
|
||||
nics.append(NICObservation.from_config(config=nic_config, parent_where=where))
|
||||
count += 1
|
||||
@@ -301,5 +335,7 @@ class HostObservation(AbstractObservation, discriminator="host"):
|
||||
monitored_traffic=config.monitored_traffic,
|
||||
include_num_access=config.include_num_access,
|
||||
file_system_requires_scan=config.file_system_requires_scan,
|
||||
services_requires_scan=config.services_requires_scan,
|
||||
applications_requires_scan=config.applications_requires_scan,
|
||||
include_users=config.include_users,
|
||||
)
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from typing import ClassVar, Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
from primaite.simulator.network.nmne import NMNEConfig
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol
|
||||
from primaite.utils.validation.port import Port
|
||||
|
||||
@@ -15,6 +16,9 @@ from primaite.utils.validation.port import Port
|
||||
class NICObservation(AbstractObservation, discriminator="network-interface"):
|
||||
"""Status information about a network interface within the simulation environment."""
|
||||
|
||||
capture_nmne: ClassVar[bool] = NMNEConfig().capture_nmne
|
||||
"A Boolean specifying whether malicious network events should be captured."
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for NICObservation."""
|
||||
|
||||
@@ -25,7 +29,13 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
|
||||
monitored_traffic: Optional[Dict[IPProtocol, List[Port]]] = None
|
||||
"""A dict containing which traffic types are to be included in the observation."""
|
||||
|
||||
def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
include_nmne: bool,
|
||||
monitored_traffic: Optional[Dict] = None,
|
||||
thresholds: Dict = {},
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a network interface observation instance.
|
||||
|
||||
@@ -45,10 +55,18 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
|
||||
self.nmne_inbound_last_step: int = 0
|
||||
self.nmne_outbound_last_step: int = 0
|
||||
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_nmne_threshold = 10
|
||||
self.med_nmne_threshold = 5
|
||||
self.low_nmne_threshold = 0
|
||||
if thresholds.get("nmne") is None:
|
||||
self.low_nmne_threshold = 0
|
||||
self.med_nmne_threshold = 5
|
||||
self.high_nmne_threshold = 10
|
||||
else:
|
||||
self._set_nmne_threshold(
|
||||
thresholds=[
|
||||
thresholds.get("nmne")["low"],
|
||||
thresholds.get("nmne")["medium"],
|
||||
thresholds.get("nmne")["high"],
|
||||
]
|
||||
)
|
||||
|
||||
self.monitored_traffic = monitored_traffic
|
||||
if self.monitored_traffic:
|
||||
@@ -105,6 +123,20 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
|
||||
bandwidth_utilisation = traffic_value / nic_max_bandwidth
|
||||
return int(bandwidth_utilisation * 9) + 1
|
||||
|
||||
def _set_nmne_threshold(self, thresholds: List[int]):
|
||||
"""
|
||||
Method that validates and then sets the NMNE threshold.
|
||||
|
||||
:param: thresholds: The NMNE threshold to validate and set.
|
||||
"""
|
||||
if self._validate_thresholds(
|
||||
thresholds=thresholds,
|
||||
threshold_identifier="nmne",
|
||||
):
|
||||
self.low_nmne_threshold = thresholds[0]
|
||||
self.med_nmne_threshold = thresholds[1]
|
||||
self.high_nmne_threshold = thresholds[2]
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
@@ -116,7 +148,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
|
||||
"""
|
||||
nic_state = access_from_nested_dict(state, self.where)
|
||||
|
||||
if nic_state is NOT_PRESENT_IN_STATE:
|
||||
if nic_state is NOT_PRESENT_IN_STATE or self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
obs = {"nic_status": 1 if nic_state["enabled"] else 2}
|
||||
@@ -164,7 +196,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
|
||||
for port in self.monitored_traffic[protocol]:
|
||||
obs["TRAFFIC"][protocol][port] = {"inbound": 0, "outbound": 0}
|
||||
|
||||
if self.include_nmne:
|
||||
if self.capture_nmne and self.include_nmne:
|
||||
obs.update({"NMNE": {}})
|
||||
direction_dict = nic_state["nmne"].get("direction", {})
|
||||
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
|
||||
@@ -224,6 +256,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
|
||||
where=parent_where + ["NICs", config.nic_num],
|
||||
include_nmne=config.include_nmne,
|
||||
monitored_traffic=config.monitored_traffic,
|
||||
thresholds=config.thresholds,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -48,7 +48,13 @@ class NodesObservation(AbstractObservation, discriminator="nodes"):
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Flag to include the number of accesses."""
|
||||
file_system_requires_scan: bool = True
|
||||
"""If True, the folder must be scanned to update the health state. Tf False, the true state is always shown."""
|
||||
"""If True, the folder must be scanned to update the health state. If False, the true state is always shown."""
|
||||
services_requires_scan: bool = True
|
||||
"""If True, the services must be scanned to update the health state.
|
||||
If False, the true state is always shown."""
|
||||
applications_requires_scan: bool = True
|
||||
"""If True, the applications must be scanned to update the health state.
|
||||
If False, the true state is always shown."""
|
||||
include_users: Optional[bool] = True
|
||||
"""If True, report user session information."""
|
||||
num_ports: Optional[int] = None
|
||||
@@ -196,8 +202,14 @@ class NodesObservation(AbstractObservation, discriminator="nodes"):
|
||||
host_config.include_num_access = config.include_num_access
|
||||
if host_config.file_system_requires_scan is None:
|
||||
host_config.file_system_requires_scan = config.file_system_requires_scan
|
||||
if host_config.services_requires_scan is None:
|
||||
host_config.services_requires_scan = config.services_requires_scan
|
||||
if host_config.applications_requires_scan is None:
|
||||
host_config.applications_requires_scan = config.applications_requires_scan
|
||||
if host_config.include_users is None:
|
||||
host_config.include_users = config.include_users
|
||||
if not host_config.thresholds:
|
||||
host_config.thresholds = config.thresholds
|
||||
|
||||
for router_config in config.routers:
|
||||
if router_config.num_ports is None:
|
||||
@@ -214,6 +226,8 @@ class NodesObservation(AbstractObservation, discriminator="nodes"):
|
||||
router_config.num_rules = config.num_rules
|
||||
if router_config.include_users is None:
|
||||
router_config.include_users = config.include_users
|
||||
if not router_config.thresholds:
|
||||
router_config.thresholds = config.thresholds
|
||||
|
||||
for firewall_config in config.firewalls:
|
||||
if firewall_config.ip_list is None:
|
||||
@@ -228,6 +242,8 @@ class NodesObservation(AbstractObservation, discriminator="nodes"):
|
||||
firewall_config.num_rules = config.num_rules
|
||||
if firewall_config.include_users is None:
|
||||
firewall_config.include_users = config.include_users
|
||||
if not firewall_config.thresholds:
|
||||
firewall_config.thresholds = config.thresholds
|
||||
|
||||
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
|
||||
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]
|
||||
|
||||
@@ -114,7 +114,9 @@ class NestedObservation(AbstractObservation, discriminator="custom"):
|
||||
instances = dict()
|
||||
for component in config.components:
|
||||
obs_class = AbstractObservation._registry[component.type]
|
||||
obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options))
|
||||
obs_instance = obs_class.from_config(
|
||||
config=obs_class.ConfigSchema(**component.options, thresholds=config.thresholds)
|
||||
)
|
||||
instances[component.label] = obs_instance
|
||||
return cls(components=instances)
|
||||
|
||||
@@ -228,7 +230,7 @@ class ObservationManager(BaseModel):
|
||||
return self.obs.space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Optional[Dict]) -> "ObservationManager":
|
||||
def from_config(cls, config: Optional[Dict], thresholds: Optional[Dict] = {}) -> "ObservationManager":
|
||||
"""
|
||||
Create observation space from a config.
|
||||
|
||||
@@ -239,11 +241,10 @@ class ObservationManager(BaseModel):
|
||||
AbstractObservation
|
||||
options: this must adhere to the chosen observation type's ConfigSchema nested class.
|
||||
:type config: Dict
|
||||
:param thresholds: Dictionary containing the observation thresholds.
|
||||
:type thresholds: Optional[Dict]
|
||||
"""
|
||||
if config is None:
|
||||
return cls(NullObservation())
|
||||
obs_type = config["type"]
|
||||
obs_class = AbstractObservation._registry[obs_type]
|
||||
observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]))
|
||||
obs_manager = cls(observation)
|
||||
obs_manager = cls(config=config)
|
||||
return obs_manager
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
"""Manages the observation space for the agent."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Iterable, Optional, Type, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Type, Union
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -19,6 +19,9 @@ class AbstractObservation(ABC):
|
||||
class ConfigSchema(ABC, BaseModel):
|
||||
"""Config schema for observations."""
|
||||
|
||||
thresholds: Optional[Dict] = {}
|
||||
"""A dict containing the observation thresholds."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
_registry: Dict[str, Type["AbstractObservation"]] = {}
|
||||
@@ -69,3 +72,34 @@ class AbstractObservation(ABC):
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> "AbstractObservation":
|
||||
"""Create this observation space component form a serialised format."""
|
||||
return cls()
|
||||
|
||||
def _validate_thresholds(self, thresholds: List[int] = None, threshold_identifier: Optional[str] = "") -> bool:
|
||||
"""
|
||||
Method that checks if the thresholds are non overlapping and in the correct (ascending) order.
|
||||
|
||||
Pass in the thresholds from low to high e.g.
|
||||
thresholds=[low_threshold, med_threshold, ..._threshold, high_threshold]
|
||||
|
||||
Throws an error if the threshold is not valid
|
||||
|
||||
:param: thresholds: List of thresholds in ascending order.
|
||||
:type: List[int]
|
||||
:param: threshold_identifier: The name of the threshold option.
|
||||
:type: Optional[str]
|
||||
|
||||
:returns: bool
|
||||
"""
|
||||
if thresholds is None or len(thresholds) < 2:
|
||||
raise Exception(f"{threshold_identifier} thresholds are invalid {thresholds}")
|
||||
for idx in range(1, len(thresholds)):
|
||||
if not isinstance(thresholds[idx], int):
|
||||
raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.")
|
||||
if not isinstance(thresholds[idx - 1], int):
|
||||
raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.")
|
||||
|
||||
if thresholds[idx] <= thresholds[idx - 1]:
|
||||
raise Exception(
|
||||
f"{threshold_identifier} threshold ({thresholds[idx - 1]}) "
|
||||
f"is greater than or equal to ({thresholds[idx]}.)"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -19,7 +19,10 @@ class ServiceObservation(AbstractObservation, discriminator="service"):
|
||||
service_name: str
|
||||
"""Name of the service, used for querying simulation state dictionary"""
|
||||
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
services_requires_scan: Optional[bool] = None
|
||||
"""If True, services must be scanned to update the health state. If False, true state is always shown."""
|
||||
|
||||
def __init__(self, where: WhereType, services_requires_scan: bool) -> None:
|
||||
"""
|
||||
Initialise a service observation instance.
|
||||
|
||||
@@ -28,6 +31,7 @@ class ServiceObservation(AbstractObservation, discriminator="service"):
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.services_requires_scan = services_requires_scan
|
||||
self.default_observation = {"operating_status": 0, "health_status": 0}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
@@ -44,7 +48,9 @@ class ServiceObservation(AbstractObservation, discriminator="service"):
|
||||
return self.default_observation
|
||||
return {
|
||||
"operating_status": service_state["operating_state"],
|
||||
"health_status": service_state["health_state_visible"],
|
||||
"health_status": service_state["health_state_visible"]
|
||||
if self.services_requires_scan
|
||||
else service_state["health_state_actual"],
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -70,7 +76,9 @@ class ServiceObservation(AbstractObservation, discriminator="service"):
|
||||
:return: Constructed service observation instance.
|
||||
:rtype: ServiceObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["services", config.service_name])
|
||||
return cls(
|
||||
where=parent_where + ["services", config.service_name], services_requires_scan=config.services_requires_scan
|
||||
)
|
||||
|
||||
|
||||
class ApplicationObservation(AbstractObservation, discriminator="application"):
|
||||
@@ -82,7 +90,12 @@ class ApplicationObservation(AbstractObservation, discriminator="application"):
|
||||
application_name: str
|
||||
"""Name of the application, used for querying simulation state dictionary"""
|
||||
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
applications_requires_scan: Optional[bool] = None
|
||||
"""
|
||||
If True, applications must be scanned to update the health state. If False, true state is always shown.
|
||||
"""
|
||||
|
||||
def __init__(self, where: WhereType, applications_requires_scan: bool, thresholds: Optional[Dict] = {}) -> None:
|
||||
"""
|
||||
Initialise an application observation instance.
|
||||
|
||||
@@ -92,25 +105,52 @@ class ApplicationObservation(AbstractObservation, discriminator="application"):
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.applications_requires_scan = applications_requires_scan
|
||||
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
|
||||
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_threshold = 10
|
||||
self.med_threshold = 5
|
||||
self.low_threshold = 0
|
||||
if thresholds.get("app_executions") is None:
|
||||
self.low_app_execution_threshold = 0
|
||||
self.med_app_execution_threshold = 5
|
||||
self.high_app_execution_threshold = 10
|
||||
else:
|
||||
self._set_application_execution_thresholds(
|
||||
thresholds=[
|
||||
thresholds.get("app_executions")["low"],
|
||||
thresholds.get("app_executions")["medium"],
|
||||
thresholds.get("app_executions")["high"],
|
||||
]
|
||||
)
|
||||
|
||||
def _set_application_execution_thresholds(self, thresholds: List[int]):
|
||||
"""
|
||||
Method that validates and then sets the application execution threshold.
|
||||
|
||||
:param: thresholds: The application execution threshold to validate and set.
|
||||
"""
|
||||
if self._validate_thresholds(
|
||||
thresholds=[
|
||||
thresholds[0],
|
||||
thresholds[1],
|
||||
thresholds[2],
|
||||
],
|
||||
threshold_identifier="app_executions",
|
||||
):
|
||||
self.low_app_execution_threshold = thresholds[0]
|
||||
self.med_app_execution_threshold = thresholds[1]
|
||||
self.high_app_execution_threshold = thresholds[2]
|
||||
|
||||
def _categorise_num_executions(self, num_executions: int) -> int:
|
||||
"""
|
||||
Represent number of file accesses as a categorical variable.
|
||||
Represent number of application executions as a categorical variable.
|
||||
|
||||
:param num_access: Number of file accesses.
|
||||
:param num_access: Number of application executions.
|
||||
:return: Bin number corresponding to the number of accesses.
|
||||
"""
|
||||
if num_executions > self.high_threshold:
|
||||
if num_executions > self.high_app_execution_threshold:
|
||||
return 3
|
||||
elif num_executions > self.med_threshold:
|
||||
elif num_executions > self.med_app_execution_threshold:
|
||||
return 2
|
||||
elif num_executions > self.low_threshold:
|
||||
elif num_executions > self.low_app_execution_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
@@ -128,7 +168,9 @@ class ApplicationObservation(AbstractObservation, discriminator="application"):
|
||||
return self.default_observation
|
||||
return {
|
||||
"operating_status": application_state["operating_state"],
|
||||
"health_status": application_state["health_state_visible"],
|
||||
"health_status": application_state["health_state_visible"]
|
||||
if self.applications_requires_scan
|
||||
else application_state["health_state_actual"],
|
||||
"num_executions": self._categorise_num_executions(application_state["num_executions"]),
|
||||
}
|
||||
|
||||
@@ -161,4 +203,8 @@ class ApplicationObservation(AbstractObservation, discriminator="application"):
|
||||
:return: Constructed application observation instance.
|
||||
:rtype: ApplicationObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["applications", config.application_name])
|
||||
return cls(
|
||||
where=parent_where + ["applications", config.application_name],
|
||||
applications_requires_scan=config.applications_requires_scan,
|
||||
thresholds=config.thresholds,
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import DEFAULT_BANDWIDTH, getLogger
|
||||
from primaite.game.agent.interface import AbstractAgent, ProxyAgent
|
||||
from primaite.game.agent.observations import NICObservation
|
||||
from primaite.game.agent.rewards import SharedReward
|
||||
from primaite.game.science import graph_has_cycle, topological_sort
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
@@ -44,15 +45,15 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
SERVICE_TYPES_MAPPING = {
|
||||
"DNSClient": DNSClient,
|
||||
"DNSServer": DNSServer,
|
||||
"DatabaseService": DatabaseService,
|
||||
"WebServer": WebServer,
|
||||
"FTPClient": FTPClient,
|
||||
"FTPServer": FTPServer,
|
||||
"NTPClient": NTPClient,
|
||||
"NTPServer": NTPServer,
|
||||
"Terminal": Terminal,
|
||||
"dns-client": DNSClient,
|
||||
"dns-server": DNSServer,
|
||||
"database-service": DatabaseService,
|
||||
"web-server": WebServer,
|
||||
"ftp-client": FTPClient,
|
||||
"ftp-server": FTPServer,
|
||||
"ntp-client": NTPClient,
|
||||
"ntp-server": NTPServer,
|
||||
"terminal": Terminal,
|
||||
}
|
||||
"""List of available services that can be installed on nodes in the PrimAITE Simulation."""
|
||||
|
||||
@@ -68,6 +69,8 @@ class PrimaiteGameOptions(BaseModel):
|
||||
|
||||
seed: int = None
|
||||
"""Random number seed for RNGs."""
|
||||
generate_seed_value: bool = False
|
||||
"""Internally generated seed value."""
|
||||
max_episode_length: int = 256
|
||||
"""Maximum number of episodes for the PrimAITE game."""
|
||||
ports: List[Port]
|
||||
@@ -175,6 +178,7 @@ class PrimaiteGame:
|
||||
parameters=parameters,
|
||||
request=request,
|
||||
response=response,
|
||||
observation=obs,
|
||||
)
|
||||
|
||||
def pre_timestep(self) -> None:
|
||||
@@ -263,6 +267,7 @@ class PrimaiteGame:
|
||||
node_sets_cfg = network_config.get("node_sets", [])
|
||||
# Set the NMNE capture config
|
||||
NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {}))
|
||||
NICObservation.capture_nmne = NMNEConfig(**network_config.get("nmne_config", {})).capture_nmne
|
||||
|
||||
for node_cfg in nodes_cfg:
|
||||
n_type = node_cfg["type"]
|
||||
@@ -293,6 +298,7 @@ class PrimaiteGame:
|
||||
|
||||
if "users" in node_cfg and new_node.software_manager.software.get("user-manager"):
|
||||
user_manager: UserManager = new_node.software_manager.software["user-manager"] # noqa
|
||||
|
||||
for user_cfg in node_cfg["users"]:
|
||||
user_manager.add_user(**user_cfg, bypass_can_perform_action=True)
|
||||
|
||||
@@ -407,6 +413,7 @@ class PrimaiteGame:
|
||||
agents_cfg = cfg.get("agents", [])
|
||||
|
||||
for agent_cfg in agents_cfg:
|
||||
agent_cfg = {**agent_cfg, "thresholds": game.options.thresholds}
|
||||
new_agent = AbstractAgent.from_config(agent_cfg)
|
||||
game.agents[agent_cfg["ref"]] = new_agent
|
||||
if isinstance(new_agent, ProxyAgent):
|
||||
|
||||
@@ -47,7 +47,7 @@
|
||||
"source": [
|
||||
"def make_cfg_have_flat_obs(cfg):\n",
|
||||
" for agent in cfg['agents']:\n",
|
||||
" if agent['type'] == \"ProxyAgent\":\n",
|
||||
" if agent['type'] == \"proxy-agent\":\n",
|
||||
" agent['agent_settings']['flatten_obs'] = False"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -684,6 +684,15 @@
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.game.agents[\"data_manipulation_attacker\"].show_history()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
||||
@@ -153,6 +153,49 @@
|
||||
"PRIMAITE_CONFIG[\"developer_mode\"][\"enabled\"] = was_enabled\n",
|
||||
"PRIMAITE_CONFIG[\"developer_mode\"][\"output_sys_logs\"] = was_syslogs_enabled"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Viewing Agent history\n",
|
||||
"\n",
|
||||
"It's possible to view the actions carried out by an agent for a given training session using the `show_history()` method. By default, this will be all actions apart from DONOTHING actions."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"\n",
|
||||
"env = PrimaiteGymEnv(env_config=cfg)\n",
|
||||
"\n",
|
||||
"# Run the training session to generate some resultant data.\n",
|
||||
"for i in range(100):\n",
|
||||
" env.step(0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `.show_history()` should show us when the Data Manipulation used the `NODE_APPLICATION_EXECUTE` action."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"attacker = env.game.agents[\"data_manipulation_attacker\"]\n",
|
||||
"\n",
|
||||
"attacker.show_history()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -15,15 +15,6 @@
|
||||
"To display the available dev-mode options, run the command below:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
@@ -9,6 +9,13 @@
|
||||
"© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Simulation Layer Implementation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -186,6 +193,22 @@
|
||||
"computer_b.file_system.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Information about the latest response when executing a remote command can be seen by calling the `last_response` attribute within `Terminal`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(terminal_a.last_response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -224,11 +247,254 @@
|
||||
"source": [
|
||||
"computer_b.user_session_manager.show(include_historic=True, include_session_id=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Game Layer Implementation\n",
|
||||
"\n",
|
||||
"This notebook section will detail the implementation of how the game layer utilises the terminal to support different agent actions.\n",
|
||||
"\n",
|
||||
"The ``Terminal`` is used in a variety of different ways in the game layer. Specifically, the terminal is leveraged to implement the following actions:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"| Game Layer Action | Simulation Layer |\n",
|
||||
"|-----------------------------------|--------------------------|\n",
|
||||
"| ``node-send-local-command`` | Uses the given user credentials, creates a ``LocalTerminalSession`` and executes the given command and returns the ``RequestResponse``.\n",
|
||||
"| ``node-session-remote-login`` | Uses the given user credentials and remote IP to create a ``RemoteTerminalSession``.\n",
|
||||
"| ``node-send-remote-command`` | Uses the given remote IP to locate the correct ``RemoteTerminalSession``, executes the given command and returns the ``RequestsResponse``."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Game Layer Setup\n",
|
||||
"\n",
|
||||
"Similar to other notebooks, the next code cells create a custom proxy agent to demonstrate how these commands can be leveraged by agents in the ``UC2`` network environment.\n",
|
||||
"\n",
|
||||
"If you're unfamiliar with ``UC2`` then please refer to the [UC2-E2E-Demo notebook for further reference](./Data-Manipulation-E2E-Demonstration.ipynb)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import yaml\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"custom_terminal_agent = \"\"\"\n",
|
||||
" - ref: CustomC2Agent\n",
|
||||
" team: RED\n",
|
||||
" type: proxy-agent\n",
|
||||
" action_space:\n",
|
||||
" action_map:\n",
|
||||
" 0:\n",
|
||||
" action: do-nothing\n",
|
||||
" options: {}\n",
|
||||
" 1:\n",
|
||||
" action: node-send-local-command\n",
|
||||
" options:\n",
|
||||
" node_name: client_1\n",
|
||||
" username: admin\n",
|
||||
" password: admin\n",
|
||||
" command:\n",
|
||||
" - file_system\n",
|
||||
" - create\n",
|
||||
" - file\n",
|
||||
" - downloads\n",
|
||||
" - dog.png\n",
|
||||
" - False\n",
|
||||
" 2:\n",
|
||||
" action: node-session-remote-login\n",
|
||||
" options:\n",
|
||||
" node_name: client_1\n",
|
||||
" username: admin\n",
|
||||
" password: admin\n",
|
||||
" remote_ip: 192.168.10.22\n",
|
||||
" 3:\n",
|
||||
" action: node-send-remote-command\n",
|
||||
" options:\n",
|
||||
" node_name: client_1\n",
|
||||
" remote_ip: 192.168.10.22\n",
|
||||
" command:\n",
|
||||
" - file_system\n",
|
||||
" - create\n",
|
||||
" - file\n",
|
||||
" - downloads\n",
|
||||
" - cat.png\n",
|
||||
" - False\n",
|
||||
"\"\"\"\n",
|
||||
"custom_terminal_agent_yaml = yaml.safe_load(custom_terminal_agent)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(data_manipulation_config_path()) as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
" # removing all agents & adding the custom agent.\n",
|
||||
" cfg['agents'] = {}\n",
|
||||
" cfg['agents'] = custom_terminal_agent_yaml\n",
|
||||
"\n",
|
||||
"env = PrimaiteGymEnv(env_config=cfg)\n",
|
||||
"\n",
|
||||
"client_1: Computer = env.game.simulation.network.get_node_by_hostname(\"client_1\")\n",
|
||||
"client_2: Computer = env.game.simulation.network.get_node_by_hostname(\"client_2\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Terminal Action | ``node-send-local-command`` \n",
|
||||
"\n",
|
||||
"The yaml snippet below shows all the relevant agent options for this action:\n",
|
||||
"\n",
|
||||
"```yaml\n",
|
||||
"\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: node-send-local-command\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" - node_name: client_1\n",
|
||||
" ...\n",
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
" 1:\n",
|
||||
" action: node-send-local-command\n",
|
||||
" options:\n",
|
||||
" node_id: 0 # Index 0 at the node list.\n",
|
||||
" username: admin\n",
|
||||
" password: admin\n",
|
||||
" command:\n",
|
||||
" - file_system\n",
|
||||
" - create\n",
|
||||
" - file\n",
|
||||
" - downloads\n",
|
||||
" - dog.png\n",
|
||||
" - False\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.step(1)\n",
|
||||
"client_1.file_system.show(full=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Terminal Action | ``node-session-remote-login`` \n",
|
||||
"\n",
|
||||
"The yaml snippet below shows all the relevant agent options for this action:\n",
|
||||
"\n",
|
||||
"```yaml\n",
|
||||
"\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: node-session-remote-login\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" - node_name: client_1\n",
|
||||
" ...\n",
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
" 2:\n",
|
||||
" action: node-session-remote-login\n",
|
||||
" options:\n",
|
||||
" node_id: 0 # Index 0 at the node list.\n",
|
||||
" username: admin\n",
|
||||
" password: admin\n",
|
||||
" remote_ip: 192.168.10.22 # client_2's ip address.\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.step(2)\n",
|
||||
"client_2.session_manager.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Terminal Action | ``node-send-remote-command``\n",
|
||||
"\n",
|
||||
"The yaml snippet below shows all the relevant agent options for this action:\n",
|
||||
"\n",
|
||||
"```yaml\n",
|
||||
"\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: node-send-remote-command\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" - node_name: client_1\n",
|
||||
" ...\n",
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
" 1:\n",
|
||||
" action: node-send-remote-command\n",
|
||||
" options:\n",
|
||||
" node_id: 0 # Index 0 at the node list.\n",
|
||||
" remote_ip: 192.168.10.22\n",
|
||||
" commands:\n",
|
||||
" - file_system\n",
|
||||
" - create\n",
|
||||
" - file\n",
|
||||
" - downloads\n",
|
||||
" - cat.png\n",
|
||||
" - False\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.step(3)\n",
|
||||
"client_2.file_system.show(full=True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
||||
@@ -26,14 +26,26 @@ except ModuleNotFoundError:
|
||||
_LOGGER.debug("Torch not available for importing")
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> Union[None, int]:
|
||||
def set_random_seed(seed: int, generate_seed_value: bool) -> Union[None, int]:
|
||||
"""
|
||||
Set random number generators.
|
||||
|
||||
If seed is None or -1 and generate_seed_value is True randomly generate a
|
||||
seed value.
|
||||
If seed is > -1 and generate_seed_value is True ignore the latter and use
|
||||
the provide seed value.
|
||||
|
||||
:param seed: int
|
||||
:param generate_seed_value: bool
|
||||
:return: None or the int representing the seed used.
|
||||
"""
|
||||
if seed is None or seed == -1:
|
||||
return None
|
||||
if generate_seed_value:
|
||||
rng = np.random.default_rng()
|
||||
# 2**32-1 is highest value for python RNG seed.
|
||||
seed = int(rng.integers(low=0, high=2**32 - 1))
|
||||
else:
|
||||
return None
|
||||
elif seed < -1:
|
||||
raise ValueError("Invalid random number seed")
|
||||
# Seed python RNG
|
||||
@@ -50,6 +62,13 @@ def set_random_seed(seed: int) -> Union[None, int]:
|
||||
return seed
|
||||
|
||||
|
||||
def log_seed_value(seed: int):
|
||||
"""Log the selected seed value to file."""
|
||||
path = SIM_OUTPUT.path / "seed.log"
|
||||
with open(path, "w") as file:
|
||||
file.write(f"Seed value = {seed}")
|
||||
|
||||
|
||||
class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""
|
||||
Thin wrapper env to provide agents with a gymnasium API.
|
||||
@@ -65,7 +84,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""Object that returns a config corresponding to the current episode."""
|
||||
self.seed = self.episode_scheduler(0).get("game", {}).get("seed")
|
||||
"""Get RNG seed from config file. NB: Must be before game instantiation."""
|
||||
self.seed = set_random_seed(self.seed)
|
||||
self.generate_seed_value = self.episode_scheduler(0).get("game", {}).get("generate_seed_value")
|
||||
self.seed = set_random_seed(self.seed, self.generate_seed_value)
|
||||
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
|
||||
@@ -79,6 +99,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
|
||||
_LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}")
|
||||
|
||||
log_seed_value(self.seed)
|
||||
|
||||
def action_masks(self) -> np.ndarray:
|
||||
"""
|
||||
Return the action mask for the agent.
|
||||
@@ -146,7 +168,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
f"avg. reward: {self.agent.reward_function.total_reward}"
|
||||
)
|
||||
if seed is not None:
|
||||
set_random_seed(seed)
|
||||
set_random_seed(seed, self.generate_seed_value)
|
||||
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
|
||||
|
||||
if self.io.settings.save_agent_actions:
|
||||
|
||||
@@ -862,7 +862,21 @@ class UserManager(Service, discriminator="user-manager"):
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
# todo add doc about requeest schemas
|
||||
# todo add doc about request schemas
|
||||
rm.add_request(
|
||||
"add_user",
|
||||
RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.add_user(username=request[0], password=request[1], is_admin=request[2])
|
||||
)
|
||||
),
|
||||
)
|
||||
rm.add_request(
|
||||
"disable_user",
|
||||
RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.disable_user(username=request[0]))
|
||||
),
|
||||
)
|
||||
rm.add_request(
|
||||
"change_password",
|
||||
RequestType(
|
||||
@@ -1570,7 +1584,7 @@ class Node(SimComponent, ABC):
|
||||
|
||||
operating_state: Any = None
|
||||
|
||||
users: Any = None # Temporary to appease "extra=forbid"
|
||||
users: List[Dict] = [] # Temporary to appease "extra=forbid"
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema())
|
||||
"""Configuration items within Node"""
|
||||
@@ -1636,6 +1650,8 @@ class Node(SimComponent, ABC):
|
||||
self._install_system_software()
|
||||
self.session_manager.node = self
|
||||
self.session_manager.software_manager = self.software_manager
|
||||
for user in self.config.users:
|
||||
self.user_manager.add_user(**user, bypass_can_perform_action=True)
|
||||
|
||||
@property
|
||||
def user_manager(self) -> Optional[UserManager]:
|
||||
@@ -1767,7 +1783,7 @@ class Node(SimComponent, ABC):
|
||||
"""
|
||||
application_name = request[0]
|
||||
if self.software_manager.software.get(application_name):
|
||||
self.sys_log.warning(f"Can't install {application_name}. It's already installed.")
|
||||
self.sys_log.info(f"Can't install {application_name}. It's already installed.")
|
||||
return RequestResponse(status="success", data={"reason": "already installed"})
|
||||
application_class = Application._registry[application_name]
|
||||
self.software_manager.install(application_class)
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, ClassVar, Dict, Literal, Optional
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.file_system.file_type import FileType
|
||||
from primaite.simulator.network.hardware.base import (
|
||||
IPWiredNetworkInterface,
|
||||
Link,
|
||||
@@ -339,7 +340,7 @@ class HostNode(Node, discriminator="host-node"):
|
||||
ip_address: IPV4Address
|
||||
services: Any = None # temporarily unset to appease extra="forbid"
|
||||
applications: Any = None # temporarily unset to appease extra="forbid"
|
||||
folders: Any = None # temporarily unset to appease extra="forbid"
|
||||
folders: List[Dict] = {} # temporarily unset to appease extra="forbid"
|
||||
network_interfaces: Any = None # temporarily unset to appease extra="forbid"
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema())
|
||||
@@ -348,6 +349,18 @@ class HostNode(Node, discriminator="host-node"):
|
||||
super().__init__(**kwargs)
|
||||
self.connect_nic(NIC(ip_address=kwargs["config"].ip_address, subnet_mask=kwargs["config"].subnet_mask))
|
||||
|
||||
for folder in self.config.folders:
|
||||
# handle empty foler defined by just a string
|
||||
self.file_system.create_folder(folder["folder_name"])
|
||||
|
||||
for file in folder.get("files", []):
|
||||
self.file_system.create_file(
|
||||
folder_name=folder["folder_name"],
|
||||
file_name=file["file_name"],
|
||||
size=file.get("size", 0),
|
||||
file_type=FileType[file.get("type", "UNKNOWN").upper()],
|
||||
)
|
||||
|
||||
@property
|
||||
def nmap(self) -> Optional[NMAP]:
|
||||
"""
|
||||
|
||||
@@ -49,7 +49,7 @@ class Firewall(Router, discriminator="firewall"):
|
||||
|
||||
Example:
|
||||
>>> from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
>>> from primaite.simulator.network.transmission.transport_layer import Port
|
||||
>>> from primaite.utils.validation.port import Port
|
||||
>>> firewall = Firewall(hostname="Firewall1")
|
||||
>>> firewall.configure_internal_port(ip_address="192.168.1.1", subnet_mask="255.255.255.0")
|
||||
>>> firewall.configure_external_port(ip_address="10.0.0.1", subnet_mask="255.255.255.0")
|
||||
|
||||
@@ -467,6 +467,7 @@ class AccessControlList(SimComponent):
|
||||
"""Check if a packet with the given properties is permitted through the ACL."""
|
||||
permitted = False
|
||||
rule: ACLRule = None
|
||||
|
||||
for _rule in self._acl:
|
||||
if not _rule:
|
||||
continue
|
||||
@@ -1215,9 +1216,9 @@ class Router(NetworkNode, discriminator="router"):
|
||||
config: ConfigSchema = Field(default_factory=lambda: Router.ConfigSchema())
|
||||
|
||||
SYSTEM_SOFTWARE: ClassVar[Dict] = {
|
||||
"UserSessionManager": UserSessionManager,
|
||||
"UserManager": UserManager,
|
||||
"Terminal": Terminal,
|
||||
"user-session-manager": UserSessionManager,
|
||||
"user-manager": UserManager,
|
||||
"terminal": Terminal,
|
||||
}
|
||||
|
||||
network_interfaces: Dict[str, RouterInterface] = {}
|
||||
@@ -1384,6 +1385,12 @@ class Router(NetworkNode, discriminator="router"):
|
||||
|
||||
return False
|
||||
|
||||
def subject_to_acl(self, frame: Frame) -> bool:
|
||||
"""Check that frame is subject to ACL rules."""
|
||||
if frame.ip.protocol == "udp" and frame.is_arp:
|
||||
return False
|
||||
return True
|
||||
|
||||
def receive_frame(self, frame: Frame, from_network_interface: RouterInterface):
|
||||
"""
|
||||
Processes an incoming frame received on one of the router's interfaces.
|
||||
@@ -1397,8 +1404,12 @@ class Router(NetworkNode, discriminator="router"):
|
||||
if self.operating_state != NodeOperatingState.ON:
|
||||
return
|
||||
|
||||
# Check if it's permitted
|
||||
permitted, rule = self.acl.is_permitted(frame)
|
||||
if self.subject_to_acl(frame=frame):
|
||||
# Check if it's permitted
|
||||
permitted, rule = self.acl.is_permitted(frame)
|
||||
else:
|
||||
permitted = True
|
||||
rule = None
|
||||
|
||||
if not permitted:
|
||||
at_port = self._get_port_of_nic(from_network_interface)
|
||||
|
||||
@@ -163,7 +163,7 @@ class Frame(BaseModel):
|
||||
"""
|
||||
Checks if the Frame is an ARP (Address Resolution Protocol) packet.
|
||||
|
||||
This is determined by checking if the destination port of the TCP header is equal to the ARP port.
|
||||
This is determined by checking if the destination and source port of the UDP header is equal to the ARP port.
|
||||
|
||||
:return: True if the Frame is an ARP packet, otherwise False.
|
||||
"""
|
||||
|
||||
@@ -415,5 +415,5 @@ class SessionManager:
|
||||
table.align = "l"
|
||||
table.title = f"{self.sys_log.hostname} Session Manager"
|
||||
for session in self.sessions_by_key.values():
|
||||
table.add_row([session.dst_ip_address, session.dst_port, session.protocol])
|
||||
table.add_row([session.with_ip_address, session.dst_port, session.protocol])
|
||||
print(table)
|
||||
|
||||
@@ -55,7 +55,7 @@ class ARP(Service, discriminator="arp"):
|
||||
|
||||
:param markdown: If True, format the output as Markdown. Otherwise, use plain text.
|
||||
"""
|
||||
table = PrettyTable(["IP Address", "MAC Address", "Via"])
|
||||
table = PrettyTable(["IP Address", "MAC Address", "Via", "Port"])
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
@@ -66,6 +66,7 @@ class ARP(Service, discriminator="arp"):
|
||||
str(ip),
|
||||
arp.mac_address,
|
||||
self.software_manager.node.network_interfaces[arp.network_interface_uuid].mac_address,
|
||||
self.software_manager.node.network_interfaces[arp.network_interface_uuid].port_num,
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
|
||||
@@ -142,12 +142,20 @@ class Terminal(Service, discriminator="terminal"):
|
||||
_client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {}
|
||||
"""Dictionary of connect requests made to remote nodes."""
|
||||
|
||||
_last_response: Optional[RequestResponse] = None
|
||||
"""Last response received from RequestManager, for returning remote RequestResponse."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "terminal"
|
||||
kwargs["port"] = PORT_LOOKUP["SSH"]
|
||||
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def last_response(self) -> Optional[RequestResponse]:
|
||||
"""Public version of _last_response attribute."""
|
||||
return self._last_response
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
@@ -186,7 +194,7 @@ class Terminal(Service, discriminator="terminal"):
|
||||
return RequestResponse(status="failure", data={})
|
||||
|
||||
rm.add_request(
|
||||
"node-session-remote-login",
|
||||
"node_session_remote_login",
|
||||
request_type=RequestType(func=_remote_login),
|
||||
)
|
||||
|
||||
@@ -209,28 +217,45 @@ class Terminal(Service, discriminator="terminal"):
|
||||
command: str = request[1]["command"]
|
||||
remote_connection = self._get_connection_from_ip(ip_address=ip_address)
|
||||
if remote_connection:
|
||||
outcome = remote_connection.execute(command)
|
||||
if outcome:
|
||||
return RequestResponse(
|
||||
status="success",
|
||||
data={},
|
||||
)
|
||||
else:
|
||||
return RequestResponse(
|
||||
status="failure",
|
||||
data={},
|
||||
)
|
||||
remote_connection.execute(command)
|
||||
return self.last_response if not None else RequestResponse(status="failure", data={})
|
||||
return RequestResponse(
|
||||
status="failure",
|
||||
data={"reason": "Failed to execute command."},
|
||||
)
|
||||
|
||||
rm.add_request(
|
||||
"send_remote_command",
|
||||
request_type=RequestType(func=remote_execute_request),
|
||||
)
|
||||
|
||||
def local_execute_request(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
"""Executes a command using a local terminal session."""
|
||||
command: str = request[2]["command"]
|
||||
local_connection = self._process_local_login(username=request[0], password=request[1])
|
||||
if local_connection:
|
||||
outcome = local_connection.execute(command)
|
||||
if outcome:
|
||||
return RequestResponse(
|
||||
status="success",
|
||||
data={"reason": outcome},
|
||||
)
|
||||
return RequestResponse(
|
||||
status="success",
|
||||
data={"reason": "Local Terminal failed to resolve command. Potentially invalid credentials?"},
|
||||
)
|
||||
|
||||
rm.add_request(
|
||||
"send_local_command",
|
||||
request_type=RequestType(func=local_execute_request),
|
||||
)
|
||||
|
||||
return rm
|
||||
|
||||
def execute(self, command: List[Any]) -> Optional[RequestResponse]:
|
||||
"""Execute a passed ssh command via the request manager."""
|
||||
return self.parent.apply_request(command)
|
||||
self._last_response = self.parent.apply_request(command)
|
||||
return self._last_response
|
||||
|
||||
def _get_connection_from_ip(self, ip_address: IPv4Address) -> Optional[RemoteTerminalConnection]:
|
||||
"""Find Remote Terminal Connection from a given IP."""
|
||||
@@ -409,6 +434,8 @@ class Terminal(Service, discriminator="terminal"):
|
||||
"""
|
||||
source_ip = kwargs["frame"].ip.src_ip_address
|
||||
self.sys_log.info(f"{self.name}: Received payload: {payload}. Source: {source_ip}")
|
||||
self._last_response = None # Clear last response
|
||||
|
||||
if isinstance(payload, SSHPacket):
|
||||
if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST:
|
||||
# validate & add connection
|
||||
@@ -457,6 +484,9 @@ class Terminal(Service, discriminator="terminal"):
|
||||
session_id=session_id,
|
||||
source_ip=source_ip,
|
||||
)
|
||||
self._last_response: RequestResponse = RequestResponse(
|
||||
status="success", data={"reason": "Login Successful"}
|
||||
)
|
||||
|
||||
elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST:
|
||||
# Requesting a command to be executed
|
||||
@@ -468,12 +498,32 @@ class Terminal(Service, discriminator="terminal"):
|
||||
payload.connection_uuid
|
||||
)
|
||||
remote_session.last_active_step = self.software_manager.node.user_session_manager.current_timestep
|
||||
self.execute(command)
|
||||
self._last_response: RequestResponse = self.execute(command)
|
||||
|
||||
if self._last_response.status == "success":
|
||||
transport_message = SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS
|
||||
else:
|
||||
transport_message = SSHTransportMessage.SSH_MSG_SERVICE_FAILED
|
||||
|
||||
payload: SSHPacket = SSHPacket(
|
||||
payload=self._last_response,
|
||||
transport_message=transport_message,
|
||||
connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_DATA,
|
||||
)
|
||||
self.software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_port=self.port, session_id=session_id
|
||||
)
|
||||
return True
|
||||
else:
|
||||
self.sys_log.error(
|
||||
f"{self.name}: Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command."
|
||||
)
|
||||
elif (
|
||||
payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS
|
||||
or SSHTransportMessage.SSH_MSG_SERVICE_FAILED
|
||||
):
|
||||
# Likely receiving command ack from remote.
|
||||
self._last_response = payload.payload
|
||||
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "disconnect":
|
||||
|
||||
@@ -117,37 +117,44 @@ class WebServer(Service, discriminator="web-server"):
|
||||
:type: payload: HttpRequestPacket
|
||||
"""
|
||||
response = HttpResponsePacket(status_code=HttpStatusCode.NOT_FOUND, payload=payload)
|
||||
try:
|
||||
parsed_url = urlparse(payload.request_url)
|
||||
path = parsed_url.path.strip("/")
|
||||
|
||||
if len(path) < 1:
|
||||
parsed_url = urlparse(payload.request_url)
|
||||
path = parsed_url.path.strip("/") if parsed_url and parsed_url.path else ""
|
||||
|
||||
if len(path) < 1:
|
||||
# query succeeded
|
||||
response.status_code = HttpStatusCode.OK
|
||||
|
||||
if path.startswith("users"):
|
||||
# get data from DatabaseServer
|
||||
# get all users
|
||||
if not self._establish_db_connection():
|
||||
# unable to create a db connection
|
||||
response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR
|
||||
return response
|
||||
|
||||
if self.db_connection.query("SELECT"):
|
||||
# query succeeded
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
response.status_code = HttpStatusCode.OK
|
||||
else:
|
||||
self.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
return response
|
||||
|
||||
if path.startswith("users"):
|
||||
# get data from DatabaseServer
|
||||
# get all users
|
||||
if not self.db_connection:
|
||||
self._establish_db_connection()
|
||||
|
||||
if self.db_connection.query("SELECT"):
|
||||
# query succeeded
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
response.status_code = HttpStatusCode.OK
|
||||
else:
|
||||
self.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
|
||||
return response
|
||||
except Exception: # TODO: refactor this. Likely to cause silent bugs. (ADO ticket #2345 )
|
||||
# something went wrong on the server
|
||||
response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR
|
||||
return response
|
||||
|
||||
def _establish_db_connection(self) -> None:
|
||||
def _establish_db_connection(self) -> bool:
|
||||
"""Establish a connection to db."""
|
||||
# if active db connection, return true
|
||||
if self.db_connection:
|
||||
return True
|
||||
|
||||
# otherwise, try to create db connection
|
||||
db_client = self.software_manager.software.get("database-client")
|
||||
|
||||
if db_client is None:
|
||||
return False # database client not installed
|
||||
|
||||
self.db_connection: DatabaseClientConnection = db_client.get_new_connection()
|
||||
return self.db_connection is not None
|
||||
|
||||
def send(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user