- SB3 Agent loading - rename agent.py -> agent_abc.py - rename hardcoded.py -> hardcoded_abc.py - Tests - Added in test asset that is used to load the SB3 Agent
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -50,6 +50,9 @@ coverage.xml
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
tests/assets/**/*.png
|
||||
tests/assets/**/tensorboard_logs/
|
||||
tests/assets/**/checkpoints/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final, Union
|
||||
from typing import Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
@@ -46,38 +46,63 @@ class AgentSessionABC(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
"""
|
||||
Initialise an agent session from config files.
|
||||
Initialise an agent session from config files, or load a previous session.
|
||||
|
||||
If training configuration and laydown configuration are provided with a session path,
|
||||
the session path will be used.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param session_path: directory path of the session to load
|
||||
"""
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path, str]] = training_config_path
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
|
||||
|
||||
# initialise variables
|
||||
self._env: Primaite
|
||||
self._agent = None
|
||||
self._can_learn: bool = False
|
||||
self._can_evaluate: bool = False
|
||||
self.is_eval = False
|
||||
|
||||
self._uuid = str(uuid4())
|
||||
self.session_timestamp: datetime = datetime.now()
|
||||
"The session timestamp"
|
||||
self.session_path = get_session_path(self.session_timestamp)
|
||||
"The Session path"
|
||||
|
||||
# convert session to path
|
||||
if session_path is not None:
|
||||
if not isinstance(session_path, Path):
|
||||
session_path = Path(session_path)
|
||||
|
||||
# if a session path is provided, load it
|
||||
if not session_path.exists():
|
||||
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
|
||||
|
||||
# load session
|
||||
self.load(session_path)
|
||||
else:
|
||||
# set training config path
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Union[Path, str] = training_config_path
|
||||
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Union[Path, str] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
|
||||
|
||||
# set random UUID for session
|
||||
self._uuid = str(uuid4())
|
||||
"The session timestamp"
|
||||
self.session_path = get_session_path(self.session_timestamp)
|
||||
"The Session path"
|
||||
|
||||
@property
|
||||
def timestamp_str(self) -> str:
|
||||
@@ -226,9 +251,7 @@ class AgentSessionABC(ABC):
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load(cls, path: Union[str, Path]) -> AgentSessionABC:
|
||||
def load(self, path: Union[str, Path]):
|
||||
"""Load an agent from file."""
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
@@ -252,26 +275,29 @@ class AgentSessionABC(ABC):
|
||||
with open(temp_ldc, "w") as file:
|
||||
yaml.dump(md_dict["env"]["lay_down_config"], file)
|
||||
|
||||
agent = cls(temp_tc, temp_ldc)
|
||||
# set training config path
|
||||
self._training_config_path: Union[Path, str] = temp_tc
|
||||
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
|
||||
self._lay_down_config_path: Union[Path, str] = temp_ldc
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
|
||||
|
||||
agent.session_path = path
|
||||
# set random UUID for session
|
||||
self._uuid = md_dict["uuid"]
|
||||
|
||||
return agent
|
||||
# set the session path
|
||||
self.session_path = path
|
||||
"The Session path"
|
||||
|
||||
else:
|
||||
# Session path does not exist
|
||||
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
pass
|
||||
|
||||
@property
|
||||
def _saved_agent_path(self) -> Path:
|
||||
file_name = (
|
||||
f"{self._training_config.agent_framework}_"
|
||||
f"{self._training_config.agent_identifier}_"
|
||||
f"{self.timestamp_str}.zip"
|
||||
)
|
||||
file_name = f"{self._training_config.agent_framework}_" f"{self._training_config.agent_identifier}_" f".zip"
|
||||
return self.learning_path / file_name
|
||||
|
||||
@abstractmethod
|
||||
@@ -4,7 +4,7 @@ import numpy as np
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.agents.hardcoded import HardCodedAgentSessionABC
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import (
|
||||
get_new_action,
|
||||
get_node_of_ip,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
|
||||
from primaite.agents.hardcoded import HardCodedAgentSessionABC
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
@@ -18,7 +19,12 @@ _LOGGER = getLogger(__name__)
|
||||
class SB3Agent(AgentSessionABC):
|
||||
"""An AgentSession class that implements a Stable Baselines3 agent."""
|
||||
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
"""
|
||||
Initialise the SB3 Agent training session.
|
||||
|
||||
@@ -31,7 +37,7 @@ class SB3Agent(AgentSessionABC):
|
||||
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
|
||||
or `A2C`)
|
||||
"""
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
super().__init__(training_config_path, lay_down_config_path, session_path)
|
||||
if not self._training_config.agent_framework == AgentFramework.SB3:
|
||||
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
@@ -47,7 +53,7 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
self._tensorboard_log_path = self.learning_path / "tensorboard_logs"
|
||||
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
|
||||
self._setup()
|
||||
|
||||
_LOGGER.debug(
|
||||
f"Created {self.__class__.__name__} using: "
|
||||
f"agent_framework={self._training_config.agent_framework}, "
|
||||
@@ -57,22 +63,48 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
self.is_eval = False
|
||||
|
||||
self._setup()
|
||||
|
||||
def _setup(self):
|
||||
super()._setup()
|
||||
|
||||
self._env = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
)
|
||||
self._agent = self._agent_class(
|
||||
PPOMlp,
|
||||
self._env,
|
||||
verbose=self.sb3_output_verbose_level,
|
||||
n_steps=self._training_config.num_steps,
|
||||
tensorboard_log=str(self._tensorboard_log_path),
|
||||
seed=self._training_config.seed,
|
||||
)
|
||||
|
||||
# check if there is a zip file that needs to be loaded
|
||||
load_file = next(self.session_path.rglob("*.zip"), None)
|
||||
|
||||
if not load_file:
|
||||
# create a new env and agent
|
||||
|
||||
self._agent = self._agent_class(
|
||||
PPOMlp,
|
||||
self._env,
|
||||
verbose=self.sb3_output_verbose_level,
|
||||
n_steps=self._training_config.num_steps,
|
||||
tensorboard_log=str(self._tensorboard_log_path),
|
||||
seed=self._training_config.seed,
|
||||
)
|
||||
else:
|
||||
# load the file
|
||||
self._agent = self._agent_class.load(load_file)
|
||||
|
||||
# set env values from session metadata
|
||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
md_dict = json.load(file)
|
||||
|
||||
if self.is_eval:
|
||||
# evaluation always starts at 0
|
||||
self._env.episode_count = 0
|
||||
self._env.total_step_count = 0
|
||||
else:
|
||||
# carry on from previous learning sessions
|
||||
self._env.episode_count = md_dict["learning"]["total_episodes"]
|
||||
self._env.total_step_count = md_dict["learning"]["total_time_steps"]
|
||||
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
@@ -144,11 +176,6 @@ class SB3Agent(AgentSessionABC):
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Union[str, Path]) -> SB3Agent:
|
||||
"""Load an agent from file."""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
self._agent.save(self._saved_agent_path)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from primaite.agents.hardcoded import HardCodedAgentSessionABC
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from pathlib import Path
|
||||
from typing import Dict, Final, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.agents.hardcoded_acl import HardCodedACLAgent
|
||||
from primaite.agents.hardcoded_node import HardCodedNodeAgent
|
||||
from primaite.agents.rllib import RLlibAgent
|
||||
|
||||
@@ -4,3 +4,6 @@ from typing import Final
|
||||
|
||||
TEST_CONFIG_ROOT: Final[Path] = Path(__file__).parent / "config"
|
||||
"The tests config root directory."
|
||||
|
||||
TEST_ASSETS_ROOT: Final[Path] = Path(__file__).parent / "assets"
|
||||
"The tests assets root directory."
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,26 @@
|
||||
Episode,Average Reward
|
||||
1,-0.009857999999999992
|
||||
2,-0.009857999999999992
|
||||
3,-0.009857999999999992
|
||||
4,-0.009857999999999992
|
||||
5,-0.009857999999999992
|
||||
6,-0.009857999999999992
|
||||
7,-0.009857999999999992
|
||||
8,-0.009857999999999992
|
||||
9,-0.009857999999999992
|
||||
10,-0.009857999999999992
|
||||
11,-0.009857999999999992
|
||||
12,-0.009857999999999992
|
||||
13,-0.009857999999999992
|
||||
14,-0.009857999999999992
|
||||
15,-0.009857999999999992
|
||||
16,-0.009857999999999992
|
||||
17,-0.009857999999999992
|
||||
18,-0.009857999999999992
|
||||
19,-0.009857999999999992
|
||||
20,-0.009857999999999992
|
||||
21,-0.009857999999999992
|
||||
22,-0.009857999999999992
|
||||
23,-0.009857999999999992
|
||||
24,-0.009857999999999992
|
||||
25,-0.009857999999999992
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,26 @@
|
||||
Episode,Average Reward
|
||||
1,-0.009281999999999969
|
||||
2,-0.009727999999999978
|
||||
3,-0.009469999999999977
|
||||
4,-0.009285999999999971
|
||||
5,-0.00960599999999997
|
||||
6,-0.009449999999999986
|
||||
7,-0.009779999999999981
|
||||
8,-0.009439999999999974
|
||||
9,-0.00967999999999998
|
||||
10,-0.008985999999999994
|
||||
11,-0.008893999999999982
|
||||
12,-0.009083999999999983
|
||||
13,-0.008361999999999984
|
||||
14,-0.009489999999999964
|
||||
15,-0.009027999999999977
|
||||
16,-0.009441999999999996
|
||||
17,-0.008733999999999988
|
||||
18,-0.008675999999999984
|
||||
19,-0.008569999999999984
|
||||
20,-0.009071999999999988
|
||||
21,-0.008043999999999997
|
||||
22,-0.007955999999999982
|
||||
23,-0.008277999999999976
|
||||
24,-0.00803399999999999
|
||||
25,-0.00856399999999999
|
||||
|
File diff suppressed because one or more lines are too long
100
tests/test_session_loading.py
Normal file
100
tests/test_session_loading.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import os.path
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from uuid import uuid4
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.sb3 import SB3Agent
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.utils.session_output_reader import av_rewards_dict
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def copy_session_asset(asset_path: Union[str, Path]) -> str:
|
||||
"""Copies the asset into a temporary test folder."""
|
||||
if asset_path is None:
|
||||
raise Exception("No path provided")
|
||||
|
||||
if isinstance(asset_path, Path):
|
||||
asset_path = str(os.path.normpath(asset_path))
|
||||
|
||||
copy_path = str(Path(tempfile.gettempdir()) / "primaite" / str(uuid4()))
|
||||
|
||||
# copy the asset into a temp path
|
||||
try:
|
||||
shutil.copytree(asset_path, copy_path)
|
||||
except Exception as e:
|
||||
msg = f"Unable to copy directory: {asset_path}"
|
||||
_LOGGER.error(msg, e)
|
||||
print(msg, e)
|
||||
|
||||
_LOGGER.debug(f"Copied test asset to: {copy_path}")
|
||||
|
||||
# return the copied assets path
|
||||
return copy_path
|
||||
|
||||
|
||||
def test_load_sb3_session():
|
||||
"""Test that loading an SB3 agent works."""
|
||||
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
|
||||
|
||||
loaded_agent = SB3Agent(session_path=test_path)
|
||||
|
||||
# loaded agent should have the same UUID as the previous agent
|
||||
assert loaded_agent.uuid == "8c196c83-b77d-4ef7-af4b-0a3ada30221c"
|
||||
assert loaded_agent._training_config.agent_framework == AgentFramework.SB3.name
|
||||
assert loaded_agent._training_config.agent_identifier == AgentIdentifier.PPO.name
|
||||
assert loaded_agent._training_config.deterministic
|
||||
assert str(loaded_agent.session_path) == str(test_path)
|
||||
|
||||
# run an evaluation
|
||||
loaded_agent.evaluate()
|
||||
|
||||
# load the evaluation average reward csv file
|
||||
eval_mean_reward = av_rewards_dict(
|
||||
loaded_agent.evaluation_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv"
|
||||
)
|
||||
|
||||
# the agent config ran the evaluation in deterministic mode, so should have the same reward value
|
||||
assert len(set(eval_mean_reward.values())) == 1
|
||||
|
||||
# the evaluation should be the same as a previous run
|
||||
assert next(iter(set(eval_mean_reward.values()))) == -0.009857999999999992
|
||||
|
||||
# delete the test directory
|
||||
shutil.rmtree(test_path)
|
||||
|
||||
|
||||
def test_load_rllib_session():
|
||||
"""Test that loading an RLlib agent works."""
|
||||
# test_path = copy_session_asset(TEST_ASSETS_ROOT)
|
||||
#
|
||||
# loaded_agent = RLlibAgent(session_path=test_path)
|
||||
#
|
||||
# # loaded agent should have the same UUID as the previous agent
|
||||
# assert loaded_agent.uuid == "58c7e648-c784-44e8-bec0-a1db95898270"
|
||||
# assert loaded_agent._training_config.agent_framework == AgentFramework.SB3.name
|
||||
# assert loaded_agent._training_config.agent_identifier == AgentIdentifier.PPO.name
|
||||
# assert loaded_agent._training_config.deterministic
|
||||
# assert str(loaded_agent.session_path) == str(test_path)
|
||||
#
|
||||
# # run an evaluation
|
||||
# loaded_agent.evaluate()
|
||||
#
|
||||
# # load the evaluation average reward csv file
|
||||
# eval_mean_reward = av_rewards_dict(
|
||||
# loaded_agent.evaluation_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv"
|
||||
# )
|
||||
#
|
||||
# # the agent config ran the evaluation in deterministic mode, so should have the same reward value
|
||||
# assert len(set(eval_mean_reward.values())) == 1
|
||||
#
|
||||
# # the evaluation should be the same as a previous run
|
||||
# assert next(iter(set(eval_mean_reward.values()))) == -0.00011132812500000003
|
||||
#
|
||||
# # delete the test directory
|
||||
# shutil.rmtree(test_path)
|
||||
Reference in New Issue
Block a user