Files
PrimAITE/benchmark/primaite_benchmark.py

160 lines
5.3 KiB
Python
Raw Normal View History

# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
2023-07-18 10:11:01 +01:00
import shutil
from datetime import datetime
from pathlib import Path
2024-05-31 13:47:02 +01:00
from typing import Any, Dict, Final, Tuple
from stable_baselines3 import PPO
2023-07-18 10:11:01 +01:00
import primaite
2024-05-31 13:47:02 +01:00
from benchmark.utils.benchmark import BenchmarkPrimaiteGymEnv
from benchmark.utils.report import build_benchmark_latex_report
from primaite.config.load import data_manipulation_config_path
2023-07-18 10:11:01 +01:00
_LOGGER = primaite.getLogger(__name__)
2024-05-31 13:47:02 +01:00
_BENCHMARK_ROOT = Path(__file__).parent / "benchmark_session"
_RESULTS_ROOT: Final[Path] = _BENCHMARK_ROOT / "results"
2023-07-18 10:11:01 +01:00
_RESULTS_ROOT.mkdir(exist_ok=True, parents=True)
_OUTPUT_ROOT: Final[Path] = _BENCHMARK_ROOT / "output"
2023-07-18 10:11:01 +01:00
# Clear and recreate the output directory
if _OUTPUT_ROOT.exists():
shutil.rmtree(_OUTPUT_ROOT)
2023-07-18 10:11:01 +01:00
_OUTPUT_ROOT.mkdir()
2024-05-31 13:47:02 +01:00
class BenchmarkSession:
"""Benchmark Session class."""
2024-05-31 13:47:02 +01:00
gym_env: BenchmarkPrimaiteGymEnv
"""Gym environment used by the session to train."""
2024-05-31 13:47:02 +01:00
num_episodes: int
"""Number of episodes to run the training session."""
2024-05-31 13:47:02 +01:00
batch_size: int
"""Number of steps for each episode."""
start_time: datetime
"""Start time for the session."""
end_time: datetime
"""End time for the session."""
session_metadata: Dict
"""Dict containing the metadata for the session - used to generate benchmark report."""
def __init__(self, gym_env: BenchmarkPrimaiteGymEnv, num_episodes: int, batch_size: int):
"""Initialise the BenchmarkSession."""
self.gym_env = gym_env
self.num_episodes = num_episodes
self.batch_size = batch_size
def train(self):
"""Run the training session."""
# start timer for session
self.start_time = datetime.now()
model = PPO(
policy="MlpPolicy",
env=self.gym_env,
batch_size=self.batch_size,
n_steps=self.batch_size * self.num_episodes,
)
2024-05-31 13:47:02 +01:00
model.learn(total_timesteps=self.batch_size * self.num_episodes)
# end timer for session
self.end_time = datetime.now()
2024-05-31 13:47:02 +01:00
self.session_metadata = self.generate_learn_metadata_dict()
2023-07-18 10:11:01 +01:00
def _learn_benchmark_durations(self) -> Tuple[float, float, float]:
"""
Calculate and return the learning benchmark durations.
Calculates the:
- Total learning time in seconds
- Total learning time per time step in seconds
- Total learning time per 100 time steps per 10 nodes in seconds
:return: The learning benchmark durations as a Tuple of three floats:
Tuple[total_s, s_per_step, s_per_100_steps_10_nodes].
"""
2024-05-31 13:47:02 +01:00
delta = self.end_time - self.start_time
2023-07-18 10:11:01 +01:00
total_s = delta.total_seconds()
2024-05-31 13:47:02 +01:00
total_steps = self.batch_size * self.num_episodes
2023-07-18 10:11:01 +01:00
s_per_step = total_s / total_steps
2024-05-31 13:47:02 +01:00
num_nodes = len(self.gym_env.game.simulation.network.nodes)
2023-07-18 10:11:01 +01:00
num_intervals = total_steps / 100
av_interval_time = total_s / num_intervals
s_per_100_steps_10_nodes = av_interval_time / (num_nodes / 10)
return total_s, s_per_step, s_per_100_steps_10_nodes
2024-05-31 13:47:02 +01:00
def generate_learn_metadata_dict(self) -> Dict[str, Any]:
2023-07-18 10:11:01 +01:00
"""Metadata specific to the learning session."""
total_s, s_per_step, s_per_100_steps_10_nodes = self._learn_benchmark_durations()
2024-05-31 13:47:02 +01:00
self.gym_env.average_reward_per_episode.pop(0) # remove episode 0
2023-07-18 10:11:01 +01:00
return {
2024-05-31 13:47:02 +01:00
"total_episodes": self.gym_env.episode_counter,
"total_time_steps": self.gym_env.total_time_steps,
2023-07-18 10:11:01 +01:00
"total_s": total_s,
"s_per_step": s_per_step,
"s_per_100_steps_10_nodes": s_per_100_steps_10_nodes,
2024-05-31 13:47:02 +01:00
"av_reward_per_episode": self.gym_env.average_reward_per_episode,
2023-07-18 10:11:01 +01:00
}
2024-05-31 13:47:02 +01:00
def _get_benchmark_primaite_environment() -> BenchmarkPrimaiteGymEnv:
"""
Create an instance of the BenchmarkPrimaiteGymEnv.
2024-05-31 13:47:02 +01:00
This environment will be used to train the agents on.
"""
return BenchmarkPrimaiteGymEnv(env_config=data_manipulation_config_path())
2024-05-31 13:47:02 +01:00
def _prepare_session_directory():
"""Prepare the session directory so that it is easier to clean up after the benchmarking is done."""
# override session path
session_path = _BENCHMARK_ROOT / "sessions"
2024-05-31 13:47:02 +01:00
if session_path.is_dir():
shutil.rmtree(session_path)
2024-05-31 13:47:02 +01:00
primaite.PRIMAITE_PATHS.user_sessions_path = session_path
primaite.PRIMAITE_PATHS.user_sessions_path.mkdir(exist_ok=True, parents=True)
2024-05-31 13:47:02 +01:00
def run(number_of_sessions: int = 1, num_episodes: int = 3, batch_size: int = 128) -> None: # 10 # 1000 # 256
"""Run the PrimAITE benchmark."""
benchmark_start_time = datetime.now()
2024-05-31 13:47:02 +01:00
session_metadata_dict = {}
2024-05-31 13:47:02 +01:00
_prepare_session_directory()
2024-05-31 13:47:02 +01:00
# run training
for i in range(1, number_of_sessions + 1):
print(f"Starting Benchmark Session: {i}")
2023-07-18 10:11:01 +01:00
2024-05-31 13:47:02 +01:00
with _get_benchmark_primaite_environment() as gym_env:
session = BenchmarkSession(gym_env=gym_env, num_episodes=num_episodes, batch_size=batch_size)
session.train()
session_metadata_dict[i] = session.session_metadata
# generate report
build_benchmark_latex_report(
benchmark_start_time=benchmark_start_time,
session_metadata=session_metadata_dict,
config_path=data_manipulation_config_path(),
results_root_path=_RESULTS_ROOT,
)
2023-07-18 10:11:01 +01:00
if __name__ == "__main__":
run()