#2628: commit
This commit is contained in:
21
benchmark/benchmark.py
Normal file
21
benchmark/benchmark.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from gymnasium.core import ObsType
|
||||||
|
|
||||||
|
from primaite.session.environment import PrimaiteGymEnv
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkPrimaiteGymEnv(PrimaiteGymEnv):
|
||||||
|
"""
|
||||||
|
Class that extends the PrimaiteGymEnv.
|
||||||
|
|
||||||
|
The reset method is extended so that the average rewards per episode are recorded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
total_time_steps: int = 0
|
||||||
|
|
||||||
|
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||||
|
"""Overrides the PrimAITEGymEnv reset so that the total timesteps is saved."""
|
||||||
|
self.total_time_steps += self.game.step_counter
|
||||||
|
|
||||||
|
return super().reset(seed=seed)
|
||||||
@@ -4,11 +4,11 @@ from datetime import datetime
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Final, Tuple
|
from typing import Any, Dict, Final, Tuple
|
||||||
|
|
||||||
|
from report import build_benchmark_latex_report
|
||||||
from stable_baselines3 import PPO
|
from stable_baselines3 import PPO
|
||||||
|
|
||||||
import primaite
|
import primaite
|
||||||
from benchmark.utils.benchmark import BenchmarkPrimaiteGymEnv
|
from benchmark import BenchmarkPrimaiteGymEnv
|
||||||
from benchmark.utils.report import build_benchmark_latex_report
|
|
||||||
from primaite.config.load import data_manipulation_config_path
|
from primaite.config.load import data_manipulation_config_path
|
||||||
|
|
||||||
_LOGGER = primaite.getLogger(__name__)
|
_LOGGER = primaite.getLogger(__name__)
|
||||||
@@ -65,15 +65,12 @@ class BenchmarkSession:
|
|||||||
"""Run the training session."""
|
"""Run the training session."""
|
||||||
# start timer for session
|
# start timer for session
|
||||||
self.start_time = datetime.now()
|
self.start_time = datetime.now()
|
||||||
|
# TODO check these parameters are correct
|
||||||
model = PPO(
|
# EPISODE_LEN = 10
|
||||||
policy="MlpPolicy",
|
TOTAL_TIMESTEPS = 131072
|
||||||
env=self.gym_env,
|
LEARNING_RATE = 3e-4
|
||||||
learning_rate=self.learning_rate,
|
model = PPO("MlpPolicy", self.gym_env, learning_rate=LEARNING_RATE, verbose=0, tensorboard_log="./PPO_UC2/")
|
||||||
n_steps=self.num_steps * self.num_episodes,
|
model.learn(total_timesteps=TOTAL_TIMESTEPS)
|
||||||
batch_size=self.num_steps * self.num_episodes,
|
|
||||||
)
|
|
||||||
model.learn(total_timesteps=self.num_episodes * self.num_steps)
|
|
||||||
|
|
||||||
# end timer for session
|
# end timer for session
|
||||||
self.end_time = datetime.now()
|
self.end_time = datetime.now()
|
||||||
@@ -142,10 +139,10 @@ def _prepare_session_directory():
|
|||||||
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
number_of_sessions: int = 10,
|
number_of_sessions: int = 5,
|
||||||
num_episodes: int = 1000,
|
num_episodes: int = 512,
|
||||||
num_timesteps: int = 128,
|
num_timesteps: int = 128,
|
||||||
batch_size: int = 1280,
|
batch_size: int = 128,
|
||||||
learning_rate: float = 3e-4,
|
learning_rate: float = 3e-4,
|
||||||
) -> None: # 10 # 1000 # 256
|
) -> None: # 10 # 1000 # 256
|
||||||
"""Run the PrimAITE benchmark."""
|
"""Run the PrimAITE benchmark."""
|
||||||
|
|||||||
@@ -14,9 +14,9 @@ from pylatex import Command, Document
|
|||||||
from pylatex import Figure as LatexFigure
|
from pylatex import Figure as LatexFigure
|
||||||
from pylatex import Section, Subsection, Tabular
|
from pylatex import Section, Subsection, Tabular
|
||||||
from pylatex.utils import bold
|
from pylatex.utils import bold
|
||||||
|
from utils import _get_system_info
|
||||||
|
|
||||||
import primaite
|
import primaite
|
||||||
from benchmark.utils.utils import _get_system_info
|
|
||||||
|
|
||||||
PLOT_CONFIG = {
|
PLOT_CONFIG = {
|
||||||
"size": {"auto_size": False, "width": 1500, "height": 900},
|
"size": {"auto_size": False, "width": 1500, "height": 900},
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
from gymnasium.core import ObsType
|
|
||||||
|
|
||||||
from primaite.session.environment import PrimaiteGymEnv
|
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkPrimaiteGymEnv(PrimaiteGymEnv):
|
|
||||||
"""
|
|
||||||
Class that extends the PrimaiteGymEnv.
|
|
||||||
|
|
||||||
The reset method is extended so that the average rewards per episode are recorded.
|
|
||||||
"""
|
|
||||||
|
|
||||||
total_time_steps: int = 0
|
|
||||||
|
|
||||||
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
|
||||||
"""Overrides the PrimAITEGymEnv reset so that the total timesteps is saved."""
|
|
||||||
self.total_time_steps += self.game.step_counter
|
|
||||||
|
|
||||||
return super().reset(seed=seed)
|
|
||||||
|
|
||||||
|
|
||||||
#####################################
|
|
||||||
# IGNORE BELOW FOR NOW
|
|
||||||
#####################################
|
|
||||||
|
|
||||||
|
|
||||||
class BenchMarkOSInfo:
|
|
||||||
"""Operating System Information about the machine that run the benchmark."""
|
|
||||||
|
|
||||||
operating_system: str
|
|
||||||
"""The operating system the benchmark was run on."""
|
|
||||||
|
|
||||||
operating_system_version: str
|
|
||||||
"""The operating system version the benchmark was run on."""
|
|
||||||
|
|
||||||
machine: str
|
|
||||||
"""The type of machine running the benchmark."""
|
|
||||||
|
|
||||||
processor: str
|
|
||||||
"""The processor used to run the benchmark."""
|
|
||||||
|
|
||||||
|
|
||||||
class BenchMarkCPUInfo:
|
|
||||||
"""CPU Information of the machine that ran the benchmark."""
|
|
||||||
|
|
||||||
physical_cores: int
|
|
||||||
"""The number of CPU cores the machine that ran the benchmark had."""
|
|
||||||
|
|
||||||
total_cores: int
|
|
||||||
"""The number of total cores the machine that run the benchmark had."""
|
|
||||||
|
|
||||||
max_frequency: int
|
|
||||||
"""The CPU's maximum clock speed."""
|
|
||||||
|
|
||||||
|
|
||||||
class BenchMarkMemoryInfo:
|
|
||||||
"""The Memory Information of the machine that ran the benchmark."""
|
|
||||||
|
|
||||||
total: str
|
|
||||||
"""The total amount of memory."""
|
|
||||||
|
|
||||||
swap_total: str
|
|
||||||
"""Virtual memory."""
|
|
||||||
|
|
||||||
|
|
||||||
class BenchMarkGPUInfo:
|
|
||||||
"""The GPU Information of the machine that ran the benchmark."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
"""GPU name."""
|
|
||||||
|
|
||||||
total_memory: str
|
|
||||||
"""GPU memory."""
|
|
||||||
|
|
||||||
|
|
||||||
class BenchMarkSystemInfo:
|
|
||||||
"""Overall system information of the machine that ran the benchmark."""
|
|
||||||
|
|
||||||
system: BenchMarkOSInfo
|
|
||||||
cpu: BenchMarkCPUInfo
|
|
||||||
memory: BenchMarkMemoryInfo
|
|
||||||
gpu: List[BenchMarkMemoryInfo]
|
|
||||||
|
|
||||||
|
|
||||||
class BenchMarkResult:
|
|
||||||
"""Class containing the relevant benchmark results."""
|
|
||||||
|
|
||||||
benchmark_start_time: datetime
|
|
||||||
"""Start time of the benchmark run."""
|
|
||||||
|
|
||||||
benchmark_end_time: datetime
|
|
||||||
"""End time of the benchmark run."""
|
|
||||||
|
|
||||||
primaite_version: str
|
|
||||||
"""The version of PrimAITE being benchmarked."""
|
|
||||||
|
|
||||||
system_info: BenchMarkSystemInfo
|
|
||||||
"""System information of the machine that ran the benchmark."""
|
|
||||||
|
|
||||||
total_sessions: int
|
|
||||||
"""The number of sessions that the benchmark ran."""
|
|
||||||
|
|
||||||
total_episodes: int
|
|
||||||
"""The number of episodes over all the sessions that the benchmark ran."""
|
|
||||||
|
|
||||||
total_timesteps: int
|
|
||||||
"""The number of timesteps over all the sessions that the benchmark ran."""
|
|
||||||
|
|
||||||
average_seconds_per_session: float
|
|
||||||
"""The average time per session."""
|
|
||||||
|
|
||||||
average_seconds_per_step: float
|
|
||||||
"""The average time per step."""
|
|
||||||
|
|
||||||
average_seconds_per_100_steps_and_10_nodes: float
|
|
||||||
"""The average time per 100 steps on a 10 node network."""
|
|
||||||
|
|
||||||
combined_average_reward_per_episode: Dict
|
|
||||||
"""tbd."""
|
|
||||||
Reference in New Issue
Block a user