#2628: commit
This commit is contained in:
21
benchmark/benchmark.py
Normal file
21
benchmark/benchmark.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
|
||||
|
||||
class BenchmarkPrimaiteGymEnv(PrimaiteGymEnv):
|
||||
"""
|
||||
Class that extends the PrimaiteGymEnv.
|
||||
|
||||
The reset method is extended so that the average rewards per episode are recorded.
|
||||
"""
|
||||
|
||||
total_time_steps: int = 0
|
||||
|
||||
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
"""Overrides the PrimAITEGymEnv reset so that the total timesteps is saved."""
|
||||
self.total_time_steps += self.game.step_counter
|
||||
|
||||
return super().reset(seed=seed)
|
||||
@@ -4,11 +4,11 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Tuple
|
||||
|
||||
from report import build_benchmark_latex_report
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
import primaite
|
||||
from benchmark.utils.benchmark import BenchmarkPrimaiteGymEnv
|
||||
from benchmark.utils.report import build_benchmark_latex_report
|
||||
from benchmark import BenchmarkPrimaiteGymEnv
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
|
||||
_LOGGER = primaite.getLogger(__name__)
|
||||
@@ -65,15 +65,12 @@ class BenchmarkSession:
|
||||
"""Run the training session."""
|
||||
# start timer for session
|
||||
self.start_time = datetime.now()
|
||||
|
||||
model = PPO(
|
||||
policy="MlpPolicy",
|
||||
env=self.gym_env,
|
||||
learning_rate=self.learning_rate,
|
||||
n_steps=self.num_steps * self.num_episodes,
|
||||
batch_size=self.num_steps * self.num_episodes,
|
||||
)
|
||||
model.learn(total_timesteps=self.num_episodes * self.num_steps)
|
||||
# TODO check these parameters are correct
|
||||
# EPISODE_LEN = 10
|
||||
TOTAL_TIMESTEPS = 131072
|
||||
LEARNING_RATE = 3e-4
|
||||
model = PPO("MlpPolicy", self.gym_env, learning_rate=LEARNING_RATE, verbose=0, tensorboard_log="./PPO_UC2/")
|
||||
model.learn(total_timesteps=TOTAL_TIMESTEPS)
|
||||
|
||||
# end timer for session
|
||||
self.end_time = datetime.now()
|
||||
@@ -142,10 +139,10 @@ def _prepare_session_directory():
|
||||
|
||||
|
||||
def run(
|
||||
number_of_sessions: int = 10,
|
||||
num_episodes: int = 1000,
|
||||
number_of_sessions: int = 5,
|
||||
num_episodes: int = 512,
|
||||
num_timesteps: int = 128,
|
||||
batch_size: int = 1280,
|
||||
batch_size: int = 128,
|
||||
learning_rate: float = 3e-4,
|
||||
) -> None: # 10 # 1000 # 256
|
||||
"""Run the PrimAITE benchmark."""
|
||||
|
||||
@@ -14,9 +14,9 @@ from pylatex import Command, Document
|
||||
from pylatex import Figure as LatexFigure
|
||||
from pylatex import Section, Subsection, Tabular
|
||||
from pylatex.utils import bold
|
||||
from utils import _get_system_info
|
||||
|
||||
import primaite
|
||||
from benchmark.utils.utils import _get_system_info
|
||||
|
||||
PLOT_CONFIG = {
|
||||
"size": {"auto_size": False, "width": 1500, "height": 900},
|
||||
@@ -1,122 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
|
||||
|
||||
class BenchmarkPrimaiteGymEnv(PrimaiteGymEnv):
|
||||
"""
|
||||
Class that extends the PrimaiteGymEnv.
|
||||
|
||||
The reset method is extended so that the average rewards per episode are recorded.
|
||||
"""
|
||||
|
||||
total_time_steps: int = 0
|
||||
|
||||
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
"""Overrides the PrimAITEGymEnv reset so that the total timesteps is saved."""
|
||||
self.total_time_steps += self.game.step_counter
|
||||
|
||||
return super().reset(seed=seed)
|
||||
|
||||
|
||||
#####################################
|
||||
# IGNORE BELOW FOR NOW
|
||||
#####################################
|
||||
|
||||
|
||||
class BenchMarkOSInfo:
|
||||
"""Operating System Information about the machine that run the benchmark."""
|
||||
|
||||
operating_system: str
|
||||
"""The operating system the benchmark was run on."""
|
||||
|
||||
operating_system_version: str
|
||||
"""The operating system version the benchmark was run on."""
|
||||
|
||||
machine: str
|
||||
"""The type of machine running the benchmark."""
|
||||
|
||||
processor: str
|
||||
"""The processor used to run the benchmark."""
|
||||
|
||||
|
||||
class BenchMarkCPUInfo:
|
||||
"""CPU Information of the machine that ran the benchmark."""
|
||||
|
||||
physical_cores: int
|
||||
"""The number of CPU cores the machine that ran the benchmark had."""
|
||||
|
||||
total_cores: int
|
||||
"""The number of total cores the machine that run the benchmark had."""
|
||||
|
||||
max_frequency: int
|
||||
"""The CPU's maximum clock speed."""
|
||||
|
||||
|
||||
class BenchMarkMemoryInfo:
|
||||
"""The Memory Information of the machine that ran the benchmark."""
|
||||
|
||||
total: str
|
||||
"""The total amount of memory."""
|
||||
|
||||
swap_total: str
|
||||
"""Virtual memory."""
|
||||
|
||||
|
||||
class BenchMarkGPUInfo:
|
||||
"""The GPU Information of the machine that ran the benchmark."""
|
||||
|
||||
name: str
|
||||
"""GPU name."""
|
||||
|
||||
total_memory: str
|
||||
"""GPU memory."""
|
||||
|
||||
|
||||
class BenchMarkSystemInfo:
|
||||
"""Overall system information of the machine that ran the benchmark."""
|
||||
|
||||
system: BenchMarkOSInfo
|
||||
cpu: BenchMarkCPUInfo
|
||||
memory: BenchMarkMemoryInfo
|
||||
gpu: List[BenchMarkMemoryInfo]
|
||||
|
||||
|
||||
class BenchMarkResult:
|
||||
"""Class containing the relevant benchmark results."""
|
||||
|
||||
benchmark_start_time: datetime
|
||||
"""Start time of the benchmark run."""
|
||||
|
||||
benchmark_end_time: datetime
|
||||
"""End time of the benchmark run."""
|
||||
|
||||
primaite_version: str
|
||||
"""The version of PrimAITE being benchmarked."""
|
||||
|
||||
system_info: BenchMarkSystemInfo
|
||||
"""System information of the machine that ran the benchmark."""
|
||||
|
||||
total_sessions: int
|
||||
"""The number of sessions that the benchmark ran."""
|
||||
|
||||
total_episodes: int
|
||||
"""The number of episodes over all the sessions that the benchmark ran."""
|
||||
|
||||
total_timesteps: int
|
||||
"""The number of timesteps over all the sessions that the benchmark ran."""
|
||||
|
||||
average_seconds_per_session: float
|
||||
"""The average time per session."""
|
||||
|
||||
average_seconds_per_step: float
|
||||
"""The average time per step."""
|
||||
|
||||
average_seconds_per_100_steps_and_10_nodes: float
|
||||
"""The average time per 100 steps on a 10 node network."""
|
||||
|
||||
combined_average_reward_per_episode: Dict
|
||||
"""tbd."""
|
||||
Reference in New Issue
Block a user