2623 Implement basic action masking logic

This commit is contained in:
Marek Wolan
2024-07-09 13:13:13 +01:00
parent cbf54d442c
commit 470fa28ee1
3 changed files with 60 additions and 20 deletions

View File

@@ -3,7 +3,7 @@
"""Core of the PrimAITE Simulator."""
import warnings
from abc import abstractmethod
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
from uuid import uuid4
from prettytable import PrettyTable
@@ -179,6 +179,23 @@ class RequestManager(BaseModel):
table.add_rows(self.get_request_types_recursively())
print(table)
def check_valid(self, request: RequestFormat, context: Dict) -> bool:
"""Check if this request would be valid in the current state of the simulation without invoking it."""
request_key = request[0]
request_options = request[1:]
if request_key not in self.request_types:
return False
request_type = self.request_types[request_key]
# recurse if we are not at a leaf node
if isinstance(request_type.func, RequestManager):
return request_type.func.check_valid(request_options, context)
return request_type.validator(request_options, context)
class SimComponent(BaseModel):
"""Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator."""