#1595: test to make sure that the loaded agent trains + remove unnecessary files + fixing agent save output name

This commit is contained in:
Czar Echavez
2023-07-14 10:56:28 +01:00
parent bc7c32697f
commit a92ef3f4ad
8 changed files with 44 additions and 12593 deletions

View File

@@ -297,7 +297,7 @@ class AgentSessionABC(ABC):
@property
def _saved_agent_path(self) -> Path:
file_name = f"{self._training_config.agent_framework}_" f"{self._training_config.agent_identifier}_" f".zip"
file_name = f"{self._training_config.agent_framework}_" f"{self._training_config.agent_identifier}" f".zip"
return self.learning_path / file_name
@abstractmethod

View File

@@ -66,8 +66,7 @@ class SB3Agent(AgentSessionABC):
self._setup()
def _setup(self):
super()._setup()
"""Set up the SB3 Agent."""
self._env = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
@@ -85,18 +84,16 @@ class SB3Agent(AgentSessionABC):
PPOMlp,
self._env,
verbose=self.sb3_output_verbose_level,
n_steps=self._training_config.num_steps,
n_steps=self._training_config.num_train_steps,
tensorboard_log=str(self._tensorboard_log_path),
seed=self._training_config.seed,
)
else:
# load the file
self._agent = self._agent_class.load(load_file)
# set env values from session metadata
with open(self.session_path / "session_metadata.json", "r") as file:
md_dict = json.load(file)
# load environment values
if self.is_eval:
# evaluation always starts at 0
self._env.episode_count = 0
@@ -106,6 +103,15 @@ class SB3Agent(AgentSessionABC):
self._env.episode_count = md_dict["learning"]["total_episodes"]
self._env.total_step_count = md_dict["learning"]["total_time_steps"]
# load the file
self._agent = self._agent_class.load(load_file, env=self._env)
# set agent values
self._agent.verbose = self.sb3_output_verbose_level
self._agent.tensorboard_log = self.session_path / "learning/tensorboard_logs"
super()._setup()
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count

View File

@@ -1,26 +0,0 @@
Episode,Average Reward
1,-0.009857999999999992
2,-0.009857999999999992
3,-0.009857999999999992
4,-0.009857999999999992
5,-0.009857999999999992
6,-0.009857999999999992
7,-0.009857999999999992
8,-0.009857999999999992
9,-0.009857999999999992
10,-0.009857999999999992
11,-0.009857999999999992
12,-0.009857999999999992
13,-0.009857999999999992
14,-0.009857999999999992
15,-0.009857999999999992
16,-0.009857999999999992
17,-0.009857999999999992
18,-0.009857999999999992
19,-0.009857999999999992
20,-0.009857999999999992
21,-0.009857999999999992
22,-0.009857999999999992
23,-0.009857999999999992
24,-0.009857999999999992
25,-0.009857999999999992
1 Episode Average Reward
2 1 -0.009857999999999992
3 2 -0.009857999999999992
4 3 -0.009857999999999992
5 4 -0.009857999999999992
6 5 -0.009857999999999992
7 6 -0.009857999999999992
8 7 -0.009857999999999992
9 8 -0.009857999999999992
10 9 -0.009857999999999992
11 10 -0.009857999999999992
12 11 -0.009857999999999992
13 12 -0.009857999999999992
14 13 -0.009857999999999992
15 14 -0.009857999999999992
16 15 -0.009857999999999992
17 16 -0.009857999999999992
18 17 -0.009857999999999992
19 18 -0.009857999999999992
20 19 -0.009857999999999992
21 20 -0.009857999999999992
22 21 -0.009857999999999992
23 22 -0.009857999999999992
24 23 -0.009857999999999992
25 24 -0.009857999999999992
26 25 -0.009857999999999992

View File

@@ -1,26 +0,0 @@
Episode,Average Reward
1,-0.009281999999999969
2,-0.009727999999999978
3,-0.009469999999999977
4,-0.009285999999999971
5,-0.00960599999999997
6,-0.009449999999999986
7,-0.009779999999999981
8,-0.009439999999999974
9,-0.00967999999999998
10,-0.008985999999999994
11,-0.008893999999999982
12,-0.009083999999999983
13,-0.008361999999999984
14,-0.009489999999999964
15,-0.009027999999999977
16,-0.009441999999999996
17,-0.008733999999999988
18,-0.008675999999999984
19,-0.008569999999999984
20,-0.009071999999999988
21,-0.008043999999999997
22,-0.007955999999999982
23,-0.008277999999999976
24,-0.00803399999999999
25,-0.00856399999999999
1 Episode Average Reward
2 1 -0.009281999999999969
3 2 -0.009727999999999978
4 3 -0.009469999999999977
5 4 -0.009285999999999971
6 5 -0.00960599999999997
7 6 -0.009449999999999986
8 7 -0.009779999999999981
9 8 -0.009439999999999974
10 9 -0.00967999999999998
11 10 -0.008985999999999994
12 11 -0.008893999999999982
13 12 -0.009083999999999983
14 13 -0.008361999999999984
15 14 -0.009489999999999964
16 15 -0.009027999999999977
17 16 -0.009441999999999996
18 17 -0.008733999999999988
19 18 -0.008675999999999984
20 19 -0.008569999999999984
21 20 -0.009071999999999988
22 21 -0.008043999999999997
23 22 -0.007955999999999982
24 23 -0.008277999999999976
25 24 -0.00803399999999999
26 25 -0.00856399999999999

File diff suppressed because one or more lines are too long

View File

@@ -40,17 +40,42 @@ def copy_session_asset(asset_path: Union[str, Path]) -> str:
def test_load_sb3_session():
"""Test that loading an SB3 agent works."""
expected_learn_mean_reward_per_episode = {
10: 0,
11: -0.008037109374999995,
12: -0.007978515624999988,
13: -0.008191406249999991,
14: -0.00817578124999999,
15: -0.008085937499999998,
16: -0.007837890624999982,
17: -0.007798828124999992,
18: -0.007777343749999998,
19: -0.007958984374999988,
20: -0.0077499999999999835,
}
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
loaded_agent = SB3Agent(session_path=test_path)
# loaded agent should have the same UUID as the previous agent
assert loaded_agent.uuid == "8c196c83-b77d-4ef7-af4b-0a3ada30221c"
assert loaded_agent.uuid == "301874d3-2e14-43c2-ba7f-e2b03ad05dde"
assert loaded_agent._training_config.agent_framework == AgentFramework.SB3.name
assert loaded_agent._training_config.agent_identifier == AgentIdentifier.PPO.name
assert loaded_agent._training_config.deterministic
assert loaded_agent._training_config.seed == 12345
assert str(loaded_agent.session_path) == str(test_path)
# run another learn session
loaded_agent.learn()
learn_mean_rewards = av_rewards_dict(
loaded_agent.learning_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv"
)
# run is seeded so should have the expected learn value
assert learn_mean_rewards == expected_learn_mean_reward_per_episode
# run an evaluation
loaded_agent.evaluate()
@@ -63,38 +88,12 @@ def test_load_sb3_session():
assert len(set(eval_mean_reward.values())) == 1
# the evaluation should be the same as a previous run
assert next(iter(set(eval_mean_reward.values()))) == -0.009857999999999992
assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988
# delete the test directory
shutil.rmtree(test_path)
def test_load_rllib_session():
"""Test that loading an RLlib agent works."""
# test_path = copy_session_asset(TEST_ASSETS_ROOT)
#
# loaded_agent = RLlibAgent(session_path=test_path)
#
# # loaded agent should have the same UUID as the previous agent
# assert loaded_agent.uuid == "58c7e648-c784-44e8-bec0-a1db95898270"
# assert loaded_agent._training_config.agent_framework == AgentFramework.SB3.name
# assert loaded_agent._training_config.agent_identifier == AgentIdentifier.PPO.name
# assert loaded_agent._training_config.deterministic
# assert str(loaded_agent.session_path) == str(test_path)
#
# # run an evaluation
# loaded_agent.evaluate()
#
# # load the evaluation average reward csv file
# eval_mean_reward = av_rewards_dict(
# loaded_agent.evaluation_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv"
# )
#
# # the agent config ran the evaluation in deterministic mode, so should have the same reward value
# assert len(set(eval_mean_reward.values())) == 1
#
# # the evaluation should be the same as a previous run
# assert next(iter(set(eval_mean_reward.values()))) == -0.00011132812500000003
#
# # delete the test directory
# shutil.rmtree(test_path)
def test_load_primaite_session():
"""Test that loading a Primaite session works."""
pass