2623 Implement basic action masking logic
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
from os import PathLike
|
||||
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, Union
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
@@ -40,6 +40,27 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""Current episode number."""
|
||||
self.total_reward_per_episode: Dict[int, float] = {}
|
||||
"""Average rewards of agents per episode."""
|
||||
self.action_masking: bool = False
|
||||
"""Whether to use action masking."""
|
||||
|
||||
def action_masks(self) -> List[bool]:
|
||||
"""
|
||||
Return the action mask for the agent.
|
||||
|
||||
This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be
|
||||
performed during this step.
|
||||
|
||||
:return: Action mask
|
||||
:rtype: List[bool]
|
||||
"""
|
||||
mask = [True] * len(self.agent.action_manager.action_map)
|
||||
if not self.action_masking:
|
||||
return mask
|
||||
|
||||
for i, action in self.agent.action_manager.action_map.items():
|
||||
request = self.agent.action_manager.form_request(action_identifier=action[0], action_options=action[1])
|
||||
mask[i] = self.game.simulation._request_manager.check_valid(request, {})
|
||||
return mask
|
||||
|
||||
@property
|
||||
def agent(self) -> ProxyAgent:
|
||||
|
||||
Reference in New Issue
Block a user