Merge branch 'dev' into feature/1623-typehints

This commit is contained in:
Marek Wolan
2023-07-18 10:03:48 +01:00
109 changed files with 1321 additions and 348 deletions

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import logging
import logging.config
import sys

View File

@@ -1,2 +1,2 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Access Control List. Models firewall functionality."""

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""A class that implements the access control list implementation for the network."""
from typing import Dict

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""A class that implements an access control list rule."""

View File

@@ -1 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Common interface between RL agents from different libraries and PrimAITE."""

View File

@@ -1,28 +1,24 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from __future__ import annotations
import json
import time
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Final, TYPE_CHECKING, Union
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
from uuid import uuid4
import yaml
import primaite
from primaite import getLogger, SESSIONS_DIR
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
from primaite.data_viz.session_plots import plot_av_reward_per_episode
from primaite.environment.primaite_env import Primaite
from primaite.utils.session_metadata_parser import parse_session_metadata
if TYPE_CHECKING:
from logging import Logger
import numpy as np
_LOGGER: "Logger" = getLogger(__name__)
@@ -53,38 +49,63 @@ class AgentSessionABC(ABC):
"""
@abstractmethod
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = None,
lay_down_config_path: Optional[Union[str, Path]] = None,
session_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Initialise an agent session from config files.
Initialise an agent session from config files, or load a previous session.
If training configuration and laydown configuration are provided with a session path,
the session path will be used.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
"""
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path, str]] = training_config_path
self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
# initialise variables
self._env: Primaite
self._agent = None
self._can_learn: bool = False
self._can_evaluate: bool = False
self.is_eval = False
self._uuid = str(uuid4())
self.session_timestamp: datetime = datetime.now()
"The session timestamp"
self.session_path = get_session_path(self.session_timestamp)
"The Session path"
# convert session to path
if session_path is not None:
if not isinstance(session_path, Path):
session_path = Path(session_path)
# if a session path is provided, load it
if not session_path.exists():
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
# load session
self.load(session_path)
else:
# set training config path
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Union[Path, str] = training_config_path
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Union[Path, str] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
# set random UUID for session
self._uuid = str(uuid4())
"The session timestamp"
self.session_path = get_session_path(self.session_timestamp)
"The Session path"
@property
def timestamp_str(self) -> str:
@@ -233,51 +254,27 @@ class AgentSessionABC(ABC):
def _get_latest_checkpoint(self) -> None:
pass
@classmethod
@abstractmethod
def load(cls, path: Union[str, Path]) -> AgentSessionABC:
def load(self, path: Union[str, Path]):
"""Load an agent from file."""
if not isinstance(path, Path):
path = Path(path)
md_dict, training_config_path, laydown_config_path = parse_session_metadata(path)
if path.exists():
# Unpack the session_metadata.json file
md_file = path / "session_metadata.json"
with open(md_file, "r") as file:
md_dict = json.load(file)
# set training config path
self._training_config_path: Union[Path, str] = training_config_path
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
self._lay_down_config_path: Union[Path, str] = laydown_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
# Create a temp directory and dump the training and lay down
# configs into it
temp_dir = path / ".temp"
temp_dir.mkdir(exist_ok=True)
# set random UUID for session
self._uuid = md_dict["uuid"]
temp_tc = temp_dir / "tc.yaml"
with open(temp_tc, "w") as file:
yaml.dump(md_dict["env"]["training_config"], file)
temp_ldc = temp_dir / "ldc.yaml"
with open(temp_ldc, "w") as file:
yaml.dump(md_dict["env"]["lay_down_config"], file)
agent = cls(temp_tc, temp_ldc)
agent.session_path = path
return agent
else:
# Session path does not exist
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
# set the session path
self.session_path = path
"The Session path"
@property
def _saved_agent_path(self) -> Path:
file_name = (
f"{self._training_config.agent_framework}_"
f"{self._training_config.agent_identifier}_"
f"{self.timestamp_str}.zip"
)
file_name = f"{self._training_config.agent_framework}_" f"{self._training_config.agent_identifier}" f".zip"
return self.learning_path / file_name
@abstractmethod
@@ -313,104 +310,3 @@ class AgentSessionABC(ABC):
fig = plot_av_reward_per_episode(path, title, subtitle)
fig.write_image(image_path)
_LOGGER.debug(f"Saved average rewards per episode plot to: {path}")
class HardCodedAgentSessionABC(AgentSessionABC):
"""
An Agent Session ABC for evaluation deterministic agents.
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
implemented.
"""
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
"""
Initialise a hardcoded agent session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
"""
super().__init__(training_config_path, lay_down_config_path)
self._setup()
def _setup(self) -> None:
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
super()._setup()
self._can_learn = False
self._can_evaluate = True
def _save_checkpoint(self) -> None:
pass
def _get_latest_checkpoint(self) -> None:
pass
def learn(
self,
**kwargs: Any,
) -> None:
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
def _calculate_action(self, obs: np.ndarray) -> None:
pass
def evaluate(
self,
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
self.is_eval = True
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
obs = self._env.reset()
for episode in range(episodes):
# Reset env and collect initial observation
for step in range(time_steps):
# Calculate action
action = self._calculate_action(obs)
# Perform the step
obs, reward, done, info = self._env.step(action)
if done:
break
# Introduce a delay between steps
time.sleep(self._training_config.time_delay / 1000)
obs = self._env.reset()
self._env.close()
super().evaluate()
@classmethod
def load(cls) -> None:
"""Load an agent from file."""
_LOGGER.warning("Deterministic agents cannot be loaded")
def save(self) -> None:
"""Save the agent."""
_LOGGER.warning("Deterministic agents cannot be saved")
def export(self) -> None:
"""Export the agent to transportable file format."""
_LOGGER.warning("Deterministic agents cannot be exported")

View File

@@ -0,0 +1,116 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import time
from abc import abstractmethod
from pathlib import Path
from typing import Optional, Union
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
class HardCodedAgentSessionABC(AgentSessionABC):
"""
An Agent Session ABC for evaluation deterministic agents.
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
implemented.
"""
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
):
"""
Initialise a hardcoded agent session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
"""
super().__init__(training_config_path, lay_down_config_path, session_path)
self._setup()
def _setup(self):
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
super()._setup()
self._can_learn = False
self._can_evaluate = True
def _save_checkpoint(self):
pass
def _get_latest_checkpoint(self):
pass
def learn(
self,
**kwargs,
):
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
def _calculate_action(self, obs):
pass
def evaluate(
self,
**kwargs,
):
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
self.is_eval = True
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
obs = self._env.reset()
for episode in range(episodes):
# Reset env and collect initial observation
for step in range(time_steps):
# Calculate action
action = self._calculate_action(obs)
# Perform the step
obs, reward, done, info = self._env.step(action)
if done:
break
# Introduce a delay between steps
time.sleep(self._training_config.time_delay / 1000)
obs = self._env.reset()
self._env.close()
@classmethod
def load(cls, path=None):
"""Load an agent from file."""
_LOGGER.warning("Deterministic agents cannot be loaded")
def save(self):
"""Save the agent."""
_LOGGER.warning("Deterministic agents cannot be saved")
def export(self):
"""Export the agent to transportable file format."""
_LOGGER.warning("Deterministic agents cannot be exported")

View File

@@ -1,10 +1,11 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import Dict, List, Union
import numpy as np
from primaite.acl.access_control_list import AccessControlList
from primaite.acl.acl_rule import ACLRule
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
get_node_of_ip,

View File

@@ -1,6 +1,7 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import numpy as np
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable

View File

@@ -1,10 +1,11 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from __future__ import annotations
import json
import shutil
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union
from uuid import uuid4
from ray.rllib.algorithms import Algorithm
@@ -14,7 +15,7 @@ from ray.tune.logger import UnifiedLogger
from ray.tune.registry import register_env
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
@@ -48,7 +49,12 @@ def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
class RLlibAgent(AgentSessionABC):
"""An AgentSession class that implements a Ray RLlib agent."""
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Initialise the RLLib Agent training session.
@@ -61,6 +67,13 @@ class RLlibAgent(AgentSessionABC):
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
or `A2C`)
"""
# TODO: implement RLlib agent loading
if session_path is not None:
msg = "RLlib agent loading has not been implemented yet"
_LOGGER.error(msg)
print(msg)
raise NotImplementedError
super().__init__(training_config_path, lay_down_config_path)
if not self._training_config.agent_framework == AgentFramework.RLLIB:
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"

View File

@@ -1,14 +1,16 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING, Union
import numpy as np
from stable_baselines3 import A2C, PPO
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
@@ -21,7 +23,12 @@ _LOGGER: "Logger" = getLogger(__name__)
class SB3Agent(AgentSessionABC):
"""An AgentSession class that implements a Stable Baselines3 agent."""
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = None,
lay_down_config_path: Optional[Union[str, Path]] = None,
session_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Initialise the SB3 Agent training session.
@@ -34,7 +41,7 @@ class SB3Agent(AgentSessionABC):
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
or `A2C`)
"""
super().__init__(training_config_path, lay_down_config_path)
super().__init__(training_config_path, lay_down_config_path, session_path)
if not self._training_config.agent_framework == AgentFramework.SB3:
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
@@ -51,7 +58,7 @@ class SB3Agent(AgentSessionABC):
self._tensorboard_log_path = self.learning_path / "tensorboard_logs"
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
self._setup()
_LOGGER.debug(
f"Created {self.__class__.__name__} using: "
f"agent_framework={self._training_config.agent_framework}, "
@@ -61,8 +68,10 @@ class SB3Agent(AgentSessionABC):
self.is_eval = False
self._setup()
def _setup(self) -> None:
super()._setup()
"""Set up the SB3 Agent."""
self._env = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
@@ -70,14 +79,43 @@ class SB3Agent(AgentSessionABC):
timestamp_str=self.timestamp_str,
)
self._agent = self._agent_class(
PPOMlp,
self._env,
verbose=self.sb3_output_verbose_level,
n_steps=self._training_config.num_train_steps,
tensorboard_log=str(self._tensorboard_log_path),
seed=self._training_config.seed,
)
# check if there is a zip file that needs to be loaded
load_file = next(self.session_path.rglob("*.zip"), None)
if not load_file:
# create a new env and agent
self._agent = self._agent_class(
PPOMlp,
self._env,
verbose=self.sb3_output_verbose_level,
n_steps=self._training_config.num_train_steps,
tensorboard_log=str(self._tensorboard_log_path),
seed=self._training_config.seed,
)
else:
# set env values from session metadata
with open(self.session_path / "session_metadata.json", "r") as file:
md_dict = json.load(file)
# load environment values
if self.is_eval:
# evaluation always starts at 0
self._env.episode_count = 0
self._env.total_step_count = 0
else:
# carry on from previous learning sessions
self._env.episode_count = md_dict["learning"]["total_episodes"]
self._env.total_step_count = md_dict["learning"]["total_time_steps"]
# load the file
self._agent = self._agent_class.load(load_file, env=self._env)
# set agent values
self._agent.verbose = self.sb3_output_verbose_level
self._agent.tensorboard_log = self.session_path / "learning/tensorboard_logs"
super()._setup()
def _save_checkpoint(self) -> None:
checkpoint_n = self._training_config.checkpoint_every_n_episodes
@@ -149,11 +187,6 @@ class SB3Agent(AgentSessionABC):
self._env.close()
super().evaluate()
@classmethod
def load(cls, path: Union[str, Path]) -> SB3Agent:
"""Load an agent from file."""
raise NotImplementedError
def save(self) -> None:
"""Save the agent."""
self._agent.save(self._saved_agent_path)

View File

@@ -1,6 +1,7 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import TYPE_CHECKING
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
if TYPE_CHECKING:

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import Dict, List, Union
import numpy as np

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Provides a CLI using Typer as an entry point."""
import logging
import os
@@ -151,7 +151,7 @@ def setup(overwrite_existing: bool = True) -> None:
@app.command()
def session(tc: Optional[str] = None, ldc: Optional[str] = None) -> None:
def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[str] = None) -> None:
"""
Run a PrimAITE session.
@@ -162,11 +162,19 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None) -> None:
ldc: The lay down config file path. Optional. If no value is passed then
example default lay down config is used from:
~/primaite/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
load: The directory of a previous session. Optional. If no value is passed, then the session
will use the default training config and laydown config. Inversely, if a training config and laydown config
is passed while a session directory is passed, PrimAITE will load the session and ignore the training config
and laydown config.
"""
from primaite.config.lay_down_config import dos_very_basic_config_path
from primaite.config.training_config import main_training_config_path
from primaite.main import run
if load is not None:
run(session_path=load)
if not tc:
tc = main_training_config_path()

View File

@@ -1,2 +1,2 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Objects which are shared between many PrimAITE modules."""

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Enumerations for APE."""
from enum import Enum, IntEnum

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The protocol class."""

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The Service class."""
from primaite.common.enums import SoftwareState

View File

@@ -1 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Configuration parameters for running experiments."""

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from pathlib import Path
from typing import Any, Dict, Final, TYPE_CHECKING, Union

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from __future__ import annotations
from dataclasses import dataclass, field

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Utility to generate plots of sessions metrics after PrimAITE."""
from enum import Enum

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from pathlib import Path
from typing import Dict, Optional, Union

View File

@@ -1,2 +1,2 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Gym/Gymnasium environment for RL agents consisting of a simulated computer network."""

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Module for handling configurable observation spaces in PrimAITE."""
import logging
from abc import ABC, abstractmethod

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Main environment module containing the PRIMmary AI Training Evironment (Primaite) class."""
import copy
import logging

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Implements reward function."""
from typing import Dict, TYPE_CHECKING, Union

View File

@@ -1,2 +1,2 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Network connections between nodes in the simulation."""

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The link class."""
from typing import List

View File

@@ -1,8 +1,8 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The main PrimAITE session runner module."""
import argparse
from pathlib import Path
from typing import Union
from typing import Optional, Union
from primaite import getLogger
from primaite.primaite_session import PrimaiteSession
@@ -11,16 +11,21 @@ _LOGGER = getLogger(__name__)
def run(
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Run the PrimAITE Session.
:param training_config_path: The training config filepath.
:param lay_down_config_path: The lay down config filepath.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
"""
session = PrimaiteSession(training_config_path, lay_down_config_path)
session = PrimaiteSession(training_config_path, lay_down_config_path, session_path)
session.setup()
session.learn()
@@ -31,9 +36,14 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tc")
parser.add_argument("--ldc")
parser.add_argument("--load")
args = parser.parse_args()
if not args.tc:
_LOGGER.error("Please provide a training config file using the --tc " "argument")
if not args.ldc:
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
if args.load:
run(session_path=args.load)
else:
if not args.tc:
_LOGGER.error("Please provide a training config file using the --tc " "argument")
if not args.ldc:
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
run(training_config_path=args.tc, lay_down_config_path=args.ldc)

View File

@@ -1,2 +1,2 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Nodes represent network hosts in the simulation."""

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""An Active Node (i.e. not an actuator)."""
import logging
from typing import Final

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The base Node class."""
from typing import Final

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Defines node behaviour for Green PoL."""
from typing import TYPE_CHECKING, Union

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Defines node behaviour for Green PoL."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Union

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The Passive Node class (i.e. an actuator)."""
from primaite.common.enums import HardwareState, NodeType, Priority
from primaite.config.training_config import TrainingConfig

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""A Service Node (i.e. not an actuator)."""
import logging
from typing import Dict, Final

View File

@@ -1,5 +1,6 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Contains default jupyter notebooks which demonstrate PrimAITE functionality."""
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
import importlib.util
import os
import subprocess

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Pattern of Life- Represents the actions of users on the network."""
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Implements Pattern of Life on the network (nodes and links)."""
from typing import Dict

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""
Information Exchange Requirements for APE.

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Implements POL on the network (nodes and links) resulting from the red agent attack."""
from typing import Dict

View File

@@ -1,11 +1,12 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Main entry point to PrimAITE. Configure training/evaluation experiments and input/output."""
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, Final, Union
from typing import Any, Dict, Final, Optional, Union
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.agents.agent_abc import AgentSessionABC
from primaite.agents.hardcoded_acl import HardCodedACLAgent
from primaite.agents.hardcoded_node import HardCodedNodeAgent
from primaite.agents.rllib import RLlibAgent
@@ -14,6 +15,7 @@ from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyA
from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
from primaite.utils.session_metadata_parser import parse_session_metadata
_LOGGER = getLogger(__name__)
@@ -27,15 +29,39 @@ class PrimaiteSession:
def __init__(
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
) -> None:
"""
The PrimaiteSession constructor.
:param training_config_path: The training config path.
:param lay_down_config_path: The lay down config path.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
"""
self._agent_session: AgentSessionABC = None # noqa
self.session_path: Path = session_path # noqa
self.timestamp_str: str = None # noqa
self.learning_path: Path = None # noqa
self.evaluation_path: Path = None # noqa
# check if session path is provided
if session_path is not None:
# set load_session to true
self.is_load_session = True
if not isinstance(session_path, Path):
session_path = Path(session_path)
# if a session path is provided, load it
if not session_path.exists():
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
md_dict, training_config_path, lay_down_config_path = parse_session_metadata(session_path)
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path, str]] = training_config_path
@@ -60,11 +86,15 @@ class PrimaiteSession:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
if self._training_config.action_type == ActionType.NODE:
# Deterministic Hardcoded Agent with Node Action Space
self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = HardCodedNodeAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
self._agent_session = HardCodedACLAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = HardCodedACLAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
@@ -77,11 +107,15 @@ class PrimaiteSession:
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}")
if self._training_config.action_type == ActionType.NODE:
self._agent_session = DoNothingNodeAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = DoNothingNodeAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
self._agent_session = DoNothingACLAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = DoNothingACLAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
@@ -93,10 +127,14 @@ class PrimaiteSession:
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}")
self._agent_session = RandomAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = RandomAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}")
self._agent_session = DummyAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = DummyAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
else:
# Invalid AgentFramework AgentIdentifier combo
@@ -105,12 +143,12 @@ class PrimaiteSession:
elif self._training_config.agent_framework == AgentFramework.SB3:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}")
# Stable Baselines3 Agent
self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path)
self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path, self.session_path)
elif self._training_config.agent_framework == AgentFramework.RLLIB:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
# Ray RLlib Agent
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path)
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path, self.session_path)
else:
# Invalid AgentFramework

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Utilities to prepare the user's data folders."""
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import TYPE_CHECKING
from primaite import getLogger

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import filecmp
import os
import shutil

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import filecmp
import os
import shutil

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import TYPE_CHECKING
from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Record data of the system's state and agent's observations and actions."""
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""The Transaction class."""
from datetime import datetime
from typing import List, Optional, Tuple, TYPE_CHECKING, Union

View File

@@ -1 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Utilities for PrimAITE."""

View File

@@ -1,4 +1,4 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import os
from pathlib import Path
from typing import TYPE_CHECKING

View File

@@ -0,0 +1,59 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import json
from pathlib import Path
from typing import Union
import yaml
from primaite import getLogger
_LOGGER = getLogger(__name__)
def parse_session_metadata(session_path: Union[Path, str], dict_only=False):
"""
Loads a session metadata from the given directory path.
:param session_path: Directory where the session metadata file is in
:param dict_only: If dict_only is true, the function will only return the dict contents of session metadata
:return: Dictionary which has all the session metadata contents
:rtype: Dict
:return: Path where the YAML copy of the training config is dumped into
:rtype: str
:return: Path where the YAML copy of the laydown config is dumped into
:rtype: str
"""
if not isinstance(session_path, Path):
session_path = Path(session_path)
if not session_path.exists():
# Session path does not exist
msg = f"Failed to load PrimAITE Session, path does not exist: {session_path}"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
# Unpack the session_metadata.json file
md_file = session_path / "session_metadata.json"
with open(md_file, "r") as file:
md_dict = json.load(file)
# if dict only, return dict without doing anything else
if dict_only:
return md_dict
# Create a temp directory and dump the training and lay down
# configs into it
temp_dir = session_path / ".temp"
temp_dir.mkdir(exist_ok=True)
temp_tc = temp_dir / "tc.yaml"
with open(temp_tc, "w") as file:
yaml.dump(md_dict["env"]["training_config"], file)
temp_ldc = temp_dir / "ldc.yaml"
with open(temp_ldc, "w") as file:
yaml.dump(md_dict["env"]["lay_down_config"], file)
return [md_dict, temp_tc, temp_ldc]

View File

@@ -1,5 +1,6 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from pathlib import Path
from typing import Dict, Union
from typing import Any, Dict, Tuple, Union
# Using polars as it's faster than Pandas; it will speed things up when
# files get big!
@@ -13,8 +14,33 @@ def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
The dictionary keys are the episode number, and the values are the mean reward that episode.
:param av_rewards_csv_file: The average rewards per episode csv file path.
:return: The average rewards per episode cdv as a dict.
:return: The average rewards per episode csv as a dict.
"""
df = pl.read_csv(av_rewards_csv_file).to_dict()
df_dict = pl.read_csv(av_rewards_csv_file).to_dict()
return {v: df["Average Reward"][i] for i, v in enumerate(df["Episode"])}
return {v: df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])}
def all_transactions_dict(all_transactions_csv_file: Union[str, Path]) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""
Read an all transactions csv file and return as a dict.
The dict keys are a tuple with the structure (episode, step). The dict
values are the remaining columns as a dict.
:param all_transactions_csv_file: The all transactions csv file path.
:return: The all transactions csv file as a dict.
"""
df_dict = pl.read_csv(all_transactions_csv_file).to_dict()
new_dict = {}
episodes = df_dict["Episode"]
steps = df_dict["Step"]
keys = list(df_dict.keys())
for i in range(len(episodes)):
key = (episodes[i], steps[i])
value_dict = {key: df_dict[key][i] for key in keys if key not in ["Episode", "Step"]}
new_dict[key] = value_dict
return new_dict

View File

@@ -1,3 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import csv
from logging import Logger
from typing import Final, List, Tuple, TYPE_CHECKING, Union