Merged PR 119: Loading SB3 Agents + Loading agent via PrimaiteSession
## Summary - Added a feature which allows a user to load a previous SB3 session - Added a feature which allows a user to load a previous PrimaiteSession - Added a feature which allows a user to load a previous session via the CLI: `primaite session --load "<SESSION_PATH>"` - RLlib is TODO in another ticket #1626 - Parallel tests via the [pytest-xdist](https://pypi.org/project/pytest-xdist/) dependency (MIT licensed) - Moved hardcoded agent into hardcoded_abc.py - renamed agent.py to agent_abc.py to clarify it is an abstract base class - Added documentation to clarify how to use the feature via CLI or using the run function via main.py ## Test process Created [test_session_loading.py](https://dev.azure.com/ma-dev-uk/PrimAITE/_git/PrimAITE/pullrequest/119?_a=files&path=/tests/test_session_loading.py) which loads a previously run session and then performs a learn and evaluation run on the loaded agent/Primate session. The test copies the saved session into a temporary folder, which is then set as the test session path. Once the test is done, the temporary folder should then be deleted ## Checklist - [X] This PR is linked to a **work item** - [X] I have performed **self-review** of the code - [X] I have written **tests** for any new functionality added with this PR - [X] I have updated the **documentation** if this PR changes or adds functionality - [X] I have run **pre-commit** checks for code style Related work items: #1595
This commit is contained in:
@@ -86,5 +86,5 @@ stages:
|
||||
displayName: 'Perform PrimAITE Setup'
|
||||
|
||||
- script: |
|
||||
pytest tests/
|
||||
pytest -n 4
|
||||
displayName: 'Run tests'
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -50,6 +50,9 @@ coverage.xml
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
tests/assets/**/*.png
|
||||
tests/assets/**/tensorboard_logs/
|
||||
tests/assets/**/checkpoints/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
|
||||
@@ -43,6 +43,36 @@ The sub-directory is formatted as such: ``~/primaite/sessions/<yyyy-mm-dd>/<yyyy
|
||||
For example, when running a session at 17:30:00 on 31st January 2023, the session will output to:
|
||||
``~/primaite/sessions/2023-01-31/2023-01-31_17-30-00/``.
|
||||
|
||||
Loading a session
|
||||
-------
|
||||
A previous session can be loaded by providing the **directory** of the previous session to either the ``primaite session`` command from the cli
|
||||
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` with session_path.
|
||||
|
||||
.. tabs::
|
||||
|
||||
.. code-tab:: bash
|
||||
:caption: Unix CLI
|
||||
|
||||
cd ~/primaite
|
||||
source ./.venv/bin/activate
|
||||
primaite session --load "path/to/session"
|
||||
|
||||
.. code-tab:: bash
|
||||
:caption: Powershell CLI
|
||||
|
||||
cd ~\primaite
|
||||
.\.venv\Scripts\activate
|
||||
primaite session --load "path\to\session"
|
||||
|
||||
|
||||
.. code-tab:: python
|
||||
:caption: Python
|
||||
|
||||
from primaite.main import run
|
||||
|
||||
run(session_path=<previous session directory>)
|
||||
|
||||
When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory
|
||||
|
||||
Outputs
|
||||
-------
|
||||
|
||||
@@ -58,6 +58,7 @@ dev = [
|
||||
"pip-licenses==4.3.0",
|
||||
"pre-commit==2.20.0",
|
||||
"pytest==7.2.0",
|
||||
"pytest-xdist==3.3.1",
|
||||
"pytest-cov==4.0.0",
|
||||
"pytest-flake8==1.1.1",
|
||||
"setuptools==66",
|
||||
|
||||
@@ -1,21 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final, Union
|
||||
from typing import Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
|
||||
import primaite
|
||||
from primaite import getLogger, SESSIONS_DIR
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.data_viz.session_plots import plot_av_reward_per_episode
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.utils.session_metadata_parser import parse_session_metadata
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -47,38 +45,63 @@ class AgentSessionABC(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = None,
|
||||
lay_down_config_path: Optional[Union[str, Path]] = None,
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
"""
|
||||
Initialise an agent session from config files.
|
||||
Initialise an agent session from config files, or load a previous session.
|
||||
|
||||
If training configuration and laydown configuration are provided with a session path,
|
||||
the session path will be used.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param session_path: directory path of the session to load
|
||||
"""
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path, str]] = training_config_path
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
|
||||
|
||||
# initialise variables
|
||||
self._env: Primaite
|
||||
self._agent = None
|
||||
self._can_learn: bool = False
|
||||
self._can_evaluate: bool = False
|
||||
self.is_eval = False
|
||||
|
||||
self._uuid = str(uuid4())
|
||||
self.session_timestamp: datetime = datetime.now()
|
||||
"The session timestamp"
|
||||
self.session_path = get_session_path(self.session_timestamp)
|
||||
"The Session path"
|
||||
|
||||
# convert session to path
|
||||
if session_path is not None:
|
||||
if not isinstance(session_path, Path):
|
||||
session_path = Path(session_path)
|
||||
|
||||
# if a session path is provided, load it
|
||||
if not session_path.exists():
|
||||
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
|
||||
|
||||
# load session
|
||||
self.load(session_path)
|
||||
else:
|
||||
# set training config path
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Union[Path, str] = training_config_path
|
||||
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Union[Path, str] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
|
||||
|
||||
# set random UUID for session
|
||||
self._uuid = str(uuid4())
|
||||
"The session timestamp"
|
||||
self.session_path = get_session_path(self.session_timestamp)
|
||||
"The Session path"
|
||||
|
||||
@property
|
||||
def timestamp_str(self) -> str:
|
||||
@@ -227,52 +250,27 @@ class AgentSessionABC(ABC):
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load(cls, path: Union[str, Path]) -> AgentSessionABC:
|
||||
def load(self, path: Union[str, Path]):
|
||||
"""Load an agent from file."""
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
md_dict, training_config_path, laydown_config_path = parse_session_metadata(path)
|
||||
|
||||
if path.exists():
|
||||
# Unpack the session_metadata.json file
|
||||
md_file = path / "session_metadata.json"
|
||||
with open(md_file, "r") as file:
|
||||
md_dict = json.load(file)
|
||||
# set training config path
|
||||
self._training_config_path: Union[Path, str] = training_config_path
|
||||
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
|
||||
self._lay_down_config_path: Union[Path, str] = laydown_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
|
||||
|
||||
# Create a temp directory and dump the training and lay down
|
||||
# configs into it
|
||||
temp_dir = path / ".temp"
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
# set random UUID for session
|
||||
self._uuid = md_dict["uuid"]
|
||||
|
||||
temp_tc = temp_dir / "tc.yaml"
|
||||
with open(temp_tc, "w") as file:
|
||||
yaml.dump(md_dict["env"]["training_config"], file)
|
||||
|
||||
temp_ldc = temp_dir / "ldc.yaml"
|
||||
with open(temp_ldc, "w") as file:
|
||||
yaml.dump(md_dict["env"]["lay_down_config"], file)
|
||||
|
||||
agent = cls(temp_tc, temp_ldc)
|
||||
|
||||
agent.session_path = path
|
||||
|
||||
return agent
|
||||
|
||||
else:
|
||||
# Session path does not exist
|
||||
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
pass
|
||||
# set the session path
|
||||
self.session_path = path
|
||||
"The Session path"
|
||||
|
||||
@property
|
||||
def _saved_agent_path(self) -> Path:
|
||||
file_name = (
|
||||
f"{self._training_config.agent_framework}_"
|
||||
f"{self._training_config.agent_identifier}_"
|
||||
f"{self.timestamp_str}.zip"
|
||||
)
|
||||
file_name = f"{self._training_config.agent_framework}_" f"{self._training_config.agent_identifier}" f".zip"
|
||||
return self.learning_path / file_name
|
||||
|
||||
@abstractmethod
|
||||
@@ -308,104 +306,3 @@ class AgentSessionABC(ABC):
|
||||
fig = plot_av_reward_per_episode(path, title, subtitle)
|
||||
fig.write_image(image_path)
|
||||
_LOGGER.debug(f"Saved average rewards per episode plot to: {path}")
|
||||
|
||||
|
||||
class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
"""
|
||||
An Agent Session ABC for evaluation deterministic agents.
|
||||
|
||||
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
|
||||
implemented.
|
||||
"""
|
||||
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
"""
|
||||
Initialise a hardcoded agent session.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
"""
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
self._setup()
|
||||
|
||||
def _setup(self):
|
||||
self._env: Primaite = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
)
|
||||
super()._setup()
|
||||
self._can_learn = False
|
||||
self._can_evaluate = True
|
||||
|
||||
def _save_checkpoint(self):
|
||||
pass
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
|
||||
@abstractmethod
|
||||
def _calculate_action(self, obs):
|
||||
pass
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
self._env.set_as_eval() # noqa
|
||||
self.is_eval = True
|
||||
|
||||
time_steps = self._training_config.num_eval_steps
|
||||
episodes = self._training_config.num_eval_episodes
|
||||
|
||||
obs = self._env.reset()
|
||||
for episode in range(episodes):
|
||||
# Reset env and collect initial observation
|
||||
for step in range(time_steps):
|
||||
# Calculate action
|
||||
action = self._calculate_action(obs)
|
||||
|
||||
# Perform the step
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(self._training_config.time_delay / 1000)
|
||||
obs = self._env.reset()
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
|
||||
@classmethod
|
||||
def load(cls):
|
||||
"""Load an agent from file."""
|
||||
_LOGGER.warning("Deterministic agents cannot be loaded")
|
||||
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
_LOGGER.warning("Deterministic agents cannot be saved")
|
||||
|
||||
def export(self):
|
||||
"""Export the agent to transportable file format."""
|
||||
_LOGGER.warning("Deterministic agents cannot be exported")
|
||||
115
src/primaite/agents/hardcoded_abc.py
Normal file
115
src/primaite/agents/hardcoded_abc.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
"""
|
||||
An Agent Session ABC for evaluation deterministic agents.
|
||||
|
||||
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
|
||||
implemented.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
"""
|
||||
Initialise a hardcoded agent session.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
"""
|
||||
super().__init__(training_config_path, lay_down_config_path, session_path)
|
||||
self._setup()
|
||||
|
||||
def _setup(self):
|
||||
self._env: Primaite = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
)
|
||||
super()._setup()
|
||||
self._can_learn = False
|
||||
self._can_evaluate = True
|
||||
|
||||
def _save_checkpoint(self):
|
||||
pass
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
|
||||
@abstractmethod
|
||||
def _calculate_action(self, obs):
|
||||
pass
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
self._env.set_as_eval() # noqa
|
||||
self.is_eval = True
|
||||
|
||||
time_steps = self._training_config.num_eval_steps
|
||||
episodes = self._training_config.num_eval_episodes
|
||||
|
||||
obs = self._env.reset()
|
||||
for episode in range(episodes):
|
||||
# Reset env and collect initial observation
|
||||
for step in range(time_steps):
|
||||
# Calculate action
|
||||
action = self._calculate_action(obs)
|
||||
|
||||
# Perform the step
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(self._training_config.time_delay / 1000)
|
||||
obs = self._env.reset()
|
||||
self._env.close()
|
||||
|
||||
@classmethod
|
||||
def load(cls, path=None):
|
||||
"""Load an agent from file."""
|
||||
_LOGGER.warning("Deterministic agents cannot be loaded")
|
||||
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
_LOGGER.warning("Deterministic agents cannot be saved")
|
||||
|
||||
def export(self):
|
||||
"""Export the agent to transportable file format."""
|
||||
_LOGGER.warning("Deterministic agents cannot be exported")
|
||||
@@ -4,7 +4,7 @@ import numpy as np
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import (
|
||||
get_new_action,
|
||||
get_node_of_ip,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
@@ -14,7 +14,7 @@ from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
@@ -43,7 +43,12 @@ def _custom_log_creator(session_path: Path):
|
||||
class RLlibAgent(AgentSessionABC):
|
||||
"""An AgentSession class that implements a Ray RLlib agent."""
|
||||
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
"""
|
||||
Initialise the RLLib Agent training session.
|
||||
|
||||
@@ -56,6 +61,13 @@ class RLlibAgent(AgentSessionABC):
|
||||
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
|
||||
or `A2C`)
|
||||
"""
|
||||
# TODO: implement RLlib agent loading
|
||||
if session_path is not None:
|
||||
msg = "RLlib agent loading has not been implemented yet"
|
||||
_LOGGER.error(msg)
|
||||
print(msg)
|
||||
raise NotImplementedError
|
||||
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
if not self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
@@ -18,7 +19,12 @@ _LOGGER = getLogger(__name__)
|
||||
class SB3Agent(AgentSessionABC):
|
||||
"""An AgentSession class that implements a Stable Baselines3 agent."""
|
||||
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = None,
|
||||
lay_down_config_path: Optional[Union[str, Path]] = None,
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
"""
|
||||
Initialise the SB3 Agent training session.
|
||||
|
||||
@@ -31,7 +37,7 @@ class SB3Agent(AgentSessionABC):
|
||||
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
|
||||
or `A2C`)
|
||||
"""
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
super().__init__(training_config_path, lay_down_config_path, session_path)
|
||||
if not self._training_config.agent_framework == AgentFramework.SB3:
|
||||
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
@@ -47,7 +53,7 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
self._tensorboard_log_path = self.learning_path / "tensorboard_logs"
|
||||
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
|
||||
self._setup()
|
||||
|
||||
_LOGGER.debug(
|
||||
f"Created {self.__class__.__name__} using: "
|
||||
f"agent_framework={self._training_config.agent_framework}, "
|
||||
@@ -57,8 +63,10 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
self.is_eval = False
|
||||
|
||||
self._setup()
|
||||
|
||||
def _setup(self):
|
||||
super()._setup()
|
||||
"""Set up the SB3 Agent."""
|
||||
self._env = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
@@ -66,14 +74,43 @@ class SB3Agent(AgentSessionABC):
|
||||
timestamp_str=self.timestamp_str,
|
||||
)
|
||||
|
||||
self._agent = self._agent_class(
|
||||
PPOMlp,
|
||||
self._env,
|
||||
verbose=self.sb3_output_verbose_level,
|
||||
n_steps=self._training_config.num_train_steps,
|
||||
tensorboard_log=str(self._tensorboard_log_path),
|
||||
seed=self._training_config.seed,
|
||||
)
|
||||
# check if there is a zip file that needs to be loaded
|
||||
load_file = next(self.session_path.rglob("*.zip"), None)
|
||||
|
||||
if not load_file:
|
||||
# create a new env and agent
|
||||
|
||||
self._agent = self._agent_class(
|
||||
PPOMlp,
|
||||
self._env,
|
||||
verbose=self.sb3_output_verbose_level,
|
||||
n_steps=self._training_config.num_train_steps,
|
||||
tensorboard_log=str(self._tensorboard_log_path),
|
||||
seed=self._training_config.seed,
|
||||
)
|
||||
else:
|
||||
# set env values from session metadata
|
||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
md_dict = json.load(file)
|
||||
|
||||
# load environment values
|
||||
if self.is_eval:
|
||||
# evaluation always starts at 0
|
||||
self._env.episode_count = 0
|
||||
self._env.total_step_count = 0
|
||||
else:
|
||||
# carry on from previous learning sessions
|
||||
self._env.episode_count = md_dict["learning"]["total_episodes"]
|
||||
self._env.total_step_count = md_dict["learning"]["total_time_steps"]
|
||||
|
||||
# load the file
|
||||
self._agent = self._agent_class.load(load_file, env=self._env)
|
||||
|
||||
# set agent values
|
||||
self._agent.verbose = self.sb3_output_verbose_level
|
||||
self._agent.tensorboard_log = self.session_path / "learning/tensorboard_logs"
|
||||
|
||||
super()._setup()
|
||||
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
@@ -145,11 +182,6 @@ class SB3Agent(AgentSessionABC):
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Union[str, Path]) -> SB3Agent:
|
||||
"""Load an agent from file."""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
self._agent.save(self._saved_agent_path)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
|
||||
|
||||
|
||||
|
||||
@@ -151,7 +151,7 @@ def setup(overwrite_existing: bool = True):
|
||||
|
||||
|
||||
@app.command()
|
||||
def session(tc: Optional[str] = None, ldc: Optional[str] = None):
|
||||
def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[str] = None):
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
|
||||
@@ -162,11 +162,19 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None):
|
||||
ldc: The lay down config file path. Optional. If no value is passed then
|
||||
example default lay down config is used from:
|
||||
~/primaite/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
|
||||
|
||||
load: The directory of a previous session. Optional. If no value is passed, then the session
|
||||
will use the default training config and laydown config. Inversely, if a training config and laydown config
|
||||
is passed while a session directory is passed, PrimAITE will load the session and ignore the training config
|
||||
and laydown config.
|
||||
"""
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
from primaite.main import run
|
||||
|
||||
if load is not None:
|
||||
run(session_path=load)
|
||||
|
||||
if not tc:
|
||||
tc = main_training_config_path()
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"""The main PrimAITE session runner module."""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.primaite_session import PrimaiteSession
|
||||
@@ -11,16 +11,21 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run(
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
"""
|
||||
Run the PrimAITE Session.
|
||||
|
||||
:param training_config_path: The training config filepath.
|
||||
:param lay_down_config_path: The lay down config filepath.
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param session_path: directory path of the session to load
|
||||
"""
|
||||
session = PrimaiteSession(training_config_path, lay_down_config_path)
|
||||
session = PrimaiteSession(training_config_path, lay_down_config_path, session_path)
|
||||
|
||||
session.setup()
|
||||
session.learn()
|
||||
@@ -31,9 +36,14 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tc")
|
||||
parser.add_argument("--ldc")
|
||||
parser.add_argument("--load")
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.tc:
|
||||
_LOGGER.error("Please provide a training config file using the --tc " "argument")
|
||||
if not args.ldc:
|
||||
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
if args.load:
|
||||
run(session_path=args.load)
|
||||
else:
|
||||
if not args.tc:
|
||||
_LOGGER.error("Please provide a training config file using the --tc " "argument")
|
||||
if not args.ldc:
|
||||
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final, Union
|
||||
from typing import Dict, Final, Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.agents.hardcoded_acl import HardCodedACLAgent
|
||||
from primaite.agents.hardcoded_node import HardCodedNodeAgent
|
||||
from primaite.agents.rllib import RLlibAgent
|
||||
@@ -14,6 +14,7 @@ from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyA
|
||||
from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.utils.session_metadata_parser import parse_session_metadata
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -27,15 +28,39 @@ class PrimaiteSession:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
"""
|
||||
The PrimaiteSession constructor.
|
||||
|
||||
:param training_config_path: The training config path.
|
||||
:param lay_down_config_path: The lay down config path.
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param session_path: directory path of the session to load
|
||||
"""
|
||||
self._agent_session: AgentSessionABC = None # noqa
|
||||
self.session_path: Path = session_path # noqa
|
||||
self.timestamp_str: str = None # noqa
|
||||
self.learning_path: Path = None # noqa
|
||||
self.evaluation_path: Path = None # noqa
|
||||
|
||||
# check if session path is provided
|
||||
if session_path is not None:
|
||||
# set load_session to true
|
||||
self.is_load_session = True
|
||||
if not isinstance(session_path, Path):
|
||||
session_path = Path(session_path)
|
||||
|
||||
# if a session path is provided, load it
|
||||
if not session_path.exists():
|
||||
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
|
||||
|
||||
md_dict, training_config_path, lay_down_config_path = parse_session_metadata(session_path)
|
||||
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path, str]] = training_config_path
|
||||
@@ -46,12 +71,6 @@ class PrimaiteSession:
|
||||
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
|
||||
self._agent_session: AgentSessionABC = None # noqa
|
||||
self.session_path: Path = None # noqa
|
||||
self.timestamp_str: str = None # noqa
|
||||
self.learning_path: Path = None # noqa
|
||||
self.evaluation_path: Path = None # noqa
|
||||
|
||||
def setup(self):
|
||||
"""Performs the session setup."""
|
||||
if self._training_config.agent_framework == AgentFramework.CUSTOM:
|
||||
@@ -60,11 +79,15 @@ class PrimaiteSession:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
# Deterministic Hardcoded Agent with Node Action Space
|
||||
self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = HardCodedNodeAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = HardCodedACLAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = HardCodedACLAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
@@ -77,11 +100,15 @@ class PrimaiteSession:
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}")
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
self._agent_session = DoNothingNodeAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = DoNothingNodeAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = DoNothingACLAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = DoNothingACLAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
@@ -93,10 +120,14 @@ class PrimaiteSession:
|
||||
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}")
|
||||
self._agent_session = RandomAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = RandomAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}")
|
||||
self._agent_session = DummyAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = DummyAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework AgentIdentifier combo
|
||||
@@ -105,12 +136,12 @@ class PrimaiteSession:
|
||||
elif self._training_config.agent_framework == AgentFramework.SB3:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}")
|
||||
# Stable Baselines3 Agent
|
||||
self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path, self.session_path)
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
|
||||
# Ray RLlib Agent
|
||||
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path, self.session_path)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework
|
||||
|
||||
58
src/primaite/utils/session_metadata_parser.py
Normal file
58
src/primaite/utils/session_metadata_parser.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def parse_session_metadata(session_path: Union[Path, str], dict_only=False):
|
||||
"""
|
||||
Loads a session metadata from the given directory path.
|
||||
|
||||
:param session_path: Directory where the session metadata file is in
|
||||
:param dict_only: If dict_only is true, the function will only return the dict contents of session metadata
|
||||
|
||||
:return: Dictionary which has all the session metadata contents
|
||||
:rtype: Dict
|
||||
|
||||
:return: Path where the YAML copy of the training config is dumped into
|
||||
:rtype: str
|
||||
:return: Path where the YAML copy of the laydown config is dumped into
|
||||
:rtype: str
|
||||
"""
|
||||
if not isinstance(session_path, Path):
|
||||
session_path = Path(session_path)
|
||||
|
||||
if not session_path.exists():
|
||||
# Session path does not exist
|
||||
msg = f"Failed to load PrimAITE Session, path does not exist: {session_path}"
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
# Unpack the session_metadata.json file
|
||||
md_file = session_path / "session_metadata.json"
|
||||
with open(md_file, "r") as file:
|
||||
md_dict = json.load(file)
|
||||
|
||||
# if dict only, return dict without doing anything else
|
||||
if dict_only:
|
||||
return md_dict
|
||||
|
||||
# Create a temp directory and dump the training and lay down
|
||||
# configs into it
|
||||
temp_dir = session_path / ".temp"
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
|
||||
temp_tc = temp_dir / "tc.yaml"
|
||||
with open(temp_tc, "w") as file:
|
||||
yaml.dump(md_dict["env"]["training_config"], file)
|
||||
|
||||
temp_ldc = temp_dir / "ldc.yaml"
|
||||
with open(temp_ldc, "w") as file:
|
||||
yaml.dump(md_dict["env"]["lay_down_config"], file)
|
||||
|
||||
return [md_dict, temp_tc, temp_ldc]
|
||||
@@ -4,3 +4,6 @@ from typing import Final
|
||||
|
||||
TEST_CONFIG_ROOT: Final[Path] = Path(__file__).parent / "config"
|
||||
"The tests config root directory."
|
||||
|
||||
TEST_ASSETS_ROOT: Final[Path] = Path(__file__).parent / "assets"
|
||||
"The tests assets root directory."
|
||||
|
||||
BIN
tests/assets/example_sb3_agent_session/learning/SB3_PPO.zip
Normal file
BIN
tests/assets/example_sb3_agent_session/learning/SB3_PPO.zip
Normal file
Binary file not shown.
File diff suppressed because one or more lines are too long
@@ -7,6 +7,14 @@
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: CUSTOM
|
||||
|
||||
# Sets which deep learning framework will be used (by RLlib ONLY).
|
||||
# Default is TF (Tensorflow).
|
||||
# Options are:
|
||||
# "TF" (Tensorflow)
|
||||
# TF2 (Tensorflow 2.X)
|
||||
# TORCH (PyTorch)
|
||||
deep_learning_framework: TF2
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
@@ -17,32 +25,78 @@ agent_framework: CUSTOM
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: DUMMY
|
||||
|
||||
# Sets whether Red Agent POL and IER is randomised.
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
random_red_agent: False
|
||||
|
||||
# The (integer) seed to be used in random number generation
|
||||
# Default is None (null)
|
||||
seed: null
|
||||
|
||||
# Set whether the agent will be deterministic instead of stochastic
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
deterministic: False
|
||||
|
||||
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
|
||||
# Options are:
|
||||
# "BASIC" (The current observation space only)
|
||||
# "FULL" (Full environment view with actions taken and reward feedback)
|
||||
hard_coded_agent_view: FULL
|
||||
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: NODE
|
||||
# observation space
|
||||
observation_space:
|
||||
# flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
# - name: NODE_STATUSES
|
||||
# - name: LINK_TRAFFIC_LEVELS
|
||||
|
||||
|
||||
# Number of episodes for training to run per session
|
||||
num_train_episodes: 10
|
||||
|
||||
# Number of time_steps for training per episode
|
||||
num_train_steps: 256
|
||||
|
||||
# Number of episodes for evaluation to run per session
|
||||
num_eval_episodes: 1
|
||||
|
||||
# Number of time_steps for evaluation per episode
|
||||
num_eval_steps: 15
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 1
|
||||
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
# Sets how often the agent will save a checkpoint (every n time episodes).
|
||||
# Set to 0 if no checkpoints are required. Default is 10
|
||||
checkpoint_every_n_episodes: 10
|
||||
|
||||
# Time delay (milliseconds) between steps for CUSTOM agents.
|
||||
time_delay: 5
|
||||
|
||||
# Type of session to be run. Options are:
|
||||
# "TRAIN" (Trains an agent)
|
||||
# "EVAL" (Evaluates an agent)
|
||||
# "TRAIN_EVAL" (Trains then evaluates an agent)
|
||||
session_type: EVAL
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# The Stable Baselines3 learn/eval output verbosity level:
|
||||
# Options are:
|
||||
# "NONE" (No Output)
|
||||
# "INFO" (Info Messages (such as devices and wrappers used))
|
||||
# "DEBUG" (All Messages)
|
||||
sb3_output_verbose_level: NONE
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
|
||||
@@ -5,7 +5,15 @@
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: CUSTOM
|
||||
agent_framework: SB3
|
||||
|
||||
# Sets which deep learning framework will be used (by RLlib ONLY).
|
||||
# Default is TF (Tensorflow).
|
||||
# Options are:
|
||||
# "TF" (Tensorflow)
|
||||
# TF2 (Tensorflow 2.X)
|
||||
# TORCH (PyTorch)
|
||||
deep_learning_framework: TF2
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
@@ -15,7 +23,7 @@ agent_framework: CUSTOM
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: DUMMY
|
||||
agent_identifier: PPO
|
||||
|
||||
# Sets whether Red Agent POL and IER is randomised.
|
||||
# Options are:
|
||||
@@ -23,92 +31,128 @@ agent_identifier: DUMMY
|
||||
# False
|
||||
random_red_agent: True
|
||||
|
||||
# The (integer) seed to be used in random number generation
|
||||
# Default is None (null)
|
||||
seed: null
|
||||
|
||||
# Set whether the agent will be deterministic instead of stochastic
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
deterministic: False
|
||||
|
||||
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
|
||||
# Options are:
|
||||
# "BASIC" (The current observation space only)
|
||||
# "FULL" (Full environment view with actions taken and reward feedback)
|
||||
hard_coded_agent_view: FULL
|
||||
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: NODE
|
||||
# observation space
|
||||
observation_space:
|
||||
# flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
# - name: NODE_STATUSES
|
||||
# - name: LINK_TRAFFIC_LEVELS
|
||||
|
||||
|
||||
# Number of episodes for training to run per session
|
||||
num_train_episodes: 2
|
||||
num_train_episodes: 10
|
||||
|
||||
# Number of time_steps for training per episode
|
||||
num_train_steps: 15
|
||||
num_train_steps: 256
|
||||
|
||||
# Number of episodes for evaluation to run per session
|
||||
num_eval_episodes: 2
|
||||
num_eval_episodes: 1
|
||||
|
||||
# Number of time_steps for evaluation per episode
|
||||
num_eval_steps: 15
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 1
|
||||
num_eval_steps: 256
|
||||
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: EVAL
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
# Sets how often the agent will save a checkpoint (every n time episodes).
|
||||
# Set to 0 if no checkpoints are required. Default is 10
|
||||
checkpoint_every_n_episodes: 10
|
||||
|
||||
# Time delay (milliseconds) between steps for CUSTOM agents.
|
||||
time_delay: 5
|
||||
|
||||
# Type of session to be run. Options are:
|
||||
# "TRAIN" (Trains an agent)
|
||||
# "EVAL" (Evaluates an agent)
|
||||
# "TRAIN_EVAL" (Trains then evaluates an agent)
|
||||
session_type: TRAIN_EVAL
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# The Stable Baselines3 learn/eval output verbosity level:
|
||||
# Options are:
|
||||
# "NONE" (No Output)
|
||||
# "INFO" (Info Messages (such as devices and wrappers used))
|
||||
# "DEBUG" (All Messages)
|
||||
sb3_output_verbose_level: NONE
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
# Node Hardware State
|
||||
off_should_be_on: -10
|
||||
off_should_be_resetting: -5
|
||||
on_should_be_off: -2
|
||||
on_should_be_resetting: -5
|
||||
resetting_should_be_on: -5
|
||||
resetting_should_be_off: -2
|
||||
resetting: -3
|
||||
off_should_be_on: -0.001
|
||||
off_should_be_resetting: -0.0005
|
||||
on_should_be_off: -0.0002
|
||||
on_should_be_resetting: -0.0005
|
||||
resetting_should_be_on: -0.0005
|
||||
resetting_should_be_off: -0.0002
|
||||
resetting: -0.0003
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: 2
|
||||
good_should_be_compromised: 5
|
||||
good_should_be_overwhelmed: 5
|
||||
patching_should_be_good: -5
|
||||
patching_should_be_compromised: 2
|
||||
patching_should_be_overwhelmed: 2
|
||||
patching: -3
|
||||
compromised_should_be_good: -20
|
||||
compromised_should_be_patching: -20
|
||||
compromised_should_be_overwhelmed: -20
|
||||
compromised: -20
|
||||
overwhelmed_should_be_good: -20
|
||||
overwhelmed_should_be_patching: -20
|
||||
overwhelmed_should_be_compromised: -20
|
||||
overwhelmed: -20
|
||||
good_should_be_patching: 0.0002
|
||||
good_should_be_compromised: 0.0005
|
||||
good_should_be_overwhelmed: 0.0005
|
||||
patching_should_be_good: -0.0005
|
||||
patching_should_be_compromised: 0.0002
|
||||
patching_should_be_overwhelmed: 0.0002
|
||||
patching: -0.0003
|
||||
compromised_should_be_good: -0.002
|
||||
compromised_should_be_patching: -0.002
|
||||
compromised_should_be_overwhelmed: -0.002
|
||||
compromised: -0.002
|
||||
overwhelmed_should_be_good: -0.002
|
||||
overwhelmed_should_be_patching: -0.002
|
||||
overwhelmed_should_be_compromised: -0.002
|
||||
overwhelmed: -0.002
|
||||
# Node File System State
|
||||
good_should_be_repairing: 2
|
||||
good_should_be_restoring: 2
|
||||
good_should_be_corrupt: 5
|
||||
good_should_be_destroyed: 10
|
||||
repairing_should_be_good: -5
|
||||
repairing_should_be_restoring: 2
|
||||
repairing_should_be_corrupt: 2
|
||||
repairing_should_be_destroyed: 0
|
||||
repairing: -3
|
||||
restoring_should_be_good: -10
|
||||
restoring_should_be_repairing: -2
|
||||
restoring_should_be_corrupt: 1
|
||||
restoring_should_be_destroyed: 2
|
||||
restoring: -6
|
||||
corrupt_should_be_good: -10
|
||||
corrupt_should_be_repairing: -10
|
||||
corrupt_should_be_restoring: -10
|
||||
corrupt_should_be_destroyed: 2
|
||||
corrupt: -10
|
||||
destroyed_should_be_good: -20
|
||||
destroyed_should_be_repairing: -20
|
||||
destroyed_should_be_restoring: -20
|
||||
destroyed_should_be_corrupt: -20
|
||||
destroyed: -20
|
||||
scanning: -2
|
||||
good_should_be_repairing: 0.0002
|
||||
good_should_be_restoring: 0.0002
|
||||
good_should_be_corrupt: 0.0005
|
||||
good_should_be_destroyed: 0.001
|
||||
repairing_should_be_good: -0.0005
|
||||
repairing_should_be_restoring: 0.0002
|
||||
repairing_should_be_corrupt: 0.0002
|
||||
repairing_should_be_destroyed: 0.0000
|
||||
repairing: -0.0003
|
||||
restoring_should_be_good: -0.001
|
||||
restoring_should_be_repairing: -0.0002
|
||||
restoring_should_be_corrupt: 0.0001
|
||||
restoring_should_be_destroyed: 0.0002
|
||||
restoring: -0.0006
|
||||
corrupt_should_be_good: -0.001
|
||||
corrupt_should_be_repairing: -0.001
|
||||
corrupt_should_be_restoring: -0.001
|
||||
corrupt_should_be_destroyed: 0.0002
|
||||
corrupt: -0.001
|
||||
destroyed_should_be_good: -0.002
|
||||
destroyed_should_be_repairing: -0.002
|
||||
destroyed_should_be_restoring: -0.002
|
||||
destroyed_should_be_corrupt: -0.002
|
||||
destroyed: -0.002
|
||||
scanning: -0.0002
|
||||
# IER status
|
||||
red_ier_running: -5
|
||||
green_ier_blocked: -10
|
||||
red_ier_running: -0.0005
|
||||
green_ier_blocked: -0.001
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: 5 # The time taken to patch the OS
|
||||
|
||||
@@ -62,7 +62,6 @@ class TempPrimaiteSession(PrimaiteSession):
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
shutil.rmtree(self.session_path)
|
||||
shutil.rmtree(self.session_path.parent)
|
||||
_LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
|
||||
|
||||
|
||||
@@ -114,7 +113,7 @@ def temp_primaite_session(request):
|
||||
"""
|
||||
training_config_path = request.param[0]
|
||||
lay_down_config_path = request.param[1]
|
||||
with patch("primaite.agents.agent.get_session_path", get_temp_session_path) as mck:
|
||||
with patch("primaite.agents.agent_abc.get_session_path", get_temp_session_path) as mck:
|
||||
mck.session_timestamp = datetime.now()
|
||||
|
||||
return TempPrimaiteSession(training_config_path, lay_down_config_path)
|
||||
|
||||
188
tests/test_session_loading.py
Normal file
188
tests/test_session_loading.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import os.path
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from uuid import uuid4
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.sb3 import SB3Agent
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.main import run
|
||||
from primaite.primaite_session import PrimaiteSession
|
||||
from primaite.utils.session_output_reader import av_rewards_dict
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def copy_session_asset(asset_path: Union[str, Path]) -> str:
|
||||
"""Copies the asset into a temporary test folder."""
|
||||
if asset_path is None:
|
||||
raise Exception("No path provided")
|
||||
|
||||
if isinstance(asset_path, Path):
|
||||
asset_path = str(os.path.normpath(asset_path))
|
||||
|
||||
copy_path = str(Path(tempfile.gettempdir()) / "primaite" / str(uuid4()))
|
||||
|
||||
# copy the asset into a temp path
|
||||
try:
|
||||
shutil.copytree(asset_path, copy_path)
|
||||
except Exception as e:
|
||||
msg = f"Unable to copy directory: {asset_path}"
|
||||
_LOGGER.error(msg, e)
|
||||
print(msg, e)
|
||||
|
||||
_LOGGER.debug(f"Copied test asset to: {copy_path}")
|
||||
|
||||
# return the copied assets path
|
||||
return copy_path
|
||||
|
||||
|
||||
def test_load_sb3_session():
|
||||
"""Test that loading an SB3 agent works."""
|
||||
expected_learn_mean_reward_per_episode = {
|
||||
10: 0,
|
||||
11: -0.008037109374999995,
|
||||
12: -0.007978515624999988,
|
||||
13: -0.008191406249999991,
|
||||
14: -0.00817578124999999,
|
||||
15: -0.008085937499999998,
|
||||
16: -0.007837890624999982,
|
||||
17: -0.007798828124999992,
|
||||
18: -0.007777343749999998,
|
||||
19: -0.007958984374999988,
|
||||
20: -0.0077499999999999835,
|
||||
}
|
||||
|
||||
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
|
||||
|
||||
loaded_agent = SB3Agent(session_path=test_path)
|
||||
|
||||
# loaded agent should have the same UUID as the previous agent
|
||||
assert loaded_agent.uuid == "301874d3-2e14-43c2-ba7f-e2b03ad05dde"
|
||||
assert loaded_agent._training_config.agent_framework == AgentFramework.SB3.name
|
||||
assert loaded_agent._training_config.agent_identifier == AgentIdentifier.PPO.name
|
||||
assert loaded_agent._training_config.deterministic
|
||||
assert loaded_agent._training_config.seed == 12345
|
||||
assert str(loaded_agent.session_path) == str(test_path)
|
||||
|
||||
# run another learn session
|
||||
loaded_agent.learn()
|
||||
|
||||
learn_mean_rewards = av_rewards_dict(
|
||||
loaded_agent.learning_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv"
|
||||
)
|
||||
|
||||
# run is seeded so should have the expected learn value
|
||||
assert learn_mean_rewards == expected_learn_mean_reward_per_episode
|
||||
|
||||
# run an evaluation
|
||||
loaded_agent.evaluate()
|
||||
|
||||
# load the evaluation average reward csv file
|
||||
eval_mean_reward = av_rewards_dict(
|
||||
loaded_agent.evaluation_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv"
|
||||
)
|
||||
|
||||
# the agent config ran the evaluation in deterministic mode, so should have the same reward value
|
||||
assert len(set(eval_mean_reward.values())) == 1
|
||||
|
||||
# the evaluation should be the same as a previous run
|
||||
assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988
|
||||
|
||||
# delete the test directory
|
||||
shutil.rmtree(test_path)
|
||||
|
||||
|
||||
def test_load_primaite_session():
|
||||
"""Test that loading a Primaite session works."""
|
||||
expected_learn_mean_reward_per_episode = {
|
||||
10: 0,
|
||||
11: -0.008037109374999995,
|
||||
12: -0.007978515624999988,
|
||||
13: -0.008191406249999991,
|
||||
14: -0.00817578124999999,
|
||||
15: -0.008085937499999998,
|
||||
16: -0.007837890624999982,
|
||||
17: -0.007798828124999992,
|
||||
18: -0.007777343749999998,
|
||||
19: -0.007958984374999988,
|
||||
20: -0.0077499999999999835,
|
||||
}
|
||||
|
||||
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
|
||||
|
||||
# create loaded session
|
||||
session = PrimaiteSession(session_path=test_path)
|
||||
|
||||
# run setup on session
|
||||
session.setup()
|
||||
|
||||
# make sure that the session was loaded correctly
|
||||
assert session._agent_session.uuid == "301874d3-2e14-43c2-ba7f-e2b03ad05dde"
|
||||
assert session._agent_session._training_config.agent_framework == AgentFramework.SB3.name
|
||||
assert session._agent_session._training_config.agent_identifier == AgentIdentifier.PPO.name
|
||||
assert session._agent_session._training_config.deterministic
|
||||
assert session._agent_session._training_config.seed == 12345
|
||||
assert str(session._agent_session.session_path) == str(test_path)
|
||||
|
||||
# run another learn session
|
||||
session.learn()
|
||||
|
||||
learn_mean_rewards = av_rewards_dict(
|
||||
session.learning_path / f"average_reward_per_episode_{session.timestamp_str}.csv"
|
||||
)
|
||||
|
||||
# run is seeded so should have the expected learn value
|
||||
assert learn_mean_rewards == expected_learn_mean_reward_per_episode
|
||||
|
||||
# run an evaluation
|
||||
session.evaluate()
|
||||
|
||||
# load the evaluation average reward csv file
|
||||
eval_mean_reward = av_rewards_dict(
|
||||
session.evaluation_path / f"average_reward_per_episode_{session.timestamp_str}.csv"
|
||||
)
|
||||
|
||||
# the agent config ran the evaluation in deterministic mode, so should have the same reward value
|
||||
assert len(set(eval_mean_reward.values())) == 1
|
||||
|
||||
# the evaluation should be the same as a previous run
|
||||
assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988
|
||||
|
||||
# delete the test directory
|
||||
shutil.rmtree(test_path)
|
||||
|
||||
|
||||
def test_run_loading():
|
||||
"""Test loading session via main.run."""
|
||||
expected_learn_mean_reward_per_episode = {
|
||||
10: 0,
|
||||
11: -0.008037109374999995,
|
||||
12: -0.007978515624999988,
|
||||
13: -0.008191406249999991,
|
||||
14: -0.00817578124999999,
|
||||
15: -0.008085937499999998,
|
||||
16: -0.007837890624999982,
|
||||
17: -0.007798828124999992,
|
||||
18: -0.007777343749999998,
|
||||
19: -0.007958984374999988,
|
||||
20: -0.0077499999999999835,
|
||||
}
|
||||
|
||||
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
|
||||
|
||||
# create loaded session
|
||||
run(session_path=test_path)
|
||||
|
||||
learn_mean_rewards = av_rewards_dict(
|
||||
next(Path(test_path).rglob("**/learning/average_reward_per_episode_*.csv"), None)
|
||||
)
|
||||
|
||||
# run is seeded so should have the expected learn value
|
||||
assert learn_mean_rewards == expected_learn_mean_reward_per_episode
|
||||
|
||||
# delete the test directory
|
||||
shutil.rmtree(test_path)
|
||||
Reference in New Issue
Block a user