diff --git a/docs/source/simulation_components/system/services/terminal.rst b/docs/source/simulation_components/system/services/terminal.rst index afa79c0a..49dc941b 100644 --- a/docs/source/simulation_components/system/services/terminal.rst +++ b/docs/source/simulation_components/system/services/terminal.rst @@ -5,9 +5,16 @@ .. _Terminal: Terminal -######## +======== -The ``Terminal`` provides a generic terminal simulation, by extending the base Service class +The ``Terminal.py`` class provides a generic terminal simulation, by extending the base Service class within PrimAITE. The aim of this is to act as the primary entrypoint for Nodes within the environment. + + +Overview +-------- + +The Terminal service uses Secure Socket (SSH) as the communication method between terminals. They operate on port 22, and are part of the services automatically +installed on Nodes when they are instantiated. Key capabilities ================ @@ -17,21 +24,22 @@ Key capabilities - Simulates common Terminal commands - Leverages the Service base class for install/uninstall, status tracking etc. - Usage ===== - - Install on a node via the ``SoftwareManager`` to start the Terminal - - Terminal Clients connect, execute commands and disconnect. + - Pre-Installs on any `HostNode` component. See the below code example of how to access the terminal. + - Terminal Clients connect, execute commands and disconnect from remote components. + - Ensures that users are logged in to the component before executing any commands. - Service runs on SSH port 22 by default. Implementation ============== -- Manages SSH commands -- Ensures User login before sending commands -- Processes SSH commands -- Returns results in a ** format. +The terminal takes inspiration from the `Database Client` and `Database Service` classes, and leverages the `UserSessionManager` +to provide User Credential authentication when receiving/processing commands. + +Terminal acts as the interface between the user/component and both the Session and Requests Managers, facilitating +the passing of requests to both. Python diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py index af1c550a..5eb181a6 100644 --- a/src/primaite/simulator/network/protocols/ssh.py +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -1,7 +1,8 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from enum import IntEnum -from typing import Dict, Optional +from ipaddress import IPv4Address +from typing import Optional from primaite.interface.request import RequestResponse from primaite.simulator.network.protocols.packet import DataPacket @@ -58,21 +59,32 @@ class SSHConnectionMessage(IntEnum): class SSHUserCredentials(DataPacket): - """Hold Username and Password in SSH Packets""" + """Hold Username and Password in SSH Packets.""" - username: str = None + username: str """Username for login""" - password: str = None + password: str """Password for login""" class SSHPacket(DataPacket): """Represents an SSHPacket.""" - transport_message: SSHTransportMessage = None + sender_ip_address: IPv4Address + """Sender IP Address""" - connection_message: SSHConnectionMessage = None + target_ip_address: IPv4Address + """Target IP Address""" + + transport_message: SSHTransportMessage + """Message Transport Type""" + + connection_message: SSHConnectionMessage + """Message Connection Status""" + + user_account: Optional[SSHUserCredentials] = None + """User Account Credentials if passed""" connection_uuid: Optional[str] = None # The connection uuid used to validate the session diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 039fbeb3..7f37bc28 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -3,30 +3,33 @@ from __future__ import annotations from ipaddress import IPv4Address from typing import Dict, List, Optional -from uuid import uuid4 from pydantic import BaseModel from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType from primaite.simulator.network.hardware.base import Node -from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage +from primaite.simulator.network.protocols.ssh import ( + SSHConnectionMessage, + SSHPacket, + SSHTransportMessage, + SSHUserCredentials, +) from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service, ServiceOperatingState - # TODO: This might not be needed now? class TerminalClientConnection(BaseModel): """ TerminalClientConnection Class. - This class is used to record current User Connections within the Terminal class. + This class is used to record current remote User Connections to the Terminal class. """ - parent_node: Node # Technically I think this should be HostNode, but that causes a circular import. + parent_node: Node # Technically should be HostNode but this causes circular import error. """The parent Node that this connection was created on.""" is_active: bool = True @@ -35,6 +38,9 @@ class TerminalClientConnection(BaseModel): _dest_ip_address: IPv4Address """Destination IP address of connection""" + _connection_uuid: str = None + """Connection UUID""" + @property def dest_ip_address(self) -> Optional[IPv4Address]: """Destination IP Address.""" @@ -48,7 +54,7 @@ class TerminalClientConnection(BaseModel): def disconnect(self): """Disconnect the connection.""" if self.client and self.is_active: - self.client._disconnect(self.connection_id) # noqa + self.client._disconnect(self._connection_uuid) # noqa class Terminal(Service): @@ -63,6 +69,10 @@ class Terminal(Service): operating_state: ServiceOperatingState = ServiceOperatingState.RUNNING """Initial Operating State""" + remote_connection: TerminalClientConnection = None + + parent: Node + """Parent component the terminal service is installed on.""" def __init__(self, **kwargs): kwargs["name"] = "Terminal" @@ -93,18 +103,21 @@ class Terminal(Service): _login_valid = Terminal._LoginValidator(terminal=self) rm = super()._init_request_manager() - rm.add_request("login", request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self._validate_login()), validator=_login_valid)) + rm.add_request( + "login", + request_type=RequestType( + func=lambda request, context: RequestResponse.from_bool(self._validate_login()), validator=_login_valid + ), + ) return rm - def _validate_login(self, connection_id: str) -> bool: + def _validate_login(self) -> bool: """Validate login credentials are valid.""" - return self.parent.UserSessionManager.validate_remote_session_uuid(connection_id) - + return self.parent.UserSessionManager.validate_remote_session_uuid(self.connection_uuid) class _LoginValidator(RequestPermissionValidator): """ - When requests come in, this validator will only allow them through if the - User is logged into the Terminal. + When requests come in, this validator will only allow them through if the User is logged into the Terminal. Login is required before making use of the Terminal. """ @@ -113,18 +126,17 @@ class Terminal(Service): """Save a reference to the Terminal instance.""" def __call__(self, request: RequestFormat, context: Dict) -> bool: - """Return whether the Terminal has valid login credentials""" - return self.terminal.login_status - + """Return whether the Terminal has valid login credentials.""" + return self.terminal.is_connected + @property def fail_message(self) -> str: - """Message that is reported when a request is rejected by this validator""" - return ("Cannot perform request on terminal as not logged in.") - + """Message that is reported when a request is rejected by this validator.""" + return "Cannot perform request on terminal as not logged in." # %% Inbound - def login(self, username: str, password: str, ip_address: Optional[IPv4Address]=None) -> bool: + def login(self, username: str, password: str, ip_address: Optional[IPv4Address] = None) -> bool: """Process User request to login to Terminal. :param dest_ip_address: The IP address of the node we want to connect to. @@ -136,15 +148,12 @@ class Terminal(Service): self.sys_log.warning("Cannot process login as service is not running") return False - # need to determine if this is a local or remote login - if ip_address: - # ip_address has been given for remote login + # if ip_address has been provided, we assume we are logging in to a remote terminal. return self._send_remote_login(username=username, password=password, ip_address=ip_address) return self._process_local_login(username=username, password=password) - def _process_local_login(self, username: str, password: str) -> bool: """Local session login to terminal.""" self.connection_uuid = self.parent.UserSessionManager.login(username=username, password=password) @@ -157,25 +166,54 @@ class Terminal(Service): def _send_remote_login(self, username: str, password: str, ip_address: IPv4Address) -> bool: """Attempt to login to a remote terminal.""" - pass + transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST + connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA + user_account: SSHUserCredentials = SSHUserCredentials(username=username, password=password) + payload: SSHPacket = SSHPacket( + transport_message=transport_message, + connection_message=connection_message, + user_account=user_account, + target_ip_address=ip_address, + sender_ip_address=self.parent.network_interface[1].ip_address, + ) + self.sys_log.info(f"Sending remote login request to {ip_address}") + return self.send(payload=payload, dest_ip_address=ip_address) - def _process_remote_login(self, username: str, password: str, ip_address:IPv4Address) -> bool: + def _process_remote_login(self, payload: SSHPacket) -> bool: """Processes a remote terminal requesting to login to this terminal.""" + username: str = payload.user_account.username + password: str = payload.user_account.password self.connection_uuid = self.parent.UserSessionManager.remote_login(username=username, password=password) + self.sys_log.info(f"Sending UserAuth request to UserSessionManager, username={username}, password={password}") + if self.connection_uuid: # Send uuid to remote - self.sys_log.info(f"Remote login authorised, connection ID {self.connection_uuid} for {username} on {ip_address}") - # send back to origin. + self.sys_log.info( + f"Remote login authorised, connection ID {self.connection_uuid} for " + f"{username} on {payload.sender_ip_address}" + ) + transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS + connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA + payload = SSHPacket( + transport_message=transport_message, + connection_message=connection_message, + connection_uuid=self.connection_uuid, + sender_ip_address=self.parent.network_interface[1].ip_address, + target_ip_address=payload.sender_ip_address, + ) + self.send(payload=payload, dest_ip_address=payload.target_ip_address) return True else: + # UserSessionManager has returned None self.sys_log.warning("Login failed, incorrect Username or Password") return False - - def receive(self, payload: SSHPacket, session_id: str, **kwargs) -> bool: + def receive(self, payload: SSHPacket, **kwargs) -> bool: """Receive Payload and process for a response.""" + self.sys_log.debug(f"Received payload: {payload}") + if not isinstance(payload, SSHPacket): return False @@ -184,6 +222,7 @@ class Terminal(Service): return False if payload.connection_message == SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE: + # Close the channel connection_id = kwargs["connection_id"] dest_ip_address = kwargs["dest_ip_address"] self.disconnect(dest_ip_address=dest_ip_address) @@ -191,12 +230,13 @@ class Terminal(Service): # We need to close on the other machine as well elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: - # validate login - user_account = "Username: placeholder, Password: placeholder" - self._ssh_process_login(dest_ip_address="192.168.0.10", user_account=user_account) + """Login Request Received.""" + self._process_remote_login(payload=payload) + self.sys_log.info("User Auth Success!") elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: - self.sys_log.debug("Login Successful") + self.sys_log.info(f"Login Successful, connection ID is {payload.connection_uuid}") + self.connection_uuid = payload.connection_uuid self.is_connected = True return True @@ -208,6 +248,26 @@ class Terminal(Service): # %% Outbound + def _disconnect(self, dest_ip_address: IPv4Address) -> bool: + """Disconnect from the remote.""" + if not self.is_connected: + self.sys_log.warning("Not currently connected to remote") + return False + + if not self.remote_connection: + self.sys_log.warning("No remote connection present") + return False + + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload={"type": "disconnect", "connection_id": self.remote_connection._connection_uuid}, + dest_ip_address=dest_ip_address, + dest_port=self.port, + ) + self.connection_uuid = None + self.sys_log.info(f"{self.name}: Disconnected {self.connection_uuid}") + return True + def disconnect(self, dest_ip_address: IPv4Address) -> bool: """Disconnect from remote connection. @@ -217,28 +277,6 @@ class Terminal(Service): self._disconnect(dest_ip_address=dest_ip_address) self.is_connected = False - def _disconnect(self, dest_ip_address: IPv4Address) -> bool: - if not self.is_connected: - return False - - if len(self.user_connections) == 0: - self.sys_log.warning(f"{self.name}: Unable to disconnect, no active connections.") - return False - if not self.user_connections.get(self.connection_uuid): - return False - software_manager: SoftwareManager = self.software_manager - software_manager.send_payload_to_session_manager( - payload={"type": "disconnect", "connection_id": self.connection_uuid}, - dest_ip_address=dest_ip_address, - dest_port=self.port, - ) - connection = self.user_connections.pop(self.connection_uuid) - - connection.is_active = False - - self.sys_log.info(f"{self.name}: Disconnected {self.connection_uuid}") - return True - def send( self, payload: SSHPacket, diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 673b11a3..65346b45 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -62,14 +62,17 @@ def test_terminal_send(basic_network): network: Network = basic_network computer_a: Computer = network.get_node_by_hostname("node_a") terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") payload: SSHPacket = SSHPacket( payload="Test_Payload", transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, + sender_ip_address=computer_a.network_interface[1].ip_address, + target_ip_address=computer_b.network_interface[1].ip_address, ) - assert terminal_a.send(payload=payload, dest_ip_address="192.168.0.11") + assert terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address) def test_terminal_fail_when_closed(basic_network): @@ -77,27 +80,33 @@ def test_terminal_fail_when_closed(basic_network): network: Network = basic_network computer: Computer = network.get_node_by_hostname("node_a") terminal: Terminal = computer.software_manager.software.get("Terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") terminal.operating_state = ServiceOperatingState.STOPPED - assert terminal.login(ip_address="192.168.0.11") is False + assert ( + terminal.login(username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address) + is False + ) def test_terminal_disconnect(basic_network): """Terminal should set is_connected to false on disconnect""" network: Network = basic_network - computer: Computer = network.get_node_by_hostname("node_a") - terminal: Terminal = computer.software_manager.software.get("Terminal") + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") + terminal_b: Terminal = computer_b.software_manager.software.get("Terminal") - assert terminal.is_connected is False + assert terminal_a.is_connected is False - terminal.login(ip_address="192.168.0.11") + terminal_a.login(username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address) - assert terminal.is_connected is True + assert terminal_a.is_connected is True - terminal.disconnect(dest_ip_address="192.168.0.11") + terminal_a.disconnect(dest_ip_address=computer_b.network_interface[1].ip_address) - assert terminal.is_connected is False + assert terminal_a.is_connected is False def test_terminal_ignores_when_off(basic_network):