1566 - added test file and edited configs to include types of num steps and modifed agents to use correct step and episode counts
This commit is contained in:
@@ -348,8 +348,8 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
self._env.set_as_eval() # noqa
|
||||
self.is_eval = True
|
||||
|
||||
time_steps = self._training_config.num_steps
|
||||
episodes = self._training_config.num_episodes
|
||||
time_steps = self._training_config.num_eval_steps
|
||||
episodes = self._training_config.num_eval_episodes
|
||||
|
||||
obs = self._env.reset()
|
||||
for episode in range(episodes):
|
||||
|
||||
@@ -107,13 +107,13 @@ class RLlibAgent(AgentSessionABC):
|
||||
),
|
||||
)
|
||||
|
||||
self._agent_config.training(train_batch_size=self._training_config.num_steps)
|
||||
self._agent_config.training(train_batch_size=self._training_config.num_train_steps)
|
||||
self._agent_config.framework(framework="tf")
|
||||
|
||||
self._agent_config.rollouts(
|
||||
num_rollout_workers=1,
|
||||
num_envs_per_worker=1,
|
||||
horizon=self._training_config.num_steps,
|
||||
horizon=self._training_config.num_train_steps,
|
||||
)
|
||||
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
|
||||
|
||||
@@ -121,7 +121,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._current_result["episodes_total"]
|
||||
if checkpoint_n > 0 and episode_count > 0:
|
||||
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
|
||||
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_train_episodes):
|
||||
self._agent.save(str(self.checkpoints_path))
|
||||
|
||||
def learn(
|
||||
@@ -133,8 +133,8 @@ class RLlibAgent(AgentSessionABC):
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
time_steps = self._training_config.num_steps
|
||||
episodes = self._training_config.num_episodes
|
||||
time_steps = self._training_config.num_train_steps
|
||||
episodes = self._training_config.num_train_episodes
|
||||
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
for i in range(episodes):
|
||||
|
||||
@@ -53,11 +53,12 @@ class SB3Agent(AgentSessionABC):
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
)
|
||||
|
||||
self._agent = self._agent_class(
|
||||
PPOMlp,
|
||||
self._env,
|
||||
verbose=self.sb3_output_verbose_level,
|
||||
n_steps=self._training_config.num_steps,
|
||||
n_steps=self._training_config.num_eval_steps,
|
||||
tensorboard_log=str(self._tensorboard_log_path),
|
||||
)
|
||||
|
||||
@@ -82,8 +83,8 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
time_steps = self._training_config.num_steps
|
||||
episodes = self._training_config.num_episodes
|
||||
time_steps = self._training_config.num_train_steps
|
||||
episodes = self._training_config.num_train_episodes
|
||||
self.is_eval = False
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
for i in range(episodes):
|
||||
@@ -104,8 +105,8 @@ class SB3Agent(AgentSessionABC):
|
||||
:param deterministic: Whether the evaluation is deterministic.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
time_steps = self._training_config.num_steps
|
||||
episodes = self._training_config.num_episodes
|
||||
time_steps = self._training_config.num_eval_steps
|
||||
episodes = self._training_config.num_eval_episodes
|
||||
self._env.set_as_eval()
|
||||
self.is_eval = True
|
||||
if deterministic:
|
||||
|
||||
@@ -60,11 +60,17 @@ class TrainingConfig:
|
||||
action_type: ActionType = ActionType.ANY
|
||||
"The ActionType to use"
|
||||
|
||||
num_episodes: int = 10
|
||||
"The number of episodes to train over"
|
||||
num_train_episodes: int = 10
|
||||
"The number of episodes to train over during an training session"
|
||||
|
||||
num_steps: int = 256
|
||||
"The number of steps in an episode"
|
||||
num_train_steps: int = 256
|
||||
"The number of steps in an episode during an training session"
|
||||
|
||||
num_eval_episodes: int = 10
|
||||
"The number of episodes to train over during an evaluation session"
|
||||
|
||||
num_eval_steps: int = 256
|
||||
"The number of steps in an episode during an evaluation session"
|
||||
|
||||
checkpoint_every_n_episodes: int = 5
|
||||
"The agent will save a checkpoint every n episodes"
|
||||
@@ -230,8 +236,17 @@ class TrainingConfig:
|
||||
tc += f"{self.hard_coded_agent_view}, "
|
||||
tc += f"{self.action_type}, "
|
||||
tc += f"observation_space={self.observation_space}, "
|
||||
tc += f"{self.num_episodes} episodes @ "
|
||||
tc += f"{self.num_steps} steps"
|
||||
if self.session_type.name == "TRAIN":
|
||||
tc += f"{self.num_train_episodes} episodes @ "
|
||||
tc += f"{self.num_train_steps} steps"
|
||||
elif self.session_type.name == "EVAL":
|
||||
tc += f"{self.num_eval_episodes} episodes @ "
|
||||
tc += f"{self.num_eval_steps} steps"
|
||||
else:
|
||||
tc += f"Training: {self.num_eval_episodes} episodes @ "
|
||||
tc += f"{self.num_eval_steps} steps"
|
||||
tc += f"Evaluation: {self.num_eval_episodes} episodes @ "
|
||||
tc += f"{self.num_eval_steps} steps"
|
||||
return tc
|
||||
|
||||
|
||||
@@ -320,7 +335,8 @@ def _get_new_key_from_legacy(legacy_key: str) -> str:
|
||||
"""
|
||||
key_mapping = {
|
||||
"agentIdentifier": None,
|
||||
"numEpisodes": "num_episodes",
|
||||
"numEpisodes": "num_train_episodes",
|
||||
"numSteps": "num_train_steps",
|
||||
"timeDelay": "time_delay",
|
||||
"configFilename": None,
|
||||
"sessionType": "session_type",
|
||||
|
||||
@@ -85,7 +85,12 @@ class Primaite(Env):
|
||||
_LOGGER.info(f"Using: {str(self.training_config)}")
|
||||
|
||||
# Number of steps in an episode
|
||||
self.episode_steps = self.training_config.num_steps
|
||||
if self.training_config.session_type == SessionType.TRAIN:
|
||||
self.episode_steps = self.training_config.num_train_steps
|
||||
elif self.training_config.session_type == SessionType.EVAL:
|
||||
self.episode_steps = self.training_config.num_eval_steps
|
||||
else:
|
||||
self.episode_steps = self.training_config.num_train_steps
|
||||
|
||||
super(Primaite, self).__init__()
|
||||
|
||||
@@ -254,6 +259,7 @@ class Primaite(Env):
|
||||
self.episode_count = 0
|
||||
self.step_count = 0
|
||||
self.total_step_count = 0
|
||||
self.episode_steps = self.training_config.num_eval_steps
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user