diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py new file mode 100644 index 00000000..5212b5d2 --- /dev/null +++ b/benchmark/benchmark.py @@ -0,0 +1,21 @@ +from typing import Any, Dict, Optional, Tuple + +from gymnasium.core import ObsType + +from primaite.session.environment import PrimaiteGymEnv + + +class BenchmarkPrimaiteGymEnv(PrimaiteGymEnv): + """ + Class that extends the PrimaiteGymEnv. + + The reset method is extended so that the average rewards per episode are recorded. + """ + + total_time_steps: int = 0 + + def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: + """Overrides the PrimAITEGymEnv reset so that the total timesteps is saved.""" + self.total_time_steps += self.game.step_counter + + return super().reset(seed=seed) diff --git a/benchmark/primaite_benchmark.py b/benchmark/primaite_benchmark.py index a6bf908d..296f8cc8 100644 --- a/benchmark/primaite_benchmark.py +++ b/benchmark/primaite_benchmark.py @@ -4,11 +4,11 @@ from datetime import datetime from pathlib import Path from typing import Any, Dict, Final, Tuple +from report import build_benchmark_latex_report from stable_baselines3 import PPO import primaite -from benchmark.utils.benchmark import BenchmarkPrimaiteGymEnv -from benchmark.utils.report import build_benchmark_latex_report +from benchmark import BenchmarkPrimaiteGymEnv from primaite.config.load import data_manipulation_config_path _LOGGER = primaite.getLogger(__name__) @@ -65,15 +65,12 @@ class BenchmarkSession: """Run the training session.""" # start timer for session self.start_time = datetime.now() - - model = PPO( - policy="MlpPolicy", - env=self.gym_env, - learning_rate=self.learning_rate, - n_steps=self.num_steps * self.num_episodes, - batch_size=self.num_steps * self.num_episodes, - ) - model.learn(total_timesteps=self.num_episodes * self.num_steps) + # TODO check these parameters are correct + # EPISODE_LEN = 10 + TOTAL_TIMESTEPS = 131072 + LEARNING_RATE = 3e-4 + model = PPO("MlpPolicy", self.gym_env, learning_rate=LEARNING_RATE, verbose=0, tensorboard_log="./PPO_UC2/") + model.learn(total_timesteps=TOTAL_TIMESTEPS) # end timer for session self.end_time = datetime.now() @@ -142,10 +139,10 @@ def _prepare_session_directory(): def run( - number_of_sessions: int = 10, - num_episodes: int = 1000, + number_of_sessions: int = 5, + num_episodes: int = 512, num_timesteps: int = 128, - batch_size: int = 1280, + batch_size: int = 128, learning_rate: float = 3e-4, ) -> None: # 10 # 1000 # 256 """Run the PrimAITE benchmark.""" diff --git a/benchmark/utils/report.py b/benchmark/report.py similarity index 99% rename from benchmark/utils/report.py rename to benchmark/report.py index b0b0e52a..d4d8ec76 100644 --- a/benchmark/utils/report.py +++ b/benchmark/report.py @@ -14,9 +14,9 @@ from pylatex import Command, Document from pylatex import Figure as LatexFigure from pylatex import Section, Subsection, Tabular from pylatex.utils import bold +from utils import _get_system_info import primaite -from benchmark.utils.utils import _get_system_info PLOT_CONFIG = { "size": {"auto_size": False, "width": 1500, "height": 900}, diff --git a/benchmark/utils/utils.py b/benchmark/utils.py similarity index 100% rename from benchmark/utils/utils.py rename to benchmark/utils.py diff --git a/benchmark/utils/benchmark.py b/benchmark/utils/benchmark.py deleted file mode 100644 index fc457a03..00000000 --- a/benchmark/utils/benchmark.py +++ /dev/null @@ -1,122 +0,0 @@ -from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple - -from gymnasium.core import ObsType - -from primaite.session.environment import PrimaiteGymEnv - - -class BenchmarkPrimaiteGymEnv(PrimaiteGymEnv): - """ - Class that extends the PrimaiteGymEnv. - - The reset method is extended so that the average rewards per episode are recorded. - """ - - total_time_steps: int = 0 - - def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: - """Overrides the PrimAITEGymEnv reset so that the total timesteps is saved.""" - self.total_time_steps += self.game.step_counter - - return super().reset(seed=seed) - - -##################################### -# IGNORE BELOW FOR NOW -##################################### - - -class BenchMarkOSInfo: - """Operating System Information about the machine that run the benchmark.""" - - operating_system: str - """The operating system the benchmark was run on.""" - - operating_system_version: str - """The operating system version the benchmark was run on.""" - - machine: str - """The type of machine running the benchmark.""" - - processor: str - """The processor used to run the benchmark.""" - - -class BenchMarkCPUInfo: - """CPU Information of the machine that ran the benchmark.""" - - physical_cores: int - """The number of CPU cores the machine that ran the benchmark had.""" - - total_cores: int - """The number of total cores the machine that run the benchmark had.""" - - max_frequency: int - """The CPU's maximum clock speed.""" - - -class BenchMarkMemoryInfo: - """The Memory Information of the machine that ran the benchmark.""" - - total: str - """The total amount of memory.""" - - swap_total: str - """Virtual memory.""" - - -class BenchMarkGPUInfo: - """The GPU Information of the machine that ran the benchmark.""" - - name: str - """GPU name.""" - - total_memory: str - """GPU memory.""" - - -class BenchMarkSystemInfo: - """Overall system information of the machine that ran the benchmark.""" - - system: BenchMarkOSInfo - cpu: BenchMarkCPUInfo - memory: BenchMarkMemoryInfo - gpu: List[BenchMarkMemoryInfo] - - -class BenchMarkResult: - """Class containing the relevant benchmark results.""" - - benchmark_start_time: datetime - """Start time of the benchmark run.""" - - benchmark_end_time: datetime - """End time of the benchmark run.""" - - primaite_version: str - """The version of PrimAITE being benchmarked.""" - - system_info: BenchMarkSystemInfo - """System information of the machine that ran the benchmark.""" - - total_sessions: int - """The number of sessions that the benchmark ran.""" - - total_episodes: int - """The number of episodes over all the sessions that the benchmark ran.""" - - total_timesteps: int - """The number of timesteps over all the sessions that the benchmark ran.""" - - average_seconds_per_session: float - """The average time per session.""" - - average_seconds_per_step: float - """The average time per step.""" - - average_seconds_per_100_steps_and_10_nodes: float - """The average time per 100 steps on a 10 node network.""" - - combined_average_reward_per_episode: Dict - """tbd."""