1566 - added test file and edited configs to include types of num steps and modifed agents to use correct step and episode counts

This commit is contained in:
SunilSamra
2023-07-07 14:13:47 +01:00
parent 3ff081ea71
commit 79d98e977b
20 changed files with 652 additions and 60 deletions

View File

@@ -348,8 +348,8 @@ class HardCodedAgentSessionABC(AgentSessionABC):
self._env.set_as_eval() # noqa
self.is_eval = True
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
obs = self._env.reset()
for episode in range(episodes):

View File

@@ -107,13 +107,13 @@ class RLlibAgent(AgentSessionABC):
),
)
self._agent_config.training(train_batch_size=self._training_config.num_steps)
self._agent_config.training(train_batch_size=self._training_config.num_train_steps)
self._agent_config.framework(framework="tf")
self._agent_config.rollouts(
num_rollout_workers=1,
num_envs_per_worker=1,
horizon=self._training_config.num_steps,
horizon=self._training_config.num_train_steps,
)
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
@@ -121,7 +121,7 @@ class RLlibAgent(AgentSessionABC):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._current_result["episodes_total"]
if checkpoint_n > 0 and episode_count > 0:
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_train_episodes):
self._agent.save(str(self.checkpoints_path))
def learn(
@@ -133,8 +133,8 @@ class RLlibAgent(AgentSessionABC):
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
time_steps = self._training_config.num_train_steps
episodes = self._training_config.num_train_episodes
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
for i in range(episodes):

View File

@@ -53,11 +53,12 @@ class SB3Agent(AgentSessionABC):
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
self._agent = self._agent_class(
PPOMlp,
self._env,
verbose=self.sb3_output_verbose_level,
n_steps=self._training_config.num_steps,
n_steps=self._training_config.num_eval_steps,
tensorboard_log=str(self._tensorboard_log_path),
)
@@ -82,8 +83,8 @@ class SB3Agent(AgentSessionABC):
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
time_steps = self._training_config.num_train_steps
episodes = self._training_config.num_train_episodes
self.is_eval = False
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
for i in range(episodes):
@@ -104,8 +105,8 @@ class SB3Agent(AgentSessionABC):
:param deterministic: Whether the evaluation is deterministic.
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
self._env.set_as_eval()
self.is_eval = True
if deterministic:

View File

@@ -60,11 +60,17 @@ class TrainingConfig:
action_type: ActionType = ActionType.ANY
"The ActionType to use"
num_episodes: int = 10
"The number of episodes to train over"
num_train_episodes: int = 10
"The number of episodes to train over during an training session"
num_steps: int = 256
"The number of steps in an episode"
num_train_steps: int = 256
"The number of steps in an episode during an training session"
num_eval_episodes: int = 10
"The number of episodes to train over during an evaluation session"
num_eval_steps: int = 256
"The number of steps in an episode during an evaluation session"
checkpoint_every_n_episodes: int = 5
"The agent will save a checkpoint every n episodes"
@@ -230,8 +236,17 @@ class TrainingConfig:
tc += f"{self.hard_coded_agent_view}, "
tc += f"{self.action_type}, "
tc += f"observation_space={self.observation_space}, "
tc += f"{self.num_episodes} episodes @ "
tc += f"{self.num_steps} steps"
if self.session_type.name == "TRAIN":
tc += f"{self.num_train_episodes} episodes @ "
tc += f"{self.num_train_steps} steps"
elif self.session_type.name == "EVAL":
tc += f"{self.num_eval_episodes} episodes @ "
tc += f"{self.num_eval_steps} steps"
else:
tc += f"Training: {self.num_eval_episodes} episodes @ "
tc += f"{self.num_eval_steps} steps"
tc += f"Evaluation: {self.num_eval_episodes} episodes @ "
tc += f"{self.num_eval_steps} steps"
return tc
@@ -320,7 +335,8 @@ def _get_new_key_from_legacy(legacy_key: str) -> str:
"""
key_mapping = {
"agentIdentifier": None,
"numEpisodes": "num_episodes",
"numEpisodes": "num_train_episodes",
"numSteps": "num_train_steps",
"timeDelay": "time_delay",
"configFilename": None,
"sessionType": "session_type",

View File

@@ -85,7 +85,12 @@ class Primaite(Env):
_LOGGER.info(f"Using: {str(self.training_config)}")
# Number of steps in an episode
self.episode_steps = self.training_config.num_steps
if self.training_config.session_type == SessionType.TRAIN:
self.episode_steps = self.training_config.num_train_steps
elif self.training_config.session_type == SessionType.EVAL:
self.episode_steps = self.training_config.num_eval_steps
else:
self.episode_steps = self.training_config.num_train_steps
super(Primaite, self).__init__()
@@ -254,6 +259,7 @@ class Primaite(Env):
self.episode_count = 0
self.step_count = 0
self.total_step_count = 0
self.episode_steps = self.training_config.num_eval_steps
def reset(self):
"""