2623 Implement basic action masking logic
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
"""Core of the PrimAITE Simulator."""
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from prettytable import PrettyTable
|
||||
@@ -179,6 +179,23 @@ class RequestManager(BaseModel):
|
||||
table.add_rows(self.get_request_types_recursively())
|
||||
print(table)
|
||||
|
||||
def check_valid(self, request: RequestFormat, context: Dict) -> bool:
|
||||
"""Check if this request would be valid in the current state of the simulation without invoking it."""
|
||||
|
||||
request_key = request[0]
|
||||
request_options = request[1:]
|
||||
|
||||
if request_key not in self.request_types:
|
||||
return False
|
||||
|
||||
request_type = self.request_types[request_key]
|
||||
|
||||
# recurse if we are not at a leaf node
|
||||
if isinstance(request_type.func, RequestManager):
|
||||
return request_type.func.check_valid(request_options, context)
|
||||
|
||||
return request_type.validator(request_options, context)
|
||||
|
||||
|
||||
class SimComponent(BaseModel):
|
||||
"""Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator."""
|
||||
|
||||
Reference in New Issue
Block a user