Merged PR 149: Base class for Simulation Components

## Summary
This introduces a base class for all simulation components. The idea behind this is to formalise the way in which data is extracted from the simulator and the way actions are applied to the different aspects of the simulator.

The intention is that any class that simulates something will inherit from SimComponent (which inherits from pydantic BaseModel).

Actions enter the simulator as a list of strings that is intended to be peeled back as you go down the layers of the simulation. For example we could have an action of `["network", "nodes", "node3", "network_interface_card", "disable"]` This list is passed to the `apply_action()` function of the overall simulation controller. The simulation controller looks at the first word on the list, `network` and uses this to select a method that can apply the function. It passes the remainder of the list as an argument to that function. In this case it will be `["nodes", "node3", "network_interface_card", "disable"]`.

To the reviewers, please validate that you're happy with the implicit design choices I've made while implementing this. Especially the contract passing actions down the components tree.

(also I changed some mentions of agent to agent_abc in the docs as it was complaining and refusing to build.)

## Test process
I have written basic unit tests to check that the custom functionality added to SimComponent doesn't interfere with basic pydantic functionality.
I also started doc pages that explains these concepts to potential developers, although once there are subclasses of this core class, it will be easier to populate the docs with concrete examples.

## Checklist
- [x] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [x] I have written **tests** for any new functionality added with this PR
- [x] I have updated the **documentation** if this PR changes or adds functionality
- [x] I have run **pre-commit** checks for code style
- [x] I have **type hinted** all the code I changed.

Related work items: #1709
This commit is contained in:
Marek Wolan
2023-07-31 12:43:07 +00:00
16 changed files with 409 additions and 260 deletions

View File

@@ -43,6 +43,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
source/config
source/primaite_session
source/custom_agent
source/simulation
PrimAITE API <source/_autosummary/primaite>
PrimAITE Tests <source/_autosummary/tests>
source/dependencies

View File

@@ -13,7 +13,7 @@ Integrating a user defined blue agent
If you are planning to implement custom RL agents into PrimAITE, you must use the project as a repository. If you install PrimAITE as a python package from wheel, custom agents are not supported.
PrimAITE has integration with Ray RLLib and StableBaselines3 agents. All agents interface with PrimAITE through an :py:class:`primaite.agents.agent.AgentSessionABC<Agent Session>` which provides Input/Output of agent savefiles, as well as capturing and plotting performance metrics during training and evaluation. If you wish to integrate a custom blue agent, it is recommended to create a subclass of the :py:class:`primaite.agents.agent.AgentSessionABC` and implement the ``__init__()``, ``_setup()``, ``_save_checkpoint()``, ``learn()``, ``evaluate()``, ``_get_latest_checkpoint``, ``load()``, and ``save()`` methods.
PrimAITE has integration with Ray RLLib and StableBaselines3 agents. All agents interface with PrimAITE through an :py:class:`primaite.agents.agent_abc.AgentSessionABC<Agent Session>` which provides Input/Output of agent savefiles, as well as capturing and plotting performance metrics during training and evaluation. If you wish to integrate a custom blue agent, it is recommended to create a subclass of the :py:class:`primaite.agents.agent_abc.AgentSessionABC` and implement the ``__init__()``, ``_setup()``, ``_save_checkpoint()``, ``learn()``, ``evaluate()``, ``_get_latest_checkpoint``, ``load()``, and ``save()`` methods.
Below is a barebones example of a custom agent implementation:
@@ -21,7 +21,7 @@ Below is a barebones example of a custom agent implementation:
# src/primaite/agents/my_custom_agent.py
from primaite.agents.agent import AgentSessionABC
from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
class CustomAgent(AgentSessionABC):

View File

@@ -0,0 +1,10 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
Simulation Strucutre
====================
The simulation is made up of many smaller components which are related to each other in a tree-like structure. At the top level, there is an object called the ``SimulationController`` _(doesn't exist yet)_, which has a physical network and a software controller for managing software and users.
Each node of the simulation 'tree' has responsibility for creating, deleting, and updating its direct descendants.

View File

@@ -34,10 +34,10 @@ dependencies = [
"plotly==5.15.0",
"polars==0.18.4",
"PyYAML==6.0",
"ray[rllib]==2.2.0",
"stable-baselines3==1.6.2",
"tensorflow==2.12.0",
"typer[all]==0.9.0"
"typer[all]==0.9.0",
"pydantic"
]
[tool.setuptools.dynamic]

View File

@@ -1,286 +1,287 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from __future__ import annotations
# # © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# from __future__ import annotations
import json
import shutil
import zipfile
from datetime import datetime
from logging import Logger
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union
from uuid import uuid4
# import json
# import shutil
# import zipfile
# from datetime import datetime
# from logging import Logger
# from pathlib import Path
# from typing import Any, Callable, Dict, Optional, Union
# from uuid import uuid4
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.a2c import A2CConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import UnifiedLogger
from ray.tune.registry import register_env
# from primaite import getLogger
# from primaite.agents.agent_abc import AgentSessionABC
# from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType
# from primaite.environment.primaite_env import Primaite
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType
from primaite.environment.primaite_env import Primaite
from primaite.exceptions import RLlibAgentError
_LOGGER: Logger = getLogger(__name__)
# # from ray.rllib.algorithms import Algorithm
# # from ray.rllib.algorithms.a2c import A2CConfig
# # from ray.rllib.algorithms.ppo import PPOConfig
# # from ray.tune.logger import UnifiedLogger
# # from ray.tune.registry import register_env
# TODO: verify type of env_config
def _env_creator(env_config: Dict[str, Any]) -> Primaite:
return Primaite(
training_config_path=env_config["training_config_path"],
lay_down_config_path=env_config["lay_down_config_path"],
session_path=env_config["session_path"],
timestamp_str=env_config["timestamp_str"],
)
# # from primaite.exceptions import RLlibAgentError
# _LOGGER: Logger = getLogger(__name__)
# TODO: verify type hint return type
def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
logdir = session_path / "ray_results"
logdir.mkdir(parents=True, exist_ok=True)
# # TODO: verify type of env_config
# def _env_creator(env_config: Dict[str, Any]) -> Primaite:
# return Primaite(
# training_config_path=env_config["training_config_path"],
# lay_down_config_path=env_config["lay_down_config_path"],
# session_path=env_config["session_path"],
# timestamp_str=env_config["timestamp_str"],
# )
def logger_creator(config: Dict) -> UnifiedLogger:
return UnifiedLogger(config, logdir, loggers=None)
# # # TODO: verify type hint return type
# # def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
# # logdir = session_path / "ray_results"
# # logdir.mkdir(parents=True, exist_ok=True)
return logger_creator
# # def logger_creator(config: Dict) -> UnifiedLogger:
# # return UnifiedLogger(config, logdir, loggers=None)
# return logger_creator
class RLlibAgent(AgentSessionABC):
"""An AgentSession class that implements a Ray RLlib agent."""
# # class RLlibAgent(AgentSessionABC):
# # """An AgentSession class that implements a Ray RLlib agent."""
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Initialise the RLLib Agent training session.
# # def __init__(
# # self,
# # training_config_path: Optional[Union[str, Path]] = "",
# # lay_down_config_path: Optional[Union[str, Path]] = "",
# # session_path: Optional[Union[str, Path]] = None,
# # ) -> None:
# # """
# # Initialise the RLLib Agent training session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:raises ValueError: If the training config contains an unexpected value for agent_framework (should be "RLLIB")
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
or `A2C`)
"""
# TODO: implement RLlib agent loading
if session_path is not None:
msg = "RLlib agent loading has not been implemented yet"
_LOGGER.critical(msg)
raise NotImplementedError(msg)
# # :param training_config_path: YAML file containing configurable items defined in
# # `primaite.config.training_config.TrainingConfig`
# # :type training_config_path: Union[path, str]
# # :param lay_down_config_path: YAML file containing configurable items for generating network laydown.
# # :type lay_down_config_path: Union[path, str]
# # :raises ValueError: If the training config contains a bad value for agent_framework (should be "RLLIB")
# # :raises ValueError: If the training config contains a bad value for agent_identifies (should be `PPO`
# # or `A2C`)
# # """
# # # TODO: implement RLlib agent loading
# # if session_path is not None:
# # msg = "RLlib agent loading has not been implemented yet"
# # _LOGGER.critical(msg)
# # raise NotImplementedError(msg)
super().__init__(training_config_path, lay_down_config_path)
if self._training_config.session_type == SessionType.EVAL:
msg = "Cannot evaluate an RLlib agent that hasn't been through training yet."
_LOGGER.critical(msg)
raise RLlibAgentError(msg)
if not self._training_config.agent_framework == AgentFramework.RLLIB:
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_config_class: Union[PPOConfig, A2CConfig]
if self._training_config.agent_identifier == AgentIdentifier.PPO:
self._agent_config_class = PPOConfig
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
self._agent_config_class = A2CConfig
else:
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}"
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_config: Union[PPOConfig, A2CConfig]
# # super().__init__(training_config_path, lay_down_config_path)
# # if self._training_config.session_type == SessionType.EVAL:
# # msg = "Cannot evaluate an RLlib agent that hasn't been through training yet."
# # _LOGGER.critical(msg)
# # raise RLlibAgentError(msg)
# # if not self._training_config.agent_framework == AgentFramework.RLLIB:
# # msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
# # _LOGGER.error(msg)
# # raise ValueError(msg)
# # self._agent_config_class: Union[PPOConfig, A2CConfig]
# # if self._training_config.agent_identifier == AgentIdentifier.PPO:
# # self._agent_config_class = PPOConfig
# # elif self._training_config.agent_identifier == AgentIdentifier.A2C:
# # self._agent_config_class = A2CConfig
# # else:
# # msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}"
# # _LOGGER.error(msg)
# # raise ValueError(msg)
# # self._agent_config: Union[PPOConfig, A2CConfig]
self._current_result: dict
self._setup()
_LOGGER.debug(
f"Created {self.__class__.__name__} using: "
f"agent_framework={self._training_config.agent_framework}, "
f"agent_identifier="
f"{self._training_config.agent_identifier}, "
f"deep_learning_framework="
f"{self._training_config.deep_learning_framework}"
)
self._train_agent = None # Required to capture the learning agent to close after eval
# # self._current_result: dict
# # self._setup()
# # _LOGGER.debug(
# # f"Created {self.__class__.__name__} using: "
# # f"agent_framework={self._training_config.agent_framework}, "
# # f"agent_identifier="
# # f"{self._training_config.agent_identifier}, "
# # f"deep_learning_framework="
# # f"{self._training_config.deep_learning_framework}"
# # )
# # self._train_agent = None # Required to capture the learning agent to close after eval
def _update_session_metadata_file(self) -> None:
"""
Update the ``session_metadata.json`` file.
# # def _update_session_metadata_file(self) -> None:
# # """
# # Update the ``session_metadata.json`` file.
Updates the `session_metadata.json`` in the ``session_path`` directory
with the following key/value pairs:
# # Updates the `session_metadata.json`` in the ``session_path`` directory
# # with the following key/value pairs:
- end_datetime: The date & time the session ended in iso format.
- total_episodes: The total number of training episodes completed.
- total_time_steps: The total number of training time steps completed.
"""
with open(self.session_path / "session_metadata.json", "r") as file:
metadata_dict = json.load(file)
# # - end_datetime: The date & time the session ended in iso format.
# # - total_episodes: The total number of training episodes completed.
# # - total_time_steps: The total number of training time steps completed.
# # """
# # with open(self.session_path / "session_metadata.json", "r") as file:
# # metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
if not self.is_eval:
metadata_dict["learning"]["total_episodes"] = self._current_result["episodes_total"] # noqa
metadata_dict["learning"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
else:
metadata_dict["evaluation"]["total_episodes"] = self._current_result["episodes_total"] # noqa
metadata_dict["evaluation"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
# # metadata_dict["end_datetime"] = datetime.now().isoformat()
# # if not self.is_eval:
# # metadata_dict["learning"]["total_episodes"] = self._current_result["episodes_total"] # noqa
# # metadata_dict["learning"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
# # else:
# # metadata_dict["evaluation"]["total_episodes"] = self._current_result["episodes_total"] # noqa
# # metadata_dict["evaluation"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
_LOGGER.debug("Finished updating session metadata file")
# # filepath = self.session_path / "session_metadata.json"
# # _LOGGER.debug(f"Updating Session Metadata file: {filepath}")
# # with open(filepath, "w") as file:
# # json.dump(metadata_dict, file)
# # _LOGGER.debug("Finished updating session metadata file")
def _setup(self) -> None:
super()._setup()
register_env("primaite", _env_creator)
self._agent_config = self._agent_config_class()
# # def _setup(self) -> None:
# # super()._setup()
# # register_env("primaite", _env_creator)
# # self._agent_config = self._agent_config_class()
self._agent_config.environment(
env="primaite",
env_config=dict(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
),
)
self._agent_config.seed = self._training_config.seed
# # self._agent_config.environment(
# # env="primaite",
# # env_config=dict(
# # training_config_path=self._training_config_path,
# # lay_down_config_path=self._lay_down_config_path,
# # session_path=self.session_path,
# # timestamp_str=self.timestamp_str,
# # ),
# # )
# # self._agent_config.seed = self._training_config.seed
self._agent_config.training(train_batch_size=self._training_config.num_train_steps)
self._agent_config.framework(framework="tf")
# # self._agent_config.training(train_batch_size=self._training_config.num_train_steps)
# # self._agent_config.framework(framework="tf")
self._agent_config.rollouts(
num_rollout_workers=1,
num_envs_per_worker=1,
horizon=self._training_config.num_train_steps,
)
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
# # self._agent_config.rollouts(
# # num_rollout_workers=1,
# # num_envs_per_worker=1,
# # horizon=self._training_config.num_train_steps,
# # )
# # self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
def _save_checkpoint(self) -> None:
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._current_result["episodes_total"]
save_checkpoint = False
if checkpoint_n:
save_checkpoint = episode_count % checkpoint_n == 0
if episode_count and save_checkpoint:
self._agent.save(str(self.checkpoints_path))
# # def _save_checkpoint(self) -> None:
# # checkpoint_n = self._training_config.checkpoint_every_n_episodes
# # episode_count = self._current_result["episodes_total"]
# # save_checkpoint = False
# # if checkpoint_n:
# # save_checkpoint = episode_count % checkpoint_n == 0
# # if episode_count and save_checkpoint:
# # self._agent.save(str(self.checkpoints_path))
def learn(
self,
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
# # def learn(
# # self,
# # **kwargs: Any,
# # ) -> None:
# # """
# # Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_train_steps
episodes = self._training_config.num_train_episodes
# # :param kwargs: Any agent-specific key-word args to be passed.
# # """
# # time_steps = self._training_config.num_train_steps
# # episodes = self._training_config.num_train_episodes
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
for i in range(episodes):
self._current_result = self._agent.train()
self._save_checkpoint()
self.save()
super().learn()
# Done this way as the RLlib eval can only be performed if the session hasn't been stopped
if self._training_config.session_type is not SessionType.TRAIN:
self._train_agent = self._agent
else:
self._agent.stop()
self._plot_av_reward_per_episode(learning_session=True)
# # _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
# # for i in range(episodes):
# # self._current_result = self._agent.train()
# # self._save_checkpoint()
# # self.save()
# # super().learn()
# # # Done this way as the RLlib eval can only be performed if the session hasn't been stopped
# # if self._training_config.session_type is not SessionType.TRAIN:
# # self._train_agent = self._agent
# # else:
# # self._agent.stop()
# # self._plot_av_reward_per_episode(learning_session=True)
def _unpack_saved_agent_into_eval(self) -> Path:
"""Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval."""
agent_restore_path = self.evaluation_path / "agent_restore"
if agent_restore_path.exists():
shutil.rmtree(agent_restore_path)
agent_restore_path.mkdir()
with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file:
zip_file.extractall(agent_restore_path)
return agent_restore_path
# # def _unpack_saved_agent_into_eval(self) -> Path:
# # """Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval."""
# # agent_restore_path = self.evaluation_path / "agent_restore"
# # if agent_restore_path.exists():
# # shutil.rmtree(agent_restore_path)
# # agent_restore_path.mkdir()
# # with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file:
# # zip_file.extractall(agent_restore_path)
# # return agent_restore_path
def _setup_eval(self):
self._can_learn = False
self._can_evaluate = True
self._agent.restore(str(self._unpack_saved_agent_into_eval()))
# # def _setup_eval(self):
# # self._can_learn = False
# # self._can_evaluate = True
# # self._agent.restore(str(self._unpack_saved_agent_into_eval()))
def evaluate(
self,
**kwargs,
):
"""
Evaluate the agent.
# # def evaluate(
# # self,
# # **kwargs,
# # ):
# # """
# # Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
# # :param kwargs: Any agent-specific key-word args to be passed.
# # """
# # time_steps = self._training_config.num_eval_steps
# # episodes = self._training_config.num_eval_episodes
self._setup_eval()
# # self._setup_eval()
self._env: Primaite = Primaite(
self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str
)
# # self._env: Primaite = Primaite(
# # self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str
# # )
self._env.set_as_eval()
self.is_eval = True
if self._training_config.deterministic:
deterministic_str = "deterministic"
else:
deterministic_str = "non-deterministic"
_LOGGER.info(
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
)
for episode in range(episodes):
obs = self._env.reset()
for step in range(time_steps):
action = self._agent.compute_single_action(observation=obs, explore=False)
# # self._env.set_as_eval()
# # self.is_eval = True
# # if self._training_config.deterministic:
# # deterministic_str = "deterministic"
# # else:
# # deterministic_str = "non-deterministic"
# # _LOGGER.info(
# # f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
# # )
# # for episode in range(episodes):
# # obs = self._env.reset()
# # for step in range(time_steps):
# # action = self._agent.compute_single_action(observation=obs, explore=False)
obs, rewards, done, info = self._env.step(action)
# # obs, rewards, done, info = self._env.step(action)
self._env.reset()
self._env.close()
super().evaluate()
# Now we're safe to close the learning agent and write the mean rewards per episode for it
if self._training_config.session_type is not SessionType.TRAIN:
self._train_agent.stop()
self._plot_av_reward_per_episode(learning_session=True)
# Perform a clean-up of the unpacked agent
if (self.evaluation_path / "agent_restore").exists():
shutil.rmtree((self.evaluation_path / "agent_restore"))
# # self._env.reset()
# # self._env.close()
# # super().evaluate()
# # # Now we're safe to close the learning agent and write the mean rewards per episode for it
# # if self._training_config.session_type is not SessionType.TRAIN:
# # self._train_agent.stop()
# # self._plot_av_reward_per_episode(learning_session=True)
# # # Perform a clean-up of the unpacked agent
# # if (self.evaluation_path / "agent_restore").exists():
# # shutil.rmtree((self.evaluation_path / "agent_restore"))
def _get_latest_checkpoint(self) -> None:
raise NotImplementedError
# # def _get_latest_checkpoint(self) -> None:
# # raise NotImplementedError
@classmethod
def load(cls, path: Union[str, Path]) -> RLlibAgent:
"""Load an agent from file."""
raise NotImplementedError
# # @classmethod
# # def load(cls, path: Union[str, Path]) -> RLlibAgent:
# # """Load an agent from file."""
# # raise NotImplementedError
def save(self, overwrite_existing: bool = True) -> None:
"""Save the agent."""
# Make temp dir to save in isolation
temp_dir = self.learning_path / str(uuid4())
temp_dir.mkdir()
# # def save(self, overwrite_existing: bool = True) -> None:
# # """Save the agent."""
# # # Make temp dir to save in isolation
# # temp_dir = self.learning_path / str(uuid4())
# # temp_dir.mkdir()
# Save the agent to the temp dir
self._agent.save(str(temp_dir))
# # # Save the agent to the temp dir
# # self._agent.save(str(temp_dir))
# Capture the saved Rllib checkpoint inside the temp directory
for file in temp_dir.iterdir():
checkpoint_dir = file
break
# # # Capture the saved Rllib checkpoint inside the temp directory
# # for file in temp_dir.iterdir():
# # checkpoint_dir = file
# # break
# Zip the folder
shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
# # # Zip the folder
# # shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
# Drop the temp directory
shutil.rmtree(temp_dir)
# # # Drop the temp directory
# # shutil.rmtree(temp_dir)
def export(self) -> None:
"""Export the agent to transportable file format."""
raise NotImplementedError
# # def export(self) -> None:
# # """Export the agent to transportable file format."""
# # raise NotImplementedError

View File

@@ -99,8 +99,8 @@ class AgentFramework(Enum):
"Custom Agent"
SB3 = 1
"Stable Baselines3"
RLLIB = 2
"Ray RLlib"
# RLLIB = 2
# "Ray RLlib"
class DeepLearningFramework(Enum):

View File

@@ -248,8 +248,8 @@ class TrainingConfig:
def __str__(self) -> str:
obs_str = ",".join([c["name"] for c in self.observation_space["components"]])
tc = f"{self.agent_framework}, "
if self.agent_framework is AgentFramework.RLLIB:
tc += f"{self.deep_learning_framework}, "
# if self.agent_framework is AgentFramework.RLLIB:
# tc += f"{self.deep_learning_framework}, "
tc += f"{self.agent_identifier}, "
if self.agent_identifier is AgentIdentifier.HARDCODED:
tc += f"{self.hard_coded_agent_view}, "

View File

@@ -17,9 +17,8 @@ from primaite import getLogger
from primaite.acl.access_control_list import AccessControlList
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
from primaite.common.enums import ( # AgentFramework,
ActionType,
AgentFramework,
AgentIdentifier,
FileSystemState,
HardwareState,
@@ -236,7 +235,8 @@ class Primaite(Env):
_LOGGER.debug("Action space type NODE selected")
# Terms (for node action space):
# [0, num nodes] - node ID (0 = nothing, node ID)
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, # noqa
# service state, file system state) # noqa
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
self.action_dict = self.create_node_action_dict()
@@ -271,8 +271,8 @@ class Primaite(Env):
@property
def actual_episode_count(self) -> int:
"""Shifts the episode_count by -1 for RLlib learning session."""
if self.training_config.agent_framework is AgentFramework.RLLIB and not self.is_eval:
return self.episode_count - 1
# if self.training_config.agent_framework is AgentFramework.RLLIB and not self.is_eval:
# return self.episode_count - 1
return self.episode_count
def set_as_eval(self) -> None:

View File

@@ -10,7 +10,8 @@ from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
from primaite.agents.hardcoded_acl import HardCodedACLAgent
from primaite.agents.hardcoded_node import HardCodedNodeAgent
from primaite.agents.rllib import RLlibAgent
# from primaite.agents.rllib import RLlibAgent
from primaite.agents.sb3 import SB3Agent
from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyAgent, RandomAgent
from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType
@@ -157,10 +158,12 @@ class PrimaiteSession:
self.legacy_lay_down_config,
)
elif self._training_config.agent_framework == AgentFramework.RLLIB:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
# Ray RLlib Agent
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path, self.session_path)
# elif self._training_config.agent_framework == AgentFramework.RLLIB:
# _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
# # Ray RLlib Agent
# self._agent_session = RLlibAgent(
# self._training_config_path, self._lay_down_config_path, self.session_path
# )
else:
# Invalid AgentFramework

View File

View File

@@ -0,0 +1,55 @@
"""Core of the PrimAITE Simulator."""
from abc import abstractmethod
from typing import Callable, Dict, List
from pydantic import BaseModel
class SimComponent(BaseModel):
"""Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator."""
@abstractmethod
def describe_state(self) -> Dict:
"""
Return a dictionary describing the state of this object and any objects managed by it.
This is similar to pydantic ``model_dump()``, but it only outputs information about the objects owned by this
object. If there are objects referenced by this object that are owned by something else, it is not included in
this output.
"""
return {}
def apply_action(self, action: List[str]) -> None:
"""
Apply an action to a simulation component. Action data is passed in as a 'namespaced' list of strings.
If the list only has one element, the action is intended to be applied directly to this object. If the list has
multiple entries, the action is passed to the child of this object specified by the first one or two entries.
This is essentially a namespace.
For example, ["turn_on",] is meant to apply an action of 'turn on' to this component.
However, ["services", "email_client", "turn_on"] is meant to 'turn on' this component's email client service.
:param action: List describing the action to apply to this object.
:type action: List[str]
"""
possible_actions = self._possible_actions()
if action[0] in possible_actions:
# take the first element off the action list and pass the remaining arguments to the corresponding action
# funciton
possible_actions[action.pop(0)](action)
else:
raise ValueError(f"{self.__class__.__name__} received invalid action {action}")
def _possible_actions(self) -> Dict[str, Callable[[List[str]], None]]:
return {}
def apply_timestep(self) -> None:
"""
Apply a timestep evolution to this component.
Override this method with anything that happens automatically in the component such as scheduled restarts or
sending data.
"""
pass

View File

@@ -13,7 +13,7 @@ _LOGGER = getLogger(__name__)
@pytest.mark.parametrize(
"temp_primaite_session",
[
[TEST_CONFIG_ROOT / "session_test/training_config_main_rllib.yaml", dos_very_basic_config_path()],
# [TEST_CONFIG_ROOT / "session_test/training_config_main_rllib.yaml", dos_very_basic_config_path()],
[TEST_CONFIG_ROOT / "session_test/training_config_main_sb3.yaml", dos_very_basic_config_path()],
],
indirect=True,

View File

View File

View File

@@ -0,0 +1,79 @@
from typing import Callable, Dict, List, Literal, Tuple
import pytest
from pydantic import ValidationError
from primaite.simulator.core import SimComponent
class TestIsolatedSimComponent:
"""Test the SimComponent class in isolation."""
def test_data_validation(self):
"""
Test that our derived class does not interfere with pydantic data validation.
This test may seem like it's simply validating pydantic data validation, but
actually it is here to give us assurance that any custom functionality we add
to the SimComponent does not interfere with pydantic.
"""
class TestComponent(SimComponent):
name: str
size: Tuple[float, float]
def describe_state(self) -> Dict:
return {}
comp = TestComponent(name="computer", size=(5, 10))
assert isinstance(comp, TestComponent)
with pytest.raises(ValidationError):
invalid_comp = TestComponent(name="computer", size="small") # noqa
def test_serialisation(self):
"""Validate that our added functionality does not interfere with pydantic."""
class TestComponent(SimComponent):
name: str
size: Tuple[float, float]
def describe_state(self) -> Dict:
return {}
comp = TestComponent(name="computer", size=(5, 10))
dump = comp.model_dump()
assert dump == {"name": "computer", "size": (5, 10)}
def test_apply_action(self):
"""Validate that we can override apply_action behaviour and it updates the state of the component."""
class TestComponent(SimComponent):
name: str
status: Literal["on", "off"] = "off"
def describe_state(self) -> Dict:
return {}
def _possible_actions(self) -> Dict[str, Callable[[List[str]], None]]:
return {
"turn_off": self._turn_off,
"turn_on": self._turn_on,
}
def _turn_off(self, options: List[str]) -> None:
assert len(options) == 0, "This action does not support options."
self.status = "off"
def _turn_on(self, options: List[str]) -> None:
assert len(options) == 0, "This action does not support options."
self.status = "on"
comp = TestComponent(name="computer", status="off")
assert comp.status == "off"
comp.apply_action(["turn_on"])
assert comp.status == "on"
with pytest.raises(ValueError):
comp.apply_action(["do_nothing"])