This commit is contained in:
Czar Echavez
2024-06-05 11:03:39 +01:00
parent 3bad9aa51e
commit dbc30fc296
5 changed files with 33 additions and 137 deletions

21
benchmark/benchmark.py Normal file
View File

@@ -0,0 +1,21 @@
from typing import Any, Dict, Optional, Tuple
from gymnasium.core import ObsType
from primaite.session.environment import PrimaiteGymEnv
class BenchmarkPrimaiteGymEnv(PrimaiteGymEnv):
"""
Class that extends the PrimaiteGymEnv.
The reset method is extended so that the average rewards per episode are recorded.
"""
total_time_steps: int = 0
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Overrides the PrimAITEGymEnv reset so that the total timesteps is saved."""
self.total_time_steps += self.game.step_counter
return super().reset(seed=seed)

View File

@@ -4,11 +4,11 @@ from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Final, Tuple from typing import Any, Dict, Final, Tuple
from report import build_benchmark_latex_report
from stable_baselines3 import PPO from stable_baselines3 import PPO
import primaite import primaite
from benchmark.utils.benchmark import BenchmarkPrimaiteGymEnv from benchmark import BenchmarkPrimaiteGymEnv
from benchmark.utils.report import build_benchmark_latex_report
from primaite.config.load import data_manipulation_config_path from primaite.config.load import data_manipulation_config_path
_LOGGER = primaite.getLogger(__name__) _LOGGER = primaite.getLogger(__name__)
@@ -65,15 +65,12 @@ class BenchmarkSession:
"""Run the training session.""" """Run the training session."""
# start timer for session # start timer for session
self.start_time = datetime.now() self.start_time = datetime.now()
# TODO check these parameters are correct
model = PPO( # EPISODE_LEN = 10
policy="MlpPolicy", TOTAL_TIMESTEPS = 131072
env=self.gym_env, LEARNING_RATE = 3e-4
learning_rate=self.learning_rate, model = PPO("MlpPolicy", self.gym_env, learning_rate=LEARNING_RATE, verbose=0, tensorboard_log="./PPO_UC2/")
n_steps=self.num_steps * self.num_episodes, model.learn(total_timesteps=TOTAL_TIMESTEPS)
batch_size=self.num_steps * self.num_episodes,
)
model.learn(total_timesteps=self.num_episodes * self.num_steps)
# end timer for session # end timer for session
self.end_time = datetime.now() self.end_time = datetime.now()
@@ -142,10 +139,10 @@ def _prepare_session_directory():
def run( def run(
number_of_sessions: int = 10, number_of_sessions: int = 5,
num_episodes: int = 1000, num_episodes: int = 512,
num_timesteps: int = 128, num_timesteps: int = 128,
batch_size: int = 1280, batch_size: int = 128,
learning_rate: float = 3e-4, learning_rate: float = 3e-4,
) -> None: # 10 # 1000 # 256 ) -> None: # 10 # 1000 # 256
"""Run the PrimAITE benchmark.""" """Run the PrimAITE benchmark."""

View File

@@ -14,9 +14,9 @@ from pylatex import Command, Document
from pylatex import Figure as LatexFigure from pylatex import Figure as LatexFigure
from pylatex import Section, Subsection, Tabular from pylatex import Section, Subsection, Tabular
from pylatex.utils import bold from pylatex.utils import bold
from utils import _get_system_info
import primaite import primaite
from benchmark.utils.utils import _get_system_info
PLOT_CONFIG = { PLOT_CONFIG = {
"size": {"auto_size": False, "width": 1500, "height": 900}, "size": {"auto_size": False, "width": 1500, "height": 900},

View File

@@ -1,122 +0,0 @@
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from gymnasium.core import ObsType
from primaite.session.environment import PrimaiteGymEnv
class BenchmarkPrimaiteGymEnv(PrimaiteGymEnv):
"""
Class that extends the PrimaiteGymEnv.
The reset method is extended so that the average rewards per episode are recorded.
"""
total_time_steps: int = 0
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Overrides the PrimAITEGymEnv reset so that the total timesteps is saved."""
self.total_time_steps += self.game.step_counter
return super().reset(seed=seed)
#####################################
# IGNORE BELOW FOR NOW
#####################################
class BenchMarkOSInfo:
"""Operating System Information about the machine that run the benchmark."""
operating_system: str
"""The operating system the benchmark was run on."""
operating_system_version: str
"""The operating system version the benchmark was run on."""
machine: str
"""The type of machine running the benchmark."""
processor: str
"""The processor used to run the benchmark."""
class BenchMarkCPUInfo:
"""CPU Information of the machine that ran the benchmark."""
physical_cores: int
"""The number of CPU cores the machine that ran the benchmark had."""
total_cores: int
"""The number of total cores the machine that run the benchmark had."""
max_frequency: int
"""The CPU's maximum clock speed."""
class BenchMarkMemoryInfo:
"""The Memory Information of the machine that ran the benchmark."""
total: str
"""The total amount of memory."""
swap_total: str
"""Virtual memory."""
class BenchMarkGPUInfo:
"""The GPU Information of the machine that ran the benchmark."""
name: str
"""GPU name."""
total_memory: str
"""GPU memory."""
class BenchMarkSystemInfo:
"""Overall system information of the machine that ran the benchmark."""
system: BenchMarkOSInfo
cpu: BenchMarkCPUInfo
memory: BenchMarkMemoryInfo
gpu: List[BenchMarkMemoryInfo]
class BenchMarkResult:
"""Class containing the relevant benchmark results."""
benchmark_start_time: datetime
"""Start time of the benchmark run."""
benchmark_end_time: datetime
"""End time of the benchmark run."""
primaite_version: str
"""The version of PrimAITE being benchmarked."""
system_info: BenchMarkSystemInfo
"""System information of the machine that ran the benchmark."""
total_sessions: int
"""The number of sessions that the benchmark ran."""
total_episodes: int
"""The number of episodes over all the sessions that the benchmark ran."""
total_timesteps: int
"""The number of timesteps over all the sessions that the benchmark ran."""
average_seconds_per_session: float
"""The average time per session."""
average_seconds_per_step: float
"""The average time per step."""
average_seconds_per_100_steps_and_10_nodes: float
"""The average time per 100 steps on a 10 node network."""
combined_average_reward_per_episode: Dict
"""tbd."""