Merge remote-tracking branch 'origin/dev' into feature/2689-command-and-control

This commit is contained in:
Archer Bowen
2024-08-20 09:30:43 +01:00
26 changed files with 1761 additions and 103 deletions

View File

@@ -10,12 +10,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Random Number Generator Seeding by specifying a random number seed in the config file.
- Implemented Terminal service class, providing a generic terminal simulation.
- Added `User`, `UserManager` and `UserSessionManager` to enable the creation of user accounts and login on Nodes.
- Added actions to establish SSH connections, send commands remotely and terminate SSH connections.
- Added actions to change users' passwords.
- Added a `listen_on_ports` set in the `IOSoftware` class to enable software listening on ports in addition to the
main port they're assigned.
- Added two new red applications: ``C2Beacon`` and ``C2Server`` which aim to simulate malicious network infrastructure.
Refer to the ``Command and Control Application Suite E2E Demonstration`` notebook for more information.
- Added reward calculation details to AgentHistoryItem.
### Changed
- File and folder observations can now be configured to always show the true health status, or require scanning like before.
- Node observations can now be configured to show the number of active local and remote logins.
### Fixed
- Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config)
- Updated `SoftwareManager` `install` and `uninstall` to handle all functionality that was being done at the `install`
and `uninstall` methods in the `Node` class.
- Updated the `receive_payload_from_session_manager` method in `SoftwareManager` so that it now sends a copy of the

View File

@@ -5,7 +5,7 @@ from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Final, Tuple
from report import build_benchmark_md_report
from report import build_benchmark_md_report, md2pdf
from stable_baselines3 import PPO
import primaite
@@ -159,6 +159,13 @@ def run(
learning_rate: float = 3e-4,
) -> None:
"""Run the PrimAITE benchmark."""
# generate report folder
v_str = f"v{primaite.__version__}"
version_result_dir = _RESULTS_ROOT / v_str
version_result_dir.mkdir(exist_ok=True, parents=True)
output_path = version_result_dir / f"PrimAITE {v_str} Benchmark Report.md"
benchmark_start_time = datetime.now()
session_metadata_dict = {}
@@ -193,6 +200,12 @@ def run(
session_metadata=session_metadata_dict,
config_path=data_manipulation_config_path(),
results_root_path=_RESULTS_ROOT,
output_path=output_path,
)
md2pdf(
md_path=output_path,
pdf_path=str(output_path).replace(".md", ".pdf"),
css_path="static/styles.css",
)

View File

@@ -2,6 +2,7 @@
import json
import sys
from datetime import datetime
from os import PathLike
from pathlib import Path
from typing import Dict, Optional
@@ -14,7 +15,7 @@ from utils import _get_system_info
import primaite
PLOT_CONFIG = {
"size": {"auto_size": False, "width": 1500, "height": 900},
"size": {"auto_size": False, "width": 800, "height": 640},
"template": "plotly_white",
"range_slider": False,
}
@@ -144,6 +145,20 @@ def _plot_benchmark_metadata(
yaxis={"title": "Total Reward"},
title=title,
)
fig.update_layout(
legend=dict(
yanchor="top",
y=0.99,
xanchor="left",
x=0.01,
bgcolor="rgba(255,255,255,0.3)",
)
)
for trace in fig["data"]:
if trace["name"].startswith("Session"):
trace["showlegend"] = False
fig["data"][0]["name"] = "Individual Sessions"
fig["data"][0]["showlegend"] = True
return fig
@@ -194,6 +209,7 @@ def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure:
title=title,
)
fig["data"][0]["showlegend"] = True
fig.update_layout(legend=dict(yanchor="top", y=-0.2, xanchor="left", x=0.01, orientation="h"))
return fig
@@ -248,14 +264,7 @@ def _plot_av_s_per_100_steps_10_nodes(
versions = sorted(list(version_times_dict.keys()))
times = [version_times_dict[version] for version in versions]
fig.add_trace(
go.Bar(
x=versions,
y=times,
text=times,
textposition="auto",
)
)
fig.add_trace(go.Bar(x=versions, y=times, text=times, textposition="auto", texttemplate="%{y:.3f}"))
fig.update_layout(
xaxis_title="PrimAITE Version",
@@ -267,7 +276,11 @@ def _plot_av_s_per_100_steps_10_nodes(
def build_benchmark_md_report(
benchmark_start_time: datetime, session_metadata: Dict, config_path: Path, results_root_path: Path
benchmark_start_time: datetime,
session_metadata: Dict,
config_path: Path,
results_root_path: Path,
output_path: PathLike,
) -> None:
"""
Generates a Markdown report for a benchmarking session, documenting performance metrics and graphs.
@@ -319,7 +332,7 @@ def build_benchmark_md_report(
data = benchmark_metadata_dict
primaite_version = data["primaite_version"]
with open(version_result_dir / f"PrimAITE v{primaite_version} Benchmark Report.md", "w") as file:
with open(output_path, "w") as file:
# Title
file.write(f"# PrimAITE v{primaite_version} Learning Benchmark\n")
file.write("## PrimAITE Dev Team\n")
@@ -393,3 +406,15 @@ def build_benchmark_md_report(
f"![Performance of Minor and Bugfix Releases for Major Version {major_v}]"
f"({performance_benchmark_plot_path.name})\n"
)
def md2pdf(md_path: PathLike, pdf_path: PathLike, css_path: PathLike) -> None:
"""Generate PDF version of Markdown report."""
from md2pdf.core import md2pdf
md2pdf(
pdf_file_path=pdf_path,
md_file_path=md_path,
base_url=Path(md_path).parent,
css_file_path=css_path,
)

View File

@@ -0,0 +1,34 @@
body {
font-family: 'Arial', sans-serif;
line-height: 1.6;
/* margin: 1cm; */
}
h1, h2, h3, h4, h5, h6 {
font-weight: bold;
/* margin: 1em 0; */
}
p {
/* margin: 0.5em 0; */
}
ul, ol {
margin: 1em 0;
padding-left: 1.5em;
}
pre {
background: #f4f4f4;
padding: 0.5em;
overflow-x: auto;
}
img {
max-width: 100%;
height: auto;
}
table {
width: 100%;
border-collapse: collapse;
margin: 1em 0;
}
th, td {
padding: 0.5em;
border: 1px solid #ddd;
}

View File

@@ -75,7 +75,8 @@ dev = [
"wheel==0.38.4",
"nbsphinx==0.9.4",
"nbmake==1.5.4",
"pytest-xdist==3.3.1"
"pytest-xdist==3.3.1",
"md2pdf",
]
[project.scripts]

View File

@@ -1096,6 +1096,10 @@ class ConfigureC2BeaconAction(AbstractAction):
return cls.model_fields[info.field_name].default
return v
class NodeAccountsChangePasswordAction(AbstractAction):
"""Action which changes the password for a user."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
@@ -1119,6 +1123,25 @@ class ConfigureC2BeaconAction(AbstractAction):
class RansomwareConfigureC2ServerAction(AbstractAction):
"""Action which sends a command from the C2 Server to the C2 Beacon which configures a local RansomwareScript."""
def form_request(self, node_id: str, username: str, current_password: str, new_password: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
return [
"network",
"node",
node_name,
"service",
"UserManager",
"change_password",
username,
current_password,
new_password,
]
class NodeSessionsRemoteLoginAction(AbstractAction):
"""Action which performs a remote session login."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
@@ -1135,6 +1158,25 @@ class RansomwareConfigureC2ServerAction(AbstractAction):
class RansomwareLaunchC2ServerAction(AbstractAction):
"""Action which causes the C2 Server to send a command to the C2 Beacon to launch the RansomwareScript."""
def form_request(self, node_id: str, username: str, password: str, remote_ip: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
return [
"network",
"node",
node_name,
"service",
"Terminal",
"ssh_to_remote",
username,
password,
remote_ip,
]
class NodeSessionsRemoteLogoutAction(AbstractAction):
"""Action which performs a remote session logout."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
@@ -1160,6 +1202,15 @@ class ExfiltrationC2ServerAction(AbstractAction):
target_folder_name: str
exfiltration_folder_name: Optional[str]
def form_request(self, node_id: str, remote_ip: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
return ["network", "node", node_name, "service", "Terminal", "remote_logoff", remote_ip]
class NodeSendRemoteCommandAction(AbstractAction):
"""Action which sends a terminal command to a remote node via SSH."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
@@ -1219,6 +1270,20 @@ class TerminalC2ServerAction(AbstractAction):
TerminalC2ServerAction._Opts.model_validate(command_model)
return ["network", "node", node_name, "application", "C2Server", "terminal_command", command_model]
def form_request(self, node_id: int, remote_ip: str, command: RequestFormat) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
return [
"network",
"node",
node_name,
"service",
"Terminal",
"send_remote_command",
remote_ip,
{"command": command},
]
class ActionManager:
"""Class which manages the action space for an agent."""
@@ -1276,6 +1341,10 @@ class ActionManager:
"C2_SERVER_RANSOMWARE_CONFIGURE": RansomwareConfigureC2ServerAction,
"C2_SERVER_TERMINAL_COMMAND": TerminalC2ServerAction,
"C2_SERVER_DATA_EXFILTRATE": ExfiltrationC2ServerAction,
"NODE_ACCOUNTS_CHANGE_PASSWORD": NodeAccountsChangePasswordAction,
"SSH_TO_REMOTE": NodeSessionsRemoteLoginAction,
"SESSIONS_REMOTE_LOGOFF": NodeSessionsRemoteLogoutAction,
"NODE_SEND_REMOTE_COMMAND": NodeSendRemoteCommandAction,
}
"""Dictionary which maps action type strings to the corresponding action class."""

View File

@@ -36,6 +36,8 @@ class AgentHistoryItem(BaseModel):
reward: Optional[float] = None
reward_info: Dict[str, Any] = {}
class AgentStartSettings(BaseModel):
"""Configuration values for when an agent starts performing actions."""

View File

@@ -23,8 +23,10 @@ class FileObservation(AbstractObservation, identifier="FILE"):
"""Name of the file, used for querying simulation state dictionary."""
include_num_access: Optional[bool] = None
"""Whether to include the number of accesses to the file in the observation."""
file_system_requires_scan: Optional[bool] = None
"""If True, the file must be scanned to update the health state. Tf False, the true state is always shown."""
def __init__(self, where: WhereType, include_num_access: bool) -> None:
def __init__(self, where: WhereType, include_num_access: bool, file_system_requires_scan: bool) -> None:
"""
Initialise a file observation instance.
@@ -34,9 +36,13 @@ class FileObservation(AbstractObservation, identifier="FILE"):
:type where: WhereType
:param include_num_access: Whether to include the number of accesses to the file in the observation.
:type include_num_access: bool
:param file_system_requires_scan: If True, the file must be scanned to update the health state. Tf False,
the true state is always shown.
:type file_system_requires_scan: bool
"""
self.where: WhereType = where
self.include_num_access: bool = include_num_access
self.file_system_requires_scan: bool = file_system_requires_scan
self.default_observation: ObsType = {"health_status": 0}
if self.include_num_access:
@@ -74,7 +80,11 @@ class FileObservation(AbstractObservation, identifier="FILE"):
file_state = access_from_nested_dict(state, self.where)
if file_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {"health_status": file_state["visible_status"]}
if self.file_system_requires_scan:
health_status = file_state["visible_status"]
else:
health_status = file_state["health_status"]
obs = {"health_status": health_status}
if self.include_num_access:
obs["num_access"] = self._categorise_num_access(file_state["num_access"])
return obs
@@ -104,8 +114,15 @@ class FileObservation(AbstractObservation, identifier="FILE"):
:type parent_where: WhereType, optional
:return: Constructed file observation instance.
:rtype: FileObservation
:param file_system_requires_scan: If True, the folder must be scanned to update the health state. Tf False,
the true state is always shown.
:type file_system_requires_scan: bool
"""
return cls(where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access)
return cls(
where=parent_where + ["files", config.file_name],
include_num_access=config.include_num_access,
file_system_requires_scan=config.file_system_requires_scan,
)
class FolderObservation(AbstractObservation, identifier="FOLDER"):
@@ -122,9 +139,16 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
"""Number of spaces for file observations in this folder."""
include_num_access: Optional[bool] = None
"""Whether files in this folder should include the number of accesses in their observation."""
file_system_requires_scan: Optional[bool] = None
"""If True, the folder must be scanned to update the health state. Tf False, the true state is always shown."""
def __init__(
self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool
self,
where: WhereType,
files: Iterable[FileObservation],
num_files: int,
include_num_access: bool,
file_system_requires_scan: bool,
) -> None:
"""
Initialise a folder observation instance.
@@ -138,12 +162,23 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
:type num_files: int
:param include_num_access: Whether to include the number of accesses to files in the observation.
:type include_num_access: bool
:param file_system_requires_scan: If True, the folder must be scanned to update the health state. Tf False,
the true state is always shown.
:type file_system_requires_scan: bool
"""
self.where: WhereType = where
self.file_system_requires_scan: bool = file_system_requires_scan
self.files: List[FileObservation] = files
while len(self.files) < num_files:
self.files.append(FileObservation(where=None, include_num_access=include_num_access))
self.files.append(
FileObservation(
where=None,
include_num_access=include_num_access,
file_system_requires_scan=self.file_system_requires_scan,
)
)
while len(self.files) > num_files:
truncated_file = self.files.pop()
msg = f"Too many files in folder observation. Truncating file {truncated_file}"
@@ -168,7 +203,10 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
if folder_state is NOT_PRESENT_IN_STATE:
return self.default_observation
health_status = folder_state["health_status"]
if self.file_system_requires_scan:
health_status = folder_state["visible_status"]
else:
health_status = folder_state["health_status"]
obs = {}
@@ -209,6 +247,13 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
# pass down shared/common config items
for file_config in config.files:
file_config.include_num_access = config.include_num_access
file_config.file_system_requires_scan = config.file_system_requires_scan
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access)
return cls(
where=where,
files=files,
num_files=config.num_files,
include_num_access=config.include_num_access,
file_system_requires_scan=config.file_system_requires_scan,
)

View File

@@ -10,6 +10,7 @@ from primaite import getLogger
from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.game.agent.observations.nic_observations import PortObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
@@ -32,6 +33,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
include_users: Optional[bool] = True
"""If True, report user session information."""
def __init__(
self,
@@ -41,6 +44,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
port_list: List[int],
protocol_list: List[str],
num_rules: int,
include_users: bool,
) -> None:
"""
Initialise a firewall observation instance.
@@ -58,9 +62,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
:type protocol_list: List[str]
:param num_rules: Number of rules configured in the firewall.
:type num_rules: int
:param include_users: If True, report user session information.
:type include_users: bool
"""
self.where: WhereType = where
self.include_users: bool = include_users
self.max_users: int = 3
"""Maximum number of remote sessions observable, excess sessions are truncated."""
self.ports: List[PortObservation] = [
PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3)
]
@@ -142,6 +150,9 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
:return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic.
:rtype: ObsType
"""
firewall_state = access_from_nested_dict(state, self.where)
if firewall_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
"ACL": {
@@ -159,6 +170,12 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
},
},
}
if self.include_users:
sess = firewall_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
return obs
@property
@@ -218,4 +235,5 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
port_list=config.port_list,
protocol_list=config.protocol_list,
num_rules=config.num_rules,
include_users=config.include_users,
)

View File

@@ -48,6 +48,12 @@ class HostObservation(AbstractObservation, identifier="HOST"):
"""A dict containing which traffic types are to be included in the observation."""
include_num_access: Optional[bool] = None
"""Whether to include the number of accesses to files observations on this host."""
file_system_requires_scan: Optional[bool] = None
"""
If True, files and folders must be scanned to update the health state. If False, true state is always shown.
"""
include_users: Optional[bool] = True
"""If True, report user session information."""
def __init__(
self,
@@ -64,6 +70,8 @@ class HostObservation(AbstractObservation, identifier="HOST"):
include_nmne: bool,
monitored_traffic: Optional[Dict],
include_num_access: bool,
file_system_requires_scan: bool,
include_users: bool,
) -> None:
"""
Initialise a host observation instance.
@@ -95,10 +103,18 @@ class HostObservation(AbstractObservation, identifier="HOST"):
:type monitored_traffic: Dict
:param include_num_access: Flag to include the number of accesses to files.
:type include_num_access: bool
:param file_system_requires_scan: If True, the files and folders must be scanned to update the health state.
If False, the true state is always shown.
:type file_system_requires_scan: bool
:param include_users: If True, report user session information.
:type include_users: bool
"""
self.where: WhereType = where
self.include_num_access = include_num_access
self.include_users = include_users
self.max_users: int = 3
"""Maximum number of remote sessions observable, excess sessions are truncated."""
# Ensure lists have lengths equal to specified counts by truncating or padding
self.services: List[ServiceObservation] = services
@@ -120,7 +136,13 @@ class HostObservation(AbstractObservation, identifier="HOST"):
self.folders: List[FolderObservation] = folders
while len(self.folders) < num_folders:
self.folders.append(
FolderObservation(where=None, files=[], num_files=num_files, include_num_access=include_num_access)
FolderObservation(
where=None,
files=[],
num_files=num_files,
include_num_access=include_num_access,
file_system_requires_scan=file_system_requires_scan,
)
)
while len(self.folders) > num_folders:
truncated_folder = self.folders.pop()
@@ -151,6 +173,8 @@ class HostObservation(AbstractObservation, identifier="HOST"):
if self.include_num_access:
self.default_observation["num_file_creations"] = 0
self.default_observation["num_file_deletions"] = 0
if self.include_users:
self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0}
def observe(self, state: Dict) -> ObsType:
"""
@@ -178,6 +202,12 @@ class HostObservation(AbstractObservation, identifier="HOST"):
if self.include_num_access:
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
if self.include_users:
sess = node_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
return obs
@property
@@ -202,6 +232,10 @@ class HostObservation(AbstractObservation, identifier="HOST"):
if self.include_num_access:
shape["num_file_creations"] = spaces.Discrete(4)
shape["num_file_deletions"] = spaces.Discrete(4)
if self.include_users:
shape["users"] = spaces.Dict(
{"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)}
)
return spaces.Dict(shape)
@classmethod
@@ -226,6 +260,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
for folder_config in config.folders:
folder_config.include_num_access = config.include_num_access
folder_config.num_files = config.num_files
folder_config.file_system_requires_scan = config.file_system_requires_scan
for nic_config in config.network_interfaces:
nic_config.include_nmne = config.include_nmne
@@ -257,4 +292,6 @@ class HostObservation(AbstractObservation, identifier="HOST"):
include_nmne=config.include_nmne,
monitored_traffic=config.monitored_traffic,
include_num_access=config.include_num_access,
file_system_requires_scan=config.file_system_requires_scan,
include_users=config.include_users,
)

View File

@@ -44,6 +44,10 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
"""A dict containing which traffic types are to be included in the observation."""
include_num_access: Optional[bool] = None
"""Flag to include the number of accesses."""
file_system_requires_scan: bool = True
"""If True, the folder must be scanned to update the health state. Tf False, the true state is always shown."""
include_users: Optional[bool] = True
"""If True, report user session information."""
num_ports: Optional[int] = None
"""Number of ports."""
ip_list: Optional[List[str]] = None
@@ -187,6 +191,10 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
host_config.monitored_traffic = config.monitored_traffic
if host_config.include_num_access is None:
host_config.include_num_access = config.include_num_access
if host_config.file_system_requires_scan is None:
host_config.file_system_requires_scan = config.file_system_requires_scan
if host_config.include_users is None:
host_config.include_users = config.include_users
for router_config in config.routers:
if router_config.num_ports is None:
@@ -201,6 +209,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
router_config.protocol_list = config.protocol_list
if router_config.num_rules is None:
router_config.num_rules = config.num_rules
if router_config.include_users is None:
router_config.include_users = config.include_users
for firewall_config in config.firewalls:
if firewall_config.ip_list is None:
@@ -213,6 +223,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
firewall_config.protocol_list = config.protocol_list
if firewall_config.num_rules is None:
firewall_config.num_rules = config.num_rules
if firewall_config.include_users is None:
firewall_config.include_users = config.include_users
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]

View File

@@ -39,6 +39,8 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
include_users: Optional[bool] = True
"""If True, report user session information."""
def __init__(
self,
@@ -46,6 +48,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
ports: List[PortObservation],
num_ports: int,
acl: ACLObservation,
include_users: bool,
) -> None:
"""
Initialise a router observation instance.
@@ -59,12 +62,16 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
:type num_ports: int
:param acl: ACL observation representing the access control list of the router.
:type acl: ACLObservation
:param include_users: If True, report user session information.
:type include_users: bool
"""
self.where: WhereType = where
self.ports: List[PortObservation] = ports
self.acl: ACLObservation = acl
self.num_ports: int = num_ports
self.include_users: bool = include_users
self.max_users: int = 3
"""Maximum number of remote sessions observable, excess sessions are truncated."""
while len(self.ports) < num_ports:
self.ports.append(PortObservation(where=None))
while len(self.ports) > num_ports:
@@ -95,6 +102,12 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
obs["ACL"] = self.acl.observe(state)
if self.ports:
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
if self.include_users:
sess = router_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
return obs
@property
@@ -143,4 +156,4 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports]
acl = ACLObservation.from_config(config=config.acl, parent_where=where)
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl)
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl, include_users=config.include_users)

View File

@@ -47,7 +47,15 @@ class AbstractReward:
@abstractmethod
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state."""
"""Calculate the reward for the current state.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
return 0.0
@classmethod
@@ -67,7 +75,15 @@ class DummyReward(AbstractReward):
"""Dummy reward function component which always returns 0."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state."""
"""Calculate the reward for the current state.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
return 0.0
@classmethod
@@ -109,8 +125,12 @@ class DatabaseFileIntegrity(AbstractReward):
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
database_file_state = access_from_nested_dict(state, self.location_in_state)
if database_file_state is NOT_PRESENT_IN_STATE:
@@ -283,6 +303,12 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
component will keep track of that information. In that case, it doesn't matter whether the last successful
request returned was able to connect to the database server, because there has been an unsuccessful request
since.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
self._last_request_failed = last_action_response.response.status != "success"
@@ -295,14 +321,11 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
# If the last request was actually sent, then check if the connection was established.
db_state = access_from_nested_dict(state, self.location_in_state)
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
last_action_response.reward_info = {"reason": f"Can't calculate reward for {self.__class__.__name__}"}
return 0.0
last_connection_successful = db_state["last_connection_successful"]
if last_connection_successful is False:
return -1.0
elif last_connection_successful is True:
return 1.0
return 0.0
last_action_response.reward_info = {"last_connection_successful": last_connection_successful}
return 1.0 if last_connection_successful else -1.0
@classmethod
def from_config(cls, config: Dict) -> AbstractReward:
@@ -346,7 +369,15 @@ class SharedReward(AbstractReward):
"""Method that retrieves an agent's current reward given the agent's name."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Simply access the other agent's reward and return it."""
"""Simply access the other agent's reward and return it.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
return self.callback(self.agent_name)
@classmethod
@@ -379,7 +410,15 @@ class ActionPenalty(AbstractReward):
self.do_nothing_penalty = do_nothing_penalty
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the penalty to be applied."""
"""Calculate the penalty to be applied.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
if last_action_response.action == "DONOTHING":
return self.do_nothing_penalty
else:

View File

@@ -990,6 +990,7 @@ class UserManager(Service):
if user and user.password == current_password:
user.password = new_password
self.sys_log.info(f"{self.name}: Password changed for {username}")
self._user_session_manager._logout_user(user=user)
return True
self.sys_log.info(f"{self.name}: Password change failed for {username}")
return False
@@ -1027,6 +1028,10 @@ class UserManager(Service):
self.sys_log.info(f"{self.name}: Failed to enable user: {username}")
return False
@property
def _user_session_manager(self) -> "UserSessionManager":
return self.software_manager.software["UserSessionManager"] # noqa
class UserSession(SimComponent):
"""
@@ -1260,7 +1265,8 @@ class UserSessionManager(Service):
:return: A dictionary representing the current state.
"""
state = super().describe_state()
state["active_remote_logins"] = len(self.remote_sessions)
state["current_local_user"] = None if not self.local_session else self.local_session.user.username
state["active_remote_sessions"] = list(self.remote_sessions.keys())
return state
@property
@@ -1440,6 +1446,19 @@ class UserSessionManager(Service):
"""
return self._logout(local=False, remote_session_id=remote_session_id)
def _logout_user(self, user: Union[str, User]) -> bool:
"""End a user session by username or user object."""
if isinstance(user, str):
user = self._user_manager.users[user] # grab user object from username
for sess_id, session in self.remote_sessions.items():
if session.user is user:
self._logout(local=False, remote_session_id=sess_id)
return True
if self.local_user_logged_in and self.local_session.user is user:
self.local_logout()
return True
return False
@property
def local_user_logged_in(self) -> bool:
"""

View File

@@ -23,6 +23,8 @@ from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.service import Service, ServiceOperatingState
# TODO 2824: Since remote terminal connections and remote user sessions are the same thing, we could refactor
# the terminal to leverage the user session manager's list. This way we avoid potential bugs and code ducplication
class TerminalClientConnection(BaseModel):
"""
TerminalClientConnection Class.
@@ -162,22 +164,6 @@ class Terminal(Service):
def _init_request_manager(self) -> RequestManager:
"""Initialise Request manager."""
rm = super()._init_request_manager()
rm.add_request(
"send",
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.send())),
)
def _login(request: RequestFormat, context: Dict) -> RequestResponse:
login = self._process_local_login(username=request[0], password=request[1])
if login:
return RequestResponse(
status="success",
data={
"ip_address": login.ip_address,
},
)
else:
return RequestResponse(status="failure", data={"reason": "Invalid login credentials"})
def _remote_login(request: RequestFormat, context: Dict) -> RequestResponse:
login = self._send_remote_login(username=request[0], password=request[1], ip_address=request[2])
@@ -191,10 +177,34 @@ class Terminal(Service):
else:
return RequestResponse(status="failure", data={})
rm.add_request(
"ssh_to_remote",
request_type=RequestType(func=_remote_login),
)
def _remote_logoff(request: RequestFormat, context: Dict) -> RequestResponse:
"""Logoff from remote connection."""
ip_address = IPv4Address(request[0])
remote_connection = self._get_connection_from_ip(ip_address=ip_address)
if remote_connection:
outcome = self._disconnect(remote_connection.connection_uuid)
if outcome:
return RequestResponse(
status="success",
data={},
)
else:
return RequestResponse(
status="failure",
data={"reason": "No remote connection held."},
)
rm.add_request("remote_logoff", request_type=RequestType(func=_remote_logoff))
def remote_execute_request(request: RequestFormat, context: Dict) -> RequestResponse:
"""Execute an instruction."""
command: str = request[0]
ip_address: IPv4Address = IPv4Address(request[1])
ip_address: IPv4Address = IPv4Address(request[0])
command: str = request[1]["command"]
remote_connection = self._get_connection_from_ip(ip_address=ip_address)
if remote_connection:
outcome = remote_connection.execute(command)
@@ -209,30 +219,11 @@ class Terminal(Service):
data={},
)
def _logoff(request: RequestFormat, context: Dict) -> RequestResponse:
"""Logoff from connection."""
connection_uuid = request[0]
self.parent.user_session_manager.local_logout(connection_uuid)
self._disconnect(connection_uuid)
return RequestResponse(status="success", data={})
rm.add_request(
"Login",
request_type=RequestType(func=_login),
)
rm.add_request(
"Remote Login",
request_type=RequestType(func=_remote_login),
)
rm.add_request(
"Execute",
"send_remote_command",
request_type=RequestType(func=remote_execute_request),
)
rm.add_request("Logoff", request_type=RequestType(func=_logoff))
return rm
def execute(self, command: List[Any]) -> Optional[RequestResponse]:
@@ -280,13 +271,9 @@ class Terminal(Service):
if self.operating_state != ServiceOperatingState.RUNNING:
self.sys_log.warning(f"{self.name}: Cannot login as service is not running.")
return None
connection_request_id = str(uuid4())
self._client_connection_requests[connection_request_id] = None
if ip_address:
# Assuming that if IP is passed we are connecting to remote
return self._send_remote_login(
username=username, password=password, ip_address=ip_address, connection_request_id=connection_request_id
)
return self._send_remote_login(username=username, password=password, ip_address=ip_address)
else:
return self._process_local_login(username=username, password=password)
@@ -313,6 +300,9 @@ class Terminal(Service):
def _check_client_connection(self, connection_id: str) -> bool:
"""Check that client_connection_id is valid."""
if not self.parent.user_session_manager.validate_remote_session_uuid(connection_id):
self._disconnect(connection_id)
return False
return connection_id in self._connections
def _send_remote_login(
@@ -320,32 +310,24 @@ class Terminal(Service):
username: str,
password: str,
ip_address: IPv4Address,
connection_request_id: str,
connection_request_id: Optional[str] = None,
is_reattempt: bool = False,
) -> Optional[RemoteTerminalConnection]:
"""Send a remote login attempt and connect to Node.
:param: username: Username used to connect to the remote node.
:type: username: str
:param: password: Password used to connect to the remote node
:type: password: str
:param: ip_address: Target Node IP address for login attempt.
:type: ip_address: IPv4Address
:param: connection_request_id: Connection Request ID
:type: connection_request_id: str
:param: connection_request_id: Connection Request ID, if not provided, a new one is generated
:type: connection_request_id: Optional[str]
:param: is_reattempt: True if the request has been reattempted. Default False.
:type: is_reattempt: Optional[bool]
:return: RemoteTerminalConnection: Connection Object for sending further commands if successful, else False.
"""
self.sys_log.info(
f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}"
)
connection_request_id = connection_request_id or str(uuid4())
if is_reattempt:
valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id)
if valid_connection_request:
@@ -360,6 +342,9 @@ class Terminal(Service):
self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.")
return None
self.sys_log.info(
f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}"
)
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
user_details: SSHUserCredentials = SSHUserCredentials(username=username, password=password)

View File

@@ -0,0 +1,942 @@
io_settings:
save_agent_actions: true
save_step_metadata: false
save_pcap_logs: false
save_sys_logs: false
sys_log_level: WARNING
game:
max_episode_length: 128
ports:
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
thresholds:
nmne:
high: 10
medium: 5
low: 0
agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.3
1: 0.6
2: 0.1
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_2
applications:
- application_name: WebBrowser
- application_name: DatabaseClient
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 2
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
2:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 1
reward_function:
reward_components:
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_2
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_2
- ref: client_1_green_user
team: GREEN
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.3
1: 0.6
2: 0.1
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_1
applications:
- application_name: WebBrowser
- application_name: DatabaseClient
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 2
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
2:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 1
reward_function:
reward_components:
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_1
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_1
- ref: data_manipulation_attacker
team: RED
type: RedDatabaseCorruptingAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_1
applications:
- application_name: DataManipulationBot
- node_name: client_2
applications:
- application_name: DataManipulationBot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: CUSTOM
options:
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
monitored_traffic:
icmp:
- NONE
tcp:
- DNS
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1:eth-1<->switch_1:eth-8
- router_1:eth-2<->switch_2:eth-8
- switch_1:eth-1<->domain_controller:eth-1
- switch_1:eth-2<->web_server:eth-1
- switch_1:eth-3<->database_server:eth-1
- switch_1:eth-4<->backup_server:eth-1
- switch_1:eth-7<->security_suite:eth-1
- switch_2:eth-1<->client_1:eth-1
- switch_2:eth-2<->client_2:eth-1
- switch_2:eth-7<->security_suite:eth-2
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:
- type: DONOTHING
- type: NODE_SERVICE_SCAN
- type: NODE_SERVICE_STOP
- type: NODE_SERVICE_START
- type: NODE_SERVICE_PAUSE
- type: NODE_SERVICE_RESUME
- type: NODE_SERVICE_RESTART
- type: NODE_SERVICE_DISABLE
- type: NODE_SERVICE_ENABLE
- type: NODE_SERVICE_FIX
- type: NODE_FILE_SCAN
- type: NODE_FILE_CHECKHASH
- type: NODE_FILE_DELETE
- type: NODE_FILE_REPAIR
- type: NODE_FILE_RESTORE
- type: NODE_FOLDER_SCAN
- type: NODE_FOLDER_CHECKHASH
- type: NODE_FOLDER_REPAIR
- type: NODE_FOLDER_RESTORE
- type: NODE_OS_SCAN
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: NODE_RESET
- type: ROUTER_ACL_ADDRULE
- type: ROUTER_ACL_REMOVERULE
- type: HOST_NIC_ENABLE
- type: HOST_NIC_DISABLE
action_map:
0:
action: DONOTHING
options: {}
# scan webapp service
1:
action: NODE_SERVICE_SCAN
options:
node_id: 1
service_id: 0
# stop webapp service
2:
action: NODE_SERVICE_STOP
options:
node_id: 1
service_id: 0
# start webapp service
3:
action: "NODE_SERVICE_START"
options:
node_id: 1
service_id: 0
4:
action: "NODE_SERVICE_PAUSE"
options:
node_id: 1
service_id: 0
5:
action: "NODE_SERVICE_RESUME"
options:
node_id: 1
service_id: 0
6:
action: "NODE_SERVICE_RESTART"
options:
node_id: 1
service_id: 0
7:
action: "NODE_SERVICE_DISABLE"
options:
node_id: 1
service_id: 0
8:
action: "NODE_SERVICE_ENABLE"
options:
node_id: 1
service_id: 0
9: # check database.db file
action: "NODE_FILE_SCAN"
options:
node_id: 2
folder_id: 0
file_id: 0
10:
action: "NODE_FILE_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
options:
node_id: 2
folder_id: 0
file_id: 0
11:
action: "NODE_FILE_DELETE"
options:
node_id: 2
folder_id: 0
file_id: 0
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 2
folder_id: 0
file_id: 0
13:
action: "NODE_SERVICE_FIX"
options:
node_id: 2
service_id: 0
14:
action: "NODE_FOLDER_SCAN"
options:
node_id: 2
folder_id: 0
15:
action: "NODE_FOLDER_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
options:
node_id: 2
folder_id: 0
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 2
folder_id: 0
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 2
folder_id: 0
18:
action: "NODE_OS_SCAN"
options:
node_id: 0
19:
action: "NODE_SHUTDOWN"
options:
node_id: 0
20:
action: NODE_STARTUP
options:
node_id: 0
21:
action: NODE_RESET
options:
node_id: 0
22:
action: "NODE_OS_SCAN"
options:
node_id: 1
23:
action: "NODE_SHUTDOWN"
options:
node_id: 1
24:
action: NODE_STARTUP
options:
node_id: 1
25:
action: NODE_RESET
options:
node_id: 1
26: # old action num: 18
action: "NODE_OS_SCAN"
options:
node_id: 2
27:
action: "NODE_SHUTDOWN"
options:
node_id: 2
28:
action: NODE_STARTUP
options:
node_id: 2
29:
action: NODE_RESET
options:
node_id: 2
30:
action: "NODE_OS_SCAN"
options:
node_id: 3
31:
action: "NODE_SHUTDOWN"
options:
node_id: 3
32:
action: NODE_STARTUP
options:
node_id: 3
33:
action: NODE_RESET
options:
node_id: 3
34:
action: "NODE_OS_SCAN"
options:
node_id: 4
35:
action: "NODE_SHUTDOWN"
options:
node_id: 4
36:
action: NODE_STARTUP
options:
node_id: 4
37:
action: NODE_RESET
options:
node_id: 4
38:
action: "NODE_OS_SCAN"
options:
node_id: 5
39: # old action num: 19 # shutdown client 1
action: "NODE_SHUTDOWN"
options:
node_id: 5
40: # old action num: 20
action: NODE_STARTUP
options:
node_id: 5
41: # old action num: 21
action: NODE_RESET
options:
node_id: 5
42:
action: "NODE_OS_SCAN"
options:
node_id: 6
43:
action: "NODE_SHUTDOWN"
options:
node_id: 6
44:
action: NODE_STARTUP
options:
node_id: 6
45:
action: NODE_RESET
options:
node_id: 6
46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1"
action: "ROUTER_ACL_ADDRULE"
options:
target_router: router_1
position: 1
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2"
action: "ROUTER_ACL_ADDRULE"
options:
target_router: router_1
position: 2
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
48: # old action num: 24 # block tcp traffic from client 1 to web app
action: "ROUTER_ACL_ADDRULE"
options:
target_router: router_1
position: 3
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
49: # old action num: 25 # block tcp traffic from client 2 to web app
action: "ROUTER_ACL_ADDRULE"
options:
target_router: router_1
position: 4
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
50: # old action num: 26
action: "ROUTER_ACL_ADDRULE"
options:
target_router: router_1
position: 5
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
51: # old action num: 27
action: "ROUTER_ACL_ADDRULE"
options:
target_router: router_1
position: 6
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
52: # old action num: 28
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 0
53: # old action num: 29
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 1
54: # old action num: 30
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 2
55: # old action num: 31
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 3
56: # old action num: 32
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 4
57: # old action num: 33
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 5
58: # old action num: 34
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 6
59: # old action num: 35
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 7
60: # old action num: 36
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 8
61: # old action num: 37
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 9
62: # old action num: 38
action: "HOST_NIC_DISABLE"
options:
node_id: 0
nic_id: 0
63: # old action num: 39
action: "HOST_NIC_ENABLE"
options:
node_id: 0
nic_id: 0
64: # old action num: 40
action: "HOST_NIC_DISABLE"
options:
node_id: 1
nic_id: 0
65: # old action num: 41
action: "HOST_NIC_ENABLE"
options:
node_id: 1
nic_id: 0
66: # old action num: 42
action: "HOST_NIC_DISABLE"
options:
node_id: 2
nic_id: 0
67: # old action num: 43
action: "HOST_NIC_ENABLE"
options:
node_id: 2
nic_id: 0
68: # old action num: 44
action: "HOST_NIC_DISABLE"
options:
node_id: 3
nic_id: 0
69: # old action num: 45
action: "HOST_NIC_ENABLE"
options:
node_id: 3
nic_id: 0
70: # old action num: 46
action: "HOST_NIC_DISABLE"
options:
node_id: 4
nic_id: 0
71: # old action num: 47
action: "HOST_NIC_ENABLE"
options:
node_id: 4
nic_id: 0
72: # old action num: 48
action: "HOST_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
73: # old action num: 49
action: "HOST_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
74: # old action num: 50
action: "HOST_NIC_DISABLE"
options:
node_id: 5
nic_id: 0
75: # old action num: 51
action: "HOST_NIC_ENABLE"
options:
node_id: 5
nic_id: 0
76: # old action num: 52
action: "HOST_NIC_DISABLE"
options:
node_id: 6
nic_id: 0
77: # old action num: 53
action: "HOST_NIC_ENABLE"
options:
node_id: 6
nic_id: 0
options:
nodes:
- node_name: domain_controller
- node_name: web_server
applications:
- application_name: DatabaseClient
services:
- service_name: WebServer
- node_name: database_server
folders:
- folder_name: database
files:
- file_name: database.db
services:
- service_name: DatabaseService
- node_name: backup_server
- node_name: security_suite
- node_name: client_1
- node_name: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.40
options:
node_hostname: database_server
folder_name: database
file_name: database.db
- type: SHARED_REWARD
weight: 1.0
options:
agent_name: client_1_green_user
- type: SHARED_REWARD
weight: 1.0
options:
agent_name: client_2_green_user
agent_settings:
flatten_obs: true
action_masking: true
simulation:
network:
nmne_config:
capture_nmne: true
nmne_capture_keywords:
- DELETE
nodes:
- hostname: router_1
type: router
num_ports: 5
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.10.1
subnet_mask: 255.255.255.0
acl:
18:
action: PERMIT
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
19:
action: PERMIT
src_port: DNS
dst_port: DNS
20:
action: PERMIT
src_port: FTP
dst_port: FTP
21:
action: PERMIT
src_port: HTTP
dst_port: HTTP
22:
action: PERMIT
src_port: ARP
dst_port: ARP
23:
action: PERMIT
protocol: ICMP
- hostname: switch_1
type: switch
num_ports: 8
- hostname: switch_2
type: switch
num_ports: 8
- hostname: domain_controller
type: server
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- type: DNSServer
options:
domain_mapping:
arcd.com: 192.168.1.12 # web server
- hostname: web_server
type: server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- type: WebServer
applications:
- type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- hostname: database_server
type: server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- type: DatabaseService
options:
backup_server_ip: 192.168.1.16
- type: FTPClient
- hostname: backup_server
type: server
ip_address: 192.168.1.16
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- type: FTPServer
- hostname: security_suite
type: server
ip_address: 192.168.1.110
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
network_interfaces:
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
ip_address: 192.168.10.110
subnet_mask: 255.255.255.0
- hostname: client_1
type: computer
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
- type: WebBrowser
options:
target_url: http://arcd.com/users/
- type: DatabaseClient
options:
db_server_ip: 192.168.1.14
services:
- type: DNSClient
- hostname: client_2
type: computer
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- type: WebBrowser
options:
target_url: http://arcd.com/users/
- type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
- type: DatabaseClient
options:
db_server_ip: 192.168.1.14
services:
- type: DNSClient
links:
- endpoint_a_hostname: router_1
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 8
- endpoint_a_hostname: router_1
endpoint_a_port: 2
endpoint_b_hostname: switch_2
endpoint_b_port: 8
- endpoint_a_hostname: switch_1
endpoint_a_port: 1
endpoint_b_hostname: domain_controller
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 2
endpoint_b_hostname: web_server
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 3
endpoint_b_hostname: database_server
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 4
endpoint_b_hostname: backup_server
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 7
endpoint_b_hostname: security_suite
endpoint_b_port: 1
- endpoint_a_hostname: switch_2
endpoint_a_port: 1
endpoint_b_hostname: client_1
endpoint_b_port: 1
- endpoint_a_hostname: switch_2
endpoint_a_port: 2
endpoint_b_hostname: client_2
endpoint_b_port: 1
- endpoint_a_hostname: switch_2
endpoint_a_port: 7
endpoint_b_hostname: security_suite
endpoint_b_port: 2

View File

@@ -463,6 +463,10 @@ def game_and_agent():
{"type": "C2_SERVER_RANSOMWARE_CONFIGURE"},
{"type": "C2_SERVER_TERMINAL_COMMAND"},
{"type": "C2_SERVER_DATA_EXFILTRATE"},
{"type": "NODE_ACCOUNTS_CHANGE_PASSWORD"},
{"type": "SSH_TO_REMOTE"},
{"type": "SESSIONS_REMOTE_LOGOFF"},
{"type": "NODE_SEND_REMOTE_COMMAND"},
]
action_space = ActionManager(

View File

@@ -0,0 +1,166 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Tuple
import pytest
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.base import UserManager
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.router import ACLAction
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.service import ServiceOperatingState
from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection
@pytest.fixture
def game_and_agent_fixture(game_and_agent):
"""Create a game with a simple agent that can be controlled by the tests."""
game, agent = game_and_agent
router = game.simulation.network.get_node_by_hostname("router")
router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=4)
return (game, agent)
def test_remote_login(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
game, agent = game_and_agent_fixture
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
client_1 = game.simulation.network.get_node_by_hostname("client_1")
# create a new user account on server_1 that will be logged into remotely
server_1_usm: UserManager = server_1.software_manager.software["UserManager"]
server_1_usm.add_user("user123", "password", is_admin=True)
action = (
"SSH_TO_REMOTE",
{
"node_id": 0,
"username": "user123",
"password": "password",
"remote_ip": str(server_1.network_interface[1].ip_address),
},
)
agent.store_action(action)
game.step()
assert agent.history[-1].response.status == "success"
connection_established = False
for conn_str, conn_obj in client_1.terminal.connections.items():
conn_obj: RemoteTerminalConnection
if conn_obj.ip_address == server_1.network_interface[1].ip_address:
connection_established = True
if not connection_established:
pytest.fail("Remote SSH connection could not be established")
def test_remote_login_wrong_password(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
game, agent = game_and_agent_fixture
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
client_1 = game.simulation.network.get_node_by_hostname("client_1")
# create a new user account on server_1 that will be logged into remotely
server_1_usm: UserManager = server_1.software_manager.software["UserManager"]
server_1_usm.add_user("user123", "password", is_admin=True)
action = (
"SSH_TO_REMOTE",
{
"node_id": 0,
"username": "user123",
"password": "wrong_password",
"remote_ip": str(server_1.network_interface[1].ip_address),
},
)
agent.store_action(action)
game.step()
assert agent.history[-1].response.status == "failure"
connection_established = False
for conn_str, conn_obj in client_1.terminal.connections.items():
conn_obj: RemoteTerminalConnection
if conn_obj.ip_address == server_1.network_interface[1].ip_address:
connection_established = True
if connection_established:
pytest.fail("Remote SSH connection was established despite wrong password")
def test_remote_login_change_password(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
game, agent = game_and_agent_fixture
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
client_1 = game.simulation.network.get_node_by_hostname("client_1")
# create a new user account on server_1 that will be logged into remotely
server_1_um: UserManager = server_1.software_manager.software["UserManager"]
server_1_um.add_user("user123", "password", is_admin=True)
action = (
"NODE_ACCOUNTS_CHANGE_PASSWORD",
{
"node_id": 1, # server_1
"username": "user123",
"current_password": "password",
"new_password": "different_password",
},
)
agent.store_action(action)
game.step()
assert agent.history[-1].response.status == "success"
assert server_1_um.users["user123"].password == "different_password"
def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
game, agent = game_and_agent_fixture
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
client_1 = game.simulation.network.get_node_by_hostname("client_1")
# create a new user account on server_1 that will be logged into remotely
server_1_usm: UserManager = server_1.software_manager.software["UserManager"]
server_1_usm.add_user("user123", "password", is_admin=True)
# Log in remotely
action = (
"SSH_TO_REMOTE",
{
"node_id": 0,
"username": "user123",
"password": "password",
"remote_ip": str(server_1.network_interface[1].ip_address),
},
)
agent.store_action(action)
game.step()
# Change password
action = (
"NODE_ACCOUNTS_CHANGE_PASSWORD",
{
"node_id": 1, # server_1
"username": "user123",
"current_password": "password",
"new_password": "different_password",
},
)
agent.store_action(action)
game.step()
# Assert that the user cannot execute an action
action = (
"NODE_SEND_REMOTE_COMMAND",
{
"node_id": 0,
"remote_ip": str(server_1.network_interface[1].ip_address),
"command": ["file_system", "create", "file", "folder123", "doggo.pdf", False],
},
)
agent.store_action(action)
game.step()
assert server_1.file_system.get_folder("folder123") is None
assert server_1.file_system.get_file("folder123", "doggo.pdf") is None

View File

@@ -26,6 +26,7 @@ def test_file_observation(simulation):
dog_file_obs = FileObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
include_num_access=False,
file_system_requires_scan=True,
)
assert dog_file_obs.space["health_status"] == spaces.Discrete(6)
@@ -53,6 +54,7 @@ def test_folder_observation(simulation):
root_folder_obs = FolderObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"],
include_num_access=False,
file_system_requires_scan=True,
num_files=1,
files=[],
)

View File

@@ -33,6 +33,7 @@ def test_firewall_observation():
wildcard_list=["0.0.0.255", "0.0.0.1"],
port_list=["HTTP", "DNS"],
protocol_list=["TCP"],
include_users=False,
)
observation = firewall_observation.observe(firewall.describe_state())

View File

@@ -38,6 +38,8 @@ def test_host_observation(simulation):
applications=[],
folders=[],
network_interfaces=[],
file_system_requires_scan=True,
include_users=False,
)
assert host_obs.space["operating_status"] == spaces.Discrete(5)

View File

@@ -27,7 +27,7 @@ def test_router_observation():
port_list=["HTTP", "DNS"],
protocol_list=["TCP"],
)
router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl)
router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl, include_users=False)
# Observe the state using the RouterObservation instance
observed_output = router_observation.observe(router.describe_state())

View File

@@ -0,0 +1,89 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import pytest
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.transmission.transport_layer import Port
from tests import TEST_ASSETS_ROOT
DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml"
@pytest.fixture
def env_with_ssh() -> PrimaiteGymEnv:
"""Build data manipulation environment with SSH port open on router."""
env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG)
env.agent.flatten_obs = False
router: Router = env.game.simulation.network.get_node_by_hostname("router_1")
router.acl.add_rule(ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=3)
return env
def extract_login_numbers_from_obs(obs):
"""Traverse the observation dictionary and return number of user sessions for all nodes."""
login_nums = {}
for node_name, node_obs in obs["NODES"].items():
login_nums[node_name] = node_obs.get("users")
return login_nums
class TestUserObservations:
"""Test that the RouterObservation, FirewallObservation, and HostObservation have the correct number of logins."""
def test_no_sessions_at_episode_start(self, env_with_ssh):
"""Test that all of the login observations start at 0 before any logins occur."""
obs, *_ = env_with_ssh.step(0)
logins_obs = extract_login_numbers_from_obs(obs)
for o in logins_obs.values():
assert o["local_login"] == 0
assert o["remote_sessions"] == 0
def test_single_login(self, env_with_ssh: PrimaiteGymEnv):
"""Test that performing a remote login increases the remote_sessions observation by 1."""
client_1 = env_with_ssh.game.simulation.network.get_node_by_hostname("client_1")
client_1.terminal._send_remote_login("admin", "admin", "192.168.1.14") # connect to database server via ssh
obs, *_ = env_with_ssh.step(0)
logins_obs = extract_login_numbers_from_obs(obs)
db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server
assert db_srv_logins_obs["local_login"] == 0
assert db_srv_logins_obs["remote_sessions"] == 1
for o in logins_obs.values(): # the remaining obs after popping HOST2
assert o["local_login"] == 0
assert o["remote_sessions"] == 0
def test_logout(self, env_with_ssh: PrimaiteGymEnv):
"""Test that remote_sessions observation correctly decreases upon logout."""
client_1 = env_with_ssh.game.simulation.network.get_node_by_hostname("client_1")
client_1.terminal._send_remote_login("admin", "admin", "192.168.1.14") # connect to database server via ssh
db_srv = env_with_ssh.game.simulation.network.get_node_by_hostname("database_server")
db_srv.user_manager.change_user_password("admin", "admin", "different_pass") # changing password logs out user
obs, *_ = env_with_ssh.step(0)
logins_obs = extract_login_numbers_from_obs(obs)
for o in logins_obs.values():
assert o["local_login"] == 0
assert o["remote_sessions"] == 0
def test_max_observable_sessions(self, env_with_ssh: PrimaiteGymEnv):
"""Log in from 5 remote places and check that only a max of 3 is shown in the observation."""
MAX_OBSERVABLE_SESSIONS = 3
# Right now this is hardcoded as 3 in HostObservation, FirewallObservation, and RouterObservation
obs, *_ = env_with_ssh.step(0)
logins_obs = extract_login_numbers_from_obs(obs)
db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server
db_srv = env_with_ssh.game.simulation.network.get_node_by_hostname("database_server")
db_srv.user_session_manager.remote_session_timeout_steps = 20
db_srv.user_session_manager.max_remote_sessions = 5
node_names = ("client_1", "client_2", "backup_server", "security_suite", "domain_controller")
for i, node_name in enumerate(node_names):
node = env_with_ssh.game.simulation.network.get_node_by_hostname(node_name)
node.terminal._send_remote_login("admin", "admin", "192.168.1.14")
obs, *_ = env_with_ssh.step(0)
logins_obs = extract_login_numbers_from_obs(obs)
db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server
assert db_srv_logins_obs["remote_sessions"] == min(MAX_OBSERVABLE_SESSIONS, i + 1)
assert len(db_srv.user_session_manager.remote_sessions) == i + 1

View File

@@ -17,6 +17,7 @@ def test_file_observation():
dog_file_obs = FileObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
include_num_access=False,
file_system_requires_scan=True,
)
assert dog_file_obs.observe(state) == {"health_status": 1}
assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)})

View File

@@ -76,13 +76,16 @@ def test_uc2_rewards(game_and_agent):
]
)
state = game.get_sim_state()
reward_value = comp.calculate(
state,
last_action_response=AgentHistoryItem(
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
),
ahi = AgentHistoryItem(
timestep=0,
action="NODE_APPLICATION_EXECUTE",
parameters={},
request=["execute"],
response=response,
)
reward_value = comp.calculate(state, last_action_response=ahi)
assert reward_value == 1.0
assert ahi.reward_info == {"last_connection_successful": True}
router.acl.remove_rule(position=2)
@@ -92,13 +95,9 @@ def test_uc2_rewards(game_and_agent):
]
)
state = game.get_sim_state()
reward_value = comp.calculate(
state,
last_action_response=AgentHistoryItem(
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
),
)
reward_value = comp.calculate(state, last_action_response=ahi)
assert reward_value == -1.0
assert ahi.reward_info == {"last_connection_successful": False}
def test_shared_reward():

View File

@@ -0,0 +1,132 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import List
import pytest
import yaml
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation
from primaite.game.agent.observations.host_observations import HostObservation
class TestFileSystemRequiresScan:
@pytest.mark.parametrize(
("yaml_option_string", "expected_val"),
(
("file_system_requires_scan: true", True),
("file_system_requires_scan: false", False),
(" ", True),
),
)
def test_obs_config(self, yaml_option_string, expected_val):
"""Check that the default behaviour is to set FileSystemRequiresScan to True."""
obs_cfg_yaml = f"""
type: CUSTOM
options:
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
{yaml_option_string}
include_nmne: true
monitored_traffic:
icmp:
- NONE
tcp:
- DNS
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1:eth-1<->switch_1:eth-8
- router_1:eth-2<->switch_2:eth-8
- switch_1:eth-1<->domain_controller:eth-1
- switch_1:eth-2<->web_server:eth-1
- switch_1:eth-3<->database_server:eth-1
- switch_1:eth-4<->backup_server:eth-1
- switch_1:eth-7<->security_suite:eth-1
- switch_2:eth-1<->client_1:eth-1
- switch_2:eth-2<->client_2:eth-1
- switch_2:eth-7<->security_suite:eth-2
- type: "NONE"
label: ICS
options: {{}}
"""
cfg = yaml.safe_load(obs_cfg_yaml)
manager = ObservationManager.from_config(cfg)
hosts: List[HostObservation] = manager.obs.components["NODES"].hosts
for i, host in enumerate(hosts):
folders: List[FolderObservation] = host.folders
for j, folder in enumerate(folders):
assert folder.file_system_requires_scan == expected_val # Make sure folders require scan by default
files: List[FileObservation] = folder.files
for k, file in enumerate(files):
assert file.file_system_requires_scan == expected_val
def test_file_require_scan(self):
file_state = {"health_status": 3, "visible_status": 1}
obs_requiring_scan = FileObservation([], include_num_access=False, file_system_requires_scan=True)
assert obs_requiring_scan.observe(file_state)["health_status"] == 1
obs_not_requiring_scan = FileObservation([], include_num_access=False, file_system_requires_scan=False)
assert obs_not_requiring_scan.observe(file_state)["health_status"] == 3
def test_folder_require_scan(self):
folder_state = {"health_status": 3, "visible_status": 1}
obs_requiring_scan = FolderObservation(
[], files=[], num_files=0, include_num_access=False, file_system_requires_scan=True
)
assert obs_requiring_scan.observe(folder_state)["health_status"] == 1
obs_not_requiring_scan = FolderObservation(
[], files=[], num_files=0, include_num_access=False, file_system_requires_scan=False
)
assert obs_not_requiring_scan.observe(folder_state)["health_status"] == 3