diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py index ca9663d8..be7f842f 100644 --- a/src/primaite/simulator/network/protocols/ssh.py +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -85,5 +85,5 @@ class SSHPacket(DataPacket): ssh_output: Optional[RequestResponse] = None """RequestResponse from Request Manager""" - ssh_command: Optional[str] = None + ssh_command: Optional[list] = None """Request String""" diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 0bcec90d..0ebae491 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -1,6 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations +from abc import abstractmethod from datetime import datetime from ipaddress import IPv4Address from typing import Any, Dict, List, Optional, Union @@ -42,11 +43,14 @@ class TerminalClientConnection(BaseModel): """Connection request ID""" time: datetime = None - """Timestammp connection was created.""" + """Timestamp connection was created.""" ip_address: IPv4Address """Source IP of Connection""" + is_active: bool = True + """Flag to state whether the connection is active or not""" + def __str__(self) -> str: return f"{self.__class__.__name__}(connection_id='{self.connection_uuid}')" @@ -65,6 +69,28 @@ class TerminalClientConnection(BaseModel): """Disconnect the session.""" return self.parent_terminal._disconnect(connection_uuid=self.connection_uuid) + @abstractmethod + def execute(self, command: Any) -> bool: + """Execute a given command.""" + pass + + +class LocalTerminalConnection(TerminalClientConnection): + """ + LocalTerminalConnectionClass. + + This class represents a local terminal when connected. + """ + + ip_address: str = "Local Connection" + + def execute(self, command: Any) -> RequestResponse: + """Execute a given command on local Terminal.""" + if not self.is_active: + self.parent_terminal.sys_log.warning("Connection inactive, cannot execute") + return None + return self.parent_terminal.execute(command, connection_id=self.connection_uuid) + class RemoteTerminalConnection(TerminalClientConnection): """ @@ -78,8 +104,24 @@ class RemoteTerminalConnection(TerminalClientConnection): """Execute a given command on the remote Terminal.""" if self.parent_terminal.operating_state != ServiceOperatingState.RUNNING: self.parent_terminal.sys_log.warning("Cannot process command as system not running") + return False + if not self.is_active: + self.parent_terminal.sys_log.warning("Connection inactive, cannot execute") + return False # Send command to remote terminal to process. - return self.parent_terminal.send(payload=command, session_id=self.session_id) + + transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_SERVICE_REQUEST + connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA + + payload: SSHPacket = SSHPacket( + transport_message=transport_message, + connection_message=connection_message, + connection_request_uuid=self.connection_request_id, + connection_uuid=self.connection_uuid, + ssh_command=command, + ) + + return self.parent_terminal.send(payload=payload, session_id=self.session_id) class Terminal(Service): @@ -138,7 +180,8 @@ class Terminal(Service): def _execute_request(request: RequestFormat, context: Dict) -> RequestResponse: """Execute an instruction.""" command: str = request[0] - self.execute(command) + connection_id: str = request[1] + self.execute(command, connection_id=connection_id) return RequestResponse(status="success", data={}) def _logoff(request: RequestFormat, context: Dict) -> RequestResponse: @@ -169,9 +212,14 @@ class Terminal(Service): return rm - def execute(self, command: List[Any]) -> RequestResponse: + def execute(self, command: List[Any], connection_id: str) -> Optional[RequestResponse]: """Execute a passed ssh command via the request manager.""" - return self.parent.apply_request(command) + valid_connection = self._check_client_connection(connection_id=connection_id) + if valid_connection: + return self.parent.apply_request(command) + else: + self.sys_log.error("Invalid connection ID provided") + return None def _create_local_connection(self, connection_uuid: str, session_id: str) -> TerminalClientConnection: """Create a new connection object and amend to list of active connections. @@ -180,7 +228,7 @@ class Terminal(Service): :param session_id: Session ID of the new local connection :return: TerminalClientConnection object """ - new_connection = TerminalClientConnection( + new_connection = LocalTerminalConnection( parent_terminal=self, connection_uuid=connection_uuid, session_id=session_id, @@ -340,7 +388,7 @@ class Terminal(Service): self._connections[connection_id] = client_connection self._client_connection_requests[connection_request_id] = client_connection - def receive(self, session_id: str, payload: Union[SSHPacket, Dict, List], **kwargs) -> bool: + def receive(self, session_id: str, payload: Union[SSHPacket, Dict], **kwargs) -> bool: """ Receive a payload from the Software Manager. @@ -400,6 +448,17 @@ class Terminal(Service): source_ip=source_ip, ) + elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST: + # Requesting a command to be executed + self.sys_log.info("Received command to execute") + command = payload.ssh_command + valid_connection = self._check_client_connection(payload.connection_uuid) + self.sys_log.info(f"Connection uuid is {valid_connection}") + if valid_connection: + return self.execute(command, payload.connection_uuid) + else: + self.sys_log.error(f"Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command.") + if isinstance(payload, dict) and payload.get("type"): if payload["type"] == "disconnect": connection_id = payload["connection_id"] @@ -410,10 +469,6 @@ class Terminal(Service): else: self.sys_log.info("No Active connection held for received connection ID.") - if isinstance(payload, list): - # A request? For me? - self.execute(payload) - return True def _disconnect(self, connection_uuid: str) -> bool: @@ -426,16 +481,25 @@ class Terminal(Service): self.sys_log.warning("No remote connection present") return False - # session_id = self._connections[connection_uuid].session_id - connection: RemoteTerminalConnection = self._connections.pop(connection_uuid) - session_id = connection.session_id + connection = self._connections.pop(connection_uuid) + connection.is_active = False - software_manager: SoftwareManager = self.software_manager - software_manager.send_payload_to_session_manager( - payload={"type": "disconnect", "connection_id": connection_uuid}, dest_port=self.port, session_id=session_id - ) - self.sys_log.info(f"{self.name}: Disconnected {connection_uuid}") - return True + if isinstance(connection, RemoteTerminalConnection): + # Send disconnect command via software manager + session_id = connection.session_id + + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload={"type": "disconnect", "connection_id": connection_uuid}, + dest_port=self.port, + session_id=session_id, + ) + self.sys_log.info(f"{self.name}: Disconnected {connection_uuid}") + return True + + elif isinstance(connection, LocalTerminalConnection): + # No further action needed + return True def send( self, payload: SSHPacket, dest_ip_address: Optional[IPv4Address] = None, session_id: Optional[str] = None diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 7e98e501..cdd0ebb3 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -352,3 +352,27 @@ def test_multiple_remote_terminals_same_node(basic_network): remote_connection = terminal_a.login(username="username", password="password", ip_address="192.168.0.11") assert len(terminal_a._connections) == 10 + + +def test_terminal_rejects_commands_if_disconnect(basic_network): + """Test to check terminal will ignore commands from disconnected connections""" + network: Network = basic_network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") + + terminal_b: Terminal = computer_b.software_manager.software.get("Terminal") + + remote_connection = terminal_a.login(username="username", password="password", ip_address="192.168.0.11") + + assert len(terminal_a._connections) == 1 + assert len(terminal_b._connections) == 1 + + remote_connection.disconnect() + + assert len(terminal_a._connections) == 0 + assert len(terminal_b._connections) == 0 + + assert remote_connection.execute(["software_manager", "application", "install", "RansomwareScript"]) is False + + assert not computer_b.software_manager.software.get("RansomwareScript")