From 322a691e53f5658398a2e778d2b820b604ffa04a Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 2 Aug 2024 23:21:35 +0100 Subject: [PATCH 01/13] #2768 - Added listen_on_ports attribute to IOSoftware. updated software manager so that it sends copies of payloads to listening ports too. Added integration test that installs a listening service to snoop on DB traffic. --- .../simulator/system/core/software_manager.py | 23 +++++-- .../services/database/database_service.py | 2 + src/primaite/simulator/system/software.py | 5 +- .../system/test_service_listening_on_ports.py | 64 +++++++++++++++++++ 4 files changed, 87 insertions(+), 7 deletions(-) create mode 100644 tests/integration_tests/system/test_service_listening_on_ports.py diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index e00afba6..7b36097b 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -1,4 +1,5 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from copy import deepcopy from ipaddress import IPv4Address, IPv4Network from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union @@ -76,6 +77,8 @@ class SoftwareManager: for software in self.port_protocol_mapping.values(): if software.operating_state in {ApplicationOperatingState.RUNNING, ServiceOperatingState.RUNNING}: open_ports.append(software.port) + if software.listen_on_ports: + open_ports += list(software.listen_on_ports) return open_ports def check_port_is_open(self, port: Port, protocol: IPProtocol) -> bool: @@ -223,7 +226,9 @@ class SoftwareManager: frame: Frame, ): """ - Receive a payload from the SessionManager and forward it to the corresponding service or application. + Receive a payload from the SessionManager and forward it to the corresponding service or applications. + + This function handles both software assigned a specific port, and software listening in on other ports. :param payload: The payload being received. :param session: The transport session the payload originates from. @@ -231,11 +236,17 @@ class SoftwareManager: if payload.__class__.__name__ == "PortScanPayload": self.software.get("NMAP").receive(payload=payload, session_id=session_id) return - receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None) - if receiver: - receiver.receive( - payload=payload, session_id=session_id, from_network_interface=from_network_interface, frame=frame - ) + main_receiver = self.port_protocol_mapping.get((port, protocol), None) + listening_receivers = [software for software in self.software.values() if port in software.listen_on_ports] + receivers = [main_receiver] + listening_receivers if main_receiver else listening_receivers + if receivers: + for receiver in receivers: + receiver.receive( + payload=deepcopy(payload), + session_id=session_id, + from_network_interface=from_network_interface, + frame=frame, + ) else: self.sys_log.warning(f"No service or application found for port {port} and protocol {protocol}") pass diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 22ae0ff3..56edcf89 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -377,6 +377,8 @@ class DatabaseService(Service): ) else: result = {"status_code": 401, "type": "sql"} + else: + self.sys_log.info(f"{self.name}: Ignoring payload as it is not a Database payload") self.send(payload=result, session_id=session_id) return True diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 7c27534a..7a3d675c 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -4,9 +4,10 @@ from abc import abstractmethod from datetime import datetime from enum import Enum from ipaddress import IPv4Address, IPv4Network -from typing import Any, Dict, Optional, TYPE_CHECKING, Union +from typing import Any, Dict, Optional, Set, TYPE_CHECKING, Union from prettytable import MARKDOWN, PrettyTable +from pydantic import Field from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent @@ -252,6 +253,8 @@ class IOSoftware(Software): "Indicates if the software uses UDP protocol for communication. Default is True." port: Port "The port to which the software is connected." + listen_on_ports: Set[Port] = Field(default_factory=set) + "The set of ports to listen on." protocol: IPProtocol "The IP Protocol the Software operates on." _connections: Dict[str, Dict] = {} diff --git a/tests/integration_tests/system/test_service_listening_on_ports.py b/tests/integration_tests/system/test_service_listening_on_ports.py new file mode 100644 index 00000000..0cb1ad54 --- /dev/null +++ b/tests/integration_tests/system/test_service_listening_on_ports.py @@ -0,0 +1,64 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Any, Dict, List, Set + +from pydantic import Field + +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.services.database.database_service import DatabaseService +from primaite.simulator.system.services.service import Service + + +class _DatabaseListener(Service): + name: str = "DatabaseListener" + protocol: IPProtocol = IPProtocol.TCP + port: Port = Port.NONE + listen_on_ports: Set[Port] = {Port.POSTGRES_SERVER} + payloads_received: List[Any] = Field(default_factory=list) + + def receive(self, payload: Any, session_id: str, **kwargs) -> bool: + self.payloads_received.append(payload) + self.sys_log.info(f"{self.name}: received payload {payload}") + return True + + def describe_state(self) -> Dict: + return super().describe_state() + + +def test_http_listener(client_server): + computer, server = client_server + + server.software_manager.install(DatabaseService) + server_db = server.software_manager.software["DatabaseService"] + server_db.start() + + server.software_manager.install(_DatabaseListener) + server_db_listener: _DatabaseListener = server.software_manager.software["DatabaseListener"] + server_db_listener.start() + + computer.software_manager.install(DatabaseClient) + computer_db_client: DatabaseClient = computer.software_manager.software["DatabaseClient"] + + computer_db_client.run() + computer_db_client.server_ip_address = server.network_interface[1].ip_address + + assert len(server_db_listener.payloads_received) == 0 + computer.session_manager.receive_payload_from_software_manager( + payload="masquerade as Database traffic", + dst_ip_address=server.network_interface[1].ip_address, + dst_port=Port.POSTGRES_SERVER, + ip_protocol=IPProtocol.TCP, + ) + + assert len(server_db_listener.payloads_received) == 1 + + db_connection = computer_db_client.get_new_connection() + + assert db_connection + + assert len(server_db_listener.payloads_received) == 2 + + assert db_connection.query("SELECT") + + assert len(server_db_listener.payloads_received) == 3 From 368e846c8b59488746e56610727fd3d99bc54090 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 7 Aug 2024 10:07:19 +0100 Subject: [PATCH 02/13] 2772 - Generate pdf benchmark from --- benchmark/primaite_benchmark.py | 15 ++++++++++- benchmark/report.py | 47 +++++++++++++++++++++++++-------- benchmark/static/styles.css | 34 ++++++++++++++++++++++++ pyproject.toml | 3 ++- 4 files changed, 86 insertions(+), 13 deletions(-) create mode 100644 benchmark/static/styles.css diff --git a/benchmark/primaite_benchmark.py b/benchmark/primaite_benchmark.py index 0e6c2acc..2b09870d 100644 --- a/benchmark/primaite_benchmark.py +++ b/benchmark/primaite_benchmark.py @@ -5,7 +5,7 @@ from datetime import datetime from pathlib import Path from typing import Any, Dict, Final, Tuple -from report import build_benchmark_md_report +from report import build_benchmark_md_report, md2pdf from stable_baselines3 import PPO import primaite @@ -159,6 +159,13 @@ def run( learning_rate: float = 3e-4, ) -> None: """Run the PrimAITE benchmark.""" + # generate report folder + v_str = f"v{primaite.__version__}" + + version_result_dir = _RESULTS_ROOT / v_str + version_result_dir.mkdir(exist_ok=True, parents=True) + output_path = version_result_dir / f"PrimAITE {v_str} Benchmark Report.md" + benchmark_start_time = datetime.now() session_metadata_dict = {} @@ -193,6 +200,12 @@ def run( session_metadata=session_metadata_dict, config_path=data_manipulation_config_path(), results_root_path=_RESULTS_ROOT, + output_path=output_path, + ) + md2pdf( + md_path=output_path, + pdf_path=str(output_path).replace(".md", ".pdf"), + css_path="benchmark/static/styles.css", ) diff --git a/benchmark/report.py b/benchmark/report.py index e1ff46b9..408e91cf 100644 --- a/benchmark/report.py +++ b/benchmark/report.py @@ -2,6 +2,7 @@ import json import sys from datetime import datetime +from os import PathLike from pathlib import Path from typing import Dict, Optional @@ -14,7 +15,7 @@ from utils import _get_system_info import primaite PLOT_CONFIG = { - "size": {"auto_size": False, "width": 1500, "height": 900}, + "size": {"auto_size": False, "width": 800, "height": 800}, "template": "plotly_white", "range_slider": False, } @@ -144,6 +145,20 @@ def _plot_benchmark_metadata( yaxis={"title": "Total Reward"}, title=title, ) + fig.update_layout( + legend=dict( + yanchor="top", + y=0.99, + xanchor="left", + x=0.01, + bgcolor="rgba(255,255,255,0.3)", + ) + ) + for trace in fig["data"]: + if trace["name"].startswith("Session"): + trace["showlegend"] = False + fig["data"][0]["name"] = "Individual Sessions" + fig["data"][0]["showlegend"] = True return fig @@ -194,6 +209,7 @@ def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure: title=title, ) fig["data"][0]["showlegend"] = True + fig.update_layout(legend=dict(yanchor="top", y=-0.2, xanchor="left", x=0.01, orientation="h")) return fig @@ -248,14 +264,7 @@ def _plot_av_s_per_100_steps_10_nodes( versions = sorted(list(version_times_dict.keys())) times = [version_times_dict[version] for version in versions] - fig.add_trace( - go.Bar( - x=versions, - y=times, - text=times, - textposition="auto", - ) - ) + fig.add_trace(go.Bar(x=versions, y=times, text=times, textposition="auto", texttemplate="%{y:.3f}")) fig.update_layout( xaxis_title="PrimAITE Version", @@ -267,7 +276,11 @@ def _plot_av_s_per_100_steps_10_nodes( def build_benchmark_md_report( - benchmark_start_time: datetime, session_metadata: Dict, config_path: Path, results_root_path: Path + benchmark_start_time: datetime, + session_metadata: Dict, + config_path: Path, + results_root_path: Path, + output_path: PathLike, ) -> None: """ Generates a Markdown report for a benchmarking session, documenting performance metrics and graphs. @@ -319,7 +332,7 @@ def build_benchmark_md_report( data = benchmark_metadata_dict primaite_version = data["primaite_version"] - with open(version_result_dir / f"PrimAITE v{primaite_version} Benchmark Report.md", "w") as file: + with open(output_path, "w") as file: # Title file.write(f"# PrimAITE v{primaite_version} Learning Benchmark\n") file.write("## PrimAITE Dev Team\n") @@ -393,3 +406,15 @@ def build_benchmark_md_report( f"![Performance of Minor and Bugfix Releases for Major Version {major_v}]" f"({performance_benchmark_plot_path.name})\n" ) + + +def md2pdf(md_path: PathLike, pdf_path: PathLike, css_path: PathLike) -> None: + """Generate PDF version of Markdown report.""" + from md2pdf.core import md2pdf + + md2pdf( + pdf_file_path=pdf_path, + md_file_path=md_path, + base_url=Path(md_path).parent, + css_file_path=css_path, + ) diff --git a/benchmark/static/styles.css b/benchmark/static/styles.css new file mode 100644 index 00000000..4fbb9bd5 --- /dev/null +++ b/benchmark/static/styles.css @@ -0,0 +1,34 @@ +body { + font-family: 'Arial', sans-serif; + line-height: 1.6; + /* margin: 1cm; */ +} +h1, h2, h3, h4, h5, h6 { + font-weight: bold; + /* margin: 1em 0; */ +} +p { + /* margin: 0.5em 0; */ +} +ul, ol { + margin: 1em 0; + padding-left: 1.5em; +} +pre { + background: #f4f4f4; + padding: 0.5em; + overflow-x: auto; +} +img { + max-width: 100%; + height: auto; +} +table { + width: 100%; + border-collapse: collapse; + margin: 1em 0; +} +th, td { + padding: 0.5em; + border: 1px solid #ddd; +} diff --git a/pyproject.toml b/pyproject.toml index c9b7c062..354df8b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,8 @@ dev = [ "wheel==0.38.4", "nbsphinx==0.9.4", "nbmake==1.5.4", - "pytest-xdist==3.3.1" + "pytest-xdist==3.3.1", + "md2pdf", ] [project.scripts] From 1802648436255edba7593ee39826a109701d83d7 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 7 Aug 2024 11:31:51 +0100 Subject: [PATCH 03/13] #2781 - Initial commit with changes to Terminal to integrate with user_session_manager. Login and logout are now talking to the monitored user session --- .../system/services/terminal/terminal.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 4be2c501..11101d55 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -277,8 +277,7 @@ class Terminal(Service): :return: boolean, True if successful, else False """ # TODO: Un-comment this when UserSessionManager is merged. - # connection_uuid = self.parent.UserSessionManager.login(username=username, password=password) - connection_uuid = str(uuid4()) + connection_uuid = self.parent.user_session_manager.local_login(username=username, password=password) if connection_uuid: self.sys_log.info(f"Login request authorised, connection uuid: {connection_uuid}") # Add new local session to list of connections and return @@ -332,7 +331,7 @@ class Terminal(Service): self.sys_log.info(f"{self.name}: Remote Connection to {ip_address} authorised.") return remote_terminal_connection else: - self.sys_log.warning(f"Connection request{connection_request_id} declined") + self.sys_log.warning(f"Connection request {connection_request_id} declined") return None else: self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.") @@ -405,13 +404,14 @@ class Terminal(Service): if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: # validate & add connection # TODO: uncomment this as part of 2781 - # connection_id = self.parent.UserSessionManager.login(username=username, password=password) - connection_id = str(uuid4()) + username = payload.user_account.username + password = payload.user_account.password + connection_id = self.parent.user_session_manager.remote_login( + username=username, password=password, remote_ip_address=source_ip + ) + # connection_id = str(uuid4()) if connection_id: connection_request_id = payload.connection_request_uuid - username = payload.user_account.username - password = payload.user_account.password - print(f"Connection ID is: {connection_request_id}") self.sys_log.info(f"Connection authorised, session_id: {session_id}") self._create_remote_connection( connection_id=connection_id, @@ -469,6 +469,7 @@ class Terminal(Service): if valid_id: self.sys_log.info(f"{self.name}: Received disconnect command for {connection_id=} from remote.") self._disconnect(payload["connection_id"]) + self.parent.user_session_manager.remote_logout(remote_session_id=connection_id) else: self.sys_log.info("No Active connection held for received connection ID.") @@ -501,7 +502,7 @@ class Terminal(Service): return True elif isinstance(connection, LocalTerminalConnection): - # No further action needed + self.parent.user_session_manager.local_logout() return True def send( From 9fea34bb434c1a277baf85d8aeac19a7859a8160 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 7 Aug 2024 11:58:17 +0100 Subject: [PATCH 04/13] #2781 - Correcting terminal tests and fixing a typo in base.py --- .../simulator/network/hardware/base.py | 2 +- .../system/services/terminal/terminal.py | 1 - .../_system/_services/test_terminal.py | 32 +++++++++---------- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 9230dd47..142561f5 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1174,7 +1174,7 @@ class UserSessionManager(Service): """ rm = super()._init_request_manager() - # todo add doc about requeest schemas + # todo add doc about request schemas rm.add_request( "remote_login", RequestType( diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 11101d55..5e684d89 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -456,7 +456,6 @@ class Terminal(Service): self.sys_log.info("Received command to execute") command = payload.ssh_command valid_connection = self._check_client_connection(payload.connection_uuid) - self.sys_log.info(f"Connection uuid is {valid_connection}") if valid_connection: return self.execute(command, payload.connection_uuid) else: diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 9286fa49..ffe48ab5 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -185,7 +185,7 @@ def test_terminal_receive(basic_network): ) term_a_on_node_b: RemoteTerminalConnection = terminal_a.login( - username="username", password="password", ip_address="192.168.0.11" + username="admin", password="admin", ip_address="192.168.0.11" ) term_a_on_node_b.execute(["file_system", "create", "folder", folder_name]) @@ -208,7 +208,7 @@ def test_terminal_install(basic_network): ) term_a_on_node_b: RemoteTerminalConnection = terminal_a.login( - username="username", password="password", ip_address="192.168.0.11" + username="admin", password="admin", ip_address="192.168.0.11" ) term_a_on_node_b.execute(["software_manager", "application", "install", "RansomwareScript"]) @@ -225,9 +225,7 @@ def test_terminal_fail_when_closed(basic_network): terminal.operating_state = ServiceOperatingState.STOPPED - assert not terminal.login( - username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address - ) + assert not terminal.login(username="admin", password="admin", ip_address=computer_b.network_interface[1].ip_address) def test_terminal_disconnect(basic_network): @@ -241,7 +239,7 @@ def test_terminal_disconnect(basic_network): assert len(terminal_b._connections) == 0 term_a_on_term_b = terminal_a.login( - username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address + username="admin", password="admin", ip_address=computer_b.network_interface[1].ip_address ) assert len(terminal_b._connections) == 1 @@ -260,7 +258,7 @@ def test_terminal_ignores_when_off(basic_network): computer_b: Computer = network.get_node_by_hostname("node_b") term_a_on_term_b: RemoteTerminalConnection = terminal_a.login( - username="admin", password="Admin123!", ip_address="192.168.0.11" + username="admin", password="admin", ip_address="192.168.0.11" ) # login to computer_b terminal_a.operating_state = ServiceOperatingState.STOPPED @@ -276,7 +274,7 @@ def test_computer_remote_login_to_router(wireless_wan_network): assert len(pc_a_terminal._connections) == 0 - pc_a_on_router_1 = pc_a_terminal.login(username="username", password="password", ip_address="192.168.1.1") + pc_a_on_router_1 = pc_a_terminal.login(username="admin", password="admin", ip_address="192.168.1.1") assert len(pc_a_terminal._connections) == 1 @@ -295,7 +293,7 @@ def test_router_remote_login_to_computer(wireless_wan_network): assert len(router_1_terminal._connections) == 0 - router_1_on_pc_a = router_1_terminal.login(username="username", password="password", ip_address="192.168.0.2") + router_1_on_pc_a = router_1_terminal.login(username="admin", password="admin", ip_address="192.168.0.2") assert len(router_1_terminal._connections) == 1 @@ -317,7 +315,7 @@ def test_router_blocks_SSH_traffic(wireless_wan_network): assert len(pc_a_terminal._connections) == 0 - pc_a_terminal.login(username="username", password="password", ip_address="192.168.0.2") + pc_a_terminal.login(username="admin", password="admin", ip_address="192.168.0.2") assert len(pc_a_terminal._connections) == 0 @@ -333,7 +331,7 @@ def test_SSH_across_network(wireless_wan_network): assert len(terminal_a._connections) == 0 - terminal_b_on_terminal_a = terminal_b.login(username="username", password="password", ip_address="192.168.0.2") + terminal_b_on_terminal_a = terminal_b.login(username="admin", password="admin", ip_address="192.168.0.2") assert len(terminal_a._connections) == 1 @@ -347,11 +345,13 @@ def test_multiple_remote_terminals_same_node(basic_network): assert len(terminal_a._connections) == 0 - # Spam login requests to terminal. - for attempt in range(10): - remote_connection = terminal_a.login(username="username", password="password", ip_address="192.168.0.11") + # Spam login requests to node. + for attempt in range(3): + remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11") - assert len(terminal_a._connections) == 10 + terminal_a.show() + + assert len(terminal_a._connections) == 3 def test_terminal_rejects_commands_if_disconnect(basic_network): @@ -363,7 +363,7 @@ def test_terminal_rejects_commands_if_disconnect(basic_network): terminal_b: Terminal = computer_b.software_manager.software.get("Terminal") - remote_connection = terminal_a.login(username="username", password="password", ip_address="192.168.0.11") + remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11") assert len(terminal_a._connections) == 1 assert len(terminal_b._connections) == 1 From d2693d974f48b9dad4cead272560783b7b420b94 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 7 Aug 2024 13:18:20 +0000 Subject: [PATCH 05/13] Fix relative path to primaite benchmark to align with build pipeline step --- benchmark/primaite_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/primaite_benchmark.py b/benchmark/primaite_benchmark.py index 2b09870d..86ed22a9 100644 --- a/benchmark/primaite_benchmark.py +++ b/benchmark/primaite_benchmark.py @@ -205,7 +205,7 @@ def run( md2pdf( md_path=output_path, pdf_path=str(output_path).replace(".md", ".pdf"), - css_path="benchmark/static/styles.css", + css_path="static/styles.css", ) From 93ef3076f552baa5d3b8be303843ac1659472b37 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Thu, 8 Aug 2024 11:33:42 +0100 Subject: [PATCH 06/13] #2781 - user_session_manager._timeout_session() now sends a user_timeout command when closing remote sessions. Corrected source_ip in Terminal.receive() --- .../notebooks/Terminal-Processing.ipynb | 26 ++++++++-- .../simulator/network/hardware/base.py | 7 +++ .../system/services/terminal/terminal.py | 52 ++++++++++++------- 3 files changed, 64 insertions(+), 21 deletions(-) diff --git a/src/primaite/notebooks/Terminal-Processing.ipynb b/src/primaite/notebooks/Terminal-Processing.ipynb index 30b1a5e7..f3848c84 100644 --- a/src/primaite/notebooks/Terminal-Processing.ipynb +++ b/src/primaite/notebooks/Terminal-Processing.ipynb @@ -80,14 +80,14 @@ "outputs": [], "source": [ "# Login to the remote (node_b) from local (node_a)\n", - "term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username=\"admin\", password=\"Admin123!\", ip_address=\"192.168.0.11\")" + "term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username=\"admin\", password=\"admin\", ip_address=\"192.168.0.11\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "You can view all active connections to a terminal through use of the `show()` method" + "You can view all active connections to a terminal through use of the `show()` method, " ] }, { @@ -96,7 +96,11 @@ "metadata": {}, "outputs": [], "source": [ - "terminal_b.show()" + "terminal_b.show()\n", + "print(term_a_term_b_remote_connection.ssh_session_id)\n", + "computer_b.user_session_manager.show(include_session_id=True)\n", + "computer_b.user_session_manager.show()\n", + "\n" ] }, { @@ -183,6 +187,22 @@ "\n", "terminal_b.show()" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Disconnected Terminal sessions will no longer show in the node's `user_session_manager` as active, but will be under the historic sessions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "computer_b.user_session_manager.show(include_historic=True, include_session_id=True)" + ] } ], "metadata": { diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 142561f5..7842aa66 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1294,6 +1294,13 @@ class UserSessionManager(Service): self.remote_sessions.pop(session.uuid) session_type = "Remote" session_identity = f"{session_identity} {session.remote_ip_address}" + self.parent.terminal._connections.pop(session.uuid) + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload={"type": "user_timeout", "connection_id": session.uuid}, + dest_port=Port.SSH, + dest_ip_address=session.remote_ip_address, + ) self.sys_log.info(f"{self.name}: {session_type} {session_identity} session timeout due to inactivity") diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 5e684d89..46386d3b 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -33,8 +33,8 @@ class TerminalClientConnection(BaseModel): parent_terminal: Terminal """The parent Node that this connection was created on.""" - session_id: str = None - """Session ID that connection is linked to""" + ssh_session_id: str = None + """Session ID that connection is linked to, used for sending commands via session manager.""" connection_uuid: str = None """Connection UUID""" @@ -52,7 +52,9 @@ class TerminalClientConnection(BaseModel): """Flag to state whether the connection is active or not""" def __str__(self) -> str: - return f"{self.__class__.__name__}(connection_id='{self.connection_uuid}')" + return ( + f"{self.__class__.__name__}(connection_id: '{self.connection_uuid}, ssh_session_id: {self.ssh_session_id}')" + ) def __repr__(self) -> str: return self.__str__() @@ -124,13 +126,14 @@ class RemoteTerminalConnection(TerminalClientConnection): ssh_command=command, ) - return self.parent_terminal.send(payload=payload, session_id=self.session_id) + return self.parent_terminal.send(payload=payload, session_id=self.ssh_session_id) class Terminal(Service): """Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH.""" _client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {} + """Dictionary of connect requests made to remote nodes.""" def __init__(self, **kwargs): kwargs["name"] = "Terminal" @@ -169,7 +172,14 @@ class Terminal(Service): def _login(request: RequestFormat, context: Dict) -> RequestResponse: login = self._process_local_login(username=request[0], password=request[1]) if login: - return RequestResponse(status="success", data={}) + return RequestResponse( + status="success", + data={ + "connection ID": login.connection_uuid, + "ssh_session_id": login.ssh_session_id, + "ip_address": login.ip_address, + }, + ) else: return RequestResponse(status="failure", data={}) @@ -184,16 +194,13 @@ class Terminal(Service): """Execute an instruction.""" command: str = request[0] connection_id: str = request[1] - self.execute(command, connection_id=connection_id) - return RequestResponse(status="success", data={}) + return self.execute(command, connection_id=connection_id) def _logoff(request: RequestFormat, context: Dict) -> RequestResponse: """Logoff from connection.""" connection_uuid = request[0] - # TODO: Uncomment this when UserSessionManager merged. - # self.parent.UserSessionManager.logoff(connection_uuid) + self.parent.user_session_manager.local_logout(connection_uuid) self._disconnect(connection_uuid) - return RequestResponse(status="success", data={}) rm.add_request( @@ -234,7 +241,7 @@ class Terminal(Service): new_connection = LocalTerminalConnection( parent_terminal=self, connection_uuid=connection_uuid, - session_id=session_id, + ssh_session_id=session_id, time=datetime.now(), ) self._connections[connection_uuid] = new_connection @@ -288,11 +295,11 @@ class Terminal(Service): def _validate_client_connection_request(self, connection_id: str) -> bool: """Check that client_connection_id is valid.""" - return True if connection_id in self._client_connection_requests else False + return connection_id in self._client_connection_requests def _check_client_connection(self, connection_id: str) -> bool: """Check that client_connection_id is valid.""" - return True if connection_id in self._connections else False + return connection_id in self._connections def _send_remote_login( self, @@ -381,7 +388,7 @@ class Terminal(Service): """ client_connection = RemoteTerminalConnection( parent_terminal=self, - session_id=session_id, + ssh_session_id=session_id, connection_uuid=connection_id, ip_address=source_ip, connection_request_id=connection_request_id, @@ -398,7 +405,7 @@ class Terminal(Service): :param session_id: The session id the payload relates to. :return: True. """ - source_ip = kwargs["from_network_interface"].ip_address + source_ip = [kwargs["frame"].ip.src_ip_address][0] self.sys_log.info(f"Received payload: {payload}. Source: {source_ip}") if isinstance(payload, SSHPacket): if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: @@ -470,12 +477,21 @@ class Terminal(Service): self._disconnect(payload["connection_id"]) self.parent.user_session_manager.remote_logout(remote_session_id=connection_id) else: - self.sys_log.info("No Active connection held for received connection ID.") + self.sys_log.error("No Active connection held for received connection ID.") + + if payload["type"] == "user_timeout": + connection_id = payload["connection_id"] + valid_id = self._check_client_connection(connection_id) + if valid_id: + self._connections.pop(connection_id) + self.sys_log.info(f"{self.name}: Connection {connection_id} disconnected due to inactivity.") + else: + self.sys_log.error(f"{self.name}: Connection {connection_id} is invalid.") return True def _disconnect(self, connection_uuid: str) -> bool: - """Disconnect from the remote. + """Disconnect connection. :param connection_uuid: Connection ID that we want to disconnect. :return True if successful, False otherwise. @@ -489,7 +505,7 @@ class Terminal(Service): if isinstance(connection, RemoteTerminalConnection): # Send disconnect command via software manager - session_id = connection.session_id + session_id = connection.ssh_session_id software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( From ff054830bca7ce3ede3e6bcac8d89d7559193061 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Thu, 8 Aug 2024 11:57:30 +0100 Subject: [PATCH 07/13] #2781 - Correcting some typos in Terminal notebook and elaborating the data in _remote_login request --- src/primaite/notebooks/Terminal-Processing.ipynb | 11 +++-------- .../simulator/system/services/terminal/terminal.py | 11 +++++++++-- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/primaite/notebooks/Terminal-Processing.ipynb b/src/primaite/notebooks/Terminal-Processing.ipynb index f3848c84..fdf405a7 100644 --- a/src/primaite/notebooks/Terminal-Processing.ipynb +++ b/src/primaite/notebooks/Terminal-Processing.ipynb @@ -87,7 +87,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You can view all active connections to a terminal through use of the `show()` method, " + "You can view all active connections to a terminal through use of the `show()` method." ] }, { @@ -96,11 +96,7 @@ "metadata": {}, "outputs": [], "source": [ - "terminal_b.show()\n", - "print(term_a_term_b_remote_connection.ssh_session_id)\n", - "computer_b.user_session_manager.show(include_session_id=True)\n", - "computer_b.user_session_manager.show()\n", - "\n" + "terminal_b.show()" ] }, { @@ -184,7 +180,6 @@ "term_a_term_b_remote_connection.disconnect()\n", "\n", "terminal_a.show()\n", - "\n", "terminal_b.show()" ] }, @@ -192,7 +187,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Disconnected Terminal sessions will no longer show in the node's `user_session_manager` as active, but will be under the historic sessions" + "Disconnected Terminal sessions will no longer show in the node's Terminal connection list, but will be under the historic sessions in the `user_session_manager`." ] }, { diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 46386d3b..aa3b5d62 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -181,12 +181,19 @@ class Terminal(Service): }, ) else: - return RequestResponse(status="failure", data={}) + return RequestResponse(status="failure", data={"reason": "Invalid login credentials"}) def _remote_login(request: RequestFormat, context: Dict) -> RequestResponse: login = self._send_remote_login(username=request[0], password=request[1], ip_address=request[2]) if login: - return RequestResponse(status="success", data={}) + return RequestResponse( + status="success", + data={ + "connection ID": login.connection_uuid, + "ssh_session_id": login.ssh_session_id, + "ip_address": login.ip_address, + }, + ) else: return RequestResponse(status="failure", data={}) From 5f5ea5e5246ddb9295267251e2f445ea47fdee6a Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Thu, 8 Aug 2024 14:20:23 +0100 Subject: [PATCH 08/13] #2718 - Updates to Terminal following discussion about implementation with actions. --- .../system/services/terminal/terminal.py | 41 +++++++++++-------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index aa3b5d62..88b6d3a3 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -52,9 +52,7 @@ class TerminalClientConnection(BaseModel): """Flag to state whether the connection is active or not""" def __str__(self) -> str: - return ( - f"{self.__class__.__name__}(connection_id: '{self.connection_uuid}, ssh_session_id: {self.ssh_session_id}')" - ) + return f"{self.__class__.__name__}(connection_id: '{self.connection_uuid}, ip_address: {self.ip_address}')" def __repr__(self) -> str: return self.__str__() @@ -176,7 +174,6 @@ class Terminal(Service): status="success", data={ "connection ID": login.connection_uuid, - "ssh_session_id": login.ssh_session_id, "ip_address": login.ip_address, }, ) @@ -189,19 +186,28 @@ class Terminal(Service): return RequestResponse( status="success", data={ - "connection ID": login.connection_uuid, - "ssh_session_id": login.ssh_session_id, "ip_address": login.ip_address, }, ) else: return RequestResponse(status="failure", data={}) - def _execute_request(request: RequestFormat, context: Dict) -> RequestResponse: + def remote_execute_request(request: RequestFormat, context: Dict) -> RequestResponse: """Execute an instruction.""" command: str = request[0] - connection_id: str = request[1] - return self.execute(command, connection_id=connection_id) + ip_address: IPv4Address = IPv4Address(request[1]) + remote_connection = self._get_connection_from_ip(ip_address=ip_address) + outcome = remote_connection.execute(command) + if outcome: + return RequestResponse( + status="success", + data={}, + ) + else: + return RequestResponse( + status="failure", + data={}, + ) def _logoff(request: RequestFormat, context: Dict) -> RequestResponse: """Logoff from connection.""" @@ -222,20 +228,23 @@ class Terminal(Service): rm.add_request( "Execute", - request_type=RequestType(func=_execute_request), + request_type=RequestType(func=remote_execute_request), ) rm.add_request("Logoff", request_type=RequestType(func=_logoff)) return rm - def execute(self, command: List[Any], connection_id: str) -> Optional[RequestResponse]: + def execute(self, command: List[Any]) -> Optional[RequestResponse]: """Execute a passed ssh command via the request manager.""" - valid_connection = self._check_client_connection(connection_id=connection_id) - if valid_connection: - return self.parent.apply_request(command) + return self.parent.apply_request(command) + + def _get_connection_from_ip(self, ip_address: IPv4Address) -> Optional[RemoteTerminalConnection]: + """Find Remote Terminal Connection from a given IP.""" + for connection in self._connections: + if self._connections[connection].ip_address == ip_address: + return self._connections[connection] else: - self.sys_log.error("Invalid connection ID provided") return None def _create_local_connection(self, connection_uuid: str, session_id: str) -> TerminalClientConnection: @@ -471,7 +480,7 @@ class Terminal(Service): command = payload.ssh_command valid_connection = self._check_client_connection(payload.connection_uuid) if valid_connection: - return self.execute(command, payload.connection_uuid) + return self.execute(command) else: self.sys_log.error(f"Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command.") From 116ac725b0faf5ab5b4c4e51b7a4f080a17aa1a3 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Thu, 8 Aug 2024 14:23:10 +0100 Subject: [PATCH 09/13] #2718 - making terminal rm _login() and _remote_login() consistent in their RequestResponse --- src/primaite/simulator/system/services/terminal/terminal.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 88b6d3a3..85e0c87f 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -173,7 +173,6 @@ class Terminal(Service): return RequestResponse( status="success", data={ - "connection ID": login.connection_uuid, "ip_address": login.ip_address, }, ) From 665c53d880b8f19f8ceebf0f67ea4c7a151811d5 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Thu, 8 Aug 2024 15:48:44 +0100 Subject: [PATCH 10/13] #2781 - Actioning review comments --- .../simulator/network/hardware/base.py | 4 ++ .../system/services/terminal/terminal.py | 62 ++++++++++--------- .../_system/_services/test_terminal.py | 26 ++++++++ 3 files changed, 63 insertions(+), 29 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 7842aa66..1441c93b 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1278,6 +1278,10 @@ class UserSessionManager(Service): if self.local_session: if self.local_session.last_active_step + self.local_session_timeout_steps <= timestep: self._timeout_session(self.local_session) + for session in self.remote_sessions: + remote_session = self.remote_sessions[session] + if remote_session.last_active_step + self.remote_session_timeout_steps <= timestep: + self._timeout_session(remote_session) def _timeout_session(self, session: UserSession) -> None: """ diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 85e0c87f..876b1694 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -196,12 +196,13 @@ class Terminal(Service): command: str = request[0] ip_address: IPv4Address = IPv4Address(request[1]) remote_connection = self._get_connection_from_ip(ip_address=ip_address) - outcome = remote_connection.execute(command) - if outcome: - return RequestResponse( - status="success", - data={}, - ) + if remote_connection: + outcome = remote_connection.execute(command) + if outcome: + return RequestResponse( + status="success", + data={}, + ) else: return RequestResponse( status="failure", @@ -240,11 +241,9 @@ class Terminal(Service): def _get_connection_from_ip(self, ip_address: IPv4Address) -> Optional[RemoteTerminalConnection]: """Find Remote Terminal Connection from a given IP.""" - for connection in self._connections: - if self._connections[connection].ip_address == ip_address: - return self._connections[connection] - else: - return None + for connection in self._connections.values(): + if connection.ip_address == ip_address: + return connection def _create_local_connection(self, connection_uuid: str, session_id: str) -> TerminalClientConnection: """Create a new connection object and amend to list of active connections. @@ -279,7 +278,7 @@ class Terminal(Service): :type: ip_address: Optional[IPv4Address] """ if self.operating_state != ServiceOperatingState.RUNNING: - self.sys_log.warning("Cannot login as service is not running.") + self.sys_log.warning(f"{self.name}: Cannot login as service is not running.") return None connection_request_id = str(uuid4()) self._client_connection_requests[connection_request_id] = None @@ -301,11 +300,11 @@ class Terminal(Service): # TODO: Un-comment this when UserSessionManager is merged. connection_uuid = self.parent.user_session_manager.local_login(username=username, password=password) if connection_uuid: - self.sys_log.info(f"Login request authorised, connection uuid: {connection_uuid}") + self.sys_log.info(f"{self.name}: Login request authorised, connection uuid: {connection_uuid}") # Add new local session to list of connections and return return self._create_local_connection(connection_uuid=connection_uuid, session_id="Local_Connection") else: - self.sys_log.warning("Login failed, incorrect Username or Password") + self.sys_log.warning(f"{self.name}: Login failed, incorrect Username or Password") return None def _validate_client_connection_request(self, connection_id: str) -> bool: @@ -344,7 +343,9 @@ class Terminal(Service): :return: RemoteTerminalConnection: Connection Object for sending further commands if successful, else False. """ - self.sys_log.info(f"Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}") + self.sys_log.info( + f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}" + ) if is_reattempt: valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id) if valid_connection_request: @@ -353,7 +354,7 @@ class Terminal(Service): self.sys_log.info(f"{self.name}: Remote Connection to {ip_address} authorised.") return remote_terminal_connection else: - self.sys_log.warning(f"Connection request {connection_request_id} declined") + self.sys_log.warning(f"{self.name}: Connection request {connection_request_id} declined") return None else: self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.") @@ -420,8 +421,8 @@ class Terminal(Service): :param session_id: The session id the payload relates to. :return: True. """ - source_ip = [kwargs["frame"].ip.src_ip_address][0] - self.sys_log.info(f"Received payload: {payload}. Source: {source_ip}") + source_ip = kwargs["frame"].ip.src_ip_address + self.sys_log.info(f"{self.name}: Received payload: {payload}. Source: {source_ip}") if isinstance(payload, SSHPacket): if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: # validate & add connection @@ -431,10 +432,9 @@ class Terminal(Service): connection_id = self.parent.user_session_manager.remote_login( username=username, password=password, remote_ip_address=source_ip ) - # connection_id = str(uuid4()) if connection_id: connection_request_id = payload.connection_request_uuid - self.sys_log.info(f"Connection authorised, session_id: {session_id}") + self.sys_log.info(f"{self.name}: Connection authorised, session_id: {session_id}") self._create_remote_connection( connection_id=connection_id, connection_request_id=connection_request_id, @@ -465,7 +465,7 @@ class Terminal(Service): payload=payload, dest_port=self.port, session_id=session_id ) elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: - self.sys_log.info("Login Successful") + self.sys_log.info(f"{self.name}: Login Successful") self._create_remote_connection( connection_id=payload.connection_uuid, connection_request_id=payload.connection_request_uuid, @@ -475,13 +475,16 @@ class Terminal(Service): elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST: # Requesting a command to be executed - self.sys_log.info("Received command to execute") + self.sys_log.info(f"{self.name}: Received command to execute") command = payload.ssh_command valid_connection = self._check_client_connection(payload.connection_uuid) if valid_connection: - return self.execute(command) + self.execute(command) + return True else: - self.sys_log.error(f"Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command.") + self.sys_log.error( + f"{self.name}: Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command." + ) if isinstance(payload, dict) and payload.get("type"): if payload["type"] == "disconnect": @@ -492,13 +495,14 @@ class Terminal(Service): self._disconnect(payload["connection_id"]) self.parent.user_session_manager.remote_logout(remote_session_id=connection_id) else: - self.sys_log.error("No Active connection held for received connection ID.") + self.sys_log.error(f"{self.name}: No Active connection held for received connection ID.") if payload["type"] == "user_timeout": connection_id = payload["connection_id"] valid_id = self._check_client_connection(connection_id) if valid_id: - self._connections.pop(connection_id) + connection = self._connections.pop(connection_id) + connection.is_active = False self.sys_log.info(f"{self.name}: Connection {connection_id} disconnected due to inactivity.") else: self.sys_log.error(f"{self.name}: Connection {connection_id} is invalid.") @@ -512,7 +516,7 @@ class Terminal(Service): :return True if successful, False otherwise. """ if not self._connections: - self.sys_log.warning("No remote connection present") + self.sys_log.warning(f"{self.name}: No remote connection present") return False connection = self._connections.pop(connection_uuid) @@ -545,10 +549,10 @@ class Terminal(Service): :param dest_up_address: The IP address of the payload destination. """ if self.operating_state != ServiceOperatingState.RUNNING: - self.sys_log.warning(f"Cannot send commands when Operating state is {self.operating_state}!") + self.sys_log.warning(f"{self.name}: Cannot send commands when Operating state is {self.operating_state}!") return False - self.sys_log.debug(f"Sending payload: {payload}") + self.sys_log.debug(f"{self.name}: Sending payload: {payload}") return super().send( payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port, session_id=session_id ) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index ffe48ab5..41858b90 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -248,6 +248,8 @@ def test_terminal_disconnect(basic_network): assert len(terminal_b._connections) == 0 + assert term_a_on_term_b.is_active is False + def test_terminal_ignores_when_off(basic_network): """Terminal should ignore commands when not running""" @@ -378,3 +380,27 @@ def test_terminal_rejects_commands_if_disconnect(basic_network): assert not computer_b.software_manager.software.get("RansomwareScript") assert remote_connection.is_active is False + + +def test_terminal_connection_timeout(basic_network): + """Test that terminal_connections are affected by UserSession timeout.""" + network: Network = basic_network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") + terminal_b: Terminal = computer_b.software_manager.software.get("Terminal") + + remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11") + + assert len(terminal_a._connections) == 1 + assert len(terminal_b._connections) == 1 + assert len(computer_b.user_session_manager.remote_sessions) == 1 + + remote_session = computer_b.user_session_manager.remote_sessions[remote_connection.connection_uuid] + computer_b.user_session_manager._timeout_session(remote_session) + + assert len(terminal_a._connections) == 0 + assert len(terminal_b._connections) == 0 + assert len(computer_b.user_session_manager.remote_sessions) == 0 + + assert not remote_connection.is_active From a3a9ca9963c4fc67e46a5eeeb5d067d9c764d2d5 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Thu, 8 Aug 2024 21:20:20 +0100 Subject: [PATCH 11/13] #2768 - Fixed issue causing main port to not be included in list of open ports. documented the configuration of listen_on_ports. added test that tests listen_on_ports configuration from yaml. --- CHANGELOG.md | 11 +++++- .../system/common/common_configuration.rst | 32 +++++++++++++++ src/primaite/game/game.py | 22 ++++++++++- .../simulator/system/core/software_manager.py | 29 ++++++++------ ...ic_node_with_software_listening_ports.yaml | 39 +++++++++++++++++++ .../system/test_service_listening_on_ports.py | 20 ++++++++++ 6 files changed, 139 insertions(+), 14 deletions(-) create mode 100644 tests/assets/configs/basic_node_with_software_listening_ports.yaml diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d999607..c354aa14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,9 +9,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Random Number Generator Seeding by specifying a random number seed in the config file. - Implemented Terminal service class, providing a generic terminal simulation. +- Added `User`, `UserManager` and `UserSessionManager` to enable the creation of user accounts and login on Nodes. +- Added a `listen_on_ports` set in the `IOSoftware` class to enable software listening on ports in addition to the + main port they're assigned. ### Changed -- Removed the install/uninstall methods in the node class and made the software manager install/uninstall handle all of their functionality. +- Updated `SoftwareManager` `install` and `uninstall` to handle all functionality that was being done at the `install` + and `uninstall` methods in the `Node` class. +- Updated the `receive_payload_from_session_manager` method in `SoftwareManager` so that it now sends a copy of the + payload to any software listening on the destination port of the `Frame`. + +### Removed +- Removed the `install` and `uninstall` methods in the `Node` class. ## [3.2.0] - 2024-07-18 diff --git a/docs/source/simulation_components/system/common/common_configuration.rst b/docs/source/simulation_components/system/common/common_configuration.rst index e35ee378..420166dd 100644 --- a/docs/source/simulation_components/system/common/common_configuration.rst +++ b/docs/source/simulation_components/system/common/common_configuration.rst @@ -25,3 +25,35 @@ The configuration options are the attributes that fall under the options for an Optional. Default value is ``2``. The number of timesteps the |SOFTWARE_NAME| will remain in a ``FIXING`` state before going into a ``GOOD`` state. + + +``listen_on_ports`` +""""""""""""""""""" + +The set of ports to listen on. This is in addition to the main port the software is designated. This set can either be +the string name of ports or the port integers + +Example: + +.. code-block:: yaml + + simulation: + network: + nodes: + - hostname: client + type: computer + ip_address: 192.168.10.11 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + services: + - type: DatabaseService + options: + backup_server_ip: 10.10.1.12 + listen_on_ports: + - 631 + applications: + - type: WebBrowser + options: + target_url: http://sometech.ai + listen_on_ports: + - SMB diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 6a97ad25..3d3caed9 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK """PrimAITE game - Encapsulates the simulation and agents.""" from ipaddress import IPv4Address -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import numpy as np from pydantic import BaseModel, ConfigDict @@ -44,8 +44,10 @@ from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.ntp.ntp_server import NTPServer +from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.terminal.terminal import Terminal from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.simulator.system.software import Software _LOGGER = getLogger(__name__) @@ -328,6 +330,21 @@ class PrimaiteGame: user_manager: UserManager = new_node.software_manager.software["UserManager"] # noqa for user_cfg in node_cfg["users"]: user_manager.add_user(**user_cfg, bypass_can_perform_action=True) + + def _set_software_listen_on_ports(software: Union[Software, Service], software_cfg: Dict): + """Set listener ports on software.""" + listen_on_ports = [] + for port_id in set(software_cfg.get("options", {}).get("listen_on_ports", [])): + print("yes", port_id) + port = None + if isinstance(port_id, int): + port = Port(port_id) + elif isinstance(port_id, str): + port = Port[port_id] + if port: + listen_on_ports.append(port) + software.listen_on_ports = set(listen_on_ports) + if "services" in node_cfg: for service_cfg in node_cfg["services"]: new_service = None @@ -341,6 +358,7 @@ class PrimaiteGame: if "fix_duration" in service_cfg.get("options", {}): new_service.fixing_duration = service_cfg["options"]["fix_duration"] + _set_software_listen_on_ports(new_service, service_cfg) # start the service new_service.start() else: @@ -390,6 +408,8 @@ class PrimaiteGame: _LOGGER.error(msg) raise ValueError(msg) + _set_software_listen_on_ports(new_application, application_cfg) + # run the application new_application.run() diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index 7b36097b..d45611ed 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -237,19 +237,24 @@ class SoftwareManager: self.software.get("NMAP").receive(payload=payload, session_id=session_id) return main_receiver = self.port_protocol_mapping.get((port, protocol), None) - listening_receivers = [software for software in self.software.values() if port in software.listen_on_ports] - receivers = [main_receiver] + listening_receivers if main_receiver else listening_receivers - if receivers: - for receiver in receivers: - receiver.receive( - payload=deepcopy(payload), - session_id=session_id, - from_network_interface=from_network_interface, - frame=frame, - ) - else: + if main_receiver: + main_receiver.receive( + payload=payload, session_id=session_id, from_network_interface=from_network_interface, frame=frame + ) + listening_receivers = [ + software + for software in self.software.values() + if port in software.listen_on_ports and software != main_receiver + ] + for receiver in listening_receivers: + receiver.receive( + payload=deepcopy(payload), + session_id=session_id, + from_network_interface=from_network_interface, + frame=frame, + ) + if not main_receiver and not listening_receivers: self.sys_log.warning(f"No service or application found for port {port} and protocol {protocol}") - pass def show(self, markdown: bool = False): """ diff --git a/tests/assets/configs/basic_node_with_software_listening_ports.yaml b/tests/assets/configs/basic_node_with_software_listening_ports.yaml new file mode 100644 index 00000000..53eee87f --- /dev/null +++ b/tests/assets/configs/basic_node_with_software_listening_ports.yaml @@ -0,0 +1,39 @@ +io_settings: + save_step_metadata: false + save_pcap_logs: true + save_sys_logs: true + sys_log_level: WARNING + agent_log_level: INFO + save_agent_logs: true + write_agent_log_to_terminal: True + + +game: + max_episode_length: 256 + ports: + - ARP + protocols: + - ICMP + - UDP + + +simulation: + network: + nodes: + - hostname: client + type: computer + ip_address: 192.168.10.11 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + services: + - type: DatabaseService + options: + backup_server_ip: 10.10.1.12 + listen_on_ports: + - 631 + applications: + - type: WebBrowser + options: + target_url: http://sometech.ai + listen_on_ports: + - SMB diff --git a/tests/integration_tests/system/test_service_listening_on_ports.py b/tests/integration_tests/system/test_service_listening_on_ports.py index 0cb1ad54..fd502a70 100644 --- a/tests/integration_tests/system/test_service_listening_on_ports.py +++ b/tests/integration_tests/system/test_service_listening_on_ports.py @@ -1,13 +1,17 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from typing import Any, Dict, List, Set +import yaml from pydantic import Field +from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.service import Service +from tests import TEST_ASSETS_ROOT class _DatabaseListener(Service): @@ -62,3 +66,19 @@ def test_http_listener(client_server): assert db_connection.query("SELECT") assert len(server_db_listener.payloads_received) == 3 + + +def test_set_listen_on_ports_from_config(): + config_path = TEST_ASSETS_ROOT / "configs" / "basic_node_with_software_listening_ports.yaml" + + with open(config_path, "r") as f: + config_dict = yaml.safe_load(f) + network = PrimaiteGame.from_config(cfg=config_dict).simulation.network + + client: Computer = network.get_node_by_hostname("client") + assert Port.SMB in client.software_manager.get_open_ports() + assert Port.IPP in client.software_manager.get_open_ports() + + web_browser = client.software_manager.software["WebBrowser"] + + assert not web_browser.listen_on_ports.difference({Port.SMB, Port.IPP}) From 72e6e78ed7c9b39ec04643888016c9b3830a9745 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 9 Aug 2024 09:32:13 +0100 Subject: [PATCH 12/13] #2768 - Removed debugging print statement --- src/primaite/game/game.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 3d3caed9..9117d30a 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -335,7 +335,6 @@ class PrimaiteGame: """Set listener ports on software.""" listen_on_ports = [] for port_id in set(software_cfg.get("options", {}).get("listen_on_ports", [])): - print("yes", port_id) port = None if isinstance(port_id, int): port = Port(port_id) From bf44ceaeac912195b683492d5d0843b9d74de16d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 9 Aug 2024 09:26:37 +0000 Subject: [PATCH 13/13] Apply suggestions from code review --- benchmark/report.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/report.py b/benchmark/report.py index 408e91cf..4035ceca 100644 --- a/benchmark/report.py +++ b/benchmark/report.py @@ -15,7 +15,7 @@ from utils import _get_system_info import primaite PLOT_CONFIG = { - "size": {"auto_size": False, "width": 800, "height": 800}, + "size": {"auto_size": False, "width": 800, "height": 640}, "template": "plotly_white", "range_slider": False, }