#2623 Make it possible to view currently valid simulation requests

This commit is contained in:
Marek Wolan
2024-07-08 15:17:35 +01:00
parent 2a0695d0d1
commit cbf54d442c
2 changed files with 26 additions and 7 deletions

View File

@@ -3,9 +3,10 @@
"""Core of the PrimAITE Simulator."""
import warnings
from abc import abstractmethod
from typing import Callable, Dict, List, Literal, Optional, Union
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
from uuid import uuid4
from prettytable import PrettyTable
from pydantic import BaseModel, ConfigDict, Field, validate_call
from primaite import getLogger
@@ -150,18 +151,34 @@ class RequestManager(BaseModel):
self.request_types.pop(name)
def get_request_types_recursively(self) -> List[List[str]]:
"""Recursively generate request tree for this component."""
def get_request_types_recursively(self, _parent_valid: bool = True) -> List[Tuple[RequestFormat, bool]]:
"""
Recursively generate request tree for this component.
:param parent_valid: Whether this sub-request's parent request was valid. This value should not be specified by
users, it is used by the recursive call.
:type parent_valid: bool
:returns: A list of tuples where the first tuple element is the request string and the second is whether that
request is currently possible to execute.
:rtype: List[Tuple[RequestFormat, bool]]
"""
requests = []
for req_name, req in self.request_types.items():
valid = req.validator([], {}) and _parent_valid # if parent is invalid, all children are invalid
if isinstance(req.func, RequestManager):
sub_requests = req.func.get_request_types_recursively()
sub_requests = [[req_name] + a for a in sub_requests]
sub_requests = req.func.get_request_types_recursively(valid) # recurse
sub_requests = [([req_name] + a, valid) for a, valid in sub_requests] # prepend parent request to leaf
requests.extend(sub_requests)
else:
requests.append([req_name])
else: # leaf node found
requests.append(([req_name], valid))
return requests
def show(self) -> None:
table = PrettyTable(["request", "valid"])
table.align = "l"
table.add_rows(self.get_request_types_recursively())
print(table)
class SimComponent(BaseModel):
"""Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator."""

View File

@@ -52,6 +52,8 @@ class GroupMembershipValidator(RequestPermissionValidator):
def __call__(self, request: List[str], context: Dict) -> bool:
"""Permit the action if the request comes from an account which belongs to the right group."""
# if context request source is part of any groups mentioned in self.allow_groups, return true, otherwise false
if not context:
return False
requestor_groups: List[str] = context["request_source"]["groups"]
for allowed_group in self.allowed_groups:
if allowed_group.name in requestor_groups: