#2879: Handle generate_seed_value option
This commit is contained in:
@@ -26,14 +26,25 @@ except ModuleNotFoundError:
|
||||
_LOGGER.debug("Torch not available for importing")
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> Union[None, int]:
|
||||
def set_random_seed(seed: int, generate_seed_value: bool) -> Union[None, int]:
|
||||
"""
|
||||
Set random number generators.
|
||||
|
||||
If seed is None or -1 and generate_seed_value is True randomly generate a
|
||||
seed value.
|
||||
If seed is > -1 and generate_seed_value is True ignore the latter and use
|
||||
the provide seed value.
|
||||
|
||||
:param seed: int
|
||||
:param generate_seed_value: bool
|
||||
:return: None or the int representing the seed used.
|
||||
"""
|
||||
if seed is None or seed == -1:
|
||||
return None
|
||||
if generate_seed_value:
|
||||
rng = np.random.default_rng()
|
||||
seed = int(rng.integers(low=0, high=2**63))
|
||||
else:
|
||||
return None
|
||||
elif seed < -1:
|
||||
raise ValueError("Invalid random number seed")
|
||||
# Seed python RNG
|
||||
@@ -65,7 +76,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""Object that returns a config corresponding to the current episode."""
|
||||
self.seed = self.episode_scheduler(0).get("game", {}).get("seed")
|
||||
"""Get RNG seed from config file. NB: Must be before game instantiation."""
|
||||
self.seed = set_random_seed(self.seed)
|
||||
self.generate_seed_value = self.episode_scheduler(0).get("game", {}).get("generate_seed_value")
|
||||
self.seed = set_random_seed(self.seed, self.generate_seed_value)
|
||||
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
|
||||
|
||||
Reference in New Issue
Block a user