Files
PrimAITE/src/primaite/simulator/core.py

282 lines
11 KiB
Python

# flake8: noqa
"""Core of the PrimAITE Simulator."""
import warnings
from abc import abstractmethod
from typing import Callable, Dict, List, Literal, Optional, Union
from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field, validate_call
from primaite import getLogger
from primaite.interface.request import RequestFormat, RequestResponse
_LOGGER = getLogger(__name__)
class RequestPermissionValidator(BaseModel):
"""
Base class for request validators.
The permissions manager is designed to be generic. So, although in the first instance the permissions
are evaluated purely on membership to AccountGroup, this class can support validating permissions based on any
arbitrary criteria.
"""
@abstractmethod
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Use the request and context parameters to decide whether the request should be permitted."""
pass
@property
@abstractmethod
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return "request rejected"
class AllowAllValidator(RequestPermissionValidator):
"""Always allows the request."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Always allow the request."""
return True
@property
def fail_message(self) -> str:
"""
Message that is reported when a request is rejected by this validator.
This method should really never be called because this validator never rejects requests.
"""
warnings.warn("Something went wrong - AllowAllValidator rejected a request.")
return super().fail_message
class RequestType(BaseModel):
"""
This object stores data related to a single request type.
This includes the callable that can execute the request, and the validator that will decide whether
the request can be performed or not.
"""
func: Callable[[RequestFormat, Dict], RequestResponse]
"""
``func`` is a function that accepts a request and a context dict. Typically this would be a lambda function
that invokes a class method of your SimComponent. For example if the component is a node and the request type is for
turning it off, then the SimComponent should have a turn_off(self) method that does not need to accept any args.
Then, this request will be given something like ``func = lambda request, context: self.turn_off()``.
``func`` can also be another request manager, since RequestManager is a callable with a signature that matches what is
expected by ``func``.
"""
validator: RequestPermissionValidator = AllowAllValidator()
"""
``validator`` is an instance of ``RequestPermissionValidator``. This is essentially a callable that
accepts `request` and `context` and returns a boolean to represent whether the permission is granted to perform
the request. The default validator will allow
"""
class RequestManager(BaseModel):
"""
RequestManager is used by `SimComponent` instances to keep track of requests.
Its main purpose is to be a lookup from request name to request function and corresponding validation function. This
class is responsible for providing a consistent API for processing requests as well as helpful error messages.
"""
request_types: Dict[str, RequestType] = {}
"""maps request name to an RequestType object."""
def __call__(self, request: RequestFormat, context: Dict) -> RequestResponse:
"""
Process an request request.
:param request: A list of strings describing the request. The first string must be one of the allowed
request names, i.e. it must be a key of self.request_types. The subsequent strings in the list are passed as
parameters to the request function.
:type request: List[str]
:param context: Dictionary of additional information necessary to process or validate the request.
:type context: Dict
:raises RuntimeError: If the request parameter does not have a valid request name as the first item.
"""
request_key = request[0]
request_options = request[1:]
if request_key not in self.request_types:
msg = (
f"Request {request} could not be processed because {request_key} is not a valid request name",
"within this RequestManager",
)
_LOGGER.debug(msg)
return RequestResponse(status="unreachable", data={"reason": msg})
request_type = self.request_types[request_key]
if not request_type.validator(request_options, context):
_LOGGER.debug(f"Request {request} was denied due to insufficient permissions")
return RequestResponse(status="failure", data={"reason": request_type.validator.fail_message})
return request_type.func(request_options, context)
def add_request(self, name: str, request_type: RequestType) -> None:
"""
Add a request type to this request manager.
:param name: The string associated to this request.
:type name: str
:param request_type: Request type object which contains information about how to resolve request.
:type request_type: RequestType
"""
if name in self.request_types:
msg = f"Overwriting request type {name}."
_LOGGER.debug(msg)
self.request_types[name] = request_type
def remove_request(self, name: str) -> None:
"""
Remove a request from this manager.
:param name: name identifier of the request
:type name: str
"""
if name not in self.request_types:
msg = f"Attempted to remove request {name} from request manager, but it was not registered."
_LOGGER.error(msg)
raise RuntimeError(msg)
self.request_types.pop(name)
def get_request_types_recursively(self) -> List[List[str]]:
"""Recursively generate request tree for this component."""
requests = []
for req_name, req in self.request_types.items():
if isinstance(req.func, RequestManager):
sub_requests = req.func.get_request_types_recursively()
sub_requests = [[req_name] + a for a in sub_requests]
requests.extend(sub_requests)
else:
requests.append([req_name])
return requests
class SimComponent(BaseModel):
"""Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator."""
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
"""Configure pydantic to allow arbitrary types and to let the instance have attributes not present in model."""
uuid: str = Field(default_factory=lambda: str(uuid4()))
"""The component UUID."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._request_manager: RequestManager = self._init_request_manager()
self._parent: Optional["SimComponent"] = None
def setup_for_episode(self, episode: int):
"""
Perform any additional setup on this component that can't happen during __init__.
For instance, some components may require for the entire network to exist before some configuration can be set.
"""
pass
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager for this component.
When using a hierarchy of components, the child classes should call the parent class's _init_request_manager and
add additional requests on top of the existing generic ones.
Example usage for inherited classes:
..code::python
class WebBrowser(Application):
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager() # all requests generic to any Application get initialised
rm.add_request(...) # initialise any requests specific to the web browser
return rm
:return: Request manager object belonging to this sim component.
:rtype: RequestManager
"""
return RequestManager()
@abstractmethod
def describe_state(self) -> Dict:
"""
Return a dictionary describing the state of this object and any objects managed by it.
This is similar to pydantic ``model_dump()``, but it only outputs information about the objects owned by this
object. If there are objects referenced by this object that are owned by something else, it is not included in
this output.
"""
state = {
"uuid": self.uuid,
}
return state
@validate_call
def apply_request(self, request: RequestFormat, context: Optional[Dict] = None) -> RequestResponse:
"""
Apply a request to a simulation component. Request data is passed in as a 'namespaced' list of strings.
If the list only has one element, the request is intended to be applied directly to this object. If the list has
multiple entries, the request is passed to the child of this object specified by the first one or two entries.
This is essentially a namespace.
For example, ["turn_on",] is meant to apply a request of 'turn on' to this component.
However, ["services", "email_client", "turn_on"] is meant to 'turn on' this component's email client service.
:param request: List describing the request to apply to this object.
:type request: List[str]
:param: context: Dict containing context for requests
:type context: Dict
"""
if not context:
context = None
if self._request_manager is None:
return
return self._request_manager(request, context)
def pre_timestep(self, timestep: int) -> None:
"""
Apply any logic that needs to happen at the beginning of the timestep to ensure correct observations/rewards.
:param timestep: what's the current time
:type timestep: int
"""
pass
def apply_timestep(self, timestep: int) -> None:
"""
Apply a timestep evolution to this component.
Override this method with anything that happens automatically in the component such as scheduled restarts or
sending data.
"""
pass
@property
def parent(self) -> "SimComponent":
"""Reference to the parent object which manages this object.
:return: Parent object.
:rtype: SimComponent
"""
return self._parent
@parent.setter
def parent(self, new_parent: Union["SimComponent", None]) -> None:
if self._parent and new_parent:
msg = f"Overwriting parent of {self.uuid}. Old parent: {self._parent.uuid}, New parent: {new_parent.uuid}"
_LOGGER.warning(msg)
raise RuntimeWarning(msg)
self._parent = new_parent