#2628: committing to be reviewed

This commit is contained in:
Czar Echavez
2024-06-01 13:23:27 +01:00
parent 472040aa70
commit 3bad9aa51e
7 changed files with 12601 additions and 12 deletions

View File

@@ -13,7 +13,7 @@ from primaite.config.load import data_manipulation_config_path
_LOGGER = primaite.getLogger(__name__)
_BENCHMARK_ROOT = Path(__file__).parent / "benchmark_session"
_BENCHMARK_ROOT = Path(__file__).parent
_RESULTS_ROOT: Final[Path] = _BENCHMARK_ROOT / "results"
_RESULTS_ROOT.mkdir(exist_ok=True, parents=True)
@@ -33,9 +33,15 @@ class BenchmarkSession:
num_episodes: int
"""Number of episodes to run the training session."""
num_steps: int
"""Number of steps to run the training session."""
batch_size: int
"""Number of steps for each episode."""
learning_rate: float
"""Learning rate for the model."""
start_time: datetime
"""Start time for the session."""
@@ -45,11 +51,15 @@ class BenchmarkSession:
session_metadata: Dict
"""Dict containing the metadata for the session - used to generate benchmark report."""
def __init__(self, gym_env: BenchmarkPrimaiteGymEnv, num_episodes: int, batch_size: int):
def __init__(
self, gym_env: BenchmarkPrimaiteGymEnv, num_episodes: int, num_steps: int, batch_size: int, learning_rate: float
):
"""Initialise the BenchmarkSession."""
self.gym_env = gym_env
self.num_episodes = num_episodes
self.num_steps = num_steps
self.batch_size = batch_size
self.learning_rate = learning_rate
def train(self):
"""Run the training session."""
@@ -59,10 +69,11 @@ class BenchmarkSession:
model = PPO(
policy="MlpPolicy",
env=self.gym_env,
batch_size=self.batch_size,
n_steps=self.batch_size * self.num_episodes,
learning_rate=self.learning_rate,
n_steps=self.num_steps * self.num_episodes,
batch_size=self.num_steps * self.num_episodes,
)
model.learn(total_timesteps=self.num_episodes * self.gym_env.game.options.max_episode_length)
model.learn(total_timesteps=self.num_episodes * self.num_steps)
# end timer for session
self.end_time = datetime.now()
@@ -108,14 +119,13 @@ class BenchmarkSession:
}
def _get_benchmark_primaite_environment(num_timesteps: int) -> BenchmarkPrimaiteGymEnv:
def _get_benchmark_primaite_environment() -> BenchmarkPrimaiteGymEnv:
"""
Create an instance of the BenchmarkPrimaiteGymEnv.
This environment will be used to train the agents on.
"""
env = BenchmarkPrimaiteGymEnv(env_config=data_manipulation_config_path())
env.game.options.max_episode_length = num_timesteps
return env
@@ -132,7 +142,11 @@ def _prepare_session_directory():
def run(
number_of_sessions: int = 3, num_episodes: int = 3, num_timesteps: int = 128, batch_size: int = 128
number_of_sessions: int = 10,
num_episodes: int = 1000,
num_timesteps: int = 128,
batch_size: int = 1280,
learning_rate: float = 3e-4,
) -> None: # 10 # 1000 # 256
"""Run the PrimAITE benchmark."""
benchmark_start_time = datetime.now()
@@ -145,8 +159,14 @@ def run(
for i in range(1, number_of_sessions + 1):
print(f"Starting Benchmark Session: {i}")
with _get_benchmark_primaite_environment(num_timesteps=num_timesteps) as gym_env:
session = BenchmarkSession(gym_env=gym_env, num_episodes=num_episodes, batch_size=batch_size)
with _get_benchmark_primaite_environment() as gym_env:
session = BenchmarkSession(
gym_env=gym_env,
num_episodes=num_episodes,
num_steps=num_timesteps,
batch_size=batch_size,
learning_rate=learning_rate,
)
session.train()
session_metadata_dict[i] = session.session_metadata