#2706 - Merge branch 'dev' into feature/2706-Terminal_Sim_Component
This commit is contained in:
@@ -6,8 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- Random Number Generator Seeding by specifying a random number seed in the config file.
|
||||
- Implemented Terminal service class, providing a generic terminal simulation.
|
||||
|
||||
### Changed
|
||||
@@ -25,7 +25,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Tests to verify that airspace bandwidth is applied correctly and can be configured via YAML
|
||||
- Agent logging for agents' internal decision logic
|
||||
- Action masking in all PrimAITE environments
|
||||
|
||||
### Changed
|
||||
- Application registry was moved to the `Application` class and now updates automatically when Application is subclassed
|
||||
- Databases can no longer respond to request while performing a backup
|
||||
|
||||
@@ -22,8 +22,6 @@ class ProbabilisticAgent(AbstractScriptedAgent):
|
||||
"""Strict validation."""
|
||||
action_probabilities: Dict[int, float]
|
||||
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
|
||||
random_seed: Optional[int] = None
|
||||
"""Random seed. If set, each episode the agent will choose the same random sequence of actions."""
|
||||
# TODO: give the option to still set a random seed, but have it vary each episode in a predictable way
|
||||
# for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed.
|
||||
|
||||
@@ -59,17 +57,18 @@ class ProbabilisticAgent(AbstractScriptedAgent):
|
||||
num_actions = len(action_space.action_map)
|
||||
settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}}
|
||||
|
||||
# If seed not specified, set it to None so that numpy chooses a random one.
|
||||
settings.setdefault("random_seed")
|
||||
|
||||
# The random number seed for np.random is dependent on whether a random number seed is set
|
||||
# in the config file. If there is one it is processed by set_random_seed() in environment.py
|
||||
# and as a consequence the the sequence of rng_seed's used here will be repeatable.
|
||||
self.settings = ProbabilisticAgent.Settings(**settings)
|
||||
|
||||
self.rng = np.random.default_rng(self.settings.random_seed)
|
||||
rng_seed = np.random.randint(0, 65535)
|
||||
self.rng = np.random.default_rng(rng_seed)
|
||||
|
||||
# convert probabilities from
|
||||
self.probabilities = np.asarray(list(self.settings.action_probabilities.values()))
|
||||
|
||||
super().__init__(agent_name, action_space, observation_space, reward_function)
|
||||
self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}")
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
|
||||
@@ -72,6 +72,8 @@ class PrimaiteGameOptions(BaseModel):
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
seed: int = None
|
||||
"""Random number seed for RNGs."""
|
||||
max_episode_length: int = 256
|
||||
"""Maximum number of episodes for the PrimAITE game."""
|
||||
ports: List[str]
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
from os import PathLike
|
||||
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
|
||||
|
||||
@@ -17,6 +19,36 @@ from primaite.simulator.system.core.packet_capture import PacketCapture
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
# Check torch is installed
|
||||
try:
|
||||
import torch as th
|
||||
except ModuleNotFoundError:
|
||||
_LOGGER.debug("Torch not available for importing")
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> Union[None, int]:
|
||||
"""
|
||||
Set random number generators.
|
||||
|
||||
:param seed: int
|
||||
"""
|
||||
if seed is None or seed == -1:
|
||||
return None
|
||||
elif seed < -1:
|
||||
raise ValueError("Invalid random number seed")
|
||||
# Seed python RNG
|
||||
random.seed(seed)
|
||||
# Seed numpy RNG
|
||||
np.random.seed(seed)
|
||||
# Seed the RNG for all devices (both CPU and CUDA)
|
||||
# if torch not installed don't set random seed.
|
||||
if sys.modules["torch"]:
|
||||
th.manual_seed(seed)
|
||||
th.backends.cudnn.deterministic = True
|
||||
th.backends.cudnn.benchmark = False
|
||||
|
||||
return seed
|
||||
|
||||
|
||||
class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""
|
||||
@@ -31,6 +63,9 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
super().__init__()
|
||||
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
|
||||
"""Object that returns a config corresponding to the current episode."""
|
||||
self.seed = self.episode_scheduler(0).get("game", {}).get("seed")
|
||||
"""Get RNG seed from config file. NB: Must be before game instantiation."""
|
||||
self.seed = set_random_seed(self.seed)
|
||||
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
|
||||
@@ -42,6 +77,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
self.total_reward_per_episode: Dict[int, float] = {}
|
||||
"""Average rewards of agents per episode."""
|
||||
|
||||
_LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}")
|
||||
|
||||
def action_masks(self) -> np.ndarray:
|
||||
"""
|
||||
Return the action mask for the agent.
|
||||
@@ -108,6 +145,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
f"Resetting environment, episode {self.episode_counter}, "
|
||||
f"avg. reward: {self.agent.reward_function.total_reward}"
|
||||
)
|
||||
if seed is not None:
|
||||
set_random_seed(seed)
|
||||
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
|
||||
|
||||
if self.io.settings.save_agent_actions:
|
||||
|
||||
@@ -63,6 +63,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
super().reset() # Ensure PRNG seed is set everywhere
|
||||
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
|
||||
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
|
||||
|
||||
@@ -176,6 +177,7 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
super().reset() # Ensure PRNG seed is set everywhere
|
||||
if self.env.agent.action_masking:
|
||||
obs, *_ = self.env.reset(seed=seed)
|
||||
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
@@ -54,6 +54,12 @@ class DatabaseClientConnection(BaseModel):
|
||||
if self.client and self.is_active:
|
||||
self.client._disconnect(self.connection_id) # noqa
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}(connection_id='{self.connection_id}', is_active={self.is_active})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
|
||||
|
||||
class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
"""
|
||||
@@ -76,7 +82,7 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
"""Connection ID to the Database Server."""
|
||||
client_connections: Dict[str, DatabaseClientConnection] = {}
|
||||
"""Keep track of active connections to Database Server."""
|
||||
_client_connection_requests: Dict[str, Optional[str]] = {}
|
||||
_client_connection_requests: Dict[str, Optional[Union[str, DatabaseClientConnection]]] = {}
|
||||
"""Dictionary of connection requests to Database Server."""
|
||||
connected: bool = False
|
||||
"""Boolean Value for whether connected to DB Server."""
|
||||
@@ -187,7 +193,7 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
return False
|
||||
return self._query("SELECT * FROM pg_stat_activity", connection_id=connection_id)
|
||||
|
||||
def _check_client_connection(self, connection_id: str) -> bool:
|
||||
def _validate_client_connection_request(self, connection_id: str) -> bool:
|
||||
"""Check that client_connection_id is valid."""
|
||||
return True if connection_id in self._client_connection_requests else False
|
||||
|
||||
@@ -211,23 +217,30 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
:type: is_reattempt: Optional[bool]
|
||||
"""
|
||||
if is_reattempt:
|
||||
valid_connection = self._check_client_connection(connection_id=connection_request_id)
|
||||
if valid_connection:
|
||||
valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id)
|
||||
if valid_connection_request:
|
||||
database_client_connection = self._client_connection_requests.pop(connection_request_id)
|
||||
self.sys_log.info(
|
||||
f"{self.name}: DatabaseClient connection to {server_ip_address} authorised."
|
||||
f"Connection Request ID was {connection_request_id}."
|
||||
)
|
||||
self.connected = True
|
||||
self._last_connection_successful = True
|
||||
return database_client_connection
|
||||
if isinstance(database_client_connection, DatabaseClientConnection):
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} authorised. "
|
||||
f"Using connection id {database_client_connection}"
|
||||
)
|
||||
self.connected = True
|
||||
self._last_connection_successful = True
|
||||
return database_client_connection
|
||||
else:
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined"
|
||||
)
|
||||
self._last_connection_successful = False
|
||||
return None
|
||||
else:
|
||||
self.sys_log.warning(
|
||||
f"{self.name}: DatabaseClient connection to {server_ip_address} declined."
|
||||
f"Connection Request ID was {connection_request_id}."
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined "
|
||||
f"due to unknown client-side connection request id"
|
||||
)
|
||||
self._last_connection_successful = False
|
||||
return None
|
||||
|
||||
payload = {"type": "connect_request", "password": password, "connection_request_id": connection_request_id}
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
@@ -300,9 +313,14 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return None
|
||||
|
||||
connection_request_id = str(uuid4())
|
||||
self._client_connection_requests[connection_request_id] = None
|
||||
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Sending new connection request ({connection_request_id}) to {self.server_ip_address}"
|
||||
)
|
||||
|
||||
return self._connect(
|
||||
server_ip_address=self.server_ip_address,
|
||||
password=self.server_password,
|
||||
|
||||
@@ -191,12 +191,16 @@ class DatabaseService(Service):
|
||||
:return: Response to connection request containing success info.
|
||||
:rtype: Dict[str, Union[int, Dict[str, bool]]]
|
||||
"""
|
||||
self.sys_log.info(f"{self.name}: Processing new connection request ({connection_request_id}) from {src_ip}")
|
||||
status_code = 500 # Default internal server error
|
||||
connection_id = None
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
status_code = 503 # service unavailable
|
||||
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
|
||||
self.sys_log.error(f"{self.name}: Connect request for {src_ip=} declined. Service is at capacity.")
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, service is at "
|
||||
f"capacity."
|
||||
)
|
||||
if self.health_state_actual in [
|
||||
SoftwareHealthState.GOOD,
|
||||
SoftwareHealthState.FIXING,
|
||||
@@ -208,12 +212,16 @@ class DatabaseService(Service):
|
||||
# try to create connection
|
||||
if not self.add_connection(connection_id=connection_id, session_id=session_id):
|
||||
status_code = 500
|
||||
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined")
|
||||
else:
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, "
|
||||
f"returning status code 500"
|
||||
)
|
||||
else:
|
||||
status_code = 401 # Unauthorised
|
||||
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined")
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} unauthorised "
|
||||
f"(incorrect password), returning status code 401"
|
||||
)
|
||||
else:
|
||||
status_code = 404 # service not found
|
||||
return {
|
||||
|
||||
@@ -313,7 +313,7 @@ class IOSoftware(Software):
|
||||
# if over or at capacity, set to overwhelmed
|
||||
if len(self._connections) >= self.max_sessions:
|
||||
self.set_health_state(SoftwareHealthState.OVERWHELMED)
|
||||
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.")
|
||||
self.sys_log.warning(f"{self.name}: Connection request ({connection_id}) declined. Service is at capacity.")
|
||||
return False
|
||||
else:
|
||||
# if service was previously overwhelmed, set to good because there is enough space for connections
|
||||
@@ -330,11 +330,11 @@ class IOSoftware(Software):
|
||||
"ip_address": session_details.with_ip_address if session_details else None,
|
||||
"time": datetime.now(),
|
||||
}
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
|
||||
self.sys_log.info(f"{self.name}: Connection request ({connection_id}) authorised")
|
||||
return True
|
||||
# connection with given id already exists
|
||||
self.sys_log.warning(
|
||||
f"{self.name}: Connect request for {connection_id=} declined. Connection already exists."
|
||||
f"{self.name}: Connection request ({connection_id}) declined. Connection already exists."
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
50
tests/integration_tests/game_layer/test_RNG_seed.py
Normal file
50
tests/integration_tests/game_layer/test_RNG_seed.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from pprint import pprint
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.agent.interface import AgentHistoryItem
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def create_env():
|
||||
with open(data_manipulation_config_path(), "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
return env
|
||||
|
||||
|
||||
def test_rng_seed_set(create_env):
|
||||
"""Test with RNG seed set."""
|
||||
env = create_env
|
||||
env.reset(seed=3)
|
||||
for i in range(100):
|
||||
env.step(0)
|
||||
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
|
||||
|
||||
env.reset(seed=3)
|
||||
for i in range(100):
|
||||
env.step(0)
|
||||
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
|
||||
|
||||
assert a == b
|
||||
|
||||
|
||||
def test_rng_seed_unset(create_env):
|
||||
"""Test with no RNG seed."""
|
||||
env = create_env
|
||||
env.reset()
|
||||
for i in range(100):
|
||||
env.step(0)
|
||||
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
|
||||
|
||||
env.reset()
|
||||
for i in range(100):
|
||||
env.step(0)
|
||||
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
|
||||
|
||||
assert a != b
|
||||
@@ -62,7 +62,6 @@ def test_probabilistic_agent():
|
||||
reward_function=reward_function,
|
||||
settings={
|
||||
"action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE},
|
||||
"random_seed": 120,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user