From b7fa826d9512a109353b1bd2fd928b0b3e275bb2 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Wed, 1 May 2024 11:16:45 +0100 Subject: [PATCH] #2442: Initial commit of MP test script --- src/primaite/notebooks/mp.py | 42 ++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 src/primaite/notebooks/mp.py diff --git a/src/primaite/notebooks/mp.py b/src/primaite/notebooks/mp.py new file mode 100644 index 00000000..ebed7122 --- /dev/null +++ b/src/primaite/notebooks/mp.py @@ -0,0 +1,42 @@ +import yaml +from stable_baselines3 import PPO +from stable_baselines3.common.utils import set_random_seed +from stable_baselines3.common.vec_env import SubprocVecEnv + +from primaite.session.environment import PrimaiteGymEnv + +EPISODE_LEN = 128 +NUM_EPISODES = 10 +NO_STEPS = EPISODE_LEN * NUM_EPISODES +BATCH_SIZE = 32 +LEARNING_RATE = 3e-4 + +with open("c:/projects/primaite/src/primaite/config/_package_data/data_manipulation.yaml", "r") as f: + cfg = yaml.safe_load(f) + + +def make_env(rank: int, seed: int = 0) -> callable: + """Wrapper script for _init function.""" + + def _init() -> PrimaiteGymEnv: + env = PrimaiteGymEnv(env_config=cfg) + env.reset(seed=seed + rank) + model = PPO( + "MlpPolicy", + env, + learning_rate=LEARNING_RATE, + n_steps=NO_STEPS, + batch_size=BATCH_SIZE, + verbose=0, + tensorboard_log="./PPO_UC2/", + ) + model.learn(total_timesteps=NO_STEPS) + return env + + set_random_seed(seed) + return _init + + +if __name__ == "__main__": + n_procs = 4 + train_env = SubprocVecEnv([make_env(i + n_procs) for i in range(n_procs)])