From fbdb8aab28d5b014523019b114da41d1aff98f3d Mon Sep 17 00:00:00 2001 From: "Czar.Echavez" Date: Fri, 14 Jul 2023 14:14:03 +0100 Subject: [PATCH] #1595: - Added ability to load sessions via PrimaiteSession - PrimaiteSession loading test - Added a NotImplemented RLlib loading for now - Added the ability to load sessions for hardcoded agents - Moved Session metadata parsing to utils --- src/primaite/agents/agent_abc.py | 53 ++++------------ src/primaite/agents/hardcoded_abc.py | 11 +++- src/primaite/agents/rllib.py | 16 ++++- src/primaite/primaite_session.py | 61 +++++++++++++------ src/primaite/utils/session_metadata_parser.py | 58 ++++++++++++++++++ tests/test_session_loading.py | 58 +++++++++++++++++- 6 files changed, 195 insertions(+), 62 deletions(-) create mode 100644 src/primaite/utils/session_metadata_parser.py diff --git a/src/primaite/agents/agent_abc.py b/src/primaite/agents/agent_abc.py index ec870781..e36196a0 100644 --- a/src/primaite/agents/agent_abc.py +++ b/src/primaite/agents/agent_abc.py @@ -7,14 +7,13 @@ from pathlib import Path from typing import Dict, Optional, Union from uuid import uuid4 -import yaml - import primaite from primaite import getLogger, SESSIONS_DIR from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig from primaite.data_viz.session_plots import plot_av_reward_per_episode from primaite.environment.primaite_env import Primaite +from primaite.utils.session_metadata_parser import parse_session_metadata _LOGGER = getLogger(__name__) @@ -253,47 +252,21 @@ class AgentSessionABC(ABC): def load(self, path: Union[str, Path]): """Load an agent from file.""" - if not isinstance(path, Path): - path = Path(path) + md_dict, training_config_path, laydown_config_path = parse_session_metadata(path) - if path.exists(): - # Unpack the session_metadata.json file - md_file = path / "session_metadata.json" - with open(md_file, "r") as file: - md_dict = json.load(file) + # set training config path + self._training_config_path: Union[Path, str] = training_config_path + self._training_config: TrainingConfig = training_config.load(self._training_config_path) + self._lay_down_config_path: Union[Path, str] = laydown_config_path + self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) + self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level - # Create a temp directory and dump the training and lay down - # configs into it - temp_dir = path / ".temp" - temp_dir.mkdir(exist_ok=True) + # set random UUID for session + self._uuid = md_dict["uuid"] - temp_tc = temp_dir / "tc.yaml" - with open(temp_tc, "w") as file: - yaml.dump(md_dict["env"]["training_config"], file) - - temp_ldc = temp_dir / "ldc.yaml" - with open(temp_ldc, "w") as file: - yaml.dump(md_dict["env"]["lay_down_config"], file) - - # set training config path - self._training_config_path: Union[Path, str] = temp_tc - self._training_config: TrainingConfig = training_config.load(self._training_config_path) - self._lay_down_config_path: Union[Path, str] = temp_ldc - self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) - self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level - - # set random UUID for session - self._uuid = md_dict["uuid"] - - # set the session path - self.session_path = path - "The Session path" - - else: - # Session path does not exist - msg = f"Failed to load PrimAITE Session, path does not exist: {path}" - _LOGGER.error(msg) - raise FileNotFoundError(msg) + # set the session path + self.session_path = path + "The Session path" @property def _saved_agent_path(self) -> Path: diff --git a/src/primaite/agents/hardcoded_abc.py b/src/primaite/agents/hardcoded_abc.py index 2c00c6c8..cfee3e16 100644 --- a/src/primaite/agents/hardcoded_abc.py +++ b/src/primaite/agents/hardcoded_abc.py @@ -1,5 +1,7 @@ import time from abc import abstractmethod +from pathlib import Path +from typing import Optional, Union from primaite import getLogger from primaite.agents.agent_abc import AgentSessionABC @@ -16,7 +18,12 @@ class HardCodedAgentSessionABC(AgentSessionABC): implemented. """ - def __init__(self, training_config_path, lay_down_config_path): + def __init__( + self, + training_config_path: Optional[Union[str, Path]] = "", + lay_down_config_path: Optional[Union[str, Path]] = "", + session_path: Optional[Union[str, Path]] = None, + ): """ Initialise a hardcoded agent session. @@ -26,7 +33,7 @@ class HardCodedAgentSessionABC(AgentSessionABC): :param lay_down_config_path: YAML file containing configurable items for generating network laydown. :type lay_down_config_path: Union[path, str] """ - super().__init__(training_config_path, lay_down_config_path) + super().__init__(training_config_path, lay_down_config_path, session_path) self._setup() def _setup(self): diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 190e4234..1707cb81 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -4,7 +4,7 @@ import json import shutil from datetime import datetime from pathlib import Path -from typing import Union +from typing import Optional, Union from uuid import uuid4 from ray.rllib.algorithms import Algorithm @@ -43,7 +43,12 @@ def _custom_log_creator(session_path: Path): class RLlibAgent(AgentSessionABC): """An AgentSession class that implements a Ray RLlib agent.""" - def __init__(self, training_config_path, lay_down_config_path): + def __init__( + self, + training_config_path: Optional[Union[str, Path]] = "", + lay_down_config_path: Optional[Union[str, Path]] = "", + session_path: Optional[Union[str, Path]] = None, + ): """ Initialise the RLLib Agent training session. @@ -56,6 +61,13 @@ class RLlibAgent(AgentSessionABC): :raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO` or `A2C`) """ + # TODO: implement RLlib agent loading + if session_path is not None: + msg = "RLlib agent loading has not been implemented yet" + _LOGGER.error(msg) + print(msg) + raise NotImplementedError + super().__init__(training_config_path, lay_down_config_path) if not self._training_config.agent_framework == AgentFramework.RLLIB: msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}" diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index da4d29f0..4dab5cb6 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -2,7 +2,7 @@ from __future__ import annotations from pathlib import Path -from typing import Dict, Final, Union +from typing import Dict, Final, Optional, Union from primaite import getLogger from primaite.agents.agent_abc import AgentSessionABC @@ -14,6 +14,7 @@ from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyA from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig +from primaite.utils.session_metadata_parser import parse_session_metadata _LOGGER = getLogger(__name__) @@ -27,8 +28,9 @@ class PrimaiteSession: def __init__( self, - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], + training_config_path: Optional[Union[str, Path]] = "", + lay_down_config_path: Optional[Union[str, Path]] = "", + session_path: Optional[Union[str, Path]] = None, ): """ The PrimaiteSession constructor. @@ -36,6 +38,25 @@ class PrimaiteSession: :param training_config_path: The training config path. :param lay_down_config_path: The lay down config path. """ + self._agent_session: AgentSessionABC = None # noqa + self.session_path: Path = session_path # noqa + self.timestamp_str: str = None # noqa + self.learning_path: Path = None # noqa + self.evaluation_path: Path = None # noqa + + # check if session path is provided + if session_path is not None: + # set load_session to true + self.is_load_session = True + if not isinstance(session_path, Path): + session_path = Path(session_path) + + # if a session path is provided, load it + if not session_path.exists(): + raise Exception(f"Session could not be loaded. Path does not exist: {session_path}") + + md_dict, training_config_path, lay_down_config_path = parse_session_metadata(session_path) + if not isinstance(training_config_path, Path): training_config_path = Path(training_config_path) self._training_config_path: Final[Union[Path, str]] = training_config_path @@ -46,12 +67,6 @@ class PrimaiteSession: self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) - self._agent_session: AgentSessionABC = None # noqa - self.session_path: Path = None # noqa - self.timestamp_str: str = None # noqa - self.learning_path: Path = None # noqa - self.evaluation_path: Path = None # noqa - def setup(self): """Performs the session setup.""" if self._training_config.agent_framework == AgentFramework.CUSTOM: @@ -60,11 +75,15 @@ class PrimaiteSession: _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}") if self._training_config.action_type == ActionType.NODE: # Deterministic Hardcoded Agent with Node Action Space - self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = HardCodedNodeAgent( + self._training_config_path, self._lay_down_config_path, self.session_path + ) elif self._training_config.action_type == ActionType.ACL: # Deterministic Hardcoded Agent with ACL Action Space - self._agent_session = HardCodedACLAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = HardCodedACLAgent( + self._training_config_path, self._lay_down_config_path, self.session_path + ) elif self._training_config.action_type == ActionType.ANY: # Deterministic Hardcoded Agent with ANY Action Space @@ -77,11 +96,15 @@ class PrimaiteSession: elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING: _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}") if self._training_config.action_type == ActionType.NODE: - self._agent_session = DoNothingNodeAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = DoNothingNodeAgent( + self._training_config_path, self._lay_down_config_path, self.session_path + ) elif self._training_config.action_type == ActionType.ACL: # Deterministic Hardcoded Agent with ACL Action Space - self._agent_session = DoNothingACLAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = DoNothingACLAgent( + self._training_config_path, self._lay_down_config_path, self.session_path + ) elif self._training_config.action_type == ActionType.ANY: # Deterministic Hardcoded Agent with ANY Action Space @@ -93,10 +116,14 @@ class PrimaiteSession: elif self._training_config.agent_identifier == AgentIdentifier.RANDOM: _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}") - self._agent_session = RandomAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = RandomAgent( + self._training_config_path, self._lay_down_config_path, self.session_path + ) elif self._training_config.agent_identifier == AgentIdentifier.DUMMY: _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}") - self._agent_session = DummyAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = DummyAgent( + self._training_config_path, self._lay_down_config_path, self.session_path + ) else: # Invalid AgentFramework AgentIdentifier combo @@ -105,12 +132,12 @@ class PrimaiteSession: elif self._training_config.agent_framework == AgentFramework.SB3: _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}") # Stable Baselines3 Agent - self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path) + self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path, self.session_path) elif self._training_config.agent_framework == AgentFramework.RLLIB: _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}") # Ray RLlib Agent - self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path) + self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path, self.session_path) else: # Invalid AgentFramework diff --git a/src/primaite/utils/session_metadata_parser.py b/src/primaite/utils/session_metadata_parser.py new file mode 100644 index 00000000..936d3269 --- /dev/null +++ b/src/primaite/utils/session_metadata_parser.py @@ -0,0 +1,58 @@ +import json +from pathlib import Path +from typing import Union + +import yaml + +from primaite import getLogger + +_LOGGER = getLogger(__name__) + + +def parse_session_metadata(session_path: Union[Path, str], dict_only=False): + """ + Loads a session metadata from the given directory path. + + :param session_path: Directory where the session metadata file is in + :param dict_only: If dict_only is true, the function will only return the dict contents of session metadata + + :return: Dictionary which has all the session metadata contents + :rtype: Dict + + :return: Path where the YAML copy of the training config is dumped into + :rtype: str + :return: Path where the YAML copy of the laydown config is dumped into + :rtype: str + """ + if not isinstance(session_path, Path): + session_path = Path(session_path) + + if not session_path.exists(): + # Session path does not exist + msg = f"Failed to load PrimAITE Session, path does not exist: {session_path}" + _LOGGER.error(msg) + raise FileNotFoundError(msg) + + # Unpack the session_metadata.json file + md_file = session_path / "session_metadata.json" + with open(md_file, "r") as file: + md_dict = json.load(file) + + # if dict only, return dict without doing anything else + if dict_only: + return md_dict + + # Create a temp directory and dump the training and lay down + # configs into it + temp_dir = session_path / ".temp" + temp_dir.mkdir(exist_ok=True) + + temp_tc = temp_dir / "tc.yaml" + with open(temp_tc, "w") as file: + yaml.dump(md_dict["env"]["training_config"], file) + + temp_ldc = temp_dir / "ldc.yaml" + with open(temp_ldc, "w") as file: + yaml.dump(md_dict["env"]["lay_down_config"], file) + + return [md_dict, temp_tc, temp_ldc] diff --git a/tests/test_session_loading.py b/tests/test_session_loading.py index a59a6e00..d79b0dde 100644 --- a/tests/test_session_loading.py +++ b/tests/test_session_loading.py @@ -8,6 +8,7 @@ from uuid import uuid4 from primaite import getLogger from primaite.agents.sb3 import SB3Agent from primaite.common.enums import AgentFramework, AgentIdentifier +from primaite.primaite_session import PrimaiteSession from primaite.utils.session_output_reader import av_rewards_dict from tests import TEST_ASSETS_ROOT @@ -96,4 +97,59 @@ def test_load_sb3_session(): def test_load_primaite_session(): """Test that loading a Primaite session works.""" - pass + expected_learn_mean_reward_per_episode = { + 10: 0, + 11: -0.008037109374999995, + 12: -0.007978515624999988, + 13: -0.008191406249999991, + 14: -0.00817578124999999, + 15: -0.008085937499999998, + 16: -0.007837890624999982, + 17: -0.007798828124999992, + 18: -0.007777343749999998, + 19: -0.007958984374999988, + 20: -0.0077499999999999835, + } + + test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session") + + # create loaded session + session = PrimaiteSession(session_path=test_path) + + # run setup on session + session.setup() + + # make sure that the session was loaded correctly + assert session._agent_session.uuid == "301874d3-2e14-43c2-ba7f-e2b03ad05dde" + assert session._agent_session._training_config.agent_framework == AgentFramework.SB3.name + assert session._agent_session._training_config.agent_identifier == AgentIdentifier.PPO.name + assert session._agent_session._training_config.deterministic + assert session._agent_session._training_config.seed == 12345 + assert str(session._agent_session.session_path) == str(test_path) + + # run another learn session + session.learn() + + learn_mean_rewards = av_rewards_dict( + session.learning_path / f"average_reward_per_episode_{session.timestamp_str}.csv" + ) + + # run is seeded so should have the expected learn value + assert learn_mean_rewards == expected_learn_mean_reward_per_episode + + # run an evaluation + session.evaluate() + + # load the evaluation average reward csv file + eval_mean_reward = av_rewards_dict( + session.evaluation_path / f"average_reward_per_episode_{session.timestamp_str}.csv" + ) + + # the agent config ran the evaluation in deterministic mode, so should have the same reward value + assert len(set(eval_mean_reward.values())) == 1 + + # the evaluation should be the same as a previous run + assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988 + + # delete the test directory + shutil.rmtree(test_path)