diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index ab1b3af3..96bb0737 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -1,286 +1,287 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -from __future__ import annotations +# # © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK +# from __future__ import annotations -import json -import shutil -import zipfile -from datetime import datetime -from logging import Logger -from pathlib import Path -from typing import Any, Callable, Dict, Optional, Union -from uuid import uuid4 +# import json +# import shutil +# import zipfile +# from datetime import datetime +# from logging import Logger +# from pathlib import Path +# from typing import Any, Callable, Dict, Optional, Union +# from uuid import uuid4 -from ray.rllib.algorithms import Algorithm -from ray.rllib.algorithms.a2c import A2CConfig -from ray.rllib.algorithms.ppo import PPOConfig -from ray.tune.logger import UnifiedLogger -from ray.tune.registry import register_env +# from primaite import getLogger +# from primaite.agents.agent_abc import AgentSessionABC +# from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType +# from primaite.environment.primaite_env import Primaite -from primaite import getLogger -from primaite.agents.agent_abc import AgentSessionABC -from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType -from primaite.environment.primaite_env import Primaite -from primaite.exceptions import RLlibAgentError - -_LOGGER: Logger = getLogger(__name__) +# # from ray.rllib.algorithms import Algorithm +# # from ray.rllib.algorithms.a2c import A2CConfig +# # from ray.rllib.algorithms.ppo import PPOConfig +# # from ray.tune.logger import UnifiedLogger +# # from ray.tune.registry import register_env -# TODO: verify type of env_config -def _env_creator(env_config: Dict[str, Any]) -> Primaite: - return Primaite( - training_config_path=env_config["training_config_path"], - lay_down_config_path=env_config["lay_down_config_path"], - session_path=env_config["session_path"], - timestamp_str=env_config["timestamp_str"], - ) +# # from primaite.exceptions import RLlibAgentError + +# _LOGGER: Logger = getLogger(__name__) -# TODO: verify type hint return type -def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]: - logdir = session_path / "ray_results" - logdir.mkdir(parents=True, exist_ok=True) +# # TODO: verify type of env_config +# def _env_creator(env_config: Dict[str, Any]) -> Primaite: +# return Primaite( +# training_config_path=env_config["training_config_path"], +# lay_down_config_path=env_config["lay_down_config_path"], +# session_path=env_config["session_path"], +# timestamp_str=env_config["timestamp_str"], +# ) - def logger_creator(config: Dict) -> UnifiedLogger: - return UnifiedLogger(config, logdir, loggers=None) +# # # TODO: verify type hint return type +# # def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]: +# # logdir = session_path / "ray_results" +# # logdir.mkdir(parents=True, exist_ok=True) - return logger_creator +# # def logger_creator(config: Dict) -> UnifiedLogger: +# # return UnifiedLogger(config, logdir, loggers=None) + +# return logger_creator -class RLlibAgent(AgentSessionABC): - """An AgentSession class that implements a Ray RLlib agent.""" +# # class RLlibAgent(AgentSessionABC): +# # """An AgentSession class that implements a Ray RLlib agent.""" - def __init__( - self, - training_config_path: Optional[Union[str, Path]] = "", - lay_down_config_path: Optional[Union[str, Path]] = "", - session_path: Optional[Union[str, Path]] = None, - ) -> None: - """ - Initialise the RLLib Agent training session. +# # def __init__( +# # self, +# # training_config_path: Optional[Union[str, Path]] = "", +# # lay_down_config_path: Optional[Union[str, Path]] = "", +# # session_path: Optional[Union[str, Path]] = None, +# # ) -> None: +# # """ +# # Initialise the RLLib Agent training session. - :param training_config_path: YAML file containing configurable items defined in - `primaite.config.training_config.TrainingConfig` - :type training_config_path: Union[path, str] - :param lay_down_config_path: YAML file containing configurable items for generating network laydown. - :type lay_down_config_path: Union[path, str] - :raises ValueError: If the training config contains an unexpected value for agent_framework (should be "RLLIB") - :raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO` - or `A2C`) - """ - # TODO: implement RLlib agent loading - if session_path is not None: - msg = "RLlib agent loading has not been implemented yet" - _LOGGER.critical(msg) - raise NotImplementedError(msg) +# # :param training_config_path: YAML file containing configurable items defined in +# # `primaite.config.training_config.TrainingConfig` +# # :type training_config_path: Union[path, str] +# # :param lay_down_config_path: YAML file containing configurable items for generating network laydown. +# # :type lay_down_config_path: Union[path, str] +# # :raises ValueError: If the training config contains a bad value for agent_framework (should be "RLLIB") +# # :raises ValueError: If the training config contains a bad value for agent_identifies (should be `PPO` +# # or `A2C`) +# # """ +# # # TODO: implement RLlib agent loading +# # if session_path is not None: +# # msg = "RLlib agent loading has not been implemented yet" +# # _LOGGER.critical(msg) +# # raise NotImplementedError(msg) - super().__init__(training_config_path, lay_down_config_path) - if self._training_config.session_type == SessionType.EVAL: - msg = "Cannot evaluate an RLlib agent that hasn't been through training yet." - _LOGGER.critical(msg) - raise RLlibAgentError(msg) - if not self._training_config.agent_framework == AgentFramework.RLLIB: - msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}" - _LOGGER.error(msg) - raise ValueError(msg) - self._agent_config_class: Union[PPOConfig, A2CConfig] - if self._training_config.agent_identifier == AgentIdentifier.PPO: - self._agent_config_class = PPOConfig - elif self._training_config.agent_identifier == AgentIdentifier.A2C: - self._agent_config_class = A2CConfig - else: - msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}" - _LOGGER.error(msg) - raise ValueError(msg) - self._agent_config: Union[PPOConfig, A2CConfig] +# # super().__init__(training_config_path, lay_down_config_path) +# # if self._training_config.session_type == SessionType.EVAL: +# # msg = "Cannot evaluate an RLlib agent that hasn't been through training yet." +# # _LOGGER.critical(msg) +# # raise RLlibAgentError(msg) +# # if not self._training_config.agent_framework == AgentFramework.RLLIB: +# # msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}" +# # _LOGGER.error(msg) +# # raise ValueError(msg) +# # self._agent_config_class: Union[PPOConfig, A2CConfig] +# # if self._training_config.agent_identifier == AgentIdentifier.PPO: +# # self._agent_config_class = PPOConfig +# # elif self._training_config.agent_identifier == AgentIdentifier.A2C: +# # self._agent_config_class = A2CConfig +# # else: +# # msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}" +# # _LOGGER.error(msg) +# # raise ValueError(msg) +# # self._agent_config: Union[PPOConfig, A2CConfig] - self._current_result: dict - self._setup() - _LOGGER.debug( - f"Created {self.__class__.__name__} using: " - f"agent_framework={self._training_config.agent_framework}, " - f"agent_identifier=" - f"{self._training_config.agent_identifier}, " - f"deep_learning_framework=" - f"{self._training_config.deep_learning_framework}" - ) - self._train_agent = None # Required to capture the learning agent to close after eval +# # self._current_result: dict +# # self._setup() +# # _LOGGER.debug( +# # f"Created {self.__class__.__name__} using: " +# # f"agent_framework={self._training_config.agent_framework}, " +# # f"agent_identifier=" +# # f"{self._training_config.agent_identifier}, " +# # f"deep_learning_framework=" +# # f"{self._training_config.deep_learning_framework}" +# # ) +# # self._train_agent = None # Required to capture the learning agent to close after eval - def _update_session_metadata_file(self) -> None: - """ - Update the ``session_metadata.json`` file. +# # def _update_session_metadata_file(self) -> None: +# # """ +# # Update the ``session_metadata.json`` file. - Updates the `session_metadata.json`` in the ``session_path`` directory - with the following key/value pairs: +# # Updates the `session_metadata.json`` in the ``session_path`` directory +# # with the following key/value pairs: - - end_datetime: The date & time the session ended in iso format. - - total_episodes: The total number of training episodes completed. - - total_time_steps: The total number of training time steps completed. - """ - with open(self.session_path / "session_metadata.json", "r") as file: - metadata_dict = json.load(file) +# # - end_datetime: The date & time the session ended in iso format. +# # - total_episodes: The total number of training episodes completed. +# # - total_time_steps: The total number of training time steps completed. +# # """ +# # with open(self.session_path / "session_metadata.json", "r") as file: +# # metadata_dict = json.load(file) - metadata_dict["end_datetime"] = datetime.now().isoformat() - if not self.is_eval: - metadata_dict["learning"]["total_episodes"] = self._current_result["episodes_total"] # noqa - metadata_dict["learning"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa - else: - metadata_dict["evaluation"]["total_episodes"] = self._current_result["episodes_total"] # noqa - metadata_dict["evaluation"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa +# # metadata_dict["end_datetime"] = datetime.now().isoformat() +# # if not self.is_eval: +# # metadata_dict["learning"]["total_episodes"] = self._current_result["episodes_total"] # noqa +# # metadata_dict["learning"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa +# # else: +# # metadata_dict["evaluation"]["total_episodes"] = self._current_result["episodes_total"] # noqa +# # metadata_dict["evaluation"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa - filepath = self.session_path / "session_metadata.json" - _LOGGER.debug(f"Updating Session Metadata file: {filepath}") - with open(filepath, "w") as file: - json.dump(metadata_dict, file) - _LOGGER.debug("Finished updating session metadata file") +# # filepath = self.session_path / "session_metadata.json" +# # _LOGGER.debug(f"Updating Session Metadata file: {filepath}") +# # with open(filepath, "w") as file: +# # json.dump(metadata_dict, file) +# # _LOGGER.debug("Finished updating session metadata file") - def _setup(self) -> None: - super()._setup() - register_env("primaite", _env_creator) - self._agent_config = self._agent_config_class() +# # def _setup(self) -> None: +# # super()._setup() +# # register_env("primaite", _env_creator) +# # self._agent_config = self._agent_config_class() - self._agent_config.environment( - env="primaite", - env_config=dict( - training_config_path=self._training_config_path, - lay_down_config_path=self._lay_down_config_path, - session_path=self.session_path, - timestamp_str=self.timestamp_str, - ), - ) - self._agent_config.seed = self._training_config.seed +# # self._agent_config.environment( +# # env="primaite", +# # env_config=dict( +# # training_config_path=self._training_config_path, +# # lay_down_config_path=self._lay_down_config_path, +# # session_path=self.session_path, +# # timestamp_str=self.timestamp_str, +# # ), +# # ) +# # self._agent_config.seed = self._training_config.seed - self._agent_config.training(train_batch_size=self._training_config.num_train_steps) - self._agent_config.framework(framework="tf") +# # self._agent_config.training(train_batch_size=self._training_config.num_train_steps) +# # self._agent_config.framework(framework="tf") - self._agent_config.rollouts( - num_rollout_workers=1, - num_envs_per_worker=1, - horizon=self._training_config.num_train_steps, - ) - self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path)) +# # self._agent_config.rollouts( +# # num_rollout_workers=1, +# # num_envs_per_worker=1, +# # horizon=self._training_config.num_train_steps, +# # ) +# # self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path)) - def _save_checkpoint(self) -> None: - checkpoint_n = self._training_config.checkpoint_every_n_episodes - episode_count = self._current_result["episodes_total"] - save_checkpoint = False - if checkpoint_n: - save_checkpoint = episode_count % checkpoint_n == 0 - if episode_count and save_checkpoint: - self._agent.save(str(self.checkpoints_path)) +# # def _save_checkpoint(self) -> None: +# # checkpoint_n = self._training_config.checkpoint_every_n_episodes +# # episode_count = self._current_result["episodes_total"] +# # save_checkpoint = False +# # if checkpoint_n: +# # save_checkpoint = episode_count % checkpoint_n == 0 +# # if episode_count and save_checkpoint: +# # self._agent.save(str(self.checkpoints_path)) - def learn( - self, - **kwargs: Any, - ) -> None: - """ - Evaluate the agent. +# # def learn( +# # self, +# # **kwargs: Any, +# # ) -> None: +# # """ +# # Evaluate the agent. - :param kwargs: Any agent-specific key-word args to be passed. - """ - time_steps = self._training_config.num_train_steps - episodes = self._training_config.num_train_episodes +# # :param kwargs: Any agent-specific key-word args to be passed. +# # """ +# # time_steps = self._training_config.num_train_steps +# # episodes = self._training_config.num_train_episodes - _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") - for i in range(episodes): - self._current_result = self._agent.train() - self._save_checkpoint() - self.save() - super().learn() - # Done this way as the RLlib eval can only be performed if the session hasn't been stopped - if self._training_config.session_type is not SessionType.TRAIN: - self._train_agent = self._agent - else: - self._agent.stop() - self._plot_av_reward_per_episode(learning_session=True) +# # _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") +# # for i in range(episodes): +# # self._current_result = self._agent.train() +# # self._save_checkpoint() +# # self.save() +# # super().learn() +# # # Done this way as the RLlib eval can only be performed if the session hasn't been stopped +# # if self._training_config.session_type is not SessionType.TRAIN: +# # self._train_agent = self._agent +# # else: +# # self._agent.stop() +# # self._plot_av_reward_per_episode(learning_session=True) - def _unpack_saved_agent_into_eval(self) -> Path: - """Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval.""" - agent_restore_path = self.evaluation_path / "agent_restore" - if agent_restore_path.exists(): - shutil.rmtree(agent_restore_path) - agent_restore_path.mkdir() - with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file: - zip_file.extractall(agent_restore_path) - return agent_restore_path +# # def _unpack_saved_agent_into_eval(self) -> Path: +# # """Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval.""" +# # agent_restore_path = self.evaluation_path / "agent_restore" +# # if agent_restore_path.exists(): +# # shutil.rmtree(agent_restore_path) +# # agent_restore_path.mkdir() +# # with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file: +# # zip_file.extractall(agent_restore_path) +# # return agent_restore_path - def _setup_eval(self): - self._can_learn = False - self._can_evaluate = True - self._agent.restore(str(self._unpack_saved_agent_into_eval())) +# # def _setup_eval(self): +# # self._can_learn = False +# # self._can_evaluate = True +# # self._agent.restore(str(self._unpack_saved_agent_into_eval())) - def evaluate( - self, - **kwargs, - ): - """ - Evaluate the agent. +# # def evaluate( +# # self, +# # **kwargs, +# # ): +# # """ +# # Evaluate the agent. - :param kwargs: Any agent-specific key-word args to be passed. - """ - time_steps = self._training_config.num_eval_steps - episodes = self._training_config.num_eval_episodes +# # :param kwargs: Any agent-specific key-word args to be passed. +# # """ +# # time_steps = self._training_config.num_eval_steps +# # episodes = self._training_config.num_eval_episodes - self._setup_eval() +# # self._setup_eval() - self._env: Primaite = Primaite( - self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str - ) +# # self._env: Primaite = Primaite( +# # self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str +# # ) - self._env.set_as_eval() - self.is_eval = True - if self._training_config.deterministic: - deterministic_str = "deterministic" - else: - deterministic_str = "non-deterministic" - _LOGGER.info( - f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..." - ) - for episode in range(episodes): - obs = self._env.reset() - for step in range(time_steps): - action = self._agent.compute_single_action(observation=obs, explore=False) +# # self._env.set_as_eval() +# # self.is_eval = True +# # if self._training_config.deterministic: +# # deterministic_str = "deterministic" +# # else: +# # deterministic_str = "non-deterministic" +# # _LOGGER.info( +# # f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..." +# # ) +# # for episode in range(episodes): +# # obs = self._env.reset() +# # for step in range(time_steps): +# # action = self._agent.compute_single_action(observation=obs, explore=False) - obs, rewards, done, info = self._env.step(action) +# # obs, rewards, done, info = self._env.step(action) - self._env.reset() - self._env.close() - super().evaluate() - # Now we're safe to close the learning agent and write the mean rewards per episode for it - if self._training_config.session_type is not SessionType.TRAIN: - self._train_agent.stop() - self._plot_av_reward_per_episode(learning_session=True) - # Perform a clean-up of the unpacked agent - if (self.evaluation_path / "agent_restore").exists(): - shutil.rmtree((self.evaluation_path / "agent_restore")) +# # self._env.reset() +# # self._env.close() +# # super().evaluate() +# # # Now we're safe to close the learning agent and write the mean rewards per episode for it +# # if self._training_config.session_type is not SessionType.TRAIN: +# # self._train_agent.stop() +# # self._plot_av_reward_per_episode(learning_session=True) +# # # Perform a clean-up of the unpacked agent +# # if (self.evaluation_path / "agent_restore").exists(): +# # shutil.rmtree((self.evaluation_path / "agent_restore")) - def _get_latest_checkpoint(self) -> None: - raise NotImplementedError +# # def _get_latest_checkpoint(self) -> None: +# # raise NotImplementedError - @classmethod - def load(cls, path: Union[str, Path]) -> RLlibAgent: - """Load an agent from file.""" - raise NotImplementedError +# # @classmethod +# # def load(cls, path: Union[str, Path]) -> RLlibAgent: +# # """Load an agent from file.""" +# # raise NotImplementedError - def save(self, overwrite_existing: bool = True) -> None: - """Save the agent.""" - # Make temp dir to save in isolation - temp_dir = self.learning_path / str(uuid4()) - temp_dir.mkdir() +# # def save(self, overwrite_existing: bool = True) -> None: +# # """Save the agent.""" +# # # Make temp dir to save in isolation +# # temp_dir = self.learning_path / str(uuid4()) +# # temp_dir.mkdir() - # Save the agent to the temp dir - self._agent.save(str(temp_dir)) +# # # Save the agent to the temp dir +# # self._agent.save(str(temp_dir)) - # Capture the saved Rllib checkpoint inside the temp directory - for file in temp_dir.iterdir(): - checkpoint_dir = file - break +# # # Capture the saved Rllib checkpoint inside the temp directory +# # for file in temp_dir.iterdir(): +# # checkpoint_dir = file +# # break - # Zip the folder - shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa +# # # Zip the folder +# # shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa - # Drop the temp directory - shutil.rmtree(temp_dir) +# # # Drop the temp directory +# # shutil.rmtree(temp_dir) - def export(self) -> None: - """Export the agent to transportable file format.""" - raise NotImplementedError +# # def export(self) -> None: +# # """Export the agent to transportable file format.""" +# # raise NotImplementedError diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 006301f1..c33e764b 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -99,8 +99,8 @@ class AgentFramework(Enum): "Custom Agent" SB3 = 1 "Stable Baselines3" - RLLIB = 2 - "Ray RLlib" + # RLLIB = 2 + # "Ray RLlib" class DeepLearningFramework(Enum): diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 7f5dc568..f81bb6f7 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -248,8 +248,8 @@ class TrainingConfig: def __str__(self) -> str: obs_str = ",".join([c["name"] for c in self.observation_space["components"]]) tc = f"{self.agent_framework}, " - if self.agent_framework is AgentFramework.RLLIB: - tc += f"{self.deep_learning_framework}, " + # if self.agent_framework is AgentFramework.RLLIB: + # tc += f"{self.deep_learning_framework}, " tc += f"{self.agent_identifier}, " if self.agent_identifier is AgentIdentifier.HARDCODED: tc += f"{self.hard_coded_agent_view}, " diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 62af6c5b..a809772f 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -17,9 +17,8 @@ from primaite import getLogger from primaite.acl.access_control_list import AccessControlList from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action from primaite.common.custom_typing import NodeUnion -from primaite.common.enums import ( +from primaite.common.enums import ( # AgentFramework, ActionType, - AgentFramework, AgentIdentifier, FileSystemState, HardwareState, @@ -236,7 +235,8 @@ class Primaite(Env): _LOGGER.debug("Action space type NODE selected") # Terms (for node action space): # [0, num nodes] - node ID (0 = nothing, node ID) - # [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa + # [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, # noqa + # service state, file system state) # noqa # [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa # [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa self.action_dict = self.create_node_action_dict() @@ -271,8 +271,8 @@ class Primaite(Env): @property def actual_episode_count(self) -> int: """Shifts the episode_count by -1 for RLlib learning session.""" - if self.training_config.agent_framework is AgentFramework.RLLIB and not self.is_eval: - return self.episode_count - 1 + # if self.training_config.agent_framework is AgentFramework.RLLIB and not self.is_eval: + # return self.episode_count - 1 return self.episode_count def set_as_eval(self) -> None: diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index 2cb0d5bd..7d5b2709 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -10,7 +10,8 @@ from primaite import getLogger from primaite.agents.agent_abc import AgentSessionABC from primaite.agents.hardcoded_acl import HardCodedACLAgent from primaite.agents.hardcoded_node import HardCodedNodeAgent -from primaite.agents.rllib import RLlibAgent + +# from primaite.agents.rllib import RLlibAgent from primaite.agents.sb3 import SB3Agent from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyAgent, RandomAgent from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType @@ -157,10 +158,12 @@ class PrimaiteSession: self.legacy_lay_down_config, ) - elif self._training_config.agent_framework == AgentFramework.RLLIB: - _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}") - # Ray RLlib Agent - self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path, self.session_path) + # elif self._training_config.agent_framework == AgentFramework.RLLIB: + # _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}") + # # Ray RLlib Agent + # self._agent_session = RLlibAgent( + # self._training_config_path, self._lay_down_config_path, self.session_path + # ) else: # Invalid AgentFramework diff --git a/tests/test_primaite_session.py b/tests/test_primaite_session.py index b76a2ecf..6e23b3ac 100644 --- a/tests/test_primaite_session.py +++ b/tests/test_primaite_session.py @@ -13,7 +13,7 @@ _LOGGER = getLogger(__name__) @pytest.mark.parametrize( "temp_primaite_session", [ - [TEST_CONFIG_ROOT / "session_test/training_config_main_rllib.yaml", dos_very_basic_config_path()], + # [TEST_CONFIG_ROOT / "session_test/training_config_main_rllib.yaml", dos_very_basic_config_path()], [TEST_CONFIG_ROOT / "session_test/training_config_main_sb3.yaml", dos_very_basic_config_path()], ], indirect=True,