#2706 - Resolving an issue that saw disconnected terminal connections still able to send execute commands that were also then processed by the target node. Created a new class: LocalterminalConnection, for local connection objects to terminal. Calling terminal.show() when there is a local connection will have 'Local Connection' as the IP address. Receive and execute will check that the provided connection uuid is valid before actioning any commands. TerminalClientConnection objects now have an is_active flag similar to DatabaseClientConnection. Added a new test to check that terminals will reject commands from disconnected clientconnection objects.

This commit is contained in:
Charlie Crane
2024-08-06 19:09:23 +01:00
parent de14dfdc48
commit d05fd00594
3 changed files with 109 additions and 21 deletions

View File

@@ -85,5 +85,5 @@ class SSHPacket(DataPacket):
ssh_output: Optional[RequestResponse] = None
"""RequestResponse from Request Manager"""
ssh_command: Optional[str] = None
ssh_command: Optional[list] = None
"""Request String"""

View File

@@ -1,6 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from __future__ import annotations
from abc import abstractmethod
from datetime import datetime
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional, Union
@@ -42,11 +43,14 @@ class TerminalClientConnection(BaseModel):
"""Connection request ID"""
time: datetime = None
"""Timestammp connection was created."""
"""Timestamp connection was created."""
ip_address: IPv4Address
"""Source IP of Connection"""
is_active: bool = True
"""Flag to state whether the connection is active or not"""
def __str__(self) -> str:
return f"{self.__class__.__name__}(connection_id='{self.connection_uuid}')"
@@ -65,6 +69,28 @@ class TerminalClientConnection(BaseModel):
"""Disconnect the session."""
return self.parent_terminal._disconnect(connection_uuid=self.connection_uuid)
@abstractmethod
def execute(self, command: Any) -> bool:
"""Execute a given command."""
pass
class LocalTerminalConnection(TerminalClientConnection):
"""
LocalTerminalConnectionClass.
This class represents a local terminal when connected.
"""
ip_address: str = "Local Connection"
def execute(self, command: Any) -> RequestResponse:
"""Execute a given command on local Terminal."""
if not self.is_active:
self.parent_terminal.sys_log.warning("Connection inactive, cannot execute")
return None
return self.parent_terminal.execute(command, connection_id=self.connection_uuid)
class RemoteTerminalConnection(TerminalClientConnection):
"""
@@ -78,8 +104,24 @@ class RemoteTerminalConnection(TerminalClientConnection):
"""Execute a given command on the remote Terminal."""
if self.parent_terminal.operating_state != ServiceOperatingState.RUNNING:
self.parent_terminal.sys_log.warning("Cannot process command as system not running")
return False
if not self.is_active:
self.parent_terminal.sys_log.warning("Connection inactive, cannot execute")
return False
# Send command to remote terminal to process.
return self.parent_terminal.send(payload=command, session_id=self.session_id)
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_SERVICE_REQUEST
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
payload: SSHPacket = SSHPacket(
transport_message=transport_message,
connection_message=connection_message,
connection_request_uuid=self.connection_request_id,
connection_uuid=self.connection_uuid,
ssh_command=command,
)
return self.parent_terminal.send(payload=payload, session_id=self.session_id)
class Terminal(Service):
@@ -138,7 +180,8 @@ class Terminal(Service):
def _execute_request(request: RequestFormat, context: Dict) -> RequestResponse:
"""Execute an instruction."""
command: str = request[0]
self.execute(command)
connection_id: str = request[1]
self.execute(command, connection_id=connection_id)
return RequestResponse(status="success", data={})
def _logoff(request: RequestFormat, context: Dict) -> RequestResponse:
@@ -169,9 +212,14 @@ class Terminal(Service):
return rm
def execute(self, command: List[Any]) -> RequestResponse:
def execute(self, command: List[Any], connection_id: str) -> Optional[RequestResponse]:
"""Execute a passed ssh command via the request manager."""
return self.parent.apply_request(command)
valid_connection = self._check_client_connection(connection_id=connection_id)
if valid_connection:
return self.parent.apply_request(command)
else:
self.sys_log.error("Invalid connection ID provided")
return None
def _create_local_connection(self, connection_uuid: str, session_id: str) -> TerminalClientConnection:
"""Create a new connection object and amend to list of active connections.
@@ -180,7 +228,7 @@ class Terminal(Service):
:param session_id: Session ID of the new local connection
:return: TerminalClientConnection object
"""
new_connection = TerminalClientConnection(
new_connection = LocalTerminalConnection(
parent_terminal=self,
connection_uuid=connection_uuid,
session_id=session_id,
@@ -340,7 +388,7 @@ class Terminal(Service):
self._connections[connection_id] = client_connection
self._client_connection_requests[connection_request_id] = client_connection
def receive(self, session_id: str, payload: Union[SSHPacket, Dict, List], **kwargs) -> bool:
def receive(self, session_id: str, payload: Union[SSHPacket, Dict], **kwargs) -> bool:
"""
Receive a payload from the Software Manager.
@@ -400,6 +448,17 @@ class Terminal(Service):
source_ip=source_ip,
)
elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST:
# Requesting a command to be executed
self.sys_log.info("Received command to execute")
command = payload.ssh_command
valid_connection = self._check_client_connection(payload.connection_uuid)
self.sys_log.info(f"Connection uuid is {valid_connection}")
if valid_connection:
return self.execute(command, payload.connection_uuid)
else:
self.sys_log.error(f"Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command.")
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "disconnect":
connection_id = payload["connection_id"]
@@ -410,10 +469,6 @@ class Terminal(Service):
else:
self.sys_log.info("No Active connection held for received connection ID.")
if isinstance(payload, list):
# A request? For me?
self.execute(payload)
return True
def _disconnect(self, connection_uuid: str) -> bool:
@@ -426,16 +481,25 @@ class Terminal(Service):
self.sys_log.warning("No remote connection present")
return False
# session_id = self._connections[connection_uuid].session_id
connection: RemoteTerminalConnection = self._connections.pop(connection_uuid)
session_id = connection.session_id
connection = self._connections.pop(connection_uuid)
connection.is_active = False
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "disconnect", "connection_id": connection_uuid}, dest_port=self.port, session_id=session_id
)
self.sys_log.info(f"{self.name}: Disconnected {connection_uuid}")
return True
if isinstance(connection, RemoteTerminalConnection):
# Send disconnect command via software manager
session_id = connection.session_id
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "disconnect", "connection_id": connection_uuid},
dest_port=self.port,
session_id=session_id,
)
self.sys_log.info(f"{self.name}: Disconnected {connection_uuid}")
return True
elif isinstance(connection, LocalTerminalConnection):
# No further action needed
return True
def send(
self, payload: SSHPacket, dest_ip_address: Optional[IPv4Address] = None, session_id: Optional[str] = None

View File

@@ -352,3 +352,27 @@ def test_multiple_remote_terminals_same_node(basic_network):
remote_connection = terminal_a.login(username="username", password="password", ip_address="192.168.0.11")
assert len(terminal_a._connections) == 10
def test_terminal_rejects_commands_if_disconnect(basic_network):
"""Test to check terminal will ignore commands from disconnected connections"""
network: Network = basic_network
computer_a: Computer = network.get_node_by_hostname("node_a")
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
computer_b: Computer = network.get_node_by_hostname("node_b")
terminal_b: Terminal = computer_b.software_manager.software.get("Terminal")
remote_connection = terminal_a.login(username="username", password="password", ip_address="192.168.0.11")
assert len(terminal_a._connections) == 1
assert len(terminal_b._connections) == 1
remote_connection.disconnect()
assert len(terminal_a._connections) == 0
assert len(terminal_b._connections) == 0
assert remote_connection.execute(["software_manager", "application", "install", "RansomwareScript"]) is False
assert not computer_b.software_manager.software.get("RansomwareScript")