Merged PR 149: Base class for Simulation Components
## Summary This introduces a base class for all simulation components. The idea behind this is to formalise the way in which data is extracted from the simulator and the way actions are applied to the different aspects of the simulator. The intention is that any class that simulates something will inherit from SimComponent (which inherits from pydantic BaseModel). Actions enter the simulator as a list of strings that is intended to be peeled back as you go down the layers of the simulation. For example we could have an action of `["network", "nodes", "node3", "network_interface_card", "disable"]` This list is passed to the `apply_action()` function of the overall simulation controller. The simulation controller looks at the first word on the list, `network` and uses this to select a method that can apply the function. It passes the remainder of the list as an argument to that function. In this case it will be `["nodes", "node3", "network_interface_card", "disable"]`. To the reviewers, please validate that you're happy with the implicit design choices I've made while implementing this. Especially the contract passing actions down the components tree. (also I changed some mentions of agent to agent_abc in the docs as it was complaining and refusing to build.) ## Test process I have written basic unit tests to check that the custom functionality added to SimComponent doesn't interfere with basic pydantic functionality. I also started doc pages that explains these concepts to potential developers, although once there are subclasses of this core class, it will be easier to populate the docs with concrete examples. ## Checklist - [x] This PR is linked to a **work item** - [x] I have performed **self-review** of the code - [x] I have written **tests** for any new functionality added with this PR - [x] I have updated the **documentation** if this PR changes or adds functionality - [x] I have run **pre-commit** checks for code style - [x] I have **type hinted** all the code I changed. Related work items: #1709
This commit is contained in:
@@ -43,6 +43,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
|
||||
source/config
|
||||
source/primaite_session
|
||||
source/custom_agent
|
||||
source/simulation
|
||||
PrimAITE API <source/_autosummary/primaite>
|
||||
PrimAITE Tests <source/_autosummary/tests>
|
||||
source/dependencies
|
||||
|
||||
@@ -13,7 +13,7 @@ Integrating a user defined blue agent
|
||||
|
||||
If you are planning to implement custom RL agents into PrimAITE, you must use the project as a repository. If you install PrimAITE as a python package from wheel, custom agents are not supported.
|
||||
|
||||
PrimAITE has integration with Ray RLLib and StableBaselines3 agents. All agents interface with PrimAITE through an :py:class:`primaite.agents.agent.AgentSessionABC<Agent Session>` which provides Input/Output of agent savefiles, as well as capturing and plotting performance metrics during training and evaluation. If you wish to integrate a custom blue agent, it is recommended to create a subclass of the :py:class:`primaite.agents.agent.AgentSessionABC` and implement the ``__init__()``, ``_setup()``, ``_save_checkpoint()``, ``learn()``, ``evaluate()``, ``_get_latest_checkpoint``, ``load()``, and ``save()`` methods.
|
||||
PrimAITE has integration with Ray RLLib and StableBaselines3 agents. All agents interface with PrimAITE through an :py:class:`primaite.agents.agent_abc.AgentSessionABC<Agent Session>` which provides Input/Output of agent savefiles, as well as capturing and plotting performance metrics during training and evaluation. If you wish to integrate a custom blue agent, it is recommended to create a subclass of the :py:class:`primaite.agents.agent_abc.AgentSessionABC` and implement the ``__init__()``, ``_setup()``, ``_save_checkpoint()``, ``learn()``, ``evaluate()``, ``_get_latest_checkpoint``, ``load()``, and ``save()`` methods.
|
||||
|
||||
Below is a barebones example of a custom agent implementation:
|
||||
|
||||
@@ -21,7 +21,7 @@ Below is a barebones example of a custom agent implementation:
|
||||
|
||||
# src/primaite/agents/my_custom_agent.py
|
||||
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
|
||||
class CustomAgent(AgentSessionABC):
|
||||
|
||||
10
docs/source/simulation.rst
Normal file
10
docs/source/simulation.rst
Normal file
@@ -0,0 +1,10 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
Simulation Strucutre
|
||||
====================
|
||||
|
||||
The simulation is made up of many smaller components which are related to each other in a tree-like structure. At the top level, there is an object called the ``SimulationController`` _(doesn't exist yet)_, which has a physical network and a software controller for managing software and users.
|
||||
|
||||
Each node of the simulation 'tree' has responsibility for creating, deleting, and updating its direct descendants.
|
||||
@@ -34,10 +34,10 @@ dependencies = [
|
||||
"plotly==5.15.0",
|
||||
"polars==0.18.4",
|
||||
"PyYAML==6.0",
|
||||
"ray[rllib]==2.2.0",
|
||||
"stable-baselines3==1.6.2",
|
||||
"tensorflow==2.12.0",
|
||||
"typer[all]==0.9.0"
|
||||
"typer[all]==0.9.0",
|
||||
"pydantic"
|
||||
]
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
|
||||
@@ -1,286 +1,287 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
# # © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
# from __future__ import annotations
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
# import json
|
||||
# import shutil
|
||||
# import zipfile
|
||||
# from datetime import datetime
|
||||
# from logging import Logger
|
||||
# from pathlib import Path
|
||||
# from typing import Any, Callable, Dict, Optional, Union
|
||||
# from uuid import uuid4
|
||||
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
from ray.rllib.algorithms.a2c import A2CConfig
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.registry import register_env
|
||||
# from primaite import getLogger
|
||||
# from primaite.agents.agent_abc import AgentSessionABC
|
||||
# from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType
|
||||
# from primaite.environment.primaite_env import Primaite
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.exceptions import RLlibAgentError
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
# # from ray.rllib.algorithms import Algorithm
|
||||
# # from ray.rllib.algorithms.a2c import A2CConfig
|
||||
# # from ray.rllib.algorithms.ppo import PPOConfig
|
||||
# # from ray.tune.logger import UnifiedLogger
|
||||
# # from ray.tune.registry import register_env
|
||||
|
||||
|
||||
# TODO: verify type of env_config
|
||||
def _env_creator(env_config: Dict[str, Any]) -> Primaite:
|
||||
return Primaite(
|
||||
training_config_path=env_config["training_config_path"],
|
||||
lay_down_config_path=env_config["lay_down_config_path"],
|
||||
session_path=env_config["session_path"],
|
||||
timestamp_str=env_config["timestamp_str"],
|
||||
)
|
||||
# # from primaite.exceptions import RLlibAgentError
|
||||
|
||||
# _LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: verify type hint return type
|
||||
def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
|
||||
logdir = session_path / "ray_results"
|
||||
logdir.mkdir(parents=True, exist_ok=True)
|
||||
# # TODO: verify type of env_config
|
||||
# def _env_creator(env_config: Dict[str, Any]) -> Primaite:
|
||||
# return Primaite(
|
||||
# training_config_path=env_config["training_config_path"],
|
||||
# lay_down_config_path=env_config["lay_down_config_path"],
|
||||
# session_path=env_config["session_path"],
|
||||
# timestamp_str=env_config["timestamp_str"],
|
||||
# )
|
||||
|
||||
def logger_creator(config: Dict) -> UnifiedLogger:
|
||||
return UnifiedLogger(config, logdir, loggers=None)
|
||||
# # # TODO: verify type hint return type
|
||||
# # def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
|
||||
# # logdir = session_path / "ray_results"
|
||||
# # logdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return logger_creator
|
||||
# # def logger_creator(config: Dict) -> UnifiedLogger:
|
||||
# # return UnifiedLogger(config, logdir, loggers=None)
|
||||
|
||||
# return logger_creator
|
||||
|
||||
|
||||
class RLlibAgent(AgentSessionABC):
|
||||
"""An AgentSession class that implements a Ray RLlib agent."""
|
||||
# # class RLlibAgent(AgentSessionABC):
|
||||
# # """An AgentSession class that implements a Ray RLlib agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the RLLib Agent training session.
|
||||
# # def __init__(
|
||||
# # self,
|
||||
# # training_config_path: Optional[Union[str, Path]] = "",
|
||||
# # lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
# # session_path: Optional[Union[str, Path]] = None,
|
||||
# # ) -> None:
|
||||
# # """
|
||||
# # Initialise the RLLib Agent training session.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:raises ValueError: If the training config contains an unexpected value for agent_framework (should be "RLLIB")
|
||||
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
|
||||
or `A2C`)
|
||||
"""
|
||||
# TODO: implement RLlib agent loading
|
||||
if session_path is not None:
|
||||
msg = "RLlib agent loading has not been implemented yet"
|
||||
_LOGGER.critical(msg)
|
||||
raise NotImplementedError(msg)
|
||||
# # :param training_config_path: YAML file containing configurable items defined in
|
||||
# # `primaite.config.training_config.TrainingConfig`
|
||||
# # :type training_config_path: Union[path, str]
|
||||
# # :param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
# # :type lay_down_config_path: Union[path, str]
|
||||
# # :raises ValueError: If the training config contains a bad value for agent_framework (should be "RLLIB")
|
||||
# # :raises ValueError: If the training config contains a bad value for agent_identifies (should be `PPO`
|
||||
# # or `A2C`)
|
||||
# # """
|
||||
# # # TODO: implement RLlib agent loading
|
||||
# # if session_path is not None:
|
||||
# # msg = "RLlib agent loading has not been implemented yet"
|
||||
# # _LOGGER.critical(msg)
|
||||
# # raise NotImplementedError(msg)
|
||||
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
if self._training_config.session_type == SessionType.EVAL:
|
||||
msg = "Cannot evaluate an RLlib agent that hasn't been through training yet."
|
||||
_LOGGER.critical(msg)
|
||||
raise RLlibAgentError(msg)
|
||||
if not self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
self._agent_config_class: Union[PPOConfig, A2CConfig]
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
self._agent_config_class = PPOConfig
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
self._agent_config_class = A2CConfig
|
||||
else:
|
||||
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
self._agent_config: Union[PPOConfig, A2CConfig]
|
||||
# # super().__init__(training_config_path, lay_down_config_path)
|
||||
# # if self._training_config.session_type == SessionType.EVAL:
|
||||
# # msg = "Cannot evaluate an RLlib agent that hasn't been through training yet."
|
||||
# # _LOGGER.critical(msg)
|
||||
# # raise RLlibAgentError(msg)
|
||||
# # if not self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
# # msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
# # _LOGGER.error(msg)
|
||||
# # raise ValueError(msg)
|
||||
# # self._agent_config_class: Union[PPOConfig, A2CConfig]
|
||||
# # if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
# # self._agent_config_class = PPOConfig
|
||||
# # elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
# # self._agent_config_class = A2CConfig
|
||||
# # else:
|
||||
# # msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}"
|
||||
# # _LOGGER.error(msg)
|
||||
# # raise ValueError(msg)
|
||||
# # self._agent_config: Union[PPOConfig, A2CConfig]
|
||||
|
||||
self._current_result: dict
|
||||
self._setup()
|
||||
_LOGGER.debug(
|
||||
f"Created {self.__class__.__name__} using: "
|
||||
f"agent_framework={self._training_config.agent_framework}, "
|
||||
f"agent_identifier="
|
||||
f"{self._training_config.agent_identifier}, "
|
||||
f"deep_learning_framework="
|
||||
f"{self._training_config.deep_learning_framework}"
|
||||
)
|
||||
self._train_agent = None # Required to capture the learning agent to close after eval
|
||||
# # self._current_result: dict
|
||||
# # self._setup()
|
||||
# # _LOGGER.debug(
|
||||
# # f"Created {self.__class__.__name__} using: "
|
||||
# # f"agent_framework={self._training_config.agent_framework}, "
|
||||
# # f"agent_identifier="
|
||||
# # f"{self._training_config.agent_identifier}, "
|
||||
# # f"deep_learning_framework="
|
||||
# # f"{self._training_config.deep_learning_framework}"
|
||||
# # )
|
||||
# # self._train_agent = None # Required to capture the learning agent to close after eval
|
||||
|
||||
def _update_session_metadata_file(self) -> None:
|
||||
"""
|
||||
Update the ``session_metadata.json`` file.
|
||||
# # def _update_session_metadata_file(self) -> None:
|
||||
# # """
|
||||
# # Update the ``session_metadata.json`` file.
|
||||
|
||||
Updates the `session_metadata.json`` in the ``session_path`` directory
|
||||
with the following key/value pairs:
|
||||
# # Updates the `session_metadata.json`` in the ``session_path`` directory
|
||||
# # with the following key/value pairs:
|
||||
|
||||
- end_datetime: The date & time the session ended in iso format.
|
||||
- total_episodes: The total number of training episodes completed.
|
||||
- total_time_steps: The total number of training time steps completed.
|
||||
"""
|
||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
metadata_dict = json.load(file)
|
||||
# # - end_datetime: The date & time the session ended in iso format.
|
||||
# # - total_episodes: The total number of training episodes completed.
|
||||
# # - total_time_steps: The total number of training time steps completed.
|
||||
# # """
|
||||
# # with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
# # metadata_dict = json.load(file)
|
||||
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
if not self.is_eval:
|
||||
metadata_dict["learning"]["total_episodes"] = self._current_result["episodes_total"] # noqa
|
||||
metadata_dict["learning"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
|
||||
else:
|
||||
metadata_dict["evaluation"]["total_episodes"] = self._current_result["episodes_total"] # noqa
|
||||
metadata_dict["evaluation"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
|
||||
# # metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
# # if not self.is_eval:
|
||||
# # metadata_dict["learning"]["total_episodes"] = self._current_result["episodes_total"] # noqa
|
||||
# # metadata_dict["learning"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
|
||||
# # else:
|
||||
# # metadata_dict["evaluation"]["total_episodes"] = self._current_result["episodes_total"] # noqa
|
||||
# # metadata_dict["evaluation"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
|
||||
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
with open(filepath, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished updating session metadata file")
|
||||
# # filepath = self.session_path / "session_metadata.json"
|
||||
# # _LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
# # with open(filepath, "w") as file:
|
||||
# # json.dump(metadata_dict, file)
|
||||
# # _LOGGER.debug("Finished updating session metadata file")
|
||||
|
||||
def _setup(self) -> None:
|
||||
super()._setup()
|
||||
register_env("primaite", _env_creator)
|
||||
self._agent_config = self._agent_config_class()
|
||||
# # def _setup(self) -> None:
|
||||
# # super()._setup()
|
||||
# # register_env("primaite", _env_creator)
|
||||
# # self._agent_config = self._agent_config_class()
|
||||
|
||||
self._agent_config.environment(
|
||||
env="primaite",
|
||||
env_config=dict(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
),
|
||||
)
|
||||
self._agent_config.seed = self._training_config.seed
|
||||
# # self._agent_config.environment(
|
||||
# # env="primaite",
|
||||
# # env_config=dict(
|
||||
# # training_config_path=self._training_config_path,
|
||||
# # lay_down_config_path=self._lay_down_config_path,
|
||||
# # session_path=self.session_path,
|
||||
# # timestamp_str=self.timestamp_str,
|
||||
# # ),
|
||||
# # )
|
||||
# # self._agent_config.seed = self._training_config.seed
|
||||
|
||||
self._agent_config.training(train_batch_size=self._training_config.num_train_steps)
|
||||
self._agent_config.framework(framework="tf")
|
||||
# # self._agent_config.training(train_batch_size=self._training_config.num_train_steps)
|
||||
# # self._agent_config.framework(framework="tf")
|
||||
|
||||
self._agent_config.rollouts(
|
||||
num_rollout_workers=1,
|
||||
num_envs_per_worker=1,
|
||||
horizon=self._training_config.num_train_steps,
|
||||
)
|
||||
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
|
||||
# # self._agent_config.rollouts(
|
||||
# # num_rollout_workers=1,
|
||||
# # num_envs_per_worker=1,
|
||||
# # horizon=self._training_config.num_train_steps,
|
||||
# # )
|
||||
# # self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
|
||||
|
||||
def _save_checkpoint(self) -> None:
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._current_result["episodes_total"]
|
||||
save_checkpoint = False
|
||||
if checkpoint_n:
|
||||
save_checkpoint = episode_count % checkpoint_n == 0
|
||||
if episode_count and save_checkpoint:
|
||||
self._agent.save(str(self.checkpoints_path))
|
||||
# # def _save_checkpoint(self) -> None:
|
||||
# # checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
# # episode_count = self._current_result["episodes_total"]
|
||||
# # save_checkpoint = False
|
||||
# # if checkpoint_n:
|
||||
# # save_checkpoint = episode_count % checkpoint_n == 0
|
||||
# # if episode_count and save_checkpoint:
|
||||
# # self._agent.save(str(self.checkpoints_path))
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
# # def learn(
|
||||
# # self,
|
||||
# # **kwargs: Any,
|
||||
# # ) -> None:
|
||||
# # """
|
||||
# # Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
time_steps = self._training_config.num_train_steps
|
||||
episodes = self._training_config.num_train_episodes
|
||||
# # :param kwargs: Any agent-specific key-word args to be passed.
|
||||
# # """
|
||||
# # time_steps = self._training_config.num_train_steps
|
||||
# # episodes = self._training_config.num_train_episodes
|
||||
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
for i in range(episodes):
|
||||
self._current_result = self._agent.train()
|
||||
self._save_checkpoint()
|
||||
self.save()
|
||||
super().learn()
|
||||
# Done this way as the RLlib eval can only be performed if the session hasn't been stopped
|
||||
if self._training_config.session_type is not SessionType.TRAIN:
|
||||
self._train_agent = self._agent
|
||||
else:
|
||||
self._agent.stop()
|
||||
self._plot_av_reward_per_episode(learning_session=True)
|
||||
# # _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
# # for i in range(episodes):
|
||||
# # self._current_result = self._agent.train()
|
||||
# # self._save_checkpoint()
|
||||
# # self.save()
|
||||
# # super().learn()
|
||||
# # # Done this way as the RLlib eval can only be performed if the session hasn't been stopped
|
||||
# # if self._training_config.session_type is not SessionType.TRAIN:
|
||||
# # self._train_agent = self._agent
|
||||
# # else:
|
||||
# # self._agent.stop()
|
||||
# # self._plot_av_reward_per_episode(learning_session=True)
|
||||
|
||||
def _unpack_saved_agent_into_eval(self) -> Path:
|
||||
"""Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval."""
|
||||
agent_restore_path = self.evaluation_path / "agent_restore"
|
||||
if agent_restore_path.exists():
|
||||
shutil.rmtree(agent_restore_path)
|
||||
agent_restore_path.mkdir()
|
||||
with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file:
|
||||
zip_file.extractall(agent_restore_path)
|
||||
return agent_restore_path
|
||||
# # def _unpack_saved_agent_into_eval(self) -> Path:
|
||||
# # """Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval."""
|
||||
# # agent_restore_path = self.evaluation_path / "agent_restore"
|
||||
# # if agent_restore_path.exists():
|
||||
# # shutil.rmtree(agent_restore_path)
|
||||
# # agent_restore_path.mkdir()
|
||||
# # with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file:
|
||||
# # zip_file.extractall(agent_restore_path)
|
||||
# # return agent_restore_path
|
||||
|
||||
def _setup_eval(self):
|
||||
self._can_learn = False
|
||||
self._can_evaluate = True
|
||||
self._agent.restore(str(self._unpack_saved_agent_into_eval()))
|
||||
# # def _setup_eval(self):
|
||||
# # self._can_learn = False
|
||||
# # self._can_evaluate = True
|
||||
# # self._agent.restore(str(self._unpack_saved_agent_into_eval()))
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
# # def evaluate(
|
||||
# # self,
|
||||
# # **kwargs,
|
||||
# # ):
|
||||
# # """
|
||||
# # Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
time_steps = self._training_config.num_eval_steps
|
||||
episodes = self._training_config.num_eval_episodes
|
||||
# # :param kwargs: Any agent-specific key-word args to be passed.
|
||||
# # """
|
||||
# # time_steps = self._training_config.num_eval_steps
|
||||
# # episodes = self._training_config.num_eval_episodes
|
||||
|
||||
self._setup_eval()
|
||||
# # self._setup_eval()
|
||||
|
||||
self._env: Primaite = Primaite(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str
|
||||
)
|
||||
# # self._env: Primaite = Primaite(
|
||||
# # self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str
|
||||
# # )
|
||||
|
||||
self._env.set_as_eval()
|
||||
self.is_eval = True
|
||||
if self._training_config.deterministic:
|
||||
deterministic_str = "deterministic"
|
||||
else:
|
||||
deterministic_str = "non-deterministic"
|
||||
_LOGGER.info(
|
||||
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
|
||||
)
|
||||
for episode in range(episodes):
|
||||
obs = self._env.reset()
|
||||
for step in range(time_steps):
|
||||
action = self._agent.compute_single_action(observation=obs, explore=False)
|
||||
# # self._env.set_as_eval()
|
||||
# # self.is_eval = True
|
||||
# # if self._training_config.deterministic:
|
||||
# # deterministic_str = "deterministic"
|
||||
# # else:
|
||||
# # deterministic_str = "non-deterministic"
|
||||
# # _LOGGER.info(
|
||||
# # f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
|
||||
# # )
|
||||
# # for episode in range(episodes):
|
||||
# # obs = self._env.reset()
|
||||
# # for step in range(time_steps):
|
||||
# # action = self._agent.compute_single_action(observation=obs, explore=False)
|
||||
|
||||
obs, rewards, done, info = self._env.step(action)
|
||||
# # obs, rewards, done, info = self._env.step(action)
|
||||
|
||||
self._env.reset()
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
# Now we're safe to close the learning agent and write the mean rewards per episode for it
|
||||
if self._training_config.session_type is not SessionType.TRAIN:
|
||||
self._train_agent.stop()
|
||||
self._plot_av_reward_per_episode(learning_session=True)
|
||||
# Perform a clean-up of the unpacked agent
|
||||
if (self.evaluation_path / "agent_restore").exists():
|
||||
shutil.rmtree((self.evaluation_path / "agent_restore"))
|
||||
# # self._env.reset()
|
||||
# # self._env.close()
|
||||
# # super().evaluate()
|
||||
# # # Now we're safe to close the learning agent and write the mean rewards per episode for it
|
||||
# # if self._training_config.session_type is not SessionType.TRAIN:
|
||||
# # self._train_agent.stop()
|
||||
# # self._plot_av_reward_per_episode(learning_session=True)
|
||||
# # # Perform a clean-up of the unpacked agent
|
||||
# # if (self.evaluation_path / "agent_restore").exists():
|
||||
# # shutil.rmtree((self.evaluation_path / "agent_restore"))
|
||||
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
raise NotImplementedError
|
||||
# # def _get_latest_checkpoint(self) -> None:
|
||||
# # raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Union[str, Path]) -> RLlibAgent:
|
||||
"""Load an agent from file."""
|
||||
raise NotImplementedError
|
||||
# # @classmethod
|
||||
# # def load(cls, path: Union[str, Path]) -> RLlibAgent:
|
||||
# # """Load an agent from file."""
|
||||
# # raise NotImplementedError
|
||||
|
||||
def save(self, overwrite_existing: bool = True) -> None:
|
||||
"""Save the agent."""
|
||||
# Make temp dir to save in isolation
|
||||
temp_dir = self.learning_path / str(uuid4())
|
||||
temp_dir.mkdir()
|
||||
# # def save(self, overwrite_existing: bool = True) -> None:
|
||||
# # """Save the agent."""
|
||||
# # # Make temp dir to save in isolation
|
||||
# # temp_dir = self.learning_path / str(uuid4())
|
||||
# # temp_dir.mkdir()
|
||||
|
||||
# Save the agent to the temp dir
|
||||
self._agent.save(str(temp_dir))
|
||||
# # # Save the agent to the temp dir
|
||||
# # self._agent.save(str(temp_dir))
|
||||
|
||||
# Capture the saved Rllib checkpoint inside the temp directory
|
||||
for file in temp_dir.iterdir():
|
||||
checkpoint_dir = file
|
||||
break
|
||||
# # # Capture the saved Rllib checkpoint inside the temp directory
|
||||
# # for file in temp_dir.iterdir():
|
||||
# # checkpoint_dir = file
|
||||
# # break
|
||||
|
||||
# Zip the folder
|
||||
shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
|
||||
# # # Zip the folder
|
||||
# # shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
|
||||
|
||||
# Drop the temp directory
|
||||
shutil.rmtree(temp_dir)
|
||||
# # # Drop the temp directory
|
||||
# # shutil.rmtree(temp_dir)
|
||||
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
raise NotImplementedError
|
||||
# # def export(self) -> None:
|
||||
# # """Export the agent to transportable file format."""
|
||||
# # raise NotImplementedError
|
||||
|
||||
@@ -99,8 +99,8 @@ class AgentFramework(Enum):
|
||||
"Custom Agent"
|
||||
SB3 = 1
|
||||
"Stable Baselines3"
|
||||
RLLIB = 2
|
||||
"Ray RLlib"
|
||||
# RLLIB = 2
|
||||
# "Ray RLlib"
|
||||
|
||||
|
||||
class DeepLearningFramework(Enum):
|
||||
|
||||
@@ -248,8 +248,8 @@ class TrainingConfig:
|
||||
def __str__(self) -> str:
|
||||
obs_str = ",".join([c["name"] for c in self.observation_space["components"]])
|
||||
tc = f"{self.agent_framework}, "
|
||||
if self.agent_framework is AgentFramework.RLLIB:
|
||||
tc += f"{self.deep_learning_framework}, "
|
||||
# if self.agent_framework is AgentFramework.RLLIB:
|
||||
# tc += f"{self.deep_learning_framework}, "
|
||||
tc += f"{self.agent_identifier}, "
|
||||
if self.agent_identifier is AgentIdentifier.HARDCODED:
|
||||
tc += f"{self.hard_coded_agent_view}, "
|
||||
|
||||
@@ -17,9 +17,8 @@ from primaite import getLogger
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
from primaite.common.enums import ( # AgentFramework,
|
||||
ActionType,
|
||||
AgentFramework,
|
||||
AgentIdentifier,
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
@@ -236,7 +235,8 @@ class Primaite(Env):
|
||||
_LOGGER.debug("Action space type NODE selected")
|
||||
# Terms (for node action space):
|
||||
# [0, num nodes] - node ID (0 = nothing, node ID)
|
||||
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
|
||||
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, # noqa
|
||||
# service state, file system state) # noqa
|
||||
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
|
||||
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
|
||||
self.action_dict = self.create_node_action_dict()
|
||||
@@ -271,8 +271,8 @@ class Primaite(Env):
|
||||
@property
|
||||
def actual_episode_count(self) -> int:
|
||||
"""Shifts the episode_count by -1 for RLlib learning session."""
|
||||
if self.training_config.agent_framework is AgentFramework.RLLIB and not self.is_eval:
|
||||
return self.episode_count - 1
|
||||
# if self.training_config.agent_framework is AgentFramework.RLLIB and not self.is_eval:
|
||||
# return self.episode_count - 1
|
||||
return self.episode_count
|
||||
|
||||
def set_as_eval(self) -> None:
|
||||
|
||||
@@ -10,7 +10,8 @@ from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.agents.hardcoded_acl import HardCodedACLAgent
|
||||
from primaite.agents.hardcoded_node import HardCodedNodeAgent
|
||||
from primaite.agents.rllib import RLlibAgent
|
||||
|
||||
# from primaite.agents.rllib import RLlibAgent
|
||||
from primaite.agents.sb3 import SB3Agent
|
||||
from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyAgent, RandomAgent
|
||||
from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType
|
||||
@@ -157,10 +158,12 @@ class PrimaiteSession:
|
||||
self.legacy_lay_down_config,
|
||||
)
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
|
||||
# Ray RLlib Agent
|
||||
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path, self.session_path)
|
||||
# elif self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
# _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
|
||||
# # Ray RLlib Agent
|
||||
# self._agent_session = RLlibAgent(
|
||||
# self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
# )
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework
|
||||
|
||||
0
src/primaite/simulator/__init__.py
Normal file
0
src/primaite/simulator/__init__.py
Normal file
55
src/primaite/simulator/core.py
Normal file
55
src/primaite/simulator/core.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Core of the PrimAITE Simulator."""
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SimComponent(BaseModel):
|
||||
"""Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator."""
|
||||
|
||||
@abstractmethod
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Return a dictionary describing the state of this object and any objects managed by it.
|
||||
|
||||
This is similar to pydantic ``model_dump()``, but it only outputs information about the objects owned by this
|
||||
object. If there are objects referenced by this object that are owned by something else, it is not included in
|
||||
this output.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def apply_action(self, action: List[str]) -> None:
|
||||
"""
|
||||
Apply an action to a simulation component. Action data is passed in as a 'namespaced' list of strings.
|
||||
|
||||
If the list only has one element, the action is intended to be applied directly to this object. If the list has
|
||||
multiple entries, the action is passed to the child of this object specified by the first one or two entries.
|
||||
This is essentially a namespace.
|
||||
|
||||
For example, ["turn_on",] is meant to apply an action of 'turn on' to this component.
|
||||
|
||||
However, ["services", "email_client", "turn_on"] is meant to 'turn on' this component's email client service.
|
||||
|
||||
:param action: List describing the action to apply to this object.
|
||||
:type action: List[str]
|
||||
"""
|
||||
possible_actions = self._possible_actions()
|
||||
if action[0] in possible_actions:
|
||||
# take the first element off the action list and pass the remaining arguments to the corresponding action
|
||||
# funciton
|
||||
possible_actions[action.pop(0)](action)
|
||||
else:
|
||||
raise ValueError(f"{self.__class__.__name__} received invalid action {action}")
|
||||
|
||||
def _possible_actions(self) -> Dict[str, Callable[[List[str]], None]]:
|
||||
return {}
|
||||
|
||||
def apply_timestep(self) -> None:
|
||||
"""
|
||||
Apply a timestep evolution to this component.
|
||||
|
||||
Override this method with anything that happens automatically in the component such as scheduled restarts or
|
||||
sending data.
|
||||
"""
|
||||
pass
|
||||
@@ -13,7 +13,7 @@ _LOGGER = getLogger(__name__)
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[TEST_CONFIG_ROOT / "session_test/training_config_main_rllib.yaml", dos_very_basic_config_path()],
|
||||
# [TEST_CONFIG_ROOT / "session_test/training_config_main_rllib.yaml", dos_very_basic_config_path()],
|
||||
[TEST_CONFIG_ROOT / "session_test/training_config_main_sb3.yaml", dos_very_basic_config_path()],
|
||||
],
|
||||
indirect=True,
|
||||
|
||||
0
tests/unit_tests/__init__.py
Normal file
0
tests/unit_tests/__init__.py
Normal file
0
tests/unit_tests/primaite/__init__.py
Normal file
0
tests/unit_tests/primaite/__init__.py
Normal file
0
tests/unit_tests/primaite/simulator/__init__.py
Normal file
0
tests/unit_tests/primaite/simulator/__init__.py
Normal file
79
tests/unit_tests/primaite/simulator/test_core.py
Normal file
79
tests/unit_tests/primaite/simulator/test_core.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import Callable, Dict, List, Literal, Tuple
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from primaite.simulator.core import SimComponent
|
||||
|
||||
|
||||
class TestIsolatedSimComponent:
|
||||
"""Test the SimComponent class in isolation."""
|
||||
|
||||
def test_data_validation(self):
|
||||
"""
|
||||
Test that our derived class does not interfere with pydantic data validation.
|
||||
|
||||
This test may seem like it's simply validating pydantic data validation, but
|
||||
actually it is here to give us assurance that any custom functionality we add
|
||||
to the SimComponent does not interfere with pydantic.
|
||||
"""
|
||||
|
||||
class TestComponent(SimComponent):
|
||||
name: str
|
||||
size: Tuple[float, float]
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
return {}
|
||||
|
||||
comp = TestComponent(name="computer", size=(5, 10))
|
||||
assert isinstance(comp, TestComponent)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
invalid_comp = TestComponent(name="computer", size="small") # noqa
|
||||
|
||||
def test_serialisation(self):
|
||||
"""Validate that our added functionality does not interfere with pydantic."""
|
||||
|
||||
class TestComponent(SimComponent):
|
||||
name: str
|
||||
size: Tuple[float, float]
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
return {}
|
||||
|
||||
comp = TestComponent(name="computer", size=(5, 10))
|
||||
dump = comp.model_dump()
|
||||
assert dump == {"name": "computer", "size": (5, 10)}
|
||||
|
||||
def test_apply_action(self):
|
||||
"""Validate that we can override apply_action behaviour and it updates the state of the component."""
|
||||
|
||||
class TestComponent(SimComponent):
|
||||
name: str
|
||||
status: Literal["on", "off"] = "off"
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
return {}
|
||||
|
||||
def _possible_actions(self) -> Dict[str, Callable[[List[str]], None]]:
|
||||
return {
|
||||
"turn_off": self._turn_off,
|
||||
"turn_on": self._turn_on,
|
||||
}
|
||||
|
||||
def _turn_off(self, options: List[str]) -> None:
|
||||
assert len(options) == 0, "This action does not support options."
|
||||
self.status = "off"
|
||||
|
||||
def _turn_on(self, options: List[str]) -> None:
|
||||
assert len(options) == 0, "This action does not support options."
|
||||
self.status = "on"
|
||||
|
||||
comp = TestComponent(name="computer", status="off")
|
||||
|
||||
assert comp.status == "off"
|
||||
comp.apply_action(["turn_on"])
|
||||
assert comp.status == "on"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
comp.apply_action(["do_nothing"])
|
||||
Reference in New Issue
Block a user