2623 Implement basic action masking logic
This commit is contained in:
@@ -49,7 +49,7 @@ class AbstractAction(ABC):
|
||||
objects."""
|
||||
|
||||
@abstractmethod
|
||||
def form_request(self) -> List[str]:
|
||||
def form_request(self) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return []
|
||||
|
||||
@@ -67,7 +67,7 @@ class DoNothingAction(AbstractAction):
|
||||
# i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter
|
||||
# with one option. This just aids the Action Manager to enumerate all possibilities.
|
||||
|
||||
def form_request(self, **kwargs) -> List[str]:
|
||||
def form_request(self, **kwargs) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["do_nothing"]
|
||||
|
||||
@@ -86,7 +86,7 @@ class NodeServiceAbstractAction(AbstractAction):
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services}
|
||||
self.verb: str # define but don't initialise: defends against children classes not defining this
|
||||
|
||||
def form_request(self, node_id: int, service_id: int) -> List[str]:
|
||||
def form_request(self, node_id: int, service_id: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
service_name = self.manager.get_service_name_by_idx(node_id, service_id)
|
||||
@@ -181,7 +181,7 @@ class NodeApplicationAbstractAction(AbstractAction):
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "application_id": num_applications}
|
||||
self.verb: str # define but don't initialise: defends against children classes not defining this
|
||||
|
||||
def form_request(self, node_id: int, application_id: int) -> List[str]:
|
||||
def form_request(self, node_id: int, application_id: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
application_name = self.manager.get_application_name_by_idx(node_id, application_id)
|
||||
@@ -229,7 +229,7 @@ class NodeApplicationInstallAction(AbstractAction):
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes}
|
||||
|
||||
def form_request(self, node_id: int, application_name: str) -> List[str]:
|
||||
def form_request(self, node_id: int, application_name: str) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
if node_name is None:
|
||||
@@ -324,7 +324,7 @@ class NodeApplicationRemoveAction(AbstractAction):
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes}
|
||||
|
||||
def form_request(self, node_id: int, application_name: str) -> List[str]:
|
||||
def form_request(self, node_id: int, application_name: str) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
if node_name is None:
|
||||
@@ -346,7 +346,7 @@ class NodeFolderAbstractAction(AbstractAction):
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders}
|
||||
self.verb: str # define but don't initialise: defends against children classes not defining this
|
||||
|
||||
def form_request(self, node_id: int, folder_id: int) -> List[str]:
|
||||
def form_request(self, node_id: int, folder_id: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id)
|
||||
@@ -394,7 +394,9 @@ class NodeFileCreateAction(AbstractAction):
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "create"
|
||||
|
||||
def form_request(self, node_id: int, folder_name: str, file_name: str, force: Optional[bool] = False) -> List[str]:
|
||||
def form_request(
|
||||
self, node_id: int, folder_name: str, file_name: str, force: Optional[bool] = False
|
||||
) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
if node_name is None or folder_name is None or file_name is None:
|
||||
@@ -409,7 +411,7 @@ class NodeFolderCreateAction(AbstractAction):
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "create"
|
||||
|
||||
def form_request(self, node_id: int, folder_name: str) -> List[str]:
|
||||
def form_request(self, node_id: int, folder_name: str) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
if node_name is None or folder_name is None:
|
||||
@@ -430,7 +432,7 @@ class NodeFileAbstractAction(AbstractAction):
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files}
|
||||
self.verb: str # define but don't initialise: defends against children classes not defining this
|
||||
|
||||
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
|
||||
def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id)
|
||||
@@ -463,7 +465,7 @@ class NodeFileDeleteAction(NodeFileAbstractAction):
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb: str = "delete"
|
||||
|
||||
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
|
||||
def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id)
|
||||
@@ -504,7 +506,7 @@ class NodeFileAccessAction(AbstractAction):
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "access"
|
||||
|
||||
def form_request(self, node_id: int, folder_name: str, file_name: str) -> List[str]:
|
||||
def form_request(self, node_id: int, folder_name: str, file_name: str) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
if node_name is None or folder_name is None or file_name is None:
|
||||
@@ -525,7 +527,7 @@ class NodeAbstractAction(AbstractAction):
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes}
|
||||
self.verb: str # define but don't initialise: defends against children classes not defining this
|
||||
|
||||
def form_request(self, node_id: int) -> List[str]:
|
||||
def form_request(self, node_id: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
return ["network", "node", node_name, self.verb]
|
||||
@@ -740,7 +742,7 @@ class RouterACLRemoveRuleAction(AbstractAction):
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"position": max_acl_rules}
|
||||
|
||||
def form_request(self, target_router: str, position: int) -> List[str]:
|
||||
def form_request(self, target_router: str, position: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["network", "node", target_router, "acl", "remove_rule", position]
|
||||
|
||||
@@ -923,7 +925,7 @@ class HostNICAbstractAction(AbstractAction):
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
|
||||
self.verb: str # define but don't initialise: defends against children classes not defining this
|
||||
|
||||
def form_request(self, node_id: int, nic_id: int) -> List[str]:
|
||||
def form_request(self, node_id: int, nic_id: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_idx=node_id)
|
||||
nic_num = self.manager.get_nic_num_by_idx(node_idx=node_id, nic_idx=nic_id)
|
||||
@@ -960,7 +962,7 @@ class NetworkPortEnableAction(AbstractAction):
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"port_id": max_nics_per_node}
|
||||
|
||||
def form_request(self, target_nodename: str, port_id: int) -> List[str]:
|
||||
def form_request(self, target_nodename: str, port_id: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if target_nodename is None or port_id is None:
|
||||
return ["do_nothing"]
|
||||
@@ -979,7 +981,7 @@ class NetworkPortDisableAction(AbstractAction):
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"port_id": max_nics_per_node}
|
||||
|
||||
def form_request(self, target_nodename: str, port_id: int) -> List[str]:
|
||||
def form_request(self, target_nodename: str, port_id: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if target_nodename is None or port_id is None:
|
||||
return ["do_nothing"]
|
||||
@@ -1315,7 +1317,7 @@ class ActionManager:
|
||||
act_identifier, act_options = self.action_map[action]
|
||||
return act_identifier, act_options
|
||||
|
||||
def form_request(self, action_identifier: str, action_options: Dict) -> List[str]:
|
||||
def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat:
|
||||
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format."""
|
||||
act_obj = self.actions[action_identifier]
|
||||
return act_obj.form_request(**action_options)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
from os import PathLike
|
||||
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, Union
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
@@ -40,6 +40,27 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""Current episode number."""
|
||||
self.total_reward_per_episode: Dict[int, float] = {}
|
||||
"""Average rewards of agents per episode."""
|
||||
self.action_masking: bool = False
|
||||
"""Whether to use action masking."""
|
||||
|
||||
def action_masks(self) -> List[bool]:
|
||||
"""
|
||||
Return the action mask for the agent.
|
||||
|
||||
This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be
|
||||
performed during this step.
|
||||
|
||||
:return: Action mask
|
||||
:rtype: List[bool]
|
||||
"""
|
||||
mask = [True] * len(self.agent.action_manager.action_map)
|
||||
if not self.action_masking:
|
||||
return mask
|
||||
|
||||
for i, action in self.agent.action_manager.action_map.items():
|
||||
request = self.agent.action_manager.form_request(action_identifier=action[0], action_options=action[1])
|
||||
mask[i] = self.game.simulation._request_manager.check_valid(request, {})
|
||||
return mask
|
||||
|
||||
@property
|
||||
def agent(self) -> ProxyAgent:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"""Core of the PrimAITE Simulator."""
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from prettytable import PrettyTable
|
||||
@@ -179,6 +179,23 @@ class RequestManager(BaseModel):
|
||||
table.add_rows(self.get_request_types_recursively())
|
||||
print(table)
|
||||
|
||||
def check_valid(self, request: RequestFormat, context: Dict) -> bool:
|
||||
"""Check if this request would be valid in the current state of the simulation without invoking it."""
|
||||
|
||||
request_key = request[0]
|
||||
request_options = request[1:]
|
||||
|
||||
if request_key not in self.request_types:
|
||||
return False
|
||||
|
||||
request_type = self.request_types[request_key]
|
||||
|
||||
# recurse if we are not at a leaf node
|
||||
if isinstance(request_type.func, RequestManager):
|
||||
return request_type.func.check_valid(request_options, context)
|
||||
|
||||
return request_type.validator(request_options, context)
|
||||
|
||||
|
||||
class SimComponent(BaseModel):
|
||||
"""Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator."""
|
||||
|
||||
Reference in New Issue
Block a user