diff --git a/src/primaite/notebooks/Terminal-Processing.ipynb b/src/primaite/notebooks/Terminal-Processing.ipynb index 30b1a5e7..fdf405a7 100644 --- a/src/primaite/notebooks/Terminal-Processing.ipynb +++ b/src/primaite/notebooks/Terminal-Processing.ipynb @@ -80,14 +80,14 @@ "outputs": [], "source": [ "# Login to the remote (node_b) from local (node_a)\n", - "term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username=\"admin\", password=\"Admin123!\", ip_address=\"192.168.0.11\")" + "term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username=\"admin\", password=\"admin\", ip_address=\"192.168.0.11\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "You can view all active connections to a terminal through use of the `show()` method" + "You can view all active connections to a terminal through use of the `show()` method." ] }, { @@ -180,9 +180,24 @@ "term_a_term_b_remote_connection.disconnect()\n", "\n", "terminal_a.show()\n", - "\n", "terminal_b.show()" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Disconnected Terminal sessions will no longer show in the node's Terminal connection list, but will be under the historic sessions in the `user_session_manager`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "computer_b.user_session_manager.show(include_historic=True, include_session_id=True)" + ] } ], "metadata": { diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 9230dd47..1441c93b 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1174,7 +1174,7 @@ class UserSessionManager(Service): """ rm = super()._init_request_manager() - # todo add doc about requeest schemas + # todo add doc about request schemas rm.add_request( "remote_login", RequestType( @@ -1278,6 +1278,10 @@ class UserSessionManager(Service): if self.local_session: if self.local_session.last_active_step + self.local_session_timeout_steps <= timestep: self._timeout_session(self.local_session) + for session in self.remote_sessions: + remote_session = self.remote_sessions[session] + if remote_session.last_active_step + self.remote_session_timeout_steps <= timestep: + self._timeout_session(remote_session) def _timeout_session(self, session: UserSession) -> None: """ @@ -1294,6 +1298,13 @@ class UserSessionManager(Service): self.remote_sessions.pop(session.uuid) session_type = "Remote" session_identity = f"{session_identity} {session.remote_ip_address}" + self.parent.terminal._connections.pop(session.uuid) + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload={"type": "user_timeout", "connection_id": session.uuid}, + dest_port=Port.SSH, + dest_ip_address=session.remote_ip_address, + ) self.sys_log.info(f"{self.name}: {session_type} {session_identity} session timeout due to inactivity") diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 4be2c501..876b1694 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -33,8 +33,8 @@ class TerminalClientConnection(BaseModel): parent_terminal: Terminal """The parent Node that this connection was created on.""" - session_id: str = None - """Session ID that connection is linked to""" + ssh_session_id: str = None + """Session ID that connection is linked to, used for sending commands via session manager.""" connection_uuid: str = None """Connection UUID""" @@ -52,7 +52,7 @@ class TerminalClientConnection(BaseModel): """Flag to state whether the connection is active or not""" def __str__(self) -> str: - return f"{self.__class__.__name__}(connection_id='{self.connection_uuid}')" + return f"{self.__class__.__name__}(connection_id: '{self.connection_uuid}, ip_address: {self.ip_address}')" def __repr__(self) -> str: return self.__str__() @@ -124,13 +124,14 @@ class RemoteTerminalConnection(TerminalClientConnection): ssh_command=command, ) - return self.parent_terminal.send(payload=payload, session_id=self.session_id) + return self.parent_terminal.send(payload=payload, session_id=self.ssh_session_id) class Terminal(Service): """Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH.""" _client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {} + """Dictionary of connect requests made to remote nodes.""" def __init__(self, **kwargs): kwargs["name"] = "Terminal" @@ -169,31 +170,50 @@ class Terminal(Service): def _login(request: RequestFormat, context: Dict) -> RequestResponse: login = self._process_local_login(username=request[0], password=request[1]) if login: - return RequestResponse(status="success", data={}) + return RequestResponse( + status="success", + data={ + "ip_address": login.ip_address, + }, + ) else: - return RequestResponse(status="failure", data={}) + return RequestResponse(status="failure", data={"reason": "Invalid login credentials"}) def _remote_login(request: RequestFormat, context: Dict) -> RequestResponse: login = self._send_remote_login(username=request[0], password=request[1], ip_address=request[2]) if login: - return RequestResponse(status="success", data={}) + return RequestResponse( + status="success", + data={ + "ip_address": login.ip_address, + }, + ) else: return RequestResponse(status="failure", data={}) - def _execute_request(request: RequestFormat, context: Dict) -> RequestResponse: + def remote_execute_request(request: RequestFormat, context: Dict) -> RequestResponse: """Execute an instruction.""" command: str = request[0] - connection_id: str = request[1] - self.execute(command, connection_id=connection_id) - return RequestResponse(status="success", data={}) + ip_address: IPv4Address = IPv4Address(request[1]) + remote_connection = self._get_connection_from_ip(ip_address=ip_address) + if remote_connection: + outcome = remote_connection.execute(command) + if outcome: + return RequestResponse( + status="success", + data={}, + ) + else: + return RequestResponse( + status="failure", + data={}, + ) def _logoff(request: RequestFormat, context: Dict) -> RequestResponse: """Logoff from connection.""" connection_uuid = request[0] - # TODO: Uncomment this when UserSessionManager merged. - # self.parent.UserSessionManager.logoff(connection_uuid) + self.parent.user_session_manager.local_logout(connection_uuid) self._disconnect(connection_uuid) - return RequestResponse(status="success", data={}) rm.add_request( @@ -208,21 +228,22 @@ class Terminal(Service): rm.add_request( "Execute", - request_type=RequestType(func=_execute_request), + request_type=RequestType(func=remote_execute_request), ) rm.add_request("Logoff", request_type=RequestType(func=_logoff)) return rm - def execute(self, command: List[Any], connection_id: str) -> Optional[RequestResponse]: + def execute(self, command: List[Any]) -> Optional[RequestResponse]: """Execute a passed ssh command via the request manager.""" - valid_connection = self._check_client_connection(connection_id=connection_id) - if valid_connection: - return self.parent.apply_request(command) - else: - self.sys_log.error("Invalid connection ID provided") - return None + return self.parent.apply_request(command) + + def _get_connection_from_ip(self, ip_address: IPv4Address) -> Optional[RemoteTerminalConnection]: + """Find Remote Terminal Connection from a given IP.""" + for connection in self._connections.values(): + if connection.ip_address == ip_address: + return connection def _create_local_connection(self, connection_uuid: str, session_id: str) -> TerminalClientConnection: """Create a new connection object and amend to list of active connections. @@ -234,7 +255,7 @@ class Terminal(Service): new_connection = LocalTerminalConnection( parent_terminal=self, connection_uuid=connection_uuid, - session_id=session_id, + ssh_session_id=session_id, time=datetime.now(), ) self._connections[connection_uuid] = new_connection @@ -257,7 +278,7 @@ class Terminal(Service): :type: ip_address: Optional[IPv4Address] """ if self.operating_state != ServiceOperatingState.RUNNING: - self.sys_log.warning("Cannot login as service is not running.") + self.sys_log.warning(f"{self.name}: Cannot login as service is not running.") return None connection_request_id = str(uuid4()) self._client_connection_requests[connection_request_id] = None @@ -277,23 +298,22 @@ class Terminal(Service): :return: boolean, True if successful, else False """ # TODO: Un-comment this when UserSessionManager is merged. - # connection_uuid = self.parent.UserSessionManager.login(username=username, password=password) - connection_uuid = str(uuid4()) + connection_uuid = self.parent.user_session_manager.local_login(username=username, password=password) if connection_uuid: - self.sys_log.info(f"Login request authorised, connection uuid: {connection_uuid}") + self.sys_log.info(f"{self.name}: Login request authorised, connection uuid: {connection_uuid}") # Add new local session to list of connections and return return self._create_local_connection(connection_uuid=connection_uuid, session_id="Local_Connection") else: - self.sys_log.warning("Login failed, incorrect Username or Password") + self.sys_log.warning(f"{self.name}: Login failed, incorrect Username or Password") return None def _validate_client_connection_request(self, connection_id: str) -> bool: """Check that client_connection_id is valid.""" - return True if connection_id in self._client_connection_requests else False + return connection_id in self._client_connection_requests def _check_client_connection(self, connection_id: str) -> bool: """Check that client_connection_id is valid.""" - return True if connection_id in self._connections else False + return connection_id in self._connections def _send_remote_login( self, @@ -323,7 +343,9 @@ class Terminal(Service): :return: RemoteTerminalConnection: Connection Object for sending further commands if successful, else False. """ - self.sys_log.info(f"Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}") + self.sys_log.info( + f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}" + ) if is_reattempt: valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id) if valid_connection_request: @@ -332,7 +354,7 @@ class Terminal(Service): self.sys_log.info(f"{self.name}: Remote Connection to {ip_address} authorised.") return remote_terminal_connection else: - self.sys_log.warning(f"Connection request{connection_request_id} declined") + self.sys_log.warning(f"{self.name}: Connection request {connection_request_id} declined") return None else: self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.") @@ -382,7 +404,7 @@ class Terminal(Service): """ client_connection = RemoteTerminalConnection( parent_terminal=self, - session_id=session_id, + ssh_session_id=session_id, connection_uuid=connection_id, ip_address=source_ip, connection_request_id=connection_request_id, @@ -399,20 +421,20 @@ class Terminal(Service): :param session_id: The session id the payload relates to. :return: True. """ - source_ip = kwargs["from_network_interface"].ip_address - self.sys_log.info(f"Received payload: {payload}. Source: {source_ip}") + source_ip = kwargs["frame"].ip.src_ip_address + self.sys_log.info(f"{self.name}: Received payload: {payload}. Source: {source_ip}") if isinstance(payload, SSHPacket): if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: # validate & add connection # TODO: uncomment this as part of 2781 - # connection_id = self.parent.UserSessionManager.login(username=username, password=password) - connection_id = str(uuid4()) + username = payload.user_account.username + password = payload.user_account.password + connection_id = self.parent.user_session_manager.remote_login( + username=username, password=password, remote_ip_address=source_ip + ) if connection_id: connection_request_id = payload.connection_request_uuid - username = payload.user_account.username - password = payload.user_account.password - print(f"Connection ID is: {connection_request_id}") - self.sys_log.info(f"Connection authorised, session_id: {session_id}") + self.sys_log.info(f"{self.name}: Connection authorised, session_id: {session_id}") self._create_remote_connection( connection_id=connection_id, connection_request_id=connection_request_id, @@ -443,7 +465,7 @@ class Terminal(Service): payload=payload, dest_port=self.port, session_id=session_id ) elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: - self.sys_log.info("Login Successful") + self.sys_log.info(f"{self.name}: Login Successful") self._create_remote_connection( connection_id=payload.connection_uuid, connection_request_id=payload.connection_request_uuid, @@ -453,14 +475,16 @@ class Terminal(Service): elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST: # Requesting a command to be executed - self.sys_log.info("Received command to execute") + self.sys_log.info(f"{self.name}: Received command to execute") command = payload.ssh_command valid_connection = self._check_client_connection(payload.connection_uuid) - self.sys_log.info(f"Connection uuid is {valid_connection}") if valid_connection: - return self.execute(command, payload.connection_uuid) + self.execute(command) + return True else: - self.sys_log.error(f"Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command.") + self.sys_log.error( + f"{self.name}: Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command." + ) if isinstance(payload, dict) and payload.get("type"): if payload["type"] == "disconnect": @@ -469,19 +493,30 @@ class Terminal(Service): if valid_id: self.sys_log.info(f"{self.name}: Received disconnect command for {connection_id=} from remote.") self._disconnect(payload["connection_id"]) + self.parent.user_session_manager.remote_logout(remote_session_id=connection_id) else: - self.sys_log.info("No Active connection held for received connection ID.") + self.sys_log.error(f"{self.name}: No Active connection held for received connection ID.") + + if payload["type"] == "user_timeout": + connection_id = payload["connection_id"] + valid_id = self._check_client_connection(connection_id) + if valid_id: + connection = self._connections.pop(connection_id) + connection.is_active = False + self.sys_log.info(f"{self.name}: Connection {connection_id} disconnected due to inactivity.") + else: + self.sys_log.error(f"{self.name}: Connection {connection_id} is invalid.") return True def _disconnect(self, connection_uuid: str) -> bool: - """Disconnect from the remote. + """Disconnect connection. :param connection_uuid: Connection ID that we want to disconnect. :return True if successful, False otherwise. """ if not self._connections: - self.sys_log.warning("No remote connection present") + self.sys_log.warning(f"{self.name}: No remote connection present") return False connection = self._connections.pop(connection_uuid) @@ -489,7 +524,7 @@ class Terminal(Service): if isinstance(connection, RemoteTerminalConnection): # Send disconnect command via software manager - session_id = connection.session_id + session_id = connection.ssh_session_id software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( @@ -501,7 +536,7 @@ class Terminal(Service): return True elif isinstance(connection, LocalTerminalConnection): - # No further action needed + self.parent.user_session_manager.local_logout() return True def send( @@ -514,10 +549,10 @@ class Terminal(Service): :param dest_up_address: The IP address of the payload destination. """ if self.operating_state != ServiceOperatingState.RUNNING: - self.sys_log.warning(f"Cannot send commands when Operating state is {self.operating_state}!") + self.sys_log.warning(f"{self.name}: Cannot send commands when Operating state is {self.operating_state}!") return False - self.sys_log.debug(f"Sending payload: {payload}") + self.sys_log.debug(f"{self.name}: Sending payload: {payload}") return super().send( payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port, session_id=session_id ) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 9286fa49..41858b90 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -185,7 +185,7 @@ def test_terminal_receive(basic_network): ) term_a_on_node_b: RemoteTerminalConnection = terminal_a.login( - username="username", password="password", ip_address="192.168.0.11" + username="admin", password="admin", ip_address="192.168.0.11" ) term_a_on_node_b.execute(["file_system", "create", "folder", folder_name]) @@ -208,7 +208,7 @@ def test_terminal_install(basic_network): ) term_a_on_node_b: RemoteTerminalConnection = terminal_a.login( - username="username", password="password", ip_address="192.168.0.11" + username="admin", password="admin", ip_address="192.168.0.11" ) term_a_on_node_b.execute(["software_manager", "application", "install", "RansomwareScript"]) @@ -225,9 +225,7 @@ def test_terminal_fail_when_closed(basic_network): terminal.operating_state = ServiceOperatingState.STOPPED - assert not terminal.login( - username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address - ) + assert not terminal.login(username="admin", password="admin", ip_address=computer_b.network_interface[1].ip_address) def test_terminal_disconnect(basic_network): @@ -241,7 +239,7 @@ def test_terminal_disconnect(basic_network): assert len(terminal_b._connections) == 0 term_a_on_term_b = terminal_a.login( - username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address + username="admin", password="admin", ip_address=computer_b.network_interface[1].ip_address ) assert len(terminal_b._connections) == 1 @@ -250,6 +248,8 @@ def test_terminal_disconnect(basic_network): assert len(terminal_b._connections) == 0 + assert term_a_on_term_b.is_active is False + def test_terminal_ignores_when_off(basic_network): """Terminal should ignore commands when not running""" @@ -260,7 +260,7 @@ def test_terminal_ignores_when_off(basic_network): computer_b: Computer = network.get_node_by_hostname("node_b") term_a_on_term_b: RemoteTerminalConnection = terminal_a.login( - username="admin", password="Admin123!", ip_address="192.168.0.11" + username="admin", password="admin", ip_address="192.168.0.11" ) # login to computer_b terminal_a.operating_state = ServiceOperatingState.STOPPED @@ -276,7 +276,7 @@ def test_computer_remote_login_to_router(wireless_wan_network): assert len(pc_a_terminal._connections) == 0 - pc_a_on_router_1 = pc_a_terminal.login(username="username", password="password", ip_address="192.168.1.1") + pc_a_on_router_1 = pc_a_terminal.login(username="admin", password="admin", ip_address="192.168.1.1") assert len(pc_a_terminal._connections) == 1 @@ -295,7 +295,7 @@ def test_router_remote_login_to_computer(wireless_wan_network): assert len(router_1_terminal._connections) == 0 - router_1_on_pc_a = router_1_terminal.login(username="username", password="password", ip_address="192.168.0.2") + router_1_on_pc_a = router_1_terminal.login(username="admin", password="admin", ip_address="192.168.0.2") assert len(router_1_terminal._connections) == 1 @@ -317,7 +317,7 @@ def test_router_blocks_SSH_traffic(wireless_wan_network): assert len(pc_a_terminal._connections) == 0 - pc_a_terminal.login(username="username", password="password", ip_address="192.168.0.2") + pc_a_terminal.login(username="admin", password="admin", ip_address="192.168.0.2") assert len(pc_a_terminal._connections) == 0 @@ -333,7 +333,7 @@ def test_SSH_across_network(wireless_wan_network): assert len(terminal_a._connections) == 0 - terminal_b_on_terminal_a = terminal_b.login(username="username", password="password", ip_address="192.168.0.2") + terminal_b_on_terminal_a = terminal_b.login(username="admin", password="admin", ip_address="192.168.0.2") assert len(terminal_a._connections) == 1 @@ -347,11 +347,13 @@ def test_multiple_remote_terminals_same_node(basic_network): assert len(terminal_a._connections) == 0 - # Spam login requests to terminal. - for attempt in range(10): - remote_connection = terminal_a.login(username="username", password="password", ip_address="192.168.0.11") + # Spam login requests to node. + for attempt in range(3): + remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11") - assert len(terminal_a._connections) == 10 + terminal_a.show() + + assert len(terminal_a._connections) == 3 def test_terminal_rejects_commands_if_disconnect(basic_network): @@ -363,7 +365,7 @@ def test_terminal_rejects_commands_if_disconnect(basic_network): terminal_b: Terminal = computer_b.software_manager.software.get("Terminal") - remote_connection = terminal_a.login(username="username", password="password", ip_address="192.168.0.11") + remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11") assert len(terminal_a._connections) == 1 assert len(terminal_b._connections) == 1 @@ -378,3 +380,27 @@ def test_terminal_rejects_commands_if_disconnect(basic_network): assert not computer_b.software_manager.software.get("RansomwareScript") assert remote_connection.is_active is False + + +def test_terminal_connection_timeout(basic_network): + """Test that terminal_connections are affected by UserSession timeout.""" + network: Network = basic_network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") + terminal_b: Terminal = computer_b.software_manager.software.get("Terminal") + + remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11") + + assert len(terminal_a._connections) == 1 + assert len(terminal_b._connections) == 1 + assert len(computer_b.user_session_manager.remote_sessions) == 1 + + remote_session = computer_b.user_session_manager.remote_sessions[remote_connection.connection_uuid] + computer_b.user_session_manager._timeout_session(remote_session) + + assert len(terminal_a._connections) == 0 + assert len(terminal_b._connections) == 0 + assert len(computer_b.user_session_manager.remote_sessions) == 0 + + assert not remote_connection.is_active