#2706 - Merge branch 'dev' into feature/2706-Terminal_Sim_Component

This commit is contained in:
Charlie Crane
2024-08-07 09:08:43 +01:00
10 changed files with 150 additions and 34 deletions

View File

@@ -6,8 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Added
- Random Number Generator Seeding by specifying a random number seed in the config file.
- Implemented Terminal service class, providing a generic terminal simulation.
### Changed
@@ -25,7 +25,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Tests to verify that airspace bandwidth is applied correctly and can be configured via YAML
- Agent logging for agents' internal decision logic
- Action masking in all PrimAITE environments
### Changed
- Application registry was moved to the `Application` class and now updates automatically when Application is subclassed
- Databases can no longer respond to request while performing a backup

View File

@@ -22,8 +22,6 @@ class ProbabilisticAgent(AbstractScriptedAgent):
"""Strict validation."""
action_probabilities: Dict[int, float]
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
random_seed: Optional[int] = None
"""Random seed. If set, each episode the agent will choose the same random sequence of actions."""
# TODO: give the option to still set a random seed, but have it vary each episode in a predictable way
# for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed.
@@ -59,17 +57,18 @@ class ProbabilisticAgent(AbstractScriptedAgent):
num_actions = len(action_space.action_map)
settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}}
# If seed not specified, set it to None so that numpy chooses a random one.
settings.setdefault("random_seed")
# The random number seed for np.random is dependent on whether a random number seed is set
# in the config file. If there is one it is processed by set_random_seed() in environment.py
# and as a consequence the the sequence of rng_seed's used here will be repeatable.
self.settings = ProbabilisticAgent.Settings(**settings)
self.rng = np.random.default_rng(self.settings.random_seed)
rng_seed = np.random.randint(0, 65535)
self.rng = np.random.default_rng(rng_seed)
# convert probabilities from
self.probabilities = np.asarray(list(self.settings.action_probabilities.values()))
super().__init__(agent_name, action_space, observation_space, reward_function)
self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}")
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""

View File

@@ -72,6 +72,8 @@ class PrimaiteGameOptions(BaseModel):
model_config = ConfigDict(extra="forbid")
seed: int = None
"""Random number seed for RNGs."""
max_episode_length: int = 256
"""Maximum number of episodes for the PrimAITE game."""
ports: List[str]

View File

@@ -1,5 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import json
import random
import sys
from os import PathLike
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
@@ -17,6 +19,36 @@ from primaite.simulator.system.core.packet_capture import PacketCapture
_LOGGER = getLogger(__name__)
# Check torch is installed
try:
import torch as th
except ModuleNotFoundError:
_LOGGER.debug("Torch not available for importing")
def set_random_seed(seed: int) -> Union[None, int]:
"""
Set random number generators.
:param seed: int
"""
if seed is None or seed == -1:
return None
elif seed < -1:
raise ValueError("Invalid random number seed")
# Seed python RNG
random.seed(seed)
# Seed numpy RNG
np.random.seed(seed)
# Seed the RNG for all devices (both CPU and CUDA)
# if torch not installed don't set random seed.
if sys.modules["torch"]:
th.manual_seed(seed)
th.backends.cudnn.deterministic = True
th.backends.cudnn.benchmark = False
return seed
class PrimaiteGymEnv(gymnasium.Env):
"""
@@ -31,6 +63,9 @@ class PrimaiteGymEnv(gymnasium.Env):
super().__init__()
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
"""Object that returns a config corresponding to the current episode."""
self.seed = self.episode_scheduler(0).get("game", {}).get("seed")
"""Get RNG seed from config file. NB: Must be before game instantiation."""
self.seed = set_random_seed(self.seed)
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
@@ -42,6 +77,8 @@ class PrimaiteGymEnv(gymnasium.Env):
self.total_reward_per_episode: Dict[int, float] = {}
"""Average rewards of agents per episode."""
_LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}")
def action_masks(self) -> np.ndarray:
"""
Return the action mask for the agent.
@@ -108,6 +145,8 @@ class PrimaiteGymEnv(gymnasium.Env):
f"Resetting environment, episode {self.episode_counter}, "
f"avg. reward: {self.agent.reward_function.total_reward}"
)
if seed is not None:
set_random_seed(seed)
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
if self.io.settings.save_agent_actions:

View File

@@ -63,6 +63,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
super().reset() # Ensure PRNG seed is set everywhere
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
@@ -176,6 +177,7 @@ class PrimaiteRayEnv(gymnasium.Env):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
super().reset() # Ensure PRNG seed is set everywhere
if self.env.agent.action_masking:
obs, *_ = self.env.reset(seed=seed)
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
from uuid import uuid4
from prettytable import MARKDOWN, PrettyTable
@@ -54,6 +54,12 @@ class DatabaseClientConnection(BaseModel):
if self.client and self.is_active:
self.client._disconnect(self.connection_id) # noqa
def __str__(self) -> str:
return f"{self.__class__.__name__}(connection_id='{self.connection_id}', is_active={self.is_active})"
def __repr__(self) -> str:
return str(self)
class DatabaseClient(Application, identifier="DatabaseClient"):
"""
@@ -76,7 +82,7 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
"""Connection ID to the Database Server."""
client_connections: Dict[str, DatabaseClientConnection] = {}
"""Keep track of active connections to Database Server."""
_client_connection_requests: Dict[str, Optional[str]] = {}
_client_connection_requests: Dict[str, Optional[Union[str, DatabaseClientConnection]]] = {}
"""Dictionary of connection requests to Database Server."""
connected: bool = False
"""Boolean Value for whether connected to DB Server."""
@@ -187,7 +193,7 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
return False
return self._query("SELECT * FROM pg_stat_activity", connection_id=connection_id)
def _check_client_connection(self, connection_id: str) -> bool:
def _validate_client_connection_request(self, connection_id: str) -> bool:
"""Check that client_connection_id is valid."""
return True if connection_id in self._client_connection_requests else False
@@ -211,23 +217,30 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
:type: is_reattempt: Optional[bool]
"""
if is_reattempt:
valid_connection = self._check_client_connection(connection_id=connection_request_id)
if valid_connection:
valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id)
if valid_connection_request:
database_client_connection = self._client_connection_requests.pop(connection_request_id)
self.sys_log.info(
f"{self.name}: DatabaseClient connection to {server_ip_address} authorised."
f"Connection Request ID was {connection_request_id}."
)
self.connected = True
self._last_connection_successful = True
return database_client_connection
if isinstance(database_client_connection, DatabaseClientConnection):
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} authorised. "
f"Using connection id {database_client_connection}"
)
self.connected = True
self._last_connection_successful = True
return database_client_connection
else:
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined"
)
self._last_connection_successful = False
return None
else:
self.sys_log.warning(
f"{self.name}: DatabaseClient connection to {server_ip_address} declined."
f"Connection Request ID was {connection_request_id}."
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined "
f"due to unknown client-side connection request id"
)
self._last_connection_successful = False
return None
payload = {"type": "connect_request", "password": password, "connection_request_id": connection_request_id}
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
@@ -300,9 +313,14 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
"""
if not self._can_perform_action():
return None
connection_request_id = str(uuid4())
self._client_connection_requests[connection_request_id] = None
self.sys_log.info(
f"{self.name}: Sending new connection request ({connection_request_id}) to {self.server_ip_address}"
)
return self._connect(
server_ip_address=self.server_ip_address,
password=self.server_password,

View File

@@ -191,12 +191,16 @@ class DatabaseService(Service):
:return: Response to connection request containing success info.
:rtype: Dict[str, Union[int, Dict[str, bool]]]
"""
self.sys_log.info(f"{self.name}: Processing new connection request ({connection_request_id}) from {src_ip}")
status_code = 500 # Default internal server error
connection_id = None
if self.operating_state == ServiceOperatingState.RUNNING:
status_code = 503 # service unavailable
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
self.sys_log.error(f"{self.name}: Connect request for {src_ip=} declined. Service is at capacity.")
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, service is at "
f"capacity."
)
if self.health_state_actual in [
SoftwareHealthState.GOOD,
SoftwareHealthState.FIXING,
@@ -208,12 +212,16 @@ class DatabaseService(Service):
# try to create connection
if not self.add_connection(connection_id=connection_id, session_id=session_id):
status_code = 500
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined")
else:
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, "
f"returning status code 500"
)
else:
status_code = 401 # Unauthorised
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined")
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} unauthorised "
f"(incorrect password), returning status code 401"
)
else:
status_code = 404 # service not found
return {

View File

@@ -313,7 +313,7 @@ class IOSoftware(Software):
# if over or at capacity, set to overwhelmed
if len(self._connections) >= self.max_sessions:
self.set_health_state(SoftwareHealthState.OVERWHELMED)
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.")
self.sys_log.warning(f"{self.name}: Connection request ({connection_id}) declined. Service is at capacity.")
return False
else:
# if service was previously overwhelmed, set to good because there is enough space for connections
@@ -330,11 +330,11 @@ class IOSoftware(Software):
"ip_address": session_details.with_ip_address if session_details else None,
"time": datetime.now(),
}
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
self.sys_log.info(f"{self.name}: Connection request ({connection_id}) authorised")
return True
# connection with given id already exists
self.sys_log.warning(
f"{self.name}: Connect request for {connection_id=} declined. Connection already exists."
f"{self.name}: Connection request ({connection_id}) declined. Connection already exists."
)
return False

View File

@@ -0,0 +1,50 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from pprint import pprint
import pytest
import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.agent.interface import AgentHistoryItem
from primaite.session.environment import PrimaiteGymEnv
@pytest.fixture()
def create_env():
with open(data_manipulation_config_path(), "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(env_config=cfg)
return env
def test_rng_seed_set(create_env):
"""Test with RNG seed set."""
env = create_env
env.reset(seed=3)
for i in range(100):
env.step(0)
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
env.reset(seed=3)
for i in range(100):
env.step(0)
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
assert a == b
def test_rng_seed_unset(create_env):
"""Test with no RNG seed."""
env = create_env
env.reset()
for i in range(100):
env.step(0)
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
env.reset()
for i in range(100):
env.step(0)
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
assert a != b

View File

@@ -62,7 +62,6 @@ def test_probabilistic_agent():
reward_function=reward_function,
settings={
"action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE},
"random_seed": 120,
},
)