2025-01-02 15:05:06 +00:00
|
|
|
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
2023-07-18 10:11:01 +01:00
|
|
|
import json
|
|
|
|
|
import shutil
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
from pathlib import Path
|
2024-05-31 13:47:02 +01:00
|
|
|
from typing import Any, Dict, Final, Tuple
|
|
|
|
|
|
2024-08-07 10:07:19 +01:00
|
|
|
from report import build_benchmark_md_report, md2pdf
|
2024-05-31 13:47:02 +01:00
|
|
|
from stable_baselines3 import PPO
|
2023-07-20 08:48:18 +01:00
|
|
|
|
2023-07-18 10:11:01 +01:00
|
|
|
import primaite
|
2024-06-05 11:03:39 +01:00
|
|
|
from benchmark import BenchmarkPrimaiteGymEnv
|
2024-05-31 13:47:02 +01:00
|
|
|
from primaite.config.load import data_manipulation_config_path
|
2023-07-18 10:11:01 +01:00
|
|
|
|
|
|
|
|
_LOGGER = primaite.getLogger(__name__)
|
|
|
|
|
|
2024-06-07 19:59:55 +01:00
|
|
|
_MAJOR_V = primaite.__version__.split(".")[0]
|
|
|
|
|
|
2023-07-20 08:48:18 +01:00
|
|
|
_BENCHMARK_ROOT = Path(__file__).parent
|
2024-06-07 19:59:55 +01:00
|
|
|
_RESULTS_ROOT: Final[Path] = _BENCHMARK_ROOT / "results" / f"v{_MAJOR_V}"
|
|
|
|
|
_VERSION_ROOT: Final[Path] = _RESULTS_ROOT / f"v{primaite.__version__}"
|
|
|
|
|
_SESSION_METADATA_ROOT: Final[Path] = _VERSION_ROOT / "session_metadata"
|
2023-07-18 10:11:01 +01:00
|
|
|
|
2024-06-07 19:59:55 +01:00
|
|
|
_SESSION_METADATA_ROOT.mkdir(parents=True, exist_ok=True)
|
2023-07-18 10:11:01 +01:00
|
|
|
|
2023-07-20 08:48:18 +01:00
|
|
|
|
2024-05-31 13:47:02 +01:00
|
|
|
class BenchmarkSession:
|
|
|
|
|
"""Benchmark Session class."""
|
2023-07-20 08:48:18 +01:00
|
|
|
|
2024-05-31 13:47:02 +01:00
|
|
|
gym_env: BenchmarkPrimaiteGymEnv
|
|
|
|
|
"""Gym environment used by the session to train."""
|
2023-07-20 08:48:18 +01:00
|
|
|
|
2024-05-31 13:47:02 +01:00
|
|
|
num_episodes: int
|
|
|
|
|
"""Number of episodes to run the training session."""
|
2023-07-20 08:48:18 +01:00
|
|
|
|
2024-06-05 19:39:39 +01:00
|
|
|
episode_len: int
|
|
|
|
|
"""The number of steps per episode."""
|
|
|
|
|
|
|
|
|
|
total_steps: int
|
2024-06-01 13:23:27 +01:00
|
|
|
"""Number of steps to run the training session."""
|
|
|
|
|
|
2024-05-31 13:47:02 +01:00
|
|
|
batch_size: int
|
|
|
|
|
"""Number of steps for each episode."""
|
2023-07-20 08:48:18 +01:00
|
|
|
|
2024-06-01 13:23:27 +01:00
|
|
|
learning_rate: float
|
|
|
|
|
"""Learning rate for the model."""
|
|
|
|
|
|
2024-05-31 13:47:02 +01:00
|
|
|
start_time: datetime
|
|
|
|
|
"""Start time for the session."""
|
|
|
|
|
|
|
|
|
|
end_time: datetime
|
|
|
|
|
"""End time for the session."""
|
2023-07-18 10:11:01 +01:00
|
|
|
|
2023-07-20 08:48:18 +01:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
2024-06-05 19:39:39 +01:00
|
|
|
gym_env: BenchmarkPrimaiteGymEnv,
|
|
|
|
|
episode_len: int,
|
|
|
|
|
num_episodes: int,
|
|
|
|
|
n_steps: int,
|
|
|
|
|
batch_size: int,
|
|
|
|
|
learning_rate: float,
|
2024-06-01 13:23:27 +01:00
|
|
|
):
|
2024-05-31 13:47:02 +01:00
|
|
|
"""Initialise the BenchmarkSession."""
|
|
|
|
|
self.gym_env = gym_env
|
2024-06-05 19:39:39 +01:00
|
|
|
self.episode_len = episode_len
|
|
|
|
|
self.n_steps = n_steps
|
2024-05-31 13:47:02 +01:00
|
|
|
self.num_episodes = num_episodes
|
2024-06-05 19:39:39 +01:00
|
|
|
self.total_steps = self.num_episodes * self.episode_len
|
2024-05-31 13:47:02 +01:00
|
|
|
self.batch_size = batch_size
|
2024-06-01 13:23:27 +01:00
|
|
|
self.learning_rate = learning_rate
|
2024-05-31 13:47:02 +01:00
|
|
|
|
|
|
|
|
def train(self):
|
|
|
|
|
"""Run the training session."""
|
|
|
|
|
# start timer for session
|
|
|
|
|
self.start_time = datetime.now()
|
2024-06-05 19:39:39 +01:00
|
|
|
model = PPO(
|
|
|
|
|
policy="MlpPolicy",
|
|
|
|
|
env=self.gym_env,
|
|
|
|
|
learning_rate=self.learning_rate,
|
|
|
|
|
n_steps=self.n_steps,
|
|
|
|
|
batch_size=self.batch_size,
|
|
|
|
|
verbose=0,
|
|
|
|
|
tensorboard_log="./PPO_UC2/",
|
|
|
|
|
)
|
|
|
|
|
model.learn(total_timesteps=self.total_steps)
|
2023-07-20 08:48:18 +01:00
|
|
|
|
2024-05-31 13:47:02 +01:00
|
|
|
# end timer for session
|
|
|
|
|
self.end_time = datetime.now()
|
2023-07-20 08:48:18 +01:00
|
|
|
|
2024-05-31 13:47:02 +01:00
|
|
|
self.session_metadata = self.generate_learn_metadata_dict()
|
2023-07-20 08:48:18 +01:00
|
|
|
|
2023-07-18 10:11:01 +01:00
|
|
|
def _learn_benchmark_durations(self) -> Tuple[float, float, float]:
|
|
|
|
|
"""
|
|
|
|
|
Calculate and return the learning benchmark durations.
|
|
|
|
|
|
|
|
|
|
Calculates the:
|
|
|
|
|
- Total learning time in seconds
|
|
|
|
|
- Total learning time per time step in seconds
|
|
|
|
|
- Total learning time per 100 time steps per 10 nodes in seconds
|
|
|
|
|
|
|
|
|
|
:return: The learning benchmark durations as a Tuple of three floats:
|
|
|
|
|
Tuple[total_s, s_per_step, s_per_100_steps_10_nodes].
|
|
|
|
|
"""
|
2024-05-31 13:47:02 +01:00
|
|
|
delta = self.end_time - self.start_time
|
2023-07-18 10:11:01 +01:00
|
|
|
total_s = delta.total_seconds()
|
|
|
|
|
|
2024-05-31 13:47:02 +01:00
|
|
|
total_steps = self.batch_size * self.num_episodes
|
2023-07-18 10:11:01 +01:00
|
|
|
s_per_step = total_s / total_steps
|
|
|
|
|
|
2024-05-31 13:47:02 +01:00
|
|
|
num_nodes = len(self.gym_env.game.simulation.network.nodes)
|
2023-07-18 10:11:01 +01:00
|
|
|
num_intervals = total_steps / 100
|
|
|
|
|
av_interval_time = total_s / num_intervals
|
|
|
|
|
s_per_100_steps_10_nodes = av_interval_time / (num_nodes / 10)
|
|
|
|
|
|
|
|
|
|
return total_s, s_per_step, s_per_100_steps_10_nodes
|
|
|
|
|
|
2024-05-31 13:47:02 +01:00
|
|
|
def generate_learn_metadata_dict(self) -> Dict[str, Any]:
|
2023-07-18 10:11:01 +01:00
|
|
|
"""Metadata specific to the learning session."""
|
|
|
|
|
total_s, s_per_step, s_per_100_steps_10_nodes = self._learn_benchmark_durations()
|
2024-06-25 16:58:39 +01:00
|
|
|
self.gym_env.total_reward_per_episode.pop(0) # remove episode 0
|
2023-07-18 10:11:01 +01:00
|
|
|
return {
|
2024-05-31 13:47:02 +01:00
|
|
|
"total_episodes": self.gym_env.episode_counter,
|
|
|
|
|
"total_time_steps": self.gym_env.total_time_steps,
|
2023-07-18 10:11:01 +01:00
|
|
|
"total_s": total_s,
|
|
|
|
|
"s_per_step": s_per_step,
|
|
|
|
|
"s_per_100_steps_10_nodes": s_per_100_steps_10_nodes,
|
2024-06-26 13:26:18 +01:00
|
|
|
"total_reward_per_episode": self.gym_env.total_reward_per_episode,
|
2023-07-18 10:11:01 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
2024-06-01 13:23:27 +01:00
|
|
|
def _get_benchmark_primaite_environment() -> BenchmarkPrimaiteGymEnv:
|
2024-05-31 13:47:02 +01:00
|
|
|
"""
|
|
|
|
|
Create an instance of the BenchmarkPrimaiteGymEnv.
|
2023-07-20 08:48:18 +01:00
|
|
|
|
2024-05-31 13:47:02 +01:00
|
|
|
This environment will be used to train the agents on.
|
|
|
|
|
"""
|
2024-05-31 15:20:10 +01:00
|
|
|
env = BenchmarkPrimaiteGymEnv(env_config=data_manipulation_config_path())
|
|
|
|
|
return env
|
2023-07-20 08:48:18 +01:00
|
|
|
|
|
|
|
|
|
2024-05-31 13:47:02 +01:00
|
|
|
def _prepare_session_directory():
|
|
|
|
|
"""Prepare the session directory so that it is easier to clean up after the benchmarking is done."""
|
|
|
|
|
# override session path
|
|
|
|
|
session_path = _BENCHMARK_ROOT / "sessions"
|
2023-07-20 08:48:18 +01:00
|
|
|
|
2024-05-31 13:47:02 +01:00
|
|
|
if session_path.is_dir():
|
|
|
|
|
shutil.rmtree(session_path)
|
2023-07-20 08:48:18 +01:00
|
|
|
|
2024-05-31 13:47:02 +01:00
|
|
|
primaite.PRIMAITE_PATHS.user_sessions_path = session_path
|
|
|
|
|
primaite.PRIMAITE_PATHS.user_sessions_path.mkdir(exist_ok=True, parents=True)
|
2023-07-20 08:48:18 +01:00
|
|
|
|
|
|
|
|
|
2024-05-31 15:20:10 +01:00
|
|
|
def run(
|
2024-06-26 12:52:29 +01:00
|
|
|
number_of_sessions: int = 5,
|
|
|
|
|
num_episodes: int = 1000,
|
2024-06-05 19:39:39 +01:00
|
|
|
episode_len: int = 128,
|
|
|
|
|
n_steps: int = 1280,
|
|
|
|
|
batch_size: int = 32,
|
2024-06-01 13:23:27 +01:00
|
|
|
learning_rate: float = 3e-4,
|
2024-06-05 19:39:39 +01:00
|
|
|
) -> None:
|
2024-05-31 13:47:02 +01:00
|
|
|
"""Run the PrimAITE benchmark."""
|
2024-08-07 10:07:19 +01:00
|
|
|
# generate report folder
|
|
|
|
|
v_str = f"v{primaite.__version__}"
|
|
|
|
|
|
|
|
|
|
version_result_dir = _RESULTS_ROOT / v_str
|
|
|
|
|
version_result_dir.mkdir(exist_ok=True, parents=True)
|
|
|
|
|
output_path = version_result_dir / f"PrimAITE {v_str} Benchmark Report.md"
|
|
|
|
|
|
2024-05-31 13:47:02 +01:00
|
|
|
benchmark_start_time = datetime.now()
|
2023-07-20 08:48:18 +01:00
|
|
|
|
2024-05-31 13:47:02 +01:00
|
|
|
session_metadata_dict = {}
|
2023-07-20 08:48:18 +01:00
|
|
|
|
2024-05-31 13:47:02 +01:00
|
|
|
_prepare_session_directory()
|
2023-07-20 08:48:18 +01:00
|
|
|
|
2024-05-31 13:47:02 +01:00
|
|
|
# run training
|
|
|
|
|
for i in range(1, number_of_sessions + 1):
|
2023-07-20 08:48:18 +01:00
|
|
|
print(f"Starting Benchmark Session: {i}")
|
2023-07-18 10:11:01 +01:00
|
|
|
|
2024-06-01 13:23:27 +01:00
|
|
|
with _get_benchmark_primaite_environment() as gym_env:
|
|
|
|
|
session = BenchmarkSession(
|
|
|
|
|
gym_env=gym_env,
|
|
|
|
|
num_episodes=num_episodes,
|
2024-06-05 19:39:39 +01:00
|
|
|
n_steps=n_steps,
|
|
|
|
|
episode_len=episode_len,
|
2024-06-01 13:23:27 +01:00
|
|
|
batch_size=batch_size,
|
|
|
|
|
learning_rate=learning_rate,
|
|
|
|
|
)
|
2024-05-31 13:47:02 +01:00
|
|
|
session.train()
|
|
|
|
|
|
2024-06-07 19:59:55 +01:00
|
|
|
# Dump the session metadata so that we're not holding it in memory as it's large
|
|
|
|
|
with open(_SESSION_METADATA_ROOT / f"{i}.json", "w") as file:
|
|
|
|
|
json.dump(session.session_metadata, file, indent=4)
|
|
|
|
|
|
|
|
|
|
for i in range(1, number_of_sessions + 1):
|
|
|
|
|
with open(_SESSION_METADATA_ROOT / f"{i}.json", "r") as file:
|
|
|
|
|
session_metadata_dict[i] = json.load(file)
|
2024-05-31 13:47:02 +01:00
|
|
|
# generate report
|
2024-07-29 08:52:16 +01:00
|
|
|
build_benchmark_md_report(
|
2024-05-31 13:47:02 +01:00
|
|
|
benchmark_start_time=benchmark_start_time,
|
|
|
|
|
session_metadata=session_metadata_dict,
|
|
|
|
|
config_path=data_manipulation_config_path(),
|
|
|
|
|
results_root_path=_RESULTS_ROOT,
|
2024-08-07 10:07:19 +01:00
|
|
|
output_path=output_path,
|
|
|
|
|
)
|
|
|
|
|
md2pdf(
|
|
|
|
|
md_path=output_path,
|
|
|
|
|
pdf_path=str(output_path).replace(".md", ".pdf"),
|
2024-08-07 13:18:20 +00:00
|
|
|
css_path="static/styles.css",
|
2023-07-20 08:48:18 +01:00
|
|
|
)
|
2023-07-18 10:11:01 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
run()
|