- SB3 Agent loading
- rename agent.py -> agent_abc.py
- rename hardcoded.py -> hardcoded_abc.py
- Tests
- Added in test asset that is used to load the SB3 Agent
This commit is contained in:
Czar.Echavez
2023-07-13 16:24:03 +01:00
parent 54e4da1250
commit e2d5f0bcff
15 changed files with 12767 additions and 53 deletions

View File

@@ -4,3 +4,6 @@ from typing import Final
TEST_CONFIG_ROOT: Final[Path] = Path(__file__).parent / "config"
"The tests config root directory."
TEST_ASSETS_ROOT: Final[Path] = Path(__file__).parent / "assets"
"The tests assets root directory."

View File

@@ -0,0 +1,26 @@
Episode,Average Reward
1,-0.009857999999999992
2,-0.009857999999999992
3,-0.009857999999999992
4,-0.009857999999999992
5,-0.009857999999999992
6,-0.009857999999999992
7,-0.009857999999999992
8,-0.009857999999999992
9,-0.009857999999999992
10,-0.009857999999999992
11,-0.009857999999999992
12,-0.009857999999999992
13,-0.009857999999999992
14,-0.009857999999999992
15,-0.009857999999999992
16,-0.009857999999999992
17,-0.009857999999999992
18,-0.009857999999999992
19,-0.009857999999999992
20,-0.009857999999999992
21,-0.009857999999999992
22,-0.009857999999999992
23,-0.009857999999999992
24,-0.009857999999999992
25,-0.009857999999999992
1 Episode Average Reward
2 1 -0.009857999999999992
3 2 -0.009857999999999992
4 3 -0.009857999999999992
5 4 -0.009857999999999992
6 5 -0.009857999999999992
7 6 -0.009857999999999992
8 7 -0.009857999999999992
9 8 -0.009857999999999992
10 9 -0.009857999999999992
11 10 -0.009857999999999992
12 11 -0.009857999999999992
13 12 -0.009857999999999992
14 13 -0.009857999999999992
15 14 -0.009857999999999992
16 15 -0.009857999999999992
17 16 -0.009857999999999992
18 17 -0.009857999999999992
19 18 -0.009857999999999992
20 19 -0.009857999999999992
21 20 -0.009857999999999992
22 21 -0.009857999999999992
23 22 -0.009857999999999992
24 23 -0.009857999999999992
25 24 -0.009857999999999992
26 25 -0.009857999999999992

View File

@@ -0,0 +1,26 @@
Episode,Average Reward
1,-0.009281999999999969
2,-0.009727999999999978
3,-0.009469999999999977
4,-0.009285999999999971
5,-0.00960599999999997
6,-0.009449999999999986
7,-0.009779999999999981
8,-0.009439999999999974
9,-0.00967999999999998
10,-0.008985999999999994
11,-0.008893999999999982
12,-0.009083999999999983
13,-0.008361999999999984
14,-0.009489999999999964
15,-0.009027999999999977
16,-0.009441999999999996
17,-0.008733999999999988
18,-0.008675999999999984
19,-0.008569999999999984
20,-0.009071999999999988
21,-0.008043999999999997
22,-0.007955999999999982
23,-0.008277999999999976
24,-0.00803399999999999
25,-0.00856399999999999
1 Episode Average Reward
2 1 -0.009281999999999969
3 2 -0.009727999999999978
4 3 -0.009469999999999977
5 4 -0.009285999999999971
6 5 -0.00960599999999997
7 6 -0.009449999999999986
8 7 -0.009779999999999981
9 8 -0.009439999999999974
10 9 -0.00967999999999998
11 10 -0.008985999999999994
12 11 -0.008893999999999982
13 12 -0.009083999999999983
14 13 -0.008361999999999984
15 14 -0.009489999999999964
16 15 -0.009027999999999977
17 16 -0.009441999999999996
18 17 -0.008733999999999988
19 18 -0.008675999999999984
20 19 -0.008569999999999984
21 20 -0.009071999999999988
22 21 -0.008043999999999997
23 22 -0.007955999999999982
24 23 -0.008277999999999976
25 24 -0.00803399999999999
26 25 -0.00856399999999999

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,100 @@
import os.path
import shutil
import tempfile
from pathlib import Path
from typing import Union
from uuid import uuid4
from primaite import getLogger
from primaite.agents.sb3 import SB3Agent
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.utils.session_output_reader import av_rewards_dict
from tests import TEST_ASSETS_ROOT
_LOGGER = getLogger(__name__)
def copy_session_asset(asset_path: Union[str, Path]) -> str:
"""Copies the asset into a temporary test folder."""
if asset_path is None:
raise Exception("No path provided")
if isinstance(asset_path, Path):
asset_path = str(os.path.normpath(asset_path))
copy_path = str(Path(tempfile.gettempdir()) / "primaite" / str(uuid4()))
# copy the asset into a temp path
try:
shutil.copytree(asset_path, copy_path)
except Exception as e:
msg = f"Unable to copy directory: {asset_path}"
_LOGGER.error(msg, e)
print(msg, e)
_LOGGER.debug(f"Copied test asset to: {copy_path}")
# return the copied assets path
return copy_path
def test_load_sb3_session():
"""Test that loading an SB3 agent works."""
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
loaded_agent = SB3Agent(session_path=test_path)
# loaded agent should have the same UUID as the previous agent
assert loaded_agent.uuid == "8c196c83-b77d-4ef7-af4b-0a3ada30221c"
assert loaded_agent._training_config.agent_framework == AgentFramework.SB3.name
assert loaded_agent._training_config.agent_identifier == AgentIdentifier.PPO.name
assert loaded_agent._training_config.deterministic
assert str(loaded_agent.session_path) == str(test_path)
# run an evaluation
loaded_agent.evaluate()
# load the evaluation average reward csv file
eval_mean_reward = av_rewards_dict(
loaded_agent.evaluation_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv"
)
# the agent config ran the evaluation in deterministic mode, so should have the same reward value
assert len(set(eval_mean_reward.values())) == 1
# the evaluation should be the same as a previous run
assert next(iter(set(eval_mean_reward.values()))) == -0.009857999999999992
# delete the test directory
shutil.rmtree(test_path)
def test_load_rllib_session():
"""Test that loading an RLlib agent works."""
# test_path = copy_session_asset(TEST_ASSETS_ROOT)
#
# loaded_agent = RLlibAgent(session_path=test_path)
#
# # loaded agent should have the same UUID as the previous agent
# assert loaded_agent.uuid == "58c7e648-c784-44e8-bec0-a1db95898270"
# assert loaded_agent._training_config.agent_framework == AgentFramework.SB3.name
# assert loaded_agent._training_config.agent_identifier == AgentIdentifier.PPO.name
# assert loaded_agent._training_config.deterministic
# assert str(loaded_agent.session_path) == str(test_path)
#
# # run an evaluation
# loaded_agent.evaluate()
#
# # load the evaluation average reward csv file
# eval_mean_reward = av_rewards_dict(
# loaded_agent.evaluation_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv"
# )
#
# # the agent config ran the evaluation in deterministic mode, so should have the same reward value
# assert len(set(eval_mean_reward.values())) == 1
#
# # the evaluation should be the same as a previous run
# assert next(iter(set(eval_mean_reward.values()))) == -0.00011132812500000003
#
# # delete the test directory
# shutil.rmtree(test_path)