#2735 - make the disabled/enabled admins/non-admins dynamic properties for simplicity. Added num_of_logins to User. Added additional test for counting user logins. Added all users to the UserManager describe_state function. Refactored model fields with empty dict as default value to have direct instantiation instead of using Field(default_factory=dict) or Field(default_factory=: lambda: {}).

This commit is contained in:
Chris McCarthy
2024-08-02 12:47:02 +01:00
parent 61c7cc2da3
commit 696236aa61
2 changed files with 68 additions and 11 deletions

View File

@@ -817,6 +817,9 @@ class User(SimComponent):
is_admin: bool = False
"""Boolean flag indicating whether the user has admin privileges"""
num_of_logins: int = 0
"""Counts the number of the User has logged in"""
def describe_state(self) -> Dict:
"""
Returns a dictionary representing the current state of the user.
@@ -835,9 +838,7 @@ class UserManager(Service):
:param disabled_admins: A dictionary of currently disabled admin users by their usernames
"""
users: Dict[str, User] = Field(default_factory=dict)
admins: Dict[str, User] = Field(default_factory=dict)
disabled_admins: Dict[str, User] = Field(default_factory=dict)
users: Dict[str, User] = {}
def __init__(self, **kwargs):
"""
@@ -880,6 +881,7 @@ class UserManager(Service):
"""
state = super().describe_state()
state.update({"total_users": len(self.users), "total_admins": len(self.admins) + len(self.disabled_admins)})
state["users"] = {k: v.describe_state() for k, v in self.users.items()}
return state
def show(self, markdown: bool = False):
@@ -897,6 +899,42 @@ class UserManager(Service):
table.add_row([user.username, user.is_admin, user.disabled])
print(table.get_string(sortby="Username"))
@property
def non_admins(self) -> Dict[str, User]:
"""
Returns a dictionary of all enabled non-admin users.
:return: A dictionary with usernames as keys and User objects as values for non-admin, non-disabled users.
"""
return {k: v for k, v in self.users.items() if not v.is_admin and not v.disabled}
@property
def disabled_non_admins(self) -> Dict[str, User]:
"""
Returns a dictionary of all disabled non-admin users.
:return: A dictionary with usernames as keys and User objects as values for non-admin, disabled users.
"""
return {k: v for k, v in self.users.items() if not v.is_admin and v.disabled}
@property
def admins(self) -> Dict[str, User]:
"""
Returns a dictionary of all enabled admin users.
:return: A dictionary with usernames as keys and User objects as values for admin, non-disabled users.
"""
return {k: v for k, v in self.users.items() if v.is_admin and not v.disabled}
@property
def disabled_admins(self) -> Dict[str, User]:
"""
Returns a dictionary of all disabled admin users.
:return: A dictionary with usernames as keys and User objects as values for admin, disabled users.
"""
return {k: v for k, v in self.users.items() if v.is_admin and v.disabled}
def install(self) -> None:
"""Setup default user during first-time installation."""
self.add_user(username="admin", password="admin", is_admin=True, bypass_can_perform_action=True)
@@ -922,8 +960,6 @@ class UserManager(Service):
return False
user = User(username=username, password=password, is_admin=is_admin)
self.users[username] = user
if is_admin:
self.admins[username] = user
self.sys_log.info(f"{self.name}: Added new {'admin' if is_admin else 'user'}: {username}")
return True
@@ -978,8 +1014,6 @@ class UserManager(Service):
return False
self.users[username].disabled = True
self.sys_log.info(f"{self.name}: User disabled: {username}")
if username in self.admins:
self.disabled_admins[username] = self.admins.pop(username)
return True
self.sys_log.info(f"{self.name}: Failed to disable user: {username}")
return False
@@ -994,8 +1028,6 @@ class UserManager(Service):
if username in self.users and self.users[username].disabled:
self.users[username].disabled = False
self.sys_log.info(f"{self.name}: User enabled: {username}")
if username in self.disabled_admins:
self.admins[username] = self.disabled_admins.pop(username)
return True
self.sys_log.info(f"{self.name}: Failed to enable user: {username}")
return False
@@ -1028,7 +1060,7 @@ class UserSession(SimComponent):
"""The timestep when the session ended, if applicable."""
local: bool = True
"""Indicates if the session is local. Defaults to True."""
"""Indicates if the session is a local session or a remote session. Defaults to True as a local session."""
@classmethod
def create(cls, user: User, timestep: int) -> UserSession:
@@ -1041,6 +1073,7 @@ class UserSession(SimComponent):
:param timestep: The timestep when the session is created.
:return: An instance of UserSession.
"""
user.num_of_logins += 1
return UserSession(user=user, start_step=timestep, last_active_step=timestep)
def describe_state(self) -> Dict:
@@ -1107,7 +1140,7 @@ class UserSessionManager(Service):
local_session: Optional[UserSession] = None
"""The current local user session, if any."""
remote_sessions: Dict[str, RemoteUserSession] = Field(default_factory=dict)
remote_sessions: Dict[str, RemoteUserSession] = {}
"""A dictionary of active remote user sessions."""
historic_sessions: List[UserSession] = Field(default_factory=list)

View File

@@ -5,6 +5,7 @@ from uuid import uuid4
import pytest
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.base import User
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
@@ -46,6 +47,29 @@ def test_local_login_success(client_server_network):
assert client.user_session_manager.local_user_logged_in
def test_login_count_increases(client_server_network):
client, server, network = client_server_network
admin_user: User = client.user_manager.users["admin"]
assert admin_user.num_of_logins == 0
client.user_session_manager.local_login(username="admin", password="admin")
assert admin_user.num_of_logins == 1
client.user_session_manager.local_login(username="admin", password="admin")
# shouldn't change as user is already logged in
assert admin_user.num_of_logins == 1
client.user_session_manager.local_logout()
client.user_session_manager.local_login(username="admin", password="admin")
assert admin_user.num_of_logins == 2
def test_local_login_failure(client_server_network):
client, server, network = client_server_network