Merged PR 498: RNG seed setting.

## Summary
Add support for setting random number seed in config file.

## Test process
Tested on all notebooks in PrimAITE Internal except Training-an-Ray-RLLIB-MARL-System.
Added specific test for seed setting.
## Checklist
- [X] PR is linked to a **work item**
- [X] **acceptance criteria** of linked ticket are met
- [X] performed **self-review** of the code
- [X] written **tests** for any new functionality added with this PR
- [X] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [X] updated the **change log**
- [X] ran **pre-commit** checks for code style
- [X] attended to any **TO-DOs** left in the code

Related work items: #2777
This commit is contained in:
Nick Todd
2024-08-07 08:04:49 +00:00
7 changed files with 101 additions and 10 deletions

View File

@@ -6,7 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Added
- Random Number Generator Seeding by specifying a random number seed in the config file.
### Changed
- Removed the install/uninstall methods in the node class and made the software manager install/uninstall handle all of their functionality.
@@ -22,7 +23,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Tests to verify that airspace bandwidth is applied correctly and can be configured via YAML
- Agent logging for agents' internal decision logic
- Action masking in all PrimAITE environments
### Changed
- Application registry was moved to the `Application` class and now updates automatically when Application is subclassed
- Databases can no longer respond to request while performing a backup

View File

@@ -22,8 +22,6 @@ class ProbabilisticAgent(AbstractScriptedAgent):
"""Strict validation."""
action_probabilities: Dict[int, float]
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
random_seed: Optional[int] = None
"""Random seed. If set, each episode the agent will choose the same random sequence of actions."""
# TODO: give the option to still set a random seed, but have it vary each episode in a predictable way
# for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed.
@@ -59,17 +57,18 @@ class ProbabilisticAgent(AbstractScriptedAgent):
num_actions = len(action_space.action_map)
settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}}
# If seed not specified, set it to None so that numpy chooses a random one.
settings.setdefault("random_seed")
# The random number seed for np.random is dependent on whether a random number seed is set
# in the config file. If there is one it is processed by set_random_seed() in environment.py
# and as a consequence the the sequence of rng_seed's used here will be repeatable.
self.settings = ProbabilisticAgent.Settings(**settings)
self.rng = np.random.default_rng(self.settings.random_seed)
rng_seed = np.random.randint(0, 65535)
self.rng = np.random.default_rng(rng_seed)
# convert probabilities from
self.probabilities = np.asarray(list(self.settings.action_probabilities.values()))
super().__init__(agent_name, action_space, observation_space, reward_function)
self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}")
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""

View File

@@ -70,6 +70,8 @@ class PrimaiteGameOptions(BaseModel):
model_config = ConfigDict(extra="forbid")
seed: int = None
"""Random number seed for RNGs."""
max_episode_length: int = 256
"""Maximum number of episodes for the PrimAITE game."""
ports: List[str]

View File

@@ -1,5 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import json
import random
import sys
from os import PathLike
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
@@ -17,6 +19,36 @@ from primaite.simulator.system.core.packet_capture import PacketCapture
_LOGGER = getLogger(__name__)
# Check torch is installed
try:
import torch as th
except ModuleNotFoundError:
_LOGGER.debug("Torch not available for importing")
def set_random_seed(seed: int) -> Union[None, int]:
"""
Set random number generators.
:param seed: int
"""
if seed is None or seed == -1:
return None
elif seed < -1:
raise ValueError("Invalid random number seed")
# Seed python RNG
random.seed(seed)
# Seed numpy RNG
np.random.seed(seed)
# Seed the RNG for all devices (both CPU and CUDA)
# if torch not installed don't set random seed.
if sys.modules["torch"]:
th.manual_seed(seed)
th.backends.cudnn.deterministic = True
th.backends.cudnn.benchmark = False
return seed
class PrimaiteGymEnv(gymnasium.Env):
"""
@@ -31,6 +63,9 @@ class PrimaiteGymEnv(gymnasium.Env):
super().__init__()
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
"""Object that returns a config corresponding to the current episode."""
self.seed = self.episode_scheduler(0).get("game", {}).get("seed")
"""Get RNG seed from config file. NB: Must be before game instantiation."""
self.seed = set_random_seed(self.seed)
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
@@ -42,6 +77,8 @@ class PrimaiteGymEnv(gymnasium.Env):
self.total_reward_per_episode: Dict[int, float] = {}
"""Average rewards of agents per episode."""
_LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}")
def action_masks(self) -> np.ndarray:
"""
Return the action mask for the agent.
@@ -108,6 +145,8 @@ class PrimaiteGymEnv(gymnasium.Env):
f"Resetting environment, episode {self.episode_counter}, "
f"avg. reward: {self.agent.reward_function.total_reward}"
)
if seed is not None:
set_random_seed(seed)
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
if self.io.settings.save_agent_actions:

View File

@@ -63,6 +63,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
super().reset() # Ensure PRNG seed is set everywhere
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
@@ -176,6 +177,7 @@ class PrimaiteRayEnv(gymnasium.Env):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
super().reset() # Ensure PRNG seed is set everywhere
if self.env.agent.action_masking:
obs, *_ = self.env.reset(seed=seed)
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}

View File

@@ -0,0 +1,50 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from pprint import pprint
import pytest
import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.agent.interface import AgentHistoryItem
from primaite.session.environment import PrimaiteGymEnv
@pytest.fixture()
def create_env():
with open(data_manipulation_config_path(), "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(env_config=cfg)
return env
def test_rng_seed_set(create_env):
"""Test with RNG seed set."""
env = create_env
env.reset(seed=3)
for i in range(100):
env.step(0)
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
env.reset(seed=3)
for i in range(100):
env.step(0)
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
assert a == b
def test_rng_seed_unset(create_env):
"""Test with no RNG seed."""
env = create_env
env.reset()
for i in range(100):
env.step(0)
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
env.reset()
for i in range(100):
env.step(0)
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
assert a != b

View File

@@ -62,7 +62,6 @@ def test_probabilistic_agent():
reward_function=reward_function,
settings={
"action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE},
"random_seed": 120,
},
)