Files
PrimAITE/benchmark/primaite_benchmark.py
2025-01-02 15:05:06 +00:00

214 lines
6.8 KiB
Python

# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
import json
import shutil
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Final, Tuple
from report import build_benchmark_md_report, md2pdf
from stable_baselines3 import PPO
import primaite
from benchmark import BenchmarkPrimaiteGymEnv
from primaite.config.load import data_manipulation_config_path
_LOGGER = primaite.getLogger(__name__)
_MAJOR_V = primaite.__version__.split(".")[0]
_BENCHMARK_ROOT = Path(__file__).parent
_RESULTS_ROOT: Final[Path] = _BENCHMARK_ROOT / "results" / f"v{_MAJOR_V}"
_VERSION_ROOT: Final[Path] = _RESULTS_ROOT / f"v{primaite.__version__}"
_SESSION_METADATA_ROOT: Final[Path] = _VERSION_ROOT / "session_metadata"
_SESSION_METADATA_ROOT.mkdir(parents=True, exist_ok=True)
class BenchmarkSession:
"""Benchmark Session class."""
gym_env: BenchmarkPrimaiteGymEnv
"""Gym environment used by the session to train."""
num_episodes: int
"""Number of episodes to run the training session."""
episode_len: int
"""The number of steps per episode."""
total_steps: int
"""Number of steps to run the training session."""
batch_size: int
"""Number of steps for each episode."""
learning_rate: float
"""Learning rate for the model."""
start_time: datetime
"""Start time for the session."""
end_time: datetime
"""End time for the session."""
def __init__(
self,
gym_env: BenchmarkPrimaiteGymEnv,
episode_len: int,
num_episodes: int,
n_steps: int,
batch_size: int,
learning_rate: float,
):
"""Initialise the BenchmarkSession."""
self.gym_env = gym_env
self.episode_len = episode_len
self.n_steps = n_steps
self.num_episodes = num_episodes
self.total_steps = self.num_episodes * self.episode_len
self.batch_size = batch_size
self.learning_rate = learning_rate
def train(self):
"""Run the training session."""
# start timer for session
self.start_time = datetime.now()
model = PPO(
policy="MlpPolicy",
env=self.gym_env,
learning_rate=self.learning_rate,
n_steps=self.n_steps,
batch_size=self.batch_size,
verbose=0,
tensorboard_log="./PPO_UC2/",
)
model.learn(total_timesteps=self.total_steps)
# end timer for session
self.end_time = datetime.now()
self.session_metadata = self.generate_learn_metadata_dict()
def _learn_benchmark_durations(self) -> Tuple[float, float, float]:
"""
Calculate and return the learning benchmark durations.
Calculates the:
- Total learning time in seconds
- Total learning time per time step in seconds
- Total learning time per 100 time steps per 10 nodes in seconds
:return: The learning benchmark durations as a Tuple of three floats:
Tuple[total_s, s_per_step, s_per_100_steps_10_nodes].
"""
delta = self.end_time - self.start_time
total_s = delta.total_seconds()
total_steps = self.batch_size * self.num_episodes
s_per_step = total_s / total_steps
num_nodes = len(self.gym_env.game.simulation.network.nodes)
num_intervals = total_steps / 100
av_interval_time = total_s / num_intervals
s_per_100_steps_10_nodes = av_interval_time / (num_nodes / 10)
return total_s, s_per_step, s_per_100_steps_10_nodes
def generate_learn_metadata_dict(self) -> Dict[str, Any]:
"""Metadata specific to the learning session."""
total_s, s_per_step, s_per_100_steps_10_nodes = self._learn_benchmark_durations()
self.gym_env.total_reward_per_episode.pop(0) # remove episode 0
return {
"total_episodes": self.gym_env.episode_counter,
"total_time_steps": self.gym_env.total_time_steps,
"total_s": total_s,
"s_per_step": s_per_step,
"s_per_100_steps_10_nodes": s_per_100_steps_10_nodes,
"total_reward_per_episode": self.gym_env.total_reward_per_episode,
}
def _get_benchmark_primaite_environment() -> BenchmarkPrimaiteGymEnv:
"""
Create an instance of the BenchmarkPrimaiteGymEnv.
This environment will be used to train the agents on.
"""
env = BenchmarkPrimaiteGymEnv(env_config=data_manipulation_config_path())
return env
def _prepare_session_directory():
"""Prepare the session directory so that it is easier to clean up after the benchmarking is done."""
# override session path
session_path = _BENCHMARK_ROOT / "sessions"
if session_path.is_dir():
shutil.rmtree(session_path)
primaite.PRIMAITE_PATHS.user_sessions_path = session_path
primaite.PRIMAITE_PATHS.user_sessions_path.mkdir(exist_ok=True, parents=True)
def run(
number_of_sessions: int = 5,
num_episodes: int = 1000,
episode_len: int = 128,
n_steps: int = 1280,
batch_size: int = 32,
learning_rate: float = 3e-4,
) -> None:
"""Run the PrimAITE benchmark."""
# generate report folder
v_str = f"v{primaite.__version__}"
version_result_dir = _RESULTS_ROOT / v_str
version_result_dir.mkdir(exist_ok=True, parents=True)
output_path = version_result_dir / f"PrimAITE {v_str} Benchmark Report.md"
benchmark_start_time = datetime.now()
session_metadata_dict = {}
_prepare_session_directory()
# run training
for i in range(1, number_of_sessions + 1):
print(f"Starting Benchmark Session: {i}")
with _get_benchmark_primaite_environment() as gym_env:
session = BenchmarkSession(
gym_env=gym_env,
num_episodes=num_episodes,
n_steps=n_steps,
episode_len=episode_len,
batch_size=batch_size,
learning_rate=learning_rate,
)
session.train()
# Dump the session metadata so that we're not holding it in memory as it's large
with open(_SESSION_METADATA_ROOT / f"{i}.json", "w") as file:
json.dump(session.session_metadata, file, indent=4)
for i in range(1, number_of_sessions + 1):
with open(_SESSION_METADATA_ROOT / f"{i}.json", "r") as file:
session_metadata_dict[i] = json.load(file)
# generate report
build_benchmark_md_report(
benchmark_start_time=benchmark_start_time,
session_metadata=session_metadata_dict,
config_path=data_manipulation_config_path(),
results_root_path=_RESULTS_ROOT,
output_path=output_path,
)
md2pdf(
md_path=output_path,
pdf_path=str(output_path).replace(".md", ".pdf"),
css_path="static/styles.css",
)
if __name__ == "__main__":
run()