#2712 - Updates to the login logic and fixing resultant test failures. Updates to terminal.rst and ssh.py
This commit is contained in:
@@ -5,9 +5,16 @@
|
||||
.. _Terminal:
|
||||
|
||||
Terminal
|
||||
########
|
||||
========
|
||||
|
||||
The ``Terminal`` provides a generic terminal simulation, by extending the base Service class
|
||||
The ``Terminal.py`` class provides a generic terminal simulation, by extending the base Service class within PrimAITE. The aim of this is to act as the primary entrypoint for Nodes within the environment.
|
||||
|
||||
|
||||
Overview
|
||||
--------
|
||||
|
||||
The Terminal service uses Secure Socket (SSH) as the communication method between terminals. They operate on port 22, and are part of the services automatically
|
||||
installed on Nodes when they are instantiated.
|
||||
|
||||
Key capabilities
|
||||
================
|
||||
@@ -17,21 +24,22 @@ Key capabilities
|
||||
- Simulates common Terminal commands
|
||||
- Leverages the Service base class for install/uninstall, status tracking etc.
|
||||
|
||||
|
||||
Usage
|
||||
=====
|
||||
|
||||
- Install on a node via the ``SoftwareManager`` to start the Terminal
|
||||
- Terminal Clients connect, execute commands and disconnect.
|
||||
- Pre-Installs on any `HostNode` component. See the below code example of how to access the terminal.
|
||||
- Terminal Clients connect, execute commands and disconnect from remote components.
|
||||
- Ensures that users are logged in to the component before executing any commands.
|
||||
- Service runs on SSH port 22 by default.
|
||||
|
||||
Implementation
|
||||
==============
|
||||
|
||||
- Manages SSH commands
|
||||
- Ensures User login before sending commands
|
||||
- Processes SSH commands
|
||||
- Returns results in a *<TBD>* format.
|
||||
The terminal takes inspiration from the `Database Client` and `Database Service` classes, and leverages the `UserSessionManager`
|
||||
to provide User Credential authentication when receiving/processing commands.
|
||||
|
||||
Terminal acts as the interface between the user/component and both the Session and Requests Managers, facilitating
|
||||
the passing of requests to both.
|
||||
|
||||
|
||||
Python
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
from enum import IntEnum
|
||||
from typing import Dict, Optional
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.network.protocols.packet import DataPacket
|
||||
@@ -58,21 +59,32 @@ class SSHConnectionMessage(IntEnum):
|
||||
|
||||
|
||||
class SSHUserCredentials(DataPacket):
|
||||
"""Hold Username and Password in SSH Packets"""
|
||||
"""Hold Username and Password in SSH Packets."""
|
||||
|
||||
username: str = None
|
||||
username: str
|
||||
"""Username for login"""
|
||||
|
||||
password: str = None
|
||||
password: str
|
||||
"""Password for login"""
|
||||
|
||||
|
||||
class SSHPacket(DataPacket):
|
||||
"""Represents an SSHPacket."""
|
||||
|
||||
transport_message: SSHTransportMessage = None
|
||||
sender_ip_address: IPv4Address
|
||||
"""Sender IP Address"""
|
||||
|
||||
connection_message: SSHConnectionMessage = None
|
||||
target_ip_address: IPv4Address
|
||||
"""Target IP Address"""
|
||||
|
||||
transport_message: SSHTransportMessage
|
||||
"""Message Transport Type"""
|
||||
|
||||
connection_message: SSHConnectionMessage
|
||||
"""Message Connection Status"""
|
||||
|
||||
user_account: Optional[SSHUserCredentials] = None
|
||||
"""User Account Credentials if passed"""
|
||||
|
||||
connection_uuid: Optional[str] = None # The connection uuid used to validate the session
|
||||
|
||||
|
||||
@@ -3,30 +3,33 @@ from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage
|
||||
from primaite.simulator.network.protocols.ssh import (
|
||||
SSHConnectionMessage,
|
||||
SSHPacket,
|
||||
SSHTransportMessage,
|
||||
SSHUserCredentials,
|
||||
)
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.services.service import Service, ServiceOperatingState
|
||||
|
||||
|
||||
|
||||
# TODO: This might not be needed now?
|
||||
class TerminalClientConnection(BaseModel):
|
||||
"""
|
||||
TerminalClientConnection Class.
|
||||
|
||||
This class is used to record current User Connections within the Terminal class.
|
||||
This class is used to record current remote User Connections to the Terminal class.
|
||||
"""
|
||||
|
||||
parent_node: Node # Technically I think this should be HostNode, but that causes a circular import.
|
||||
parent_node: Node # Technically should be HostNode but this causes circular import error.
|
||||
"""The parent Node that this connection was created on."""
|
||||
|
||||
is_active: bool = True
|
||||
@@ -35,6 +38,9 @@ class TerminalClientConnection(BaseModel):
|
||||
_dest_ip_address: IPv4Address
|
||||
"""Destination IP address of connection"""
|
||||
|
||||
_connection_uuid: str = None
|
||||
"""Connection UUID"""
|
||||
|
||||
@property
|
||||
def dest_ip_address(self) -> Optional[IPv4Address]:
|
||||
"""Destination IP Address."""
|
||||
@@ -48,7 +54,7 @@ class TerminalClientConnection(BaseModel):
|
||||
def disconnect(self):
|
||||
"""Disconnect the connection."""
|
||||
if self.client and self.is_active:
|
||||
self.client._disconnect(self.connection_id) # noqa
|
||||
self.client._disconnect(self._connection_uuid) # noqa
|
||||
|
||||
|
||||
class Terminal(Service):
|
||||
@@ -63,6 +69,10 @@ class Terminal(Service):
|
||||
operating_state: ServiceOperatingState = ServiceOperatingState.RUNNING
|
||||
"""Initial Operating State"""
|
||||
|
||||
remote_connection: TerminalClientConnection = None
|
||||
|
||||
parent: Node
|
||||
"""Parent component the terminal service is installed on."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "Terminal"
|
||||
@@ -93,18 +103,21 @@ class Terminal(Service):
|
||||
_login_valid = Terminal._LoginValidator(terminal=self)
|
||||
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request("login", request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self._validate_login()), validator=_login_valid))
|
||||
rm.add_request(
|
||||
"login",
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self._validate_login()), validator=_login_valid
|
||||
),
|
||||
)
|
||||
return rm
|
||||
|
||||
def _validate_login(self, connection_id: str) -> bool:
|
||||
def _validate_login(self) -> bool:
|
||||
"""Validate login credentials are valid."""
|
||||
return self.parent.UserSessionManager.validate_remote_session_uuid(connection_id)
|
||||
|
||||
return self.parent.UserSessionManager.validate_remote_session_uuid(self.connection_uuid)
|
||||
|
||||
class _LoginValidator(RequestPermissionValidator):
|
||||
"""
|
||||
When requests come in, this validator will only allow them through if the
|
||||
User is logged into the Terminal.
|
||||
When requests come in, this validator will only allow them through if the User is logged into the Terminal.
|
||||
|
||||
Login is required before making use of the Terminal.
|
||||
"""
|
||||
@@ -113,18 +126,17 @@ class Terminal(Service):
|
||||
"""Save a reference to the Terminal instance."""
|
||||
|
||||
def __call__(self, request: RequestFormat, context: Dict) -> bool:
|
||||
"""Return whether the Terminal has valid login credentials"""
|
||||
return self.terminal.login_status
|
||||
|
||||
"""Return whether the Terminal has valid login credentials."""
|
||||
return self.terminal.is_connected
|
||||
|
||||
@property
|
||||
def fail_message(self) -> str:
|
||||
"""Message that is reported when a request is rejected by this validator"""
|
||||
return ("Cannot perform request on terminal as not logged in.")
|
||||
|
||||
"""Message that is reported when a request is rejected by this validator."""
|
||||
return "Cannot perform request on terminal as not logged in."
|
||||
|
||||
# %% Inbound
|
||||
|
||||
def login(self, username: str, password: str, ip_address: Optional[IPv4Address]=None) -> bool:
|
||||
def login(self, username: str, password: str, ip_address: Optional[IPv4Address] = None) -> bool:
|
||||
"""Process User request to login to Terminal.
|
||||
|
||||
:param dest_ip_address: The IP address of the node we want to connect to.
|
||||
@@ -136,15 +148,12 @@ class Terminal(Service):
|
||||
self.sys_log.warning("Cannot process login as service is not running")
|
||||
return False
|
||||
|
||||
# need to determine if this is a local or remote login
|
||||
|
||||
if ip_address:
|
||||
# ip_address has been given for remote login
|
||||
# if ip_address has been provided, we assume we are logging in to a remote terminal.
|
||||
return self._send_remote_login(username=username, password=password, ip_address=ip_address)
|
||||
|
||||
return self._process_local_login(username=username, password=password)
|
||||
|
||||
|
||||
def _process_local_login(self, username: str, password: str) -> bool:
|
||||
"""Local session login to terminal."""
|
||||
self.connection_uuid = self.parent.UserSessionManager.login(username=username, password=password)
|
||||
@@ -157,25 +166,54 @@ class Terminal(Service):
|
||||
|
||||
def _send_remote_login(self, username: str, password: str, ip_address: IPv4Address) -> bool:
|
||||
"""Attempt to login to a remote terminal."""
|
||||
pass
|
||||
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST
|
||||
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
|
||||
user_account: SSHUserCredentials = SSHUserCredentials(username=username, password=password)
|
||||
|
||||
payload: SSHPacket = SSHPacket(
|
||||
transport_message=transport_message,
|
||||
connection_message=connection_message,
|
||||
user_account=user_account,
|
||||
target_ip_address=ip_address,
|
||||
sender_ip_address=self.parent.network_interface[1].ip_address,
|
||||
)
|
||||
|
||||
self.sys_log.info(f"Sending remote login request to {ip_address}")
|
||||
return self.send(payload=payload, dest_ip_address=ip_address)
|
||||
|
||||
def _process_remote_login(self, username: str, password: str, ip_address:IPv4Address) -> bool:
|
||||
def _process_remote_login(self, payload: SSHPacket) -> bool:
|
||||
"""Processes a remote terminal requesting to login to this terminal."""
|
||||
username: str = payload.user_account.username
|
||||
password: str = payload.user_account.password
|
||||
self.connection_uuid = self.parent.UserSessionManager.remote_login(username=username, password=password)
|
||||
self.sys_log.info(f"Sending UserAuth request to UserSessionManager, username={username}, password={password}")
|
||||
|
||||
if self.connection_uuid:
|
||||
# Send uuid to remote
|
||||
self.sys_log.info(f"Remote login authorised, connection ID {self.connection_uuid} for {username} on {ip_address}")
|
||||
# send back to origin.
|
||||
self.sys_log.info(
|
||||
f"Remote login authorised, connection ID {self.connection_uuid} for "
|
||||
f"{username} on {payload.sender_ip_address}"
|
||||
)
|
||||
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS
|
||||
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
|
||||
payload = SSHPacket(
|
||||
transport_message=transport_message,
|
||||
connection_message=connection_message,
|
||||
connection_uuid=self.connection_uuid,
|
||||
sender_ip_address=self.parent.network_interface[1].ip_address,
|
||||
target_ip_address=payload.sender_ip_address,
|
||||
)
|
||||
self.send(payload=payload, dest_ip_address=payload.target_ip_address)
|
||||
return True
|
||||
else:
|
||||
# UserSessionManager has returned None
|
||||
self.sys_log.warning("Login failed, incorrect Username or Password")
|
||||
return False
|
||||
|
||||
|
||||
def receive(self, payload: SSHPacket, session_id: str, **kwargs) -> bool:
|
||||
def receive(self, payload: SSHPacket, **kwargs) -> bool:
|
||||
"""Receive Payload and process for a response."""
|
||||
self.sys_log.debug(f"Received payload: {payload}")
|
||||
|
||||
if not isinstance(payload, SSHPacket):
|
||||
return False
|
||||
|
||||
@@ -184,6 +222,7 @@ class Terminal(Service):
|
||||
return False
|
||||
|
||||
if payload.connection_message == SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE:
|
||||
# Close the channel
|
||||
connection_id = kwargs["connection_id"]
|
||||
dest_ip_address = kwargs["dest_ip_address"]
|
||||
self.disconnect(dest_ip_address=dest_ip_address)
|
||||
@@ -191,12 +230,13 @@ class Terminal(Service):
|
||||
# We need to close on the other machine as well
|
||||
|
||||
elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST:
|
||||
# validate login
|
||||
user_account = "Username: placeholder, Password: placeholder"
|
||||
self._ssh_process_login(dest_ip_address="192.168.0.10", user_account=user_account)
|
||||
"""Login Request Received."""
|
||||
self._process_remote_login(payload=payload)
|
||||
self.sys_log.info("User Auth Success!")
|
||||
|
||||
elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS:
|
||||
self.sys_log.debug("Login Successful")
|
||||
self.sys_log.info(f"Login Successful, connection ID is {payload.connection_uuid}")
|
||||
self.connection_uuid = payload.connection_uuid
|
||||
self.is_connected = True
|
||||
return True
|
||||
|
||||
@@ -208,6 +248,26 @@ class Terminal(Service):
|
||||
|
||||
# %% Outbound
|
||||
|
||||
def _disconnect(self, dest_ip_address: IPv4Address) -> bool:
|
||||
"""Disconnect from the remote."""
|
||||
if not self.is_connected:
|
||||
self.sys_log.warning("Not currently connected to remote")
|
||||
return False
|
||||
|
||||
if not self.remote_connection:
|
||||
self.sys_log.warning("No remote connection present")
|
||||
return False
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect", "connection_id": self.remote_connection._connection_uuid},
|
||||
dest_ip_address=dest_ip_address,
|
||||
dest_port=self.port,
|
||||
)
|
||||
self.connection_uuid = None
|
||||
self.sys_log.info(f"{self.name}: Disconnected {self.connection_uuid}")
|
||||
return True
|
||||
|
||||
def disconnect(self, dest_ip_address: IPv4Address) -> bool:
|
||||
"""Disconnect from remote connection.
|
||||
|
||||
@@ -217,28 +277,6 @@ class Terminal(Service):
|
||||
self._disconnect(dest_ip_address=dest_ip_address)
|
||||
self.is_connected = False
|
||||
|
||||
def _disconnect(self, dest_ip_address: IPv4Address) -> bool:
|
||||
if not self.is_connected:
|
||||
return False
|
||||
|
||||
if len(self.user_connections) == 0:
|
||||
self.sys_log.warning(f"{self.name}: Unable to disconnect, no active connections.")
|
||||
return False
|
||||
if not self.user_connections.get(self.connection_uuid):
|
||||
return False
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect", "connection_id": self.connection_uuid},
|
||||
dest_ip_address=dest_ip_address,
|
||||
dest_port=self.port,
|
||||
)
|
||||
connection = self.user_connections.pop(self.connection_uuid)
|
||||
|
||||
connection.is_active = False
|
||||
|
||||
self.sys_log.info(f"{self.name}: Disconnected {self.connection_uuid}")
|
||||
return True
|
||||
|
||||
def send(
|
||||
self,
|
||||
payload: SSHPacket,
|
||||
|
||||
@@ -62,14 +62,17 @@ def test_terminal_send(basic_network):
|
||||
network: Network = basic_network
|
||||
computer_a: Computer = network.get_node_by_hostname("node_a")
|
||||
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
|
||||
computer_b: Computer = network.get_node_by_hostname("node_b")
|
||||
|
||||
payload: SSHPacket = SSHPacket(
|
||||
payload="Test_Payload",
|
||||
transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST,
|
||||
connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN,
|
||||
sender_ip_address=computer_a.network_interface[1].ip_address,
|
||||
target_ip_address=computer_b.network_interface[1].ip_address,
|
||||
)
|
||||
|
||||
assert terminal_a.send(payload=payload, dest_ip_address="192.168.0.11")
|
||||
assert terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address)
|
||||
|
||||
|
||||
def test_terminal_fail_when_closed(basic_network):
|
||||
@@ -77,27 +80,33 @@ def test_terminal_fail_when_closed(basic_network):
|
||||
network: Network = basic_network
|
||||
computer: Computer = network.get_node_by_hostname("node_a")
|
||||
terminal: Terminal = computer.software_manager.software.get("Terminal")
|
||||
computer_b: Computer = network.get_node_by_hostname("node_b")
|
||||
|
||||
terminal.operating_state = ServiceOperatingState.STOPPED
|
||||
|
||||
assert terminal.login(ip_address="192.168.0.11") is False
|
||||
assert (
|
||||
terminal.login(username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_terminal_disconnect(basic_network):
|
||||
"""Terminal should set is_connected to false on disconnect"""
|
||||
network: Network = basic_network
|
||||
computer: Computer = network.get_node_by_hostname("node_a")
|
||||
terminal: Terminal = computer.software_manager.software.get("Terminal")
|
||||
computer_a: Computer = network.get_node_by_hostname("node_a")
|
||||
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
|
||||
computer_b: Computer = network.get_node_by_hostname("node_b")
|
||||
terminal_b: Terminal = computer_b.software_manager.software.get("Terminal")
|
||||
|
||||
assert terminal.is_connected is False
|
||||
assert terminal_a.is_connected is False
|
||||
|
||||
terminal.login(ip_address="192.168.0.11")
|
||||
terminal_a.login(username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address)
|
||||
|
||||
assert terminal.is_connected is True
|
||||
assert terminal_a.is_connected is True
|
||||
|
||||
terminal.disconnect(dest_ip_address="192.168.0.11")
|
||||
terminal_a.disconnect(dest_ip_address=computer_b.network_interface[1].ip_address)
|
||||
|
||||
assert terminal.is_connected is False
|
||||
assert terminal_a.is_connected is False
|
||||
|
||||
|
||||
def test_terminal_ignores_when_off(basic_network):
|
||||
|
||||
Reference in New Issue
Block a user