Merged PR 546: Log the random seed used on each run
## Summary Added changes so that the value of the random number seed is recorded in a log file when it's specified or the user asks for a RNG seed to be automatically generated. ## Test process Updated existing RNG tests; added new test. ## Checklist - [X] PR is linked to a **work item** - [X] **acceptance criteria** of linked ticket are met - [X] performed **self-review** of the code - [X] written **tests** for any new functionality added with this PR - [ ] updated the **documentation** if this PR changes or adds functionality - [ ] written/updated **design docs** if this PR implements new functionality - [X] updated the **change log** - [X] ran **pre-commit** checks for code style - [ ] attended to any **TO-DOs** left in the code Related work items: #2879
This commit is contained in:
@@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Changed
|
||||
- ACL's are no longer applied to layer-2 traffic.
|
||||
- Random number seed values are recorded in simulation/seed.log if the seed is set in the config file
|
||||
or `generate_seed_value` is set to `true`.
|
||||
- ARP .show() method will now include the port number associated with each entry.
|
||||
- Added `services_requires_scan` and `applications_requires_scan` to agent observation space config to allow the agents to be able to see actual health states of services and applications without requiring scans (Default `True`, set to `False` to allow agents to see actual health state without scanning).
|
||||
|
||||
|
||||
@@ -80,6 +80,8 @@ class PrimaiteGameOptions(BaseModel):
|
||||
|
||||
seed: int = None
|
||||
"""Random number seed for RNGs."""
|
||||
generate_seed_value: bool = False
|
||||
"""Internally generated seed value."""
|
||||
max_episode_length: int = 256
|
||||
"""Maximum number of episodes for the PrimAITE game."""
|
||||
ports: List[str]
|
||||
|
||||
@@ -26,14 +26,26 @@ except ModuleNotFoundError:
|
||||
_LOGGER.debug("Torch not available for importing")
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> Union[None, int]:
|
||||
def set_random_seed(seed: int, generate_seed_value: bool) -> Union[None, int]:
|
||||
"""
|
||||
Set random number generators.
|
||||
|
||||
If seed is None or -1 and generate_seed_value is True randomly generate a
|
||||
seed value.
|
||||
If seed is > -1 and generate_seed_value is True ignore the latter and use
|
||||
the provide seed value.
|
||||
|
||||
:param seed: int
|
||||
:param generate_seed_value: bool
|
||||
:return: None or the int representing the seed used.
|
||||
"""
|
||||
if seed is None or seed == -1:
|
||||
return None
|
||||
if generate_seed_value:
|
||||
rng = np.random.default_rng()
|
||||
# 2**32-1 is highest value for python RNG seed.
|
||||
seed = int(rng.integers(low=0, high=2**32 - 1))
|
||||
else:
|
||||
return None
|
||||
elif seed < -1:
|
||||
raise ValueError("Invalid random number seed")
|
||||
# Seed python RNG
|
||||
@@ -50,6 +62,13 @@ def set_random_seed(seed: int) -> Union[None, int]:
|
||||
return seed
|
||||
|
||||
|
||||
def log_seed_value(seed: int):
|
||||
"""Log the selected seed value to file."""
|
||||
path = SIM_OUTPUT.path / "seed.log"
|
||||
with open(path, "w") as file:
|
||||
file.write(f"Seed value = {seed}")
|
||||
|
||||
|
||||
class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""
|
||||
Thin wrapper env to provide agents with a gymnasium API.
|
||||
@@ -65,7 +84,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""Object that returns a config corresponding to the current episode."""
|
||||
self.seed = self.episode_scheduler(0).get("game", {}).get("seed")
|
||||
"""Get RNG seed from config file. NB: Must be before game instantiation."""
|
||||
self.seed = set_random_seed(self.seed)
|
||||
self.generate_seed_value = self.episode_scheduler(0).get("game", {}).get("generate_seed_value")
|
||||
self.seed = set_random_seed(self.seed, self.generate_seed_value)
|
||||
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
|
||||
@@ -79,6 +99,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
|
||||
_LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}")
|
||||
|
||||
log_seed_value(self.seed)
|
||||
|
||||
def action_masks(self) -> np.ndarray:
|
||||
"""
|
||||
Return the action mask for the agent.
|
||||
@@ -146,7 +168,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
f"avg. reward: {self.agent.reward_function.total_reward}"
|
||||
)
|
||||
if seed is not None:
|
||||
set_random_seed(seed)
|
||||
set_random_seed(seed, self.generate_seed_value)
|
||||
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
|
||||
|
||||
if self.io.settings.save_agent_actions:
|
||||
|
||||
@@ -7,6 +7,7 @@ import yaml
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.agent.interface import AgentHistoryItem
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -33,6 +34,11 @@ def test_rng_seed_set(create_env):
|
||||
|
||||
assert a == b
|
||||
|
||||
# Check that seed log file was created.
|
||||
path = SIM_OUTPUT.path / "seed.log"
|
||||
with open(path, "r") as file:
|
||||
assert file
|
||||
|
||||
|
||||
def test_rng_seed_unset(create_env):
|
||||
"""Test with no RNG seed."""
|
||||
@@ -48,3 +54,19 @@ def test_rng_seed_unset(create_env):
|
||||
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
|
||||
|
||||
assert a != b
|
||||
|
||||
|
||||
def test_for_generated_seed():
|
||||
"""
|
||||
Show that setting generate_seed_value to true producess a valid seed.
|
||||
"""
|
||||
with open(data_manipulation_config_path(), "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
cfg["game"]["generate_seed_value"] = True
|
||||
PrimaiteGymEnv(env_config=cfg)
|
||||
path = SIM_OUTPUT.path / "seed.log"
|
||||
with open(path, "r") as file:
|
||||
data = file.read()
|
||||
|
||||
assert data.split(" ")[3] != None
|
||||
|
||||
Reference in New Issue
Block a user