#2712 - Updates to the login logic and fixing resultant test failures. Updates to terminal.rst and ssh.py

This commit is contained in:
Charlie Crane
2024-07-23 15:18:20 +01:00
parent 3c590a8733
commit a7f9e4502e
4 changed files with 146 additions and 79 deletions

View File

@@ -5,9 +5,16 @@
.. _Terminal:
Terminal
########
========
The ``Terminal`` provides a generic terminal simulation, by extending the base Service class
The ``Terminal.py`` class provides a generic terminal simulation, by extending the base Service class within PrimAITE. The aim of this is to act as the primary entrypoint for Nodes within the environment.
Overview
--------
The Terminal service uses Secure Socket (SSH) as the communication method between terminals. They operate on port 22, and are part of the services automatically
installed on Nodes when they are instantiated.
Key capabilities
================
@@ -17,21 +24,22 @@ Key capabilities
- Simulates common Terminal commands
- Leverages the Service base class for install/uninstall, status tracking etc.
Usage
=====
- Install on a node via the ``SoftwareManager`` to start the Terminal
- Terminal Clients connect, execute commands and disconnect.
- Pre-Installs on any `HostNode` component. See the below code example of how to access the terminal.
- Terminal Clients connect, execute commands and disconnect from remote components.
- Ensures that users are logged in to the component before executing any commands.
- Service runs on SSH port 22 by default.
Implementation
==============
- Manages SSH commands
- Ensures User login before sending commands
- Processes SSH commands
- Returns results in a *<TBD>* format.
The terminal takes inspiration from the `Database Client` and `Database Service` classes, and leverages the `UserSessionManager`
to provide User Credential authentication when receiving/processing commands.
Terminal acts as the interface between the user/component and both the Session and Requests Managers, facilitating
the passing of requests to both.
Python

View File

@@ -1,7 +1,8 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from enum import IntEnum
from typing import Dict, Optional
from ipaddress import IPv4Address
from typing import Optional
from primaite.interface.request import RequestResponse
from primaite.simulator.network.protocols.packet import DataPacket
@@ -58,21 +59,32 @@ class SSHConnectionMessage(IntEnum):
class SSHUserCredentials(DataPacket):
"""Hold Username and Password in SSH Packets"""
"""Hold Username and Password in SSH Packets."""
username: str = None
username: str
"""Username for login"""
password: str = None
password: str
"""Password for login"""
class SSHPacket(DataPacket):
"""Represents an SSHPacket."""
transport_message: SSHTransportMessage = None
sender_ip_address: IPv4Address
"""Sender IP Address"""
connection_message: SSHConnectionMessage = None
target_ip_address: IPv4Address
"""Target IP Address"""
transport_message: SSHTransportMessage
"""Message Transport Type"""
connection_message: SSHConnectionMessage
"""Message Connection Status"""
user_account: Optional[SSHUserCredentials] = None
"""User Account Credentials if passed"""
connection_uuid: Optional[str] = None # The connection uuid used to validate the session

View File

@@ -3,30 +3,33 @@ from __future__ import annotations
from ipaddress import IPv4Address
from typing import Dict, List, Optional
from uuid import uuid4
from pydantic import BaseModel
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage
from primaite.simulator.network.protocols.ssh import (
SSHConnectionMessage,
SSHPacket,
SSHTransportMessage,
SSHUserCredentials,
)
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.service import Service, ServiceOperatingState
# TODO: This might not be needed now?
class TerminalClientConnection(BaseModel):
"""
TerminalClientConnection Class.
This class is used to record current User Connections within the Terminal class.
This class is used to record current remote User Connections to the Terminal class.
"""
parent_node: Node # Technically I think this should be HostNode, but that causes a circular import.
parent_node: Node # Technically should be HostNode but this causes circular import error.
"""The parent Node that this connection was created on."""
is_active: bool = True
@@ -35,6 +38,9 @@ class TerminalClientConnection(BaseModel):
_dest_ip_address: IPv4Address
"""Destination IP address of connection"""
_connection_uuid: str = None
"""Connection UUID"""
@property
def dest_ip_address(self) -> Optional[IPv4Address]:
"""Destination IP Address."""
@@ -48,7 +54,7 @@ class TerminalClientConnection(BaseModel):
def disconnect(self):
"""Disconnect the connection."""
if self.client and self.is_active:
self.client._disconnect(self.connection_id) # noqa
self.client._disconnect(self._connection_uuid) # noqa
class Terminal(Service):
@@ -63,6 +69,10 @@ class Terminal(Service):
operating_state: ServiceOperatingState = ServiceOperatingState.RUNNING
"""Initial Operating State"""
remote_connection: TerminalClientConnection = None
parent: Node
"""Parent component the terminal service is installed on."""
def __init__(self, **kwargs):
kwargs["name"] = "Terminal"
@@ -93,18 +103,21 @@ class Terminal(Service):
_login_valid = Terminal._LoginValidator(terminal=self)
rm = super()._init_request_manager()
rm.add_request("login", request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self._validate_login()), validator=_login_valid))
rm.add_request(
"login",
request_type=RequestType(
func=lambda request, context: RequestResponse.from_bool(self._validate_login()), validator=_login_valid
),
)
return rm
def _validate_login(self, connection_id: str) -> bool:
def _validate_login(self) -> bool:
"""Validate login credentials are valid."""
return self.parent.UserSessionManager.validate_remote_session_uuid(connection_id)
return self.parent.UserSessionManager.validate_remote_session_uuid(self.connection_uuid)
class _LoginValidator(RequestPermissionValidator):
"""
When requests come in, this validator will only allow them through if the
User is logged into the Terminal.
When requests come in, this validator will only allow them through if the User is logged into the Terminal.
Login is required before making use of the Terminal.
"""
@@ -113,18 +126,17 @@ class Terminal(Service):
"""Save a reference to the Terminal instance."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Return whether the Terminal has valid login credentials"""
return self.terminal.login_status
"""Return whether the Terminal has valid login credentials."""
return self.terminal.is_connected
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator"""
return ("Cannot perform request on terminal as not logged in.")
"""Message that is reported when a request is rejected by this validator."""
return "Cannot perform request on terminal as not logged in."
# %% Inbound
def login(self, username: str, password: str, ip_address: Optional[IPv4Address]=None) -> bool:
def login(self, username: str, password: str, ip_address: Optional[IPv4Address] = None) -> bool:
"""Process User request to login to Terminal.
:param dest_ip_address: The IP address of the node we want to connect to.
@@ -136,15 +148,12 @@ class Terminal(Service):
self.sys_log.warning("Cannot process login as service is not running")
return False
# need to determine if this is a local or remote login
if ip_address:
# ip_address has been given for remote login
# if ip_address has been provided, we assume we are logging in to a remote terminal.
return self._send_remote_login(username=username, password=password, ip_address=ip_address)
return self._process_local_login(username=username, password=password)
def _process_local_login(self, username: str, password: str) -> bool:
"""Local session login to terminal."""
self.connection_uuid = self.parent.UserSessionManager.login(username=username, password=password)
@@ -157,25 +166,54 @@ class Terminal(Service):
def _send_remote_login(self, username: str, password: str, ip_address: IPv4Address) -> bool:
"""Attempt to login to a remote terminal."""
pass
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
user_account: SSHUserCredentials = SSHUserCredentials(username=username, password=password)
payload: SSHPacket = SSHPacket(
transport_message=transport_message,
connection_message=connection_message,
user_account=user_account,
target_ip_address=ip_address,
sender_ip_address=self.parent.network_interface[1].ip_address,
)
self.sys_log.info(f"Sending remote login request to {ip_address}")
return self.send(payload=payload, dest_ip_address=ip_address)
def _process_remote_login(self, username: str, password: str, ip_address:IPv4Address) -> bool:
def _process_remote_login(self, payload: SSHPacket) -> bool:
"""Processes a remote terminal requesting to login to this terminal."""
username: str = payload.user_account.username
password: str = payload.user_account.password
self.connection_uuid = self.parent.UserSessionManager.remote_login(username=username, password=password)
self.sys_log.info(f"Sending UserAuth request to UserSessionManager, username={username}, password={password}")
if self.connection_uuid:
# Send uuid to remote
self.sys_log.info(f"Remote login authorised, connection ID {self.connection_uuid} for {username} on {ip_address}")
# send back to origin.
self.sys_log.info(
f"Remote login authorised, connection ID {self.connection_uuid} for "
f"{username} on {payload.sender_ip_address}"
)
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
payload = SSHPacket(
transport_message=transport_message,
connection_message=connection_message,
connection_uuid=self.connection_uuid,
sender_ip_address=self.parent.network_interface[1].ip_address,
target_ip_address=payload.sender_ip_address,
)
self.send(payload=payload, dest_ip_address=payload.target_ip_address)
return True
else:
# UserSessionManager has returned None
self.sys_log.warning("Login failed, incorrect Username or Password")
return False
def receive(self, payload: SSHPacket, session_id: str, **kwargs) -> bool:
def receive(self, payload: SSHPacket, **kwargs) -> bool:
"""Receive Payload and process for a response."""
self.sys_log.debug(f"Received payload: {payload}")
if not isinstance(payload, SSHPacket):
return False
@@ -184,6 +222,7 @@ class Terminal(Service):
return False
if payload.connection_message == SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE:
# Close the channel
connection_id = kwargs["connection_id"]
dest_ip_address = kwargs["dest_ip_address"]
self.disconnect(dest_ip_address=dest_ip_address)
@@ -191,12 +230,13 @@ class Terminal(Service):
# We need to close on the other machine as well
elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST:
# validate login
user_account = "Username: placeholder, Password: placeholder"
self._ssh_process_login(dest_ip_address="192.168.0.10", user_account=user_account)
"""Login Request Received."""
self._process_remote_login(payload=payload)
self.sys_log.info("User Auth Success!")
elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS:
self.sys_log.debug("Login Successful")
self.sys_log.info(f"Login Successful, connection ID is {payload.connection_uuid}")
self.connection_uuid = payload.connection_uuid
self.is_connected = True
return True
@@ -208,6 +248,26 @@ class Terminal(Service):
# %% Outbound
def _disconnect(self, dest_ip_address: IPv4Address) -> bool:
"""Disconnect from the remote."""
if not self.is_connected:
self.sys_log.warning("Not currently connected to remote")
return False
if not self.remote_connection:
self.sys_log.warning("No remote connection present")
return False
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "disconnect", "connection_id": self.remote_connection._connection_uuid},
dest_ip_address=dest_ip_address,
dest_port=self.port,
)
self.connection_uuid = None
self.sys_log.info(f"{self.name}: Disconnected {self.connection_uuid}")
return True
def disconnect(self, dest_ip_address: IPv4Address) -> bool:
"""Disconnect from remote connection.
@@ -217,28 +277,6 @@ class Terminal(Service):
self._disconnect(dest_ip_address=dest_ip_address)
self.is_connected = False
def _disconnect(self, dest_ip_address: IPv4Address) -> bool:
if not self.is_connected:
return False
if len(self.user_connections) == 0:
self.sys_log.warning(f"{self.name}: Unable to disconnect, no active connections.")
return False
if not self.user_connections.get(self.connection_uuid):
return False
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "disconnect", "connection_id": self.connection_uuid},
dest_ip_address=dest_ip_address,
dest_port=self.port,
)
connection = self.user_connections.pop(self.connection_uuid)
connection.is_active = False
self.sys_log.info(f"{self.name}: Disconnected {self.connection_uuid}")
return True
def send(
self,
payload: SSHPacket,

View File

@@ -62,14 +62,17 @@ def test_terminal_send(basic_network):
network: Network = basic_network
computer_a: Computer = network.get_node_by_hostname("node_a")
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
computer_b: Computer = network.get_node_by_hostname("node_b")
payload: SSHPacket = SSHPacket(
payload="Test_Payload",
transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST,
connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN,
sender_ip_address=computer_a.network_interface[1].ip_address,
target_ip_address=computer_b.network_interface[1].ip_address,
)
assert terminal_a.send(payload=payload, dest_ip_address="192.168.0.11")
assert terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address)
def test_terminal_fail_when_closed(basic_network):
@@ -77,27 +80,33 @@ def test_terminal_fail_when_closed(basic_network):
network: Network = basic_network
computer: Computer = network.get_node_by_hostname("node_a")
terminal: Terminal = computer.software_manager.software.get("Terminal")
computer_b: Computer = network.get_node_by_hostname("node_b")
terminal.operating_state = ServiceOperatingState.STOPPED
assert terminal.login(ip_address="192.168.0.11") is False
assert (
terminal.login(username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address)
is False
)
def test_terminal_disconnect(basic_network):
"""Terminal should set is_connected to false on disconnect"""
network: Network = basic_network
computer: Computer = network.get_node_by_hostname("node_a")
terminal: Terminal = computer.software_manager.software.get("Terminal")
computer_a: Computer = network.get_node_by_hostname("node_a")
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
computer_b: Computer = network.get_node_by_hostname("node_b")
terminal_b: Terminal = computer_b.software_manager.software.get("Terminal")
assert terminal.is_connected is False
assert terminal_a.is_connected is False
terminal.login(ip_address="192.168.0.11")
terminal_a.login(username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address)
assert terminal.is_connected is True
assert terminal_a.is_connected is True
terminal.disconnect(dest_ip_address="192.168.0.11")
terminal_a.disconnect(dest_ip_address=computer_b.network_interface[1].ip_address)
assert terminal.is_connected is False
assert terminal_a.is_connected is False
def test_terminal_ignores_when_off(basic_network):