214 lines
8.6 KiB
Python
214 lines
8.6 KiB
Python
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
|
import json
|
|
import random
|
|
import sys
|
|
from os import PathLike
|
|
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
|
|
|
|
import gymnasium
|
|
import numpy as np
|
|
from gymnasium.core import ActType, ObsType
|
|
|
|
from primaite import getLogger
|
|
from primaite.game.agent.interface import ProxyAgent
|
|
from primaite.game.game import PrimaiteGame
|
|
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler
|
|
from primaite.session.io import PrimaiteIO
|
|
from primaite.simulator import SIM_OUTPUT
|
|
from primaite.simulator.system.core.packet_capture import PacketCapture
|
|
|
|
_LOGGER = getLogger(__name__)
|
|
|
|
# Check torch is installed
|
|
try:
|
|
import torch as th
|
|
except ModuleNotFoundError:
|
|
_LOGGER.debug("Torch not available for importing")
|
|
|
|
|
|
def set_random_seed(seed: int, generate_seed_value: bool) -> Union[None, int]:
|
|
"""
|
|
Set random number generators.
|
|
|
|
If seed is None or -1 and generate_seed_value is True randomly generate a
|
|
seed value.
|
|
If seed is > -1 and generate_seed_value is True ignore the latter and use
|
|
the provide seed value.
|
|
|
|
:param seed: int
|
|
:param generate_seed_value: bool
|
|
:return: None or the int representing the seed used.
|
|
"""
|
|
if seed is None or seed == -1:
|
|
if generate_seed_value:
|
|
rng = np.random.default_rng()
|
|
# 2**32-1 is highest value for python RNG seed.
|
|
seed = int(rng.integers(low=0, high=2**32 - 1))
|
|
else:
|
|
return None
|
|
elif seed < -1:
|
|
raise ValueError("Invalid random number seed")
|
|
# Seed python RNG
|
|
random.seed(seed)
|
|
# Seed numpy RNG
|
|
np.random.seed(seed)
|
|
# Seed the RNG for all devices (both CPU and CUDA)
|
|
# if torch not installed don't set random seed.
|
|
if sys.modules["torch"]:
|
|
th.manual_seed(seed)
|
|
th.backends.cudnn.deterministic = True
|
|
th.backends.cudnn.benchmark = False
|
|
|
|
return seed
|
|
|
|
|
|
def log_seed_value(seed: int):
|
|
"""Log the selected seed value to file."""
|
|
path = SIM_OUTPUT.path / "seed.log"
|
|
with open(path, "w") as file:
|
|
file.write(f"Seed value = {seed}")
|
|
|
|
|
|
class PrimaiteGymEnv(gymnasium.Env):
|
|
"""
|
|
Thin wrapper env to provide agents with a gymnasium API.
|
|
|
|
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
|
|
assumptions about the agent list always having a list of length 1.
|
|
"""
|
|
|
|
def __init__(self, env_config: Union[Dict, str, PathLike]):
|
|
"""Initialise the environment."""
|
|
super().__init__()
|
|
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
|
|
"""Object that returns a config corresponding to the current episode."""
|
|
self.seed = self.episode_scheduler(0).get("game", {}).get("seed")
|
|
"""Get RNG seed from config file. NB: Must be before game instantiation."""
|
|
self.generate_seed_value = self.episode_scheduler(0).get("game", {}).get("generate_seed_value")
|
|
self.seed = set_random_seed(self.seed, self.generate_seed_value)
|
|
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
|
|
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
|
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
|
|
"""Current game."""
|
|
self._agent_name = next(iter(self.game.rl_agents))
|
|
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
|
|
self.episode_counter: int = 0
|
|
"""Current episode number."""
|
|
self.total_reward_per_episode: Dict[int, float] = {}
|
|
"""Average rewards of agents per episode."""
|
|
|
|
_LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}")
|
|
|
|
log_seed_value(self.seed)
|
|
|
|
def action_masks(self) -> np.ndarray:
|
|
"""
|
|
Return the action mask for the agent.
|
|
|
|
This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be
|
|
performed during this step.
|
|
|
|
:return: Action mask
|
|
:rtype: List[bool]
|
|
"""
|
|
if not self.agent.config.agent_settings.action_masking:
|
|
return np.asarray([True] * len(self.agent.action_manager.action_map))
|
|
else:
|
|
return self.game.action_mask(self._agent_name)
|
|
|
|
@property
|
|
def agent(self) -> ProxyAgent:
|
|
"""Grab a fresh reference to the agent object because it will be reinstantiated each episode."""
|
|
return self.game.rl_agents[self._agent_name]
|
|
|
|
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
|
"""Perform a step in the environment."""
|
|
# make ProxyAgent store the action chosen by the RL policy
|
|
step = self.game.step_counter
|
|
self.agent.store_action(action)
|
|
self.game.pre_timestep()
|
|
# apply_agent_actions accesses the action we just stored
|
|
self.game.apply_agent_actions()
|
|
self.game.advance_timestep()
|
|
state = self.game.get_sim_state()
|
|
self.game.update_agents(state)
|
|
|
|
next_obs = self._get_obs() # this doesn't update observation, just gets the current observation
|
|
reward = self.agent.reward_function.current_reward
|
|
_LOGGER.debug(f"step: {self.game.step_counter}, Blue reward: {reward}")
|
|
terminated = False
|
|
truncated = self.game.calculate_truncated()
|
|
info = {
|
|
"agent_actions": {name: agent.history[-1] for name, agent in self.game.agents.items()}
|
|
} # tell us what all the agents did for convenience.
|
|
if self.game.save_step_metadata:
|
|
self._write_step_metadata_json(step, action, state, reward)
|
|
return next_obs, reward, terminated, truncated, info
|
|
|
|
def _write_step_metadata_json(self, step: int, action: int, state: Dict, reward: int):
|
|
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
path = output_dir / f"step_{step}.json"
|
|
|
|
data = {
|
|
"episode": self.episode_counter,
|
|
"step": step,
|
|
"action": int(action),
|
|
"reward": int(reward),
|
|
"state": state,
|
|
}
|
|
with open(path, "w") as file:
|
|
json.dump(data, file)
|
|
|
|
def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
|
"""Reset the environment."""
|
|
_LOGGER.info(
|
|
f"Resetting environment, episode {self.episode_counter}, "
|
|
f"avg. reward: {self.agent.reward_function.total_reward}"
|
|
)
|
|
if seed is not None:
|
|
set_random_seed(seed, self.generate_seed_value)
|
|
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
|
|
|
|
if self.io.settings.save_agent_actions:
|
|
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
|
|
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
|
|
self.episode_counter += 1
|
|
PacketCapture.clear()
|
|
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.episode_scheduler(self.episode_counter))
|
|
self.game.setup_for_episode(episode=self.episode_counter)
|
|
state = self.game.get_sim_state()
|
|
self.game.update_agents(state=state)
|
|
next_obs = self._get_obs()
|
|
info = {}
|
|
return next_obs, info
|
|
|
|
@property
|
|
def action_space(self) -> gymnasium.Space:
|
|
"""Return the action space of the environment."""
|
|
return self.agent.action_manager.space
|
|
|
|
@property
|
|
def observation_space(self) -> gymnasium.Space:
|
|
"""Return the observation space of the environment."""
|
|
if self.agent.flatten_obs:
|
|
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
|
|
else:
|
|
return self.agent.observation_manager.space
|
|
|
|
def _get_obs(self) -> ObsType:
|
|
"""Return the current observation."""
|
|
if self.agent.flatten_obs:
|
|
unflat_space = self.agent.observation_manager.space
|
|
unflat_obs = self.agent.observation_manager.current_observation
|
|
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
|
else:
|
|
return self.agent.observation_manager.current_observation
|
|
|
|
def close(self):
|
|
"""Close the simulation."""
|
|
if self.io.settings.save_agent_actions:
|
|
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
|
|
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
|