- Added ability to load sessions via PrimaiteSession - PrimaiteSession loading test - Added a NotImplemented RLlib loading for now - Added the ability to load sessions for hardcoded agents - Moved Session metadata parsing to utils
This commit is contained in:
@@ -7,14 +7,13 @@ from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
|
||||
import primaite
|
||||
from primaite import getLogger, SESSIONS_DIR
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.data_viz.session_plots import plot_av_reward_per_episode
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.utils.session_metadata_parser import parse_session_metadata
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -253,47 +252,21 @@ class AgentSessionABC(ABC):
|
||||
|
||||
def load(self, path: Union[str, Path]):
|
||||
"""Load an agent from file."""
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
md_dict, training_config_path, laydown_config_path = parse_session_metadata(path)
|
||||
|
||||
if path.exists():
|
||||
# Unpack the session_metadata.json file
|
||||
md_file = path / "session_metadata.json"
|
||||
with open(md_file, "r") as file:
|
||||
md_dict = json.load(file)
|
||||
# set training config path
|
||||
self._training_config_path: Union[Path, str] = training_config_path
|
||||
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
|
||||
self._lay_down_config_path: Union[Path, str] = laydown_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
|
||||
|
||||
# Create a temp directory and dump the training and lay down
|
||||
# configs into it
|
||||
temp_dir = path / ".temp"
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
# set random UUID for session
|
||||
self._uuid = md_dict["uuid"]
|
||||
|
||||
temp_tc = temp_dir / "tc.yaml"
|
||||
with open(temp_tc, "w") as file:
|
||||
yaml.dump(md_dict["env"]["training_config"], file)
|
||||
|
||||
temp_ldc = temp_dir / "ldc.yaml"
|
||||
with open(temp_ldc, "w") as file:
|
||||
yaml.dump(md_dict["env"]["lay_down_config"], file)
|
||||
|
||||
# set training config path
|
||||
self._training_config_path: Union[Path, str] = temp_tc
|
||||
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
|
||||
self._lay_down_config_path: Union[Path, str] = temp_ldc
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
|
||||
|
||||
# set random UUID for session
|
||||
self._uuid = md_dict["uuid"]
|
||||
|
||||
# set the session path
|
||||
self.session_path = path
|
||||
"The Session path"
|
||||
|
||||
else:
|
||||
# Session path does not exist
|
||||
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
# set the session path
|
||||
self.session_path = path
|
||||
"The Session path"
|
||||
|
||||
@property
|
||||
def _saved_agent_path(self) -> Path:
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
@@ -16,7 +18,12 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
implemented.
|
||||
"""
|
||||
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
"""
|
||||
Initialise a hardcoded agent session.
|
||||
|
||||
@@ -26,7 +33,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
"""
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
super().__init__(training_config_path, lay_down_config_path, session_path)
|
||||
self._setup()
|
||||
|
||||
def _setup(self):
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
@@ -43,7 +43,12 @@ def _custom_log_creator(session_path: Path):
|
||||
class RLlibAgent(AgentSessionABC):
|
||||
"""An AgentSession class that implements a Ray RLlib agent."""
|
||||
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
"""
|
||||
Initialise the RLLib Agent training session.
|
||||
|
||||
@@ -56,6 +61,13 @@ class RLlibAgent(AgentSessionABC):
|
||||
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
|
||||
or `A2C`)
|
||||
"""
|
||||
# TODO: implement RLlib agent loading
|
||||
if session_path is not None:
|
||||
msg = "RLlib agent loading has not been implemented yet"
|
||||
_LOGGER.error(msg)
|
||||
print(msg)
|
||||
raise NotImplementedError
|
||||
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
if not self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final, Union
|
||||
from typing import Dict, Final, Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
@@ -14,6 +14,7 @@ from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyA
|
||||
from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.utils.session_metadata_parser import parse_session_metadata
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -27,8 +28,9 @@ class PrimaiteSession:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
"""
|
||||
The PrimaiteSession constructor.
|
||||
@@ -36,6 +38,25 @@ class PrimaiteSession:
|
||||
:param training_config_path: The training config path.
|
||||
:param lay_down_config_path: The lay down config path.
|
||||
"""
|
||||
self._agent_session: AgentSessionABC = None # noqa
|
||||
self.session_path: Path = session_path # noqa
|
||||
self.timestamp_str: str = None # noqa
|
||||
self.learning_path: Path = None # noqa
|
||||
self.evaluation_path: Path = None # noqa
|
||||
|
||||
# check if session path is provided
|
||||
if session_path is not None:
|
||||
# set load_session to true
|
||||
self.is_load_session = True
|
||||
if not isinstance(session_path, Path):
|
||||
session_path = Path(session_path)
|
||||
|
||||
# if a session path is provided, load it
|
||||
if not session_path.exists():
|
||||
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
|
||||
|
||||
md_dict, training_config_path, lay_down_config_path = parse_session_metadata(session_path)
|
||||
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path, str]] = training_config_path
|
||||
@@ -46,12 +67,6 @@ class PrimaiteSession:
|
||||
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
|
||||
self._agent_session: AgentSessionABC = None # noqa
|
||||
self.session_path: Path = None # noqa
|
||||
self.timestamp_str: str = None # noqa
|
||||
self.learning_path: Path = None # noqa
|
||||
self.evaluation_path: Path = None # noqa
|
||||
|
||||
def setup(self):
|
||||
"""Performs the session setup."""
|
||||
if self._training_config.agent_framework == AgentFramework.CUSTOM:
|
||||
@@ -60,11 +75,15 @@ class PrimaiteSession:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
# Deterministic Hardcoded Agent with Node Action Space
|
||||
self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = HardCodedNodeAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = HardCodedACLAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = HardCodedACLAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
@@ -77,11 +96,15 @@ class PrimaiteSession:
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}")
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
self._agent_session = DoNothingNodeAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = DoNothingNodeAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = DoNothingACLAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = DoNothingACLAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
@@ -93,10 +116,14 @@ class PrimaiteSession:
|
||||
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}")
|
||||
self._agent_session = RandomAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = RandomAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}")
|
||||
self._agent_session = DummyAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = DummyAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework AgentIdentifier combo
|
||||
@@ -105,12 +132,12 @@ class PrimaiteSession:
|
||||
elif self._training_config.agent_framework == AgentFramework.SB3:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}")
|
||||
# Stable Baselines3 Agent
|
||||
self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path, self.session_path)
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
|
||||
# Ray RLlib Agent
|
||||
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path, self.session_path)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework
|
||||
|
||||
58
src/primaite/utils/session_metadata_parser.py
Normal file
58
src/primaite/utils/session_metadata_parser.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def parse_session_metadata(session_path: Union[Path, str], dict_only=False):
|
||||
"""
|
||||
Loads a session metadata from the given directory path.
|
||||
|
||||
:param session_path: Directory where the session metadata file is in
|
||||
:param dict_only: If dict_only is true, the function will only return the dict contents of session metadata
|
||||
|
||||
:return: Dictionary which has all the session metadata contents
|
||||
:rtype: Dict
|
||||
|
||||
:return: Path where the YAML copy of the training config is dumped into
|
||||
:rtype: str
|
||||
:return: Path where the YAML copy of the laydown config is dumped into
|
||||
:rtype: str
|
||||
"""
|
||||
if not isinstance(session_path, Path):
|
||||
session_path = Path(session_path)
|
||||
|
||||
if not session_path.exists():
|
||||
# Session path does not exist
|
||||
msg = f"Failed to load PrimAITE Session, path does not exist: {session_path}"
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
# Unpack the session_metadata.json file
|
||||
md_file = session_path / "session_metadata.json"
|
||||
with open(md_file, "r") as file:
|
||||
md_dict = json.load(file)
|
||||
|
||||
# if dict only, return dict without doing anything else
|
||||
if dict_only:
|
||||
return md_dict
|
||||
|
||||
# Create a temp directory and dump the training and lay down
|
||||
# configs into it
|
||||
temp_dir = session_path / ".temp"
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
|
||||
temp_tc = temp_dir / "tc.yaml"
|
||||
with open(temp_tc, "w") as file:
|
||||
yaml.dump(md_dict["env"]["training_config"], file)
|
||||
|
||||
temp_ldc = temp_dir / "ldc.yaml"
|
||||
with open(temp_ldc, "w") as file:
|
||||
yaml.dump(md_dict["env"]["lay_down_config"], file)
|
||||
|
||||
return [md_dict, temp_tc, temp_ldc]
|
||||
@@ -8,6 +8,7 @@ from uuid import uuid4
|
||||
from primaite import getLogger
|
||||
from primaite.agents.sb3 import SB3Agent
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.primaite_session import PrimaiteSession
|
||||
from primaite.utils.session_output_reader import av_rewards_dict
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
@@ -96,4 +97,59 @@ def test_load_sb3_session():
|
||||
|
||||
def test_load_primaite_session():
|
||||
"""Test that loading a Primaite session works."""
|
||||
pass
|
||||
expected_learn_mean_reward_per_episode = {
|
||||
10: 0,
|
||||
11: -0.008037109374999995,
|
||||
12: -0.007978515624999988,
|
||||
13: -0.008191406249999991,
|
||||
14: -0.00817578124999999,
|
||||
15: -0.008085937499999998,
|
||||
16: -0.007837890624999982,
|
||||
17: -0.007798828124999992,
|
||||
18: -0.007777343749999998,
|
||||
19: -0.007958984374999988,
|
||||
20: -0.0077499999999999835,
|
||||
}
|
||||
|
||||
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
|
||||
|
||||
# create loaded session
|
||||
session = PrimaiteSession(session_path=test_path)
|
||||
|
||||
# run setup on session
|
||||
session.setup()
|
||||
|
||||
# make sure that the session was loaded correctly
|
||||
assert session._agent_session.uuid == "301874d3-2e14-43c2-ba7f-e2b03ad05dde"
|
||||
assert session._agent_session._training_config.agent_framework == AgentFramework.SB3.name
|
||||
assert session._agent_session._training_config.agent_identifier == AgentIdentifier.PPO.name
|
||||
assert session._agent_session._training_config.deterministic
|
||||
assert session._agent_session._training_config.seed == 12345
|
||||
assert str(session._agent_session.session_path) == str(test_path)
|
||||
|
||||
# run another learn session
|
||||
session.learn()
|
||||
|
||||
learn_mean_rewards = av_rewards_dict(
|
||||
session.learning_path / f"average_reward_per_episode_{session.timestamp_str}.csv"
|
||||
)
|
||||
|
||||
# run is seeded so should have the expected learn value
|
||||
assert learn_mean_rewards == expected_learn_mean_reward_per_episode
|
||||
|
||||
# run an evaluation
|
||||
session.evaluate()
|
||||
|
||||
# load the evaluation average reward csv file
|
||||
eval_mean_reward = av_rewards_dict(
|
||||
session.evaluation_path / f"average_reward_per_episode_{session.timestamp_str}.csv"
|
||||
)
|
||||
|
||||
# the agent config ran the evaluation in deterministic mode, so should have the same reward value
|
||||
assert len(set(eval_mean_reward.values())) == 1
|
||||
|
||||
# the evaluation should be the same as a previous run
|
||||
assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988
|
||||
|
||||
# delete the test directory
|
||||
shutil.rmtree(test_path)
|
||||
|
||||
Reference in New Issue
Block a user