Test my validators

This commit is contained in:
Marek Wolan
2023-08-03 16:26:33 +01:00
parent 94617c57a4
commit 2a680c1e48
6 changed files with 235 additions and 55 deletions

View File

@@ -1,11 +1,11 @@
"""Core of the PrimAITE Simulator."""
from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Optional
from uuid import uuid4
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Extra
from primaite import getLogger
from primaite.simulator.domain import AccountGroup
_LOGGER = getLogger(__name__)
@@ -33,23 +33,6 @@ class AllowAllValidator(ActionPermissionValidator):
return True
class GroupMembershipValidator(ActionPermissionValidator):
"""Permit actions based on group membership."""
def __init__(self, allowed_groups: List[AccountGroup]) -> None:
"""TODO."""
self.allowed_groups = allowed_groups
def __call__(self, request: List[str], context: Dict) -> bool:
"""Permit the action if the request comes from an account which belongs to the right group."""
# if context request source is part of any groups mentioned in self.allow_groups, return true, otherwise false
requestor_groups: List[str] = context["request_source"]["groups"]
for allowed_group in self.allowed_groups:
if allowed_group.name in requestor_groups:
return True
return False
class Action:
"""
This object stores data related to a single action.
@@ -83,7 +66,7 @@ class ActionManager:
def __init__(self) -> None:
"""TODO."""
self.actions: Dict[str, Action]
self.actions: Dict[str, Action] = {}
def process_request(self, request: List[str], context: Dict) -> None:
"""Process action request."""
@@ -106,17 +89,20 @@ class ActionManager:
action.func(action_options, context)
def add_action(self, name: str, action: Action) -> None:
self.actions[name] = action
class SimComponent(BaseModel):
"""Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator."""
model_config = ConfigDict(arbitrary_types_allowed=True)
uuid: str
model_config = ConfigDict(arbitrary_types_allowed=True, extra=Extra.allow)
uuid: str = str(uuid4())
"The component UUID."
def __init__(self, **kwargs) -> None:
self.action_manager: Optional[ActionManager] = None
super().__init__(**kwargs)
self.action_manager: Optional[ActionManager] = None
@abstractmethod
def describe_state(self) -> Dict:

View File

@@ -1,3 +0,0 @@
from primaite.simulator.domain.account import Account, AccountGroup, AccountType
__all__ = ["Account", "AccountGroup", "AccountType"]

View File

@@ -1,6 +1,6 @@
"""User account simulation."""
from enum import Enum
from typing import Callable, Dict, List, TypeAlias
from typing import Any, Callable, Dict, List
from primaite import getLogger
from primaite.simulator.core import SimComponent
@@ -8,9 +8,6 @@ from primaite.simulator.core import SimComponent
_LOGGER = getLogger(__name__)
__temp_node = TypeAlias() # placeholder while nodes don't exist
class AccountType(Enum):
"""Whether the account is intended for a user to log in or for a service to use."""
@@ -20,19 +17,6 @@ class AccountType(Enum):
"User accounts are used to allow agents to log in and perform actions"
class AccountGroup(Enum):
"""Permissions are set at group-level and accounts can belong to these groups."""
local_user = 1
"For performing basic actions on a node"
domain_user = 2
"For performing basic actions to the domain"
local_admin = 3
"For full access to actions on a node"
domain_admin = 4
"For full access"
class AccountStatus(Enum):
"""Whether the account is active."""

View File

@@ -1,29 +1,71 @@
from typing import Dict, Final, List, TypeAlias
from enum import Enum
from typing import Any, Dict, Final, List
from primaite.simulator.core import ActionPermissionValidator, SimComponent
from primaite.simulator.domain.account import Account, AccountType
from primaite.simulator.core import SimComponent
from primaite.simulator.domain import Account, AccountGroup, AccountType
# placeholder while these objects don't yet exist
__temp_node = TypeAlias()
__temp_application = TypeAlias()
__temp_folder = TypeAlias()
__temp_file = TypeAlias()
class temp_node:
pass
class temp_application:
pass
class temp_folder:
pass
class temp_file:
pass
class AccountGroup(Enum):
"""Permissions are set at group-level and accounts can belong to these groups."""
local_user = 1
"For performing basic actions on a node"
domain_user = 2
"For performing basic actions to the domain"
local_admin = 3
"For full access to actions on a node"
domain_admin = 4
"For full access"
class GroupMembershipValidator(ActionPermissionValidator):
"""Permit actions based on group membership."""
def __init__(self, allowed_groups: List[AccountGroup]) -> None:
"""TODO."""
self.allowed_groups = allowed_groups
def __call__(self, request: List[str], context: Dict) -> bool:
"""Permit the action if the request comes from an account which belongs to the right group."""
# if context request source is part of any groups mentioned in self.allow_groups, return true, otherwise false
requestor_groups: List[str] = context["request_source"]["groups"]
for allowed_group in self.allowed_groups:
if allowed_group.name in requestor_groups:
return True
return False
class DomainController(SimComponent):
"""Main object for controlling the domain."""
# owned objects
accounts: List(Account) = []
accounts: List[Account] = []
groups: Final[List[AccountGroup]] = list(AccountGroup)
group_membership: Dict[AccountGroup, List[Account]]
# references to non-owned objects
nodes: List(__temp_node) = []
applications: List(__temp_application) = []
folders: List(__temp_folder) = []
files: List(__temp_file) = []
nodes: List[temp_node] = []
applications: List[temp_application] = []
folders: List[temp_folder] = []
files: List[temp_file] = []
def _register_account(self, account: Account) -> None:
"""TODO."""

View File

@@ -0,0 +1,171 @@
from enum import Enum
from typing import Dict, List, Literal
import pytest
from primaite.simulator.core import Action, ActionManager, AllowAllValidator, SimComponent
from primaite.simulator.domain.controller import AccountGroup, GroupMembershipValidator
def test_group_action_validation() -> None:
"""Check that actions are denied when an unauthorised request is made."""
class Folder(SimComponent):
name: str
def describe_state(self) -> Dict:
return super().describe_state()
class Node(SimComponent):
name: str
folders: List[Folder] = []
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.action_manager = ActionManager()
self.action_manager.add_action(
"create_folder",
Action(
func=lambda request, context: self.create_folder(request[0]),
validator=GroupMembershipValidator([AccountGroup.local_admin, AccountGroup.domain_admin]),
),
)
def describe_state(self) -> Dict:
return super().describe_state()
def create_folder(self, folder_name: str) -> None:
new_folder = Folder(uuid="0000-0000-0001", name=folder_name)
self.folders.append(new_folder)
def remove_folder(self, folder: Folder) -> None:
self.folders = [x for x in self.folders if x is not folder]
permitted_context = {"request_source": {"agent": "BLUE", "account": "User1", "groups": ["local_admin"]}}
my_node = Node(uuid="0000-0000-1234", name="pc")
my_node.apply_action(["create_folder", "memes"], context=permitted_context)
assert len(my_node.folders) == 1
assert my_node.folders[0].name == "memes"
invalid_context = {"request_source": {"agent": "BLUE", "account": "User1", "groups": ["local_user", "domain_user"]}}
my_node.apply_action(["create_folder", "memes2"], context=invalid_context)
assert len(my_node.folders) == 1
assert my_node.folders[0].name == "memes"
def test_hierarchical_action_with_validation() -> None:
"""Check that validation works with sub-objects"""
class Application(SimComponent):
name: str
state: Literal["on", "off", "disabled"] = "off"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.action_manager = ActionManager()
self.action_manager.add_action(
"turn_on",
Action(
func=lambda request, context: self.turn_on(),
validator=AllowAllValidator(),
),
)
self.action_manager.add_action(
"turn_off",
Action(
func=lambda request, context: self.turn_off(),
validator=AllowAllValidator(),
),
)
self.action_manager.add_action(
"disable",
Action(
func=lambda request, context: self.disable(),
validator=GroupMembershipValidator([AccountGroup.local_admin, AccountGroup.domain_admin]),
),
)
self.action_manager.add_action(
"enable",
Action(
func=lambda request, context: self.enable(),
validator=GroupMembershipValidator([AccountGroup.local_admin, AccountGroup.domain_admin]),
),
)
def describe_state(self) -> Dict:
return super().describe_state()
def disable(self) -> None:
self.status = "disabled"
def enable(self) -> None:
if self.status == "disabled":
self.status = "off"
def turn_on(self) -> None:
if self.status == "off":
self.status = "on"
def turn_off(self) -> None:
if self.status == "on":
self.status = "off"
class Node(SimComponent):
name: str
state: Literal["on", "off"] = "on"
apps: List[Application] = []
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.action_manager = ActionManager()
self.action_manager.add_action(
"apps",
Action(
func=lambda request, context: self.send_action_to_app(request.pop(0), request, context),
validator=AllowAllValidator(),
),
)
def describe_state(self) -> Dict:
return super().describe_state()
def install_app(self, app_name: str) -> None:
new_app = Application(name=app_name)
self.apps.append(new_app)
def send_action_to_app(self, app_name: str, options: List[str], context: Dict):
for app in self.apps:
if app_name == app.name:
app.apply_action(options)
break
else:
msg = f"Node has no app with name {app_name}"
raise LookupError(msg)
my_node = Node(name="pc")
my_node.install_app("Chrome")
my_node.install_app("Firefox")
non_admin_context = {
"request_source": {"agent": "BLUE", "account": "User1", "groups": ["local_user", "domain_user"]}
}
admin_context = {
"request_source": {
"agent": "BLUE",
"account": "User1",
"groups": ["local_admin", "domain_admin", "local_user", "domain_user"],
}
}
my_node.apply_action(["apps", "Chrome", "disable"], non_admin_context)
my_node.apply_action(["apps", "Firefox", "turn_on"], non_admin_context)
assert my_node.apps[0].name == "Chrome"
assert my_node.apps[1].name == "Firefox"
assert my_node.apps[0].state == ... # TODO: finish