#901 - merged dev into branch

This commit is contained in:
SunilSamra
2023-07-17 19:54:07 +01:00
27 changed files with 1010 additions and 304 deletions

View File

@@ -86,5 +86,5 @@ stages:
displayName: 'Perform PrimAITE Setup'
- script: |
pytest tests/
pytest -n 4
displayName: 'Run tests'

3
.gitignore vendored
View File

@@ -50,6 +50,9 @@ coverage.xml
.hypothesis/
.pytest_cache/
cover/
tests/assets/**/*.png
tests/assets/**/tensorboard_logs/
tests/assets/**/checkpoints/
# Translations
*.mo

View File

@@ -43,6 +43,36 @@ The sub-directory is formatted as such: ``~/primaite/sessions/<yyyy-mm-dd>/<yyyy
For example, when running a session at 17:30:00 on 31st January 2023, the session will output to:
``~/primaite/sessions/2023-01-31/2023-01-31_17-30-00/``.
Loading a session
-------
A previous session can be loaded by providing the **directory** of the previous session to either the ``primaite session`` command from the cli
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` with session_path.
.. tabs::
.. code-tab:: bash
:caption: Unix CLI
cd ~/primaite
source ./.venv/bin/activate
primaite session --load "path/to/session"
.. code-tab:: bash
:caption: Powershell CLI
cd ~\primaite
.\.venv\Scripts\activate
primaite session --load "path\to\session"
.. code-tab:: python
:caption: Python
from primaite.main import run
run(session_path=<previous session directory>)
When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory
Outputs
-------

View File

@@ -58,6 +58,7 @@ dev = [
"pip-licenses==4.3.0",
"pre-commit==2.20.0",
"pytest==7.2.0",
"pytest-xdist==3.3.1",
"pytest-cov==4.0.0",
"pytest-flake8==1.1.1",
"setuptools==66",

View File

@@ -1,21 +1,19 @@
from __future__ import annotations
import json
import time
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import Dict, Final, Union
from typing import Dict, Optional, Union
from uuid import uuid4
import yaml
import primaite
from primaite import getLogger, SESSIONS_DIR
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
from primaite.data_viz.session_plots import plot_av_reward_per_episode
from primaite.environment.primaite_env import Primaite
from primaite.utils.session_metadata_parser import parse_session_metadata
_LOGGER = getLogger(__name__)
@@ -47,38 +45,63 @@ class AgentSessionABC(ABC):
"""
@abstractmethod
def __init__(self, training_config_path, lay_down_config_path):
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = None,
lay_down_config_path: Optional[Union[str, Path]] = None,
session_path: Optional[Union[str, Path]] = None,
):
"""
Initialise an agent session from config files.
Initialise an agent session from config files, or load a previous session.
If training configuration and laydown configuration are provided with a session path,
the session path will be used.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
"""
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path, str]] = training_config_path
self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
# initialise variables
self._env: Primaite
self._agent = None
self._can_learn: bool = False
self._can_evaluate: bool = False
self.is_eval = False
self._uuid = str(uuid4())
self.session_timestamp: datetime = datetime.now()
"The session timestamp"
self.session_path = get_session_path(self.session_timestamp)
"The Session path"
# convert session to path
if session_path is not None:
if not isinstance(session_path, Path):
session_path = Path(session_path)
# if a session path is provided, load it
if not session_path.exists():
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
# load session
self.load(session_path)
else:
# set training config path
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Union[Path, str] = training_config_path
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Union[Path, str] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
# set random UUID for session
self._uuid = str(uuid4())
"The session timestamp"
self.session_path = get_session_path(self.session_timestamp)
"The Session path"
@property
def timestamp_str(self) -> str:
@@ -227,52 +250,27 @@ class AgentSessionABC(ABC):
def _get_latest_checkpoint(self):
pass
@classmethod
@abstractmethod
def load(cls, path: Union[str, Path]) -> AgentSessionABC:
def load(self, path: Union[str, Path]):
"""Load an agent from file."""
if not isinstance(path, Path):
path = Path(path)
md_dict, training_config_path, laydown_config_path = parse_session_metadata(path)
if path.exists():
# Unpack the session_metadata.json file
md_file = path / "session_metadata.json"
with open(md_file, "r") as file:
md_dict = json.load(file)
# set training config path
self._training_config_path: Union[Path, str] = training_config_path
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
self._lay_down_config_path: Union[Path, str] = laydown_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
# Create a temp directory and dump the training and lay down
# configs into it
temp_dir = path / ".temp"
temp_dir.mkdir(exist_ok=True)
# set random UUID for session
self._uuid = md_dict["uuid"]
temp_tc = temp_dir / "tc.yaml"
with open(temp_tc, "w") as file:
yaml.dump(md_dict["env"]["training_config"], file)
temp_ldc = temp_dir / "ldc.yaml"
with open(temp_ldc, "w") as file:
yaml.dump(md_dict["env"]["lay_down_config"], file)
agent = cls(temp_tc, temp_ldc)
agent.session_path = path
return agent
else:
# Session path does not exist
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
pass
# set the session path
self.session_path = path
"The Session path"
@property
def _saved_agent_path(self) -> Path:
file_name = (
f"{self._training_config.agent_framework}_"
f"{self._training_config.agent_identifier}_"
f"{self.timestamp_str}.zip"
)
file_name = f"{self._training_config.agent_framework}_" f"{self._training_config.agent_identifier}" f".zip"
return self.learning_path / file_name
@abstractmethod
@@ -308,104 +306,3 @@ class AgentSessionABC(ABC):
fig = plot_av_reward_per_episode(path, title, subtitle)
fig.write_image(image_path)
_LOGGER.debug(f"Saved average rewards per episode plot to: {path}")
class HardCodedAgentSessionABC(AgentSessionABC):
"""
An Agent Session ABC for evaluation deterministic agents.
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
implemented.
"""
def __init__(self, training_config_path, lay_down_config_path):
"""
Initialise a hardcoded agent session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
"""
super().__init__(training_config_path, lay_down_config_path)
self._setup()
def _setup(self):
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
super()._setup()
self._can_learn = False
self._can_evaluate = True
def _save_checkpoint(self):
pass
def _get_latest_checkpoint(self):
pass
def learn(
self,
**kwargs,
):
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
def _calculate_action(self, obs):
pass
def evaluate(
self,
**kwargs,
):
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
self.is_eval = True
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
obs = self._env.reset()
for episode in range(episodes):
# Reset env and collect initial observation
for step in range(time_steps):
# Calculate action
action = self._calculate_action(obs)
# Perform the step
obs, reward, done, info = self._env.step(action)
if done:
break
# Introduce a delay between steps
time.sleep(self._training_config.time_delay / 1000)
obs = self._env.reset()
self._env.close()
super().evaluate()
@classmethod
def load(cls):
"""Load an agent from file."""
_LOGGER.warning("Deterministic agents cannot be loaded")
def save(self):
"""Save the agent."""
_LOGGER.warning("Deterministic agents cannot be saved")
def export(self):
"""Export the agent to transportable file format."""
_LOGGER.warning("Deterministic agents cannot be exported")

View File

@@ -0,0 +1,115 @@
import time
from abc import abstractmethod
from pathlib import Path
from typing import Optional, Union
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
class HardCodedAgentSessionABC(AgentSessionABC):
"""
An Agent Session ABC for evaluation deterministic agents.
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
implemented.
"""
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
):
"""
Initialise a hardcoded agent session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
"""
super().__init__(training_config_path, lay_down_config_path, session_path)
self._setup()
def _setup(self):
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
super()._setup()
self._can_learn = False
self._can_evaluate = True
def _save_checkpoint(self):
pass
def _get_latest_checkpoint(self):
pass
def learn(
self,
**kwargs,
):
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
def _calculate_action(self, obs):
pass
def evaluate(
self,
**kwargs,
):
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
self.is_eval = True
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
obs = self._env.reset()
for episode in range(episodes):
# Reset env and collect initial observation
for step in range(time_steps):
# Calculate action
action = self._calculate_action(obs)
# Perform the step
obs, reward, done, info = self._env.step(action)
if done:
break
# Introduce a delay between steps
time.sleep(self._training_config.time_delay / 1000)
obs = self._env.reset()
self._env.close()
@classmethod
def load(cls, path=None):
"""Load an agent from file."""
_LOGGER.warning("Deterministic agents cannot be loaded")
def save(self):
"""Save the agent."""
_LOGGER.warning("Deterministic agents cannot be saved")
def export(self):
"""Export the agent to transportable file format."""
_LOGGER.warning("Deterministic agents cannot be exported")

View File

@@ -4,7 +4,7 @@ import numpy as np
from primaite.acl.access_control_list import AccessControlList
from primaite.acl.acl_rule import ACLRule
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
get_node_of_ip,

View File

@@ -1,6 +1,6 @@
import numpy as np
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable

View File

@@ -4,7 +4,7 @@ import json
import shutil
from datetime import datetime
from pathlib import Path
from typing import Union
from typing import Optional, Union
from uuid import uuid4
from ray.rllib.algorithms import Algorithm
@@ -14,7 +14,7 @@ from ray.tune.logger import UnifiedLogger
from ray.tune.registry import register_env
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
@@ -43,7 +43,12 @@ def _custom_log_creator(session_path: Path):
class RLlibAgent(AgentSessionABC):
"""An AgentSession class that implements a Ray RLlib agent."""
def __init__(self, training_config_path, lay_down_config_path):
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
):
"""
Initialise the RLLib Agent training session.
@@ -56,6 +61,13 @@ class RLlibAgent(AgentSessionABC):
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
or `A2C`)
"""
# TODO: implement RLlib agent loading
if session_path is not None:
msg = "RLlib agent loading has not been implemented yet"
_LOGGER.error(msg)
print(msg)
raise NotImplementedError
super().__init__(training_config_path, lay_down_config_path)
if not self._training_config.agent_framework == AgentFramework.RLLIB:
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"

View File

@@ -1,14 +1,15 @@
from __future__ import annotations
import json
from pathlib import Path
from typing import Union
from typing import Optional, Union
import numpy as np
from stable_baselines3 import A2C, PPO
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
@@ -18,7 +19,12 @@ _LOGGER = getLogger(__name__)
class SB3Agent(AgentSessionABC):
"""An AgentSession class that implements a Stable Baselines3 agent."""
def __init__(self, training_config_path, lay_down_config_path):
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = None,
lay_down_config_path: Optional[Union[str, Path]] = None,
session_path: Optional[Union[str, Path]] = None,
):
"""
Initialise the SB3 Agent training session.
@@ -31,7 +37,7 @@ class SB3Agent(AgentSessionABC):
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
or `A2C`)
"""
super().__init__(training_config_path, lay_down_config_path)
super().__init__(training_config_path, lay_down_config_path, session_path)
if not self._training_config.agent_framework == AgentFramework.SB3:
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
@@ -47,7 +53,7 @@ class SB3Agent(AgentSessionABC):
self._tensorboard_log_path = self.learning_path / "tensorboard_logs"
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
self._setup()
_LOGGER.debug(
f"Created {self.__class__.__name__} using: "
f"agent_framework={self._training_config.agent_framework}, "
@@ -57,8 +63,10 @@ class SB3Agent(AgentSessionABC):
self.is_eval = False
self._setup()
def _setup(self):
super()._setup()
"""Set up the SB3 Agent."""
self._env = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
@@ -66,14 +74,43 @@ class SB3Agent(AgentSessionABC):
timestamp_str=self.timestamp_str,
)
self._agent = self._agent_class(
PPOMlp,
self._env,
verbose=self.sb3_output_verbose_level,
n_steps=self._training_config.num_train_steps,
tensorboard_log=str(self._tensorboard_log_path),
seed=self._training_config.seed,
)
# check if there is a zip file that needs to be loaded
load_file = next(self.session_path.rglob("*.zip"), None)
if not load_file:
# create a new env and agent
self._agent = self._agent_class(
PPOMlp,
self._env,
verbose=self.sb3_output_verbose_level,
n_steps=self._training_config.num_train_steps,
tensorboard_log=str(self._tensorboard_log_path),
seed=self._training_config.seed,
)
else:
# set env values from session metadata
with open(self.session_path / "session_metadata.json", "r") as file:
md_dict = json.load(file)
# load environment values
if self.is_eval:
# evaluation always starts at 0
self._env.episode_count = 0
self._env.total_step_count = 0
else:
# carry on from previous learning sessions
self._env.episode_count = md_dict["learning"]["total_episodes"]
self._env.total_step_count = md_dict["learning"]["total_time_steps"]
# load the file
self._agent = self._agent_class.load(load_file, env=self._env)
# set agent values
self._agent.verbose = self.sb3_output_verbose_level
self._agent.tensorboard_log = self.session_path / "learning/tensorboard_logs"
super()._setup()
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
@@ -145,11 +182,6 @@ class SB3Agent(AgentSessionABC):
self._env.close()
super().evaluate()
@classmethod
def load(cls, path: Union[str, Path]) -> SB3Agent:
"""Load an agent from file."""
raise NotImplementedError
def save(self):
"""Save the agent."""
self._agent.save(self._saved_agent_path)

View File

@@ -1,4 +1,4 @@
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum

View File

@@ -151,7 +151,7 @@ def setup(overwrite_existing: bool = True):
@app.command()
def session(tc: Optional[str] = None, ldc: Optional[str] = None):
def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[str] = None):
"""
Run a PrimAITE session.
@@ -162,11 +162,19 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None):
ldc: The lay down config file path. Optional. If no value is passed then
example default lay down config is used from:
~/primaite/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
load: The directory of a previous session. Optional. If no value is passed, then the session
will use the default training config and laydown config. Inversely, if a training config and laydown config
is passed while a session directory is passed, PrimAITE will load the session and ignore the training config
and laydown config.
"""
from primaite.config.lay_down_config import dos_very_basic_config_path
from primaite.config.training_config import main_training_config_path
from primaite.main import run
if load is not None:
run(session_path=load)
if not tc:
tc = main_training_config_path()

View File

@@ -2,7 +2,7 @@
"""The main PrimAITE session runner module."""
import argparse
from pathlib import Path
from typing import Union
from typing import Optional, Union
from primaite import getLogger
from primaite.primaite_session import PrimaiteSession
@@ -11,16 +11,21 @@ _LOGGER = getLogger(__name__)
def run(
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
):
"""
Run the PrimAITE Session.
:param training_config_path: The training config filepath.
:param lay_down_config_path: The lay down config filepath.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
"""
session = PrimaiteSession(training_config_path, lay_down_config_path)
session = PrimaiteSession(training_config_path, lay_down_config_path, session_path)
session.setup()
session.learn()
@@ -31,9 +36,14 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tc")
parser.add_argument("--ldc")
parser.add_argument("--load")
args = parser.parse_args()
if not args.tc:
_LOGGER.error("Please provide a training config file using the --tc " "argument")
if not args.ldc:
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
if args.load:
run(session_path=args.load)
else:
if not args.tc:
_LOGGER.error("Please provide a training config file using the --tc " "argument")
if not args.ldc:
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
run(training_config_path=args.tc, lay_down_config_path=args.ldc)

View File

@@ -2,10 +2,10 @@
from __future__ import annotations
from pathlib import Path
from typing import Dict, Final, Union
from typing import Dict, Final, Optional, Union
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.agents.agent_abc import AgentSessionABC
from primaite.agents.hardcoded_acl import HardCodedACLAgent
from primaite.agents.hardcoded_node import HardCodedNodeAgent
from primaite.agents.rllib import RLlibAgent
@@ -14,6 +14,7 @@ from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyA
from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
from primaite.utils.session_metadata_parser import parse_session_metadata
_LOGGER = getLogger(__name__)
@@ -27,15 +28,39 @@ class PrimaiteSession:
def __init__(
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
):
"""
The PrimaiteSession constructor.
:param training_config_path: The training config path.
:param lay_down_config_path: The lay down config path.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
"""
self._agent_session: AgentSessionABC = None # noqa
self.session_path: Path = session_path # noqa
self.timestamp_str: str = None # noqa
self.learning_path: Path = None # noqa
self.evaluation_path: Path = None # noqa
# check if session path is provided
if session_path is not None:
# set load_session to true
self.is_load_session = True
if not isinstance(session_path, Path):
session_path = Path(session_path)
# if a session path is provided, load it
if not session_path.exists():
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
md_dict, training_config_path, lay_down_config_path = parse_session_metadata(session_path)
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path, str]] = training_config_path
@@ -46,12 +71,6 @@ class PrimaiteSession:
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self._agent_session: AgentSessionABC = None # noqa
self.session_path: Path = None # noqa
self.timestamp_str: str = None # noqa
self.learning_path: Path = None # noqa
self.evaluation_path: Path = None # noqa
def setup(self):
"""Performs the session setup."""
if self._training_config.agent_framework == AgentFramework.CUSTOM:
@@ -60,11 +79,15 @@ class PrimaiteSession:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
if self._training_config.action_type == ActionType.NODE:
# Deterministic Hardcoded Agent with Node Action Space
self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = HardCodedNodeAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
self._agent_session = HardCodedACLAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = HardCodedACLAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
@@ -77,11 +100,15 @@ class PrimaiteSession:
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}")
if self._training_config.action_type == ActionType.NODE:
self._agent_session = DoNothingNodeAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = DoNothingNodeAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
self._agent_session = DoNothingACLAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = DoNothingACLAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
@@ -93,10 +120,14 @@ class PrimaiteSession:
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}")
self._agent_session = RandomAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = RandomAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}")
self._agent_session = DummyAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = DummyAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
else:
# Invalid AgentFramework AgentIdentifier combo
@@ -105,12 +136,12 @@ class PrimaiteSession:
elif self._training_config.agent_framework == AgentFramework.SB3:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}")
# Stable Baselines3 Agent
self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path)
self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path, self.session_path)
elif self._training_config.agent_framework == AgentFramework.RLLIB:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
# Ray RLlib Agent
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path, self.session_path)
else:
# Invalid AgentFramework

View File

@@ -0,0 +1,58 @@
import json
from pathlib import Path
from typing import Union
import yaml
from primaite import getLogger
_LOGGER = getLogger(__name__)
def parse_session_metadata(session_path: Union[Path, str], dict_only=False):
"""
Loads a session metadata from the given directory path.
:param session_path: Directory where the session metadata file is in
:param dict_only: If dict_only is true, the function will only return the dict contents of session metadata
:return: Dictionary which has all the session metadata contents
:rtype: Dict
:return: Path where the YAML copy of the training config is dumped into
:rtype: str
:return: Path where the YAML copy of the laydown config is dumped into
:rtype: str
"""
if not isinstance(session_path, Path):
session_path = Path(session_path)
if not session_path.exists():
# Session path does not exist
msg = f"Failed to load PrimAITE Session, path does not exist: {session_path}"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
# Unpack the session_metadata.json file
md_file = session_path / "session_metadata.json"
with open(md_file, "r") as file:
md_dict = json.load(file)
# if dict only, return dict without doing anything else
if dict_only:
return md_dict
# Create a temp directory and dump the training and lay down
# configs into it
temp_dir = session_path / ".temp"
temp_dir.mkdir(exist_ok=True)
temp_tc = temp_dir / "tc.yaml"
with open(temp_tc, "w") as file:
yaml.dump(md_dict["env"]["training_config"], file)
temp_ldc = temp_dir / "ldc.yaml"
with open(temp_ldc, "w") as file:
yaml.dump(md_dict["env"]["lay_down_config"], file)
return [md_dict, temp_tc, temp_ldc]

View File

@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Dict, Union
from typing import Any, Dict, Tuple, Union
# Using polars as it's faster than Pandas; it will speed things up when
# files get big!
@@ -13,8 +13,33 @@ def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
The dictionary keys are the episode number, and the values are the mean reward that episode.
:param av_rewards_csv_file: The average rewards per episode csv file path.
:return: The average rewards per episode cdv as a dict.
:return: The average rewards per episode csv as a dict.
"""
df = pl.read_csv(av_rewards_csv_file).to_dict()
df_dict = pl.read_csv(av_rewards_csv_file).to_dict()
return {v: df["Average Reward"][i] for i, v in enumerate(df["Episode"])}
return {v: df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])}
def all_transactions_dict(all_transactions_csv_file: Union[str, Path]) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""
Read an all transactions csv file and return as a dict.
The dict keys are a tuple with the structure (episode, step). The dict
values are the remaining columns as a dict.
:param all_transactions_csv_file: The all transactions csv file path.
:return: The all transactions csv file as a dict.
"""
df_dict = pl.read_csv(all_transactions_csv_file).to_dict()
new_dict = {}
episodes = df_dict["Episode"]
steps = df_dict["Step"]
keys = list(df_dict.keys())
for i in range(len(episodes)):
key = (episodes[i], steps[i])
value_dict = {key: df_dict[key][i] for key in keys if key not in ["Episode", "Step"]}
new_dict[key] = value_dict
return new_dict

View File

@@ -4,3 +4,6 @@ from typing import Final
TEST_CONFIG_ROOT: Final[Path] = Path(__file__).parent / "config"
"The tests config root directory."
TEST_ASSETS_ROOT: Final[Path] = Path(__file__).parent / "assets"
"The tests assets root directory."

File diff suppressed because one or more lines are too long

View File

@@ -7,6 +7,14 @@
# "CUSTOM" (Custom Agent)
agent_framework: CUSTOM
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
@@ -17,27 +25,66 @@ agent_framework: CUSTOM
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: DUMMY
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: False
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: null
# Set whether the agent will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes for training to run per session
num_train_episodes: 10
# Number of time_steps for training per episode
num_train_steps: 256
# Number of episodes for evaluation to run per session
num_eval_episodes: 1
# Number of time_steps for evaluation per episode
num_eval_steps: 15
# Time delay between steps (for generic agents)
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 10
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: EVAL
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
@@ -45,6 +92,13 @@ observation_space_high_value: 1000000000
implicit_acl_rule: DENY
max_number_acl_rules: 10
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values
# Generic
all_ok: 0

View File

@@ -5,7 +5,15 @@
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: CUSTOM
agent_framework: SB3
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
@@ -15,7 +23,7 @@ agent_framework: CUSTOM
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: DUMMY
agent_identifier: PPO
# Sets whether Red Agent POL and IER is randomised.
# Options are:
@@ -23,92 +31,128 @@ agent_identifier: DUMMY
# False
random_red_agent: True
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: null
# Set whether the agent will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes for training to run per session
num_train_episodes: 2
num_train_episodes: 10
# Number of time_steps for training per episode
num_train_steps: 15
num_train_steps: 256
# Number of episodes for evaluation to run per session
num_eval_episodes: 2
num_eval_episodes: 1
# Number of time_steps for evaluation per episode
num_eval_steps: 15
# Time delay between steps (for generic agents)
time_delay: 1
num_eval_steps: 256
# Type of session to be run (TRAINING or EVALUATION)
session_type: EVAL
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 10
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN_EVAL
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -10
off_should_be_resetting: -5
on_should_be_off: -2
on_should_be_resetting: -5
resetting_should_be_on: -5
resetting_should_be_off: -2
resetting: -3
off_should_be_on: -0.001
off_should_be_resetting: -0.0005
on_should_be_off: -0.0002
on_should_be_resetting: -0.0005
resetting_should_be_on: -0.0005
resetting_should_be_off: -0.0002
resetting: -0.0003
# Node Software or Service State
good_should_be_patching: 2
good_should_be_compromised: 5
good_should_be_overwhelmed: 5
patching_should_be_good: -5
patching_should_be_compromised: 2
patching_should_be_overwhelmed: 2
patching: -3
compromised_should_be_good: -20
compromised_should_be_patching: -20
compromised_should_be_overwhelmed: -20
compromised: -20
overwhelmed_should_be_good: -20
overwhelmed_should_be_patching: -20
overwhelmed_should_be_compromised: -20
overwhelmed: -20
good_should_be_patching: 0.0002
good_should_be_compromised: 0.0005
good_should_be_overwhelmed: 0.0005
patching_should_be_good: -0.0005
patching_should_be_compromised: 0.0002
patching_should_be_overwhelmed: 0.0002
patching: -0.0003
compromised_should_be_good: -0.002
compromised_should_be_patching: -0.002
compromised_should_be_overwhelmed: -0.002
compromised: -0.002
overwhelmed_should_be_good: -0.002
overwhelmed_should_be_patching: -0.002
overwhelmed_should_be_compromised: -0.002
overwhelmed: -0.002
# Node File System State
good_should_be_repairing: 2
good_should_be_restoring: 2
good_should_be_corrupt: 5
good_should_be_destroyed: 10
repairing_should_be_good: -5
repairing_should_be_restoring: 2
repairing_should_be_corrupt: 2
repairing_should_be_destroyed: 0
repairing: -3
restoring_should_be_good: -10
restoring_should_be_repairing: -2
restoring_should_be_corrupt: 1
restoring_should_be_destroyed: 2
restoring: -6
corrupt_should_be_good: -10
corrupt_should_be_repairing: -10
corrupt_should_be_restoring: -10
corrupt_should_be_destroyed: 2
corrupt: -10
destroyed_should_be_good: -20
destroyed_should_be_repairing: -20
destroyed_should_be_restoring: -20
destroyed_should_be_corrupt: -20
destroyed: -20
scanning: -2
good_should_be_repairing: 0.0002
good_should_be_restoring: 0.0002
good_should_be_corrupt: 0.0005
good_should_be_destroyed: 0.001
repairing_should_be_good: -0.0005
repairing_should_be_restoring: 0.0002
repairing_should_be_corrupt: 0.0002
repairing_should_be_destroyed: 0.0000
repairing: -0.0003
restoring_should_be_good: -0.001
restoring_should_be_repairing: -0.0002
restoring_should_be_corrupt: 0.0001
restoring_should_be_destroyed: 0.0002
restoring: -0.0006
corrupt_should_be_good: -0.001
corrupt_should_be_repairing: -0.001
corrupt_should_be_restoring: -0.001
corrupt_should_be_destroyed: 0.0002
corrupt: -0.001
destroyed_should_be_good: -0.002
destroyed_should_be_repairing: -0.002
destroyed_should_be_restoring: -0.002
destroyed_should_be_corrupt: -0.002
destroyed: -0.002
scanning: -0.0002
# IER status
red_ier_running: -5
green_ier_blocked: -10
red_ier_running: -0.0005
green_ier_blocked: -0.001
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS

View File

@@ -0,0 +1,163 @@
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: RLLIB
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: PPO
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: False
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: null
# Set whether the agent will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes for training to run per session
num_train_episodes: 10
# Number of time_steps for training per episode
num_train_steps: 256
# Number of episodes for evaluation to run per session
num_eval_episodes: 1
# Number of time_steps for evaluation per episode
num_eval_steps: 256
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 10
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN_EVAL
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -0.001
off_should_be_resetting: -0.0005
on_should_be_off: -0.0002
on_should_be_resetting: -0.0005
resetting_should_be_on: -0.0005
resetting_should_be_off: -0.0002
resetting: -0.0003
# Node Software or Service State
good_should_be_patching: 0.0002
good_should_be_compromised: 0.0005
good_should_be_overwhelmed: 0.0005
patching_should_be_good: -0.0005
patching_should_be_compromised: 0.0002
patching_should_be_overwhelmed: 0.0002
patching: -0.0003
compromised_should_be_good: -0.002
compromised_should_be_patching: -0.002
compromised_should_be_overwhelmed: -0.002
compromised: -0.002
overwhelmed_should_be_good: -0.002
overwhelmed_should_be_patching: -0.002
overwhelmed_should_be_compromised: -0.002
overwhelmed: -0.002
# Node File System State
good_should_be_repairing: 0.0002
good_should_be_restoring: 0.0002
good_should_be_corrupt: 0.0005
good_should_be_destroyed: 0.001
repairing_should_be_good: -0.0005
repairing_should_be_restoring: 0.0002
repairing_should_be_corrupt: 0.0002
repairing_should_be_destroyed: 0.0000
repairing: -0.0003
restoring_should_be_good: -0.001
restoring_should_be_repairing: -0.0002
restoring_should_be_corrupt: 0.0001
restoring_should_be_destroyed: 0.0002
restoring: -0.0006
corrupt_should_be_good: -0.001
corrupt_should_be_repairing: -0.001
corrupt_should_be_restoring: -0.001
corrupt_should_be_destroyed: 0.0002
corrupt: -0.001
destroyed_should_be_good: -0.002
destroyed_should_be_repairing: -0.002
destroyed_should_be_restoring: -0.002
destroyed_should_be_corrupt: -0.002
destroyed: -0.002
scanning: -0.0002
# IER status
red_ier_running: -0.0005
green_ier_blocked: -0.001
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
node_reset_duration: 5 # The time taken to reset a node (hardware)
service_patching_duration: 5 # The time taken to patch a service
file_system_repairing_limit: 5 # The time take to repair the file system
file_system_restoring_limit: 5 # The time take to restore the file system
file_system_scanning_limit: 5 # The time taken to scan the file system

View File

@@ -5,7 +5,7 @@ import shutil
import tempfile
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Union
from typing import Any, Dict, Tuple, Union
from unittest.mock import patch
import pytest
@@ -13,7 +13,7 @@ import pytest
from primaite import getLogger
from primaite.environment.primaite_env import Primaite
from primaite.primaite_session import PrimaiteSession
from primaite.utils.session_output_reader import av_rewards_dict
from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
ACTION_SPACE_NODE_VALUES = 1
@@ -37,16 +37,26 @@ class TempPrimaiteSession(PrimaiteSession):
super().__init__(training_config_path, lay_down_config_path)
self.setup()
def learn_av_reward_per_episode(self) -> Dict[int, float]:
def learn_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the learn av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.learning_path / csv_file)
def eval_av_reward_per_episode_csv(self) -> Dict[int, float]:
def eval_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the eval av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.evaluation_path / csv_file)
def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""Get the learn all transactions from file."""
csv_file = f"all_transactions_{self.timestamp_str}.csv"
return all_transactions_dict(self.learning_path / csv_file)
def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""Get the eval all transactions from file."""
csv_file = f"all_transactions_{self.timestamp_str}.csv"
return all_transactions_dict(self.evaluation_path / csv_file)
def metadata_file_as_dict(self) -> Dict[str, Any]:
"""Read the session_metadata.json file and return as a dict."""
with open(self.session_path / "session_metadata.json", "r") as file:
@@ -113,7 +123,7 @@ def temp_primaite_session(request):
"""
training_config_path = request.param[0]
lay_down_config_path = request.param[1]
with patch("primaite.agents.agent.get_session_path", get_temp_session_path) as mck:
with patch("primaite.agents.agent_abc.get_session_path", get_temp_session_path) as mck:
mck.session_timestamp = datetime.now()
return TempPrimaiteSession(training_config_path, lay_down_config_path)

View File

@@ -48,5 +48,5 @@ def test_rewards_are_being_penalised_at_each_step_function(
"""
with temp_primaite_session as session:
session.evaluate()
ev_rewards = session.eval_av_reward_per_episode_csv()
ev_rewards = session.eval_av_reward_per_episode_dict()
assert ev_rewards[1] == -8.0

23
tests/test_rllib_agent.py Normal file
View File

@@ -0,0 +1,23 @@
import pytest
from primaite import getLogger
from primaite.config.lay_down_config import dos_very_basic_config_path
from tests import TEST_CONFIG_ROOT
_LOGGER = getLogger(__name__)
@pytest.mark.parametrize(
"temp_primaite_session",
[[TEST_CONFIG_ROOT / "training_config_main_rllib.yaml", dos_very_basic_config_path()]],
indirect=True,
)
def test_primaite_session(temp_primaite_session):
"""Test the training_config_main_rllib.yaml training config file."""
with temp_primaite_session as session:
session_path = session.session_path
assert session_path.exists()
session.learn()
assert len(session.learn_av_reward_per_episode_dict().keys()) == 10
assert len(session.learn_all_transactions_dict().keys()) == 10 * 256

View File

@@ -43,11 +43,9 @@ def test_seeded_learning(temp_primaite_session):
session._training_config.seed == 67890
), "Expected output is based upon a agent that was trained with seed 67890"
session.learn()
actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict()
print("\n")
print(session.learn_av_reward_per_episode())
assert expected_mean_reward_per_episode == session.learn_av_reward_per_episode()
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode
@pytest.mark.skip(reason="Inconsistent results. Needs someone with RL knowledge to investigate further.")
@@ -62,5 +60,5 @@ def test_deterministic_evaluation(temp_primaite_session):
# do stuff
session.learn()
session.evaluate()
eval_mean_reward = session.eval_av_reward_per_episode_csv()
eval_mean_reward = session.eval_av_reward_per_episode_dict()
assert len(set(eval_mean_reward.values())) == 1

View File

@@ -0,0 +1,188 @@
import os.path
import shutil
import tempfile
from pathlib import Path
from typing import Union
from uuid import uuid4
from primaite import getLogger
from primaite.agents.sb3 import SB3Agent
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.main import run
from primaite.primaite_session import PrimaiteSession
from primaite.utils.session_output_reader import av_rewards_dict
from tests import TEST_ASSETS_ROOT
_LOGGER = getLogger(__name__)
def copy_session_asset(asset_path: Union[str, Path]) -> str:
"""Copies the asset into a temporary test folder."""
if asset_path is None:
raise Exception("No path provided")
if isinstance(asset_path, Path):
asset_path = str(os.path.normpath(asset_path))
copy_path = str(Path(tempfile.gettempdir()) / "primaite" / str(uuid4()))
# copy the asset into a temp path
try:
shutil.copytree(asset_path, copy_path)
except Exception as e:
msg = f"Unable to copy directory: {asset_path}"
_LOGGER.error(msg, e)
print(msg, e)
_LOGGER.debug(f"Copied test asset to: {copy_path}")
# return the copied assets path
return copy_path
def test_load_sb3_session():
"""Test that loading an SB3 agent works."""
expected_learn_mean_reward_per_episode = {
10: 0,
11: -0.008037109374999995,
12: -0.007978515624999988,
13: -0.008191406249999991,
14: -0.00817578124999999,
15: -0.008085937499999998,
16: -0.007837890624999982,
17: -0.007798828124999992,
18: -0.007777343749999998,
19: -0.007958984374999988,
20: -0.0077499999999999835,
}
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
loaded_agent = SB3Agent(session_path=test_path)
# loaded agent should have the same UUID as the previous agent
assert loaded_agent.uuid == "301874d3-2e14-43c2-ba7f-e2b03ad05dde"
assert loaded_agent._training_config.agent_framework == AgentFramework.SB3.name
assert loaded_agent._training_config.agent_identifier == AgentIdentifier.PPO.name
assert loaded_agent._training_config.deterministic
assert loaded_agent._training_config.seed == 12345
assert str(loaded_agent.session_path) == str(test_path)
# run another learn session
loaded_agent.learn()
learn_mean_rewards = av_rewards_dict(
loaded_agent.learning_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv"
)
# run is seeded so should have the expected learn value
assert learn_mean_rewards == expected_learn_mean_reward_per_episode
# run an evaluation
loaded_agent.evaluate()
# load the evaluation average reward csv file
eval_mean_reward = av_rewards_dict(
loaded_agent.evaluation_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv"
)
# the agent config ran the evaluation in deterministic mode, so should have the same reward value
assert len(set(eval_mean_reward.values())) == 1
# the evaluation should be the same as a previous run
assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988
# delete the test directory
shutil.rmtree(test_path)
def test_load_primaite_session():
"""Test that loading a Primaite session works."""
expected_learn_mean_reward_per_episode = {
10: 0,
11: -0.008037109374999995,
12: -0.007978515624999988,
13: -0.008191406249999991,
14: -0.00817578124999999,
15: -0.008085937499999998,
16: -0.007837890624999982,
17: -0.007798828124999992,
18: -0.007777343749999998,
19: -0.007958984374999988,
20: -0.0077499999999999835,
}
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
# create loaded session
session = PrimaiteSession(session_path=test_path)
# run setup on session
session.setup()
# make sure that the session was loaded correctly
assert session._agent_session.uuid == "301874d3-2e14-43c2-ba7f-e2b03ad05dde"
assert session._agent_session._training_config.agent_framework == AgentFramework.SB3.name
assert session._agent_session._training_config.agent_identifier == AgentIdentifier.PPO.name
assert session._agent_session._training_config.deterministic
assert session._agent_session._training_config.seed == 12345
assert str(session._agent_session.session_path) == str(test_path)
# run another learn session
session.learn()
learn_mean_rewards = av_rewards_dict(
session.learning_path / f"average_reward_per_episode_{session.timestamp_str}.csv"
)
# run is seeded so should have the expected learn value
assert learn_mean_rewards == expected_learn_mean_reward_per_episode
# run an evaluation
session.evaluate()
# load the evaluation average reward csv file
eval_mean_reward = av_rewards_dict(
session.evaluation_path / f"average_reward_per_episode_{session.timestamp_str}.csv"
)
# the agent config ran the evaluation in deterministic mode, so should have the same reward value
assert len(set(eval_mean_reward.values())) == 1
# the evaluation should be the same as a previous run
assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988
# delete the test directory
shutil.rmtree(test_path)
def test_run_loading():
"""Test loading session via main.run."""
expected_learn_mean_reward_per_episode = {
10: 0,
11: -0.008037109374999995,
12: -0.007978515624999988,
13: -0.008191406249999991,
14: -0.00817578124999999,
15: -0.008085937499999998,
16: -0.007837890624999982,
17: -0.007798828124999992,
18: -0.007777343749999998,
19: -0.007958984374999988,
20: -0.0077499999999999835,
}
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
# create loaded session
run(session_path=test_path)
learn_mean_rewards = av_rewards_dict(
next(Path(test_path).rglob("**/learning/average_reward_per_episode_*.csv"), None)
)
# run is seeded so should have the expected learn value
assert learn_mean_rewards == expected_learn_mean_reward_per_episode
# delete the test directory
shutil.rmtree(test_path)