Merged PR 546: Log the random seed used on each run

## Summary
Added changes so that the value of the random number seed is recorded in a log file when it's specified or the user asks for a RNG seed to be automatically generated.

## Test process
Updated existing RNG tests; added new test.

## Checklist
- [X] PR is linked to a **work item**
- [X] **acceptance criteria** of linked ticket are met
- [X] performed **self-review** of the code
- [X] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [X] updated the **change log**
- [X] ran **pre-commit** checks for code style
- [ ] attended to any **TO-DOs** left in the code

Related work items: #2879
This commit is contained in:
Nick Todd
2024-09-17 13:06:10 +00:00
4 changed files with 52 additions and 4 deletions

View File

@@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- ACL's are no longer applied to layer-2 traffic.
- Random number seed values are recorded in simulation/seed.log if the seed is set in the config file
or `generate_seed_value` is set to `true`.
- ARP .show() method will now include the port number associated with each entry.
- Added `services_requires_scan` and `applications_requires_scan` to agent observation space config to allow the agents to be able to see actual health states of services and applications without requiring scans (Default `True`, set to `False` to allow agents to see actual health state without scanning).

View File

@@ -80,6 +80,8 @@ class PrimaiteGameOptions(BaseModel):
seed: int = None
"""Random number seed for RNGs."""
generate_seed_value: bool = False
"""Internally generated seed value."""
max_episode_length: int = 256
"""Maximum number of episodes for the PrimAITE game."""
ports: List[str]

View File

@@ -26,14 +26,26 @@ except ModuleNotFoundError:
_LOGGER.debug("Torch not available for importing")
def set_random_seed(seed: int) -> Union[None, int]:
def set_random_seed(seed: int, generate_seed_value: bool) -> Union[None, int]:
"""
Set random number generators.
If seed is None or -1 and generate_seed_value is True randomly generate a
seed value.
If seed is > -1 and generate_seed_value is True ignore the latter and use
the provide seed value.
:param seed: int
:param generate_seed_value: bool
:return: None or the int representing the seed used.
"""
if seed is None or seed == -1:
return None
if generate_seed_value:
rng = np.random.default_rng()
# 2**32-1 is highest value for python RNG seed.
seed = int(rng.integers(low=0, high=2**32 - 1))
else:
return None
elif seed < -1:
raise ValueError("Invalid random number seed")
# Seed python RNG
@@ -50,6 +62,13 @@ def set_random_seed(seed: int) -> Union[None, int]:
return seed
def log_seed_value(seed: int):
"""Log the selected seed value to file."""
path = SIM_OUTPUT.path / "seed.log"
with open(path, "w") as file:
file.write(f"Seed value = {seed}")
class PrimaiteGymEnv(gymnasium.Env):
"""
Thin wrapper env to provide agents with a gymnasium API.
@@ -65,7 +84,8 @@ class PrimaiteGymEnv(gymnasium.Env):
"""Object that returns a config corresponding to the current episode."""
self.seed = self.episode_scheduler(0).get("game", {}).get("seed")
"""Get RNG seed from config file. NB: Must be before game instantiation."""
self.seed = set_random_seed(self.seed)
self.generate_seed_value = self.episode_scheduler(0).get("game", {}).get("generate_seed_value")
self.seed = set_random_seed(self.seed, self.generate_seed_value)
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
@@ -79,6 +99,8 @@ class PrimaiteGymEnv(gymnasium.Env):
_LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}")
log_seed_value(self.seed)
def action_masks(self) -> np.ndarray:
"""
Return the action mask for the agent.
@@ -146,7 +168,7 @@ class PrimaiteGymEnv(gymnasium.Env):
f"avg. reward: {self.agent.reward_function.total_reward}"
)
if seed is not None:
set_random_seed(seed)
set_random_seed(seed, self.generate_seed_value)
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
if self.io.settings.save_agent_actions:

View File

@@ -7,6 +7,7 @@ import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.agent.interface import AgentHistoryItem
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator import SIM_OUTPUT
@pytest.fixture()
@@ -33,6 +34,11 @@ def test_rng_seed_set(create_env):
assert a == b
# Check that seed log file was created.
path = SIM_OUTPUT.path / "seed.log"
with open(path, "r") as file:
assert file
def test_rng_seed_unset(create_env):
"""Test with no RNG seed."""
@@ -48,3 +54,19 @@ def test_rng_seed_unset(create_env):
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
assert a != b
def test_for_generated_seed():
"""
Show that setting generate_seed_value to true producess a valid seed.
"""
with open(data_manipulation_config_path(), "r") as f:
cfg = yaml.safe_load(f)
cfg["game"]["generate_seed_value"] = True
PrimaiteGymEnv(env_config=cfg)
path = SIM_OUTPUT.path / "seed.log"
with open(path, "r") as file:
data = file.read()
assert data.split(" ")[3] != None