diff --git a/docs/source/simulation_components/system/services/terminal.rst b/docs/source/simulation_components/system/services/terminal.rst index afa79c0a..4d1285d1 100644 --- a/docs/source/simulation_components/system/services/terminal.rst +++ b/docs/source/simulation_components/system/services/terminal.rst @@ -5,33 +5,41 @@ .. _Terminal: Terminal -######## +======== -The ``Terminal`` provides a generic terminal simulation, by extending the base Service class +The ``Terminal.py`` class provides a generic terminal simulation, by extending the base Service class within PrimAITE. The aim of this is to act as the primary entrypoint for Nodes within the environment. + + +Overview +-------- + +The Terminal service uses Secure Socket (SSH) as the communication method between terminals. They operate on port 22, and are part of the services automatically +installed on Nodes when they are instantiated. Key capabilities ================ - Authenticates User connection by maintaining an active User account. - Ensures packets are matched to an existing session - - Simulates common Terminal commands + - Simulates common Terminal processes/commands. - Leverages the Service base class for install/uninstall, status tracking etc. - Usage ===== - - Install on a node via the ``SoftwareManager`` to start the Terminal - - Terminal Clients connect, execute commands and disconnect. + - Pre-Installs on any `HostNode` component. See the below code example of how to access the terminal. + - Terminal Clients connect, execute commands and disconnect from remote components. + - Ensures that users are logged in to the component before executing any commands. - Service runs on SSH port 22 by default. Implementation ============== -- Manages SSH commands -- Ensures User login before sending commands -- Processes SSH commands -- Returns results in a ** format. +The terminal takes inspiration from the `Database Client` and `Database Service` classes, and leverages the `UserSessionManager` +to provide User Credential authentication when receiving/processing commands. + +Terminal acts as the interface between the user/component and both the Session and Requests Managers, facilitating +the passing of requests to both. Python diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py index 7be81982..5eb181a6 100644 --- a/src/primaite/simulator/network/protocols/ssh.py +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -1,7 +1,8 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from enum import IntEnum -from typing import Dict, Optional +from ipaddress import IPv4Address +from typing import Optional from primaite.interface.request import RequestResponse from primaite.simulator.network.protocols.packet import DataPacket @@ -57,15 +58,36 @@ class SSHConnectionMessage(IntEnum): """Closes the channel.""" +class SSHUserCredentials(DataPacket): + """Hold Username and Password in SSH Packets.""" + + username: str + """Username for login""" + + password: str + """Password for login""" + + class SSHPacket(DataPacket): """Represents an SSHPacket.""" - transport_message: SSHTransportMessage = None + sender_ip_address: IPv4Address + """Sender IP Address""" - connection_message: SSHConnectionMessage = None + target_ip_address: IPv4Address + """Target IP Address""" - ssh_command: Optional[str] = None # This is the request string + transport_message: SSHTransportMessage + """Message Transport Type""" + + connection_message: SSHConnectionMessage + """Message Connection Status""" + + user_account: Optional[SSHUserCredentials] = None + """User Account Credentials if passed""" + + connection_uuid: Optional[str] = None # The connection uuid used to validate the session ssh_output: Optional[RequestResponse] = None # The Request Manager's returned RequestResponse - user_account: Optional[Dict] = None # The user account we will use to login if we do not have a current connection. + ssh_command: Optional[str] = None # This is the request string diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 589492ba..f01b44a2 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -7,10 +7,15 @@ from uuid import uuid4 from pydantic import BaseModel -from primaite.interface.request import RequestResponse -from primaite.simulator.core import RequestManager +from primaite.interface.request import RequestFormat, RequestResponse +from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType from primaite.simulator.network.hardware.base import Node -from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage +from primaite.simulator.network.protocols.ssh import ( + SSHConnectionMessage, + SSHPacket, + SSHTransportMessage, + SSHUserCredentials, +) from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.software_manager import SoftwareManager @@ -21,10 +26,10 @@ class TerminalClientConnection(BaseModel): """ TerminalClientConnection Class. - This class is used to record current User Connections within the Terminal class. + This class is used to record current remote User Connections to the Terminal class. """ - parent_node: Node # Technically I think this should be HostNode, but that causes a circular import. + parent_node: Node # Technically should be HostNode but this causes circular import error. """The parent Node that this connection was created on.""" is_active: bool = True @@ -33,6 +38,9 @@ class TerminalClientConnection(BaseModel): _dest_ip_address: IPv4Address """Destination IP address of connection""" + _connection_uuid: str = None + """Connection UUID""" + @property def dest_ip_address(self) -> Optional[IPv4Address]: """Destination IP Address.""" @@ -46,15 +54,12 @@ class TerminalClientConnection(BaseModel): def disconnect(self): """Disconnect the connection.""" if self.client and self.is_active: - self.client._disconnect(self.connection_id) # noqa + self.client._disconnect(self._connection_uuid) # noqa class Terminal(Service): """Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH.""" - user_account: Optional[str] = None - "The User Account used for login" - is_connected: bool = False "Boolean Value for whether connected" @@ -64,8 +69,7 @@ class Terminal(Service): operating_state: ServiceOperatingState = ServiceOperatingState.RUNNING """Initial Operating State""" - user_connections: Dict[str, TerminalClientConnection] = {} - """List of authenticated connected users""" + remote_connection: Dict[str, TerminalClientConnection] = {} def __init__(self, **kwargs): kwargs["name"] = "Terminal" @@ -85,97 +89,138 @@ class Terminal(Service): :rtype: Dict """ state = super().describe_state() - - state.update({"hostname": self.name}) return state def apply_request(self, request: List[str | int | float | Dict], context: Dict | None = None) -> RequestResponse: - """Apply Temrinal Request.""" + """Apply Terminal Request.""" return super().apply_request(request, context) def _init_request_manager(self) -> RequestManager: """Initialise Request manager.""" - # TODO: Expand with a login validator? + _login_valid = Terminal._LoginValidator(terminal=self) + rm = super()._init_request_manager() + rm.add_request( + "send", + request_type=RequestType( + func=lambda request, context: RequestResponse.from_bool(self.send()), validator=_login_valid + ), + ) return rm - def _validate_login(self, user_account: Optional[str]) -> bool: + def _validate_login(self) -> bool: """Validate login credentials are valid.""" - # TODO: Interact with UserManager to check user_account details - if len(self.user_connections) == 0: - # No current connections - self.sys_log.warning("Login Required!") - return False - else: - return True + # return self.parent.UserSessionManager.validate_remote_session_uuid(self.connection_uuid) + return True + + class _LoginValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only allow them through if the User is logged into the Terminal. + + Login is required before making use of the Terminal. + """ + + terminal: Terminal + """Save a reference to the Terminal instance.""" + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Return whether the Terminal is connected.""" + return self.terminal.is_connected + + @property + def fail_message(self) -> str: + """Message that is reported when a request is rejected by this validator.""" + return "Cannot perform request on terminal as not logged in." # %% Inbound - def _generate_connection_uuid(self) -> str: - """Generate a unique connection ID.""" - return str(uuid4()) - - def login(self, dest_ip_address: IPv4Address, **kwargs) -> bool: + def login(self, username: str, password: str, ip_address: Optional[IPv4Address] = None) -> bool: """Process User request to login to Terminal. :param dest_ip_address: The IP address of the node we want to connect to. + :param username: The username credential. + :param password: The user password component of credentials. :return: True if successful, False otherwise. """ if self.operating_state != ServiceOperatingState.RUNNING: self.sys_log.warning("Cannot process login as service is not running") return False - if self.connection_uuid in self.user_connections: - self.sys_log.debug("User authentication passed") + + if ip_address: + # if ip_address has been provided, we assume we are logging in to a remote terminal. + return self._send_remote_login(username=username, password=password, ip_address=ip_address) + + return self._process_local_login(username=username, password=password) + + def _process_local_login(self, username: str, password: str) -> bool: + """Local session login to terminal.""" + # self.connection_uuid = self.parent.UserSessionManager.login(username=username, password=password) + self.connection_uuid = str(uuid4()) + if self.connection_uuid: + self.sys_log.info(f"Login request authorised, connection uuid: {self.connection_uuid}") return True else: - # Need to send a login request - # TODO: Refactor with UserManager changes to provide correct credentials and validate. - transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST - connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN - payload: SSHPacket = SSHPacket( - payload="login", transport_message=transport_message, connection_message=connection_message + self.sys_log.warning("Login failed, incorrect Username or Password") + return False + + def _send_remote_login(self, username: str, password: str, ip_address: IPv4Address) -> bool: + """Attempt to login to a remote terminal.""" + transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST + connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA + user_account: SSHUserCredentials = SSHUserCredentials(username=username, password=password) + + payload: SSHPacket = SSHPacket( + transport_message=transport_message, + connection_message=connection_message, + user_account=user_account, + target_ip_address=ip_address, + sender_ip_address=self.parent.network_interface[1].ip_address, + ) + + self.sys_log.info(f"Sending remote login request to {ip_address}") + return self.send(payload=payload, dest_ip_address=ip_address) + + def _process_remote_login(self, payload: SSHPacket) -> bool: + """Processes a remote terminal requesting to login to this terminal.""" + username: str = payload.user_account.username + password: str = payload.user_account.password + self.sys_log.info(f"Sending UserAuth request to UserSessionManager, username={username}, password={password}") + # connection_uuid = self.parent.UserSessionManager.remote_login(username=username, password=password) + connection_uuid = str(uuid4()) + + if connection_uuid: + # Send uuid to remote + self.sys_log.info( + f"Remote login authorised, connection ID {self.connection_uuid} for " + f"{username} on {payload.sender_ip_address}" + ) + transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS + connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA + return_payload = SSHPacket( + transport_message=transport_message, + connection_message=connection_message, + connection_uuid=connection_uuid, + sender_ip_address=self.parent.network_interface[1].ip_address, + target_ip_address=payload.sender_ip_address, + ) + self.send(payload=return_payload, dest_ip_address=return_payload.target_ip_address) + + self.remote_connection[connection_uuid] = TerminalClientConnection( + parent_node=self.software_manager.node, + _dest_ip_address=payload.sender_ip_address, + connection_uuid=connection_uuid, ) - self.sys_log.debug(f"Sending login request to {dest_ip_address}") - self.send(payload=payload, dest_ip_address=dest_ip_address) + return True + else: + # UserSessionManager has returned None + self.sys_log.warning("Login failed, incorrect Username or Password") + return False - def _ssh_process_login(self, dest_ip_address: IPv4Address, user_account: dict, **kwargs) -> bool: - """Processes the login attempt. Returns a bool which either rejects the login or accepts it.""" - # we assume that the login fails unless we meet all the criteria. - transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_FAILURE - connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_FAILED - - # Hard coded at current - replace with another method to handle local accounts. - if user_account == "Username: placeholder, Password: placeholder": # hardcoded - self.connection_uuid = self._generate_connection_uuid() - if not self.add_connection(connection_id=self.connection_uuid): - self.sys_log.warning( - f"{self.name}: Connect request for {dest_ip_address} declined. Service is at capacity." - ) - return False - else: - self.sys_log.info(f"{self.name}: Connect request for ID: {self.connection_uuid} authorised") - transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS - connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_CONFIRMATION - new_connection = TerminalClientConnection( - parent_node=self.software_manager.node, - connection_id=self.connection_uuid, - dest_ip_address=dest_ip_address, - ) - self.user_connections[self.connection_uuid] = new_connection - self.is_connected = True - - payload: SSHPacket = SSHPacket(transport_message=transport_message, connection_message=connection_message) - - self.send(payload=payload, dest_ip_address=dest_ip_address) - return True - - def _ssh_process_logoff(self, session_id: str, *args, **kwargs) -> bool: - """Process the logoff attempt. Return a bool if succesful or unsuccessful.""" - # TODO: Should remove - - def receive(self, payload: SSHPacket, session_id: str, **kwargs) -> bool: + def receive(self, payload: SSHPacket, **kwargs) -> bool: """Receive Payload and process for a response.""" + self.sys_log.debug(f"Received payload: {payload}") + if not isinstance(payload, SSHPacket): return False @@ -183,23 +228,21 @@ class Terminal(Service): self.sys_log.warning("Cannot process message as not running") return False - self.sys_log.debug(f"Received payload: {payload} from session: {session_id}") - if payload.connection_message == SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE: + # Close the channel connection_id = kwargs["connection_id"] dest_ip_address = kwargs["dest_ip_address"] - self._ssh_process_logoff(session_id=session_id) self.disconnect(dest_ip_address=dest_ip_address) self.sys_log.debug(f"Disconnecting {connection_id}") # We need to close on the other machine as well elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: - # validate login - user_account = "Username: placeholder, Password: placeholder" - self._ssh_process_login(dest_ip_address="192.168.0.10", user_account=user_account) + """Login Request Received.""" + self._process_remote_login(payload=payload) elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: - self.sys_log.debug("Login Successful") + self.sys_log.info(f"Login Successful, connection ID is {payload.connection_uuid}") + self.connection_uuid = payload.connection_uuid self.is_connected = True return True @@ -210,38 +253,26 @@ class Terminal(Service): return True # %% Outbound - def _ssh_remote_login(self, dest_ip_address: IPv4Address, user_account: Optional[dict] = None) -> bool: - """Remote login to terminal via SSH.""" - if not user_account: - # TODO: Generic hardcoded info, will need to be updated with UserManager. - user_account = "Username: placeholder, Password: placeholder" - # something like self.user_manager.get_user_details ? - # Implement SSHPacket class - payload: SSHPacket = SSHPacket( - transport_message=SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST, - connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, - user_account=user_account, + def _disconnect(self, dest_ip_address: IPv4Address) -> bool: + """Disconnect from the remote.""" + if not self.is_connected: + self.sys_log.warning("Not currently connected to remote") + return False + + if not self.remote_connection: + self.sys_log.warning("No remote connection present") + return False + + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload={"type": "disconnect", "connection_id": self.remote_connection._connection_uuid}, + dest_ip_address=dest_ip_address, + dest_port=self.port, ) - if self.send(payload=payload, dest_ip_address=dest_ip_address): - if payload.connection_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: - self.sys_log.info(f"{self.name} established an ssh connection with {dest_ip_address}") - # Need to confirm if self.uuid is correct. - self.add_connection(self, connection_id=self.uuid, session_id=self.session_id) - return True - else: - self.sys_log.error("Login Failed. Incorrect credentials provided.") - return False - else: - self.sys_log.error("Login Failed. Incorrect credentials provided.") - return False - - def check_connection(self, connection_id: str) -> bool: - """Check whether the connection is valid.""" - if self.is_connected: - return self.send(dest_ip_address=self.dest_ip_address, connection_id=connection_id) - else: - return False + self.connection_uuid = None + self.sys_log.info(f"{self.name}: Disconnected {self.connection_uuid}") + return True def disconnect(self, dest_ip_address: IPv4Address) -> bool: """Disconnect from remote connection. @@ -252,28 +283,6 @@ class Terminal(Service): self._disconnect(dest_ip_address=dest_ip_address) self.is_connected = False - def _disconnect(self, dest_ip_address: IPv4Address) -> bool: - if not self.is_connected: - return False - - if len(self.user_connections) == 0: - self.sys_log.warning(f"{self.name}: Unable to disconnect, no active connections.") - return False - if not self.user_connections.get(self.connection_uuid): - return False - software_manager: SoftwareManager = self.software_manager - software_manager.send_payload_to_session_manager( - payload={"type": "disconnect", "connection_id": self.connection_uuid}, - dest_ip_address=dest_ip_address, - dest_port=self.port, - ) - connection = self.user_connections.pop(self.connection_uuid) - - connection.is_active = False - - self.sys_log.info(f"{self.name}: Disconnected {self.connection_uuid}") - return True - def send( self, payload: SSHPacket, diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 6b0365ce..65346b45 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -62,14 +62,17 @@ def test_terminal_send(basic_network): network: Network = basic_network computer_a: Computer = network.get_node_by_hostname("node_a") terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") payload: SSHPacket = SSHPacket( payload="Test_Payload", transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, + sender_ip_address=computer_a.network_interface[1].ip_address, + target_ip_address=computer_b.network_interface[1].ip_address, ) - assert terminal_a.send(payload=payload, dest_ip_address="192.168.0.11") + assert terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address) def test_terminal_fail_when_closed(basic_network): @@ -77,27 +80,33 @@ def test_terminal_fail_when_closed(basic_network): network: Network = basic_network computer: Computer = network.get_node_by_hostname("node_a") terminal: Terminal = computer.software_manager.software.get("Terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") terminal.operating_state = ServiceOperatingState.STOPPED - assert terminal.login(dest_ip_address="192.168.0.11") is False + assert ( + terminal.login(username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address) + is False + ) def test_terminal_disconnect(basic_network): """Terminal should set is_connected to false on disconnect""" network: Network = basic_network - computer: Computer = network.get_node_by_hostname("node_a") - terminal: Terminal = computer.software_manager.software.get("Terminal") + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") + terminal_b: Terminal = computer_b.software_manager.software.get("Terminal") - assert terminal.is_connected is False + assert terminal_a.is_connected is False - terminal.login(dest_ip_address="192.168.0.11") + terminal_a.login(username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address) - assert terminal.is_connected is True + assert terminal_a.is_connected is True - terminal.disconnect(dest_ip_address="192.168.0.11") + terminal_a.disconnect(dest_ip_address=computer_b.network_interface[1].ip_address) - assert terminal.is_connected is False + assert terminal_a.is_connected is False def test_terminal_ignores_when_off(basic_network): @@ -108,7 +117,7 @@ def test_terminal_ignores_when_off(basic_network): computer_b: Computer = network.get_node_by_hostname("node_b") - terminal_a.login(dest_ip_address="192.168.0.11") # login to computer_b + terminal_a.login(ip_address="192.168.0.11") # login to computer_b assert terminal_a.is_connected is True