diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py index 7be81982..361c2552 100644 --- a/src/primaite/simulator/network/protocols/ssh.py +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -56,6 +56,8 @@ class SSHConnectionMessage(IntEnum): SSH_MSG_CHANNEL_CLOSE = 87 """Closes the channel.""" + SSH_LOGOFF_ACK = 89 + """Logoff confirmation acknowledgement""" class SSHPacket(DataPacket): """Represents an SSHPacket.""" diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 1dd3133d..e5ff9054 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -24,8 +24,8 @@ class TerminalClientConnection(BaseModel): This class is used to record current User Connections within the Terminal class. """ - connection_id: str - """Connection UUID.""" + session_id: str + """Session UUID.""" parent_node: Node # Technically I think this should be HostNode, but that causes a circular import. """The parent Node that this connection was created on.""" @@ -76,6 +76,8 @@ class Terminal(Service): kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) + # %% Util + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -100,6 +102,22 @@ class Terminal(Service): rm = super()._init_request_manager() return rm + def _validate_login(self, user_account: Optional[str]) -> bool: + """Validate login credentials are valid.""" + # Pending login/Usermanager implementation + if user_account: + # validate bits - poke UserManager with provided info + # return self.user_manager.validate(user_account) + pass + else: + pass + # user_account = next(iter(self.user_connections)) + # return self.user_manager.validate(user_account) + + return True + + + # %% Inbound def _generate_connection_id(self) -> str: @@ -142,40 +160,50 @@ class Terminal(Service): self.send(payload=payload, dest_ip_address=dest_ip_address) return True - def validate_user(self, user: Dict[str]) -> bool: - return True if user.get("username") in self.user_connections else False + def validate_user(self, session_id: str) -> bool: + return True - def _ssh_process_logoff(self, dest_ip_address: IPv4Address, user_account: dict, **kwargs) -> bool: + def _ssh_process_logoff(self, session_id: str, *args, **kwargs) -> bool: """Process the logoff attempt. Return a bool if succesful or unsuccessful.""" - if self.validate_user(user_account): + if self.validate_user(session_id): # Account is logged in - self.user_connections.pop[user_account["username"]] # assumption atm - self.is_connected = False return True else: self.sys_log.warning("User account credentials invalid.") + return False def _ssh_process_command(self, session_id: str, *args, **kwargs) -> bool: return True - def send_logoff_ack(self): + def send_logoff_ack(self, session_id: str): """Send confirmation of successful disconnect""" transport_message = SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS - connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE + connection_message = SSHConnectionMessage.SSH_LOGOFF_ACK payload: SSHPacket = SSHPacket( transport_message=transport_message, connection_message=connection_message, - ssh_output=RequestResponse(status="success"), + ssh_output=RequestResponse(status="success", data={"reason": "Successfully Disconnected"}), ) - self.send(payload=payload) + self.send(payload=payload, session_id=session_id) def receive(self, payload: SSHPacket, session_id: str, **kwargs) -> bool: + # shouldn't be expecting to see anything other than SSHPacket payloads currently + # confirm that we are receiving the + if not isinstance(payload, SSHPacket): + return False self.sys_log.debug(f"Received payload: {payload} from session: {session_id}") - if payload.connection_message == SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE: - result = self._ssh_process_logoff(session_id=session_id) + + if payload.connection_message == SSHConnectionMessage.SSH_LOGOFF_ACK: + # Logoff acknowledgement received. NFA needed. + self.sys_log.debug("Received confirmation of successful disconnect") + return True + + elif payload.connection_message == SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE: + self._ssh_process_logoff(session_id=session_id) + self.sys_log.debug("Disconnect message received, sending logoff ack") # We need to close on the other machine as well - self.send_logoff_ack() + self.send_logoff_ack(session_id=session_id) elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: src_ip = kwargs.get("frame").ip.src_ip_address @@ -191,12 +219,13 @@ class Terminal(Service): # send a SSH_MSG_CHANNEL_CLOSE if there is a session_id otherwise SSH_MSG_OPEN_FAILED return False - self.send(payload=result, session_id=session_id) + # self.send(payload=result, session_id=session_id) return True + # %% Outbound - def login(self, dest_ip_address: IPv4Address) -> bool: + def login(self, dest_ip_address: IPv4Address, user_account: dict[str]) -> bool: """ Perform an initial login request. @@ -204,13 +233,14 @@ class Terminal(Service): """ # TODO: This will need elaborating when user accounts are implemented self.sys_log.info("Attempting Login") - return self.ssh_remote_login(self, dest_ip_address=dest_ip_address, user_account=self.user_account) + return self._ssh_remote_login(self, dest_ip_address=dest_ip_address, user_account=user_account) - def ssh_remote_login(self, dest_ip_address: IPv4Address, user_account: Optional[dict] = None) -> bool: + def _ssh_remote_login(self, dest_ip_address: IPv4Address, user_account: Optional[dict] = None) -> bool: """Remote login to terminal via SSH.""" if not user_account: # Setting default creds (Best to use this until we have more clarification around user accounts) user_account = {self.user_name: "placeholder", self.password: "placeholder"} + # something like self.user_manager.get_user_details ? # Implement SSHPacket class payload: SSHPacket = SSHPacket( @@ -275,5 +305,9 @@ class Terminal(Service): payload: SSHPacket, dest_ip_address: Optional[IPv4Address] = None, session_id: Optional[str] = None, + user_account: Optional[str] = None, ) -> bool: - return super().send(payload=payload, dest_ip_address=dest_ip_address, dest_port=Port.SSH, session_id=session_id) + """Send a payload out from the Terminal.""" + self._validate_login(user_account) + self.sys_log.debug(f"Sending payload: {payload} from session: {session_id}") + return super().send(payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port, session_id=session_id) diff --git a/tests/integration_tests/system/test_nmap.py b/tests/integration_tests/system/test_nmap.py index bbfa4f43..a261f272 100644 --- a/tests/integration_tests/system/test_nmap.py +++ b/tests/integration_tests/system/test_nmap.py @@ -106,7 +106,7 @@ def test_port_scan_full_subnet_all_ports_and_protocols(example_network): expected_result = { IPv4Address("192.168.10.1"): {IPProtocol.UDP: [Port.ARP]}, IPv4Address("192.168.10.22"): { - IPProtocol.TCP: [Port.HTTP, Port.FTP, Port.DNS], + IPProtocol.TCP: [Port.HTTP, Port.FTP, Port.DNS, Port.SSH], IPProtocol.UDP: [Port.ARP, Port.NTP], }, }