#1386: added the ability to set deterministic and seeding RNG when training and evaluating + the fix provided in #1535
This commit is contained in:
@@ -14,8 +14,8 @@ from pathlib import Path
|
||||
from typing import Final, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
@@ -54,9 +54,6 @@ def run_generic(env: Primaite, config_values: TrainingConfig):
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(config_values.time_delay / 1000)
|
||||
|
||||
# Reset the environment at the end of the episode
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
@@ -90,7 +87,7 @@ def run_stable_baselines3_ppo(
|
||||
_LOGGER.error("Could not load agent")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
|
||||
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps, seed=env.training_config.seed)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
@@ -103,8 +100,19 @@ def run_stable_baselines3_ppo(
|
||||
# Default to being in an evaluation session
|
||||
print("Starting evaluation session...")
|
||||
_LOGGER.debug("Starting evaluation session...")
|
||||
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
obs = env.reset()
|
||||
|
||||
for step in range(0, config_values.num_steps):
|
||||
action, _states = agent.predict(
|
||||
obs,
|
||||
deterministic=env.training_config.deterministic
|
||||
)
|
||||
# convert to int if action is a numpy array
|
||||
if isinstance(action, np.ndarray):
|
||||
action = np.int64(action)
|
||||
obs, rewards, done, info = env.step(action)
|
||||
env.close()
|
||||
|
||||
|
||||
@@ -138,7 +146,7 @@ def run_stable_baselines3_a2c(
|
||||
_LOGGER.error("Could not load agent")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps)
|
||||
agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps, seed=env.training_config.seed)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
@@ -151,7 +159,18 @@ def run_stable_baselines3_a2c(
|
||||
# Default to being in an evaluation session
|
||||
print("Starting evaluation session...")
|
||||
_LOGGER.debug("Starting evaluation session...")
|
||||
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
obs = env.reset()
|
||||
|
||||
for step in range(0, config_values.num_steps):
|
||||
action, _states = agent.predict(
|
||||
obs,
|
||||
deterministic=env.training_config.deterministic
|
||||
)
|
||||
# convert to int if action is a numpy array
|
||||
if isinstance(action, np.ndarray):
|
||||
action = np.int64(action)
|
||||
obs, rewards, done, info = env.step(action)
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user