From 5dcc0189a0655a47cd5e51dc17f98a06e890c117 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Fri, 2 Aug 2024 11:30:45 +0100 Subject: [PATCH 1/9] #2777: Implementation of RNG seed --- .../scripted_agents/probabilistic_agent.py | 14 ++++---- src/primaite/game/game.py | 2 ++ src/primaite/session/environment.py | 36 +++++++++++++++++++ src/primaite/session/ray_envs.py | 2 ++ 4 files changed, 47 insertions(+), 7 deletions(-) diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py index f5905ad0..ce1da3f2 100644 --- a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -22,8 +22,6 @@ class ProbabilisticAgent(AbstractScriptedAgent): """Strict validation.""" action_probabilities: Dict[int, float] """Probability to perform each action in the action map. The sum of probabilities should sum to 1.""" - random_seed: Optional[int] = None - """Random seed. If set, each episode the agent will choose the same random sequence of actions.""" # TODO: give the option to still set a random seed, but have it vary each episode in a predictable way # for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed. @@ -59,17 +57,19 @@ class ProbabilisticAgent(AbstractScriptedAgent): num_actions = len(action_space.action_map) settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}} - # If seed not specified, set it to None so that numpy chooses a random one. - settings.setdefault("random_seed") - + # The random number seed for np.random is dependent on whether a random number seed is set + # in the config file. If there is one it is processed by set_random_seed() in environment.py + # and as a consequence the the sequence of rng_seed's used here will be repeatable. self.settings = ProbabilisticAgent.Settings(**settings) - - self.rng = np.random.default_rng(self.settings.random_seed) + rng_seed = np.random.randint(0, 65535) + self.rng = np.random.default_rng(rng_seed) + print(f"Probabilistic Agent - rng_seed: {rng_seed}") # convert probabilities from self.probabilities = np.asarray(list(self.settings.action_probabilities.values())) super().__init__(agent_name, action_space, observation_space, reward_function) + self.logger.info(f"ProbabilisticAgent RNG seed: {rng_seed}") def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 5ef8c14c..a4325b3e 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -70,6 +70,8 @@ class PrimaiteGameOptions(BaseModel): model_config = ConfigDict(extra="forbid") + seed: int = None + """Random number seed for RNGs.""" max_episode_length: int = 256 """Maximum number of episodes for the PrimAITE game.""" ports: List[str] diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index a87f0cde..359932c7 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -1,5 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK import json +import random +import sys from os import PathLike from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union @@ -17,6 +19,33 @@ from primaite.simulator.system.core.packet_capture import PacketCapture _LOGGER = getLogger(__name__) +# Check torch is installed +try: + import torch as th +except ModuleNotFoundError: + _LOGGER.debug("Torch not available for importing") + + +def set_random_seed(seed: int) -> Union[None, int]: + """ + Set random number generators. + + :param seed: int + """ + if seed is None or seed == -1: + return None + elif seed < -1: + raise ValueError("Invalid random number seed") + # Seed python RNG + random.seed(seed) + # Seed numpy RNG + np.random.seed(seed) + # Seed the RNG for all devices (both CPU and CUDA) + # if torch not installed don't set random seed. + if sys.modules["torch"]: + th.manual_seed(seed) + return seed + class PrimaiteGymEnv(gymnasium.Env): """ @@ -31,6 +60,9 @@ class PrimaiteGymEnv(gymnasium.Env): super().__init__() self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config) """Object that returns a config corresponding to the current episode.""" + self.seed = self.episode_scheduler(0).get("game").get("seed") + """Get RNG seed from config file. NB: Must be before game instantiation.""" + self.seed = set_random_seed(self.seed) self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {})) """Handles IO for the environment. This produces sys logs, agent logs, etc.""" self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0)) @@ -42,6 +74,8 @@ class PrimaiteGymEnv(gymnasium.Env): self.total_reward_per_episode: Dict[int, float] = {} """Average rewards of agents per episode.""" + _LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}") + def action_masks(self) -> np.ndarray: """ Return the action mask for the agent. @@ -108,6 +142,8 @@ class PrimaiteGymEnv(gymnasium.Env): f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {self.agent.reward_function.total_reward}" ) + if seed is not None: + set_random_seed(seed) self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward if self.io.settings.save_agent_actions: diff --git a/src/primaite/session/ray_envs.py b/src/primaite/session/ray_envs.py index 1adc324c..33c74b0e 100644 --- a/src/primaite/session/ray_envs.py +++ b/src/primaite/session/ray_envs.py @@ -63,6 +63,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" + super().reset() # Ensure PRNG seed is set everywhere rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()} _LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}") @@ -176,6 +177,7 @@ class PrimaiteRayEnv(gymnasium.Env): def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" + super().reset() # Ensure PRNG seed is set everywhere if self.env.agent.action_masking: obs, *_ = self.env.reset(seed=seed) new_obs = {"action_mask": self.env.action_masks(), "observations": obs} From a1e1a17c2a9fe87099b8bfcd9e3c3c0eab3bc408 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Fri, 2 Aug 2024 12:49:17 +0100 Subject: [PATCH 2/9] #2777: Add RNG test --- .../game_layer/test_RNG_seed.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tests/integration_tests/game_layer/test_RNG_seed.py diff --git a/tests/integration_tests/game_layer/test_RNG_seed.py b/tests/integration_tests/game_layer/test_RNG_seed.py new file mode 100644 index 00000000..c1bb7bb0 --- /dev/null +++ b/tests/integration_tests/game_layer/test_RNG_seed.py @@ -0,0 +1,43 @@ +from primaite.config.load import data_manipulation_config_path +from primaite.session.environment import PrimaiteGymEnv +from primaite.game.agent.interface import AgentHistoryItem +import yaml +from pprint import pprint +import pytest + +@pytest.fixture() +def create_env(): + with open(data_manipulation_config_path(), 'r') as f: + cfg = yaml.safe_load(f) + + env = PrimaiteGymEnv(env_config = cfg) + return env + +def test_rng_seed_set(create_env): + env = create_env + env.reset(seed=3) + for i in range(100): + env.step(0) + a = [item.timestep for item in env.game.agents['client_2_green_user'].history if item.action!="DONOTHING"] + + env.reset(seed=3) + for i in range(100): + env.step(0) + b = [item.timestep for item in env.game.agents['client_2_green_user'].history if item.action!="DONOTHING"] + + assert a==b + +def test_rng_seed_unset(create_env): + env = create_env + env.reset() + for i in range(100): + env.step(0) + a = [item.timestep for item in env.game.agents['client_2_green_user'].history if item.action!="DONOTHING"] + + env.reset() + for i in range(100): + env.step(0) + b = [item.timestep for item in env.game.agents['client_2_green_user'].history if item.action!="DONOTHING"] + + assert a!=b + From 0cc724be605fff5e65a893d9ddebd5ed2517f342 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Fri, 2 Aug 2024 12:50:40 +0100 Subject: [PATCH 3/9] #2777: Updated CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cebc2569..7d7ba9c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **Bandwidth Tracking**: Tracks data transmission across each frequency. - **New Tests**: Added to validate the respect of bandwidth capacities and the correct parsing of airspace configurations from YAML files. - **New Logging**: Added a new agent behaviour log which are more human friendly than agent history. These Logs are found in session log directory and can be enabled in the I/O settings in a yaml configuration file. - +- **Random Number Generator Seeding**: Added support for specifying a random number seed in the config file. ### Changed - **NetworkInterface Speed Type**: The `speed` attribute of `NetworkInterface` has been changed from `int` to `float`. From 2e4a1c37d1708ba7d01e3c16005f81938e1f9796 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 5 Aug 2024 10:34:06 +0100 Subject: [PATCH 4/9] #2777: Pre-commit fixes to test --- .../game_layer/test_RNG_seed.py | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/tests/integration_tests/game_layer/test_RNG_seed.py b/tests/integration_tests/game_layer/test_RNG_seed.py index c1bb7bb0..0c6d567d 100644 --- a/tests/integration_tests/game_layer/test_RNG_seed.py +++ b/tests/integration_tests/game_layer/test_RNG_seed.py @@ -1,43 +1,50 @@ -from primaite.config.load import data_manipulation_config_path -from primaite.session.environment import PrimaiteGymEnv -from primaite.game.agent.interface import AgentHistoryItem -import yaml +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from pprint import pprint + import pytest +import yaml + +from primaite.config.load import data_manipulation_config_path +from primaite.game.agent.interface import AgentHistoryItem +from primaite.session.environment import PrimaiteGymEnv + @pytest.fixture() def create_env(): - with open(data_manipulation_config_path(), 'r') as f: + with open(data_manipulation_config_path(), "r") as f: cfg = yaml.safe_load(f) - env = PrimaiteGymEnv(env_config = cfg) + env = PrimaiteGymEnv(env_config=cfg) return env + def test_rng_seed_set(create_env): + """Test with RNG seed set.""" env = create_env env.reset(seed=3) for i in range(100): env.step(0) - a = [item.timestep for item in env.game.agents['client_2_green_user'].history if item.action!="DONOTHING"] + a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] env.reset(seed=3) for i in range(100): env.step(0) - b = [item.timestep for item in env.game.agents['client_2_green_user'].history if item.action!="DONOTHING"] + b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] + + assert a == b - assert a==b def test_rng_seed_unset(create_env): + """Test with no RNG seed.""" env = create_env env.reset() for i in range(100): env.step(0) - a = [item.timestep for item in env.game.agents['client_2_green_user'].history if item.action!="DONOTHING"] + a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] env.reset() for i in range(100): env.step(0) - b = [item.timestep for item in env.game.agents['client_2_green_user'].history if item.action!="DONOTHING"] - - assert a!=b + b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] + assert a != b From 7d7117e6246d96a46bae4a1a0c6c619c219a44b5 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 5 Aug 2024 11:13:32 +0100 Subject: [PATCH 5/9] #2777: Merge with dev --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 68745913..c52f4678 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Tests to verify that airspace bandwidth is applied correctly and can be configured via YAML - Agent logging for agents' internal decision logic - Action masking in all PrimAITE environments -- **Random Number Generator Seeding**: Added support for specifying a random number seed in the config file. +- Random Number Generator Seeding by specifying a random number seed in the config file. ### Changed - Application registry was moved to the `Application` class and now updates automatically when Application is subclassed - Databases can no longer respond to request while performing a backup From 966542c2ca1b00d128594ae4afdd638d45160972 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 5 Aug 2024 15:08:31 +0100 Subject: [PATCH 6/9] #2777: Add determinism to torch backends when seed set. --- src/primaite/session/environment.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 359932c7..a12d2eb7 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -44,6 +44,10 @@ def set_random_seed(seed: int) -> Union[None, int]: # if torch not installed don't set random seed. if sys.modules["torch"]: th.manual_seed(seed) + + th.backends.cudnn.deterministic = True + th.backends.cudnn.benchmark = False + return seed From d059ddceaba77ac60ed9f24b4120e3375bfc384c Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 5 Aug 2024 15:11:57 +0100 Subject: [PATCH 7/9] #2777: Remove debug print statement --- src/primaite/game/agent/scripted_agents/probabilistic_agent.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py index ce1da3f2..ab2e69ef 100644 --- a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -63,7 +63,6 @@ class ProbabilisticAgent(AbstractScriptedAgent): self.settings = ProbabilisticAgent.Settings(**settings) rng_seed = np.random.randint(0, 65535) self.rng = np.random.default_rng(rng_seed) - print(f"Probabilistic Agent - rng_seed: {rng_seed}") # convert probabilities from self.probabilities = np.asarray(list(self.settings.action_probabilities.values())) From 3253dd80547125635c8c13693689f15bbafc6e67 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 5 Aug 2024 16:27:54 +0100 Subject: [PATCH 8/9] #2777: Update test --- .../_primaite/_game/_agent/test_probabilistic_agent.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py index f3b3c6eb..ec18f1fb 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py @@ -62,7 +62,6 @@ def test_probabilistic_agent(): reward_function=reward_function, settings={ "action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE}, - "random_seed": 120, }, ) From 3441dd25092aff65c7c9f5e9e0d11855f7bad8d7 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 5 Aug 2024 17:45:01 +0100 Subject: [PATCH 9/9] #2777: Code review changes. --- CHANGELOG.md | 4 ++-- .../game/agent/scripted_agents/probabilistic_agent.py | 2 +- src/primaite/session/environment.py | 7 +++---- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c52f4678..8b3cfbb6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] - +### Added +- Random Number Generator Seeding by specifying a random number seed in the config file. ### Changed - Removed the install/uninstall methods in the node class and made the software manager install/uninstall handle all of their functionality. @@ -22,7 +23,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Tests to verify that airspace bandwidth is applied correctly and can be configured via YAML - Agent logging for agents' internal decision logic - Action masking in all PrimAITE environments -- Random Number Generator Seeding by specifying a random number seed in the config file. ### Changed - Application registry was moved to the `Application` class and now updates automatically when Application is subclassed - Databases can no longer respond to request while performing a backup diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py index ab2e69ef..cd44644f 100644 --- a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -68,7 +68,7 @@ class ProbabilisticAgent(AbstractScriptedAgent): self.probabilities = np.asarray(list(self.settings.action_probabilities.values())) super().__init__(agent_name, action_space, observation_space, reward_function) - self.logger.info(f"ProbabilisticAgent RNG seed: {rng_seed}") + self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}") def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index a12d2eb7..c66663e3 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -44,9 +44,8 @@ def set_random_seed(seed: int) -> Union[None, int]: # if torch not installed don't set random seed. if sys.modules["torch"]: th.manual_seed(seed) - - th.backends.cudnn.deterministic = True - th.backends.cudnn.benchmark = False + th.backends.cudnn.deterministic = True + th.backends.cudnn.benchmark = False return seed @@ -64,7 +63,7 @@ class PrimaiteGymEnv(gymnasium.Env): super().__init__() self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config) """Object that returns a config corresponding to the current episode.""" - self.seed = self.episode_scheduler(0).get("game").get("seed") + self.seed = self.episode_scheduler(0).get("game", {}).get("seed") """Get RNG seed from config file. NB: Must be before game instantiation.""" self.seed = set_random_seed(self.seed) self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))