#2706 - Resolving an issue that saw disconnected terminal connections still able to send execute commands that were also then processed by the target node. Created a new class: LocalterminalConnection, for local connection objects to terminal. Calling terminal.show() when there is a local connection will have 'Local Connection' as the IP address. Receive and execute will check that the provided connection uuid is valid before actioning any commands. TerminalClientConnection objects now have an is_active flag similar to DatabaseClientConnection. Added a new test to check that terminals will reject commands from disconnected clientconnection objects.
This commit is contained in:
@@ -85,5 +85,5 @@ class SSHPacket(DataPacket):
|
||||
ssh_output: Optional[RequestResponse] = None
|
||||
"""RequestResponse from Request Manager"""
|
||||
|
||||
ssh_command: Optional[str] = None
|
||||
ssh_command: Optional[list] = None
|
||||
"""Request String"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
@@ -42,11 +43,14 @@ class TerminalClientConnection(BaseModel):
|
||||
"""Connection request ID"""
|
||||
|
||||
time: datetime = None
|
||||
"""Timestammp connection was created."""
|
||||
"""Timestamp connection was created."""
|
||||
|
||||
ip_address: IPv4Address
|
||||
"""Source IP of Connection"""
|
||||
|
||||
is_active: bool = True
|
||||
"""Flag to state whether the connection is active or not"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}(connection_id='{self.connection_uuid}')"
|
||||
|
||||
@@ -65,6 +69,28 @@ class TerminalClientConnection(BaseModel):
|
||||
"""Disconnect the session."""
|
||||
return self.parent_terminal._disconnect(connection_uuid=self.connection_uuid)
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, command: Any) -> bool:
|
||||
"""Execute a given command."""
|
||||
pass
|
||||
|
||||
|
||||
class LocalTerminalConnection(TerminalClientConnection):
|
||||
"""
|
||||
LocalTerminalConnectionClass.
|
||||
|
||||
This class represents a local terminal when connected.
|
||||
"""
|
||||
|
||||
ip_address: str = "Local Connection"
|
||||
|
||||
def execute(self, command: Any) -> RequestResponse:
|
||||
"""Execute a given command on local Terminal."""
|
||||
if not self.is_active:
|
||||
self.parent_terminal.sys_log.warning("Connection inactive, cannot execute")
|
||||
return None
|
||||
return self.parent_terminal.execute(command, connection_id=self.connection_uuid)
|
||||
|
||||
|
||||
class RemoteTerminalConnection(TerminalClientConnection):
|
||||
"""
|
||||
@@ -78,8 +104,24 @@ class RemoteTerminalConnection(TerminalClientConnection):
|
||||
"""Execute a given command on the remote Terminal."""
|
||||
if self.parent_terminal.operating_state != ServiceOperatingState.RUNNING:
|
||||
self.parent_terminal.sys_log.warning("Cannot process command as system not running")
|
||||
return False
|
||||
if not self.is_active:
|
||||
self.parent_terminal.sys_log.warning("Connection inactive, cannot execute")
|
||||
return False
|
||||
# Send command to remote terminal to process.
|
||||
return self.parent_terminal.send(payload=command, session_id=self.session_id)
|
||||
|
||||
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_SERVICE_REQUEST
|
||||
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
|
||||
|
||||
payload: SSHPacket = SSHPacket(
|
||||
transport_message=transport_message,
|
||||
connection_message=connection_message,
|
||||
connection_request_uuid=self.connection_request_id,
|
||||
connection_uuid=self.connection_uuid,
|
||||
ssh_command=command,
|
||||
)
|
||||
|
||||
return self.parent_terminal.send(payload=payload, session_id=self.session_id)
|
||||
|
||||
|
||||
class Terminal(Service):
|
||||
@@ -138,7 +180,8 @@ class Terminal(Service):
|
||||
def _execute_request(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
"""Execute an instruction."""
|
||||
command: str = request[0]
|
||||
self.execute(command)
|
||||
connection_id: str = request[1]
|
||||
self.execute(command, connection_id=connection_id)
|
||||
return RequestResponse(status="success", data={})
|
||||
|
||||
def _logoff(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
@@ -169,9 +212,14 @@ class Terminal(Service):
|
||||
|
||||
return rm
|
||||
|
||||
def execute(self, command: List[Any]) -> RequestResponse:
|
||||
def execute(self, command: List[Any], connection_id: str) -> Optional[RequestResponse]:
|
||||
"""Execute a passed ssh command via the request manager."""
|
||||
return self.parent.apply_request(command)
|
||||
valid_connection = self._check_client_connection(connection_id=connection_id)
|
||||
if valid_connection:
|
||||
return self.parent.apply_request(command)
|
||||
else:
|
||||
self.sys_log.error("Invalid connection ID provided")
|
||||
return None
|
||||
|
||||
def _create_local_connection(self, connection_uuid: str, session_id: str) -> TerminalClientConnection:
|
||||
"""Create a new connection object and amend to list of active connections.
|
||||
@@ -180,7 +228,7 @@ class Terminal(Service):
|
||||
:param session_id: Session ID of the new local connection
|
||||
:return: TerminalClientConnection object
|
||||
"""
|
||||
new_connection = TerminalClientConnection(
|
||||
new_connection = LocalTerminalConnection(
|
||||
parent_terminal=self,
|
||||
connection_uuid=connection_uuid,
|
||||
session_id=session_id,
|
||||
@@ -340,7 +388,7 @@ class Terminal(Service):
|
||||
self._connections[connection_id] = client_connection
|
||||
self._client_connection_requests[connection_request_id] = client_connection
|
||||
|
||||
def receive(self, session_id: str, payload: Union[SSHPacket, Dict, List], **kwargs) -> bool:
|
||||
def receive(self, session_id: str, payload: Union[SSHPacket, Dict], **kwargs) -> bool:
|
||||
"""
|
||||
Receive a payload from the Software Manager.
|
||||
|
||||
@@ -400,6 +448,17 @@ class Terminal(Service):
|
||||
source_ip=source_ip,
|
||||
)
|
||||
|
||||
elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST:
|
||||
# Requesting a command to be executed
|
||||
self.sys_log.info("Received command to execute")
|
||||
command = payload.ssh_command
|
||||
valid_connection = self._check_client_connection(payload.connection_uuid)
|
||||
self.sys_log.info(f"Connection uuid is {valid_connection}")
|
||||
if valid_connection:
|
||||
return self.execute(command, payload.connection_uuid)
|
||||
else:
|
||||
self.sys_log.error(f"Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command.")
|
||||
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "disconnect":
|
||||
connection_id = payload["connection_id"]
|
||||
@@ -410,10 +469,6 @@ class Terminal(Service):
|
||||
else:
|
||||
self.sys_log.info("No Active connection held for received connection ID.")
|
||||
|
||||
if isinstance(payload, list):
|
||||
# A request? For me?
|
||||
self.execute(payload)
|
||||
|
||||
return True
|
||||
|
||||
def _disconnect(self, connection_uuid: str) -> bool:
|
||||
@@ -426,16 +481,25 @@ class Terminal(Service):
|
||||
self.sys_log.warning("No remote connection present")
|
||||
return False
|
||||
|
||||
# session_id = self._connections[connection_uuid].session_id
|
||||
connection: RemoteTerminalConnection = self._connections.pop(connection_uuid)
|
||||
session_id = connection.session_id
|
||||
connection = self._connections.pop(connection_uuid)
|
||||
connection.is_active = False
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect", "connection_id": connection_uuid}, dest_port=self.port, session_id=session_id
|
||||
)
|
||||
self.sys_log.info(f"{self.name}: Disconnected {connection_uuid}")
|
||||
return True
|
||||
if isinstance(connection, RemoteTerminalConnection):
|
||||
# Send disconnect command via software manager
|
||||
session_id = connection.session_id
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect", "connection_id": connection_uuid},
|
||||
dest_port=self.port,
|
||||
session_id=session_id,
|
||||
)
|
||||
self.sys_log.info(f"{self.name}: Disconnected {connection_uuid}")
|
||||
return True
|
||||
|
||||
elif isinstance(connection, LocalTerminalConnection):
|
||||
# No further action needed
|
||||
return True
|
||||
|
||||
def send(
|
||||
self, payload: SSHPacket, dest_ip_address: Optional[IPv4Address] = None, session_id: Optional[str] = None
|
||||
|
||||
@@ -352,3 +352,27 @@ def test_multiple_remote_terminals_same_node(basic_network):
|
||||
remote_connection = terminal_a.login(username="username", password="password", ip_address="192.168.0.11")
|
||||
|
||||
assert len(terminal_a._connections) == 10
|
||||
|
||||
|
||||
def test_terminal_rejects_commands_if_disconnect(basic_network):
|
||||
"""Test to check terminal will ignore commands from disconnected connections"""
|
||||
network: Network = basic_network
|
||||
computer_a: Computer = network.get_node_by_hostname("node_a")
|
||||
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
|
||||
computer_b: Computer = network.get_node_by_hostname("node_b")
|
||||
|
||||
terminal_b: Terminal = computer_b.software_manager.software.get("Terminal")
|
||||
|
||||
remote_connection = terminal_a.login(username="username", password="password", ip_address="192.168.0.11")
|
||||
|
||||
assert len(terminal_a._connections) == 1
|
||||
assert len(terminal_b._connections) == 1
|
||||
|
||||
remote_connection.disconnect()
|
||||
|
||||
assert len(terminal_a._connections) == 0
|
||||
assert len(terminal_b._connections) == 0
|
||||
|
||||
assert remote_connection.execute(["software_manager", "application", "install", "RansomwareScript"]) is False
|
||||
|
||||
assert not computer_b.software_manager.software.get("RansomwareScript")
|
||||
|
||||
Reference in New Issue
Block a user