diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 9e6784c5..3ffc7b35 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -9,7 +9,7 @@ from pathlib import Path from typing import Any, ClassVar, Dict, List, Optional, TypeVar, Union from prettytable import MARKDOWN, PrettyTable -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validate_call import primaite.simulator.network.nmne from primaite import getLogger @@ -989,6 +989,12 @@ class RemoteUserSession(UserSession): remote_ip_address: IPV4Address local: bool = False + @classmethod + def create(cls, user: User, timestep: int, remote_ip_address: IPV4Address) -> RemoteUserSession: # noqa + return RemoteUserSession( + user=user, start_step=timestep, last_active_step=timestep, remote_ip_address=remote_ip_address + ) + def describe_state(self) -> Dict: state = super().describe_state() state["remote_ip_address"] = str(self.remote_ip_address) @@ -1066,7 +1072,9 @@ class UserSessionManager(Service): print(table.get_string(sortby="Step Last Active", reversesort=True)) def describe_state(self) -> Dict: - return super().describe_state() + state = super().describe_state() + state["active_remote_logins"] = len(self.remote_sessions) + return state @property def _user_manager(self) -> UserManager: @@ -1092,27 +1100,78 @@ class UserSessionManager(Service): self.sys_log.info(f"{self.name}: {session_type} {session_identity} session timeout due to inactivity") - def login(self, username: str, password: str) -> Optional[str]: + @property + def remote_session_limit_reached(self) -> bool: + return len(self.remote_sessions) >= self.max_remote_sessions + + def validate_remote_session_uuid(self, remote_session_id: str) -> bool: + return remote_session_id in self.remote_sessions + + def _login( + self, username: str, password: str, local: bool = True, remote_ip_address: Optional[IPv4Address] = None + ) -> Optional[str]: if not self._can_perform_action(): return None - user = self._user_manager.authenticate_user(username=username, password=password) - if user: - self.logout() - self.local_session = UserSession.create(user=user, timestep=self.current_timestep) - self.sys_log.info(f"{self.name}: User {user.username} logged in") - return self.local_session.uuid - else: - self.sys_log.info(f"{self.name}: Incorrect username or password") - def logout(self): + user = self._user_manager.authenticate_user(username=username, password=password) + + if not user: + self.sys_log.info(f"{self.name}: Incorrect username or password") + return None + + session_id = None + if local: + create_new_session = True + if self.local_session: + if self.local_session.user != user: + # logout the current user + self.local_logout() + else: + # not required as existing logged-in user attempting to re-login + create_new_session = False + + if create_new_session: + self.local_session = UserSession.create(user=user, timestep=self.current_timestep) + + session_id = self.local_session.uuid + else: + if not self.remote_session_limit_reached: + remote_session = RemoteUserSession.create( + user=user, timestep=self.current_timestep, remote_ip_address=remote_ip_address + ) + session_id = remote_session.uuid + self.remote_sessions[session_id] = remote_session + self.sys_log.info(f"{self.name}: User {user.username} logged in") + return session_id + + def local_login(self, username: str, password: str) -> Optional[str]: + return self._login(username=username, password=password, local=True) + + @validate_call() + def remote_login(self, username: str, password: str, remote_ip_address: IPV4Address) -> Optional[str]: + return self._login(username=username, password=password, local=False, remote_ip_address=remote_ip_address) + + def _logout(self, local: bool = True, remote_session_id: Optional[str] = None): if not self._can_perform_action(): return False - if self.local_session: + session = None + if local and self.local_session: session = self.local_session session.end_step = self.current_timestep - self.historic_sessions.append(session) self.local_session = None + + if not local and remote_session_id: + session = self.remote_sessions.pop(remote_session_id) + if session: + self.historic_sessions.append(session) self.sys_log.info(f"{self.name}: User {session.user.username} logged out") + return + + def local_logout(self): + self._logout(local=True) + + def remote_logout(self, remote_session_id: str): + self._logout(local=False, remote_session_id=remote_session_id) @property def local_user_logged_in(self): @@ -1225,8 +1284,8 @@ class Node(SimComponent): def user_session_manager(self) -> UserSessionManager: return self.software_manager.software["UserSessionManager"] # noqa - def login(self, username: str, password: str) -> Optional[str]: - return self.user_session_manager.login(username, password) + def local_login(self, username: str, password: str) -> Optional[str]: + return self.user_session_manager.local_login(username, password) def logout(self): return self.user_session_manager.logout() @@ -1765,14 +1824,10 @@ class Node(SimComponent): :param pings: The number of pings to attempt, default is 4. :return: True if the ping is successful, otherwise False. """ - if not self.user_session_manager.local_user_logged_in: - return False if not isinstance(target_ip_address, IPv4Address): target_ip_address = IPv4Address(target_ip_address) if self.software_manager.icmp: - print("yes") return self.software_manager.icmp.ping(target_ip_address, pings) - print("no icmp") return False @abstractmethod diff --git a/tests/integration_tests/system/test_local_accounts.py b/tests/integration_tests/system/test_local_accounts.py deleted file mode 100644 index dbdbf857..00000000 --- a/tests/integration_tests/system/test_local_accounts.py +++ /dev/null @@ -1,37 +0,0 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from primaite.simulator.network.container import Network -from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.hardware.nodes.host.server import Server - - -def test_local_accounts_ping_temp(): - network = Network() - - # Create Computer - computer = Computer( - hostname="computer", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) - computer.power_on() - - # Create Server - server = Server( - hostname="server", - ip_address="192.168.1.3", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) - server.power_on() - - # Connect Computer and Server - network.connect(computer.network_interface[1], server.network_interface[1]) - - assert not computer.ping(server.network_interface[1].ip_address) - - computer.user_session_manager.login(username="admin", password="admin") - - assert computer.ping(server.network_interface[1].ip_address) diff --git a/tests/integration_tests/system/test_user_session_manager_logins.py b/tests/integration_tests/system/test_user_session_manager_logins.py new file mode 100644 index 00000000..955408ad --- /dev/null +++ b/tests/integration_tests/system/test_user_session_manager_logins.py @@ -0,0 +1,250 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Tuple +from uuid import uuid4 + +import pytest + +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server + + +@pytest.fixture(scope="function") +def client_server_network() -> Tuple[Computer, Server, Network]: + network = Network() + + client = Computer( + hostname="client", + ip_address="192.168.1.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + client.power_on() + + server = Server( + hostname="server", + ip_address="192.168.1.3", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + server.power_on() + + network.connect(client.network_interface[1], server.network_interface[1]) + + return client, server, network + + +def test_local_login_success(client_server_network): + client, server, network = client_server_network + + assert not client.user_session_manager.local_user_logged_in + + client.user_session_manager.local_login(username="admin", password="admin") + + assert client.user_session_manager.local_user_logged_in + + +def test_local_login_failure(client_server_network): + client, server, network = client_server_network + + assert not client.user_session_manager.local_user_logged_in + + client.user_session_manager.local_login(username="jane.doe", password="12345") + + assert not client.user_session_manager.local_user_logged_in + + +def test_new_user_local_login_success(client_server_network): + client, server, network = client_server_network + + assert not client.user_session_manager.local_user_logged_in + + client.user_manager.add_user(username="jane.doe", password="12345") + + client.user_session_manager.local_login(username="jane.doe", password="12345") + + assert client.user_session_manager.local_user_logged_in + + +def test_new_local_login_clears_previous_login(client_server_network): + client, server, network = client_server_network + + assert not client.user_session_manager.local_user_logged_in + + current_session_id = client.user_session_manager.local_login(username="admin", password="admin") + + assert client.user_session_manager.local_user_logged_in + + assert client.user_session_manager.local_session.user.username == "admin" + + client.user_manager.add_user(username="jane.doe", password="12345") + + new_session_id = client.user_session_manager.local_login(username="jane.doe", password="12345") + + assert client.user_session_manager.local_user_logged_in + + assert client.user_session_manager.local_session.user.username == "jane.doe" + + assert new_session_id != current_session_id + + +def test_new_local_login_attempt_same_uses_persists(client_server_network): + client, server, network = client_server_network + + assert not client.user_session_manager.local_user_logged_in + + current_session_id = client.user_session_manager.local_login(username="admin", password="admin") + + assert client.user_session_manager.local_user_logged_in + + assert client.user_session_manager.local_session.user.username == "admin" + + new_session_id = client.user_session_manager.local_login(username="admin", password="admin") + + assert client.user_session_manager.local_user_logged_in + + assert client.user_session_manager.local_session.user.username == "admin" + + assert new_session_id == current_session_id + + +def test_remote_login_success(client_server_network): + # partial test for now until we get the terminal application in so that amn actual remote connection can be made + client, server, network = client_server_network + + assert not server.user_session_manager.remote_sessions + + remote_session_id = server.user_session_manager.remote_login( + username="admin", password="admin", remote_ip_address="192.168.1.10" + ) + + assert server.user_session_manager.validate_remote_session_uuid(remote_session_id) + + server.user_session_manager.remote_logout(remote_session_id) + + assert not server.user_session_manager.validate_remote_session_uuid(remote_session_id) + + +def test_remote_login_failure(client_server_network): + # partial test for now until we get the terminal application in so that amn actual remote connection can be made + client, server, network = client_server_network + + assert not server.user_session_manager.remote_sessions + + remote_session_id = server.user_session_manager.remote_login( + username="jane.doe", password="12345", remote_ip_address="192.168.1.10" + ) + + assert not server.user_session_manager.validate_remote_session_uuid(remote_session_id) + + +def test_new_user_remote_login_success(client_server_network): + client, server, network = client_server_network + + server.user_manager.add_user(username="jane.doe", password="12345") + + remote_session_id = server.user_session_manager.remote_login( + username="jane.doe", password="12345", remote_ip_address="192.168.1.10" + ) + + assert server.user_session_manager.validate_remote_session_uuid(remote_session_id) + + server.user_session_manager.remote_logout(remote_session_id) + + assert not server.user_session_manager.validate_remote_session_uuid(remote_session_id) + + +def test_max_remote_sessions_same_user(client_server_network): + client, server, network = client_server_network + + remote_session_ids = [ + server.user_session_manager.remote_login(username="admin", password="admin", remote_ip_address="192.168.1.10") + for _ in range(server.user_session_manager.max_remote_sessions) + ] + + assert all([server.user_session_manager.validate_remote_session_uuid(id) for id in remote_session_ids]) + + +def test_max_remote_sessions_different_users(client_server_network): + client, server, network = client_server_network + + remote_session_ids = [] + + for i in range(server.user_session_manager.max_remote_sessions): + username = str(uuid4()) + password = "12345" + server.user_manager.add_user(username=username, password=password) + + remote_session_ids.append( + server.user_session_manager.remote_login( + username=username, password=password, remote_ip_address="192.168.1.10" + ) + ) + + assert all([server.user_session_manager.validate_remote_session_uuid(id) for id in remote_session_ids]) + + +def test_max_remote_sessions_limit_reached(client_server_network): + client, server, network = client_server_network + + remote_session_ids = [ + server.user_session_manager.remote_login(username="admin", password="admin", remote_ip_address="192.168.1.10") + for _ in range(server.user_session_manager.max_remote_sessions) + ] + + assert all([server.user_session_manager.validate_remote_session_uuid(id) for id in remote_session_ids]) + + assert len(server.user_session_manager.remote_sessions) == server.user_session_manager.max_remote_sessions + + fourth_attempt_session_id = server.user_session_manager.remote_login( + username="admin", password="admin", remote_ip_address="192.168.1.10" + ) + + assert not server.user_session_manager.validate_remote_session_uuid(fourth_attempt_session_id) + + assert all([server.user_session_manager.validate_remote_session_uuid(id) for id in remote_session_ids]) + + +def test_single_remote_logout_others_persist(client_server_network): + client, server, network = client_server_network + + server.user_manager.add_user(username="jane.doe", password="12345") + server.user_manager.add_user(username="john.doe", password="12345") + + admin_session_id = server.user_session_manager.remote_login( + username="admin", password="admin", remote_ip_address="192.168.1.10" + ) + + jane_session_id = server.user_session_manager.remote_login( + username="jane.doe", password="12345", remote_ip_address="192.168.1.10" + ) + + john_session_id = server.user_session_manager.remote_login( + username="john.doe", password="12345", remote_ip_address="192.168.1.10" + ) + + server.user_session_manager.remote_logout(admin_session_id) + + assert not server.user_session_manager.validate_remote_session_uuid(admin_session_id) + + assert server.user_session_manager.validate_remote_session_uuid(jane_session_id) + + assert server.user_session_manager.validate_remote_session_uuid(john_session_id) + + server.user_session_manager.remote_logout(jane_session_id) + + assert not server.user_session_manager.validate_remote_session_uuid(admin_session_id) + + assert not server.user_session_manager.validate_remote_session_uuid(jane_session_id) + + assert server.user_session_manager.validate_remote_session_uuid(john_session_id) + + server.user_session_manager.remote_logout(john_session_id) + + assert not server.user_session_manager.validate_remote_session_uuid(admin_session_id) + + assert not server.user_session_manager.validate_remote_session_uuid(jane_session_id) + + assert not server.user_session_manager.validate_remote_session_uuid(john_session_id)