#2628 - ran the primaite v3 benchmarking script
This commit is contained in:
@@ -33,7 +33,10 @@ class BenchmarkSession:
|
||||
num_episodes: int
|
||||
"""Number of episodes to run the training session."""
|
||||
|
||||
num_steps: int
|
||||
episode_len: int
|
||||
"""The number of steps per episode."""
|
||||
|
||||
total_steps: int
|
||||
"""Number of steps to run the training session."""
|
||||
|
||||
batch_size: int
|
||||
@@ -52,12 +55,20 @@ class BenchmarkSession:
|
||||
"""Dict containing the metadata for the session - used to generate benchmark report."""
|
||||
|
||||
def __init__(
|
||||
self, gym_env: BenchmarkPrimaiteGymEnv, num_episodes: int, num_steps: int, batch_size: int, learning_rate: float
|
||||
self,
|
||||
gym_env: BenchmarkPrimaiteGymEnv,
|
||||
episode_len: int,
|
||||
num_episodes: int,
|
||||
n_steps: int,
|
||||
batch_size: int,
|
||||
learning_rate: float,
|
||||
):
|
||||
"""Initialise the BenchmarkSession."""
|
||||
self.gym_env = gym_env
|
||||
self.episode_len = episode_len
|
||||
self.n_steps = n_steps
|
||||
self.num_episodes = num_episodes
|
||||
self.num_steps = num_steps
|
||||
self.total_steps = self.num_episodes * self.episode_len
|
||||
self.batch_size = batch_size
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
@@ -65,12 +76,16 @@ class BenchmarkSession:
|
||||
"""Run the training session."""
|
||||
# start timer for session
|
||||
self.start_time = datetime.now()
|
||||
# TODO check these parameters are correct
|
||||
# EPISODE_LEN = 10
|
||||
TOTAL_TIMESTEPS = 131072
|
||||
LEARNING_RATE = 3e-4
|
||||
model = PPO("MlpPolicy", self.gym_env, learning_rate=LEARNING_RATE, verbose=0, tensorboard_log="./PPO_UC2/")
|
||||
model.learn(total_timesteps=TOTAL_TIMESTEPS)
|
||||
model = PPO(
|
||||
policy="MlpPolicy",
|
||||
env=self.gym_env,
|
||||
learning_rate=self.learning_rate,
|
||||
n_steps=self.n_steps,
|
||||
batch_size=self.batch_size,
|
||||
verbose=0,
|
||||
tensorboard_log="./PPO_UC2/",
|
||||
)
|
||||
model.learn(total_timesteps=self.total_steps)
|
||||
|
||||
# end timer for session
|
||||
self.end_time = datetime.now()
|
||||
@@ -140,11 +155,12 @@ def _prepare_session_directory():
|
||||
|
||||
def run(
|
||||
number_of_sessions: int = 5,
|
||||
num_episodes: int = 512,
|
||||
num_timesteps: int = 128,
|
||||
batch_size: int = 128,
|
||||
num_episodes: int = 1000,
|
||||
episode_len: int = 128,
|
||||
n_steps: int = 1280,
|
||||
batch_size: int = 32,
|
||||
learning_rate: float = 3e-4,
|
||||
) -> None: # 10 # 1000 # 256
|
||||
) -> None:
|
||||
"""Run the PrimAITE benchmark."""
|
||||
benchmark_start_time = datetime.now()
|
||||
|
||||
@@ -160,7 +176,8 @@ def run(
|
||||
session = BenchmarkSession(
|
||||
gym_env=gym_env,
|
||||
num_episodes=num_episodes,
|
||||
num_steps=num_timesteps,
|
||||
n_steps=n_steps,
|
||||
episode_len=episode_len,
|
||||
batch_size=batch_size,
|
||||
learning_rate=learning_rate,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user