From 2eb36149b28a55cdea48e6d8ea63f6e883de9112 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 15 Jul 2024 08:20:11 +0100 Subject: [PATCH] #2710 - Prep for draft PR --- .../simulator/network/hardware/base.py | 1 - .../simulator/network/protocols/ssh.py | 1 + .../services/database/database_service.py | 2 +- .../system/services/terminal/terminal.py | 189 ++++++++---------- .../_system/_services/test_terminal.py | 112 +++++++++++ 5 files changed, 199 insertions(+), 106 deletions(-) create mode 100644 tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 6942d280..610dd071 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -256,7 +256,6 @@ class NetworkInterface(SimComponent, ABC): """ # Determine the direction of the traffic direction = "inbound" if inbound else "outbound" - # Initialize protocol and port variables protocol = None port = None diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py index 361c2552..7d1f915e 100644 --- a/src/primaite/simulator/network/protocols/ssh.py +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -59,6 +59,7 @@ class SSHConnectionMessage(IntEnum): SSH_LOGOFF_ACK = 89 """Logoff confirmation acknowledgement""" + class SSHPacket(DataPacket): """Represents an SSHPacket.""" diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 22ae0ff3..d6feafbd 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -19,7 +19,7 @@ _LOGGER = getLogger(__name__) class DatabaseService(Service): """ - A class for simulating a generic SQL Server service. +A class for simulating a generic SQL Server service. This class inherits from the `Service` class and provides methods to simulate a SQL database. """ diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index e5ff9054..3324c4e4 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -2,7 +2,7 @@ from __future__ import annotations from ipaddress import IPv4Address -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional from uuid import uuid4 from pydantic import BaseModel @@ -24,9 +24,6 @@ class TerminalClientConnection(BaseModel): This class is used to record current User Connections within the Terminal class. """ - session_id: str - """Session UUID.""" - parent_node: Node # Technically I think this should be HostNode, but that causes a circular import. """The parent Node that this connection was created on.""" @@ -104,34 +101,44 @@ class Terminal(Service): def _validate_login(self, user_account: Optional[str]) -> bool: """Validate login credentials are valid.""" - # Pending login/Usermanager implementation - if user_account: - # validate bits - poke UserManager with provided info - # return self.user_manager.validate(user_account) - pass + # TODO: Interact with UserManager to check user_account details + if len(self.user_connections) == 0: + # No current connections + self.sys_log.warning("Login Required!") + return False else: - pass - # user_account = next(iter(self.user_connections)) - # return self.user_manager.validate(user_account) - - return True - - + return True # %% Inbound - def _generate_connection_id(self) -> str: + def _generate_connection_uuid(self) -> str: """Generate a unique connection ID.""" return str(uuid4()) - def process_login(self, dest_ip_address: IPv4Address, user_account: dict, **kwargs) -> bool: - """Process User request to login to Terminal.""" - if user_account in self.user_connections: + def login(self, dest_ip_address: IPv4Address, **kwargs) -> bool: + """Process User request to login to Terminal. + + :param dest_ip_address: The IP address of the node we want to connect to. + :return: True if successful, False otherwise. + """ + if self.operating_state != ServiceOperatingState.RUNNING: + self.sys_log.warning("Cannot process login as service is not running") + return False + user_account = f"Username: placeholder, Password: placeholder" + if self.connection_uuid in self.user_connections: self.sys_log.debug("User authentication passed") return True else: - self._ssh_process_login(dest_ip_address=dest_ip_address, user_account=user_account) - self.process_login(dest_ip_address=dest_ip_address, user_account=user_account) + # Need to send a login request + # TODO: Refactor with UserManager changes to provide correct credentials and validate. + transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST + connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN + payload: SSHPacket = SSHPacket(payload="login", + transport_message=transport_message, + connection_message=connection_message) + + self.sys_log.debug(f"Sending login request to {dest_ip_address}") + self.send(payload=payload, dest_ip_address=dest_ip_address) def _ssh_process_login(self, dest_ip_address: IPv4Address, user_account: dict, **kwargs) -> bool: """Processes the login attempt. Returns a bool which either rejects the login or accepts it.""" @@ -140,19 +147,20 @@ class Terminal(Service): connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_FAILED # Hard coded at current - replace with another method to handle local accounts. - if user_account == f"{self.user_name:} placeholder, {self.password:} placeholder": # hardcoded - connection_id = self._generate_connection_id() - if not self.add_connection(self, connection_id=connection_id): + if user_account == "Username: placeholder, Password: placeholder": # hardcoded + self.connection_uuid = self._generate_connection_uuid() + if not self.add_connection(connection_id=self.connection_uuid): self.sys_log.warning( f"{self.name}: Connect request for {dest_ip_address} declined. Service is at capacity." ) return False else: - self.sys_log.info(f"{self.name}: Connect request for ID: {connection_id} authorised") + self.sys_log.info(f"{self.name}: Connect request for ID: {self.connection_uuid} authorised") transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_CONFIRMATION - new_connection = TerminalClientConnection(connection_id=connection_id, dest_ip_address=dest_ip_address) - self.user_connections[connection_id] = new_connection + new_connection = TerminalClientConnection(parent_node = self.software_manager.node, + connection_id=self.connection_uuid, dest_ip_address=dest_ip_address) + self.user_connections[self.connection_uuid] = new_connection self.is_connected = True payload: SSHPacket = SSHPacket(transport_message=transport_message, connection_message=connection_message) @@ -160,86 +168,51 @@ class Terminal(Service): self.send(payload=payload, dest_ip_address=dest_ip_address) return True - def validate_user(self, session_id: str) -> bool: - return True - def _ssh_process_logoff(self, session_id: str, *args, **kwargs) -> bool: """Process the logoff attempt. Return a bool if succesful or unsuccessful.""" - - if self.validate_user(session_id): - # Account is logged in - return True - else: - self.sys_log.warning("User account credentials invalid.") - return False - - def _ssh_process_command(self, session_id: str, *args, **kwargs) -> bool: - return True - - def send_logoff_ack(self, session_id: str): - """Send confirmation of successful disconnect""" - transport_message = SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS - connection_message = SSHConnectionMessage.SSH_LOGOFF_ACK - payload: SSHPacket = SSHPacket( - transport_message=transport_message, - connection_message=connection_message, - ssh_output=RequestResponse(status="success", data={"reason": "Successfully Disconnected"}), - ) - self.send(payload=payload, session_id=session_id) + # TODO: Should remove def receive(self, payload: SSHPacket, session_id: str, **kwargs) -> bool: - # shouldn't be expecting to see anything other than SSHPacket payloads currently - # confirm that we are receiving the + """Receive Payload and process for a response.""" if not isinstance(payload, SSHPacket): return False + + if self.operating_state != ServiceOperatingState.RUNNING: + self.sys_log.warning(f"Cannot process message as not running") + return False + self.sys_log.debug(f"Received payload: {payload} from session: {session_id}") - if payload.connection_message == SSHConnectionMessage.SSH_LOGOFF_ACK: - # Logoff acknowledgement received. NFA needed. - self.sys_log.debug("Received confirmation of successful disconnect") - return True - - elif payload.connection_message == SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE: + if payload.connection_message == SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE: + connection_id = kwargs["connection_id"] + dest_ip_address = kwargs["dest_ip_address"] self._ssh_process_logoff(session_id=session_id) - self.sys_log.debug("Disconnect message received, sending logoff ack") + self.disconnect(dest_ip_address=dest_ip_address) + self.sys_log.debug(f"Disconnecting {connection_id}") # We need to close on the other machine as well - self.send_logoff_ack(session_id=session_id) elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: - src_ip = kwargs.get("frame").ip.src_ip_address - user_account = payload.get("user_account", {}) - result = self._ssh_process_login(src_ip=src_ip, session_id=session_id, user_account=user_account) + # validate login + user_account = "Username: placeholder, Password: placeholder" + self._ssh_process_login(dest_ip_address="192.168.0.10", user_account=user_account) - elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST: - # Ensure we only ever process requests if we have a established connection (e.g session_id is provided and validated) - result = self._ssh_process_command(session_id=session_id) + elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: + self.sys_log.debug("Login Successful") + self.is_connected = True + return True else: self.sys_log.warning("Encounter unexpected message type, rejecting connection") - # send a SSH_MSG_CHANNEL_CLOSE if there is a session_id otherwise SSH_MSG_OPEN_FAILED return False - # self.send(payload=result, session_id=session_id) return True - # %% Outbound - - def login(self, dest_ip_address: IPv4Address, user_account: dict[str]) -> bool: - """ - Perform an initial login request. - - If this fails, raises an error. - """ - # TODO: This will need elaborating when user accounts are implemented - self.sys_log.info("Attempting Login") - return self._ssh_remote_login(self, dest_ip_address=dest_ip_address, user_account=user_account) - def _ssh_remote_login(self, dest_ip_address: IPv4Address, user_account: Optional[dict] = None) -> bool: """Remote login to terminal via SSH.""" if not user_account: - # Setting default creds (Best to use this until we have more clarification around user accounts) - user_account = {self.user_name: "placeholder", self.password: "placeholder"} + # TODO: Generic hardcoded info, will need to be updated with UserManager. + user_account = f"Username: placeholder, Password: placeholder" # something like self.user_manager.get_user_details ? # Implement SSHPacket class @@ -248,7 +221,6 @@ class Terminal(Service): connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, user_account=user_account, ) - # self.send will return bool, payload unchanged? if self.send(payload=payload, dest_ip_address=dest_ip_address): if payload.connection_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: self.sys_log.info(f"{self.name} established an ssh connection with {dest_ip_address}") @@ -269,45 +241,54 @@ class Terminal(Service): else: return False - def disconnect(self, connection_id: str): - """Disconnect from remote.""" - self._disconnect(connection_id) + def disconnect(self, dest_ip_address: IPv4Address) -> bool: + """Disconnect from remote connection. + + :param dest_ip_address: The IP address fo the connection we are terminating. + :return: True if successful, False otherwise. + """ + self._disconnect(dest_ip_address=dest_ip_address) self.is_connected = False - def _disconnect(self, connection_id: str) -> bool: + def _disconnect(self, dest_ip_address: IPv4Address) -> bool: if not self.is_connected: return False if len(self.user_connections) == 0: self.sys_log.warning(f"{self.name}: Unable to disconnect, no active connections.") return False - if not self.user_connections.get(connection_id): + if not self.user_connections.get(self.connection_uuid): return False software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( - payload={"type": "disconnect", "connection_id": connection_id}, - dest_ip_address=self.server_ip_address, - dest_port=self.port, + payload={"type": "disconnect", "connection_id": self.connection_uuid}, + dest_ip_address=dest_ip_address, + dest_port=self.port ) - connection = self.user_connections.pop(connection_id) - self.terminate_connection(connection_id=connection_id) + connection = self.user_connections.pop(self.connection_uuid) connection.is_active = False self.sys_log.info( - f"{self.name}: Disconnected {connection_id} from: {self.user_connections[connection_id]._dest_ip_address}" + f"{self.name}: Disconnected {self.connection_uuid}" ) - self.connected = False return True def send( self, payload: SSHPacket, - dest_ip_address: Optional[IPv4Address] = None, - session_id: Optional[str] = None, - user_account: Optional[str] = None, + dest_ip_address: IPv4Address, ) -> bool: - """Send a payload out from the Terminal.""" - self._validate_login(user_account) - self.sys_log.debug(f"Sending payload: {payload} from session: {session_id}") - return super().send(payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port, session_id=session_id) + """ + Send a payload out from the Terminal. + + :param payload: The payload to be sent. + :param dest_up_address: The IP address of the payload destination. + """ + if self.operating_state != ServiceOperatingState.RUNNING: + self.sys_log.warning(f"Cannot send commands when Operating state is {self.operating_state}!") + return False + self.sys_log.debug(f"Sending payload: {payload}") + return super().send( + payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port + ) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py new file mode 100644 index 00000000..62933b5c --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -0,0 +1,112 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Tuple + +import pytest + +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.network.switch import Switch +from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage +from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.simulator.system.services.terminal.terminal import Terminal +from primaite.simulator.system.software import SoftwareHealthState + +@pytest.fixture(scope="function") +def terminal_on_computer() -> Tuple[Terminal, Computer]: + computer: Computer = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + computer.power_on() + terminal: Terminal = computer.software_manager.software.get("Terminal") + + return [terminal, computer] + +@pytest.fixture(scope="function") +def basic_network() -> Network: + network = Network() + node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + node_a.power_on() + node_a.software_manager.get_open_ports() + + node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) + node_b.power_on() + network.connect(node_a.network_interface[1], node_b.network_interface[1]) + + return network + + +def test_terminal_creation(terminal_on_computer): + terminal, computer = terminal_on_computer + terminal.describe_state() + +def test_terminal_install_default(): + """Terminal should be auto installed onto Nodes""" + computer = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + computer.power_on() + + assert computer.software_manager.software.get("Terminal") + +def test_terminal_not_on_switch(): + """Ensure terminal does not auto-install to switch""" + test_switch = Switch(hostname="Test") + + assert not test_switch.software_manager.software.get("Terminal") + +def test_terminal_send(basic_network): + """Check that Terminal can send """ + network: Network = basic_network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + + payload: SSHPacket = SSHPacket(payload="Test_Payload", + transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, + connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN) + + + assert terminal_a.send(payload=payload, dest_ip_address="192.168.0.11") + + +def test_terminal_fail_when_closed(basic_network): + """Ensure Terminal won't attempt to send/receive when off""" + network: Network = basic_network + computer: Computer = network.get_node_by_hostname("node_a") + terminal: Terminal = computer.software_manager.software.get("Terminal") + + terminal.operating_state = ServiceOperatingState.STOPPED + + assert terminal.login(dest_ip_address="192.168.0.11") is False + + +def test_terminal_disconnect(basic_network): + """Terminal should set is_connected to false on disconnect""" + network: Network = basic_network + computer: Computer = network.get_node_by_hostname("node_a") + terminal: Terminal = computer.software_manager.software.get("Terminal") + + assert terminal.is_connected is False + + terminal.login(dest_ip_address="192.168.0.11") + + assert terminal.is_connected is True + + terminal.disconnect(dest_ip_address="192.168.0.11") + + assert terminal.is_connected is False + +def test_terminal_ignores_when_off(basic_network): + """Terminal should ignore commands when not running""" + network: Network = basic_network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + + computer_b: Computer = network.get_node_by_hostname("node_b") + + terminal_a.login(dest_ip_address="192.168.0.11") # login to computer_b + + assert terminal_a.is_connected is True + + terminal_a.operating_state = ServiceOperatingState.STOPPED + + payload: SSHPacket = SSHPacket(payload="Test_Payload", + transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, + connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_DATA) + + assert not terminal_a.send(payload=payload, dest_ip_address="192.168.0.11")