diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 7842aa66..1441c93b 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1278,6 +1278,10 @@ class UserSessionManager(Service): if self.local_session: if self.local_session.last_active_step + self.local_session_timeout_steps <= timestep: self._timeout_session(self.local_session) + for session in self.remote_sessions: + remote_session = self.remote_sessions[session] + if remote_session.last_active_step + self.remote_session_timeout_steps <= timestep: + self._timeout_session(remote_session) def _timeout_session(self, session: UserSession) -> None: """ diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 85e0c87f..876b1694 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -196,12 +196,13 @@ class Terminal(Service): command: str = request[0] ip_address: IPv4Address = IPv4Address(request[1]) remote_connection = self._get_connection_from_ip(ip_address=ip_address) - outcome = remote_connection.execute(command) - if outcome: - return RequestResponse( - status="success", - data={}, - ) + if remote_connection: + outcome = remote_connection.execute(command) + if outcome: + return RequestResponse( + status="success", + data={}, + ) else: return RequestResponse( status="failure", @@ -240,11 +241,9 @@ class Terminal(Service): def _get_connection_from_ip(self, ip_address: IPv4Address) -> Optional[RemoteTerminalConnection]: """Find Remote Terminal Connection from a given IP.""" - for connection in self._connections: - if self._connections[connection].ip_address == ip_address: - return self._connections[connection] - else: - return None + for connection in self._connections.values(): + if connection.ip_address == ip_address: + return connection def _create_local_connection(self, connection_uuid: str, session_id: str) -> TerminalClientConnection: """Create a new connection object and amend to list of active connections. @@ -279,7 +278,7 @@ class Terminal(Service): :type: ip_address: Optional[IPv4Address] """ if self.operating_state != ServiceOperatingState.RUNNING: - self.sys_log.warning("Cannot login as service is not running.") + self.sys_log.warning(f"{self.name}: Cannot login as service is not running.") return None connection_request_id = str(uuid4()) self._client_connection_requests[connection_request_id] = None @@ -301,11 +300,11 @@ class Terminal(Service): # TODO: Un-comment this when UserSessionManager is merged. connection_uuid = self.parent.user_session_manager.local_login(username=username, password=password) if connection_uuid: - self.sys_log.info(f"Login request authorised, connection uuid: {connection_uuid}") + self.sys_log.info(f"{self.name}: Login request authorised, connection uuid: {connection_uuid}") # Add new local session to list of connections and return return self._create_local_connection(connection_uuid=connection_uuid, session_id="Local_Connection") else: - self.sys_log.warning("Login failed, incorrect Username or Password") + self.sys_log.warning(f"{self.name}: Login failed, incorrect Username or Password") return None def _validate_client_connection_request(self, connection_id: str) -> bool: @@ -344,7 +343,9 @@ class Terminal(Service): :return: RemoteTerminalConnection: Connection Object for sending further commands if successful, else False. """ - self.sys_log.info(f"Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}") + self.sys_log.info( + f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}" + ) if is_reattempt: valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id) if valid_connection_request: @@ -353,7 +354,7 @@ class Terminal(Service): self.sys_log.info(f"{self.name}: Remote Connection to {ip_address} authorised.") return remote_terminal_connection else: - self.sys_log.warning(f"Connection request {connection_request_id} declined") + self.sys_log.warning(f"{self.name}: Connection request {connection_request_id} declined") return None else: self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.") @@ -420,8 +421,8 @@ class Terminal(Service): :param session_id: The session id the payload relates to. :return: True. """ - source_ip = [kwargs["frame"].ip.src_ip_address][0] - self.sys_log.info(f"Received payload: {payload}. Source: {source_ip}") + source_ip = kwargs["frame"].ip.src_ip_address + self.sys_log.info(f"{self.name}: Received payload: {payload}. Source: {source_ip}") if isinstance(payload, SSHPacket): if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: # validate & add connection @@ -431,10 +432,9 @@ class Terminal(Service): connection_id = self.parent.user_session_manager.remote_login( username=username, password=password, remote_ip_address=source_ip ) - # connection_id = str(uuid4()) if connection_id: connection_request_id = payload.connection_request_uuid - self.sys_log.info(f"Connection authorised, session_id: {session_id}") + self.sys_log.info(f"{self.name}: Connection authorised, session_id: {session_id}") self._create_remote_connection( connection_id=connection_id, connection_request_id=connection_request_id, @@ -465,7 +465,7 @@ class Terminal(Service): payload=payload, dest_port=self.port, session_id=session_id ) elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: - self.sys_log.info("Login Successful") + self.sys_log.info(f"{self.name}: Login Successful") self._create_remote_connection( connection_id=payload.connection_uuid, connection_request_id=payload.connection_request_uuid, @@ -475,13 +475,16 @@ class Terminal(Service): elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST: # Requesting a command to be executed - self.sys_log.info("Received command to execute") + self.sys_log.info(f"{self.name}: Received command to execute") command = payload.ssh_command valid_connection = self._check_client_connection(payload.connection_uuid) if valid_connection: - return self.execute(command) + self.execute(command) + return True else: - self.sys_log.error(f"Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command.") + self.sys_log.error( + f"{self.name}: Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command." + ) if isinstance(payload, dict) and payload.get("type"): if payload["type"] == "disconnect": @@ -492,13 +495,14 @@ class Terminal(Service): self._disconnect(payload["connection_id"]) self.parent.user_session_manager.remote_logout(remote_session_id=connection_id) else: - self.sys_log.error("No Active connection held for received connection ID.") + self.sys_log.error(f"{self.name}: No Active connection held for received connection ID.") if payload["type"] == "user_timeout": connection_id = payload["connection_id"] valid_id = self._check_client_connection(connection_id) if valid_id: - self._connections.pop(connection_id) + connection = self._connections.pop(connection_id) + connection.is_active = False self.sys_log.info(f"{self.name}: Connection {connection_id} disconnected due to inactivity.") else: self.sys_log.error(f"{self.name}: Connection {connection_id} is invalid.") @@ -512,7 +516,7 @@ class Terminal(Service): :return True if successful, False otherwise. """ if not self._connections: - self.sys_log.warning("No remote connection present") + self.sys_log.warning(f"{self.name}: No remote connection present") return False connection = self._connections.pop(connection_uuid) @@ -545,10 +549,10 @@ class Terminal(Service): :param dest_up_address: The IP address of the payload destination. """ if self.operating_state != ServiceOperatingState.RUNNING: - self.sys_log.warning(f"Cannot send commands when Operating state is {self.operating_state}!") + self.sys_log.warning(f"{self.name}: Cannot send commands when Operating state is {self.operating_state}!") return False - self.sys_log.debug(f"Sending payload: {payload}") + self.sys_log.debug(f"{self.name}: Sending payload: {payload}") return super().send( payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port, session_id=session_id ) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index ffe48ab5..41858b90 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -248,6 +248,8 @@ def test_terminal_disconnect(basic_network): assert len(terminal_b._connections) == 0 + assert term_a_on_term_b.is_active is False + def test_terminal_ignores_when_off(basic_network): """Terminal should ignore commands when not running""" @@ -378,3 +380,27 @@ def test_terminal_rejects_commands_if_disconnect(basic_network): assert not computer_b.software_manager.software.get("RansomwareScript") assert remote_connection.is_active is False + + +def test_terminal_connection_timeout(basic_network): + """Test that terminal_connections are affected by UserSession timeout.""" + network: Network = basic_network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") + terminal_b: Terminal = computer_b.software_manager.software.get("Terminal") + + remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11") + + assert len(terminal_a._connections) == 1 + assert len(terminal_b._connections) == 1 + assert len(computer_b.user_session_manager.remote_sessions) == 1 + + remote_session = computer_b.user_session_manager.remote_sessions[remote_connection.connection_uuid] + computer_b.user_session_manager._timeout_session(remote_session) + + assert len(terminal_a._connections) == 0 + assert len(terminal_b._connections) == 0 + assert len(computer_b.user_session_manager.remote_sessions) == 0 + + assert not remote_connection.is_active