#2777: Implementation of RNG seed
This commit is contained in:
@@ -22,8 +22,6 @@ class ProbabilisticAgent(AbstractScriptedAgent):
|
||||
"""Strict validation."""
|
||||
action_probabilities: Dict[int, float]
|
||||
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
|
||||
random_seed: Optional[int] = None
|
||||
"""Random seed. If set, each episode the agent will choose the same random sequence of actions."""
|
||||
# TODO: give the option to still set a random seed, but have it vary each episode in a predictable way
|
||||
# for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed.
|
||||
|
||||
@@ -59,17 +57,19 @@ class ProbabilisticAgent(AbstractScriptedAgent):
|
||||
num_actions = len(action_space.action_map)
|
||||
settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}}
|
||||
|
||||
# If seed not specified, set it to None so that numpy chooses a random one.
|
||||
settings.setdefault("random_seed")
|
||||
|
||||
# The random number seed for np.random is dependent on whether a random number seed is set
|
||||
# in the config file. If there is one it is processed by set_random_seed() in environment.py
|
||||
# and as a consequence the the sequence of rng_seed's used here will be repeatable.
|
||||
self.settings = ProbabilisticAgent.Settings(**settings)
|
||||
|
||||
self.rng = np.random.default_rng(self.settings.random_seed)
|
||||
rng_seed = np.random.randint(0, 65535)
|
||||
self.rng = np.random.default_rng(rng_seed)
|
||||
print(f"Probabilistic Agent - rng_seed: {rng_seed}")
|
||||
|
||||
# convert probabilities from
|
||||
self.probabilities = np.asarray(list(self.settings.action_probabilities.values()))
|
||||
|
||||
super().__init__(agent_name, action_space, observation_space, reward_function)
|
||||
self.logger.info(f"ProbabilisticAgent RNG seed: {rng_seed}")
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
|
||||
@@ -70,6 +70,8 @@ class PrimaiteGameOptions(BaseModel):
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
seed: int = None
|
||||
"""Random number seed for RNGs."""
|
||||
max_episode_length: int = 256
|
||||
"""Maximum number of episodes for the PrimAITE game."""
|
||||
ports: List[str]
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
from os import PathLike
|
||||
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
|
||||
|
||||
@@ -17,6 +19,33 @@ from primaite.simulator.system.core.packet_capture import PacketCapture
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
# Check torch is installed
|
||||
try:
|
||||
import torch as th
|
||||
except ModuleNotFoundError:
|
||||
_LOGGER.debug("Torch not available for importing")
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> Union[None, int]:
|
||||
"""
|
||||
Set random number generators.
|
||||
|
||||
:param seed: int
|
||||
"""
|
||||
if seed is None or seed == -1:
|
||||
return None
|
||||
elif seed < -1:
|
||||
raise ValueError("Invalid random number seed")
|
||||
# Seed python RNG
|
||||
random.seed(seed)
|
||||
# Seed numpy RNG
|
||||
np.random.seed(seed)
|
||||
# Seed the RNG for all devices (both CPU and CUDA)
|
||||
# if torch not installed don't set random seed.
|
||||
if sys.modules["torch"]:
|
||||
th.manual_seed(seed)
|
||||
return seed
|
||||
|
||||
|
||||
class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""
|
||||
@@ -31,6 +60,9 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
super().__init__()
|
||||
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
|
||||
"""Object that returns a config corresponding to the current episode."""
|
||||
self.seed = self.episode_scheduler(0).get("game").get("seed")
|
||||
"""Get RNG seed from config file. NB: Must be before game instantiation."""
|
||||
self.seed = set_random_seed(self.seed)
|
||||
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
|
||||
@@ -42,6 +74,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
self.total_reward_per_episode: Dict[int, float] = {}
|
||||
"""Average rewards of agents per episode."""
|
||||
|
||||
_LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}")
|
||||
|
||||
def action_masks(self) -> np.ndarray:
|
||||
"""
|
||||
Return the action mask for the agent.
|
||||
@@ -108,6 +142,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
f"Resetting environment, episode {self.episode_counter}, "
|
||||
f"avg. reward: {self.agent.reward_function.total_reward}"
|
||||
)
|
||||
if seed is not None:
|
||||
set_random_seed(seed)
|
||||
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
|
||||
|
||||
if self.io.settings.save_agent_actions:
|
||||
|
||||
@@ -63,6 +63,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
super().reset() # Ensure PRNG seed is set everywhere
|
||||
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
|
||||
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
|
||||
|
||||
@@ -176,6 +177,7 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
super().reset() # Ensure PRNG seed is set everywhere
|
||||
if self.env.agent.action_masking:
|
||||
obs, *_ = self.env.reset(seed=seed)
|
||||
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}
|
||||
|
||||
Reference in New Issue
Block a user