Make a type alias for request & fix typo

This commit is contained in:
Marek Wolan
2024-03-08 17:14:41 +00:00
parent 0447a05084
commit 289b5c548a
2 changed files with 8 additions and 9 deletions

View File

@@ -11,6 +11,8 @@ from primaite.interface.request import RequestResponse
_LOGGER = getLogger(__name__)
RequestFormat = List[Union[str, int, float]]
class RequestPermissionValidator(BaseModel):
"""
@@ -22,7 +24,7 @@ class RequestPermissionValidator(BaseModel):
"""
@abstractmethod
def __call__(self, request: List[str], context: Dict) -> bool:
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Use the request and context parameters to decide whether the request should be permitted."""
pass
@@ -30,7 +32,7 @@ class RequestPermissionValidator(BaseModel):
class AllowAllValidator(RequestPermissionValidator):
"""Always allows the request."""
def __call__(self, request: List[str], context: Dict) -> bool:
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Always allow the request."""
return True
@@ -43,7 +45,7 @@ class RequestType(BaseModel):
the request can be performed or not.
"""
func: Callable[[List[Union[str, int, float]], Dict], RequestResponse]
func: Callable[[RequestFormat, Dict], RequestResponse]
"""
``func`` is a function that accepts a request and a context dict. Typically this would be a lambda function
that invokes a class method of your SimComponent. For example if the component is a node and the request type is for
@@ -72,8 +74,7 @@ class RequestManager(BaseModel):
request_types: Dict[str, RequestType] = {}
"""maps request name to an RequestType object."""
@validate_call
def __call__(self, request: List[Union[str, int, float]], context: Dict) -> RequestResponse:
def __call__(self, request: RequestFormat, context: Dict) -> RequestResponse:
"""
Process an request request.
@@ -93,8 +94,6 @@ class RequestManager(BaseModel):
f"Request {request} could not be processed because {request_key} is not a valid request name",
"within this RequestManager",
)
# _LOGGER.error(msg)
# raise RuntimeError(msg)
_LOGGER.debug(msg)
return RequestResponse(status="unreachable", data={"reason": msg})
@@ -207,7 +206,7 @@ class SimComponent(BaseModel):
return state
@validate_call
def apply_request(self, request: List[Union[str, int, float]], context: Dict = {}) -> RequestResponse:
def apply_request(self, request: RequestFormat, context: Dict = {}) -> RequestResponse:
"""
Apply a request to a simulation component. Request data is passed in as a 'namespaced' list of strings.

View File

@@ -62,7 +62,7 @@ class FileSystem(SimComponent):
self._restore_manager.add_request(
name="file",
request_type=RequestType(
func=lambda request, context: RequestResponse(
func=lambda request, context: RequestResponse.from_bool(
self.restore_file(folder_name=request[0], file_name=request[1])
)
),