diff --git a/docs/index.rst b/docs/index.rst index 2c7d4690..c0e7a007 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -43,6 +43,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE! source/config source/primaite_session source/custom_agent + source/simulation PrimAITE API PrimAITE Tests source/dependencies diff --git a/docs/source/custom_agent.rst b/docs/source/custom_agent.rst index 8a95d3ae..040b4b3d 100644 --- a/docs/source/custom_agent.rst +++ b/docs/source/custom_agent.rst @@ -13,7 +13,7 @@ Integrating a user defined blue agent If you are planning to implement custom RL agents into PrimAITE, you must use the project as a repository. If you install PrimAITE as a python package from wheel, custom agents are not supported. -PrimAITE has integration with Ray RLLib and StableBaselines3 agents. All agents interface with PrimAITE through an :py:class:`primaite.agents.agent.AgentSessionABC` which provides Input/Output of agent savefiles, as well as capturing and plotting performance metrics during training and evaluation. If you wish to integrate a custom blue agent, it is recommended to create a subclass of the :py:class:`primaite.agents.agent.AgentSessionABC` and implement the ``__init__()``, ``_setup()``, ``_save_checkpoint()``, ``learn()``, ``evaluate()``, ``_get_latest_checkpoint``, ``load()``, and ``save()`` methods. +PrimAITE has integration with Ray RLLib and StableBaselines3 agents. All agents interface with PrimAITE through an :py:class:`primaite.agents.agent_abc.AgentSessionABC` which provides Input/Output of agent savefiles, as well as capturing and plotting performance metrics during training and evaluation. If you wish to integrate a custom blue agent, it is recommended to create a subclass of the :py:class:`primaite.agents.agent_abc.AgentSessionABC` and implement the ``__init__()``, ``_setup()``, ``_save_checkpoint()``, ``learn()``, ``evaluate()``, ``_get_latest_checkpoint``, ``load()``, and ``save()`` methods. Below is a barebones example of a custom agent implementation: @@ -21,7 +21,7 @@ Below is a barebones example of a custom agent implementation: # src/primaite/agents/my_custom_agent.py - from primaite.agents.agent import AgentSessionABC + from primaite.agents.agent_abc import AgentSessionABC from primaite.common.enums import AgentFramework, AgentIdentifier class CustomAgent(AgentSessionABC): diff --git a/docs/source/simulation.rst b/docs/source/simulation.rst new file mode 100644 index 00000000..1620f6ba --- /dev/null +++ b/docs/source/simulation.rst @@ -0,0 +1,10 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +Simulation Strucutre +==================== + +The simulation is made up of many smaller components which are related to each other in a tree-like structure. At the top level, there is an object called the ``SimulationController`` _(doesn't exist yet)_, which has a physical network and a software controller for managing software and users. + +Each node of the simulation 'tree' has responsibility for creating, deleting, and updating its direct descendants. diff --git a/pyproject.toml b/pyproject.toml index b66b0168..4e8250d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,10 +34,10 @@ dependencies = [ "plotly==5.15.0", "polars==0.18.4", "PyYAML==6.0", - "ray[rllib]==2.2.0", "stable-baselines3==1.6.2", "tensorflow==2.12.0", - "typer[all]==0.9.0" + "typer[all]==0.9.0", + "pydantic" ] [tool.setuptools.dynamic] diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index ab1b3af3..96bb0737 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -1,286 +1,287 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -from __future__ import annotations +# # © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK +# from __future__ import annotations -import json -import shutil -import zipfile -from datetime import datetime -from logging import Logger -from pathlib import Path -from typing import Any, Callable, Dict, Optional, Union -from uuid import uuid4 +# import json +# import shutil +# import zipfile +# from datetime import datetime +# from logging import Logger +# from pathlib import Path +# from typing import Any, Callable, Dict, Optional, Union +# from uuid import uuid4 -from ray.rllib.algorithms import Algorithm -from ray.rllib.algorithms.a2c import A2CConfig -from ray.rllib.algorithms.ppo import PPOConfig -from ray.tune.logger import UnifiedLogger -from ray.tune.registry import register_env +# from primaite import getLogger +# from primaite.agents.agent_abc import AgentSessionABC +# from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType +# from primaite.environment.primaite_env import Primaite -from primaite import getLogger -from primaite.agents.agent_abc import AgentSessionABC -from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType -from primaite.environment.primaite_env import Primaite -from primaite.exceptions import RLlibAgentError - -_LOGGER: Logger = getLogger(__name__) +# # from ray.rllib.algorithms import Algorithm +# # from ray.rllib.algorithms.a2c import A2CConfig +# # from ray.rllib.algorithms.ppo import PPOConfig +# # from ray.tune.logger import UnifiedLogger +# # from ray.tune.registry import register_env -# TODO: verify type of env_config -def _env_creator(env_config: Dict[str, Any]) -> Primaite: - return Primaite( - training_config_path=env_config["training_config_path"], - lay_down_config_path=env_config["lay_down_config_path"], - session_path=env_config["session_path"], - timestamp_str=env_config["timestamp_str"], - ) +# # from primaite.exceptions import RLlibAgentError + +# _LOGGER: Logger = getLogger(__name__) -# TODO: verify type hint return type -def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]: - logdir = session_path / "ray_results" - logdir.mkdir(parents=True, exist_ok=True) +# # TODO: verify type of env_config +# def _env_creator(env_config: Dict[str, Any]) -> Primaite: +# return Primaite( +# training_config_path=env_config["training_config_path"], +# lay_down_config_path=env_config["lay_down_config_path"], +# session_path=env_config["session_path"], +# timestamp_str=env_config["timestamp_str"], +# ) - def logger_creator(config: Dict) -> UnifiedLogger: - return UnifiedLogger(config, logdir, loggers=None) +# # # TODO: verify type hint return type +# # def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]: +# # logdir = session_path / "ray_results" +# # logdir.mkdir(parents=True, exist_ok=True) - return logger_creator +# # def logger_creator(config: Dict) -> UnifiedLogger: +# # return UnifiedLogger(config, logdir, loggers=None) + +# return logger_creator -class RLlibAgent(AgentSessionABC): - """An AgentSession class that implements a Ray RLlib agent.""" +# # class RLlibAgent(AgentSessionABC): +# # """An AgentSession class that implements a Ray RLlib agent.""" - def __init__( - self, - training_config_path: Optional[Union[str, Path]] = "", - lay_down_config_path: Optional[Union[str, Path]] = "", - session_path: Optional[Union[str, Path]] = None, - ) -> None: - """ - Initialise the RLLib Agent training session. +# # def __init__( +# # self, +# # training_config_path: Optional[Union[str, Path]] = "", +# # lay_down_config_path: Optional[Union[str, Path]] = "", +# # session_path: Optional[Union[str, Path]] = None, +# # ) -> None: +# # """ +# # Initialise the RLLib Agent training session. - :param training_config_path: YAML file containing configurable items defined in - `primaite.config.training_config.TrainingConfig` - :type training_config_path: Union[path, str] - :param lay_down_config_path: YAML file containing configurable items for generating network laydown. - :type lay_down_config_path: Union[path, str] - :raises ValueError: If the training config contains an unexpected value for agent_framework (should be "RLLIB") - :raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO` - or `A2C`) - """ - # TODO: implement RLlib agent loading - if session_path is not None: - msg = "RLlib agent loading has not been implemented yet" - _LOGGER.critical(msg) - raise NotImplementedError(msg) +# # :param training_config_path: YAML file containing configurable items defined in +# # `primaite.config.training_config.TrainingConfig` +# # :type training_config_path: Union[path, str] +# # :param lay_down_config_path: YAML file containing configurable items for generating network laydown. +# # :type lay_down_config_path: Union[path, str] +# # :raises ValueError: If the training config contains a bad value for agent_framework (should be "RLLIB") +# # :raises ValueError: If the training config contains a bad value for agent_identifies (should be `PPO` +# # or `A2C`) +# # """ +# # # TODO: implement RLlib agent loading +# # if session_path is not None: +# # msg = "RLlib agent loading has not been implemented yet" +# # _LOGGER.critical(msg) +# # raise NotImplementedError(msg) - super().__init__(training_config_path, lay_down_config_path) - if self._training_config.session_type == SessionType.EVAL: - msg = "Cannot evaluate an RLlib agent that hasn't been through training yet." - _LOGGER.critical(msg) - raise RLlibAgentError(msg) - if not self._training_config.agent_framework == AgentFramework.RLLIB: - msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}" - _LOGGER.error(msg) - raise ValueError(msg) - self._agent_config_class: Union[PPOConfig, A2CConfig] - if self._training_config.agent_identifier == AgentIdentifier.PPO: - self._agent_config_class = PPOConfig - elif self._training_config.agent_identifier == AgentIdentifier.A2C: - self._agent_config_class = A2CConfig - else: - msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}" - _LOGGER.error(msg) - raise ValueError(msg) - self._agent_config: Union[PPOConfig, A2CConfig] +# # super().__init__(training_config_path, lay_down_config_path) +# # if self._training_config.session_type == SessionType.EVAL: +# # msg = "Cannot evaluate an RLlib agent that hasn't been through training yet." +# # _LOGGER.critical(msg) +# # raise RLlibAgentError(msg) +# # if not self._training_config.agent_framework == AgentFramework.RLLIB: +# # msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}" +# # _LOGGER.error(msg) +# # raise ValueError(msg) +# # self._agent_config_class: Union[PPOConfig, A2CConfig] +# # if self._training_config.agent_identifier == AgentIdentifier.PPO: +# # self._agent_config_class = PPOConfig +# # elif self._training_config.agent_identifier == AgentIdentifier.A2C: +# # self._agent_config_class = A2CConfig +# # else: +# # msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}" +# # _LOGGER.error(msg) +# # raise ValueError(msg) +# # self._agent_config: Union[PPOConfig, A2CConfig] - self._current_result: dict - self._setup() - _LOGGER.debug( - f"Created {self.__class__.__name__} using: " - f"agent_framework={self._training_config.agent_framework}, " - f"agent_identifier=" - f"{self._training_config.agent_identifier}, " - f"deep_learning_framework=" - f"{self._training_config.deep_learning_framework}" - ) - self._train_agent = None # Required to capture the learning agent to close after eval +# # self._current_result: dict +# # self._setup() +# # _LOGGER.debug( +# # f"Created {self.__class__.__name__} using: " +# # f"agent_framework={self._training_config.agent_framework}, " +# # f"agent_identifier=" +# # f"{self._training_config.agent_identifier}, " +# # f"deep_learning_framework=" +# # f"{self._training_config.deep_learning_framework}" +# # ) +# # self._train_agent = None # Required to capture the learning agent to close after eval - def _update_session_metadata_file(self) -> None: - """ - Update the ``session_metadata.json`` file. +# # def _update_session_metadata_file(self) -> None: +# # """ +# # Update the ``session_metadata.json`` file. - Updates the `session_metadata.json`` in the ``session_path`` directory - with the following key/value pairs: +# # Updates the `session_metadata.json`` in the ``session_path`` directory +# # with the following key/value pairs: - - end_datetime: The date & time the session ended in iso format. - - total_episodes: The total number of training episodes completed. - - total_time_steps: The total number of training time steps completed. - """ - with open(self.session_path / "session_metadata.json", "r") as file: - metadata_dict = json.load(file) +# # - end_datetime: The date & time the session ended in iso format. +# # - total_episodes: The total number of training episodes completed. +# # - total_time_steps: The total number of training time steps completed. +# # """ +# # with open(self.session_path / "session_metadata.json", "r") as file: +# # metadata_dict = json.load(file) - metadata_dict["end_datetime"] = datetime.now().isoformat() - if not self.is_eval: - metadata_dict["learning"]["total_episodes"] = self._current_result["episodes_total"] # noqa - metadata_dict["learning"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa - else: - metadata_dict["evaluation"]["total_episodes"] = self._current_result["episodes_total"] # noqa - metadata_dict["evaluation"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa +# # metadata_dict["end_datetime"] = datetime.now().isoformat() +# # if not self.is_eval: +# # metadata_dict["learning"]["total_episodes"] = self._current_result["episodes_total"] # noqa +# # metadata_dict["learning"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa +# # else: +# # metadata_dict["evaluation"]["total_episodes"] = self._current_result["episodes_total"] # noqa +# # metadata_dict["evaluation"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa - filepath = self.session_path / "session_metadata.json" - _LOGGER.debug(f"Updating Session Metadata file: {filepath}") - with open(filepath, "w") as file: - json.dump(metadata_dict, file) - _LOGGER.debug("Finished updating session metadata file") +# # filepath = self.session_path / "session_metadata.json" +# # _LOGGER.debug(f"Updating Session Metadata file: {filepath}") +# # with open(filepath, "w") as file: +# # json.dump(metadata_dict, file) +# # _LOGGER.debug("Finished updating session metadata file") - def _setup(self) -> None: - super()._setup() - register_env("primaite", _env_creator) - self._agent_config = self._agent_config_class() +# # def _setup(self) -> None: +# # super()._setup() +# # register_env("primaite", _env_creator) +# # self._agent_config = self._agent_config_class() - self._agent_config.environment( - env="primaite", - env_config=dict( - training_config_path=self._training_config_path, - lay_down_config_path=self._lay_down_config_path, - session_path=self.session_path, - timestamp_str=self.timestamp_str, - ), - ) - self._agent_config.seed = self._training_config.seed +# # self._agent_config.environment( +# # env="primaite", +# # env_config=dict( +# # training_config_path=self._training_config_path, +# # lay_down_config_path=self._lay_down_config_path, +# # session_path=self.session_path, +# # timestamp_str=self.timestamp_str, +# # ), +# # ) +# # self._agent_config.seed = self._training_config.seed - self._agent_config.training(train_batch_size=self._training_config.num_train_steps) - self._agent_config.framework(framework="tf") +# # self._agent_config.training(train_batch_size=self._training_config.num_train_steps) +# # self._agent_config.framework(framework="tf") - self._agent_config.rollouts( - num_rollout_workers=1, - num_envs_per_worker=1, - horizon=self._training_config.num_train_steps, - ) - self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path)) +# # self._agent_config.rollouts( +# # num_rollout_workers=1, +# # num_envs_per_worker=1, +# # horizon=self._training_config.num_train_steps, +# # ) +# # self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path)) - def _save_checkpoint(self) -> None: - checkpoint_n = self._training_config.checkpoint_every_n_episodes - episode_count = self._current_result["episodes_total"] - save_checkpoint = False - if checkpoint_n: - save_checkpoint = episode_count % checkpoint_n == 0 - if episode_count and save_checkpoint: - self._agent.save(str(self.checkpoints_path)) +# # def _save_checkpoint(self) -> None: +# # checkpoint_n = self._training_config.checkpoint_every_n_episodes +# # episode_count = self._current_result["episodes_total"] +# # save_checkpoint = False +# # if checkpoint_n: +# # save_checkpoint = episode_count % checkpoint_n == 0 +# # if episode_count and save_checkpoint: +# # self._agent.save(str(self.checkpoints_path)) - def learn( - self, - **kwargs: Any, - ) -> None: - """ - Evaluate the agent. +# # def learn( +# # self, +# # **kwargs: Any, +# # ) -> None: +# # """ +# # Evaluate the agent. - :param kwargs: Any agent-specific key-word args to be passed. - """ - time_steps = self._training_config.num_train_steps - episodes = self._training_config.num_train_episodes +# # :param kwargs: Any agent-specific key-word args to be passed. +# # """ +# # time_steps = self._training_config.num_train_steps +# # episodes = self._training_config.num_train_episodes - _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") - for i in range(episodes): - self._current_result = self._agent.train() - self._save_checkpoint() - self.save() - super().learn() - # Done this way as the RLlib eval can only be performed if the session hasn't been stopped - if self._training_config.session_type is not SessionType.TRAIN: - self._train_agent = self._agent - else: - self._agent.stop() - self._plot_av_reward_per_episode(learning_session=True) +# # _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") +# # for i in range(episodes): +# # self._current_result = self._agent.train() +# # self._save_checkpoint() +# # self.save() +# # super().learn() +# # # Done this way as the RLlib eval can only be performed if the session hasn't been stopped +# # if self._training_config.session_type is not SessionType.TRAIN: +# # self._train_agent = self._agent +# # else: +# # self._agent.stop() +# # self._plot_av_reward_per_episode(learning_session=True) - def _unpack_saved_agent_into_eval(self) -> Path: - """Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval.""" - agent_restore_path = self.evaluation_path / "agent_restore" - if agent_restore_path.exists(): - shutil.rmtree(agent_restore_path) - agent_restore_path.mkdir() - with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file: - zip_file.extractall(agent_restore_path) - return agent_restore_path +# # def _unpack_saved_agent_into_eval(self) -> Path: +# # """Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval.""" +# # agent_restore_path = self.evaluation_path / "agent_restore" +# # if agent_restore_path.exists(): +# # shutil.rmtree(agent_restore_path) +# # agent_restore_path.mkdir() +# # with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file: +# # zip_file.extractall(agent_restore_path) +# # return agent_restore_path - def _setup_eval(self): - self._can_learn = False - self._can_evaluate = True - self._agent.restore(str(self._unpack_saved_agent_into_eval())) +# # def _setup_eval(self): +# # self._can_learn = False +# # self._can_evaluate = True +# # self._agent.restore(str(self._unpack_saved_agent_into_eval())) - def evaluate( - self, - **kwargs, - ): - """ - Evaluate the agent. +# # def evaluate( +# # self, +# # **kwargs, +# # ): +# # """ +# # Evaluate the agent. - :param kwargs: Any agent-specific key-word args to be passed. - """ - time_steps = self._training_config.num_eval_steps - episodes = self._training_config.num_eval_episodes +# # :param kwargs: Any agent-specific key-word args to be passed. +# # """ +# # time_steps = self._training_config.num_eval_steps +# # episodes = self._training_config.num_eval_episodes - self._setup_eval() +# # self._setup_eval() - self._env: Primaite = Primaite( - self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str - ) +# # self._env: Primaite = Primaite( +# # self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str +# # ) - self._env.set_as_eval() - self.is_eval = True - if self._training_config.deterministic: - deterministic_str = "deterministic" - else: - deterministic_str = "non-deterministic" - _LOGGER.info( - f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..." - ) - for episode in range(episodes): - obs = self._env.reset() - for step in range(time_steps): - action = self._agent.compute_single_action(observation=obs, explore=False) +# # self._env.set_as_eval() +# # self.is_eval = True +# # if self._training_config.deterministic: +# # deterministic_str = "deterministic" +# # else: +# # deterministic_str = "non-deterministic" +# # _LOGGER.info( +# # f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..." +# # ) +# # for episode in range(episodes): +# # obs = self._env.reset() +# # for step in range(time_steps): +# # action = self._agent.compute_single_action(observation=obs, explore=False) - obs, rewards, done, info = self._env.step(action) +# # obs, rewards, done, info = self._env.step(action) - self._env.reset() - self._env.close() - super().evaluate() - # Now we're safe to close the learning agent and write the mean rewards per episode for it - if self._training_config.session_type is not SessionType.TRAIN: - self._train_agent.stop() - self._plot_av_reward_per_episode(learning_session=True) - # Perform a clean-up of the unpacked agent - if (self.evaluation_path / "agent_restore").exists(): - shutil.rmtree((self.evaluation_path / "agent_restore")) +# # self._env.reset() +# # self._env.close() +# # super().evaluate() +# # # Now we're safe to close the learning agent and write the mean rewards per episode for it +# # if self._training_config.session_type is not SessionType.TRAIN: +# # self._train_agent.stop() +# # self._plot_av_reward_per_episode(learning_session=True) +# # # Perform a clean-up of the unpacked agent +# # if (self.evaluation_path / "agent_restore").exists(): +# # shutil.rmtree((self.evaluation_path / "agent_restore")) - def _get_latest_checkpoint(self) -> None: - raise NotImplementedError +# # def _get_latest_checkpoint(self) -> None: +# # raise NotImplementedError - @classmethod - def load(cls, path: Union[str, Path]) -> RLlibAgent: - """Load an agent from file.""" - raise NotImplementedError +# # @classmethod +# # def load(cls, path: Union[str, Path]) -> RLlibAgent: +# # """Load an agent from file.""" +# # raise NotImplementedError - def save(self, overwrite_existing: bool = True) -> None: - """Save the agent.""" - # Make temp dir to save in isolation - temp_dir = self.learning_path / str(uuid4()) - temp_dir.mkdir() +# # def save(self, overwrite_existing: bool = True) -> None: +# # """Save the agent.""" +# # # Make temp dir to save in isolation +# # temp_dir = self.learning_path / str(uuid4()) +# # temp_dir.mkdir() - # Save the agent to the temp dir - self._agent.save(str(temp_dir)) +# # # Save the agent to the temp dir +# # self._agent.save(str(temp_dir)) - # Capture the saved Rllib checkpoint inside the temp directory - for file in temp_dir.iterdir(): - checkpoint_dir = file - break +# # # Capture the saved Rllib checkpoint inside the temp directory +# # for file in temp_dir.iterdir(): +# # checkpoint_dir = file +# # break - # Zip the folder - shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa +# # # Zip the folder +# # shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa - # Drop the temp directory - shutil.rmtree(temp_dir) +# # # Drop the temp directory +# # shutil.rmtree(temp_dir) - def export(self) -> None: - """Export the agent to transportable file format.""" - raise NotImplementedError +# # def export(self) -> None: +# # """Export the agent to transportable file format.""" +# # raise NotImplementedError diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 006301f1..c33e764b 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -99,8 +99,8 @@ class AgentFramework(Enum): "Custom Agent" SB3 = 1 "Stable Baselines3" - RLLIB = 2 - "Ray RLlib" + # RLLIB = 2 + # "Ray RLlib" class DeepLearningFramework(Enum): diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 7f5dc568..f81bb6f7 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -248,8 +248,8 @@ class TrainingConfig: def __str__(self) -> str: obs_str = ",".join([c["name"] for c in self.observation_space["components"]]) tc = f"{self.agent_framework}, " - if self.agent_framework is AgentFramework.RLLIB: - tc += f"{self.deep_learning_framework}, " + # if self.agent_framework is AgentFramework.RLLIB: + # tc += f"{self.deep_learning_framework}, " tc += f"{self.agent_identifier}, " if self.agent_identifier is AgentIdentifier.HARDCODED: tc += f"{self.hard_coded_agent_view}, " diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 62af6c5b..a809772f 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -17,9 +17,8 @@ from primaite import getLogger from primaite.acl.access_control_list import AccessControlList from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action from primaite.common.custom_typing import NodeUnion -from primaite.common.enums import ( +from primaite.common.enums import ( # AgentFramework, ActionType, - AgentFramework, AgentIdentifier, FileSystemState, HardwareState, @@ -236,7 +235,8 @@ class Primaite(Env): _LOGGER.debug("Action space type NODE selected") # Terms (for node action space): # [0, num nodes] - node ID (0 = nothing, node ID) - # [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa + # [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, # noqa + # service state, file system state) # noqa # [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa # [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa self.action_dict = self.create_node_action_dict() @@ -271,8 +271,8 @@ class Primaite(Env): @property def actual_episode_count(self) -> int: """Shifts the episode_count by -1 for RLlib learning session.""" - if self.training_config.agent_framework is AgentFramework.RLLIB and not self.is_eval: - return self.episode_count - 1 + # if self.training_config.agent_framework is AgentFramework.RLLIB and not self.is_eval: + # return self.episode_count - 1 return self.episode_count def set_as_eval(self) -> None: diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index 2cb0d5bd..7d5b2709 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -10,7 +10,8 @@ from primaite import getLogger from primaite.agents.agent_abc import AgentSessionABC from primaite.agents.hardcoded_acl import HardCodedACLAgent from primaite.agents.hardcoded_node import HardCodedNodeAgent -from primaite.agents.rllib import RLlibAgent + +# from primaite.agents.rllib import RLlibAgent from primaite.agents.sb3 import SB3Agent from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyAgent, RandomAgent from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType @@ -157,10 +158,12 @@ class PrimaiteSession: self.legacy_lay_down_config, ) - elif self._training_config.agent_framework == AgentFramework.RLLIB: - _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}") - # Ray RLlib Agent - self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path, self.session_path) + # elif self._training_config.agent_framework == AgentFramework.RLLIB: + # _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}") + # # Ray RLlib Agent + # self._agent_session = RLlibAgent( + # self._training_config_path, self._lay_down_config_path, self.session_path + # ) else: # Invalid AgentFramework diff --git a/src/primaite/simulator/__init__.py b/src/primaite/simulator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py new file mode 100644 index 00000000..5b9bea1f --- /dev/null +++ b/src/primaite/simulator/core.py @@ -0,0 +1,55 @@ +"""Core of the PrimAITE Simulator.""" +from abc import abstractmethod +from typing import Callable, Dict, List + +from pydantic import BaseModel + + +class SimComponent(BaseModel): + """Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator.""" + + @abstractmethod + def describe_state(self) -> Dict: + """ + Return a dictionary describing the state of this object and any objects managed by it. + + This is similar to pydantic ``model_dump()``, but it only outputs information about the objects owned by this + object. If there are objects referenced by this object that are owned by something else, it is not included in + this output. + """ + return {} + + def apply_action(self, action: List[str]) -> None: + """ + Apply an action to a simulation component. Action data is passed in as a 'namespaced' list of strings. + + If the list only has one element, the action is intended to be applied directly to this object. If the list has + multiple entries, the action is passed to the child of this object specified by the first one or two entries. + This is essentially a namespace. + + For example, ["turn_on",] is meant to apply an action of 'turn on' to this component. + + However, ["services", "email_client", "turn_on"] is meant to 'turn on' this component's email client service. + + :param action: List describing the action to apply to this object. + :type action: List[str] + """ + possible_actions = self._possible_actions() + if action[0] in possible_actions: + # take the first element off the action list and pass the remaining arguments to the corresponding action + # funciton + possible_actions[action.pop(0)](action) + else: + raise ValueError(f"{self.__class__.__name__} received invalid action {action}") + + def _possible_actions(self) -> Dict[str, Callable[[List[str]], None]]: + return {} + + def apply_timestep(self) -> None: + """ + Apply a timestep evolution to this component. + + Override this method with anything that happens automatically in the component such as scheduled restarts or + sending data. + """ + pass diff --git a/tests/test_primaite_session.py b/tests/test_primaite_session.py index b76a2ecf..6e23b3ac 100644 --- a/tests/test_primaite_session.py +++ b/tests/test_primaite_session.py @@ -13,7 +13,7 @@ _LOGGER = getLogger(__name__) @pytest.mark.parametrize( "temp_primaite_session", [ - [TEST_CONFIG_ROOT / "session_test/training_config_main_rllib.yaml", dos_very_basic_config_path()], + # [TEST_CONFIG_ROOT / "session_test/training_config_main_rllib.yaml", dos_very_basic_config_path()], [TEST_CONFIG_ROOT / "session_test/training_config_main_sb3.yaml", dos_very_basic_config_path()], ], indirect=True, diff --git a/tests/unit_tests/__init__.py b/tests/unit_tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/primaite/__init__.py b/tests/unit_tests/primaite/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/primaite/simulator/__init__.py b/tests/unit_tests/primaite/simulator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/primaite/simulator/test_core.py b/tests/unit_tests/primaite/simulator/test_core.py new file mode 100644 index 00000000..de0732f9 --- /dev/null +++ b/tests/unit_tests/primaite/simulator/test_core.py @@ -0,0 +1,79 @@ +from typing import Callable, Dict, List, Literal, Tuple + +import pytest +from pydantic import ValidationError + +from primaite.simulator.core import SimComponent + + +class TestIsolatedSimComponent: + """Test the SimComponent class in isolation.""" + + def test_data_validation(self): + """ + Test that our derived class does not interfere with pydantic data validation. + + This test may seem like it's simply validating pydantic data validation, but + actually it is here to give us assurance that any custom functionality we add + to the SimComponent does not interfere with pydantic. + """ + + class TestComponent(SimComponent): + name: str + size: Tuple[float, float] + + def describe_state(self) -> Dict: + return {} + + comp = TestComponent(name="computer", size=(5, 10)) + assert isinstance(comp, TestComponent) + + with pytest.raises(ValidationError): + invalid_comp = TestComponent(name="computer", size="small") # noqa + + def test_serialisation(self): + """Validate that our added functionality does not interfere with pydantic.""" + + class TestComponent(SimComponent): + name: str + size: Tuple[float, float] + + def describe_state(self) -> Dict: + return {} + + comp = TestComponent(name="computer", size=(5, 10)) + dump = comp.model_dump() + assert dump == {"name": "computer", "size": (5, 10)} + + def test_apply_action(self): + """Validate that we can override apply_action behaviour and it updates the state of the component.""" + + class TestComponent(SimComponent): + name: str + status: Literal["on", "off"] = "off" + + def describe_state(self) -> Dict: + return {} + + def _possible_actions(self) -> Dict[str, Callable[[List[str]], None]]: + return { + "turn_off": self._turn_off, + "turn_on": self._turn_on, + } + + def _turn_off(self, options: List[str]) -> None: + assert len(options) == 0, "This action does not support options." + self.status = "off" + + def _turn_on(self, options: List[str]) -> None: + assert len(options) == 0, "This action does not support options." + self.status = "on" + + comp = TestComponent(name="computer", status="off") + + assert comp.status == "off" + comp.apply_action(["turn_on"]) + assert comp.status == "on" + + with pytest.raises(ValueError): + comp.apply_action(["do_nothing"])