2623 Implement basic action masking logic

This commit is contained in:
Marek Wolan
2024-07-09 13:13:13 +01:00
parent cbf54d442c
commit 470fa28ee1
3 changed files with 60 additions and 20 deletions

View File

@@ -49,7 +49,7 @@ class AbstractAction(ABC):
objects."""
@abstractmethod
def form_request(self) -> List[str]:
def form_request(self) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return []
@@ -67,7 +67,7 @@ class DoNothingAction(AbstractAction):
# i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter
# with one option. This just aids the Action Manager to enumerate all possibilities.
def form_request(self, **kwargs) -> List[str]:
def form_request(self, **kwargs) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["do_nothing"]
@@ -86,7 +86,7 @@ class NodeServiceAbstractAction(AbstractAction):
self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, service_id: int) -> List[str]:
def form_request(self, node_id: int, service_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
service_name = self.manager.get_service_name_by_idx(node_id, service_id)
@@ -181,7 +181,7 @@ class NodeApplicationAbstractAction(AbstractAction):
self.shape: Dict[str, int] = {"node_id": num_nodes, "application_id": num_applications}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, application_id: int) -> List[str]:
def form_request(self, node_id: int, application_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
application_name = self.manager.get_application_name_by_idx(node_id, application_id)
@@ -229,7 +229,7 @@ class NodeApplicationInstallAction(AbstractAction):
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes}
def form_request(self, node_id: int, application_name: str) -> List[str]:
def form_request(self, node_id: int, application_name: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None:
@@ -324,7 +324,7 @@ class NodeApplicationRemoveAction(AbstractAction):
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes}
def form_request(self, node_id: int, application_name: str) -> List[str]:
def form_request(self, node_id: int, application_name: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None:
@@ -346,7 +346,7 @@ class NodeFolderAbstractAction(AbstractAction):
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, folder_id: int) -> List[str]:
def form_request(self, node_id: int, folder_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id)
@@ -394,7 +394,9 @@ class NodeFileCreateAction(AbstractAction):
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "create"
def form_request(self, node_id: int, folder_name: str, file_name: str, force: Optional[bool] = False) -> List[str]:
def form_request(
self, node_id: int, folder_name: str, file_name: str, force: Optional[bool] = False
) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None or folder_name is None or file_name is None:
@@ -409,7 +411,7 @@ class NodeFolderCreateAction(AbstractAction):
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "create"
def form_request(self, node_id: int, folder_name: str) -> List[str]:
def form_request(self, node_id: int, folder_name: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None or folder_name is None:
@@ -430,7 +432,7 @@ class NodeFileAbstractAction(AbstractAction):
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id)
@@ -463,7 +465,7 @@ class NodeFileDeleteAction(NodeFileAbstractAction):
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb: str = "delete"
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id)
@@ -504,7 +506,7 @@ class NodeFileAccessAction(AbstractAction):
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "access"
def form_request(self, node_id: int, folder_name: str, file_name: str) -> List[str]:
def form_request(self, node_id: int, folder_name: str, file_name: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None or folder_name is None or file_name is None:
@@ -525,7 +527,7 @@ class NodeAbstractAction(AbstractAction):
self.shape: Dict[str, int] = {"node_id": num_nodes}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int) -> List[str]:
def form_request(self, node_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
return ["network", "node", node_name, self.verb]
@@ -740,7 +742,7 @@ class RouterACLRemoveRuleAction(AbstractAction):
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"position": max_acl_rules}
def form_request(self, target_router: str, position: int) -> List[str]:
def form_request(self, target_router: str, position: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["network", "node", target_router, "acl", "remove_rule", position]
@@ -923,7 +925,7 @@ class HostNICAbstractAction(AbstractAction):
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, nic_id: int) -> List[str]:
def form_request(self, node_id: int, nic_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_idx=node_id)
nic_num = self.manager.get_nic_num_by_idx(node_idx=node_id, nic_idx=nic_id)
@@ -960,7 +962,7 @@ class NetworkPortEnableAction(AbstractAction):
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"port_id": max_nics_per_node}
def form_request(self, target_nodename: str, port_id: int) -> List[str]:
def form_request(self, target_nodename: str, port_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if target_nodename is None or port_id is None:
return ["do_nothing"]
@@ -979,7 +981,7 @@ class NetworkPortDisableAction(AbstractAction):
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"port_id": max_nics_per_node}
def form_request(self, target_nodename: str, port_id: int) -> List[str]:
def form_request(self, target_nodename: str, port_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if target_nodename is None or port_id is None:
return ["do_nothing"]
@@ -1315,7 +1317,7 @@ class ActionManager:
act_identifier, act_options = self.action_map[action]
return act_identifier, act_options
def form_request(self, action_identifier: str, action_options: Dict) -> List[str]:
def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat:
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format."""
act_obj = self.actions[action_identifier]
return act_obj.form_request(**action_options)

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import json
from os import PathLike
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, Union
import gymnasium
from gymnasium.core import ActType, ObsType
@@ -40,6 +40,27 @@ class PrimaiteGymEnv(gymnasium.Env):
"""Current episode number."""
self.total_reward_per_episode: Dict[int, float] = {}
"""Average rewards of agents per episode."""
self.action_masking: bool = False
"""Whether to use action masking."""
def action_masks(self) -> List[bool]:
"""
Return the action mask for the agent.
This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be
performed during this step.
:return: Action mask
:rtype: List[bool]
"""
mask = [True] * len(self.agent.action_manager.action_map)
if not self.action_masking:
return mask
for i, action in self.agent.action_manager.action_map.items():
request = self.agent.action_manager.form_request(action_identifier=action[0], action_options=action[1])
mask[i] = self.game.simulation._request_manager.check_valid(request, {})
return mask
@property
def agent(self) -> ProxyAgent:

View File

@@ -3,7 +3,7 @@
"""Core of the PrimAITE Simulator."""
import warnings
from abc import abstractmethod
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
from uuid import uuid4
from prettytable import PrettyTable
@@ -179,6 +179,23 @@ class RequestManager(BaseModel):
table.add_rows(self.get_request_types_recursively())
print(table)
def check_valid(self, request: RequestFormat, context: Dict) -> bool:
"""Check if this request would be valid in the current state of the simulation without invoking it."""
request_key = request[0]
request_options = request[1:]
if request_key not in self.request_types:
return False
request_type = self.request_types[request_key]
# recurse if we are not at a leaf node
if isinstance(request_type.func, RequestManager):
return request_type.func.check_valid(request_options, context)
return request_type.validator(request_options, context)
class SimComponent(BaseModel):
"""Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator."""