This commit is contained in:
Czar Echavez
2024-06-05 11:03:39 +01:00
parent 3bad9aa51e
commit dbc30fc296
5 changed files with 33 additions and 137 deletions

21
benchmark/benchmark.py Normal file
View File

@@ -0,0 +1,21 @@
from typing import Any, Dict, Optional, Tuple
from gymnasium.core import ObsType
from primaite.session.environment import PrimaiteGymEnv
class BenchmarkPrimaiteGymEnv(PrimaiteGymEnv):
"""
Class that extends the PrimaiteGymEnv.
The reset method is extended so that the average rewards per episode are recorded.
"""
total_time_steps: int = 0
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Overrides the PrimAITEGymEnv reset so that the total timesteps is saved."""
self.total_time_steps += self.game.step_counter
return super().reset(seed=seed)

View File

@@ -4,11 +4,11 @@ from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Final, Tuple
from report import build_benchmark_latex_report
from stable_baselines3 import PPO
import primaite
from benchmark.utils.benchmark import BenchmarkPrimaiteGymEnv
from benchmark.utils.report import build_benchmark_latex_report
from benchmark import BenchmarkPrimaiteGymEnv
from primaite.config.load import data_manipulation_config_path
_LOGGER = primaite.getLogger(__name__)
@@ -65,15 +65,12 @@ class BenchmarkSession:
"""Run the training session."""
# start timer for session
self.start_time = datetime.now()
model = PPO(
policy="MlpPolicy",
env=self.gym_env,
learning_rate=self.learning_rate,
n_steps=self.num_steps * self.num_episodes,
batch_size=self.num_steps * self.num_episodes,
)
model.learn(total_timesteps=self.num_episodes * self.num_steps)
# TODO check these parameters are correct
# EPISODE_LEN = 10
TOTAL_TIMESTEPS = 131072
LEARNING_RATE = 3e-4
model = PPO("MlpPolicy", self.gym_env, learning_rate=LEARNING_RATE, verbose=0, tensorboard_log="./PPO_UC2/")
model.learn(total_timesteps=TOTAL_TIMESTEPS)
# end timer for session
self.end_time = datetime.now()
@@ -142,10 +139,10 @@ def _prepare_session_directory():
def run(
number_of_sessions: int = 10,
num_episodes: int = 1000,
number_of_sessions: int = 5,
num_episodes: int = 512,
num_timesteps: int = 128,
batch_size: int = 1280,
batch_size: int = 128,
learning_rate: float = 3e-4,
) -> None: # 10 # 1000 # 256
"""Run the PrimAITE benchmark."""

View File

@@ -14,9 +14,9 @@ from pylatex import Command, Document
from pylatex import Figure as LatexFigure
from pylatex import Section, Subsection, Tabular
from pylatex.utils import bold
from utils import _get_system_info
import primaite
from benchmark.utils.utils import _get_system_info
PLOT_CONFIG = {
"size": {"auto_size": False, "width": 1500, "height": 900},

View File

@@ -1,122 +0,0 @@
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from gymnasium.core import ObsType
from primaite.session.environment import PrimaiteGymEnv
class BenchmarkPrimaiteGymEnv(PrimaiteGymEnv):
"""
Class that extends the PrimaiteGymEnv.
The reset method is extended so that the average rewards per episode are recorded.
"""
total_time_steps: int = 0
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Overrides the PrimAITEGymEnv reset so that the total timesteps is saved."""
self.total_time_steps += self.game.step_counter
return super().reset(seed=seed)
#####################################
# IGNORE BELOW FOR NOW
#####################################
class BenchMarkOSInfo:
"""Operating System Information about the machine that run the benchmark."""
operating_system: str
"""The operating system the benchmark was run on."""
operating_system_version: str
"""The operating system version the benchmark was run on."""
machine: str
"""The type of machine running the benchmark."""
processor: str
"""The processor used to run the benchmark."""
class BenchMarkCPUInfo:
"""CPU Information of the machine that ran the benchmark."""
physical_cores: int
"""The number of CPU cores the machine that ran the benchmark had."""
total_cores: int
"""The number of total cores the machine that run the benchmark had."""
max_frequency: int
"""The CPU's maximum clock speed."""
class BenchMarkMemoryInfo:
"""The Memory Information of the machine that ran the benchmark."""
total: str
"""The total amount of memory."""
swap_total: str
"""Virtual memory."""
class BenchMarkGPUInfo:
"""The GPU Information of the machine that ran the benchmark."""
name: str
"""GPU name."""
total_memory: str
"""GPU memory."""
class BenchMarkSystemInfo:
"""Overall system information of the machine that ran the benchmark."""
system: BenchMarkOSInfo
cpu: BenchMarkCPUInfo
memory: BenchMarkMemoryInfo
gpu: List[BenchMarkMemoryInfo]
class BenchMarkResult:
"""Class containing the relevant benchmark results."""
benchmark_start_time: datetime
"""Start time of the benchmark run."""
benchmark_end_time: datetime
"""End time of the benchmark run."""
primaite_version: str
"""The version of PrimAITE being benchmarked."""
system_info: BenchMarkSystemInfo
"""System information of the machine that ran the benchmark."""
total_sessions: int
"""The number of sessions that the benchmark ran."""
total_episodes: int
"""The number of episodes over all the sessions that the benchmark ran."""
total_timesteps: int
"""The number of timesteps over all the sessions that the benchmark ran."""
average_seconds_per_session: float
"""The average time per session."""
average_seconds_per_step: float
"""The average time per step."""
average_seconds_per_100_steps_and_10_nodes: float
"""The average time per 100 steps on a 10 node network."""
combined_average_reward_per_episode: Dict
"""tbd."""