diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 5848ade4..1fb936cd 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -293,6 +293,7 @@ class HostNode(Node): * DNS (Domain Name System) Client: Resolves domain names to IP addresses. * FTP (File Transfer Protocol) Client: Enables file transfers between the host and FTP servers. * NTP (Network Time Protocol) Client: Synchronizes the system clock with NTP servers. + * Terminal Client: Handles SSH requests between HostNode and external components. Applications: ------------ diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 22ae0ff3..f061b3c7 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -21,7 +21,7 @@ class DatabaseService(Service): """ A class for simulating a generic SQL Server service. - This class inherits from the `Service` class and provides methods to simulate a SQL database. + This class inherits from the `Service` class and provides methods to simulate a SQL database. """ password: Optional[str] = None diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 5f8719ac..589492ba 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -9,7 +9,7 @@ from pydantic import BaseModel from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager -from primaite.simulator.network.hardware.nodes.host.host_node import HostNode +from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port @@ -24,10 +24,7 @@ class TerminalClientConnection(BaseModel): This class is used to record current User Connections within the Terminal class. """ - connection_id: str - """Connection UUID.""" - - parent_node: HostNode + parent_node: Node # Technically I think this should be HostNode, but that causes a circular import. """The parent Node that this connection was created on.""" is_active: bool = True @@ -76,6 +73,8 @@ class Terminal(Service): kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) + # %% Util + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -100,41 +99,70 @@ class Terminal(Service): rm = super()._init_request_manager() return rm + def _validate_login(self, user_account: Optional[str]) -> bool: + """Validate login credentials are valid.""" + # TODO: Interact with UserManager to check user_account details + if len(self.user_connections) == 0: + # No current connections + self.sys_log.warning("Login Required!") + return False + else: + return True + # %% Inbound - def _generate_connection_id(self) -> str: + def _generate_connection_uuid(self) -> str: """Generate a unique connection ID.""" return str(uuid4()) - def process_login(self, dest_ip_address: IPv4Address, user_account: dict, **kwargs) -> bool: - """Process User request to login to Terminal.""" - if user_account in self.user_connections: + def login(self, dest_ip_address: IPv4Address, **kwargs) -> bool: + """Process User request to login to Terminal. + + :param dest_ip_address: The IP address of the node we want to connect to. + :return: True if successful, False otherwise. + """ + if self.operating_state != ServiceOperatingState.RUNNING: + self.sys_log.warning("Cannot process login as service is not running") + return False + if self.connection_uuid in self.user_connections: self.sys_log.debug("User authentication passed") return True else: - self._ssh_process_login(dest_ip_address=dest_ip_address, user_account=user_account) - self.process_login(dest_ip_address=dest_ip_address, user_account=user_account) + # Need to send a login request + # TODO: Refactor with UserManager changes to provide correct credentials and validate. + transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST + connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN + payload: SSHPacket = SSHPacket( + payload="login", transport_message=transport_message, connection_message=connection_message + ) + + self.sys_log.debug(f"Sending login request to {dest_ip_address}") + self.send(payload=payload, dest_ip_address=dest_ip_address) def _ssh_process_login(self, dest_ip_address: IPv4Address, user_account: dict, **kwargs) -> bool: - """Processes the login attempt. Returns a SSHPacket which either rejects the login or accepts it.""" + """Processes the login attempt. Returns a bool which either rejects the login or accepts it.""" # we assume that the login fails unless we meet all the criteria. transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_FAILURE connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_FAILED # Hard coded at current - replace with another method to handle local accounts. - if user_account == f"{self.user_name:} placeholder, {self.password:} placeholder": # hardcoded - connection_id = self._generate_connection_id() - if not self.add_connection(self, connection_id=connection_id): + if user_account == "Username: placeholder, Password: placeholder": # hardcoded + self.connection_uuid = self._generate_connection_uuid() + if not self.add_connection(connection_id=self.connection_uuid): self.sys_log.warning( f"{self.name}: Connect request for {dest_ip_address} declined. Service is at capacity." ) return False else: - self.sys_log.info(f"{self.name}: Connect request for ID: {connection_id} authorised") + self.sys_log.info(f"{self.name}: Connect request for ID: {self.connection_uuid} authorised") transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_CONFIRMATION - new_connection = TerminalClientConnection(connection_id=connection_id, dest_ip_address=dest_ip_address) - self.user_connections[connection_id] = new_connection + new_connection = TerminalClientConnection( + parent_node=self.software_manager.node, + connection_id=self.connection_uuid, + dest_ip_address=dest_ip_address, + ) + self.user_connections[self.connection_uuid] = new_connection self.is_connected = True payload: SSHPacket = SSHPacket(transport_message=transport_message, connection_message=connection_message) @@ -142,23 +170,52 @@ class Terminal(Service): self.send(payload=payload, dest_ip_address=dest_ip_address) return True + def _ssh_process_logoff(self, session_id: str, *args, **kwargs) -> bool: + """Process the logoff attempt. Return a bool if succesful or unsuccessful.""" + # TODO: Should remove + + def receive(self, payload: SSHPacket, session_id: str, **kwargs) -> bool: + """Receive Payload and process for a response.""" + if not isinstance(payload, SSHPacket): + return False + + if self.operating_state != ServiceOperatingState.RUNNING: + self.sys_log.warning("Cannot process message as not running") + return False + + self.sys_log.debug(f"Received payload: {payload} from session: {session_id}") + + if payload.connection_message == SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE: + connection_id = kwargs["connection_id"] + dest_ip_address = kwargs["dest_ip_address"] + self._ssh_process_logoff(session_id=session_id) + self.disconnect(dest_ip_address=dest_ip_address) + self.sys_log.debug(f"Disconnecting {connection_id}") + # We need to close on the other machine as well + + elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: + # validate login + user_account = "Username: placeholder, Password: placeholder" + self._ssh_process_login(dest_ip_address="192.168.0.10", user_account=user_account) + + elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: + self.sys_log.debug("Login Successful") + self.is_connected = True + return True + + else: + self.sys_log.warning("Encounter unexpected message type, rejecting connection") + return False + + return True + # %% Outbound - - def login(self, dest_ip_address: IPv4Address) -> bool: - """ - Perform an initial login request. - - If this fails, raises an error. - """ - # TODO: This will need elaborating when user accounts are implemented - self.sys_log.info("Attempting Login") - return self.ssh_remote_login(self, dest_ip_address=dest_ip_address, user_account=self.user_account) - - def ssh_remote_login(self, dest_ip_address: IPv4Address, user_account: Optional[dict] = None) -> bool: + def _ssh_remote_login(self, dest_ip_address: IPv4Address, user_account: Optional[dict] = None) -> bool: """Remote login to terminal via SSH.""" if not user_account: - # Setting default creds (Best to use this until we have more clarification around user accounts) - user_account = {self.user_name: "placeholder", self.password: "placeholder"} + # TODO: Generic hardcoded info, will need to be updated with UserManager. + user_account = "Username: placeholder, Password: placeholder" + # something like self.user_manager.get_user_details ? # Implement SSHPacket class payload: SSHPacket = SSHPacket( @@ -166,7 +223,6 @@ class Terminal(Service): connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, user_account=user_account, ) - # self.send will return bool, payload unchanged? if self.send(payload=payload, dest_ip_address=dest_ip_address): if payload.connection_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: self.sys_log.info(f"{self.name} established an ssh connection with {dest_ip_address}") @@ -187,33 +243,50 @@ class Terminal(Service): else: return False - def disconnect(self, connection_id: str): - """Disconnect from remote.""" - self._disconnect(connection_id) + def disconnect(self, dest_ip_address: IPv4Address) -> bool: + """Disconnect from remote connection. + + :param dest_ip_address: The IP address fo the connection we are terminating. + :return: True if successful, False otherwise. + """ + self._disconnect(dest_ip_address=dest_ip_address) self.is_connected = False - def _disconnect(self, connection_id: str) -> bool: + def _disconnect(self, dest_ip_address: IPv4Address) -> bool: if not self.is_connected: return False if len(self.user_connections) == 0: self.sys_log.warning(f"{self.name}: Unable to disconnect, no active connections.") return False - if not self.user_connections.get(connection_id): + if not self.user_connections.get(self.connection_uuid): return False software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( - payload={"type": "disconnect", "connection_id": connection_id}, - dest_ip_address=self.server_ip_address, + payload={"type": "disconnect", "connection_id": self.connection_uuid}, + dest_ip_address=dest_ip_address, dest_port=self.port, ) - connection = self.user_connections.pop(connection_id) - self.terminate_connection(connection_id=connection_id) + connection = self.user_connections.pop(self.connection_uuid) connection.is_active = False - self.sys_log.info( - f"{self.name}: Disconnected {connection_id} from: {self.user_connections[connection_id]._dest_ip_address}" - ) - self.connected = False + self.sys_log.info(f"{self.name}: Disconnected {self.connection_uuid}") return True + + def send( + self, + payload: SSHPacket, + dest_ip_address: IPv4Address, + ) -> bool: + """ + Send a payload out from the Terminal. + + :param payload: The payload to be sent. + :param dest_up_address: The IP address of the payload destination. + """ + if self.operating_state != ServiceOperatingState.RUNNING: + self.sys_log.warning(f"Cannot send commands when Operating state is {self.operating_state}!") + return False + self.sys_log.debug(f"Sending payload: {payload}") + return super().send(payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port) diff --git a/tests/integration_tests/system/test_nmap.py b/tests/integration_tests/system/test_nmap.py index bbfa4f43..a261f272 100644 --- a/tests/integration_tests/system/test_nmap.py +++ b/tests/integration_tests/system/test_nmap.py @@ -106,7 +106,7 @@ def test_port_scan_full_subnet_all_ports_and_protocols(example_network): expected_result = { IPv4Address("192.168.10.1"): {IPProtocol.UDP: [Port.ARP]}, IPv4Address("192.168.10.22"): { - IPProtocol.TCP: [Port.HTTP, Port.FTP, Port.DNS], + IPProtocol.TCP: [Port.HTTP, Port.FTP, Port.DNS, Port.SSH], IPProtocol.UDP: [Port.ARP, Port.NTP], }, } diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py new file mode 100644 index 00000000..6b0365ce --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -0,0 +1,123 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Tuple + +import pytest + +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.network.switch import Switch +from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage +from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.simulator.system.services.terminal.terminal import Terminal +from primaite.simulator.system.software import SoftwareHealthState + + +@pytest.fixture(scope="function") +def terminal_on_computer() -> Tuple[Terminal, Computer]: + computer: Computer = Computer( + hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0 + ) + computer.power_on() + terminal: Terminal = computer.software_manager.software.get("Terminal") + + return [terminal, computer] + + +@pytest.fixture(scope="function") +def basic_network() -> Network: + network = Network() + node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + node_a.power_on() + node_a.software_manager.get_open_ports() + + node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) + node_b.power_on() + network.connect(node_a.network_interface[1], node_b.network_interface[1]) + + return network + + +def test_terminal_creation(terminal_on_computer): + terminal, computer = terminal_on_computer + terminal.describe_state() + + +def test_terminal_install_default(): + """Terminal should be auto installed onto Nodes""" + computer = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + computer.power_on() + + assert computer.software_manager.software.get("Terminal") + + +def test_terminal_not_on_switch(): + """Ensure terminal does not auto-install to switch""" + test_switch = Switch(hostname="Test") + + assert not test_switch.software_manager.software.get("Terminal") + + +def test_terminal_send(basic_network): + """Check that Terminal can send""" + network: Network = basic_network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + + payload: SSHPacket = SSHPacket( + payload="Test_Payload", + transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, + connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, + ) + + assert terminal_a.send(payload=payload, dest_ip_address="192.168.0.11") + + +def test_terminal_fail_when_closed(basic_network): + """Ensure Terminal won't attempt to send/receive when off""" + network: Network = basic_network + computer: Computer = network.get_node_by_hostname("node_a") + terminal: Terminal = computer.software_manager.software.get("Terminal") + + terminal.operating_state = ServiceOperatingState.STOPPED + + assert terminal.login(dest_ip_address="192.168.0.11") is False + + +def test_terminal_disconnect(basic_network): + """Terminal should set is_connected to false on disconnect""" + network: Network = basic_network + computer: Computer = network.get_node_by_hostname("node_a") + terminal: Terminal = computer.software_manager.software.get("Terminal") + + assert terminal.is_connected is False + + terminal.login(dest_ip_address="192.168.0.11") + + assert terminal.is_connected is True + + terminal.disconnect(dest_ip_address="192.168.0.11") + + assert terminal.is_connected is False + + +def test_terminal_ignores_when_off(basic_network): + """Terminal should ignore commands when not running""" + network: Network = basic_network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + + computer_b: Computer = network.get_node_by_hostname("node_b") + + terminal_a.login(dest_ip_address="192.168.0.11") # login to computer_b + + assert terminal_a.is_connected is True + + terminal_a.operating_state = ServiceOperatingState.STOPPED + + payload: SSHPacket = SSHPacket( + payload="Test_Payload", + transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, + connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_DATA, + ) + + assert not terminal_a.send(payload=payload, dest_ip_address="192.168.0.11")