diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index b3b7189c..9a5fedc9 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -49,7 +49,7 @@ class AbstractAction(ABC): objects.""" @abstractmethod - def form_request(self) -> List[str]: + def form_request(self) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" return [] @@ -67,7 +67,7 @@ class DoNothingAction(AbstractAction): # i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter # with one option. This just aids the Action Manager to enumerate all possibilities. - def form_request(self, **kwargs) -> List[str]: + def form_request(self, **kwargs) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" return ["do_nothing"] @@ -86,7 +86,7 @@ class NodeServiceAbstractAction(AbstractAction): self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services} self.verb: str # define but don't initialise: defends against children classes not defining this - def form_request(self, node_id: int, service_id: int) -> List[str]: + def form_request(self, node_id: int, service_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) service_name = self.manager.get_service_name_by_idx(node_id, service_id) @@ -181,7 +181,7 @@ class NodeApplicationAbstractAction(AbstractAction): self.shape: Dict[str, int] = {"node_id": num_nodes, "application_id": num_applications} self.verb: str # define but don't initialise: defends against children classes not defining this - def form_request(self, node_id: int, application_id: int) -> List[str]: + def form_request(self, node_id: int, application_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) application_name = self.manager.get_application_name_by_idx(node_id, application_id) @@ -229,7 +229,7 @@ class NodeApplicationInstallAction(AbstractAction): super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes} - def form_request(self, node_id: int, application_name: str) -> List[str]: + def form_request(self, node_id: int, application_name: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) if node_name is None: @@ -324,7 +324,7 @@ class NodeApplicationRemoveAction(AbstractAction): super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes} - def form_request(self, node_id: int, application_name: str) -> List[str]: + def form_request(self, node_id: int, application_name: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) if node_name is None: @@ -346,7 +346,7 @@ class NodeFolderAbstractAction(AbstractAction): self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders} self.verb: str # define but don't initialise: defends against children classes not defining this - def form_request(self, node_id: int, folder_id: int) -> List[str]: + def form_request(self, node_id: int, folder_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id) @@ -394,7 +394,9 @@ class NodeFileCreateAction(AbstractAction): super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "create" - def form_request(self, node_id: int, folder_name: str, file_name: str, force: Optional[bool] = False) -> List[str]: + def form_request( + self, node_id: int, folder_name: str, file_name: str, force: Optional[bool] = False + ) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) if node_name is None or folder_name is None or file_name is None: @@ -409,7 +411,7 @@ class NodeFolderCreateAction(AbstractAction): super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "create" - def form_request(self, node_id: int, folder_name: str) -> List[str]: + def form_request(self, node_id: int, folder_name: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) if node_name is None or folder_name is None: @@ -430,7 +432,7 @@ class NodeFileAbstractAction(AbstractAction): self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files} self.verb: str # define but don't initialise: defends against children classes not defining this - def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]: + def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id) @@ -463,7 +465,7 @@ class NodeFileDeleteAction(NodeFileAbstractAction): super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) self.verb: str = "delete" - def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]: + def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id) @@ -504,7 +506,7 @@ class NodeFileAccessAction(AbstractAction): super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "access" - def form_request(self, node_id: int, folder_name: str, file_name: str) -> List[str]: + def form_request(self, node_id: int, folder_name: str, file_name: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) if node_name is None or folder_name is None or file_name is None: @@ -525,7 +527,7 @@ class NodeAbstractAction(AbstractAction): self.shape: Dict[str, int] = {"node_id": num_nodes} self.verb: str # define but don't initialise: defends against children classes not defining this - def form_request(self, node_id: int) -> List[str]: + def form_request(self, node_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) return ["network", "node", node_name, self.verb] @@ -740,7 +742,7 @@ class RouterACLRemoveRuleAction(AbstractAction): super().__init__(manager=manager) self.shape: Dict[str, int] = {"position": max_acl_rules} - def form_request(self, target_router: str, position: int) -> List[str]: + def form_request(self, target_router: str, position: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" return ["network", "node", target_router, "acl", "remove_rule", position] @@ -923,7 +925,7 @@ class HostNICAbstractAction(AbstractAction): self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node} self.verb: str # define but don't initialise: defends against children classes not defining this - def form_request(self, node_id: int, nic_id: int) -> List[str]: + def form_request(self, node_id: int, nic_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_idx=node_id) nic_num = self.manager.get_nic_num_by_idx(node_idx=node_id, nic_idx=nic_id) @@ -960,7 +962,7 @@ class NetworkPortEnableAction(AbstractAction): super().__init__(manager=manager) self.shape: Dict[str, int] = {"port_id": max_nics_per_node} - def form_request(self, target_nodename: str, port_id: int) -> List[str]: + def form_request(self, target_nodename: str, port_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" if target_nodename is None or port_id is None: return ["do_nothing"] @@ -979,7 +981,7 @@ class NetworkPortDisableAction(AbstractAction): super().__init__(manager=manager) self.shape: Dict[str, int] = {"port_id": max_nics_per_node} - def form_request(self, target_nodename: str, port_id: int) -> List[str]: + def form_request(self, target_nodename: str, port_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" if target_nodename is None or port_id is None: return ["do_nothing"] @@ -1315,7 +1317,7 @@ class ActionManager: act_identifier, act_options = self.action_map[action] return act_identifier, act_options - def form_request(self, action_identifier: str, action_options: Dict) -> List[str]: + def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat: """Take action in CAOS format and use the execution definition to change it into PrimAITE request format.""" act_obj = self.actions[action_identifier] return act_obj.form_request(**action_options) diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 6cc1282f..aa2bc308 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK import json from os import PathLike -from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union +from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, Union import gymnasium from gymnasium.core import ActType, ObsType @@ -40,6 +40,27 @@ class PrimaiteGymEnv(gymnasium.Env): """Current episode number.""" self.total_reward_per_episode: Dict[int, float] = {} """Average rewards of agents per episode.""" + self.action_masking: bool = False + """Whether to use action masking.""" + + def action_masks(self) -> List[bool]: + """ + Return the action mask for the agent. + + This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be + performed during this step. + + :return: Action mask + :rtype: List[bool] + """ + mask = [True] * len(self.agent.action_manager.action_map) + if not self.action_masking: + return mask + + for i, action in self.agent.action_manager.action_map.items(): + request = self.agent.action_manager.form_request(action_identifier=action[0], action_options=action[1]) + mask[i] = self.game.simulation._request_manager.check_valid(request, {}) + return mask @property def agent(self) -> ProxyAgent: diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index be5eb4b9..7653a3ab 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -3,7 +3,7 @@ """Core of the PrimAITE Simulator.""" import warnings from abc import abstractmethod -from typing import Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union from uuid import uuid4 from prettytable import PrettyTable @@ -179,6 +179,23 @@ class RequestManager(BaseModel): table.add_rows(self.get_request_types_recursively()) print(table) + def check_valid(self, request: RequestFormat, context: Dict) -> bool: + """Check if this request would be valid in the current state of the simulation without invoking it.""" + + request_key = request[0] + request_options = request[1:] + + if request_key not in self.request_types: + return False + + request_type = self.request_types[request_key] + + # recurse if we are not at a leaf node + if isinstance(request_type.func, RequestManager): + return request_type.func.check_valid(request_options, context) + + return request_type.validator(request_options, context) + class SimComponent(BaseModel): """Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator."""