- SB3 Agent loading
- rename agent.py -> agent_abc.py
- rename hardcoded.py -> hardcoded_abc.py
- Tests
- Added in test asset that is used to load the SB3 Agent
This commit is contained in:
Czar.Echavez
2023-07-13 16:24:03 +01:00
parent 54e4da1250
commit e2d5f0bcff
15 changed files with 12767 additions and 53 deletions

3
.gitignore vendored
View File

@@ -50,6 +50,9 @@ coverage.xml
.hypothesis/
.pytest_cache/
cover/
tests/assets/**/*.png
tests/assets/**/tensorboard_logs/
tests/assets/**/checkpoints/
# Translations
*.mo

View File

@@ -4,7 +4,7 @@ import json
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import Dict, Final, Union
from typing import Dict, Optional, Union
from uuid import uuid4
import yaml
@@ -46,38 +46,63 @@ class AgentSessionABC(ABC):
"""
@abstractmethod
def __init__(self, training_config_path, lay_down_config_path):
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
):
"""
Initialise an agent session from config files.
Initialise an agent session from config files, or load a previous session.
If training configuration and laydown configuration are provided with a session path,
the session path will be used.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
"""
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path, str]] = training_config_path
self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
# initialise variables
self._env: Primaite
self._agent = None
self._can_learn: bool = False
self._can_evaluate: bool = False
self.is_eval = False
self._uuid = str(uuid4())
self.session_timestamp: datetime = datetime.now()
"The session timestamp"
self.session_path = get_session_path(self.session_timestamp)
"The Session path"
# convert session to path
if session_path is not None:
if not isinstance(session_path, Path):
session_path = Path(session_path)
# if a session path is provided, load it
if not session_path.exists():
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
# load session
self.load(session_path)
else:
# set training config path
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Union[Path, str] = training_config_path
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Union[Path, str] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
# set random UUID for session
self._uuid = str(uuid4())
"The session timestamp"
self.session_path = get_session_path(self.session_timestamp)
"The Session path"
@property
def timestamp_str(self) -> str:
@@ -226,9 +251,7 @@ class AgentSessionABC(ABC):
def _get_latest_checkpoint(self):
pass
@classmethod
@abstractmethod
def load(cls, path: Union[str, Path]) -> AgentSessionABC:
def load(self, path: Union[str, Path]):
"""Load an agent from file."""
if not isinstance(path, Path):
path = Path(path)
@@ -252,26 +275,29 @@ class AgentSessionABC(ABC):
with open(temp_ldc, "w") as file:
yaml.dump(md_dict["env"]["lay_down_config"], file)
agent = cls(temp_tc, temp_ldc)
# set training config path
self._training_config_path: Union[Path, str] = temp_tc
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
self._lay_down_config_path: Union[Path, str] = temp_ldc
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
agent.session_path = path
# set random UUID for session
self._uuid = md_dict["uuid"]
return agent
# set the session path
self.session_path = path
"The Session path"
else:
# Session path does not exist
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
pass
@property
def _saved_agent_path(self) -> Path:
file_name = (
f"{self._training_config.agent_framework}_"
f"{self._training_config.agent_identifier}_"
f"{self.timestamp_str}.zip"
)
file_name = f"{self._training_config.agent_framework}_" f"{self._training_config.agent_identifier}_" f".zip"
return self.learning_path / file_name
@abstractmethod

View File

@@ -4,7 +4,7 @@ import numpy as np
from primaite.acl.access_control_list import AccessControlList
from primaite.acl.acl_rule import ACLRule
from primaite.agents.hardcoded import HardCodedAgentSessionABC
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
get_node_of_ip,

View File

@@ -1,6 +1,6 @@
import numpy as np
from primaite.agents.hardcoded import HardCodedAgentSessionABC
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable

View File

@@ -14,7 +14,7 @@ from ray.tune.logger import UnifiedLogger
from ray.tune.registry import register_env
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite

View File

@@ -1,14 +1,15 @@
from __future__ import annotations
import json
from pathlib import Path
from typing import Union
from typing import Optional, Union
import numpy as np
from stable_baselines3 import A2C, PPO
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
@@ -18,7 +19,12 @@ _LOGGER = getLogger(__name__)
class SB3Agent(AgentSessionABC):
"""An AgentSession class that implements a Stable Baselines3 agent."""
def __init__(self, training_config_path, lay_down_config_path):
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
):
"""
Initialise the SB3 Agent training session.
@@ -31,7 +37,7 @@ class SB3Agent(AgentSessionABC):
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
or `A2C`)
"""
super().__init__(training_config_path, lay_down_config_path)
super().__init__(training_config_path, lay_down_config_path, session_path)
if not self._training_config.agent_framework == AgentFramework.SB3:
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
@@ -47,7 +53,7 @@ class SB3Agent(AgentSessionABC):
self._tensorboard_log_path = self.learning_path / "tensorboard_logs"
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
self._setup()
_LOGGER.debug(
f"Created {self.__class__.__name__} using: "
f"agent_framework={self._training_config.agent_framework}, "
@@ -57,22 +63,48 @@ class SB3Agent(AgentSessionABC):
self.is_eval = False
self._setup()
def _setup(self):
super()._setup()
self._env = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
self._agent = self._agent_class(
PPOMlp,
self._env,
verbose=self.sb3_output_verbose_level,
n_steps=self._training_config.num_steps,
tensorboard_log=str(self._tensorboard_log_path),
seed=self._training_config.seed,
)
# check if there is a zip file that needs to be loaded
load_file = next(self.session_path.rglob("*.zip"), None)
if not load_file:
# create a new env and agent
self._agent = self._agent_class(
PPOMlp,
self._env,
verbose=self.sb3_output_verbose_level,
n_steps=self._training_config.num_steps,
tensorboard_log=str(self._tensorboard_log_path),
seed=self._training_config.seed,
)
else:
# load the file
self._agent = self._agent_class.load(load_file)
# set env values from session metadata
with open(self.session_path / "session_metadata.json", "r") as file:
md_dict = json.load(file)
if self.is_eval:
# evaluation always starts at 0
self._env.episode_count = 0
self._env.total_step_count = 0
else:
# carry on from previous learning sessions
self._env.episode_count = md_dict["learning"]["total_episodes"]
self._env.total_step_count = md_dict["learning"]["total_time_steps"]
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
@@ -144,11 +176,6 @@ class SB3Agent(AgentSessionABC):
self._env.close()
super().evaluate()
@classmethod
def load(cls, path: Union[str, Path]) -> SB3Agent:
"""Load an agent from file."""
raise NotImplementedError
def save(self):
"""Save the agent."""
self._agent.save(self._saved_agent_path)

View File

@@ -1,4 +1,4 @@
from primaite.agents.hardcoded import HardCodedAgentSessionABC
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum

View File

@@ -5,7 +5,7 @@ from pathlib import Path
from typing import Dict, Final, Union
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.agents.agent_abc import AgentSessionABC
from primaite.agents.hardcoded_acl import HardCodedACLAgent
from primaite.agents.hardcoded_node import HardCodedNodeAgent
from primaite.agents.rllib import RLlibAgent

View File

@@ -4,3 +4,6 @@ from typing import Final
TEST_CONFIG_ROOT: Final[Path] = Path(__file__).parent / "config"
"The tests config root directory."
TEST_ASSETS_ROOT: Final[Path] = Path(__file__).parent / "assets"
"The tests assets root directory."

View File

@@ -0,0 +1,26 @@
Episode,Average Reward
1,-0.009857999999999992
2,-0.009857999999999992
3,-0.009857999999999992
4,-0.009857999999999992
5,-0.009857999999999992
6,-0.009857999999999992
7,-0.009857999999999992
8,-0.009857999999999992
9,-0.009857999999999992
10,-0.009857999999999992
11,-0.009857999999999992
12,-0.009857999999999992
13,-0.009857999999999992
14,-0.009857999999999992
15,-0.009857999999999992
16,-0.009857999999999992
17,-0.009857999999999992
18,-0.009857999999999992
19,-0.009857999999999992
20,-0.009857999999999992
21,-0.009857999999999992
22,-0.009857999999999992
23,-0.009857999999999992
24,-0.009857999999999992
25,-0.009857999999999992
1 Episode Average Reward
2 1 -0.009857999999999992
3 2 -0.009857999999999992
4 3 -0.009857999999999992
5 4 -0.009857999999999992
6 5 -0.009857999999999992
7 6 -0.009857999999999992
8 7 -0.009857999999999992
9 8 -0.009857999999999992
10 9 -0.009857999999999992
11 10 -0.009857999999999992
12 11 -0.009857999999999992
13 12 -0.009857999999999992
14 13 -0.009857999999999992
15 14 -0.009857999999999992
16 15 -0.009857999999999992
17 16 -0.009857999999999992
18 17 -0.009857999999999992
19 18 -0.009857999999999992
20 19 -0.009857999999999992
21 20 -0.009857999999999992
22 21 -0.009857999999999992
23 22 -0.009857999999999992
24 23 -0.009857999999999992
25 24 -0.009857999999999992
26 25 -0.009857999999999992

View File

@@ -0,0 +1,26 @@
Episode,Average Reward
1,-0.009281999999999969
2,-0.009727999999999978
3,-0.009469999999999977
4,-0.009285999999999971
5,-0.00960599999999997
6,-0.009449999999999986
7,-0.009779999999999981
8,-0.009439999999999974
9,-0.00967999999999998
10,-0.008985999999999994
11,-0.008893999999999982
12,-0.009083999999999983
13,-0.008361999999999984
14,-0.009489999999999964
15,-0.009027999999999977
16,-0.009441999999999996
17,-0.008733999999999988
18,-0.008675999999999984
19,-0.008569999999999984
20,-0.009071999999999988
21,-0.008043999999999997
22,-0.007955999999999982
23,-0.008277999999999976
24,-0.00803399999999999
25,-0.00856399999999999
1 Episode Average Reward
2 1 -0.009281999999999969
3 2 -0.009727999999999978
4 3 -0.009469999999999977
5 4 -0.009285999999999971
6 5 -0.00960599999999997
7 6 -0.009449999999999986
8 7 -0.009779999999999981
9 8 -0.009439999999999974
10 9 -0.00967999999999998
11 10 -0.008985999999999994
12 11 -0.008893999999999982
13 12 -0.009083999999999983
14 13 -0.008361999999999984
15 14 -0.009489999999999964
16 15 -0.009027999999999977
17 16 -0.009441999999999996
18 17 -0.008733999999999988
19 18 -0.008675999999999984
20 19 -0.008569999999999984
21 20 -0.009071999999999988
22 21 -0.008043999999999997
23 22 -0.007955999999999982
24 23 -0.008277999999999976
25 24 -0.00803399999999999
26 25 -0.00856399999999999

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,100 @@
import os.path
import shutil
import tempfile
from pathlib import Path
from typing import Union
from uuid import uuid4
from primaite import getLogger
from primaite.agents.sb3 import SB3Agent
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.utils.session_output_reader import av_rewards_dict
from tests import TEST_ASSETS_ROOT
_LOGGER = getLogger(__name__)
def copy_session_asset(asset_path: Union[str, Path]) -> str:
"""Copies the asset into a temporary test folder."""
if asset_path is None:
raise Exception("No path provided")
if isinstance(asset_path, Path):
asset_path = str(os.path.normpath(asset_path))
copy_path = str(Path(tempfile.gettempdir()) / "primaite" / str(uuid4()))
# copy the asset into a temp path
try:
shutil.copytree(asset_path, copy_path)
except Exception as e:
msg = f"Unable to copy directory: {asset_path}"
_LOGGER.error(msg, e)
print(msg, e)
_LOGGER.debug(f"Copied test asset to: {copy_path}")
# return the copied assets path
return copy_path
def test_load_sb3_session():
"""Test that loading an SB3 agent works."""
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
loaded_agent = SB3Agent(session_path=test_path)
# loaded agent should have the same UUID as the previous agent
assert loaded_agent.uuid == "8c196c83-b77d-4ef7-af4b-0a3ada30221c"
assert loaded_agent._training_config.agent_framework == AgentFramework.SB3.name
assert loaded_agent._training_config.agent_identifier == AgentIdentifier.PPO.name
assert loaded_agent._training_config.deterministic
assert str(loaded_agent.session_path) == str(test_path)
# run an evaluation
loaded_agent.evaluate()
# load the evaluation average reward csv file
eval_mean_reward = av_rewards_dict(
loaded_agent.evaluation_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv"
)
# the agent config ran the evaluation in deterministic mode, so should have the same reward value
assert len(set(eval_mean_reward.values())) == 1
# the evaluation should be the same as a previous run
assert next(iter(set(eval_mean_reward.values()))) == -0.009857999999999992
# delete the test directory
shutil.rmtree(test_path)
def test_load_rllib_session():
"""Test that loading an RLlib agent works."""
# test_path = copy_session_asset(TEST_ASSETS_ROOT)
#
# loaded_agent = RLlibAgent(session_path=test_path)
#
# # loaded agent should have the same UUID as the previous agent
# assert loaded_agent.uuid == "58c7e648-c784-44e8-bec0-a1db95898270"
# assert loaded_agent._training_config.agent_framework == AgentFramework.SB3.name
# assert loaded_agent._training_config.agent_identifier == AgentIdentifier.PPO.name
# assert loaded_agent._training_config.deterministic
# assert str(loaded_agent.session_path) == str(test_path)
#
# # run an evaluation
# loaded_agent.evaluate()
#
# # load the evaluation average reward csv file
# eval_mean_reward = av_rewards_dict(
# loaded_agent.evaluation_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv"
# )
#
# # the agent config ran the evaluation in deterministic mode, so should have the same reward value
# assert len(set(eval_mean_reward.values())) == 1
#
# # the evaluation should be the same as a previous run
# assert next(iter(set(eval_mean_reward.values()))) == -0.00011132812500000003
#
# # delete the test directory
# shutil.rmtree(test_path)