diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ce4bbf2..8d999607 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,8 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] - ### Added +- Random Number Generator Seeding by specifying a random number seed in the config file. - Implemented Terminal service class, providing a generic terminal simulation. ### Changed @@ -25,7 +25,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Tests to verify that airspace bandwidth is applied correctly and can be configured via YAML - Agent logging for agents' internal decision logic - Action masking in all PrimAITE environments - ### Changed - Application registry was moved to the `Application` class and now updates automatically when Application is subclassed - Databases can no longer respond to request while performing a backup diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py index f5905ad0..cd44644f 100644 --- a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -22,8 +22,6 @@ class ProbabilisticAgent(AbstractScriptedAgent): """Strict validation.""" action_probabilities: Dict[int, float] """Probability to perform each action in the action map. The sum of probabilities should sum to 1.""" - random_seed: Optional[int] = None - """Random seed. If set, each episode the agent will choose the same random sequence of actions.""" # TODO: give the option to still set a random seed, but have it vary each episode in a predictable way # for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed. @@ -59,17 +57,18 @@ class ProbabilisticAgent(AbstractScriptedAgent): num_actions = len(action_space.action_map) settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}} - # If seed not specified, set it to None so that numpy chooses a random one. - settings.setdefault("random_seed") - + # The random number seed for np.random is dependent on whether a random number seed is set + # in the config file. If there is one it is processed by set_random_seed() in environment.py + # and as a consequence the the sequence of rng_seed's used here will be repeatable. self.settings = ProbabilisticAgent.Settings(**settings) - - self.rng = np.random.default_rng(self.settings.random_seed) + rng_seed = np.random.randint(0, 65535) + self.rng = np.random.default_rng(rng_seed) # convert probabilities from self.probabilities = np.asarray(list(self.settings.action_probabilities.values())) super().__init__(agent_name, action_space, observation_space, reward_function) + self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}") def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index c197a8de..6a97ad25 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -72,6 +72,8 @@ class PrimaiteGameOptions(BaseModel): model_config = ConfigDict(extra="forbid") + seed: int = None + """Random number seed for RNGs.""" max_episode_length: int = 256 """Maximum number of episodes for the PrimAITE game.""" ports: List[str] diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index a87f0cde..c66663e3 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -1,5 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK import json +import random +import sys from os import PathLike from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union @@ -17,6 +19,36 @@ from primaite.simulator.system.core.packet_capture import PacketCapture _LOGGER = getLogger(__name__) +# Check torch is installed +try: + import torch as th +except ModuleNotFoundError: + _LOGGER.debug("Torch not available for importing") + + +def set_random_seed(seed: int) -> Union[None, int]: + """ + Set random number generators. + + :param seed: int + """ + if seed is None or seed == -1: + return None + elif seed < -1: + raise ValueError("Invalid random number seed") + # Seed python RNG + random.seed(seed) + # Seed numpy RNG + np.random.seed(seed) + # Seed the RNG for all devices (both CPU and CUDA) + # if torch not installed don't set random seed. + if sys.modules["torch"]: + th.manual_seed(seed) + th.backends.cudnn.deterministic = True + th.backends.cudnn.benchmark = False + + return seed + class PrimaiteGymEnv(gymnasium.Env): """ @@ -31,6 +63,9 @@ class PrimaiteGymEnv(gymnasium.Env): super().__init__() self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config) """Object that returns a config corresponding to the current episode.""" + self.seed = self.episode_scheduler(0).get("game", {}).get("seed") + """Get RNG seed from config file. NB: Must be before game instantiation.""" + self.seed = set_random_seed(self.seed) self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {})) """Handles IO for the environment. This produces sys logs, agent logs, etc.""" self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0)) @@ -42,6 +77,8 @@ class PrimaiteGymEnv(gymnasium.Env): self.total_reward_per_episode: Dict[int, float] = {} """Average rewards of agents per episode.""" + _LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}") + def action_masks(self) -> np.ndarray: """ Return the action mask for the agent. @@ -108,6 +145,8 @@ class PrimaiteGymEnv(gymnasium.Env): f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {self.agent.reward_function.total_reward}" ) + if seed is not None: + set_random_seed(seed) self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward if self.io.settings.save_agent_actions: diff --git a/src/primaite/session/ray_envs.py b/src/primaite/session/ray_envs.py index 1adc324c..33c74b0e 100644 --- a/src/primaite/session/ray_envs.py +++ b/src/primaite/session/ray_envs.py @@ -63,6 +63,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" + super().reset() # Ensure PRNG seed is set everywhere rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()} _LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}") @@ -176,6 +177,7 @@ class PrimaiteRayEnv(gymnasium.Env): def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" + super().reset() # Ensure PRNG seed is set everywhere if self.env.agent.action_masking: obs, *_ = self.env.reset(seed=seed) new_obs = {"action_mask": self.env.action_masks(), "observations": obs} diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 06d22126..e6cfa343 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -2,7 +2,7 @@ from __future__ import annotations from ipaddress import IPv4Address -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from uuid import uuid4 from prettytable import MARKDOWN, PrettyTable @@ -54,6 +54,12 @@ class DatabaseClientConnection(BaseModel): if self.client and self.is_active: self.client._disconnect(self.connection_id) # noqa + def __str__(self) -> str: + return f"{self.__class__.__name__}(connection_id='{self.connection_id}', is_active={self.is_active})" + + def __repr__(self) -> str: + return str(self) + class DatabaseClient(Application, identifier="DatabaseClient"): """ @@ -76,7 +82,7 @@ class DatabaseClient(Application, identifier="DatabaseClient"): """Connection ID to the Database Server.""" client_connections: Dict[str, DatabaseClientConnection] = {} """Keep track of active connections to Database Server.""" - _client_connection_requests: Dict[str, Optional[str]] = {} + _client_connection_requests: Dict[str, Optional[Union[str, DatabaseClientConnection]]] = {} """Dictionary of connection requests to Database Server.""" connected: bool = False """Boolean Value for whether connected to DB Server.""" @@ -187,7 +193,7 @@ class DatabaseClient(Application, identifier="DatabaseClient"): return False return self._query("SELECT * FROM pg_stat_activity", connection_id=connection_id) - def _check_client_connection(self, connection_id: str) -> bool: + def _validate_client_connection_request(self, connection_id: str) -> bool: """Check that client_connection_id is valid.""" return True if connection_id in self._client_connection_requests else False @@ -211,23 +217,30 @@ class DatabaseClient(Application, identifier="DatabaseClient"): :type: is_reattempt: Optional[bool] """ if is_reattempt: - valid_connection = self._check_client_connection(connection_id=connection_request_id) - if valid_connection: + valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id) + if valid_connection_request: database_client_connection = self._client_connection_requests.pop(connection_request_id) - self.sys_log.info( - f"{self.name}: DatabaseClient connection to {server_ip_address} authorised." - f"Connection Request ID was {connection_request_id}." - ) - self.connected = True - self._last_connection_successful = True - return database_client_connection + if isinstance(database_client_connection, DatabaseClientConnection): + self.sys_log.info( + f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} authorised. " + f"Using connection id {database_client_connection}" + ) + self.connected = True + self._last_connection_successful = True + return database_client_connection + else: + self.sys_log.info( + f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined" + ) + self._last_connection_successful = False + return None else: - self.sys_log.warning( - f"{self.name}: DatabaseClient connection to {server_ip_address} declined." - f"Connection Request ID was {connection_request_id}." + self.sys_log.info( + f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined " + f"due to unknown client-side connection request id" ) - self._last_connection_successful = False return None + payload = {"type": "connect_request", "password": password, "connection_request_id": connection_request_id} software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( @@ -300,9 +313,14 @@ class DatabaseClient(Application, identifier="DatabaseClient"): """ if not self._can_perform_action(): return None + connection_request_id = str(uuid4()) self._client_connection_requests[connection_request_id] = None + self.sys_log.info( + f"{self.name}: Sending new connection request ({connection_request_id}) to {self.server_ip_address}" + ) + return self._connect( server_ip_address=self.server_ip_address, password=self.server_password, diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 22ae0ff3..74ef51ee 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -191,12 +191,16 @@ class DatabaseService(Service): :return: Response to connection request containing success info. :rtype: Dict[str, Union[int, Dict[str, bool]]] """ + self.sys_log.info(f"{self.name}: Processing new connection request ({connection_request_id}) from {src_ip}") status_code = 500 # Default internal server error connection_id = None if self.operating_state == ServiceOperatingState.RUNNING: status_code = 503 # service unavailable if self.health_state_actual == SoftwareHealthState.OVERWHELMED: - self.sys_log.error(f"{self.name}: Connect request for {src_ip=} declined. Service is at capacity.") + self.sys_log.info( + f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, service is at " + f"capacity." + ) if self.health_state_actual in [ SoftwareHealthState.GOOD, SoftwareHealthState.FIXING, @@ -208,12 +212,16 @@ class DatabaseService(Service): # try to create connection if not self.add_connection(connection_id=connection_id, session_id=session_id): status_code = 500 - self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined") - else: - self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised") + self.sys_log.info( + f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, " + f"returning status code 500" + ) else: status_code = 401 # Unauthorised - self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined") + self.sys_log.info( + f"{self.name}: Connection request ({connection_request_id}) from {src_ip} unauthorised " + f"(incorrect password), returning status code 401" + ) else: status_code = 404 # service not found return { diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 7c27534a..efa8c9b1 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -313,7 +313,7 @@ class IOSoftware(Software): # if over or at capacity, set to overwhelmed if len(self._connections) >= self.max_sessions: self.set_health_state(SoftwareHealthState.OVERWHELMED) - self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.") + self.sys_log.warning(f"{self.name}: Connection request ({connection_id}) declined. Service is at capacity.") return False else: # if service was previously overwhelmed, set to good because there is enough space for connections @@ -330,11 +330,11 @@ class IOSoftware(Software): "ip_address": session_details.with_ip_address if session_details else None, "time": datetime.now(), } - self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised") + self.sys_log.info(f"{self.name}: Connection request ({connection_id}) authorised") return True # connection with given id already exists self.sys_log.warning( - f"{self.name}: Connect request for {connection_id=} declined. Connection already exists." + f"{self.name}: Connection request ({connection_id}) declined. Connection already exists." ) return False diff --git a/tests/integration_tests/game_layer/test_RNG_seed.py b/tests/integration_tests/game_layer/test_RNG_seed.py new file mode 100644 index 00000000..0c6d567d --- /dev/null +++ b/tests/integration_tests/game_layer/test_RNG_seed.py @@ -0,0 +1,50 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from pprint import pprint + +import pytest +import yaml + +from primaite.config.load import data_manipulation_config_path +from primaite.game.agent.interface import AgentHistoryItem +from primaite.session.environment import PrimaiteGymEnv + + +@pytest.fixture() +def create_env(): + with open(data_manipulation_config_path(), "r") as f: + cfg = yaml.safe_load(f) + + env = PrimaiteGymEnv(env_config=cfg) + return env + + +def test_rng_seed_set(create_env): + """Test with RNG seed set.""" + env = create_env + env.reset(seed=3) + for i in range(100): + env.step(0) + a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] + + env.reset(seed=3) + for i in range(100): + env.step(0) + b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] + + assert a == b + + +def test_rng_seed_unset(create_env): + """Test with no RNG seed.""" + env = create_env + env.reset() + for i in range(100): + env.step(0) + a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] + + env.reset() + for i in range(100): + env.step(0) + b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] + + assert a != b diff --git a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py index f3b3c6eb..ec18f1fb 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py @@ -62,7 +62,6 @@ def test_probabilistic_agent(): reward_function=reward_function, settings={ "action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE}, - "random_seed": 120, }, )