diff --git a/CHANGELOG.md b/CHANGELOG.md index f06301a0..6811ecaf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,12 +10,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Random Number Generator Seeding by specifying a random number seed in the config file. - Implemented Terminal service class, providing a generic terminal simulation. - Added `User`, `UserManager` and `UserSessionManager` to enable the creation of user accounts and login on Nodes. +- Added actions to establish SSH connections, send commands remotely and terminate SSH connections. +- Added actions to change users' passwords. - Added a `listen_on_ports` set in the `IOSoftware` class to enable software listening on ports in addition to the main port they're assigned. - Added two new red applications: ``C2Beacon`` and ``C2Server`` which aim to simulate malicious network infrastructure. Refer to the ``Command and Control Application Suite E2E Demonstration`` notebook for more information. +- Added reward calculation details to AgentHistoryItem. ### Changed +- File and folder observations can now be configured to always show the true health status, or require scanning like before. +- Node observations can now be configured to show the number of active local and remote logins. + +### Fixed +- Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config) - Updated `SoftwareManager` `install` and `uninstall` to handle all functionality that was being done at the `install` and `uninstall` methods in the `Node` class. - Updated the `receive_payload_from_session_manager` method in `SoftwareManager` so that it now sends a copy of the diff --git a/benchmark/primaite_benchmark.py b/benchmark/primaite_benchmark.py index 0e6c2acc..86ed22a9 100644 --- a/benchmark/primaite_benchmark.py +++ b/benchmark/primaite_benchmark.py @@ -5,7 +5,7 @@ from datetime import datetime from pathlib import Path from typing import Any, Dict, Final, Tuple -from report import build_benchmark_md_report +from report import build_benchmark_md_report, md2pdf from stable_baselines3 import PPO import primaite @@ -159,6 +159,13 @@ def run( learning_rate: float = 3e-4, ) -> None: """Run the PrimAITE benchmark.""" + # generate report folder + v_str = f"v{primaite.__version__}" + + version_result_dir = _RESULTS_ROOT / v_str + version_result_dir.mkdir(exist_ok=True, parents=True) + output_path = version_result_dir / f"PrimAITE {v_str} Benchmark Report.md" + benchmark_start_time = datetime.now() session_metadata_dict = {} @@ -193,6 +200,12 @@ def run( session_metadata=session_metadata_dict, config_path=data_manipulation_config_path(), results_root_path=_RESULTS_ROOT, + output_path=output_path, + ) + md2pdf( + md_path=output_path, + pdf_path=str(output_path).replace(".md", ".pdf"), + css_path="static/styles.css", ) diff --git a/benchmark/report.py b/benchmark/report.py index e1ff46b9..4035ceca 100644 --- a/benchmark/report.py +++ b/benchmark/report.py @@ -2,6 +2,7 @@ import json import sys from datetime import datetime +from os import PathLike from pathlib import Path from typing import Dict, Optional @@ -14,7 +15,7 @@ from utils import _get_system_info import primaite PLOT_CONFIG = { - "size": {"auto_size": False, "width": 1500, "height": 900}, + "size": {"auto_size": False, "width": 800, "height": 640}, "template": "plotly_white", "range_slider": False, } @@ -144,6 +145,20 @@ def _plot_benchmark_metadata( yaxis={"title": "Total Reward"}, title=title, ) + fig.update_layout( + legend=dict( + yanchor="top", + y=0.99, + xanchor="left", + x=0.01, + bgcolor="rgba(255,255,255,0.3)", + ) + ) + for trace in fig["data"]: + if trace["name"].startswith("Session"): + trace["showlegend"] = False + fig["data"][0]["name"] = "Individual Sessions" + fig["data"][0]["showlegend"] = True return fig @@ -194,6 +209,7 @@ def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure: title=title, ) fig["data"][0]["showlegend"] = True + fig.update_layout(legend=dict(yanchor="top", y=-0.2, xanchor="left", x=0.01, orientation="h")) return fig @@ -248,14 +264,7 @@ def _plot_av_s_per_100_steps_10_nodes( versions = sorted(list(version_times_dict.keys())) times = [version_times_dict[version] for version in versions] - fig.add_trace( - go.Bar( - x=versions, - y=times, - text=times, - textposition="auto", - ) - ) + fig.add_trace(go.Bar(x=versions, y=times, text=times, textposition="auto", texttemplate="%{y:.3f}")) fig.update_layout( xaxis_title="PrimAITE Version", @@ -267,7 +276,11 @@ def _plot_av_s_per_100_steps_10_nodes( def build_benchmark_md_report( - benchmark_start_time: datetime, session_metadata: Dict, config_path: Path, results_root_path: Path + benchmark_start_time: datetime, + session_metadata: Dict, + config_path: Path, + results_root_path: Path, + output_path: PathLike, ) -> None: """ Generates a Markdown report for a benchmarking session, documenting performance metrics and graphs. @@ -319,7 +332,7 @@ def build_benchmark_md_report( data = benchmark_metadata_dict primaite_version = data["primaite_version"] - with open(version_result_dir / f"PrimAITE v{primaite_version} Benchmark Report.md", "w") as file: + with open(output_path, "w") as file: # Title file.write(f"# PrimAITE v{primaite_version} Learning Benchmark\n") file.write("## PrimAITE Dev Team\n") @@ -393,3 +406,15 @@ def build_benchmark_md_report( f"![Performance of Minor and Bugfix Releases for Major Version {major_v}]" f"({performance_benchmark_plot_path.name})\n" ) + + +def md2pdf(md_path: PathLike, pdf_path: PathLike, css_path: PathLike) -> None: + """Generate PDF version of Markdown report.""" + from md2pdf.core import md2pdf + + md2pdf( + pdf_file_path=pdf_path, + md_file_path=md_path, + base_url=Path(md_path).parent, + css_file_path=css_path, + ) diff --git a/benchmark/static/styles.css b/benchmark/static/styles.css new file mode 100644 index 00000000..4fbb9bd5 --- /dev/null +++ b/benchmark/static/styles.css @@ -0,0 +1,34 @@ +body { + font-family: 'Arial', sans-serif; + line-height: 1.6; + /* margin: 1cm; */ +} +h1, h2, h3, h4, h5, h6 { + font-weight: bold; + /* margin: 1em 0; */ +} +p { + /* margin: 0.5em 0; */ +} +ul, ol { + margin: 1em 0; + padding-left: 1.5em; +} +pre { + background: #f4f4f4; + padding: 0.5em; + overflow-x: auto; +} +img { + max-width: 100%; + height: auto; +} +table { + width: 100%; + border-collapse: collapse; + margin: 1em 0; +} +th, td { + padding: 0.5em; + border: 1px solid #ddd; +} diff --git a/pyproject.toml b/pyproject.toml index c9b7c062..354df8b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,8 @@ dev = [ "wheel==0.38.4", "nbsphinx==0.9.4", "nbmake==1.5.4", - "pytest-xdist==3.3.1" + "pytest-xdist==3.3.1", + "md2pdf", ] [project.scripts] diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 654ac0ac..713c4eb2 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1096,6 +1096,10 @@ class ConfigureC2BeaconAction(AbstractAction): return cls.model_fields[info.field_name].default return v + +class NodeAccountsChangePasswordAction(AbstractAction): + """Action which changes the password for a user.""" + def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) @@ -1119,6 +1123,25 @@ class ConfigureC2BeaconAction(AbstractAction): class RansomwareConfigureC2ServerAction(AbstractAction): """Action which sends a command from the C2 Server to the C2 Beacon which configures a local RansomwareScript.""" + def form_request(self, node_id: str, username: str, current_password: str, new_password: str) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + node_name = self.manager.get_node_name_by_idx(node_id) + return [ + "network", + "node", + node_name, + "service", + "UserManager", + "change_password", + username, + current_password, + new_password, + ] + + +class NodeSessionsRemoteLoginAction(AbstractAction): + """Action which performs a remote session login.""" + def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) @@ -1135,6 +1158,25 @@ class RansomwareConfigureC2ServerAction(AbstractAction): class RansomwareLaunchC2ServerAction(AbstractAction): """Action which causes the C2 Server to send a command to the C2 Beacon to launch the RansomwareScript.""" + def form_request(self, node_id: str, username: str, password: str, remote_ip: str) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + node_name = self.manager.get_node_name_by_idx(node_id) + return [ + "network", + "node", + node_name, + "service", + "Terminal", + "ssh_to_remote", + username, + password, + remote_ip, + ] + + +class NodeSessionsRemoteLogoutAction(AbstractAction): + """Action which performs a remote session logout.""" + def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) @@ -1160,6 +1202,15 @@ class ExfiltrationC2ServerAction(AbstractAction): target_folder_name: str exfiltration_folder_name: Optional[str] + def form_request(self, node_id: str, remote_ip: str) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + node_name = self.manager.get_node_name_by_idx(node_id) + return ["network", "node", node_name, "service", "Terminal", "remote_logoff", remote_ip] + + +class NodeSendRemoteCommandAction(AbstractAction): + """Action which sends a terminal command to a remote node via SSH.""" + def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) @@ -1219,6 +1270,20 @@ class TerminalC2ServerAction(AbstractAction): TerminalC2ServerAction._Opts.model_validate(command_model) return ["network", "node", node_name, "application", "C2Server", "terminal_command", command_model] + def form_request(self, node_id: int, remote_ip: str, command: RequestFormat) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + node_name = self.manager.get_node_name_by_idx(node_id) + return [ + "network", + "node", + node_name, + "service", + "Terminal", + "send_remote_command", + remote_ip, + {"command": command}, + ] + class ActionManager: """Class which manages the action space for an agent.""" @@ -1276,6 +1341,10 @@ class ActionManager: "C2_SERVER_RANSOMWARE_CONFIGURE": RansomwareConfigureC2ServerAction, "C2_SERVER_TERMINAL_COMMAND": TerminalC2ServerAction, "C2_SERVER_DATA_EXFILTRATE": ExfiltrationC2ServerAction, + "NODE_ACCOUNTS_CHANGE_PASSWORD": NodeAccountsChangePasswordAction, + "SSH_TO_REMOTE": NodeSessionsRemoteLoginAction, + "SESSIONS_REMOTE_LOGOFF": NodeSessionsRemoteLogoutAction, + "NODE_SEND_REMOTE_COMMAND": NodeSendRemoteCommandAction, } """Dictionary which maps action type strings to the corresponding action class.""" diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index f57dc191..14b97821 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -36,6 +36,8 @@ class AgentHistoryItem(BaseModel): reward: Optional[float] = None + reward_info: Dict[str, Any] = {} + class AgentStartSettings(BaseModel): """Configuration values for when an agent starts performing actions.""" diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index cb48fe7d..1c73d026 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -23,8 +23,10 @@ class FileObservation(AbstractObservation, identifier="FILE"): """Name of the file, used for querying simulation state dictionary.""" include_num_access: Optional[bool] = None """Whether to include the number of accesses to the file in the observation.""" + file_system_requires_scan: Optional[bool] = None + """If True, the file must be scanned to update the health state. Tf False, the true state is always shown.""" - def __init__(self, where: WhereType, include_num_access: bool) -> None: + def __init__(self, where: WhereType, include_num_access: bool, file_system_requires_scan: bool) -> None: """ Initialise a file observation instance. @@ -34,9 +36,13 @@ class FileObservation(AbstractObservation, identifier="FILE"): :type where: WhereType :param include_num_access: Whether to include the number of accesses to the file in the observation. :type include_num_access: bool + :param file_system_requires_scan: If True, the file must be scanned to update the health state. Tf False, + the true state is always shown. + :type file_system_requires_scan: bool """ self.where: WhereType = where self.include_num_access: bool = include_num_access + self.file_system_requires_scan: bool = file_system_requires_scan self.default_observation: ObsType = {"health_status": 0} if self.include_num_access: @@ -74,7 +80,11 @@ class FileObservation(AbstractObservation, identifier="FILE"): file_state = access_from_nested_dict(state, self.where) if file_state is NOT_PRESENT_IN_STATE: return self.default_observation - obs = {"health_status": file_state["visible_status"]} + if self.file_system_requires_scan: + health_status = file_state["visible_status"] + else: + health_status = file_state["health_status"] + obs = {"health_status": health_status} if self.include_num_access: obs["num_access"] = self._categorise_num_access(file_state["num_access"]) return obs @@ -104,8 +114,15 @@ class FileObservation(AbstractObservation, identifier="FILE"): :type parent_where: WhereType, optional :return: Constructed file observation instance. :rtype: FileObservation + :param file_system_requires_scan: If True, the folder must be scanned to update the health state. Tf False, + the true state is always shown. + :type file_system_requires_scan: bool """ - return cls(where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access) + return cls( + where=parent_where + ["files", config.file_name], + include_num_access=config.include_num_access, + file_system_requires_scan=config.file_system_requires_scan, + ) class FolderObservation(AbstractObservation, identifier="FOLDER"): @@ -122,9 +139,16 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): """Number of spaces for file observations in this folder.""" include_num_access: Optional[bool] = None """Whether files in this folder should include the number of accesses in their observation.""" + file_system_requires_scan: Optional[bool] = None + """If True, the folder must be scanned to update the health state. Tf False, the true state is always shown.""" def __init__( - self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool + self, + where: WhereType, + files: Iterable[FileObservation], + num_files: int, + include_num_access: bool, + file_system_requires_scan: bool, ) -> None: """ Initialise a folder observation instance. @@ -138,12 +162,23 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): :type num_files: int :param include_num_access: Whether to include the number of accesses to files in the observation. :type include_num_access: bool + :param file_system_requires_scan: If True, the folder must be scanned to update the health state. Tf False, + the true state is always shown. + :type file_system_requires_scan: bool """ self.where: WhereType = where + self.file_system_requires_scan: bool = file_system_requires_scan + self.files: List[FileObservation] = files while len(self.files) < num_files: - self.files.append(FileObservation(where=None, include_num_access=include_num_access)) + self.files.append( + FileObservation( + where=None, + include_num_access=include_num_access, + file_system_requires_scan=self.file_system_requires_scan, + ) + ) while len(self.files) > num_files: truncated_file = self.files.pop() msg = f"Too many files in folder observation. Truncating file {truncated_file}" @@ -168,7 +203,10 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): if folder_state is NOT_PRESENT_IN_STATE: return self.default_observation - health_status = folder_state["health_status"] + if self.file_system_requires_scan: + health_status = folder_state["visible_status"] + else: + health_status = folder_state["health_status"] obs = {} @@ -209,6 +247,13 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): # pass down shared/common config items for file_config in config.files: file_config.include_num_access = config.include_num_access + file_config.file_system_requires_scan = config.file_system_requires_scan files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files] - return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access) + return cls( + where=where, + files=files, + num_files=config.num_files, + include_num_access=config.include_num_access, + file_system_requires_scan=config.file_system_requires_scan, + ) diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py index 4f1a9d90..42ceaff0 100644 --- a/src/primaite/game/agent/observations/firewall_observation.py +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -10,6 +10,7 @@ from primaite import getLogger from primaite.game.agent.observations.acl_observation import ACLObservation from primaite.game.agent.observations.nic_observations import PortObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE _LOGGER = getLogger(__name__) @@ -32,6 +33,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): """List of protocols for encoding ACLs.""" num_rules: Optional[int] = None """Number of rules ACL rules to show.""" + include_users: Optional[bool] = True + """If True, report user session information.""" def __init__( self, @@ -41,6 +44,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): port_list: List[int], protocol_list: List[str], num_rules: int, + include_users: bool, ) -> None: """ Initialise a firewall observation instance. @@ -58,9 +62,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): :type protocol_list: List[str] :param num_rules: Number of rules configured in the firewall. :type num_rules: int + :param include_users: If True, report user session information. + :type include_users: bool """ self.where: WhereType = where - + self.include_users: bool = include_users + self.max_users: int = 3 + """Maximum number of remote sessions observable, excess sessions are truncated.""" self.ports: List[PortObservation] = [ PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3) ] @@ -142,6 +150,9 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): :return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic. :rtype: ObsType """ + firewall_state = access_from_nested_dict(state, self.where) + if firewall_state is NOT_PRESENT_IN_STATE: + return self.default_observation obs = { "PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)}, "ACL": { @@ -159,6 +170,12 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): }, }, } + if self.include_users: + sess = firewall_state["services"]["UserSessionManager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } return obs @property @@ -218,4 +235,5 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): port_list=config.port_list, protocol_list=config.protocol_list, num_rules=config.num_rules, + include_users=config.include_users, ) diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index f9fd9b1a..4419ccc7 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -48,6 +48,12 @@ class HostObservation(AbstractObservation, identifier="HOST"): """A dict containing which traffic types are to be included in the observation.""" include_num_access: Optional[bool] = None """Whether to include the number of accesses to files observations on this host.""" + file_system_requires_scan: Optional[bool] = None + """ + If True, files and folders must be scanned to update the health state. If False, true state is always shown. + """ + include_users: Optional[bool] = True + """If True, report user session information.""" def __init__( self, @@ -64,6 +70,8 @@ class HostObservation(AbstractObservation, identifier="HOST"): include_nmne: bool, monitored_traffic: Optional[Dict], include_num_access: bool, + file_system_requires_scan: bool, + include_users: bool, ) -> None: """ Initialise a host observation instance. @@ -95,10 +103,18 @@ class HostObservation(AbstractObservation, identifier="HOST"): :type monitored_traffic: Dict :param include_num_access: Flag to include the number of accesses to files. :type include_num_access: bool + :param file_system_requires_scan: If True, the files and folders must be scanned to update the health state. + If False, the true state is always shown. + :type file_system_requires_scan: bool + :param include_users: If True, report user session information. + :type include_users: bool """ self.where: WhereType = where self.include_num_access = include_num_access + self.include_users = include_users + self.max_users: int = 3 + """Maximum number of remote sessions observable, excess sessions are truncated.""" # Ensure lists have lengths equal to specified counts by truncating or padding self.services: List[ServiceObservation] = services @@ -120,7 +136,13 @@ class HostObservation(AbstractObservation, identifier="HOST"): self.folders: List[FolderObservation] = folders while len(self.folders) < num_folders: self.folders.append( - FolderObservation(where=None, files=[], num_files=num_files, include_num_access=include_num_access) + FolderObservation( + where=None, + files=[], + num_files=num_files, + include_num_access=include_num_access, + file_system_requires_scan=file_system_requires_scan, + ) ) while len(self.folders) > num_folders: truncated_folder = self.folders.pop() @@ -151,6 +173,8 @@ class HostObservation(AbstractObservation, identifier="HOST"): if self.include_num_access: self.default_observation["num_file_creations"] = 0 self.default_observation["num_file_deletions"] = 0 + if self.include_users: + self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0} def observe(self, state: Dict) -> ObsType: """ @@ -178,6 +202,12 @@ class HostObservation(AbstractObservation, identifier="HOST"): if self.include_num_access: obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] + if self.include_users: + sess = node_state["services"]["UserSessionManager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } return obs @property @@ -202,6 +232,10 @@ class HostObservation(AbstractObservation, identifier="HOST"): if self.include_num_access: shape["num_file_creations"] = spaces.Discrete(4) shape["num_file_deletions"] = spaces.Discrete(4) + if self.include_users: + shape["users"] = spaces.Dict( + {"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)} + ) return spaces.Dict(shape) @classmethod @@ -226,6 +260,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): for folder_config in config.folders: folder_config.include_num_access = config.include_num_access folder_config.num_files = config.num_files + folder_config.file_system_requires_scan = config.file_system_requires_scan for nic_config in config.network_interfaces: nic_config.include_nmne = config.include_nmne @@ -257,4 +292,6 @@ class HostObservation(AbstractObservation, identifier="HOST"): include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic, include_num_access=config.include_num_access, + file_system_requires_scan=config.file_system_requires_scan, + include_users=config.include_users, ) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index f7bfcc99..e263cadb 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -44,6 +44,10 @@ class NodesObservation(AbstractObservation, identifier="NODES"): """A dict containing which traffic types are to be included in the observation.""" include_num_access: Optional[bool] = None """Flag to include the number of accesses.""" + file_system_requires_scan: bool = True + """If True, the folder must be scanned to update the health state. Tf False, the true state is always shown.""" + include_users: Optional[bool] = True + """If True, report user session information.""" num_ports: Optional[int] = None """Number of ports.""" ip_list: Optional[List[str]] = None @@ -187,6 +191,10 @@ class NodesObservation(AbstractObservation, identifier="NODES"): host_config.monitored_traffic = config.monitored_traffic if host_config.include_num_access is None: host_config.include_num_access = config.include_num_access + if host_config.file_system_requires_scan is None: + host_config.file_system_requires_scan = config.file_system_requires_scan + if host_config.include_users is None: + host_config.include_users = config.include_users for router_config in config.routers: if router_config.num_ports is None: @@ -201,6 +209,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): router_config.protocol_list = config.protocol_list if router_config.num_rules is None: router_config.num_rules = config.num_rules + if router_config.include_users is None: + router_config.include_users = config.include_users for firewall_config in config.firewalls: if firewall_config.ip_list is None: @@ -213,6 +223,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): firewall_config.protocol_list = config.protocol_list if firewall_config.num_rules is None: firewall_config.num_rules = config.num_rules + if firewall_config.include_users is None: + firewall_config.include_users = config.include_users hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts] routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers] diff --git a/src/primaite/game/agent/observations/router_observation.py b/src/primaite/game/agent/observations/router_observation.py index f1d4ec8e..d064936a 100644 --- a/src/primaite/game/agent/observations/router_observation.py +++ b/src/primaite/game/agent/observations/router_observation.py @@ -39,6 +39,8 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): """List of protocols for encoding ACLs.""" num_rules: Optional[int] = None """Number of rules ACL rules to show.""" + include_users: Optional[bool] = True + """If True, report user session information.""" def __init__( self, @@ -46,6 +48,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): ports: List[PortObservation], num_ports: int, acl: ACLObservation, + include_users: bool, ) -> None: """ Initialise a router observation instance. @@ -59,12 +62,16 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): :type num_ports: int :param acl: ACL observation representing the access control list of the router. :type acl: ACLObservation + :param include_users: If True, report user session information. + :type include_users: bool """ self.where: WhereType = where self.ports: List[PortObservation] = ports self.acl: ACLObservation = acl self.num_ports: int = num_ports - + self.include_users: bool = include_users + self.max_users: int = 3 + """Maximum number of remote sessions observable, excess sessions are truncated.""" while len(self.ports) < num_ports: self.ports.append(PortObservation(where=None)) while len(self.ports) > num_ports: @@ -95,6 +102,12 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): obs["ACL"] = self.acl.observe(state) if self.ports: obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)} + if self.include_users: + sess = router_state["services"]["UserSessionManager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } return obs @property @@ -143,4 +156,4 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports] acl = ACLObservation.from_config(config=config.acl, parent_where=where) - return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl) + return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl, include_users=config.include_users) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index c959ee5b..b913501d 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -47,7 +47,15 @@ class AbstractReward: @abstractmethod def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Calculate the reward for the current state.""" + """Calculate the reward for the current state. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ return 0.0 @classmethod @@ -67,7 +75,15 @@ class DummyReward(AbstractReward): """Dummy reward function component which always returns 0.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Calculate the reward for the current state.""" + """Calculate the reward for the current state. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ return 0.0 @classmethod @@ -109,8 +125,12 @@ class DatabaseFileIntegrity(AbstractReward): def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. - :param state: The current state of the simulation. + :param state: Current simulation state :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float """ database_file_state = access_from_nested_dict(state, self.location_in_state) if database_file_state is NOT_PRESENT_IN_STATE: @@ -283,6 +303,12 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): component will keep track of that information. In that case, it doesn't matter whether the last successful request returned was able to connect to the database server, because there has been an unsuccessful request since. + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float """ if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]: self._last_request_failed = last_action_response.response.status != "success" @@ -295,14 +321,11 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): # If the last request was actually sent, then check if the connection was established. db_state = access_from_nested_dict(state, self.location_in_state) if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state: - _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") + last_action_response.reward_info = {"reason": f"Can't calculate reward for {self.__class__.__name__}"} return 0.0 last_connection_successful = db_state["last_connection_successful"] - if last_connection_successful is False: - return -1.0 - elif last_connection_successful is True: - return 1.0 - return 0.0 + last_action_response.reward_info = {"last_connection_successful": last_connection_successful} + return 1.0 if last_connection_successful else -1.0 @classmethod def from_config(cls, config: Dict) -> AbstractReward: @@ -346,7 +369,15 @@ class SharedReward(AbstractReward): """Method that retrieves an agent's current reward given the agent's name.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Simply access the other agent's reward and return it.""" + """Simply access the other agent's reward and return it. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ return self.callback(self.agent_name) @classmethod @@ -379,7 +410,15 @@ class ActionPenalty(AbstractReward): self.do_nothing_penalty = do_nothing_penalty def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Calculate the penalty to be applied.""" + """Calculate the penalty to be applied. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ if last_action_response.action == "DONOTHING": return self.do_nothing_penalty else: diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 08164f22..4f73ad7b 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -990,6 +990,7 @@ class UserManager(Service): if user and user.password == current_password: user.password = new_password self.sys_log.info(f"{self.name}: Password changed for {username}") + self._user_session_manager._logout_user(user=user) return True self.sys_log.info(f"{self.name}: Password change failed for {username}") return False @@ -1027,6 +1028,10 @@ class UserManager(Service): self.sys_log.info(f"{self.name}: Failed to enable user: {username}") return False + @property + def _user_session_manager(self) -> "UserSessionManager": + return self.software_manager.software["UserSessionManager"] # noqa + class UserSession(SimComponent): """ @@ -1260,7 +1265,8 @@ class UserSessionManager(Service): :return: A dictionary representing the current state. """ state = super().describe_state() - state["active_remote_logins"] = len(self.remote_sessions) + state["current_local_user"] = None if not self.local_session else self.local_session.user.username + state["active_remote_sessions"] = list(self.remote_sessions.keys()) return state @property @@ -1440,6 +1446,19 @@ class UserSessionManager(Service): """ return self._logout(local=False, remote_session_id=remote_session_id) + def _logout_user(self, user: Union[str, User]) -> bool: + """End a user session by username or user object.""" + if isinstance(user, str): + user = self._user_manager.users[user] # grab user object from username + for sess_id, session in self.remote_sessions.items(): + if session.user is user: + self._logout(local=False, remote_session_id=sess_id) + return True + if self.local_user_logged_in and self.local_session.user is user: + self.local_logout() + return True + return False + @property def local_user_logged_in(self) -> bool: """ diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index df2098df..406facd1 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -23,6 +23,8 @@ from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service, ServiceOperatingState +# TODO 2824: Since remote terminal connections and remote user sessions are the same thing, we could refactor +# the terminal to leverage the user session manager's list. This way we avoid potential bugs and code ducplication class TerminalClientConnection(BaseModel): """ TerminalClientConnection Class. @@ -162,22 +164,6 @@ class Terminal(Service): def _init_request_manager(self) -> RequestManager: """Initialise Request manager.""" rm = super()._init_request_manager() - rm.add_request( - "send", - request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.send())), - ) - - def _login(request: RequestFormat, context: Dict) -> RequestResponse: - login = self._process_local_login(username=request[0], password=request[1]) - if login: - return RequestResponse( - status="success", - data={ - "ip_address": login.ip_address, - }, - ) - else: - return RequestResponse(status="failure", data={"reason": "Invalid login credentials"}) def _remote_login(request: RequestFormat, context: Dict) -> RequestResponse: login = self._send_remote_login(username=request[0], password=request[1], ip_address=request[2]) @@ -191,10 +177,34 @@ class Terminal(Service): else: return RequestResponse(status="failure", data={}) + rm.add_request( + "ssh_to_remote", + request_type=RequestType(func=_remote_login), + ) + + def _remote_logoff(request: RequestFormat, context: Dict) -> RequestResponse: + """Logoff from remote connection.""" + ip_address = IPv4Address(request[0]) + remote_connection = self._get_connection_from_ip(ip_address=ip_address) + if remote_connection: + outcome = self._disconnect(remote_connection.connection_uuid) + if outcome: + return RequestResponse( + status="success", + data={}, + ) + else: + return RequestResponse( + status="failure", + data={"reason": "No remote connection held."}, + ) + + rm.add_request("remote_logoff", request_type=RequestType(func=_remote_logoff)) + def remote_execute_request(request: RequestFormat, context: Dict) -> RequestResponse: """Execute an instruction.""" - command: str = request[0] - ip_address: IPv4Address = IPv4Address(request[1]) + ip_address: IPv4Address = IPv4Address(request[0]) + command: str = request[1]["command"] remote_connection = self._get_connection_from_ip(ip_address=ip_address) if remote_connection: outcome = remote_connection.execute(command) @@ -209,30 +219,11 @@ class Terminal(Service): data={}, ) - def _logoff(request: RequestFormat, context: Dict) -> RequestResponse: - """Logoff from connection.""" - connection_uuid = request[0] - self.parent.user_session_manager.local_logout(connection_uuid) - self._disconnect(connection_uuid) - return RequestResponse(status="success", data={}) - rm.add_request( - "Login", - request_type=RequestType(func=_login), - ) - - rm.add_request( - "Remote Login", - request_type=RequestType(func=_remote_login), - ) - - rm.add_request( - "Execute", + "send_remote_command", request_type=RequestType(func=remote_execute_request), ) - rm.add_request("Logoff", request_type=RequestType(func=_logoff)) - return rm def execute(self, command: List[Any]) -> Optional[RequestResponse]: @@ -280,13 +271,9 @@ class Terminal(Service): if self.operating_state != ServiceOperatingState.RUNNING: self.sys_log.warning(f"{self.name}: Cannot login as service is not running.") return None - connection_request_id = str(uuid4()) - self._client_connection_requests[connection_request_id] = None if ip_address: # Assuming that if IP is passed we are connecting to remote - return self._send_remote_login( - username=username, password=password, ip_address=ip_address, connection_request_id=connection_request_id - ) + return self._send_remote_login(username=username, password=password, ip_address=ip_address) else: return self._process_local_login(username=username, password=password) @@ -313,6 +300,9 @@ class Terminal(Service): def _check_client_connection(self, connection_id: str) -> bool: """Check that client_connection_id is valid.""" + if not self.parent.user_session_manager.validate_remote_session_uuid(connection_id): + self._disconnect(connection_id) + return False return connection_id in self._connections def _send_remote_login( @@ -320,32 +310,24 @@ class Terminal(Service): username: str, password: str, ip_address: IPv4Address, - connection_request_id: str, + connection_request_id: Optional[str] = None, is_reattempt: bool = False, ) -> Optional[RemoteTerminalConnection]: """Send a remote login attempt and connect to Node. :param: username: Username used to connect to the remote node. :type: username: str - :param: password: Password used to connect to the remote node :type: password: str - :param: ip_address: Target Node IP address for login attempt. :type: ip_address: IPv4Address - - :param: connection_request_id: Connection Request ID - :type: connection_request_id: str - + :param: connection_request_id: Connection Request ID, if not provided, a new one is generated + :type: connection_request_id: Optional[str] :param: is_reattempt: True if the request has been reattempted. Default False. :type: is_reattempt: Optional[bool] - :return: RemoteTerminalConnection: Connection Object for sending further commands if successful, else False. - """ - self.sys_log.info( - f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}" - ) + connection_request_id = connection_request_id or str(uuid4()) if is_reattempt: valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id) if valid_connection_request: @@ -360,6 +342,9 @@ class Terminal(Service): self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.") return None + self.sys_log.info( + f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}" + ) transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA user_details: SSHUserCredentials = SSHUserCredentials(username=username, password=password) diff --git a/tests/assets/configs/data_manipulation.yaml b/tests/assets/configs/data_manipulation.yaml new file mode 100644 index 00000000..97442903 --- /dev/null +++ b/tests/assets/configs/data_manipulation.yaml @@ -0,0 +1,942 @@ +io_settings: + save_agent_actions: true + save_step_metadata: false + save_pcap_logs: false + save_sys_logs: false + sys_log_level: WARNING + + +game: + max_episode_length: 128 + ports: + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + thresholds: + nmne: + high: 10 + medium: 5 + low: 0 + +agents: + - ref: client_2_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_2 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_2 + + - ref: client_1_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_1 + + + + + + - ref: data_manipulation_attacker + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: null + + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: DataManipulationBot + - node_name: client_2 + applications: + - application_name: DataManipulationBot + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 + + - ref: defender + team: BLUE + type: ProxyAgent + + observation_space: + type: CUSTOM + options: + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1:eth-1<->switch_1:eth-8 + - router_1:eth-2<->switch_2:eth-8 + - switch_1:eth-1<->domain_controller:eth-1 + - switch_1:eth-2<->web_server:eth-1 + - switch_1:eth-3<->database_server:eth-1 + - switch_1:eth-4<->backup_server:eth-1 + - switch_1:eth-7<->security_suite:eth-1 + - switch_2:eth-1<->client_1:eth-1 + - switch_2:eth-2<->client_2:eth-1 + - switch_2:eth-7<->security_suite:eth-2 + - type: "NONE" + label: ICS + options: {} + + action_space: + action_list: + - type: DONOTHING + - type: NODE_SERVICE_SCAN + - type: NODE_SERVICE_STOP + - type: NODE_SERVICE_START + - type: NODE_SERVICE_PAUSE + - type: NODE_SERVICE_RESUME + - type: NODE_SERVICE_RESTART + - type: NODE_SERVICE_DISABLE + - type: NODE_SERVICE_ENABLE + - type: NODE_SERVICE_FIX + - type: NODE_FILE_SCAN + - type: NODE_FILE_CHECKHASH + - type: NODE_FILE_DELETE + - type: NODE_FILE_REPAIR + - type: NODE_FILE_RESTORE + - type: NODE_FOLDER_SCAN + - type: NODE_FOLDER_CHECKHASH + - type: NODE_FOLDER_REPAIR + - type: NODE_FOLDER_RESTORE + - type: NODE_OS_SCAN + - type: NODE_SHUTDOWN + - type: NODE_STARTUP + - type: NODE_RESET + - type: ROUTER_ACL_ADDRULE + - type: ROUTER_ACL_REMOVERULE + - type: HOST_NIC_ENABLE + - type: HOST_NIC_DISABLE + + action_map: + 0: + action: DONOTHING + options: {} + # scan webapp service + 1: + action: NODE_SERVICE_SCAN + options: + node_id: 1 + service_id: 0 + # stop webapp service + 2: + action: NODE_SERVICE_STOP + options: + node_id: 1 + service_id: 0 + # start webapp service + 3: + action: "NODE_SERVICE_START" + options: + node_id: 1 + service_id: 0 + 4: + action: "NODE_SERVICE_PAUSE" + options: + node_id: 1 + service_id: 0 + 5: + action: "NODE_SERVICE_RESUME" + options: + node_id: 1 + service_id: 0 + 6: + action: "NODE_SERVICE_RESTART" + options: + node_id: 1 + service_id: 0 + 7: + action: "NODE_SERVICE_DISABLE" + options: + node_id: 1 + service_id: 0 + 8: + action: "NODE_SERVICE_ENABLE" + options: + node_id: 1 + service_id: 0 + 9: # check database.db file + action: "NODE_FILE_SCAN" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 10: + action: "NODE_FILE_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 11: + action: "NODE_FILE_DELETE" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 12: + action: "NODE_FILE_REPAIR" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 13: + action: "NODE_SERVICE_FIX" + options: + node_id: 2 + service_id: 0 + 14: + action: "NODE_FOLDER_SCAN" + options: + node_id: 2 + folder_id: 0 + 15: + action: "NODE_FOLDER_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + options: + node_id: 2 + folder_id: 0 + 16: + action: "NODE_FOLDER_REPAIR" + options: + node_id: 2 + folder_id: 0 + 17: + action: "NODE_FOLDER_RESTORE" + options: + node_id: 2 + folder_id: 0 + 18: + action: "NODE_OS_SCAN" + options: + node_id: 0 + 19: + action: "NODE_SHUTDOWN" + options: + node_id: 0 + 20: + action: NODE_STARTUP + options: + node_id: 0 + 21: + action: NODE_RESET + options: + node_id: 0 + 22: + action: "NODE_OS_SCAN" + options: + node_id: 1 + 23: + action: "NODE_SHUTDOWN" + options: + node_id: 1 + 24: + action: NODE_STARTUP + options: + node_id: 1 + 25: + action: NODE_RESET + options: + node_id: 1 + 26: # old action num: 18 + action: "NODE_OS_SCAN" + options: + node_id: 2 + 27: + action: "NODE_SHUTDOWN" + options: + node_id: 2 + 28: + action: NODE_STARTUP + options: + node_id: 2 + 29: + action: NODE_RESET + options: + node_id: 2 + 30: + action: "NODE_OS_SCAN" + options: + node_id: 3 + 31: + action: "NODE_SHUTDOWN" + options: + node_id: 3 + 32: + action: NODE_STARTUP + options: + node_id: 3 + 33: + action: NODE_RESET + options: + node_id: 3 + 34: + action: "NODE_OS_SCAN" + options: + node_id: 4 + 35: + action: "NODE_SHUTDOWN" + options: + node_id: 4 + 36: + action: NODE_STARTUP + options: + node_id: 4 + 37: + action: NODE_RESET + options: + node_id: 4 + 38: + action: "NODE_OS_SCAN" + options: + node_id: 5 + 39: # old action num: 19 # shutdown client 1 + action: "NODE_SHUTDOWN" + options: + node_id: 5 + 40: # old action num: 20 + action: NODE_STARTUP + options: + node_id: 5 + 41: # old action num: 21 + action: NODE_RESET + options: + node_id: 5 + 42: + action: "NODE_OS_SCAN" + options: + node_id: 6 + 43: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 44: + action: NODE_STARTUP + options: + node_id: 6 + 45: + action: NODE_RESET + options: + node_id: 6 + + 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 1 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 2 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 48: # old action num: 24 # block tcp traffic from client 1 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 3 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 49: # old action num: 25 # block tcp traffic from client 2 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 4 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 50: # old action num: 26 + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 5 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 51: # old action num: 27 + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 6 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 52: # old action num: 28 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 0 + 53: # old action num: 29 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 1 + 54: # old action num: 30 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 2 + 55: # old action num: 31 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 3 + 56: # old action num: 32 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 4 + 57: # old action num: 33 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 5 + 58: # old action num: 34 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 6 + 59: # old action num: 35 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 7 + 60: # old action num: 36 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 8 + 61: # old action num: 37 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 9 + 62: # old action num: 38 + action: "HOST_NIC_DISABLE" + options: + node_id: 0 + nic_id: 0 + 63: # old action num: 39 + action: "HOST_NIC_ENABLE" + options: + node_id: 0 + nic_id: 0 + 64: # old action num: 40 + action: "HOST_NIC_DISABLE" + options: + node_id: 1 + nic_id: 0 + 65: # old action num: 41 + action: "HOST_NIC_ENABLE" + options: + node_id: 1 + nic_id: 0 + 66: # old action num: 42 + action: "HOST_NIC_DISABLE" + options: + node_id: 2 + nic_id: 0 + 67: # old action num: 43 + action: "HOST_NIC_ENABLE" + options: + node_id: 2 + nic_id: 0 + 68: # old action num: 44 + action: "HOST_NIC_DISABLE" + options: + node_id: 3 + nic_id: 0 + 69: # old action num: 45 + action: "HOST_NIC_ENABLE" + options: + node_id: 3 + nic_id: 0 + 70: # old action num: 46 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 0 + 71: # old action num: 47 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 0 + 72: # old action num: 48 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 1 + 73: # old action num: 49 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 1 + 74: # old action num: 50 + action: "HOST_NIC_DISABLE" + options: + node_id: 5 + nic_id: 0 + 75: # old action num: 51 + action: "HOST_NIC_ENABLE" + options: + node_id: 5 + nic_id: 0 + 76: # old action num: 52 + action: "HOST_NIC_DISABLE" + options: + node_id: 6 + nic_id: 0 + 77: # old action num: 53 + action: "HOST_NIC_ENABLE" + options: + node_id: 6 + nic_id: 0 + + + + options: + nodes: + - node_name: domain_controller + - node_name: web_server + applications: + - application_name: DatabaseClient + services: + - service_name: WebServer + - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db + services: + - service_name: DatabaseService + - node_name: backup_server + - node_name: security_suite + - node_name: client_1 + - node_name: client_2 + + max_folders_per_node: 2 + max_files_per_folder: 2 + max_services_per_node: 2 + max_nics_per_node: 8 + max_acl_rules: 10 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + + + reward_function: + reward_components: + - type: DATABASE_FILE_INTEGRITY + weight: 0.40 + options: + node_hostname: database_server + folder_name: database + file_name: database.db + + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_1_green_user + + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_2_green_user + + agent_settings: + flatten_obs: true + action_masking: true + + + + + +simulation: + network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE + nodes: + + - hostname: router_1 + type: router + num_ports: 5 + ports: + 1: + ip_address: 192.168.1.1 + subnet_mask: 255.255.255.0 + 2: + ip_address: 192.168.10.1 + subnet_mask: 255.255.255.0 + acl: + 18: + action: PERMIT + src_port: POSTGRES_SERVER + dst_port: POSTGRES_SERVER + 19: + action: PERMIT + src_port: DNS + dst_port: DNS + 20: + action: PERMIT + src_port: FTP + dst_port: FTP + 21: + action: PERMIT + src_port: HTTP + dst_port: HTTP + 22: + action: PERMIT + src_port: ARP + dst_port: ARP + 23: + action: PERMIT + protocol: ICMP + + - hostname: switch_1 + type: switch + num_ports: 8 + + - hostname: switch_2 + type: switch + num_ports: 8 + + - hostname: domain_controller + type: server + ip_address: 192.168.1.10 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + services: + - type: DNSServer + options: + domain_mapping: + arcd.com: 192.168.1.12 # web server + + - hostname: web_server + type: server + ip_address: 192.168.1.12 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: WebServer + applications: + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + + + - hostname: database_server + type: server + ip_address: 192.168.1.14 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: DatabaseService + options: + backup_server_ip: 192.168.1.16 + - type: FTPClient + + - hostname: backup_server + type: server + ip_address: 192.168.1.16 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: FTPServer + + - hostname: security_suite + type: server + ip_address: 192.168.1.110 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + network_interfaces: + 2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot + ip_address: 192.168.10.110 + subnet_mask: 255.255.255.0 + + - hostname: client_1 + type: computer + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + services: + - type: DNSClient + + - hostname: client_2 + type: computer + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + services: + - type: DNSClient + + links: + - endpoint_a_hostname: router_1 + endpoint_a_port: 1 + endpoint_b_hostname: switch_1 + endpoint_b_port: 8 + - endpoint_a_hostname: router_1 + endpoint_a_port: 2 + endpoint_b_hostname: switch_2 + endpoint_b_port: 8 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 1 + endpoint_b_hostname: domain_controller + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 2 + endpoint_b_hostname: web_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 3 + endpoint_b_hostname: database_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 4 + endpoint_b_hostname: backup_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 7 + endpoint_b_hostname: security_suite + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 1 + endpoint_b_hostname: client_1 + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 2 + endpoint_b_hostname: client_2 + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 7 + endpoint_b_hostname: security_suite + endpoint_b_port: 2 diff --git a/tests/conftest.py b/tests/conftest.py index 1328bc9c..1bbff8f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -463,6 +463,10 @@ def game_and_agent(): {"type": "C2_SERVER_RANSOMWARE_CONFIGURE"}, {"type": "C2_SERVER_TERMINAL_COMMAND"}, {"type": "C2_SERVER_DATA_EXFILTRATE"}, + {"type": "NODE_ACCOUNTS_CHANGE_PASSWORD"}, + {"type": "SSH_TO_REMOTE"}, + {"type": "SESSIONS_REMOTE_LOGOFF"}, + {"type": "NODE_SEND_REMOTE_COMMAND"}, ] action_space = ActionManager( diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py new file mode 100644 index 00000000..d011c1e8 --- /dev/null +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -0,0 +1,166 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Tuple + +import pytest + +from primaite.game.agent.interface import ProxyAgent +from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.base import UserManager +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.hardware.nodes.network.router import ACLAction +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection + + +@pytest.fixture +def game_and_agent_fixture(game_and_agent): + """Create a game with a simple agent that can be controlled by the tests.""" + game, agent = game_and_agent + + router = game.simulation.network.get_node_by_hostname("router") + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=4) + + return (game, agent) + + +def test_remote_login(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + # create a new user account on server_1 that will be logged into remotely + server_1_usm: UserManager = server_1.software_manager.software["UserManager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "SSH_TO_REMOTE", + { + "node_id": 0, + "username": "user123", + "password": "password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "success" + + connection_established = False + for conn_str, conn_obj in client_1.terminal.connections.items(): + conn_obj: RemoteTerminalConnection + if conn_obj.ip_address == server_1.network_interface[1].ip_address: + connection_established = True + if not connection_established: + pytest.fail("Remote SSH connection could not be established") + + +def test_remote_login_wrong_password(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + # create a new user account on server_1 that will be logged into remotely + server_1_usm: UserManager = server_1.software_manager.software["UserManager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "SSH_TO_REMOTE", + { + "node_id": 0, + "username": "user123", + "password": "wrong_password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "failure" + + connection_established = False + for conn_str, conn_obj in client_1.terminal.connections.items(): + conn_obj: RemoteTerminalConnection + if conn_obj.ip_address == server_1.network_interface[1].ip_address: + connection_established = True + if connection_established: + pytest.fail("Remote SSH connection was established despite wrong password") + + +def test_remote_login_change_password(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + # create a new user account on server_1 that will be logged into remotely + server_1_um: UserManager = server_1.software_manager.software["UserManager"] + server_1_um.add_user("user123", "password", is_admin=True) + + action = ( + "NODE_ACCOUNTS_CHANGE_PASSWORD", + { + "node_id": 1, # server_1 + "username": "user123", + "current_password": "password", + "new_password": "different_password", + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "success" + assert server_1_um.users["user123"].password == "different_password" + + +def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + # create a new user account on server_1 that will be logged into remotely + server_1_usm: UserManager = server_1.software_manager.software["UserManager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + # Log in remotely + action = ( + "SSH_TO_REMOTE", + { + "node_id": 0, + "username": "user123", + "password": "password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + + # Change password + action = ( + "NODE_ACCOUNTS_CHANGE_PASSWORD", + { + "node_id": 1, # server_1 + "username": "user123", + "current_password": "password", + "new_password": "different_password", + }, + ) + agent.store_action(action) + game.step() + + # Assert that the user cannot execute an action + action = ( + "NODE_SEND_REMOTE_COMMAND", + { + "node_id": 0, + "remote_ip": str(server_1.network_interface[1].ip_address), + "command": ["file_system", "create", "file", "folder123", "doggo.pdf", False], + }, + ) + agent.store_action(action) + game.step() + + assert server_1.file_system.get_folder("folder123") is None + assert server_1.file_system.get_file("folder123", "doggo.pdf") is None diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py index 1031dcb0..e2ab2990 100644 --- a/tests/integration_tests/game_layer/observations/test_file_system_observations.py +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -26,6 +26,7 @@ def test_file_observation(simulation): dog_file_obs = FileObservation( where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], include_num_access=False, + file_system_requires_scan=True, ) assert dog_file_obs.space["health_status"] == spaces.Discrete(6) @@ -53,6 +54,7 @@ def test_folder_observation(simulation): root_folder_obs = FolderObservation( where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"], include_num_access=False, + file_system_requires_scan=True, num_files=1, files=[], ) diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 99417e33..34a37f5e 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -33,6 +33,7 @@ def test_firewall_observation(): wildcard_list=["0.0.0.255", "0.0.0.1"], port_list=["HTTP", "DNS"], protocol_list=["TCP"], + include_users=False, ) observation = firewall_observation.observe(firewall.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index 8a36ea5c..69d9f106 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -38,6 +38,8 @@ def test_host_observation(simulation): applications=[], folders=[], network_interfaces=[], + file_system_requires_scan=True, + include_users=False, ) assert host_obs.space["operating_status"] == spaces.Discrete(5) diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index c534307f..48d29cfb 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -27,7 +27,7 @@ def test_router_observation(): port_list=["HTTP", "DNS"], protocol_list=["TCP"], ) - router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl) + router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl, include_users=False) # Observe the state using the RouterObservation instance observed_output = router_observation.observe(router.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_user_observations.py b/tests/integration_tests/game_layer/observations/test_user_observations.py new file mode 100644 index 00000000..ca5e2543 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_user_observations.py @@ -0,0 +1,89 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest + +from primaite.session.environment import PrimaiteGymEnv +from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router +from primaite.simulator.network.transmission.transport_layer import Port +from tests import TEST_ASSETS_ROOT + +DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml" + + +@pytest.fixture +def env_with_ssh() -> PrimaiteGymEnv: + """Build data manipulation environment with SSH port open on router.""" + env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG) + env.agent.flatten_obs = False + router: Router = env.game.simulation.network.get_node_by_hostname("router_1") + router.acl.add_rule(ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=3) + return env + + +def extract_login_numbers_from_obs(obs): + """Traverse the observation dictionary and return number of user sessions for all nodes.""" + login_nums = {} + for node_name, node_obs in obs["NODES"].items(): + login_nums[node_name] = node_obs.get("users") + return login_nums + + +class TestUserObservations: + """Test that the RouterObservation, FirewallObservation, and HostObservation have the correct number of logins.""" + + def test_no_sessions_at_episode_start(self, env_with_ssh): + """Test that all of the login observations start at 0 before any logins occur.""" + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + for o in logins_obs.values(): + assert o["local_login"] == 0 + assert o["remote_sessions"] == 0 + + def test_single_login(self, env_with_ssh: PrimaiteGymEnv): + """Test that performing a remote login increases the remote_sessions observation by 1.""" + client_1 = env_with_ssh.game.simulation.network.get_node_by_hostname("client_1") + client_1.terminal._send_remote_login("admin", "admin", "192.168.1.14") # connect to database server via ssh + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server + assert db_srv_logins_obs["local_login"] == 0 + assert db_srv_logins_obs["remote_sessions"] == 1 + for o in logins_obs.values(): # the remaining obs after popping HOST2 + assert o["local_login"] == 0 + assert o["remote_sessions"] == 0 + + def test_logout(self, env_with_ssh: PrimaiteGymEnv): + """Test that remote_sessions observation correctly decreases upon logout.""" + client_1 = env_with_ssh.game.simulation.network.get_node_by_hostname("client_1") + client_1.terminal._send_remote_login("admin", "admin", "192.168.1.14") # connect to database server via ssh + db_srv = env_with_ssh.game.simulation.network.get_node_by_hostname("database_server") + db_srv.user_manager.change_user_password("admin", "admin", "different_pass") # changing password logs out user + + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + for o in logins_obs.values(): + assert o["local_login"] == 0 + assert o["remote_sessions"] == 0 + + def test_max_observable_sessions(self, env_with_ssh: PrimaiteGymEnv): + """Log in from 5 remote places and check that only a max of 3 is shown in the observation.""" + MAX_OBSERVABLE_SESSIONS = 3 + # Right now this is hardcoded as 3 in HostObservation, FirewallObservation, and RouterObservation + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server + + db_srv = env_with_ssh.game.simulation.network.get_node_by_hostname("database_server") + db_srv.user_session_manager.remote_session_timeout_steps = 20 + db_srv.user_session_manager.max_remote_sessions = 5 + node_names = ("client_1", "client_2", "backup_server", "security_suite", "domain_controller") + + for i, node_name in enumerate(node_names): + node = env_with_ssh.game.simulation.network.get_node_by_hostname(node_name) + node.terminal._send_remote_login("admin", "admin", "192.168.1.14") + + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server + + assert db_srv_logins_obs["remote_sessions"] == min(MAX_OBSERVABLE_SESSIONS, i + 1) + assert len(db_srv.user_session_manager.remote_sessions) == i + 1 diff --git a/tests/integration_tests/game_layer/test_observations.py b/tests/integration_tests/game_layer/test_observations.py index ff83c532..d5679007 100644 --- a/tests/integration_tests/game_layer/test_observations.py +++ b/tests/integration_tests/game_layer/test_observations.py @@ -17,6 +17,7 @@ def test_file_observation(): dog_file_obs = FileObservation( where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], include_num_access=False, + file_system_requires_scan=True, ) assert dog_file_obs.observe(state) == {"health_status": 1} assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)}) diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 2bf551c8..e945f482 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -76,13 +76,16 @@ def test_uc2_rewards(game_and_agent): ] ) state = game.get_sim_state() - reward_value = comp.calculate( - state, - last_action_response=AgentHistoryItem( - timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response - ), + ahi = AgentHistoryItem( + timestep=0, + action="NODE_APPLICATION_EXECUTE", + parameters={}, + request=["execute"], + response=response, ) + reward_value = comp.calculate(state, last_action_response=ahi) assert reward_value == 1.0 + assert ahi.reward_info == {"last_connection_successful": True} router.acl.remove_rule(position=2) @@ -92,13 +95,9 @@ def test_uc2_rewards(game_and_agent): ] ) state = game.get_sim_state() - reward_value = comp.calculate( - state, - last_action_response=AgentHistoryItem( - timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response - ), - ) + reward_value = comp.calculate(state, last_action_response=ahi) assert reward_value == -1.0 + assert ahi.reward_info == {"last_connection_successful": False} def test_shared_reward(): diff --git a/tests/unit_tests/_primaite/_game/_agent/test_observations.py b/tests/unit_tests/_primaite/_game/_agent/test_observations.py new file mode 100644 index 00000000..7f590685 --- /dev/null +++ b/tests/unit_tests/_primaite/_game/_agent/test_observations.py @@ -0,0 +1,132 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import List + +import pytest +import yaml + +from primaite.game.agent.observations import ObservationManager +from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation +from primaite.game.agent.observations.host_observations import HostObservation + + +class TestFileSystemRequiresScan: + @pytest.mark.parametrize( + ("yaml_option_string", "expected_val"), + ( + ("file_system_requires_scan: true", True), + ("file_system_requires_scan: false", False), + (" ", True), + ), + ) + def test_obs_config(self, yaml_option_string, expected_val): + """Check that the default behaviour is to set FileSystemRequiresScan to True.""" + obs_cfg_yaml = f""" + type: CUSTOM + options: + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + {yaml_option_string} + include_nmne: true + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1:eth-1<->switch_1:eth-8 + - router_1:eth-2<->switch_2:eth-8 + - switch_1:eth-1<->domain_controller:eth-1 + - switch_1:eth-2<->web_server:eth-1 + - switch_1:eth-3<->database_server:eth-1 + - switch_1:eth-4<->backup_server:eth-1 + - switch_1:eth-7<->security_suite:eth-1 + - switch_2:eth-1<->client_1:eth-1 + - switch_2:eth-2<->client_2:eth-1 + - switch_2:eth-7<->security_suite:eth-2 + - type: "NONE" + label: ICS + options: {{}} + + """ + + cfg = yaml.safe_load(obs_cfg_yaml) + manager = ObservationManager.from_config(cfg) + + hosts: List[HostObservation] = manager.obs.components["NODES"].hosts + for i, host in enumerate(hosts): + folders: List[FolderObservation] = host.folders + for j, folder in enumerate(folders): + assert folder.file_system_requires_scan == expected_val # Make sure folders require scan by default + files: List[FileObservation] = folder.files + for k, file in enumerate(files): + assert file.file_system_requires_scan == expected_val + + def test_file_require_scan(self): + file_state = {"health_status": 3, "visible_status": 1} + + obs_requiring_scan = FileObservation([], include_num_access=False, file_system_requires_scan=True) + assert obs_requiring_scan.observe(file_state)["health_status"] == 1 + + obs_not_requiring_scan = FileObservation([], include_num_access=False, file_system_requires_scan=False) + assert obs_not_requiring_scan.observe(file_state)["health_status"] == 3 + + def test_folder_require_scan(self): + folder_state = {"health_status": 3, "visible_status": 1} + + obs_requiring_scan = FolderObservation( + [], files=[], num_files=0, include_num_access=False, file_system_requires_scan=True + ) + assert obs_requiring_scan.observe(folder_state)["health_status"] == 1 + + obs_not_requiring_scan = FolderObservation( + [], files=[], num_files=0, include_num_access=False, file_system_requires_scan=False + ) + assert obs_not_requiring_scan.observe(folder_state)["health_status"] == 3