2623 Implement basic action masking logic

This commit is contained in:
Marek Wolan
2024-07-09 13:13:13 +01:00
parent cbf54d442c
commit 470fa28ee1
3 changed files with 60 additions and 20 deletions

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import json
from os import PathLike
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, Union
import gymnasium
from gymnasium.core import ActType, ObsType
@@ -40,6 +40,27 @@ class PrimaiteGymEnv(gymnasium.Env):
"""Current episode number."""
self.total_reward_per_episode: Dict[int, float] = {}
"""Average rewards of agents per episode."""
self.action_masking: bool = False
"""Whether to use action masking."""
def action_masks(self) -> List[bool]:
"""
Return the action mask for the agent.
This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be
performed during this step.
:return: Action mask
:rtype: List[bool]
"""
mask = [True] * len(self.agent.action_manager.action_map)
if not self.action_masking:
return mask
for i, action in self.agent.action_manager.action_map.items():
request = self.agent.action_manager.form_request(action_identifier=action[0], action_options=action[1])
mask[i] = self.game.simulation._request_manager.check_valid(request, {})
return mask
@property
def agent(self) -> ProxyAgent: