Merge remote-tracking branch 'origin/dev' into feature/2502-placeholders-docs
@@ -26,7 +26,7 @@ jobs:
|
||||
displayName: 'Install build dependencies'
|
||||
|
||||
- script: |
|
||||
pip install -e .[dev]
|
||||
pip install -e .[dev,rl]
|
||||
displayName: 'Install PrimAITE for docs autosummary'
|
||||
|
||||
- script: |
|
||||
|
||||
@@ -14,31 +14,31 @@ parameters:
|
||||
- name: matrix
|
||||
type: object
|
||||
default:
|
||||
- job_name: 'UbuntuPython38'
|
||||
py: '3.8'
|
||||
img: 'ubuntu-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
# - job_name: 'UbuntuPython38'
|
||||
# py: '3.8'
|
||||
# img: 'ubuntu-latest'
|
||||
# every_time: false
|
||||
# publish_coverage: false
|
||||
- job_name: 'UbuntuPython311'
|
||||
py: '3.11'
|
||||
img: 'ubuntu-latest'
|
||||
every_time: true
|
||||
publish_coverage: true
|
||||
- job_name: 'WindowsPython38'
|
||||
py: '3.8'
|
||||
img: 'windows-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
# - job_name: 'WindowsPython38'
|
||||
# py: '3.8'
|
||||
# img: 'windows-latest'
|
||||
# every_time: false
|
||||
# publish_coverage: false
|
||||
- job_name: 'WindowsPython311'
|
||||
py: '3.11'
|
||||
img: 'windows-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
- job_name: 'MacOSPython38'
|
||||
py: '3.8'
|
||||
img: 'macOS-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
# - job_name: 'MacOSPython38'
|
||||
# py: '3.8'
|
||||
# img: 'macOS-latest'
|
||||
# every_time: false
|
||||
# publish_coverage: false
|
||||
- job_name: 'MacOSPython311'
|
||||
py: '3.11'
|
||||
img: 'macOS-latest'
|
||||
@@ -82,12 +82,12 @@ stages:
|
||||
|
||||
- script: |
|
||||
PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl)
|
||||
python -m pip install $PRIMAITE_WHEEL[dev]
|
||||
python -m pip install $PRIMAITE_WHEEL[dev,rl]
|
||||
displayName: 'Install PrimAITE'
|
||||
condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' ))
|
||||
|
||||
- script: |
|
||||
forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file[dev]"
|
||||
forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file[dev,rl]"
|
||||
displayName: 'Install PrimAITE'
|
||||
condition: eq( variables['Agent.OS'], 'Windows_NT' )
|
||||
|
||||
|
||||
2
.github/workflows/build-sphinx.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
- name: Install PrimAITE for docs autosummary
|
||||
run: |
|
||||
set -x
|
||||
python -m pip install -e .[dev]
|
||||
python -m pip install -e .[dev,rl]
|
||||
|
||||
- name: Run build script for Sphinx pages
|
||||
env:
|
||||
|
||||
2
.github/workflows/python-package.yml
vendored
@@ -48,7 +48,7 @@ jobs:
|
||||
- name: Install PrimAITE
|
||||
run: |
|
||||
PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl)
|
||||
python -m pip install $PRIMAITE_WHEEL[dev]
|
||||
python -m pip install $PRIMAITE_WHEEL[dev,rl]
|
||||
|
||||
- name: Perform PrimAITE Setup
|
||||
run: |
|
||||
|
||||
4
.gitignore
vendored
@@ -164,3 +164,7 @@ src/primaite/notebooks/scratch.py
|
||||
sandbox.py
|
||||
sandbox/
|
||||
sandbox.ipynb
|
||||
|
||||
# benchmarking
|
||||
**/benchmark/sessions/
|
||||
**/benchmark/output/
|
||||
|
||||
@@ -43,6 +43,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Added support for SQL INSERT command.
|
||||
- Added ability to log each agent's action choices in each step to a JSON file.
|
||||
- Removal of Link bandwidth hardcoding. This can now be configured via the network configuraiton yaml. Will default to 100 if not present.
|
||||
- Added NMAP application to all host and layer-3 network nodes.
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ cd ~\primaite
|
||||
python3 -m venv .venv
|
||||
attrib +h .venv /s /d # Hides the .venv directory
|
||||
.\.venv\Scripts\activate
|
||||
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
|
||||
pip install primaite-3.0.0-py3-none-any.whl[rl]
|
||||
primaite setup
|
||||
```
|
||||
|
||||
@@ -66,7 +66,7 @@ mkdir ~/primaite
|
||||
cd ~/primaite
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
|
||||
pip install primaite-3.0.0-py3-none-any.whl[rl]
|
||||
primaite setup
|
||||
```
|
||||
|
||||
@@ -105,7 +105,7 @@ source venv/bin/activate
|
||||
#### 5. Install `primaite` with the dev extra into the venv along with all of it's dependencies
|
||||
|
||||
```bash
|
||||
python3 -m pip install -e .[dev]
|
||||
python3 -m pip install -e .[dev,rl]
|
||||
```
|
||||
|
||||
#### 6. Perform the PrimAITE setup:
|
||||
@@ -114,6 +114,9 @@ python3 -m pip install -e .[dev]
|
||||
primaite setup
|
||||
```
|
||||
|
||||
#### Note
|
||||
*It is possible to install PrimAITE without Ray RLLib, StableBaselines3, or any deep learning libraries by omitting the `rl` flag in the pip install command.*
|
||||
|
||||
### Running PrimAITE
|
||||
|
||||
Use the provided jupyter notebooks as a starting point to try running PrimAITE. They are automatically copied to your PrimAITE notebook folder when you run `primaite setup`.
|
||||
|
||||
21
benchmark/benchmark.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
|
||||
|
||||
class BenchmarkPrimaiteGymEnv(PrimaiteGymEnv):
|
||||
"""
|
||||
Class that extends the PrimaiteGymEnv.
|
||||
|
||||
The reset method is extended so that the average rewards per episode are recorded.
|
||||
"""
|
||||
|
||||
total_time_steps: int = 0
|
||||
|
||||
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
"""Overrides the PrimAITEGymEnv reset so that the total timesteps is saved."""
|
||||
self.total_time_steps += self.game.step_counter
|
||||
|
||||
return super().reset(seed=seed)
|
||||
@@ -1,211 +1,93 @@
|
||||
# flake8: noqa
|
||||
raise DeprecationWarning(
|
||||
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
|
||||
)
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Optional, Tuple, Union
|
||||
from unittest.mock import patch
|
||||
from typing import Any, Dict, Final, Tuple
|
||||
|
||||
import GPUtil
|
||||
import plotly.graph_objects as go
|
||||
import polars as pl
|
||||
import psutil
|
||||
import yaml
|
||||
from plotly.graph_objs import Figure
|
||||
from pylatex import Command, Document
|
||||
from pylatex import Figure as LatexFigure
|
||||
from pylatex import Section, Subsection, Tabular
|
||||
from pylatex.utils import bold
|
||||
from report import build_benchmark_latex_report
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
import primaite
|
||||
from primaite.config.lay_down_config import data_manipulation_config_path
|
||||
from primaite.data_viz.session_plots import get_plotly_config
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.primaite_session import PrimaiteSession
|
||||
from benchmark import BenchmarkPrimaiteGymEnv
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
|
||||
_LOGGER = primaite.getLogger(__name__)
|
||||
|
||||
_MAJOR_V = primaite.__version__.split(".")[0]
|
||||
|
||||
_BENCHMARK_ROOT = Path(__file__).parent
|
||||
_RESULTS_ROOT: Final[Path] = _BENCHMARK_ROOT / "results"
|
||||
_RESULTS_ROOT.mkdir(exist_ok=True, parents=True)
|
||||
_RESULTS_ROOT: Final[Path] = _BENCHMARK_ROOT / "results" / f"v{_MAJOR_V}"
|
||||
_VERSION_ROOT: Final[Path] = _RESULTS_ROOT / f"v{primaite.__version__}"
|
||||
_SESSION_METADATA_ROOT: Final[Path] = _VERSION_ROOT / "session_metadata"
|
||||
|
||||
_OUTPUT_ROOT: Final[Path] = _BENCHMARK_ROOT / "output"
|
||||
# Clear and recreate the output directory
|
||||
if _OUTPUT_ROOT.exists():
|
||||
shutil.rmtree(_OUTPUT_ROOT)
|
||||
_OUTPUT_ROOT.mkdir()
|
||||
|
||||
_TRAINING_CONFIG_PATH = _BENCHMARK_ROOT / "config" / "benchmark_training_config.yaml"
|
||||
_LAY_DOWN_CONFIG_PATH = data_manipulation_config_path()
|
||||
_SESSION_METADATA_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_size(size_bytes: int) -> str:
|
||||
"""
|
||||
Scale bytes to its proper format.
|
||||
class BenchmarkSession:
|
||||
"""Benchmark Session class."""
|
||||
|
||||
e.g:
|
||||
1253656 => '1.20MB'
|
||||
1253656678 => '1.17GB'
|
||||
gym_env: BenchmarkPrimaiteGymEnv
|
||||
"""Gym environment used by the session to train."""
|
||||
|
||||
:
|
||||
"""
|
||||
factor = 1024
|
||||
for unit in ["", "K", "M", "G", "T", "P"]:
|
||||
if size_bytes < factor:
|
||||
return f"{size_bytes:.2f}{unit}B"
|
||||
size_bytes /= factor
|
||||
num_episodes: int
|
||||
"""Number of episodes to run the training session."""
|
||||
|
||||
episode_len: int
|
||||
"""The number of steps per episode."""
|
||||
|
||||
def _get_system_info() -> Dict:
|
||||
"""Builds and returns a dict containing system info."""
|
||||
uname = platform.uname()
|
||||
cpu_freq = psutil.cpu_freq()
|
||||
virtual_mem = psutil.virtual_memory()
|
||||
swap_mem = psutil.swap_memory()
|
||||
gpus = GPUtil.getGPUs()
|
||||
return {
|
||||
"System": {
|
||||
"OS": uname.system,
|
||||
"OS Version": uname.version,
|
||||
"Machine": uname.machine,
|
||||
"Processor": uname.processor,
|
||||
},
|
||||
"CPU": {
|
||||
"Physical Cores": psutil.cpu_count(logical=False),
|
||||
"Total Cores": psutil.cpu_count(logical=True),
|
||||
"Max Frequency": f"{cpu_freq.max:.2f}Mhz",
|
||||
},
|
||||
"Memory": {"Total": get_size(virtual_mem.total), "Swap Total": get_size(swap_mem.total)},
|
||||
"GPU": [{"Name": gpu.name, "Total Memory": f"{gpu.memoryTotal}MB"} for gpu in gpus],
|
||||
}
|
||||
total_steps: int
|
||||
"""Number of steps to run the training session."""
|
||||
|
||||
batch_size: int
|
||||
"""Number of steps for each episode."""
|
||||
|
||||
def _build_benchmark_latex_report(
|
||||
benchmark_metadata_dict: Dict, this_version_plot_path: Path, all_version_plot_path: Path
|
||||
) -> None:
|
||||
geometry_options = {"tmargin": "2.5cm", "rmargin": "2.5cm", "bmargin": "2.5cm", "lmargin": "2.5cm"}
|
||||
data = benchmark_metadata_dict
|
||||
primaite_version = data["primaite_version"]
|
||||
learning_rate: float
|
||||
"""Learning rate for the model."""
|
||||
|
||||
# Create a new document
|
||||
doc = Document("report", geometry_options=geometry_options)
|
||||
# Title
|
||||
doc.preamble.append(Command("title", f"PrimAITE {primaite_version} Learning Benchmark"))
|
||||
doc.preamble.append(Command("author", "PrimAITE Dev Team"))
|
||||
doc.preamble.append(Command("date", datetime.now().date()))
|
||||
doc.append(Command("maketitle"))
|
||||
start_time: datetime
|
||||
"""Start time for the session."""
|
||||
|
||||
sessions = data["total_sessions"]
|
||||
episodes = data["training_config"]["num_train_episodes"]
|
||||
steps = data["training_config"]["num_train_steps"]
|
||||
|
||||
# Body
|
||||
with doc.create(Section("Introduction")):
|
||||
doc.append(
|
||||
f"PrimAITE v{primaite_version} was benchmarked automatically upon release. Learning rate metrics "
|
||||
f"were captured to be referenced during system-level testing and user acceptance testing (UAT)."
|
||||
)
|
||||
doc.append(
|
||||
f"\nThe benchmarking process consists of running {sessions} training session using the same "
|
||||
f"training and lay down config files. Each session trains an agent for {episodes} episodes, "
|
||||
f"with each episode consisting of {steps} steps."
|
||||
)
|
||||
doc.append(
|
||||
f"\nThe mean reward per episode from each session is captured. This is then used to calculate a "
|
||||
f"combined average reward per episode from the {sessions} individual sessions for smoothing. "
|
||||
f"Finally, a 25-widow rolling average of the combined average reward per session is calculated for "
|
||||
f"further smoothing."
|
||||
)
|
||||
|
||||
with doc.create(Section("System Information")):
|
||||
with doc.create(Subsection("Python")):
|
||||
with doc.create(Tabular("|l|l|")) as table:
|
||||
table.add_hline()
|
||||
table.add_row((bold("Version"), sys.version))
|
||||
table.add_hline()
|
||||
for section, section_data in data["system_info"].items():
|
||||
if section_data:
|
||||
with doc.create(Subsection(section)):
|
||||
if isinstance(section_data, dict):
|
||||
with doc.create(Tabular("|l|l|")) as table:
|
||||
table.add_hline()
|
||||
for key, value in section_data.items():
|
||||
table.add_row((bold(key), value))
|
||||
table.add_hline()
|
||||
elif isinstance(section_data, list):
|
||||
headers = section_data[0].keys()
|
||||
tabs_str = "|".join(["l" for _ in range(len(headers))])
|
||||
tabs_str = f"|{tabs_str}|"
|
||||
with doc.create(Tabular(tabs_str)) as table:
|
||||
table.add_hline()
|
||||
table.add_row([bold(h) for h in headers])
|
||||
table.add_hline()
|
||||
for item in section_data:
|
||||
table.add_row(item.values())
|
||||
table.add_hline()
|
||||
|
||||
headers_map = {
|
||||
"total_sessions": "Total Sessions",
|
||||
"total_episodes": "Total Episodes",
|
||||
"total_time_steps": "Total Steps",
|
||||
"av_s_per_session": "Av Session Duration (s)",
|
||||
"av_s_per_step": "Av Step Duration (s)",
|
||||
"av_s_per_100_steps_10_nodes": "Av Duration per 100 Steps per 10 Nodes (s)",
|
||||
}
|
||||
with doc.create(Section("Stats")):
|
||||
with doc.create(Subsection("Benchmark Results")):
|
||||
with doc.create(Tabular("|l|l|")) as table:
|
||||
table.add_hline()
|
||||
for section, header in headers_map.items():
|
||||
if section.startswith("av_"):
|
||||
table.add_row((bold(header), f"{data[section]:.4f}"))
|
||||
else:
|
||||
table.add_row((bold(header), data[section]))
|
||||
table.add_hline()
|
||||
|
||||
with doc.create(Section("Graphs")):
|
||||
with doc.create(Subsection(f"PrimAITE {primaite_version} Learning Benchmark Plot")):
|
||||
with doc.create(LatexFigure(position="h!")) as pic:
|
||||
pic.add_image(str(this_version_plot_path))
|
||||
pic.add_caption(f"PrimAITE {primaite_version} Learning Benchmark Plot")
|
||||
|
||||
with doc.create(Subsection("PrimAITE All Versions Learning Benchmark Plot")):
|
||||
with doc.create(LatexFigure(position="h!")) as pic:
|
||||
pic.add_image(str(all_version_plot_path))
|
||||
pic.add_caption("PrimAITE All Versions Learning Benchmark Plot")
|
||||
|
||||
doc.generate_pdf(str(this_version_plot_path).replace(".png", ""), clean_tex=True)
|
||||
|
||||
|
||||
class BenchmarkPrimaiteSession(PrimaiteSession):
|
||||
"""A benchmarking primaite session."""
|
||||
end_time: datetime
|
||||
"""End time for the session."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
) -> None:
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
self.setup()
|
||||
gym_env: BenchmarkPrimaiteGymEnv,
|
||||
episode_len: int,
|
||||
num_episodes: int,
|
||||
n_steps: int,
|
||||
batch_size: int,
|
||||
learning_rate: float,
|
||||
):
|
||||
"""Initialise the BenchmarkSession."""
|
||||
self.gym_env = gym_env
|
||||
self.episode_len = episode_len
|
||||
self.n_steps = n_steps
|
||||
self.num_episodes = num_episodes
|
||||
self.total_steps = self.num_episodes * self.episode_len
|
||||
self.batch_size = batch_size
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
@property
|
||||
def env(self) -> Primaite:
|
||||
"""Direct access to the env for ease of testing."""
|
||||
return self._agent_session._env # noqa
|
||||
def train(self):
|
||||
"""Run the training session."""
|
||||
# start timer for session
|
||||
self.start_time = datetime.now()
|
||||
model = PPO(
|
||||
policy="MlpPolicy",
|
||||
env=self.gym_env,
|
||||
learning_rate=self.learning_rate,
|
||||
n_steps=self.n_steps,
|
||||
batch_size=self.batch_size,
|
||||
verbose=0,
|
||||
tensorboard_log="./PPO_UC2/",
|
||||
)
|
||||
model.learn(total_timesteps=self.total_steps)
|
||||
|
||||
def __enter__(self) -> "BenchmarkPrimaiteSession":
|
||||
return self
|
||||
# end timer for session
|
||||
self.end_time = datetime.now()
|
||||
|
||||
# TODO: typehints uncertain
|
||||
def __exit__(self, type: Any, value: Any, tb: Any) -> None:
|
||||
shutil.rmtree(self.session_path)
|
||||
_LOGGER.debug(f"Deleted benchmark session directory: {self.session_path}")
|
||||
self.session_metadata = self.generate_learn_metadata_dict()
|
||||
|
||||
def _learn_benchmark_durations(self) -> Tuple[float, float, float]:
|
||||
"""
|
||||
@@ -219,235 +101,99 @@ class BenchmarkPrimaiteSession(PrimaiteSession):
|
||||
:return: The learning benchmark durations as a Tuple of three floats:
|
||||
Tuple[total_s, s_per_step, s_per_100_steps_10_nodes].
|
||||
"""
|
||||
data = self.metadata_file_as_dict()
|
||||
start_dt = datetime.fromisoformat(data["start_datetime"])
|
||||
end_dt = datetime.fromisoformat(data["end_datetime"])
|
||||
delta = end_dt - start_dt
|
||||
delta = self.end_time - self.start_time
|
||||
total_s = delta.total_seconds()
|
||||
|
||||
total_steps = data["learning"]["total_time_steps"]
|
||||
total_steps = self.batch_size * self.num_episodes
|
||||
s_per_step = total_s / total_steps
|
||||
|
||||
num_nodes = self.env.num_nodes
|
||||
num_nodes = len(self.gym_env.game.simulation.network.nodes)
|
||||
num_intervals = total_steps / 100
|
||||
av_interval_time = total_s / num_intervals
|
||||
s_per_100_steps_10_nodes = av_interval_time / (num_nodes / 10)
|
||||
|
||||
return total_s, s_per_step, s_per_100_steps_10_nodes
|
||||
|
||||
def learn_metadata_dict(self) -> Dict[str, Any]:
|
||||
def generate_learn_metadata_dict(self) -> Dict[str, Any]:
|
||||
"""Metadata specific to the learning session."""
|
||||
total_s, s_per_step, s_per_100_steps_10_nodes = self._learn_benchmark_durations()
|
||||
self.gym_env.average_reward_per_episode.pop(0) # remove episode 0
|
||||
return {
|
||||
"total_episodes": self.env.actual_episode_count,
|
||||
"total_time_steps": self.env.total_step_count,
|
||||
"total_episodes": self.gym_env.episode_counter,
|
||||
"total_time_steps": self.gym_env.total_time_steps,
|
||||
"total_s": total_s,
|
||||
"s_per_step": s_per_step,
|
||||
"s_per_100_steps_10_nodes": s_per_100_steps_10_nodes,
|
||||
"av_reward_per_episode": self.learn_av_reward_per_episode_dict(),
|
||||
"av_reward_per_episode": self.gym_env.average_reward_per_episode,
|
||||
}
|
||||
|
||||
|
||||
def _get_benchmark_session_path(session_timestamp: datetime) -> Path:
|
||||
return _OUTPUT_ROOT / session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
|
||||
def _get_benchmark_primaite_session() -> BenchmarkPrimaiteSession:
|
||||
with patch("primaite.agents.agent_abc.get_session_path", _get_benchmark_session_path) as mck:
|
||||
mck.session_timestamp = datetime.now()
|
||||
return BenchmarkPrimaiteSession(_TRAINING_CONFIG_PATH, _LAY_DOWN_CONFIG_PATH)
|
||||
|
||||
|
||||
def _build_benchmark_results_dict(start_datetime: datetime, metadata_dict: Dict) -> dict:
|
||||
n = len(metadata_dict)
|
||||
with open(_TRAINING_CONFIG_PATH, "r") as file:
|
||||
training_config_dict = yaml.safe_load(file)
|
||||
with open(_LAY_DOWN_CONFIG_PATH, "r") as file:
|
||||
lay_down_config_dict = yaml.safe_load(file)
|
||||
averaged_data = {
|
||||
"start_timestamp": start_datetime.isoformat(),
|
||||
"end_datetime": datetime.now().isoformat(),
|
||||
"primaite_version": primaite.__version__,
|
||||
"system_info": _get_system_info(),
|
||||
"total_sessions": n,
|
||||
"total_episodes": sum(d["total_episodes"] for d in metadata_dict.values()),
|
||||
"total_time_steps": sum(d["total_time_steps"] for d in metadata_dict.values()),
|
||||
"av_s_per_session": sum(d["total_s"] for d in metadata_dict.values()) / n,
|
||||
"av_s_per_step": sum(d["s_per_step"] for d in metadata_dict.values()) / n,
|
||||
"av_s_per_100_steps_10_nodes": sum(d["s_per_100_steps_10_nodes"] for d in metadata_dict.values()) / n,
|
||||
"combined_av_reward_per_episode": {},
|
||||
"session_av_reward_per_episode": {k: v["av_reward_per_episode"] for k, v in metadata_dict.items()},
|
||||
"training_config": training_config_dict,
|
||||
"lay_down_config": lay_down_config_dict,
|
||||
}
|
||||
|
||||
episodes = metadata_dict[1]["av_reward_per_episode"].keys()
|
||||
|
||||
for episode in episodes:
|
||||
combined_av_reward = sum(metadata_dict[k]["av_reward_per_episode"][episode] for k in metadata_dict.keys()) / n
|
||||
averaged_data["combined_av_reward_per_episode"][episode] = combined_av_reward
|
||||
|
||||
return averaged_data
|
||||
|
||||
|
||||
def _get_df_from_episode_av_reward_dict(data: Dict) -> pl.DataFrame:
|
||||
data: Dict = {"episode": data.keys(), "av_reward": data.values()}
|
||||
|
||||
return (
|
||||
pl.from_dict(data)
|
||||
.with_columns(rolling_mean=pl.col("av_reward").rolling_mean(window_size=25))
|
||||
.rename({"rolling_mean": "rolling_av_reward"})
|
||||
)
|
||||
|
||||
|
||||
def _plot_benchmark_metadata(
|
||||
benchmark_metadata_dict: Dict,
|
||||
title: Optional[str] = None,
|
||||
subtitle: Optional[str] = None,
|
||||
) -> Figure:
|
||||
if title:
|
||||
if subtitle:
|
||||
title = f"{title} <br>{subtitle}</sup>"
|
||||
else:
|
||||
if subtitle:
|
||||
title = subtitle
|
||||
|
||||
config = get_plotly_config()
|
||||
layout = go.Layout(
|
||||
autosize=config["size"]["auto_size"],
|
||||
width=config["size"]["width"],
|
||||
height=config["size"]["height"],
|
||||
)
|
||||
# Create the line graph with a colored line
|
||||
fig = go.Figure(layout=layout)
|
||||
fig.update_layout(template=config["template"])
|
||||
|
||||
for session, av_reward_dict in benchmark_metadata_dict["session_av_reward_per_episode"].items():
|
||||
df = _get_df_from_episode_av_reward_dict(av_reward_dict)
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["episode"],
|
||||
y=df["av_reward"],
|
||||
mode="lines",
|
||||
name=f"Session {session}",
|
||||
opacity=0.25,
|
||||
line={"color": "#a6a6a6"},
|
||||
)
|
||||
)
|
||||
|
||||
df = _get_df_from_episode_av_reward_dict(benchmark_metadata_dict["combined_av_reward_per_episode"])
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["episode"], y=df["av_reward"], mode="lines", name="Combined Session Av", line={"color": "#FF0000"}
|
||||
)
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["episode"],
|
||||
y=df["rolling_av_reward"],
|
||||
mode="lines",
|
||||
name="Rolling Av (Combined Session Av)",
|
||||
line={"color": "#4CBB17"},
|
||||
)
|
||||
)
|
||||
|
||||
# Set the layout of the graph
|
||||
fig.update_layout(
|
||||
xaxis={
|
||||
"title": "Episode",
|
||||
"type": "linear",
|
||||
},
|
||||
yaxis={"title": "Average Reward"},
|
||||
title=title,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def _plot_all_benchmarks_combined_session_av() -> Figure:
|
||||
def _get_benchmark_primaite_environment() -> BenchmarkPrimaiteGymEnv:
|
||||
"""
|
||||
Plot the Benchmark results for each released version of PrimAITE.
|
||||
Create an instance of the BenchmarkPrimaiteGymEnv.
|
||||
|
||||
Does this by iterating over the ``benchmark/results`` directory and
|
||||
extracting the benchmark metadata json for each version that has been
|
||||
benchmarked. The combined_av_reward_per_episode is extracted from each,
|
||||
converted into a polars dataframe, and plotted as a scatter line in plotly.
|
||||
This environment will be used to train the agents on.
|
||||
"""
|
||||
title = "PrimAITE Versions Learning Benchmark"
|
||||
subtitle = "Rolling Av (Combined Session Av)"
|
||||
if title:
|
||||
if subtitle:
|
||||
title = f"{title} <br>{subtitle}</sup>"
|
||||
else:
|
||||
if subtitle:
|
||||
title = subtitle
|
||||
config = get_plotly_config()
|
||||
layout = go.Layout(
|
||||
autosize=config["size"]["auto_size"],
|
||||
width=config["size"]["width"],
|
||||
height=config["size"]["height"],
|
||||
)
|
||||
# Create the line graph with a colored line
|
||||
fig = go.Figure(layout=layout)
|
||||
fig.update_layout(template=config["template"])
|
||||
|
||||
for dir in _RESULTS_ROOT.iterdir():
|
||||
if dir.is_dir():
|
||||
metadata_file = dir / f"{dir.name}_benchmark_metadata.json"
|
||||
with open(metadata_file, "r") as file:
|
||||
metadata_dict = json.load(file)
|
||||
df = _get_df_from_episode_av_reward_dict(metadata_dict["combined_av_reward_per_episode"])
|
||||
|
||||
fig.add_trace(go.Scatter(x=df["episode"], y=df["rolling_av_reward"], mode="lines", name=dir.name))
|
||||
|
||||
# Set the layout of the graph
|
||||
fig.update_layout(
|
||||
xaxis={
|
||||
"title": "Episode",
|
||||
"type": "linear",
|
||||
},
|
||||
yaxis={"title": "Average Reward"},
|
||||
title=title,
|
||||
)
|
||||
fig["data"][0]["showlegend"] = True
|
||||
|
||||
return fig
|
||||
env = BenchmarkPrimaiteGymEnv(env_config=data_manipulation_config_path())
|
||||
return env
|
||||
|
||||
|
||||
def run() -> None:
|
||||
def _prepare_session_directory():
|
||||
"""Prepare the session directory so that it is easier to clean up after the benchmarking is done."""
|
||||
# override session path
|
||||
session_path = _BENCHMARK_ROOT / "sessions"
|
||||
|
||||
if session_path.is_dir():
|
||||
shutil.rmtree(session_path)
|
||||
|
||||
primaite.PRIMAITE_PATHS.user_sessions_path = session_path
|
||||
primaite.PRIMAITE_PATHS.user_sessions_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
def run(
|
||||
number_of_sessions: int = 5,
|
||||
num_episodes: int = 1000,
|
||||
episode_len: int = 128,
|
||||
n_steps: int = 1280,
|
||||
batch_size: int = 32,
|
||||
learning_rate: float = 3e-4,
|
||||
) -> None:
|
||||
"""Run the PrimAITE benchmark."""
|
||||
start_datetime = datetime.now()
|
||||
av_reward_per_episode_dicts = {}
|
||||
for i in range(1, 11):
|
||||
benchmark_start_time = datetime.now()
|
||||
|
||||
session_metadata_dict = {}
|
||||
|
||||
_prepare_session_directory()
|
||||
|
||||
# run training
|
||||
for i in range(1, number_of_sessions + 1):
|
||||
print(f"Starting Benchmark Session: {i}")
|
||||
with _get_benchmark_primaite_session() as session:
|
||||
session.learn()
|
||||
av_reward_per_episode_dicts[i] = session.learn_metadata_dict()
|
||||
|
||||
benchmark_metadata = _build_benchmark_results_dict(
|
||||
start_datetime=start_datetime, metadata_dict=av_reward_per_episode_dicts
|
||||
with _get_benchmark_primaite_environment() as gym_env:
|
||||
session = BenchmarkSession(
|
||||
gym_env=gym_env,
|
||||
num_episodes=num_episodes,
|
||||
n_steps=n_steps,
|
||||
episode_len=episode_len,
|
||||
batch_size=batch_size,
|
||||
learning_rate=learning_rate,
|
||||
)
|
||||
session.train()
|
||||
|
||||
# Dump the session metadata so that we're not holding it in memory as it's large
|
||||
with open(_SESSION_METADATA_ROOT / f"{i}.json", "w") as file:
|
||||
json.dump(session.session_metadata, file, indent=4)
|
||||
|
||||
for i in range(1, number_of_sessions + 1):
|
||||
with open(_SESSION_METADATA_ROOT / f"{i}.json", "r") as file:
|
||||
session_metadata_dict[i] = json.load(file)
|
||||
# generate report
|
||||
build_benchmark_latex_report(
|
||||
benchmark_start_time=benchmark_start_time,
|
||||
session_metadata=session_metadata_dict,
|
||||
config_path=data_manipulation_config_path(),
|
||||
results_root_path=_RESULTS_ROOT,
|
||||
)
|
||||
v_str = f"v{primaite.__version__}"
|
||||
|
||||
version_result_dir = _RESULTS_ROOT / v_str
|
||||
if version_result_dir.exists():
|
||||
shutil.rmtree(version_result_dir)
|
||||
version_result_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
with open(version_result_dir / f"{v_str}_benchmark_metadata.json", "w") as file:
|
||||
json.dump(benchmark_metadata, file, indent=4)
|
||||
title = f"PrimAITE v{primaite.__version__.strip()} Learning Benchmark"
|
||||
fig = _plot_benchmark_metadata(benchmark_metadata, title=title)
|
||||
this_version_plot_path = version_result_dir / f"{title}.png"
|
||||
fig.write_image(this_version_plot_path)
|
||||
|
||||
fig = _plot_all_benchmarks_combined_session_av()
|
||||
|
||||
all_version_plot_path = _RESULTS_ROOT / "PrimAITE Versions Learning Benchmark.png"
|
||||
fig.write_image(all_version_plot_path)
|
||||
|
||||
_build_benchmark_latex_report(benchmark_metadata, this_version_plot_path, all_version_plot_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
305
benchmark/report.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import plotly.graph_objects as go
|
||||
import polars as pl
|
||||
import yaml
|
||||
from plotly.graph_objs import Figure
|
||||
from pylatex import Command, Document
|
||||
from pylatex import Figure as LatexFigure
|
||||
from pylatex import Section, Subsection, Tabular
|
||||
from pylatex.utils import bold
|
||||
from utils import _get_system_info
|
||||
|
||||
import primaite
|
||||
|
||||
PLOT_CONFIG = {
|
||||
"size": {"auto_size": False, "width": 1500, "height": 900},
|
||||
"template": "plotly_white",
|
||||
"range_slider": False,
|
||||
}
|
||||
|
||||
|
||||
def _build_benchmark_results_dict(start_datetime: datetime, metadata_dict: Dict, config: Dict) -> dict:
|
||||
num_sessions = len(metadata_dict) # number of sessions
|
||||
|
||||
averaged_data = {
|
||||
"start_timestamp": start_datetime.isoformat(),
|
||||
"end_datetime": datetime.now().isoformat(),
|
||||
"primaite_version": primaite.__version__,
|
||||
"system_info": _get_system_info(),
|
||||
"total_sessions": num_sessions,
|
||||
"total_episodes": sum(d["total_episodes"] for d in metadata_dict.values()),
|
||||
"total_time_steps": sum(d["total_time_steps"] for d in metadata_dict.values()),
|
||||
"av_s_per_session": sum(d["total_s"] for d in metadata_dict.values()) / num_sessions,
|
||||
"av_s_per_step": sum(d["s_per_step"] for d in metadata_dict.values()) / num_sessions,
|
||||
"av_s_per_100_steps_10_nodes": sum(d["s_per_100_steps_10_nodes"] for d in metadata_dict.values())
|
||||
/ num_sessions,
|
||||
"combined_av_reward_per_episode": {},
|
||||
"session_av_reward_per_episode": {k: v["av_reward_per_episode"] for k, v in metadata_dict.items()},
|
||||
"config": config,
|
||||
}
|
||||
|
||||
# find the average of each episode across all sessions
|
||||
episodes = metadata_dict[1]["av_reward_per_episode"].keys()
|
||||
|
||||
for episode in episodes:
|
||||
combined_av_reward = (
|
||||
sum(metadata_dict[k]["av_reward_per_episode"][episode] for k in metadata_dict.keys()) / num_sessions
|
||||
)
|
||||
averaged_data["combined_av_reward_per_episode"][episode] = combined_av_reward
|
||||
|
||||
return averaged_data
|
||||
|
||||
|
||||
def _get_df_from_episode_av_reward_dict(data: Dict) -> pl.DataFrame:
|
||||
data: Dict = {"episode": data.keys(), "av_reward": data.values()}
|
||||
|
||||
return (
|
||||
pl.from_dict(data)
|
||||
.with_columns(rolling_mean=pl.col("av_reward").rolling_mean(window_size=25))
|
||||
.rename({"rolling_mean": "rolling_av_reward"})
|
||||
)
|
||||
|
||||
|
||||
def _plot_benchmark_metadata(
|
||||
benchmark_metadata_dict: Dict,
|
||||
title: Optional[str] = None,
|
||||
subtitle: Optional[str] = None,
|
||||
) -> Figure:
|
||||
if title:
|
||||
if subtitle:
|
||||
title = f"{title} <br>{subtitle}</sup>"
|
||||
else:
|
||||
if subtitle:
|
||||
title = subtitle
|
||||
|
||||
layout = go.Layout(
|
||||
autosize=PLOT_CONFIG["size"]["auto_size"],
|
||||
width=PLOT_CONFIG["size"]["width"],
|
||||
height=PLOT_CONFIG["size"]["height"],
|
||||
)
|
||||
# Create the line graph with a colored line
|
||||
fig = go.Figure(layout=layout)
|
||||
fig.update_layout(template=PLOT_CONFIG["template"])
|
||||
|
||||
for session, av_reward_dict in benchmark_metadata_dict["session_av_reward_per_episode"].items():
|
||||
df = _get_df_from_episode_av_reward_dict(av_reward_dict)
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["episode"],
|
||||
y=df["av_reward"],
|
||||
mode="lines",
|
||||
name=f"Session {session}",
|
||||
opacity=0.25,
|
||||
line={"color": "#a6a6a6"},
|
||||
)
|
||||
)
|
||||
|
||||
df = _get_df_from_episode_av_reward_dict(benchmark_metadata_dict["combined_av_reward_per_episode"])
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["episode"], y=df["av_reward"], mode="lines", name="Combined Session Av", line={"color": "#FF0000"}
|
||||
)
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["episode"],
|
||||
y=df["rolling_av_reward"],
|
||||
mode="lines",
|
||||
name="Rolling Av (Combined Session Av)",
|
||||
line={"color": "#4CBB17"},
|
||||
)
|
||||
)
|
||||
|
||||
# Set the layout of the graph
|
||||
fig.update_layout(
|
||||
xaxis={
|
||||
"title": "Episode",
|
||||
"type": "linear",
|
||||
},
|
||||
yaxis={"title": "Total Reward"},
|
||||
title=title,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure:
|
||||
"""
|
||||
Plot the Benchmark results for each released version of PrimAITE.
|
||||
|
||||
Does this by iterating over the ``benchmark/results`` directory and
|
||||
extracting the benchmark metadata json for each version that has been
|
||||
benchmarked. The combined_av_reward_per_episode is extracted from each,
|
||||
converted into a polars dataframe, and plotted as a scatter line in plotly.
|
||||
"""
|
||||
major_v = primaite.__version__.split(".")[0]
|
||||
title = f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*"
|
||||
subtitle = "Rolling Av (Combined Session Av)"
|
||||
if title:
|
||||
if subtitle:
|
||||
title = f"{title} <br>{subtitle}</sup>"
|
||||
else:
|
||||
if subtitle:
|
||||
title = subtitle
|
||||
layout = go.Layout(
|
||||
autosize=PLOT_CONFIG["size"]["auto_size"],
|
||||
width=PLOT_CONFIG["size"]["width"],
|
||||
height=PLOT_CONFIG["size"]["height"],
|
||||
)
|
||||
# Create the line graph with a colored line
|
||||
fig = go.Figure(layout=layout)
|
||||
fig.update_layout(template=PLOT_CONFIG["template"])
|
||||
|
||||
for dir in results_directory.iterdir():
|
||||
if dir.is_dir():
|
||||
metadata_file = dir / f"{dir.name}_benchmark_metadata.json"
|
||||
with open(metadata_file, "r") as file:
|
||||
metadata_dict = json.load(file)
|
||||
df = _get_df_from_episode_av_reward_dict(metadata_dict["combined_av_reward_per_episode"])
|
||||
|
||||
fig.add_trace(go.Scatter(x=df["episode"], y=df["rolling_av_reward"], mode="lines", name=dir.name))
|
||||
|
||||
# Set the layout of the graph
|
||||
fig.update_layout(
|
||||
xaxis={
|
||||
"title": "Episode",
|
||||
"type": "linear",
|
||||
},
|
||||
yaxis={"title": "Total Reward"},
|
||||
title=title,
|
||||
)
|
||||
fig["data"][0]["showlegend"] = True
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def build_benchmark_latex_report(
|
||||
benchmark_start_time: datetime, session_metadata: Dict, config_path: Path, results_root_path: Path
|
||||
) -> None:
|
||||
"""Generates a latex report of the benchmark run."""
|
||||
# generate report folder
|
||||
v_str = f"v{primaite.__version__}"
|
||||
|
||||
version_result_dir = results_root_path / v_str
|
||||
version_result_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# load the config file as dict
|
||||
with open(config_path, "r") as f:
|
||||
cfg_data = yaml.safe_load(f)
|
||||
|
||||
# generate the benchmark metadata dict
|
||||
benchmark_metadata_dict = _build_benchmark_results_dict(
|
||||
start_datetime=benchmark_start_time, metadata_dict=session_metadata, config=cfg_data
|
||||
)
|
||||
major_v = primaite.__version__.split(".")[0]
|
||||
with open(version_result_dir / f"{v_str}_benchmark_metadata.json", "w") as file:
|
||||
json.dump(benchmark_metadata_dict, file, indent=4)
|
||||
title = f"PrimAITE v{primaite.__version__.strip()} Learning Benchmark"
|
||||
fig = _plot_benchmark_metadata(benchmark_metadata_dict, title=title)
|
||||
this_version_plot_path = version_result_dir / f"{title}.png"
|
||||
fig.write_image(this_version_plot_path)
|
||||
|
||||
fig = _plot_all_benchmarks_combined_session_av(results_directory=results_root_path)
|
||||
|
||||
all_version_plot_path = results_root_path / "PrimAITE Versions Learning Benchmark.png"
|
||||
fig.write_image(all_version_plot_path)
|
||||
|
||||
geometry_options = {"tmargin": "2.5cm", "rmargin": "2.5cm", "bmargin": "2.5cm", "lmargin": "2.5cm"}
|
||||
data = benchmark_metadata_dict
|
||||
primaite_version = data["primaite_version"]
|
||||
|
||||
# Create a new document
|
||||
doc = Document("report", geometry_options=geometry_options)
|
||||
# Title
|
||||
doc.preamble.append(Command("title", f"PrimAITE {primaite_version} Learning Benchmark"))
|
||||
doc.preamble.append(Command("author", "PrimAITE Dev Team"))
|
||||
doc.preamble.append(Command("date", datetime.now().date()))
|
||||
doc.append(Command("maketitle"))
|
||||
|
||||
sessions = data["total_sessions"]
|
||||
episodes = session_metadata[1]["total_episodes"] - 1
|
||||
steps = data["config"]["game"]["max_episode_length"]
|
||||
|
||||
# Body
|
||||
with doc.create(Section("Introduction")):
|
||||
doc.append(
|
||||
f"PrimAITE v{primaite_version} was benchmarked automatically upon release. Learning rate metrics "
|
||||
f"were captured to be referenced during system-level testing and user acceptance testing (UAT)."
|
||||
)
|
||||
doc.append(
|
||||
f"\nThe benchmarking process consists of running {sessions} training session using the same "
|
||||
f"config file. Each session trains an agent for {episodes} episodes, "
|
||||
f"with each episode consisting of {steps} steps."
|
||||
)
|
||||
doc.append(
|
||||
f"\nThe total reward per episode from each session is captured. This is then used to calculate an "
|
||||
f"caverage total reward per episode from the {sessions} individual sessions for smoothing. "
|
||||
f"Finally, a 25-widow rolling average of the average total reward per session is calculated for "
|
||||
f"further smoothing."
|
||||
)
|
||||
|
||||
with doc.create(Section("System Information")):
|
||||
with doc.create(Subsection("Python")):
|
||||
with doc.create(Tabular("|l|l|")) as table:
|
||||
table.add_hline()
|
||||
table.add_row((bold("Version"), sys.version))
|
||||
table.add_hline()
|
||||
for section, section_data in data["system_info"].items():
|
||||
if section_data:
|
||||
with doc.create(Subsection(section)):
|
||||
if isinstance(section_data, dict):
|
||||
with doc.create(Tabular("|l|l|")) as table:
|
||||
table.add_hline()
|
||||
for key, value in section_data.items():
|
||||
table.add_row((bold(key), value))
|
||||
table.add_hline()
|
||||
elif isinstance(section_data, list):
|
||||
headers = section_data[0].keys()
|
||||
tabs_str = "|".join(["l" for _ in range(len(headers))])
|
||||
tabs_str = f"|{tabs_str}|"
|
||||
with doc.create(Tabular(tabs_str)) as table:
|
||||
table.add_hline()
|
||||
table.add_row([bold(h) for h in headers])
|
||||
table.add_hline()
|
||||
for item in section_data:
|
||||
table.add_row(item.values())
|
||||
table.add_hline()
|
||||
|
||||
headers_map = {
|
||||
"total_sessions": "Total Sessions",
|
||||
"total_episodes": "Total Episodes",
|
||||
"total_time_steps": "Total Steps",
|
||||
"av_s_per_session": "Av Session Duration (s)",
|
||||
"av_s_per_step": "Av Step Duration (s)",
|
||||
"av_s_per_100_steps_10_nodes": "Av Duration per 100 Steps per 10 Nodes (s)",
|
||||
}
|
||||
with doc.create(Section("Stats")):
|
||||
with doc.create(Subsection("Benchmark Results")):
|
||||
with doc.create(Tabular("|l|l|")) as table:
|
||||
table.add_hline()
|
||||
for section, header in headers_map.items():
|
||||
if section.startswith("av_"):
|
||||
table.add_row((bold(header), f"{data[section]:.4f}"))
|
||||
else:
|
||||
table.add_row((bold(header), data[section]))
|
||||
table.add_hline()
|
||||
|
||||
with doc.create(Section("Graphs")):
|
||||
with doc.create(Subsection(f"v{primaite_version} Learning Benchmark Plot")):
|
||||
with doc.create(LatexFigure(position="h!")) as pic:
|
||||
pic.add_image(str(this_version_plot_path))
|
||||
pic.add_caption(f"PrimAITE {primaite_version} Learning Benchmark Plot")
|
||||
|
||||
with doc.create(Subsection(f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*")):
|
||||
with doc.create(LatexFigure(position="h!")) as pic:
|
||||
pic.add_image(str(all_version_plot_path))
|
||||
pic.add_caption(f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*")
|
||||
|
||||
doc.generate_pdf(str(this_version_plot_path).replace(".png", ""), clean_tex=True)
|
||||
|
Before Width: | Height: | Size: 79 KiB After Width: | Height: | Size: 79 KiB |
|
Before Width: | Height: | Size: 225 KiB After Width: | Height: | Size: 225 KiB |
BIN
benchmark/results/v3/PrimAITE Versions Learning Benchmark.png
Normal file
|
After Width: | Height: | Size: 91 KiB |
|
After Width: | Height: | Size: 295 KiB |
7436
benchmark/results/v3/v3.0.0/v3.0.0_benchmark_metadata.json
Normal file
47
benchmark/utils.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import platform
|
||||
from typing import Dict
|
||||
|
||||
import psutil
|
||||
from GPUtil import GPUtil
|
||||
|
||||
|
||||
def get_size(size_bytes: int) -> str:
|
||||
"""
|
||||
Scale bytes to its proper format.
|
||||
|
||||
e.g:
|
||||
1253656 => '1.20MB'
|
||||
1253656678 => '1.17GB'
|
||||
|
||||
:
|
||||
"""
|
||||
factor = 1024
|
||||
for unit in ["", "K", "M", "G", "T", "P"]:
|
||||
if size_bytes < factor:
|
||||
return f"{size_bytes:.2f}{unit}B"
|
||||
size_bytes /= factor
|
||||
|
||||
|
||||
def _get_system_info() -> Dict:
|
||||
"""Builds and returns a dict containing system info."""
|
||||
uname = platform.uname()
|
||||
cpu_freq = psutil.cpu_freq()
|
||||
virtual_mem = psutil.virtual_memory()
|
||||
swap_mem = psutil.swap_memory()
|
||||
gpus = GPUtil.getGPUs()
|
||||
return {
|
||||
"System": {
|
||||
"OS": uname.system,
|
||||
"OS Version": uname.version,
|
||||
"Machine": uname.machine,
|
||||
"Processor": uname.processor,
|
||||
},
|
||||
"CPU": {
|
||||
"Physical Cores": psutil.cpu_count(logical=False),
|
||||
"Total Cores": psutil.cpu_count(logical=True),
|
||||
"Max Frequency": f"{cpu_freq.max:.2f}Mhz",
|
||||
},
|
||||
"Memory": {"Total": get_size(virtual_mem.total), "Swap Total": get_size(swap_mem.total)},
|
||||
"GPU": [{"Name": gpu.name, "Total Memory": f"{gpu.memoryTotal}MB"} for gpu in gpus],
|
||||
}
|
||||
@@ -29,6 +29,5 @@ clean:
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile | clean
|
||||
pip-licenses --format=rst --with-urls --output-file=source/primaite-dependencies.rst
|
||||
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
BIN
docs/_static/firewall_acl.png
vendored
|
Before Width: | Height: | Size: 35 KiB After Width: | Height: | Size: 23 KiB |
BIN
docs/_static/primAITE_architecture.png
vendored
Normal file
|
After Width: | Height: | Size: 106 KiB |
@@ -19,4 +19,3 @@
|
||||
:recursive:
|
||||
|
||||
primaite
|
||||
tests
|
||||
|
||||
102
docs/index.rst
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
Welcome to PrimAITE's documentation
|
||||
====================================
|
||||
@@ -11,66 +11,79 @@ What is PrimAITE?
|
||||
Overview
|
||||
^^^^^^^^
|
||||
|
||||
The ARCD Primary-level AI Training Environment (**PrimAITE**) provides an effective simulation capability for the purposes of training and evaluating AI in a cyber-defensive role. It incorporates the functionality required of a primary-level ARCD environment, which includes:
|
||||
The ARCD Primary-level AI Training Environment (**PrimAITE**) provides an effective simulation capability for training and evaluating AI in a cyber-defensive role. It incorporates the functionality required of a primary-level ARCD environment:
|
||||
|
||||
- The ability to model a relevant platform / system context;
|
||||
- The ability to model a relevant system context;
|
||||
- Modelling an adversarial agent that the defensive agent can be trained and evaluated against;
|
||||
- The ability to model key characteristics of a platform / system by representing connections, IP addresses, ports, operating systems, services and traffic loading on links;
|
||||
- Modelling background pattern-of-life;
|
||||
- Operates at machine-speed to enable fast training cycles.
|
||||
- The ability to model key characteristics of a system by representing hosts, servers, network devices, IP addresses, ports, operating systems, folders / files, applications, services and links;
|
||||
- Modelling background (green) pattern-of-life;
|
||||
- Operates at machine-speed to enable fast training cycles via Reinforcement Learning (RL).
|
||||
|
||||
Features
|
||||
^^^^^^^^
|
||||
|
||||
PrimAITE incorporates the following features:
|
||||
|
||||
- Highly configurable (via YAML files) to provide the means to model a variety of platform / system laydowns and adversarial attack scenarios;
|
||||
- A Reinforcement Learning (RL) reward function based on (a) the ability to counter the modelled adversarial cyber-attack, and (b) the ability to ensure success;
|
||||
- Provision of logging to support AI performance / effectiveness assessment;
|
||||
- Uses the concept of Information Exchange Requirements (IERs) to model background pattern of life and adversarial behaviour;
|
||||
- An Access Control List (ACL) function, mimicking the behaviour of a network firewall, is applied across the model, following standard ACL rule format (e.g. DENY/ALLOW, source IP address, destination IP address, protocol and port);
|
||||
- Application of traffic to the links of the platform / system laydown adheres to the ACL ruleset;
|
||||
- Presents both a Gymnasium and Ray RLLib interface to the environment, allowing integration with any compliant defensive agents;
|
||||
- Allows for the saving and loading of trained defensive agents;
|
||||
- Stochastic adversarial agent behaviour;
|
||||
- Full capture of discrete logs relating to agent training or evaluation (system state, agent actions taken, instantaneous and average reward for every step of every episode);
|
||||
- Distinct control over running a training and / or evaluation session;
|
||||
- NetworkX provides laydown visualisation capability.
|
||||
- Architected with a separate Simulation layer and Game layer. This separation of concerns defines a clear path towards transfer learning with environments of differing fidelity;
|
||||
- Ability to reconfigure an RL reward function based on (a) the ability to counter the modelled adversarial cyber-attack, and (b) the ability to ensure success for green agents;
|
||||
- Access Control List (ACL) functions for network devices (routers and firewalls), following standard ACL rule format (e.g., DENY / ALLOW, source / destination IP addresses, protocol and port);
|
||||
- Application of traffic to the links of the system laydown adheres to the ACL rulesets and routing tables contained within each network device;
|
||||
- Provides RL environments adherent to the Farama Foundation Gymnasium (Previously OpenAI Gym) API, allowing integration with any compliant RL Agent frameworks;
|
||||
- Provides RL environments adherent to Ray RLlib environment specifications for single-agent and multi-agent scenarios;
|
||||
- Assessed for compatibility with Stable-Baselines3 (SB3), Ray RLlib, and bespoke agents;
|
||||
- Persona-based adversarial (Red) agent behaviour; several out-the-box personas are provided, and more can be developed to suit the needs of the task. Stochastic variations in Red agent behaviour are also included as required;
|
||||
- A robust system logging tool, automatically enabled at the node level and featuring various log levels and terminal output options, enables PrimAITE users to conduct in-depth network simulations;
|
||||
- A PCAP service is seamlessly integrated within the simulation, automatically capturing and logging frames for both
|
||||
inbound and outbound traffic at the network interface level. This automatic functionality, combined with the ability
|
||||
to separate traffic directions, significantly enhances network analysis and troubleshooting capabilities;
|
||||
- Agent action logs provide a description of every action taken by each agent during the episode. This includes timestep, action, parameters, request and response, for all Blue agent activity, which is aligned with the Track 2 Common Action / Observation Space (CAOS) format. Action logs also details of all scripted / stochastic red / green agent actions;
|
||||
- Environment ground truth is provided at every timestep, providing a full description of the environment’s true state;
|
||||
- Alignment with CAOS provides the ability to transfer agents between CAOS compliant environments.
|
||||
|
||||
Architecture
|
||||
^^^^^^^^^^^^
|
||||
|
||||
PrimAITE is a Python application and is therefore Operating System agnostic. The Gymnasium and Ray RLLib frameworks are employed to provide an interface and source for AI agents. Configuration of PrimAITE is achieved via included YAML files which support full control over the platform / system laydown being modelled, background pattern of life, adversarial (red agent) behaviour, and step and episode count. NetworkX based nodes and links host Python classes to present attributes and methods, and hence a more representative platform / system can be modelled within the simulation.
|
||||
PrimAITE is a Python application and will operate on multiple Operating Systems (Windows, Linux and Mac);
|
||||
a comprehensive installation and user guide is provided with each release to support its usage.
|
||||
|
||||
Configuration of PrimAITE is achieved via included YAML files which support full control over the network / system laydown being modelled, background pattern of life, adversarial (red agent) behaviour, and step and episode count.
|
||||
A Simulation Controller layer manages the overall running of the simulation, keeping track of all low-level objects.
|
||||
|
||||
It is agnostic to the number of agents, their action / observation spaces, and the RL library being used.
|
||||
|
||||
It presents a public API providing a method for describing the current state of the simulation, a method that accepts action requests and provides responses, and a method that triggers a timestep advancement.
|
||||
The Game Layer converts the simulation into a playable game for the agent(s).
|
||||
|
||||
It translates between simulation state and Gymnasium.Spaces to pass action / observation data between the agent(s) and the simulation. It is responsible for calculating rewards, managing Multi-Agent RL (MARL) action turns, and via a single agent interface can interact with Blue, Red and Green agents.
|
||||
|
||||
Agents can either generate their own scripted behaviour or accept input behaviour from an RL agent.
|
||||
|
||||
Finally, a Gymnasium / Ray RLlib Environment Layer forwards requests to the Game Layer as the agent sends them. This layer also manages most of the I/O, such as reading in the configuration files and saving agent logs.
|
||||
|
||||
.. image:: ../../_static/primAITE_architecture.png
|
||||
:width: 500
|
||||
:align: center
|
||||
|
||||
|
||||
Training & Evaluation Capability
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
PrimAITE provides a training and evaluation capability to AI agents in the context of cyber-attack, via its Gymnasium and RLLib compliant interface. Scenarios can be constructed to reflect platform / system laydowns consisting of any configuration of nodes (e.g. PCs, servers, switches etc.) and network links between them. All nodes can be configured to model services (and their status) and the traffic loading between them over the network links. Traffic loading is broken down into a per service granularity, relating directly to a protocol (e.g. Service A would be configured as a TCP service, and TCP traffic then flows between instances of Service A under the direction of a tailored IER). Highlights of PrimAITE’s training and evaluation capability are:
|
||||
PrimAITE provides a training and evaluation capability to AI agents in the context of cyber-attack, via its Gymnasium / Ray RLlib compliant interface.
|
||||
|
||||
Scenarios can be constructed to reflect network / system laydowns consisting of any configuration of nodes (e.g., PCs, servers etc.) and the networking equipment and links between them.
|
||||
|
||||
All nodes can be configured to contain applications, services, folders and files (and their status).
|
||||
|
||||
Traffic flows between services and applications as directed by an ‘execution definition,’ with the traffic flow on the network governed by the network equipment (switches, routers and firewalls) and the ACL rules and routing tables they employ.
|
||||
|
||||
Highlights of PrimAITE’s training and evaluation capability are:
|
||||
|
||||
- The scenario is not bound to a representation of any platform, system, or technology;
|
||||
- Fully configurable (network / system laydown, IERs, node pattern-of-life, ACL, number of episodes, steps per episode) and repeatable to suit the requirements of AI agents;
|
||||
- Can integrate with any Gymnasium or RLLib compliant AI agent.
|
||||
|
||||
Use of PrimAITE default scenarios within ARCD is supported by a “Use Case Profile” tailored to the scenario.
|
||||
|
||||
AI Assessment Capability
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
PrimAITE includes the capability to support in-depth assessment of cyber defence AI by outputting logs of the environment state and AI behaviour throughout both training and evaluation sessions. These logs include the following data:
|
||||
|
||||
- Timestamp;
|
||||
- Episode and step number;
|
||||
- Agent identifier;
|
||||
- Observation space;
|
||||
- Action taken (by defensive AI);
|
||||
- Reward value.
|
||||
|
||||
Logs are available in CSV format and provide coverage of the above data for every step of every episode.
|
||||
|
||||
- Fully configurable (network / system laydown, green pattern-of-life, red personas, reward function, ACL rules for each device, number of episodes / steps, action / observation space) and repeatable to suit the requirements of AI agents;
|
||||
- Can integrate with any Gymnasium / Ray RLlib compliant AI agent .
|
||||
|
||||
|
||||
PrimAITE provides a number of use cases (network and red/green action configurations) by default which the user is able to extend and modify as required.
|
||||
|
||||
What is PrimAITE built with
|
||||
---------------------------
|
||||
@@ -109,6 +122,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
|
||||
source/config
|
||||
source/environment
|
||||
source/customising_scenarios
|
||||
source/varying_config_files
|
||||
|
||||
.. toctree::
|
||||
:caption: Notebooks:
|
||||
@@ -126,13 +140,3 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
|
||||
source/request_system
|
||||
PrimAITE API <source/_autosummary/primaite>
|
||||
PrimAITE Tests <source/_autosummary/tests>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:caption: Project Links:
|
||||
:hidden:
|
||||
|
||||
Code <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE>
|
||||
Issues <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/issues>
|
||||
Pull Requests <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/pulls>
|
||||
Discussions <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/discussions>
|
||||
|
||||
@@ -36,11 +36,6 @@ IF EXIST %AUTOSUMMARYDIR% (
|
||||
RMDIR %AUTOSUMMARYDIR% /s /q
|
||||
)
|
||||
|
||||
REM print the YT licenses
|
||||
set LICENSEBUILD=pip-licenses --format=rst --with-urls
|
||||
set DEPS="%cd%\source\primaite-dependencies.rst"
|
||||
|
||||
%LICENSEBUILD% --output-file=%DEPS%
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
goto end
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _about:
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
PrimAITE |VERSION| Configuration
|
||||
********************************
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. role:: raw-html(raw)
|
||||
:format: html
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _Developer Tools:
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
Example Jupyter Notebooks
|
||||
=========================
|
||||
@@ -77,6 +77,6 @@ The following extensions should now be installed
|
||||
:width: 300
|
||||
:align: center
|
||||
|
||||
VSCode will then ask for a Python environment version to use. PrimAITE is compatible with Python versions 3.8 - 3.10
|
||||
VSCode will then ask for a Python environment version to use. PrimAITE is compatible with Python versions 3.8 - 3.11
|
||||
|
||||
You should now be able to interact with the notebook.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _getting-started:
|
||||
|
||||
@@ -82,7 +82,7 @@ Install PrimAITE
|
||||
.. code-block:: bash
|
||||
:caption: Unix
|
||||
|
||||
pip install path/to/your/primaite.whl
|
||||
pip install path/to/your/primaite.whl[rl]
|
||||
|
||||
.. code-block:: powershell
|
||||
:caption: Windows (Powershell)
|
||||
@@ -107,7 +107,9 @@ Clone & Install PrimAITE for Development
|
||||
To be able to extend PrimAITE further, or to build wheels manually before install, clone the repository to a location
|
||||
of your choice:
|
||||
|
||||
1. Clone the repository
|
||||
1. Clone the repository.
|
||||
|
||||
For example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
@@ -133,12 +135,12 @@ of your choice:
|
||||
.. code-block:: bash
|
||||
:caption: Unix
|
||||
|
||||
pip install -e .[dev]
|
||||
pip install -e .[dev,rl]
|
||||
|
||||
.. code-block:: powershell
|
||||
:caption: Windows (Powershell)
|
||||
|
||||
pip install -e .[dev]
|
||||
pip install -e .[dev,rl]
|
||||
|
||||
To view the complete list of packages installed during PrimAITE installation, go to the dependencies page (:ref:`Dependencies`).
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
Glossary
|
||||
=============
|
||||
@@ -38,14 +38,11 @@ Glossary
|
||||
Blue Agent
|
||||
A defensive agent that protects the network from Red Agent attacks to minimise disruption to green agents and protect data.
|
||||
|
||||
Information Exchange Requirement (IER)
|
||||
Simulates network traffic by sending data from one network node to another via links for a specified amount of time. IERs can be part of green agent behaviour or red agent behaviour. PrimAITE can be configured to apply a penalty for green agents' IERs being blocked and a reward for red agents' IERs being blocked.
|
||||
|
||||
Pattern-of-Life (PoL)
|
||||
PoLs allow agents to change the current hardware, OS, file system, or service statuses of nodes during the course of an episode. For example, a green agent may restart a server node to represent scheduled maintainance. A red agent's Pattern-of-Life can be used to attack nodes by changing their states to CORRUPTED or COMPROMISED.
|
||||
|
||||
Reward
|
||||
The reward is a single number used by the blue agent to understand whether it's performing well or poorly. RL agents change their behaviour in an attempt to increase the expected reward each episode. The reward is generated based on the current states of the environment / :term:`reference environment` and is impacted positively by things like green IERS running successfully and negatively by things like nodes being compromised.
|
||||
The reward is a single number used by the blue agent to understand whether it's performing well or poorly. RL agents change their behaviour in an attempt to increase the expected reward each episode. The reward is generated based on the current states of the environment and is impacted positively by things like green PoL running successfully and negatively by things like nodes being compromised.
|
||||
|
||||
Observation
|
||||
An observation is a representation of the current state of the environment that is given to the learning agent so it can decide on which action to perform. If the environment is 'fully observable', the observation contains information about every possible aspect of the environment. More commonly, the environment is 'partially observable' which means the learning agent has to make decisions without knowing every detail of the current environment state.
|
||||
@@ -65,12 +62,6 @@ Glossary
|
||||
Episode
|
||||
When an episode starts, the network simulation is reset to an initial state. The agents take actions on each step of the episode until it reaches a terminal state, which usually happens after a predetermined number of steps. After the terminal state is reached, a new episode starts and the RL agent has another opportunity to protect the network.
|
||||
|
||||
Reference environment
|
||||
While the network simulation is unfolding, a parallel simulation takes place which is identical to the main one except that blue and red agent actions are not applied. This reference environment essentially shows what would be happening to the network if there had been no cyberattack or defense. The reference environment is used to calculate rewards.
|
||||
|
||||
Transaction
|
||||
PrimAITE records the decisions of the learning agent by saving its observation, action, and reward at every time step. During each session, this data is saved to disk to allow for full inspection.
|
||||
|
||||
Laydown
|
||||
The laydown is a file which defines the training scenario. It contains the network topology, firewall rules, services, protocols, and details about green and red agent behaviours.
|
||||
|
||||
@@ -78,4 +69,4 @@ Glossary
|
||||
PrimAITE uses the Gymnasium reinforcement learning framework API to create a training environment and interface with RL agents. Gymnasium defines a common way of creating observations, actions, and rewards.
|
||||
|
||||
User app home
|
||||
PrimAITE supports upgrading software version while retaining user data. The user data directory is where configs, notebooks, and results are stored, this location is `~/primaite<version>` on linux/darwin and `C:\\Users\\<username>\\primaite\\<version>` on Windows.
|
||||
PrimAITE supports upgrading software version while retaining user data. The user data directory is where configs, notebooks, and results are stored, this location is `~/primaite<version>/` on linux/darwin and `C:\\Users\\<username>\\primaite<version>` on Windows.
|
||||
|
||||
37
docs/source/primaite-dependencies.rst
Normal file
@@ -0,0 +1,37 @@
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
|
||||
| Name | Version | License | Description | URL |
|
||||
+===================+=========+====================================+=======================================================================================================+==============================================+
|
||||
| gymnasium | 0.28.1 | MIT License | A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym). | https://farama.org |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
|
||||
| ipywidgets | 8.1.3 | BSD License | Jupyter interactive widgets | http://jupyter.org |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
|
||||
| jupyterlab | 3.6.1 | BSD License | JupyterLab computational environment | https://jupyter.org |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
|
||||
| kaleido | 0.2.1 | MIT | Static image export for web-based visualization libraries with zero dependencies | https://github.com/plotly/Kaleido |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
|
||||
| matplotlib | 3.7.1 | Python Software Foundation License | Python plotting package | https://matplotlib.org |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
|
||||
| networkx | 3.1 | BSD License | Python package for creating and manipulating graphs and networks | https://networkx.org/ |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
|
||||
| numpy | 1.23.5 | BSD License | NumPy is the fundamental package for array computing with Python. | https://www.numpy.org |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
|
||||
| platformdirs | 3.5.1 | MIT License | A small Python package for determining appropriate platform-specific dirs, e.g. a "user data dir". | https://github.com/platformdirs/platformdirs |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
|
||||
| plotly | 5.15.0 | MIT License | An open-source, interactive data visualization library for Python | https://plotly.com/python/ |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
|
||||
| polars | 0.18.4 | MIT License | Blazingly fast DataFrame library | https://www.pola.rs/ |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
|
||||
| prettytable | 3.8.0 | BSD License (BSD (3 clause)) | A simple Python library for easily displaying tabular data in a visually appealing ASCII table format | https://github.com/jazzband/prettytable |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
|
||||
| pydantic | 2.7.0 | MIT License | Data validation using Python type hints | https://github.com/pydantic/pydantic |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
|
||||
| PyYAML | 6.0 | MIT License | YAML parser and emitter for Python | https://pyyaml.org/ |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
|
||||
| ray | 2.23.0 | Apache 2.0 | Ray provides a simple, universal API for building distributed applications. | https://github.com/ray-project/ray |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
|
||||
| stable-baselines3 | 2.1.0 | MIT | Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms. | https://github.com/DLR-RM/stable-baselines3 |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
|
||||
| tensorflow | 2.12.0 | Apache Software License | TensorFlow is an open source machine learning framework for everyone. | https://www.tensorflow.org/ |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
|
||||
| typer | 0.9.0 | MIT License | Typer, build great CLIs. Easy to code. Based on Python type hints. | https://github.com/tiangolo/typer |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
Request System
|
||||
**************
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
|
||||
Simulation
|
||||
|
||||
347
docs/source/simulation_components/system/applications/nmap.rst
Normal file
@@ -0,0 +1,347 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _NMAP:
|
||||
|
||||
NMAP
|
||||
====
|
||||
|
||||
Overview
|
||||
--------
|
||||
|
||||
The NMAP application is used to simulate network scanning activities. NMAP is a powerful tool that helps in discovering
|
||||
hosts and services on a network. It provides functionalities such as ping scans to discover active hosts and port scans
|
||||
to detect open ports on those hosts.
|
||||
|
||||
The NMAP application is essential for network administrators and security professionals to map out a network's
|
||||
structure, identify active devices, and find potential vulnerabilities by discovering open ports and running services.
|
||||
However, it is also a tool frequently used by attackers during the reconnaissance stage of a cyber attack to gather
|
||||
information about the target network.
|
||||
|
||||
Scan Types
|
||||
----------
|
||||
|
||||
Ping Scan
|
||||
^^^^^^^^^
|
||||
|
||||
A ping scan is used to identify which hosts on a network are active and reachable. This is achieved by sending ICMP
|
||||
Echo Request packets (ping) to the target IP addresses. If a host responds with an ICMP Echo Reply, it is considered
|
||||
active. Ping scans are useful for quickly mapping out live hosts in a network.
|
||||
|
||||
Port Scan
|
||||
^^^^^^^^^
|
||||
|
||||
A port scan is used to detect open ports on a target host or range of hosts. Open ports can indicate running services
|
||||
that might be exploitable or require securing. Port scans help in understanding the services available on a network and
|
||||
identifying potential entry points for attacks. There are three types of port scans based on the scope:
|
||||
|
||||
- **Horizontal Port Scan**: This scan targets a specific port across a range of IP addresses. It helps in identifying
|
||||
which hosts have a particular service running.
|
||||
|
||||
- **Vertical Port Scan**: This scan targets multiple ports on a single IP address. It provides detailed information
|
||||
about the services running on a specific host.
|
||||
|
||||
- **Box Scan**: This combines both horizontal and vertical scans, targeting multiple ports across multiple IP addresses.
|
||||
It gives a comprehensive view of the network's service landscape.
|
||||
|
||||
Example Usage
|
||||
-------------
|
||||
|
||||
The network we use for these examples is defined below:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ipaddress import IPv4Network
|
||||
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.network.router import Router
|
||||
from primaite.simulator.network.hardware.nodes.network.switch import Switch
|
||||
from primaite.simulator.system.applications.nmap import NMAP
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
|
||||
# Initialize the network
|
||||
network = Network()
|
||||
|
||||
# Set up the router
|
||||
router = Router(hostname="router", start_up_duration=0)
|
||||
router.power_on()
|
||||
router.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
|
||||
|
||||
# Set up PC 1
|
||||
pc_1 = Computer(
|
||||
hostname="pc_1",
|
||||
ip_address="192.168.1.11",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0
|
||||
)
|
||||
pc_1.power_on()
|
||||
|
||||
# Set up PC 2
|
||||
pc_2 = Computer(
|
||||
hostname="pc_2",
|
||||
ip_address="192.168.1.12",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0
|
||||
)
|
||||
pc_2.power_on()
|
||||
pc_2.software_manager.install(DatabaseService)
|
||||
pc_2.software_manager.software["DatabaseService"].start() # start the postgres server
|
||||
|
||||
# Set up PC 3
|
||||
pc_3 = Computer(
|
||||
hostname="pc_3",
|
||||
ip_address="192.168.1.13",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0
|
||||
)
|
||||
# Don't power on PC 3
|
||||
|
||||
# Set up the switch
|
||||
switch = Switch(hostname="switch", start_up_duration=0)
|
||||
switch.power_on()
|
||||
|
||||
# Connect devices
|
||||
network.connect(router.network_interface[1], switch.network_interface[24])
|
||||
network.connect(switch.network_interface[1], pc_1.network_interface[1])
|
||||
network.connect(switch.network_interface[2], pc_2.network_interface[1])
|
||||
network.connect(switch.network_interface[3], pc_3.network_interface[1])
|
||||
|
||||
|
||||
pc_1_nmap: NMAP = pc_1.software_manager.software["NMAP"]
|
||||
|
||||
|
||||
Ping Scan
|
||||
^^^^^^^^^
|
||||
|
||||
Perform a ping scan to find active hosts in the `192.168.1.0/24` subnet:
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Ping Scan Code
|
||||
|
||||
active_hosts = pc_1_nmap.ping_scan(target_ip_address=IPv4Network("192.168.1.0/24"))
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Ping Scan Return Value
|
||||
|
||||
[
|
||||
IPv4Address('192.168.1.11'),
|
||||
IPv4Address('192.168.1.12'),
|
||||
IPv4Address('192.168.1.1')
|
||||
]
|
||||
|
||||
.. code-block:: text
|
||||
:caption: Ping Scan Output
|
||||
|
||||
+-------------------------+
|
||||
| pc_1 NMAP Ping Scan |
|
||||
+--------------+----------+
|
||||
| IP Address | Can Ping |
|
||||
+--------------+----------+
|
||||
| 192.168.1.1 | True |
|
||||
| 192.168.1.11 | True |
|
||||
| 192.168.1.12 | True |
|
||||
+--------------+----------+
|
||||
|
||||
Horizontal Port Scan
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Perform a horizontal port scan on port 5432 across multiple IP addresses:
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Horizontal Port Scan Code
|
||||
|
||||
horizontal_scan_results = pc_1_nmap.port_scan(
|
||||
target_ip_address=[IPv4Address("192.168.1.12"), IPv4Address("192.168.1.13")],
|
||||
target_port=Port(5432 )
|
||||
)
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Horizontal Port Scan Return Value
|
||||
|
||||
{
|
||||
IPv4Address('192.168.1.12'): {
|
||||
<IPProtocol.TCP: 'tcp'>: [
|
||||
<Port.POSTGRES_SERVER: 5432>
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
.. code-block:: text
|
||||
:caption: Horizontal Port Scan Output
|
||||
|
||||
+--------------------------------------------------+
|
||||
| pc_1 NMAP Port Scan (Horizontal) |
|
||||
+--------------+------+-----------------+----------+
|
||||
| IP Address | Port | Name | Protocol |
|
||||
+--------------+------+-----------------+----------+
|
||||
| 192.168.1.12 | 5432 | POSTGRES_SERVER | TCP |
|
||||
+--------------+------+-----------------+----------+
|
||||
|
||||
Vertical Post Scan
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Perform a vertical port scan on multiple ports on a single IP address:
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Vertical Port Scan Code
|
||||
|
||||
vertical_scan_results = pc_1_nmap.port_scan(
|
||||
target_ip_address=[IPv4Address("192.168.1.12")],
|
||||
target_port=[Port(21), Port(22), Port(80), Port(443)]
|
||||
)
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Vertical Port Scan Return Value
|
||||
|
||||
{
|
||||
IPv4Address('192.168.1.12'): {
|
||||
<IPProtocol.TCP: 'tcp'>: [
|
||||
<Port.FTP: 21>,
|
||||
<Port.HTTP: 80>
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
.. code-block:: text
|
||||
:caption: Vertical Port Scan Output
|
||||
|
||||
+---------------------------------------+
|
||||
| pc_1 NMAP Port Scan (Vertical) |
|
||||
+--------------+------+------+----------+
|
||||
| IP Address | Port | Name | Protocol |
|
||||
+--------------+------+------+----------+
|
||||
| 192.168.1.12 | 21 | FTP | TCP |
|
||||
| 192.168.1.12 | 80 | HTTP | TCP |
|
||||
+--------------+------+------+----------+
|
||||
|
||||
Box Scan
|
||||
^^^^^^^^
|
||||
|
||||
Perform a box scan on multiple ports across multiple IP addresses:
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Box Port Scan Code
|
||||
|
||||
# Power PC 3 on before performing the box scan
|
||||
pc_3.power_on()
|
||||
|
||||
|
||||
box_scan_results = pc_1_nmap.port_scan(
|
||||
target_ip_address=[IPv4Address("192.168.1.12"), IPv4Address("192.168.1.13")],
|
||||
target_port=[Port(21), Port(22), Port(80), Port(443)]
|
||||
)
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Box Port Scan Return Value
|
||||
|
||||
{
|
||||
IPv4Address('192.168.1.13'): {
|
||||
<IPProtocol.TCP: 'tcp'>: [
|
||||
<Port.FTP: 21>,
|
||||
<Port.HTTP: 80>
|
||||
]
|
||||
},
|
||||
IPv4Address('192.168.1.12'): {
|
||||
<IPProtocol.TCP: 'tcp'>: [
|
||||
<Port.FTP: 21>,
|
||||
<Port.HTTP: 80>
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
.. code-block:: text
|
||||
:caption: Box Port Scan Output
|
||||
|
||||
+---------------------------------------+
|
||||
| pc_1 NMAP Port Scan (Box) |
|
||||
+--------------+------+------+----------+
|
||||
| IP Address | Port | Name | Protocol |
|
||||
+--------------+------+------+----------+
|
||||
| 192.168.1.12 | 21 | FTP | TCP |
|
||||
| 192.168.1.12 | 80 | HTTP | TCP |
|
||||
| 192.168.1.13 | 21 | FTP | TCP |
|
||||
| 192.168.1.13 | 80 | HTTP | TCP |
|
||||
+--------------+------+------+----------+
|
||||
|
||||
Full Box Scan
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Perform a full box scan on all ports, over both TCP and UDP, on a whole subnet:
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Box Port Scan Code
|
||||
|
||||
# Power PC 3 on before performing the full box scan
|
||||
pc_3.power_on()
|
||||
|
||||
|
||||
full_box_scan_results = pc_1_nmap.port_scan(
|
||||
target_ip_address=IPv4Network("192.168.1.0/24"),
|
||||
)
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Box Port Scan Return Value
|
||||
|
||||
{
|
||||
IPv4Address('192.168.1.11'): {
|
||||
<IPProtocol.UDP: 'udp'>: [
|
||||
<Port.ARP: 219>
|
||||
]
|
||||
},
|
||||
IPv4Address('192.168.1.1'): {
|
||||
<IPProtocol.UDP: 'udp'>: [
|
||||
<Port.ARP: 219>
|
||||
]
|
||||
},
|
||||
IPv4Address('192.168.1.12'): {
|
||||
<IPProtocol.TCP: 'tcp'>: [
|
||||
<Port.HTTP: 80>,
|
||||
<Port.DNS: 53>,
|
||||
<Port.POSTGRES_SERVER: 5432>,
|
||||
<Port.FTP: 21>
|
||||
],
|
||||
<IPProtocol.UDP: 'udp'>: [
|
||||
<Port.NTP: 123>,
|
||||
<Port.ARP: 219>
|
||||
]
|
||||
},
|
||||
IPv4Address('192.168.1.13'): {
|
||||
<IPProtocol.TCP: 'tcp'>: [
|
||||
<Port.HTTP: 80>,
|
||||
<Port.DNS: 53>,
|
||||
<Port.FTP: 21>
|
||||
],
|
||||
<IPProtocol.UDP: 'udp'>: [
|
||||
<Port.NTP: 123>,
|
||||
<Port.ARP: 219>
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
.. code-block:: text
|
||||
:caption: Box Port Scan Output
|
||||
|
||||
+--------------------------------------------------+
|
||||
| pc_1 NMAP Port Scan (Box) |
|
||||
+--------------+------+-----------------+----------+
|
||||
| IP Address | Port | Name | Protocol |
|
||||
+--------------+------+-----------------+----------+
|
||||
| 192.168.1.1 | 219 | ARP | UDP |
|
||||
| 192.168.1.11 | 219 | ARP | UDP |
|
||||
| 192.168.1.12 | 21 | FTP | TCP |
|
||||
| 192.168.1.12 | 53 | DNS | TCP |
|
||||
| 192.168.1.12 | 80 | HTTP | TCP |
|
||||
| 192.168.1.12 | 123 | NTP | UDP |
|
||||
| 192.168.1.12 | 219 | ARP | UDP |
|
||||
| 192.168.1.12 | 5432 | POSTGRES_SERVER | TCP |
|
||||
| 192.168.1.13 | 21 | FTP | TCP |
|
||||
| 192.168.1.13 | 53 | DNS | TCP |
|
||||
| 192.168.1.13 | 80 | HTTP | TCP |
|
||||
| 192.168.1.13 | 123 | NTP | UDP |
|
||||
| 192.168.1.13 | 219 | ARP | UDP |
|
||||
+--------------+------+-----------------+----------+
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
|
||||
Simulation Structure
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
Simulation State
|
||||
================
|
||||
|
||||
49
docs/source/varying_config_files.rst
Normal file
@@ -0,0 +1,49 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
Defining variations in the config files
|
||||
================
|
||||
|
||||
PrimAITE supports the ability to use different variations on a scenario at different episodes. This can be used to increase domain randomisation to prevent overfitting, or to set up curriculum learning to train agents to perform more complicated tasks.
|
||||
|
||||
When using a fixed scenario, a single yaml config file is used. However, to use episode schedules, PrimAITE uses a directory with several config files that work together.
|
||||
Defining variations in the config file.
|
||||
|
||||
Base scenario
|
||||
*************
|
||||
|
||||
The base scenario is essentially the same as a fixed YAML configuration, but it can contain placeholders that are populated with episode-specific data at runtime. The base scenario contains any network, agent, or settings that remain fixed for the entire training/evaluation session.
|
||||
|
||||
The placeholders are defined as YAML Aliases and they are denoted by an asterisk (*placeholder).
|
||||
|
||||
Variations
|
||||
**********
|
||||
|
||||
For each variation that could be used in a placeholder, there is a separate yaml file that contains the data that should populate the placeholder.
|
||||
|
||||
The data that fills the placeholder is defined as a YAML Anchor in a separate file, denoted by an ampersand ``&anchor``.
|
||||
|
||||
Learn more about YAML Aliases and Anchors `here <https://yaml.org/spec/1.2.2/#3222-anchors-and-aliases>`_.
|
||||
|
||||
Schedule
|
||||
********
|
||||
|
||||
Users must define which combination of scenario variations should be loaded in each episode. This takes the form of a YAML file with a relative path to the base scenario and a list of paths to be loaded in during each episode.
|
||||
|
||||
It takes the following format:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
base_scenario: base.yaml
|
||||
schedule:
|
||||
0: # list of variations to load in at episode 0 (before the first call to env.reset() happens)
|
||||
- laydown_1.yaml
|
||||
- attack_1.yaml
|
||||
1: # list of variations to load in at episode 1 (after the first env.reset() call)
|
||||
- laydown_2.yaml
|
||||
- attack_2.yaml
|
||||
|
||||
For more information please refer to the ``Using Episode Schedules`` notebook in either :ref:`Executed Notebooks` or run the notebook interactively in ``notebooks/example_notebooks/``.
|
||||
|
||||
For further information around notebooks in general refer to the :ref:`Example Jupyter Notebooks`.
|
||||
@@ -33,15 +33,13 @@ dependencies = [
|
||||
"numpy==1.23.5",
|
||||
"platformdirs==3.5.1",
|
||||
"plotly==5.15.0",
|
||||
"polars==0.18.4",
|
||||
"polars==0.20.30",
|
||||
"prettytable==3.8.0",
|
||||
"PyYAML==6.0",
|
||||
"stable-baselines3[extra]==2.1.0",
|
||||
"tensorflow==2.12.0",
|
||||
"typer[all]==0.9.0",
|
||||
"pydantic==2.7.0",
|
||||
"ray[rllib] >= 2.9, < 3",
|
||||
"ipywidgets"
|
||||
"ipywidgets",
|
||||
"deepdiff"
|
||||
]
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
@@ -55,6 +53,11 @@ license-files = ["LICENSE"]
|
||||
|
||||
|
||||
[project.optional-dependencies]
|
||||
rl = [
|
||||
"ray[rllib] >= 2.20.0, < 3",
|
||||
"tensorflow==2.12.0",
|
||||
"stable-baselines3[extra]==2.1.0",
|
||||
]
|
||||
dev = [
|
||||
"build==0.10.0",
|
||||
"flake8==6.0.0",
|
||||
|
||||
@@ -1 +1 @@
|
||||
3.0.0b9
|
||||
3.0.0
|
||||
|
||||
@@ -10,7 +10,7 @@ AbstractAction. The ActionManager is responsible for:
|
||||
"""
|
||||
import itertools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
@@ -870,6 +870,74 @@ class NetworkPortDisableAction(AbstractAction):
|
||||
return ["network", "node", target_nodename, "network_interface", port_id, "disable"]
|
||||
|
||||
|
||||
class NodeNMAPPingScanAction(AbstractAction):
|
||||
"""Action which performs an NMAP ping scan."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
|
||||
def form_request(self, source_node: str, target_ip_address: Union[str, List[str]]) -> List[str]: # noqa
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
source_node,
|
||||
"application",
|
||||
"NMAP",
|
||||
"ping_scan",
|
||||
{"target_ip_address": target_ip_address},
|
||||
]
|
||||
|
||||
|
||||
class NodeNMAPPortScanAction(AbstractAction):
|
||||
"""Action which performs an NMAP port scan."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
|
||||
def form_request(
|
||||
self,
|
||||
source_node: str,
|
||||
target_ip_address: Union[str, List[str]],
|
||||
target_protocol: Optional[Union[str, List[str]]] = None,
|
||||
target_port: Optional[Union[str, List[str]]] = None,
|
||||
) -> List[str]: # noqa
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
source_node,
|
||||
"application",
|
||||
"NMAP",
|
||||
"port_scan",
|
||||
{"target_ip_address": target_ip_address, "target_port": target_port, "target_protocol": target_protocol},
|
||||
]
|
||||
|
||||
|
||||
class NodeNetworkServiceReconAction(AbstractAction):
|
||||
"""Action which performs an NMAP network service recon (ping scan followed by port scan)."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
|
||||
def form_request(
|
||||
self,
|
||||
source_node: str,
|
||||
target_ip_address: Union[str, List[str]],
|
||||
target_protocol: Optional[Union[str, List[str]]] = None,
|
||||
target_port: Optional[Union[str, List[str]]] = None,
|
||||
) -> List[str]: # noqa
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
source_node,
|
||||
"application",
|
||||
"NMAP",
|
||||
"network_service_recon",
|
||||
{"target_ip_address": target_ip_address, "target_port": target_port, "target_protocol": target_protocol},
|
||||
]
|
||||
|
||||
|
||||
class ActionManager:
|
||||
"""Class which manages the action space for an agent."""
|
||||
|
||||
@@ -915,6 +983,9 @@ class ActionManager:
|
||||
"HOST_NIC_DISABLE": HostNICDisableAction,
|
||||
"NETWORK_PORT_ENABLE": NetworkPortEnableAction,
|
||||
"NETWORK_PORT_DISABLE": NetworkPortDisableAction,
|
||||
"NODE_NMAP_PING_SCAN": NodeNMAPPingScanAction,
|
||||
"NODE_NMAP_PORT_SCAN": NodeNMAPPortScanAction,
|
||||
"NODE_NMAP_NETWORK_SERVICE_RECON": NodeNetworkServiceReconAction,
|
||||
}
|
||||
"""Dictionary which maps action type strings to the corresponding action class."""
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class AgentActionHistoryItem(BaseModel):
|
||||
class AgentHistoryItem(BaseModel):
|
||||
"""One entry of an agent's action log - what the agent did and how the simulator responded in 1 step."""
|
||||
|
||||
timestep: int
|
||||
@@ -32,6 +32,8 @@ class AgentActionHistoryItem(BaseModel):
|
||||
response: RequestResponse
|
||||
"""The response sent back by the simulator for this action."""
|
||||
|
||||
reward: Optional[float] = None
|
||||
|
||||
|
||||
class AgentStartSettings(BaseModel):
|
||||
"""Configuration values for when an agent starts performing actions."""
|
||||
@@ -110,7 +112,7 @@ class AbstractAgent(ABC):
|
||||
self.observation_manager: Optional[ObservationManager] = observation_space
|
||||
self.reward_function: Optional[RewardFunction] = reward_function
|
||||
self.agent_settings = agent_settings or AgentSettings()
|
||||
self.action_history: List[AgentActionHistoryItem] = []
|
||||
self.history: List[AgentHistoryItem] = []
|
||||
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -130,7 +132,7 @@ class AbstractAgent(ABC):
|
||||
:return: Reward from the state.
|
||||
:rtype: float
|
||||
"""
|
||||
return self.reward_function.update(state=state, last_action_response=self.action_history[-1])
|
||||
return self.reward_function.update(state=state, last_action_response=self.history[-1])
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
@@ -161,12 +163,16 @@ class AbstractAgent(ABC):
|
||||
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
|
||||
) -> None:
|
||||
"""Process the response from the most recent action."""
|
||||
self.action_history.append(
|
||||
AgentActionHistoryItem(
|
||||
self.history.append(
|
||||
AgentHistoryItem(
|
||||
timestep=timestep, action=action, parameters=parameters, request=request, response=response
|
||||
)
|
||||
)
|
||||
|
||||
def save_reward_to_history(self) -> None:
|
||||
"""Update the most recent history item with the reward value."""
|
||||
self.history[-1].reward = self.reward_function.current_reward
|
||||
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent):
|
||||
"""Base class for actors which generate their own behaviour."""
|
||||
|
||||
@@ -34,7 +34,7 @@ from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.agent.interface import AgentActionHistoryItem
|
||||
from primaite.game.agent.interface import AgentHistoryItem
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
WhereType = Optional[Iterable[Union[str, int]]]
|
||||
@@ -44,7 +44,7 @@ class AbstractReward:
|
||||
"""Base class for reward function components."""
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state."""
|
||||
return 0.0
|
||||
|
||||
@@ -64,7 +64,7 @@ class AbstractReward:
|
||||
class DummyReward(AbstractReward):
|
||||
"""Dummy reward function component which always returns 0."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state."""
|
||||
return 0.0
|
||||
|
||||
@@ -104,7 +104,7 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
file_name,
|
||||
]
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
@@ -159,7 +159,7 @@ class WebServer404Penalty(AbstractReward):
|
||||
"""
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
@@ -213,7 +213,7 @@ class WebpageUnavailablePenalty(AbstractReward):
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
|
||||
self._last_request_failed: bool = False
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""
|
||||
Calculate the reward based on current simulation state, and the recent agent action.
|
||||
|
||||
@@ -273,7 +273,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
|
||||
self._last_request_failed: bool = False
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""
|
||||
Calculate the reward based on current simulation state, and the recent agent action.
|
||||
|
||||
@@ -343,7 +343,7 @@ class SharedReward(AbstractReward):
|
||||
self.callback: Callable[[str], float] = default_callback
|
||||
"""Method that retrieves an agent's current reward given the agent's name."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Simply access the other agent's reward and return it."""
|
||||
return self.callback(self.agent_name)
|
||||
|
||||
@@ -389,7 +389,7 @@ class RewardFunction:
|
||||
"""
|
||||
self.reward_components.append((component, weight))
|
||||
|
||||
def update(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def update(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the overall reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
|
||||
@@ -160,6 +160,7 @@ class PrimaiteGame:
|
||||
agent = self.agents[agent_name]
|
||||
if self.step_counter > 0: # can't get reward before first action
|
||||
agent.update_reward(state=state)
|
||||
agent.save_reward_to_history()
|
||||
agent.update_observation(state=state) # order of this doesn't matter so just use reward order
|
||||
agent.reward_function.total_reward += agent.reward_function.current_reward
|
||||
|
||||
@@ -359,11 +360,6 @@ class PrimaiteGame:
|
||||
server_ip_address=IPv4Address(opt.get("server_ip")),
|
||||
server_password=opt.get("server_password"),
|
||||
payload=opt.get("payload", "ENCRYPT"),
|
||||
c2_beacon_p_of_success=float(opt.get("c2_beacon_p_of_success", "0.5")),
|
||||
target_scan_p_of_success=float(opt.get("target_scan_p_of_success", "0.1")),
|
||||
ransomware_encrypt_p_of_success=float(
|
||||
opt.get("ransomware_encrypt_p_of_success", "0.1")
|
||||
),
|
||||
)
|
||||
elif application_type == "DatabaseClient":
|
||||
if "options" in application_cfg:
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Dict, ForwardRef, List, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, StrictBool, validate_call
|
||||
|
||||
RequestFormat = List[Union[str, int, float]]
|
||||
RequestFormat = List[Union[str, int, float, Dict]]
|
||||
|
||||
RequestResponse = ForwardRef("RequestResponse")
|
||||
"""This makes it possible to type-hint RequestResponse.from_bool return type."""
|
||||
|
||||
@@ -4,13 +4,15 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Customising red agents\n",
|
||||
"# Customising UC2 Red Agents\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"This notebook will go over some examples of how red agent behaviour can be varied by changing its configuration parameters.\n",
|
||||
"\n",
|
||||
"First, let's load the standard Data Manipulation config file, and see what the red agent does.\n",
|
||||
"\n",
|
||||
"*(For a full explanation of the Data Manipulation scenario, check out the notebook `Data-Manipulation-E2E-Demonstration.ipynb`)*"
|
||||
"*(For a full explanation of the Data Manipulation scenario, check out the data manipulation scenario notebook)*"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -22,7 +24,7 @@
|
||||
"# Imports\n",
|
||||
"\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
|
||||
"from primaite.game.agent.interface import AgentHistoryItem\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"import yaml\n",
|
||||
"from pprint import pprint"
|
||||
@@ -63,7 +65,7 @@
|
||||
"source": [
|
||||
"def friendly_output_red_action(info):\n",
|
||||
" # parse the info dict form step output and write out what the red agent is doing\n",
|
||||
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info.action\n",
|
||||
" if red_action == 'DONOTHING':\n",
|
||||
" red_str = 'DO NOTHING'\n",
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Data Manipulation Scenario\n"
|
||||
"# Data Manipulation Scenario\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -59,7 +61,7 @@
|
||||
"\n",
|
||||
"At the start of every episode, the red agent randomly chooses either client 1 or client 2 to login to. It waits a bit then sends a DELETE query to the database from its chosen client. If the delete is successful, the database file is flagged as compromised to signal that data is not available.\n",
|
||||
"\n",
|
||||
"[<img src=\"_package_data/uc2_attack.png\" width=\"500\"/>](_package_data/uc2_attack.png)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"_(click image to enlarge)_"
|
||||
]
|
||||
@@ -79,7 +81,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Reinforcement learning details"
|
||||
"## Reinforcement learning details"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -180,15 +182,15 @@
|
||||
"| link_id | endpoint_a | endpoint_b |\n",
|
||||
"|---------|------------------|-------------------|\n",
|
||||
"| 1 | router_1 | switch_1 |\n",
|
||||
"| 1 | router_1 | switch_2 |\n",
|
||||
"| 1 | switch_1 | domain_controller |\n",
|
||||
"| 1 | switch_1 | web_server |\n",
|
||||
"| 1 | switch_1 | database_server |\n",
|
||||
"| 1 | switch_1 | backup_server |\n",
|
||||
"| 1 | switch_1 | security_suite |\n",
|
||||
"| 1 | switch_2 | client_1 |\n",
|
||||
"| 1 | switch_2 | client_2 |\n",
|
||||
"| 1 | switch_2 | security_suite |\n",
|
||||
"| 2 | router_1 | switch_2 |\n",
|
||||
"| 3 | switch_1 | domain_controller |\n",
|
||||
"| 4 | switch_1 | web_server |\n",
|
||||
"| 5 | switch_1 | database_server |\n",
|
||||
"| 6 | switch_1 | backup_server |\n",
|
||||
"| 7 | switch_1 | security_suite |\n",
|
||||
"| 8 | switch_2 | client_1 |\n",
|
||||
"| 9 | switch_2 | client_2 |\n",
|
||||
"| 10 | switch_2 | security_suite |\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"The ACL rules in the observation space appear in the same order that they do in the actual ACL. Though, only the first 10 rules are shown, there are default rules lower down that cannot be changed by the agent. The extra rules just allow the network to function normally, by allowing pings, ARP traffic, etc.\n",
|
||||
@@ -392,7 +394,7 @@
|
||||
"# Imports\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
|
||||
"from primaite.game.agent.interface import AgentHistoryItem\n",
|
||||
"import yaml\n",
|
||||
"from pprint import pprint\n"
|
||||
]
|
||||
@@ -401,7 +403,8 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Instantiate the environment. We also disable the agent observation flattening.\n",
|
||||
"Instantiate the environment. \n",
|
||||
"We will also disable the agent observation flattening.\n",
|
||||
"\n",
|
||||
"This cell will print the observation when the network is healthy. You should be able to verify Node file and service statuses against the description above."
|
||||
]
|
||||
@@ -444,7 +447,7 @@
|
||||
"source": [
|
||||
"def friendly_output_red_action(info):\n",
|
||||
" # parse the info dict form step output and write out what the red agent is doing\n",
|
||||
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info.action\n",
|
||||
" if red_action == 'DONOTHING':\n",
|
||||
" red_str = 'DO NOTHING'\n",
|
||||
@@ -691,7 +694,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -705,7 +708,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
"version": "3.10.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
170
src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb
Normal file
@@ -0,0 +1,170 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Getting information out of PrimAITE\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Imports\n",
|
||||
"import yaml\n",
|
||||
"from primaite import PRIMAITE_CONFIG\n",
|
||||
"\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"from primaite.simulator.network.hardware.nodes.host.computer import Computer\n",
|
||||
"from notebook.services.config import ConfigManager\n",
|
||||
"\n",
|
||||
"cm = ConfigManager().update('notebook', {'limit_output': 50}) # limit output lines to 50 - for neatness\n",
|
||||
"\n",
|
||||
"# create the env\n",
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"\n",
|
||||
"env = PrimaiteGymEnv(env_config=cfg)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Visualising the Simulation Network"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The network can be visualised by running the code below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.game.simulation.network.draw()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Getting the state of a simulation object\n",
|
||||
"\n",
|
||||
"The state of the simulation object is used to determine the observation space used by agents.\n",
|
||||
"\n",
|
||||
"Any object created using the ``SimComponent`` class has a ``describe_state`` method which can show the state of the object.\n",
|
||||
"\n",
|
||||
"An example of such an object is ``Computer`` which inherits from ``SimComponent``. In the default network configuration, ``client_1`` is a Computer object."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"client_1: Computer = env.game.simulation.network.get_node_by_hostname(\"client_1\")\n",
|
||||
"client_1.describe_state()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### More specific describe_state\n",
|
||||
"\n",
|
||||
"As you can see, the output from the ``describe_state`` method for the ``Computer`` object includes the describe state for all its components. This can cause a large describe state output.\n",
|
||||
"\n",
|
||||
"As stated, the ``describe_state`` can be called on any object that inherits ``SimComponent``. This can allow you retrieve the state of a specific item."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"client_1.file_system.describe_state()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## System Logs\n",
|
||||
"\n",
|
||||
"Objects that inherit from the ``Node`` class will inherit the ``sys_log`` attribute.\n",
|
||||
"\n",
|
||||
"This is to simulate the idea that items such as Computer, Routers, Servers, etc. have a logging system used to diagnose problems."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# store config\n",
|
||||
"# this is to prevent the notebook from breaking your local settings\n",
|
||||
"was_enabled = PRIMAITE_CONFIG[\"developer_mode\"][\"enabled\"]\n",
|
||||
"was_syslogs_enabled = PRIMAITE_CONFIG[\"developer_mode\"][\"output_sys_logs\"]\n",
|
||||
"\n",
|
||||
"# enable dev mode so that the default config outputs are overridden for this demo\n",
|
||||
"PRIMAITE_CONFIG[\"developer_mode\"][\"enabled\"] = True\n",
|
||||
"PRIMAITE_CONFIG[\"developer_mode\"][\"output_sys_logs\"] = True\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Remake the environment\n",
|
||||
"env = PrimaiteGymEnv(env_config=cfg)\n",
|
||||
"\n",
|
||||
"# get the example computer\n",
|
||||
"client_1: Computer = env.game.simulation.network.get_node_by_hostname(\"client_1\")\n",
|
||||
"\n",
|
||||
"# show sys logs on terminal\n",
|
||||
"client_1.sys_log.show()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# restore config\n",
|
||||
"PRIMAITE_CONFIG[\"developer_mode\"][\"enabled\"] = was_enabled\n",
|
||||
"PRIMAITE_CONFIG[\"developer_mode\"][\"output_sys_logs\"] = was_syslogs_enabled"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -6,7 +6,9 @@
|
||||
"source": [
|
||||
"# Requests and Responses\n",
|
||||
"\n",
|
||||
"Agents interact with the PrimAITE simulation via the Request system.\n"
|
||||
"Agents interact with the PrimAITE simulation via the Request system.\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Train a Multi agent system using RLLIB\n",
|
||||
"# Train a Multi agent system using RLLIB\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"This notebook will demonstrate how to use the `PrimaiteRayMARLEnv` to train a very basic system with two PPO agents."
|
||||
]
|
||||
@@ -25,13 +27,13 @@
|
||||
"from primaite.game.game import PrimaiteGame\n",
|
||||
"import yaml\n",
|
||||
"\n",
|
||||
"from primaite.session.environment import PrimaiteRayEnv\n",
|
||||
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
|
||||
"from primaite import PRIMAITE_PATHS\n",
|
||||
"\n",
|
||||
"import ray\n",
|
||||
"from ray import air, tune\n",
|
||||
"from ray.rllib.algorithms.ppo import PPOConfig\n",
|
||||
"from primaite.session.environment import PrimaiteRayMARLEnv\n",
|
||||
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
|
||||
"\n",
|
||||
"# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n",
|
||||
"# to copy the files to your user data path.\n",
|
||||
@@ -60,8 +62,8 @@
|
||||
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
|
||||
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
|
||||
" )\n",
|
||||
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)#, disable_env_checking=True)\n",
|
||||
" .rollouts(num_rollout_workers=0)\n",
|
||||
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)\n",
|
||||
" .env_runners(num_env_runners=0)\n",
|
||||
" .training(train_batch_size=128)\n",
|
||||
" )\n"
|
||||
]
|
||||
|
||||
@@ -4,7 +4,10 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Train a Single agent system using RLLib\n",
|
||||
"# Train a Single agent system using RLLib\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"This notebook will demonstrate how to use PrimaiteRayEnv to train a basic PPO agent."
|
||||
]
|
||||
},
|
||||
@@ -18,8 +21,7 @@
|
||||
"import yaml\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"\n",
|
||||
"from primaite.session.environment import PrimaiteRayEnv\n",
|
||||
"from ray.rllib.algorithms import ppo\n",
|
||||
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
|
||||
"from ray import air, tune\n",
|
||||
"import ray\n",
|
||||
"from ray.rllib.algorithms.ppo import PPOConfig\n",
|
||||
@@ -52,8 +54,8 @@
|
||||
"\n",
|
||||
"config = (\n",
|
||||
" PPOConfig()\n",
|
||||
" .environment(env=PrimaiteRayEnv, env_config=env_config, disable_env_checking=True)\n",
|
||||
" .rollouts(num_rollout_workers=0)\n",
|
||||
" .environment(env=PrimaiteRayEnv, env_config=env_config)\n",
|
||||
" .env_runners(num_env_runners=0)\n",
|
||||
" .training(train_batch_size=128)\n",
|
||||
")\n"
|
||||
]
|
||||
@@ -74,7 +76,7 @@
|
||||
"tune.Tuner(\n",
|
||||
" \"PPO\",\n",
|
||||
" run_config=air.RunConfig(\n",
|
||||
" stop={\"timesteps_total\": 5 * 128}\n",
|
||||
" stop={\"timesteps_total\": 512}\n",
|
||||
" ),\n",
|
||||
" param_space=config\n",
|
||||
").fit()\n"
|
||||
@@ -97,7 +99,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
"source": [
|
||||
"# Training an SB3 Agent\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"This notebook will demonstrate how to use primaite to create and train a PPO agent, using a pre-defined configuration file."
|
||||
]
|
||||
},
|
||||
@@ -43,7 +45,10 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)"
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"for agent in cfg['agents']:\n",
|
||||
" if agent['ref'] == 'defender':\n",
|
||||
" agent['agent_settings']['flatten_obs']=True"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -163,7 +168,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -177,7 +182,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
"version": "3.10.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
"source": [
|
||||
"# Using Episode Schedules\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"PrimAITE supports the ability to use different variations on a scenario at different episodes. This can be used to increase \n",
|
||||
"domain randomisation to prevent overfitting, or to set up curriculum learning to train agents to perform more complicated tasks.\n",
|
||||
"\n",
|
||||
@@ -13,50 +15,6 @@
|
||||
"directory with several config files that work together."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Defining variations in the config file.\n",
|
||||
"\n",
|
||||
"### Base scenario\n",
|
||||
"The base scenario is essentially the same as a fixed YAML configuration, but it can contain placeholders that are \n",
|
||||
"populated with episode-specific data at runtime. The base scenario contains any network, agent, or settings that\n",
|
||||
"remain fixed for the entire training/evaluation session.\n",
|
||||
"\n",
|
||||
"The placeholders are defined as YAML Aliases and they are denoted by an asterisk (`*placeholder`).\n",
|
||||
"\n",
|
||||
"### Variations\n",
|
||||
"For each variation that could be used in a placeholder, there is a separate yaml file that contains the data that should populate the placeholder.\n",
|
||||
"\n",
|
||||
"The data that fills the placeholder is defined as a YAML Anchor in a separate file, denoted by an ampersand (`&anchor`).\n",
|
||||
"\n",
|
||||
"[Learn more about YAML Aliases and Anchors here.](https://www.educative.io/blog/advanced-yaml-syntax-cheatsheet#:~:text=YAML%20Anchors%20and%20Alias)\n",
|
||||
"\n",
|
||||
"### Schedule\n",
|
||||
"Users must define which combination of scenario variations should be loaded in each episode. This takes the form of a\n",
|
||||
"YAML file with a relative path to the base scenario and a list of paths to be loaded in during each episode.\n",
|
||||
"\n",
|
||||
"It takes the following format:\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```yaml\n",
|
||||
"base_scenario: base.yaml\n",
|
||||
"schedule:\n",
|
||||
" 0: # list of variations to load in at episode 0 (before the first call to env.reset() happens)\n",
|
||||
" - laydown_1.yaml\n",
|
||||
" - attack_1.yaml\n",
|
||||
" 1: # list of variations to load in at episode 1 (after the first env.reset() call)\n",
|
||||
" - laydown_2.yaml\n",
|
||||
" - attack_2.yaml\n",
|
||||
"```\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -298,8 +256,8 @@
|
||||
"table = PrettyTable()\n",
|
||||
"table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n",
|
||||
"for i in range(21):\n",
|
||||
" green_action = env.game.agents['green_A'].action_history[i].action\n",
|
||||
" red_action = env.game.agents['red_A'].action_history[i].action\n",
|
||||
" green_action = env.game.agents['green_A'].history[i].action\n",
|
||||
" red_action = env.game.agents['red_A'].history[i].action\n",
|
||||
" table.add_row([i, green_action, red_action])\n",
|
||||
"print(table)"
|
||||
]
|
||||
@@ -329,8 +287,8 @@
|
||||
"table = PrettyTable()\n",
|
||||
"table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n",
|
||||
"for i in range(21):\n",
|
||||
" green_action = env.game.agents['green_B'].action_history[i].action\n",
|
||||
" red_action = env.game.agents['red_B'].action_history[i].action\n",
|
||||
" green_action = env.game.agents['green_B'].history[i].action\n",
|
||||
" red_action = env.game.agents['red_B'].history[i].action\n",
|
||||
" table.add_row([i, green_action, red_action])\n",
|
||||
"print(table)"
|
||||
]
|
||||
|
||||
@@ -4,8 +4,11 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Simple multi-processing demo using SubprocVecEnv from SB3\n",
|
||||
"Based on a code example provided by Rachael Proctor."
|
||||
"# Simple multi-processing demonstration\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"This notebook uses SubprocVecEnv from SB3."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -140,7 +143,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
@@ -12,6 +11,7 @@ from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler
|
||||
from primaite.session.io import PrimaiteIO
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.system.core.packet_capture import PacketCapture
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -37,6 +37,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
self.total_reward_per_episode: Dict[int, float] = {}
|
||||
"""Average rewards of agents per episode."""
|
||||
|
||||
@property
|
||||
def agent(self) -> ProxyAgent:
|
||||
@@ -61,7 +63,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
terminated = False
|
||||
truncated = self.game.calculate_truncated()
|
||||
info = {
|
||||
"agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()}
|
||||
"agent_actions": {name: agent.history[-1] for name, agent in self.game.agents.items()}
|
||||
} # tell us what all the agents did for convenience.
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(step, action, state, reward)
|
||||
@@ -83,16 +85,19 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
with open(path, "w") as file:
|
||||
json.dump(data, file)
|
||||
|
||||
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
"""Reset the environment."""
|
||||
_LOGGER.info(
|
||||
f"Resetting environment, episode {self.episode_counter}, "
|
||||
f"avg. reward: {self.agent.reward_function.total_reward}"
|
||||
)
|
||||
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
|
||||
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
self.episode_counter += 1
|
||||
PacketCapture.clear()
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.episode_scheduler(self.episode_counter))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
state = self.game.get_sim_state()
|
||||
@@ -126,166 +131,5 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
|
||||
class PrimaiteRayEnv(gymnasium.Env):
|
||||
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
|
||||
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
"""Initialise the environment.
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration.
|
||||
:type env_config: Dict
|
||||
"""
|
||||
self.env = PrimaiteGymEnv(env_config=env_config)
|
||||
# self.env.episode_counter -= 1
|
||||
self.action_space = self.env.action_space
|
||||
self.observation_space = self.env.observation_space
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
return self.env.reset(seed=seed)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
|
||||
"""Perform a step in the environment."""
|
||||
return self.env.step(action)
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
self.env.close()
|
||||
|
||||
@property
|
||||
def game(self) -> PrimaiteGame:
|
||||
"""Pass through game from env."""
|
||||
return self.env.game
|
||||
|
||||
|
||||
class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
|
||||
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
"""Initialise the environment.
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
|
||||
which is the PrimaiteGame instance.
|
||||
:type env_config: Dict
|
||||
"""
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
|
||||
"""Object that returns a config corresponding to the current episode."""
|
||||
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
|
||||
"""Reference to the primaite game"""
|
||||
self._agent_ids = list(self.game.rl_agents.keys())
|
||||
"""Agent ids. This is a list of strings of agent names."""
|
||||
|
||||
self.terminateds = set()
|
||||
self.truncateds = set()
|
||||
self.observation_space = gymnasium.spaces.Dict(
|
||||
{
|
||||
name: gymnasium.spaces.flatten_space(agent.observation_manager.space)
|
||||
for name, agent in self.agents.items()
|
||||
}
|
||||
)
|
||||
self.action_space = gymnasium.spaces.Dict(
|
||||
{name: agent.action_manager.space for name, agent in self.agents.items()}
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def agents(self) -> Dict[str, ProxyAgent]:
|
||||
"""Grab a fresh reference to the agents from this episode's game object."""
|
||||
return {name: self.game.rl_agents[name] for name in self._agent_ids}
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
|
||||
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
|
||||
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
self.episode_counter += 1
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
info = {}
|
||||
return next_obs, info
|
||||
|
||||
def step(
|
||||
self, actions: Dict[str, ActType]
|
||||
) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]:
|
||||
"""Perform a step in the environment. Adherent to Ray MultiAgentEnv step API.
|
||||
|
||||
:param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance.
|
||||
:type actions: Dict[str, ActType]
|
||||
:return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent
|
||||
identifier.
|
||||
:rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict]
|
||||
"""
|
||||
step = self.game.step_counter
|
||||
# 1. Perform actions
|
||||
for agent_name, action in actions.items():
|
||||
self.agents[agent_name].store_action(action)
|
||||
self.game.pre_timestep()
|
||||
self.game.apply_agent_actions()
|
||||
|
||||
# 2. Advance timestep
|
||||
self.game.advance_timestep()
|
||||
|
||||
# 3. Get next observations
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
|
||||
# 4. Get rewards
|
||||
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
|
||||
_LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}")
|
||||
terminateds = {name: False for name, _ in self.agents.items()}
|
||||
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
|
||||
infos = {name: {} for name, _ in self.agents.items()}
|
||||
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
|
||||
truncateds["__all__"] = self.game.calculate_truncated()
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(step, actions, state, rewards)
|
||||
return next_obs, rewards, terminateds, truncateds, infos
|
||||
|
||||
def _write_step_metadata_json(self, step: int, actions: Dict, state: Dict, rewards: Dict):
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"step_{step}.json"
|
||||
|
||||
data = {
|
||||
"episode": self.episode_counter,
|
||||
"step": step,
|
||||
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
|
||||
"reward": rewards,
|
||||
"state": state,
|
||||
}
|
||||
with open(path, "w") as file:
|
||||
json.dump(data, file)
|
||||
|
||||
def _get_obs(self) -> Dict[str, ObsType]:
|
||||
"""Return the current observation."""
|
||||
obs = {}
|
||||
for agent_name in self._agent_ids:
|
||||
agent = self.game.rl_agents[agent_name]
|
||||
unflat_space = agent.observation_manager.space
|
||||
unflat_obs = agent.observation_manager.current_observation
|
||||
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
return obs
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
@@ -87,7 +87,7 @@ class PrimaiteIO:
|
||||
"""Return the path where agent actions will be saved."""
|
||||
return self.session_path / "agent_actions" / f"episode_{episode}.json"
|
||||
|
||||
def write_agent_actions(self, agent_actions: Dict[str, List], episode: int) -> None:
|
||||
def write_agent_log(self, agent_actions: Dict[str, List], episode: int) -> None:
|
||||
"""Take the contents of the agent action log and write it to a file.
|
||||
|
||||
:param episode: Episode number
|
||||
|
||||
177
src/primaite/session/ray_envs.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import json
|
||||
from typing import Dict, SupportsFloat, Tuple
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import _LOGGER, PrimaiteGymEnv
|
||||
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler
|
||||
from primaite.session.io import PrimaiteIO
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.system.core.packet_capture import PacketCapture
|
||||
|
||||
|
||||
class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
|
||||
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
"""Initialise the environment.
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
|
||||
which is the PrimaiteGame instance.
|
||||
:type env_config: Dict
|
||||
"""
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
|
||||
"""Object that returns a config corresponding to the current episode."""
|
||||
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
|
||||
"""Reference to the primaite game"""
|
||||
self._agent_ids = list(self.game.rl_agents.keys())
|
||||
"""Agent ids. This is a list of strings of agent names."""
|
||||
|
||||
self.terminateds = set()
|
||||
self.truncateds = set()
|
||||
self.observation_space = gymnasium.spaces.Dict(
|
||||
{
|
||||
name: gymnasium.spaces.flatten_space(agent.observation_manager.space)
|
||||
for name, agent in self.agents.items()
|
||||
}
|
||||
)
|
||||
self.action_space = gymnasium.spaces.Dict(
|
||||
{name: agent.action_manager.space for name, agent in self.agents.items()}
|
||||
)
|
||||
self._obs_space_in_preferred_format = True
|
||||
self._action_space_in_preferred_format = True
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def agents(self) -> Dict[str, ProxyAgent]:
|
||||
"""Grab a fresh reference to the agents from this episode's game object."""
|
||||
return {name: self.game.rl_agents[name] for name in self._agent_ids}
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
|
||||
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
|
||||
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
self.episode_counter += 1
|
||||
PacketCapture.clear()
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
info = {}
|
||||
return next_obs, info
|
||||
|
||||
def step(
|
||||
self, actions: Dict[str, ActType]
|
||||
) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]:
|
||||
"""Perform a step in the environment. Adherent to Ray MultiAgentEnv step API.
|
||||
|
||||
:param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance.
|
||||
:type actions: Dict[str, ActType]
|
||||
:return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent
|
||||
identifier.
|
||||
:rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict]
|
||||
"""
|
||||
step = self.game.step_counter
|
||||
# 1. Perform actions
|
||||
for agent_name, action in actions.items():
|
||||
self.agents[agent_name].store_action(action)
|
||||
self.game.pre_timestep()
|
||||
self.game.apply_agent_actions()
|
||||
|
||||
# 2. Advance timestep
|
||||
self.game.advance_timestep()
|
||||
|
||||
# 3. Get next observations
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
|
||||
# 4. Get rewards
|
||||
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
|
||||
_LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}")
|
||||
terminateds = {name: False for name, _ in self.agents.items()}
|
||||
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
|
||||
infos = {name: {} for name, _ in self.agents.items()}
|
||||
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
|
||||
truncateds["__all__"] = self.game.calculate_truncated()
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(step, actions, state, rewards)
|
||||
return next_obs, rewards, terminateds, truncateds, infos
|
||||
|
||||
def _write_step_metadata_json(self, step: int, actions: Dict, state: Dict, rewards: Dict):
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"step_{step}.json"
|
||||
|
||||
data = {
|
||||
"episode": self.episode_counter,
|
||||
"step": step,
|
||||
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
|
||||
"reward": rewards,
|
||||
"state": state,
|
||||
}
|
||||
with open(path, "w") as file:
|
||||
json.dump(data, file)
|
||||
|
||||
def _get_obs(self) -> Dict[str, ObsType]:
|
||||
"""Return the current observation."""
|
||||
obs = {}
|
||||
for agent_name in self._agent_ids:
|
||||
agent = self.game.rl_agents[agent_name]
|
||||
unflat_space = agent.observation_manager.space
|
||||
unflat_obs = agent.observation_manager.current_observation
|
||||
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
return obs
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
|
||||
class PrimaiteRayEnv(gymnasium.Env):
|
||||
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
|
||||
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
"""Initialise the environment.
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration.
|
||||
:type env_config: Dict
|
||||
"""
|
||||
self.env = PrimaiteGymEnv(env_config=env_config)
|
||||
# self.env.episode_counter -= 1
|
||||
self.action_space = self.env.action_space
|
||||
self.observation_space = self.env.observation_space
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
return self.env.reset(seed=seed)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
|
||||
"""Perform a step in the environment."""
|
||||
return self.env.step(action)
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
self.env.close()
|
||||
|
||||
@property
|
||||
def game(self) -> PrimaiteGame:
|
||||
"""Pass through game from env."""
|
||||
return self.env.game
|
||||
@@ -6,7 +6,9 @@
|
||||
"source": [
|
||||
"# Build a simulation using the Python API\n",
|
||||
"\n",
|
||||
"Currently, this notebook manipulates the simulation by directly placing objects inside of the attributes of the network and domain. It should be refactored when proper methods exist for adding these objects.\n"
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"Currently, this notebook manipulates the simulation by directly placing objects inside of the attributes of the network and domain. It should be refactored when proper methods exist for adding these objects."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
"source": [
|
||||
"# PrimAITE Router Simulation Demo\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"This demo uses a modified version of the ARCD Use Case 2 Network (seen below) to demonstrate the capabilities of the Network simulator in PrimAITE."
|
||||
]
|
||||
},
|
||||
|
||||
@@ -221,7 +221,7 @@ class SimComponent(BaseModel):
|
||||
return state
|
||||
|
||||
@validate_call
|
||||
def apply_request(self, request: RequestFormat, context: Dict = {}) -> RequestResponse:
|
||||
def apply_request(self, request: RequestFormat, context: Optional[Dict] = None) -> RequestResponse:
|
||||
"""
|
||||
Apply a request to a simulation component. Request data is passed in as a 'namespaced' list of strings.
|
||||
|
||||
@@ -239,6 +239,8 @@ class SimComponent(BaseModel):
|
||||
:param: context: Dict containing context for requests
|
||||
:type context: Dict
|
||||
"""
|
||||
if not context:
|
||||
context = None
|
||||
if self._request_manager is None:
|
||||
return
|
||||
return self._request_manager(request, context)
|
||||
|
||||
@@ -28,6 +28,7 @@ from primaite.simulator.network.nmne import (
|
||||
NMNE_CAPTURE_KEYWORDS,
|
||||
)
|
||||
from primaite.simulator.network.transmission.data_link_layer import Frame
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.core.packet_capture import PacketCapture
|
||||
from primaite.simulator.system.core.session_manager import SessionManager
|
||||
@@ -36,6 +37,7 @@ from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.processes.process import Process
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.software import IOSoftware
|
||||
from primaite.utils.converters import convert_dict_enum_keys_to_enum_values
|
||||
from primaite.utils.validators import IPV4Address
|
||||
|
||||
IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware)
|
||||
@@ -107,10 +109,14 @@ class NetworkInterface(SimComponent, ABC):
|
||||
nmne: Dict = Field(default_factory=lambda: {})
|
||||
"A dict containing details of the number of malicious network events captured."
|
||||
|
||||
traffic: Dict = Field(default_factory=lambda: {})
|
||||
"A dict containing details of the inbound and outbound traffic by port and protocol."
|
||||
|
||||
def setup_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
super().setup_for_episode(episode=episode)
|
||||
self.nmne = {}
|
||||
self.traffic = {}
|
||||
if episode and self.pcap and SIM_OUTPUT.save_pcap_logs:
|
||||
self.pcap.current_episode = episode
|
||||
self.pcap.setup_logger()
|
||||
@@ -146,6 +152,7 @@ class NetworkInterface(SimComponent, ABC):
|
||||
)
|
||||
if CAPTURE_NMNE:
|
||||
state.update({"nmne": {k: v for k, v in self.nmne.items()}})
|
||||
state.update({"traffic": convert_dict_enum_keys_to_enum_values(self.traffic)})
|
||||
return state
|
||||
|
||||
@abstractmethod
|
||||
@@ -236,6 +243,47 @@ class NetworkInterface(SimComponent, ABC):
|
||||
# Increment a generic counter if keyword capturing is not enabled
|
||||
keyword_level["*"] = keyword_level.get("*", 0) + 1
|
||||
|
||||
def _capture_traffic(self, frame: Frame, inbound: bool = True):
|
||||
"""
|
||||
Capture traffic statistics at the Network Interface.
|
||||
|
||||
:param frame: The network frame containing the traffic data.
|
||||
:type frame: Frame
|
||||
:param inbound: Flag indicating if the traffic is inbound or outbound. Defaults to True.
|
||||
:type inbound: bool
|
||||
"""
|
||||
# Determine the direction of the traffic
|
||||
direction = "inbound" if inbound else "outbound"
|
||||
|
||||
# Initialize protocol and port variables
|
||||
protocol = None
|
||||
port = None
|
||||
|
||||
# Identify the protocol and port from the frame
|
||||
if frame.tcp:
|
||||
protocol = IPProtocol.TCP
|
||||
port = frame.tcp.dst_port
|
||||
elif frame.udp:
|
||||
protocol = IPProtocol.UDP
|
||||
port = frame.udp.dst_port
|
||||
elif frame.icmp:
|
||||
protocol = IPProtocol.ICMP
|
||||
|
||||
# Ensure the protocol is in the capture dict
|
||||
if protocol not in self.traffic:
|
||||
self.traffic[protocol] = {}
|
||||
|
||||
# Handle non-ICMP protocols that use ports
|
||||
if protocol != IPProtocol.ICMP:
|
||||
if port not in self.traffic[protocol]:
|
||||
self.traffic[protocol][port] = {"inbound": 0, "outbound": 0}
|
||||
self.traffic[protocol][port][direction] += frame.size
|
||||
else:
|
||||
# Handle ICMP protocol separately (ICMP does not use ports)
|
||||
if not self.traffic[protocol]:
|
||||
self.traffic[protocol] = {"inbound": 0, "outbound": 0}
|
||||
self.traffic[protocol][direction] += frame.size
|
||||
|
||||
@abstractmethod
|
||||
def send_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
@@ -245,6 +293,7 @@ class NetworkInterface(SimComponent, ABC):
|
||||
:return: A boolean indicating whether the frame was successfully sent.
|
||||
"""
|
||||
self._capture_nmne(frame, inbound=False)
|
||||
self._capture_traffic(frame, inbound=False)
|
||||
|
||||
@abstractmethod
|
||||
def receive_frame(self, frame: Frame) -> bool:
|
||||
@@ -255,6 +304,7 @@ class NetworkInterface(SimComponent, ABC):
|
||||
:return: A boolean indicating whether the frame was successfully received.
|
||||
"""
|
||||
self._capture_nmne(frame, inbound=True)
|
||||
self._capture_traffic(frame, inbound=True)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
@@ -766,6 +816,24 @@ class Node(SimComponent):
|
||||
self.session_manager.software_manager = self.software_manager
|
||||
self._install_system_software()
|
||||
|
||||
def ip_is_network_interface(self, ip_address: IPv4Address, enabled_only: bool = False) -> bool:
|
||||
"""
|
||||
Checks if a given IP address belongs to any of the nodes interfaces.
|
||||
|
||||
:param ip_address: The IP address to check.
|
||||
:param enabled_only: If True, only considers enabled network interfaces.
|
||||
:return: True if the IP address is assigned to one of the nodes interfaces; False otherwise.
|
||||
"""
|
||||
for network_interface in self.network_interface.values():
|
||||
if not hasattr(network_interface, "ip_address"):
|
||||
continue
|
||||
if network_interface.ip_address == ip_address:
|
||||
if enabled_only:
|
||||
return network_interface.enabled
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
def setup_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
super().setup_for_episode(episode=episode)
|
||||
|
||||
@@ -7,6 +7,8 @@ from primaite import getLogger
|
||||
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface, Link, Node
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.transmission.data_link_layer import Frame
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.nmap import NMAP
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.arp.arp import ARP, ARPPacket
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
@@ -302,6 +304,7 @@ class HostNode(Node):
|
||||
"DNSClient": DNSClient,
|
||||
"NTPClient": NTPClient,
|
||||
"WebBrowser": WebBrowser,
|
||||
"NMAP": NMAP,
|
||||
}
|
||||
"""List of system software that is automatically installed on nodes."""
|
||||
|
||||
@@ -314,6 +317,16 @@ class HostNode(Node):
|
||||
super().__init__(**kwargs)
|
||||
self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask))
|
||||
|
||||
@property
|
||||
def nmap(self) -> Optional[NMAP]:
|
||||
"""
|
||||
Return the NMAP application installed on the Node.
|
||||
|
||||
:return: NMAP application installed on the Node.
|
||||
:rtype: Optional[NMAP]
|
||||
"""
|
||||
return self.software_manager.software.get("NMAP")
|
||||
|
||||
@property
|
||||
def arp(self) -> Optional[ARP]:
|
||||
"""
|
||||
@@ -365,8 +378,15 @@ class HostNode(Node):
|
||||
elif frame.udp:
|
||||
dst_port = frame.udp.dst_port
|
||||
|
||||
can_accept_nmap = False
|
||||
if self.software_manager.software.get("NMAP"):
|
||||
if self.software_manager.software["NMAP"].operating_state == ApplicationOperatingState.RUNNING:
|
||||
can_accept_nmap = True
|
||||
|
||||
accept_nmap = can_accept_nmap and frame.payload.__class__.__name__ == "PortScanPayload"
|
||||
|
||||
accept_frame = False
|
||||
if frame.icmp or dst_port in self.software_manager.get_open_ports():
|
||||
if frame.icmp or dst_port in self.software_manager.get_open_ports() or accept_nmap:
|
||||
# accept the frame as the port is open or if it's an ICMP frame
|
||||
accept_frame = True
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType
|
||||
from primaite.simulator.network.transmission.data_link_layer import Frame
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.nmap import NMAP
|
||||
from primaite.simulator.system.core.session_manager import SessionManager
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.services.arp.arp import ARP
|
||||
@@ -1238,6 +1239,7 @@ class Router(NetworkNode):
|
||||
icmp.router = self
|
||||
self.software_manager.install(RouterARP)
|
||||
self.arp.router = self
|
||||
self.software_manager.install(NMAP)
|
||||
|
||||
def _set_default_acl(self):
|
||||
"""
|
||||
|
||||
@@ -270,9 +270,16 @@ class DatabaseClient(Application):
|
||||
|
||||
Calls disconnect on all client connections to ensure that both client and server connections are killed.
|
||||
"""
|
||||
while self.client_connections.values():
|
||||
client_connection = self.client_connections[next(iter(self.client_connections.keys()))]
|
||||
client_connection.disconnect()
|
||||
while self.client_connections:
|
||||
conn_key = next(iter(self.client_connections.keys()))
|
||||
conn_obj: DatabaseClientConnection = self.client_connections[conn_key]
|
||||
conn_obj.disconnect()
|
||||
if conn_obj.is_active or conn_key in self.client_connections:
|
||||
self.sys_log.error(
|
||||
"Attempted to uninstall database client but could not drop active connections. "
|
||||
"Forcing uninstall anyway."
|
||||
)
|
||||
self.client_connections.pop(conn_key, None)
|
||||
super().uninstall()
|
||||
|
||||
def get_new_connection(self) -> Optional[DatabaseClientConnection]:
|
||||
|
||||
451
src/primaite/simulator/system/applications/nmap.py
Normal file
@@ -0,0 +1,451 @@
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Final, List, Optional, Set, Tuple, Union
|
||||
|
||||
from prettytable import PrettyTable
|
||||
from pydantic import validate_call
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.utils.validators import IPV4Address
|
||||
|
||||
|
||||
class PortScanPayload(SimComponent):
|
||||
"""
|
||||
A class representing the payload for a port scan.
|
||||
|
||||
:ivar ip_address: The target IP address for the port scan.
|
||||
:ivar port: The target port for the port scan.
|
||||
:ivar protocol: The protocol used for the port scan.
|
||||
:ivar request:Flag to indicate whether this is a request or not.
|
||||
"""
|
||||
|
||||
ip_address: IPV4Address
|
||||
port: Port
|
||||
protocol: IPProtocol
|
||||
request: bool = True
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describe the state of the port scan payload.
|
||||
|
||||
:return: A dictionary representation of the port scan payload state.
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["ip_address"] = str(self.ip_address)
|
||||
state["port"] = self.port.value
|
||||
state["protocol"] = self.protocol.value
|
||||
state["request"] = self.request
|
||||
|
||||
return state
|
||||
|
||||
|
||||
class NMAP(Application):
|
||||
"""
|
||||
A class representing the NMAP application for network scanning.
|
||||
|
||||
NMAP is a network scanning tool used to discover hosts and services on a network. It provides functionalities such
|
||||
as ping scans to discover active hosts and port scans to detect open ports on those hosts.
|
||||
"""
|
||||
|
||||
_active_port_scans: Dict[str, PortScanPayload] = {}
|
||||
_port_scan_responses: Dict[str, PortScanPayload] = {}
|
||||
|
||||
_PORT_SCAN_TYPE_MAP: Final[Dict[Tuple[bool, bool], str]] = {
|
||||
(True, True): "Box",
|
||||
(True, False): "Horizontal",
|
||||
(False, True): "Vertical",
|
||||
(False, False): "Port",
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "NMAP"
|
||||
kwargs["port"] = Port.NONE
|
||||
kwargs["protocol"] = IPProtocol.NONE
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _can_perform_network_action(self) -> bool:
|
||||
"""
|
||||
Checks if the NMAP application can perform outbound network actions.
|
||||
|
||||
This is done by checking the parent application can_per_action functionality. Then checking if there is an
|
||||
enabled NIC that can be used for outbound traffic.
|
||||
|
||||
:return: True if outbound network actions can be performed, otherwise False.
|
||||
"""
|
||||
if not super()._can_perform_action():
|
||||
return False
|
||||
|
||||
for nic in self.software_manager.node.network_interface.values():
|
||||
if nic.enabled:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
def _ping_scan_action(request: List[Any], context: Any) -> RequestResponse:
|
||||
results = self.ping_scan(target_ip_address=request[0]["target_ip_address"], json_serializable=True)
|
||||
if not self._can_perform_network_action():
|
||||
return RequestResponse.from_bool(False)
|
||||
return RequestResponse(
|
||||
status="success",
|
||||
data={"live_hosts": results},
|
||||
)
|
||||
|
||||
def _port_scan_action(request: List[Any], context: Any) -> RequestResponse:
|
||||
results = self.port_scan(**request[0], json_serializable=True)
|
||||
if not self._can_perform_network_action():
|
||||
return RequestResponse.from_bool(False)
|
||||
return RequestResponse(
|
||||
status="success",
|
||||
data=results,
|
||||
)
|
||||
|
||||
def _network_service_recon_action(request: List[Any], context: Any) -> RequestResponse:
|
||||
results = self.network_service_recon(**request[0], json_serializable=True)
|
||||
if not self._can_perform_network_action():
|
||||
return RequestResponse.from_bool(False)
|
||||
return RequestResponse(
|
||||
status="success",
|
||||
data=results,
|
||||
)
|
||||
|
||||
rm = RequestManager()
|
||||
|
||||
rm.add_request(
|
||||
name="ping_scan",
|
||||
request_type=RequestType(func=_ping_scan_action),
|
||||
)
|
||||
|
||||
rm.add_request(
|
||||
name="port_scan",
|
||||
request_type=RequestType(func=_port_scan_action),
|
||||
)
|
||||
|
||||
rm.add_request(
|
||||
name="network_service_recon",
|
||||
request_type=RequestType(func=_network_service_recon_action),
|
||||
)
|
||||
|
||||
return rm
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describe the state of the NMAP application.
|
||||
|
||||
:return: A dictionary representation of the NMAP application's state.
|
||||
:rtype: Dict
|
||||
"""
|
||||
return super().describe_state()
|
||||
|
||||
@staticmethod
|
||||
def _explode_ip_address_network_array(
|
||||
target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]]
|
||||
) -> Set[IPv4Address]:
|
||||
"""
|
||||
Explode a mixed array of IP addresses and networks into a set of individual IP addresses.
|
||||
|
||||
This method takes a combination of single and lists of IPv4 addresses and IPv4 networks, expands any networks
|
||||
into their constituent subnet useable IP addresses, and returns a set of unique IP addresses. Broadcast and
|
||||
network addresses are excluded from the result.
|
||||
|
||||
:param target_ip_address: A single or list of IPv4 addresses and networks.
|
||||
:type target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]]
|
||||
:return: A set of unique IPv4 addresses expanded from the input.
|
||||
:rtype: Set[IPv4Address]
|
||||
"""
|
||||
if isinstance(target_ip_address, IPv4Address) or isinstance(target_ip_address, IPv4Network):
|
||||
target_ip_address = [target_ip_address]
|
||||
ip_addresses: List[IPV4Address] = []
|
||||
for ip_address in target_ip_address:
|
||||
if isinstance(ip_address, IPv4Network):
|
||||
ip_addresses += [
|
||||
ip
|
||||
for ip in ip_address.hosts()
|
||||
if not ip == ip_address.broadcast_address and not ip == ip_address.network_address
|
||||
]
|
||||
else:
|
||||
ip_addresses.append(ip_address)
|
||||
return set(ip_addresses)
|
||||
|
||||
@validate_call()
|
||||
def ping_scan(
|
||||
self,
|
||||
target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]],
|
||||
show: bool = True,
|
||||
show_online_only: bool = True,
|
||||
json_serializable: bool = False,
|
||||
) -> Union[List[IPV4Address], List[str]]:
|
||||
"""
|
||||
Perform a ping scan on the target IP address(es).
|
||||
|
||||
:param target_ip_address: The target IP address(es) or network(s) for the ping scan.
|
||||
:type target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]]
|
||||
:param show: Flag indicating whether to display the scan results. Defaults to True.
|
||||
:type show: bool
|
||||
:param show_online_only: Flag indicating whether to show only the online hosts. Defaults to True.
|
||||
:type show_online_only: bool
|
||||
:param json_serializable: Flag indicating whether the return value should be json serializable. Defaults to
|
||||
False.
|
||||
:type json_serializable: bool
|
||||
|
||||
:return: A list of active IP addresses that responded to the ping.
|
||||
:rtype: Union[List[IPV4Address], List[str]]
|
||||
"""
|
||||
active_nodes = []
|
||||
if show:
|
||||
table = PrettyTable(["IP Address", "Can Ping"])
|
||||
table.align = "l"
|
||||
table.title = f"{self.software_manager.node.hostname} NMAP Ping Scan"
|
||||
|
||||
ip_addresses = self._explode_ip_address_network_array(target_ip_address)
|
||||
|
||||
for ip_address in ip_addresses:
|
||||
# Prevent ping scan on this node
|
||||
if self.software_manager.node.ip_is_network_interface(ip_address=ip_address):
|
||||
continue
|
||||
can_ping = self.software_manager.icmp.ping(ip_address)
|
||||
if can_ping:
|
||||
active_nodes.append(ip_address if not json_serializable else str(ip_address))
|
||||
if show and (can_ping or not show_online_only):
|
||||
table.add_row([ip_address, can_ping])
|
||||
if show:
|
||||
print(table.get_string(sortby="IP Address"))
|
||||
return active_nodes
|
||||
|
||||
def _determine_port_scan_type(self, target_ip_addresses: List[IPV4Address], target_ports: List[Port]) -> str:
|
||||
"""
|
||||
Determine the type of port scan based on the number of target IP addresses and ports.
|
||||
|
||||
:param target_ip_addresses: The list of target IP addresses.
|
||||
:type target_ip_addresses: List[IPV4Address]
|
||||
:param target_ports: The list of target ports.
|
||||
:type target_ports: List[Port]
|
||||
|
||||
:return: The type of port scan.
|
||||
:rtype: str
|
||||
"""
|
||||
vertical_scan = len(target_ports) > 1
|
||||
horizontal_scan = len(target_ip_addresses) > 1
|
||||
|
||||
return self._PORT_SCAN_TYPE_MAP[horizontal_scan, vertical_scan]
|
||||
|
||||
def _check_port_open_on_ip_address(
|
||||
self,
|
||||
ip_address: IPv4Address,
|
||||
port: Port,
|
||||
protocol: IPProtocol,
|
||||
is_re_attempt: bool = False,
|
||||
port_scan_uuid: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a port is open on a specific IP address.
|
||||
|
||||
:param ip_address: The target IP address.
|
||||
:type ip_address: IPv4Address
|
||||
:param port: The target port.
|
||||
:type port: Port
|
||||
:param protocol: The protocol used for the port scan.
|
||||
:type protocol: IPProtocol
|
||||
:param is_re_attempt: Flag indicating if this is a reattempt. Defaults to False.
|
||||
:type is_re_attempt: bool
|
||||
:param port_scan_uuid: The UUID of the port scan payload. Defaults to None.
|
||||
:type port_scan_uuid: Optional[str]
|
||||
|
||||
:return: True if the port is open, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
# The recursive base case
|
||||
if is_re_attempt:
|
||||
# Return True if a response has been received, otherwise return False
|
||||
if port_scan_uuid in self._port_scan_responses:
|
||||
self._port_scan_responses.pop(port_scan_uuid)
|
||||
return True
|
||||
return False
|
||||
|
||||
# Send the port scan request
|
||||
payload = PortScanPayload(ip_address=ip_address, port=port, protocol=protocol)
|
||||
self._active_port_scans[payload.uuid] = payload
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Sending port scan request over {payload.protocol.name} on port {payload.port.value} "
|
||||
f"({payload.port.name}) to {payload.ip_address}"
|
||||
)
|
||||
self.software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=ip_address, src_port=port, dest_port=port, ip_protocol=protocol
|
||||
)
|
||||
|
||||
# Recursively call this function with as a reattempt
|
||||
return self._check_port_open_on_ip_address(
|
||||
ip_address=ip_address, port=port, protocol=protocol, is_re_attempt=True, port_scan_uuid=payload.uuid
|
||||
)
|
||||
|
||||
def _process_port_scan_response(self, payload: PortScanPayload):
|
||||
"""
|
||||
Process the response to a port scan request.
|
||||
|
||||
:param payload: The port scan payload received in response.
|
||||
:type payload: PortScanPayload
|
||||
"""
|
||||
if payload.uuid in self._active_port_scans:
|
||||
self._active_port_scans.pop(payload.uuid)
|
||||
self._port_scan_responses[payload.uuid] = payload
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Received port scan response from {payload.ip_address} on port {payload.port.value} "
|
||||
f"({payload.port.name}) over {payload.protocol.name}"
|
||||
)
|
||||
|
||||
def _process_port_scan_request(self, payload: PortScanPayload, session_id: str) -> None:
|
||||
"""
|
||||
Process a port scan request.
|
||||
|
||||
:param payload: The port scan payload received in the request.
|
||||
:type payload: PortScanPayload
|
||||
:param session_id: The session ID for the port scan request.
|
||||
:type session_id: str
|
||||
"""
|
||||
if self.software_manager.check_port_is_open(port=payload.port, protocol=payload.protocol):
|
||||
payload.request = False
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Responding to port scan request for port {payload.port.value} "
|
||||
f"({payload.port.name}) over {payload.protocol.name}",
|
||||
True,
|
||||
)
|
||||
self.software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
|
||||
|
||||
@validate_call()
|
||||
def port_scan(
|
||||
self,
|
||||
target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]],
|
||||
target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None,
|
||||
target_port: Optional[Union[Port, List[Port]]] = None,
|
||||
show: bool = True,
|
||||
json_serializable: bool = False,
|
||||
) -> Dict[IPv4Address, Dict[IPProtocol, List[Port]]]:
|
||||
"""
|
||||
Perform a port scan on the target IP address(es).
|
||||
|
||||
:param target_ip_address: The target IP address(es) or network(s) for the port scan.
|
||||
:type target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]]
|
||||
:param target_protocol: The protocol(s) to use for the port scan. Defaults to None, which includes TCP and UDP.
|
||||
:type target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]]
|
||||
:param target_port: The port(s) to scan. Defaults to None, which includes all valid ports.
|
||||
:type target_port: Optional[Union[Port, List[Port]]]
|
||||
:param show: Flag indicating whether to display the scan results. Defaults to True.
|
||||
:type show: bool
|
||||
:param json_serializable: Flag indicating whether the return value should be JSON serializable. Defaults to
|
||||
False.
|
||||
:type json_serializable: bool
|
||||
|
||||
:return: A dictionary mapping IP addresses to protocols and lists of open ports.
|
||||
:rtype: Dict[IPv4Address, Dict[IPProtocol, List[Port]]]
|
||||
"""
|
||||
ip_addresses = self._explode_ip_address_network_array(target_ip_address)
|
||||
|
||||
if isinstance(target_port, Port):
|
||||
target_port = [target_port]
|
||||
elif target_port is None:
|
||||
target_port = [port for port in Port if port not in {Port.NONE, Port.UNUSED}]
|
||||
|
||||
if isinstance(target_protocol, IPProtocol):
|
||||
target_protocol = [target_protocol]
|
||||
elif target_protocol is None:
|
||||
target_protocol = [IPProtocol.TCP, IPProtocol.UDP]
|
||||
|
||||
scan_type = self._determine_port_scan_type(list(ip_addresses), target_port)
|
||||
active_ports = {}
|
||||
if show:
|
||||
table = PrettyTable(["IP Address", "Port", "Name", "Protocol"])
|
||||
table.align = "l"
|
||||
table.title = f"{self.software_manager.node.hostname} NMAP Port Scan ({scan_type})"
|
||||
self.sys_log.info(f"{self.name}: Starting port scan")
|
||||
for ip_address in ip_addresses:
|
||||
# Prevent port scan on this node
|
||||
if self.software_manager.node.ip_is_network_interface(ip_address=ip_address):
|
||||
continue
|
||||
for protocol in target_protocol:
|
||||
for port in set(target_port):
|
||||
port_open = self._check_port_open_on_ip_address(ip_address=ip_address, port=port, protocol=protocol)
|
||||
|
||||
if port_open:
|
||||
table.add_row([ip_address, port.value, port.name, protocol.name])
|
||||
_ip_address = ip_address if not json_serializable else str(ip_address)
|
||||
_protocol = protocol if not json_serializable else protocol.value
|
||||
_port = port if not json_serializable else port.value
|
||||
if _ip_address not in active_ports:
|
||||
active_ports[_ip_address] = dict()
|
||||
if _protocol not in active_ports[_ip_address]:
|
||||
active_ports[_ip_address][_protocol] = []
|
||||
active_ports[_ip_address][_protocol].append(_port)
|
||||
|
||||
if show:
|
||||
print(table.get_string(sortby="IP Address"))
|
||||
|
||||
return active_ports
|
||||
|
||||
def network_service_recon(
|
||||
self,
|
||||
target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]],
|
||||
target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None,
|
||||
target_port: Optional[Union[Port, List[Port]]] = None,
|
||||
show: bool = True,
|
||||
show_online_only: bool = True,
|
||||
json_serializable: bool = False,
|
||||
) -> Dict[IPv4Address, Dict[IPProtocol, List[Port]]]:
|
||||
"""
|
||||
Perform a network service reconnaissance which includes a ping scan followed by a port scan.
|
||||
|
||||
This method combines the functionalities of a ping scan and a port scan to provide a comprehensive
|
||||
overview of the services on the network. It first identifies active hosts in the target IP range by performing
|
||||
a ping scan. Once the active hosts are identified, it performs a port scan on these hosts to identify open
|
||||
ports and running services. This two-step process ensures that the port scan is performed only on live hosts,
|
||||
optimising the scanning process and providing accurate results.
|
||||
|
||||
:param target_ip_address: The target IP address(es) or network(s) for the port scan.
|
||||
:type target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]]
|
||||
:param target_protocol: The protocol(s) to use for the port scan. Defaults to None, which includes TCP and UDP.
|
||||
:type target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]]
|
||||
:param target_port: The port(s) to scan. Defaults to None, which includes all valid ports.
|
||||
:type target_port: Optional[Union[Port, List[Port]]]
|
||||
:param show: Flag indicating whether to display the scan results. Defaults to True.
|
||||
:type show: bool
|
||||
:param show_online_only: Flag indicating whether to show only the online hosts. Defaults to True.
|
||||
:type show_online_only: bool
|
||||
:param json_serializable: Flag indicating whether the return value should be JSON serializable. Defaults to
|
||||
False.
|
||||
:type json_serializable: bool
|
||||
|
||||
:return: A dictionary mapping IP addresses to protocols and lists of open ports.
|
||||
:rtype: Dict[IPv4Address, Dict[IPProtocol, List[Port]]]
|
||||
"""
|
||||
ping_scan_results = self.ping_scan(
|
||||
target_ip_address=target_ip_address, show=show, show_online_only=show_online_only, json_serializable=False
|
||||
)
|
||||
return self.port_scan(
|
||||
target_ip_address=ping_scan_results,
|
||||
target_protocol=target_protocol,
|
||||
target_port=target_port,
|
||||
show=show,
|
||||
json_serializable=json_serializable,
|
||||
)
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
"""
|
||||
Receive and process a payload.
|
||||
|
||||
:param payload: The payload to be processed.
|
||||
:type payload: Any
|
||||
:param session_id: The session ID associated with the payload.
|
||||
:type session_id: str
|
||||
|
||||
:return: True if the payload was successfully processed, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
if isinstance(payload, PortScanPayload):
|
||||
if payload.request:
|
||||
self._process_port_scan_request(payload=payload, session_id=session_id)
|
||||
else:
|
||||
self._process_port_scan_response(payload=payload)
|
||||
|
||||
return True
|
||||
@@ -1,8 +1,6 @@
|
||||
from enum import IntEnum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
@@ -11,43 +9,10 @@ from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
|
||||
|
||||
class RansomwareAttackStage(IntEnum):
|
||||
"""
|
||||
Enumeration representing different attack stages of the ransomware script.
|
||||
|
||||
This enumeration defines the various stages a data manipulation attack can be in during its lifecycle
|
||||
in the simulation.
|
||||
Each stage represents a specific phase in the attack process.
|
||||
"""
|
||||
|
||||
NOT_STARTED = 0
|
||||
"Indicates that the attack has not started yet."
|
||||
DOWNLOAD = 1
|
||||
"Installing the Encryption Script - Testing"
|
||||
INSTALL = 2
|
||||
"The stage where logon procedures are simulated."
|
||||
ACTIVATE = 3
|
||||
"Operating Status Changes"
|
||||
PROPAGATE = 4
|
||||
"Represents the stage of performing a horizontal port scan on the target."
|
||||
COMMAND_AND_CONTROL = 5
|
||||
"Represents the stage of setting up a rely C2 Beacon (Not Implemented)"
|
||||
PAYLOAD = 6
|
||||
"Stage of actively attacking the target."
|
||||
SUCCEEDED = 7
|
||||
"Indicates the attack has been successfully completed."
|
||||
FAILED = 8
|
||||
"Signifies that the attack has failed."
|
||||
|
||||
|
||||
class RansomwareScript(Application):
|
||||
"""Ransomware Kill Chain - Designed to be used by the TAP001 Agent on the example layout Network.
|
||||
|
||||
:ivar payload: The attack stage query payload. (Default Corrupt)
|
||||
:ivar target_scan_p_of_success: The probability of success for the target scan stage.
|
||||
:ivar c2_beacon_p_of_success: The probability of success for the c2_beacon stage
|
||||
:ivar ransomware_encrypt_p_of_success: The probability of success for the ransomware 'attack' (encrypt) stage.
|
||||
:ivar repeat: Whether to repeat attacking once finished.
|
||||
:ivar payload: The attack stage query payload. (Default ENCRYPT)
|
||||
"""
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
@@ -56,16 +21,6 @@ class RansomwareScript(Application):
|
||||
"""Password required to access the database."""
|
||||
payload: Optional[str] = "ENCRYPT"
|
||||
"Payload String for the payload stage"
|
||||
target_scan_p_of_success: float = 0.9
|
||||
"Probability of the target scan succeeding: Default 0.9"
|
||||
c2_beacon_p_of_success: float = 0.9
|
||||
"Probability of the c2 beacon setup stage succeeding: Default 0.9"
|
||||
ransomware_encrypt_p_of_success: float = 0.9
|
||||
"Probability of the ransomware attack succeeding: Default 0.9"
|
||||
repeat: bool = False
|
||||
"If true, the Denial of Service bot will keep performing the attack."
|
||||
attack_stage: RansomwareAttackStage = RansomwareAttackStage.NOT_STARTED
|
||||
"The ransomware attack stage. See RansomwareAttackStage Class"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "RansomwareScript"
|
||||
@@ -90,7 +45,7 @@ class RansomwareScript(Application):
|
||||
@property
|
||||
def _host_db_client(self) -> DatabaseClient:
|
||||
"""Return the database client that is installed on the same machine as the Ransomware Script."""
|
||||
db_client = self.software_manager.software.get("DatabaseClient")
|
||||
db_client: DatabaseClient = self.software_manager.software.get("DatabaseClient")
|
||||
if db_client is None:
|
||||
self.sys_log.warning(f"{self.__class__.__name__} cannot find a database client on its host.")
|
||||
return db_client
|
||||
@@ -108,16 +63,6 @@ class RansomwareScript(Application):
|
||||
)
|
||||
return rm
|
||||
|
||||
def _activate(self):
|
||||
"""
|
||||
Simulate the install process as the initial stage of the attack.
|
||||
|
||||
Advances the attack stage to 'ACTIVATE' attack state.
|
||||
"""
|
||||
if self.attack_stage == RansomwareAttackStage.INSTALL:
|
||||
self.sys_log.info(f"{self.name}: Activated!")
|
||||
self.attack_stage = RansomwareAttackStage.ACTIVATE
|
||||
|
||||
def run(self) -> bool:
|
||||
"""Calls the parent classes execute method before starting the application loop."""
|
||||
super().run()
|
||||
@@ -133,20 +78,9 @@ class RansomwareScript(Application):
|
||||
return False
|
||||
if self.server_ip_address and self.payload:
|
||||
self.sys_log.info(f"{self.name}: Running")
|
||||
self.attack_stage = RansomwareAttackStage.NOT_STARTED
|
||||
self._local_download()
|
||||
self._install()
|
||||
self._activate()
|
||||
self._perform_target_scan()
|
||||
self._setup_beacon()
|
||||
self._perform_ransomware_encrypt()
|
||||
|
||||
if self.repeat and self.attack_stage in (
|
||||
RansomwareAttackStage.SUCCEEDED,
|
||||
RansomwareAttackStage.FAILED,
|
||||
):
|
||||
self.attack_stage = RansomwareAttackStage.NOT_STARTED
|
||||
return True
|
||||
if self._perform_ransomware_encrypt():
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
self.sys_log.warning(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.")
|
||||
return False
|
||||
@@ -156,10 +90,6 @@ class RansomwareScript(Application):
|
||||
server_ip_address: IPv4Address,
|
||||
server_password: Optional[str] = None,
|
||||
payload: Optional[str] = None,
|
||||
target_scan_p_of_success: Optional[float] = None,
|
||||
c2_beacon_p_of_success: Optional[float] = None,
|
||||
ransomware_encrypt_p_of_success: Optional[float] = None,
|
||||
repeat: bool = True,
|
||||
):
|
||||
"""
|
||||
Configure the Ransomware Script to communicate with a DatabaseService.
|
||||
@@ -167,10 +97,6 @@ class RansomwareScript(Application):
|
||||
:param server_ip_address: The IP address of the Node the DatabaseService is on.
|
||||
:param server_password: The password on the DatabaseService.
|
||||
:param payload: The attack stage query (Encrypt / Delete)
|
||||
:param target_scan_p_of_success: The probability of success for the target scan stage.
|
||||
:param c2_beacon_p_of_success: The probability of success for the c2_beacon stage
|
||||
:param ransomware_encrypt_p_of_success: The probability of success for the ransomware 'attack' (encrypt) stage.
|
||||
:param repeat: Whether to repeat attacking once finished.
|
||||
"""
|
||||
if server_ip_address:
|
||||
self.server_ip_address = server_ip_address
|
||||
@@ -178,74 +104,15 @@ class RansomwareScript(Application):
|
||||
self.server_password = server_password
|
||||
if payload:
|
||||
self.payload = payload
|
||||
if target_scan_p_of_success:
|
||||
self.target_scan_p_of_success = target_scan_p_of_success
|
||||
if c2_beacon_p_of_success:
|
||||
self.c2_beacon_p_of_success = c2_beacon_p_of_success
|
||||
if ransomware_encrypt_p_of_success:
|
||||
self.ransomware_encrypt_p_of_success = ransomware_encrypt_p_of_success
|
||||
if repeat:
|
||||
self.repeat = repeat
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}, "
|
||||
f"{repeat=}."
|
||||
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}."
|
||||
)
|
||||
|
||||
def _install(self):
|
||||
"""
|
||||
Simulate the install stage in the kill-chain.
|
||||
|
||||
Advances the attack stage to 'ACTIVATE' if successful.
|
||||
|
||||
From this attack stage onwards.
|
||||
the ransomware application is now visible from this point onwardin the observation space.
|
||||
"""
|
||||
if self.attack_stage == RansomwareAttackStage.DOWNLOAD:
|
||||
self.sys_log.info(f"{self.name}: Malware installed on the local file system")
|
||||
downloads_folder = self.file_system.get_folder(folder_name="downloads")
|
||||
ransomware_file = downloads_folder.get_file(file_name="ransom_script.pdf")
|
||||
ransomware_file.num_access += 1
|
||||
self.attack_stage = RansomwareAttackStage.INSTALL
|
||||
|
||||
def _setup_beacon(self):
|
||||
"""
|
||||
Simulates setting up a c2 beacon; currently a pseudo step for increasing red variance.
|
||||
|
||||
Advances the attack stage to 'COMMAND AND CONTROL` if successful.
|
||||
|
||||
:param p_of_sucess: Probability of a successful c2 setup (Advancing this step),
|
||||
by default the success rate is 0.5
|
||||
"""
|
||||
if self.attack_stage == RansomwareAttackStage.PROPAGATE:
|
||||
self.sys_log.info(f"{self.name} Attempting to set up C&C Beacon - Scan 1/2")
|
||||
if simulate_trial(self.c2_beacon_p_of_success):
|
||||
self.sys_log.info(f"{self.name} C&C Successful setup - Scan 2/2")
|
||||
c2c_setup = True # TODO Implement the c2c step via an FTP Application/Service
|
||||
if c2c_setup:
|
||||
self.attack_stage = RansomwareAttackStage.COMMAND_AND_CONTROL
|
||||
|
||||
def _perform_target_scan(self):
|
||||
"""
|
||||
Perform a simulated port scan to check for open SQL ports.
|
||||
|
||||
Advances the attack stage to `PROPAGATE` if successful.
|
||||
|
||||
:param p_of_success: Probability of successful port scan, by default 0.1.
|
||||
"""
|
||||
if self.attack_stage == RansomwareAttackStage.ACTIVATE:
|
||||
# perform a port scan to identify that the SQL port is open on the server
|
||||
self.sys_log.info(f"{self.name}: Scanning for vulnerable databases - Scan 0/2")
|
||||
if simulate_trial(self.target_scan_p_of_success):
|
||||
self.sys_log.info(f"{self.name}: Found a target database! Scan 1/2")
|
||||
port_is_open = True # TODO Implement a NNME Triggering scan as a seperate Red Application
|
||||
if port_is_open:
|
||||
self.attack_stage = RansomwareAttackStage.PROPAGATE
|
||||
|
||||
def attack(self) -> bool:
|
||||
"""Perform the attack steps after opening the application."""
|
||||
self.run()
|
||||
if not self._can_perform_action():
|
||||
self.sys_log.warning("Ransomware application is unable to perform it's actions.")
|
||||
self.run()
|
||||
self.num_executions += 1
|
||||
return self._application_loop()
|
||||
|
||||
@@ -254,57 +121,30 @@ class RansomwareScript(Application):
|
||||
self._db_connection = self._host_db_client.get_new_connection()
|
||||
return True if self._db_connection else False
|
||||
|
||||
def _perform_ransomware_encrypt(self):
|
||||
def _perform_ransomware_encrypt(self) -> bool:
|
||||
"""
|
||||
Execute the Ransomware Encrypt payload on the target.
|
||||
|
||||
Advances the attack stage to `COMPLETE` if successful, or 'FAILED' if unsuccessful.
|
||||
:param p_of_success: Probability of successfully performing ransomware encryption, by default 0.1.
|
||||
"""
|
||||
if self._host_db_client is None:
|
||||
self.sys_log.info(f"{self.name}: Failed to connect to db_client - Ransomware Script")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
return
|
||||
return False
|
||||
|
||||
self._host_db_client.server_ip_address = self.server_ip_address
|
||||
self._host_db_client.server_password = self.server_password
|
||||
if self.attack_stage == RansomwareAttackStage.COMMAND_AND_CONTROL:
|
||||
if simulate_trial(self.ransomware_encrypt_p_of_success):
|
||||
self.sys_log.info(f"{self.name}: Attempting to launch payload")
|
||||
if not self._db_connection:
|
||||
self._establish_db_connection()
|
||||
if self._db_connection:
|
||||
attack_successful = self._db_connection.query(self.payload)
|
||||
self.sys_log.info(f"{self.name} Payload delivered: {self.payload}")
|
||||
if attack_successful:
|
||||
self.sys_log.info(f"{self.name}: Payload Successful")
|
||||
self.attack_stage = RansomwareAttackStage.SUCCEEDED
|
||||
else:
|
||||
self.sys_log.info(f"{self.name}: Payload failed")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
self.sys_log.info(f"{self.name}: Attempting to launch payload")
|
||||
if not self._db_connection:
|
||||
self._establish_db_connection()
|
||||
if self._db_connection:
|
||||
attack_successful = self._db_connection.query(self.payload)
|
||||
self.sys_log.info(f"{self.name} Payload delivered: {self.payload}")
|
||||
if attack_successful:
|
||||
self.sys_log.info(f"{self.name}: Payload Successful")
|
||||
return True
|
||||
else:
|
||||
self.sys_log.info(f"{self.name}: Payload failed")
|
||||
return False
|
||||
else:
|
||||
self.sys_log.warning("Attack Attempted to launch too quickly")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
|
||||
def _local_download(self):
|
||||
"""Downloads itself via the onto the local file_system."""
|
||||
if self.attack_stage == RansomwareAttackStage.NOT_STARTED:
|
||||
if self._local_download_verify():
|
||||
self.attack_stage = RansomwareAttackStage.DOWNLOAD
|
||||
else:
|
||||
self.sys_log.info("Malware failed to create a installation location")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
else:
|
||||
self.sys_log.info("Malware failed to download")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
|
||||
def _local_download_verify(self) -> bool:
|
||||
"""Verifies a download location - Creates one if needed."""
|
||||
for folder in self.file_system.folders:
|
||||
if self.file_system.folders[folder].name == "downloads":
|
||||
self.file_system.num_file_creations += 1
|
||||
return True
|
||||
|
||||
self.file_system.create_folder("downloads")
|
||||
self.file_system.create_file(folder_name="downloads", file_name="ransom_script.pdf")
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -21,6 +21,8 @@ class PacketCapture:
|
||||
The PCAPs are logged to: <simulation output directory>/<hostname>/<hostname>_<ip address>_pcap.log
|
||||
"""
|
||||
|
||||
_logger_instances: List[logging.Logger] = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hostname: str,
|
||||
@@ -65,10 +67,12 @@ class PacketCapture:
|
||||
|
||||
if outbound:
|
||||
self.outbound_logger = logging.getLogger(self._get_logger_name(outbound))
|
||||
PacketCapture._logger_instances.append(self.outbound_logger)
|
||||
logger = self.outbound_logger
|
||||
else:
|
||||
self.inbound_logger = logging.getLogger(self._get_logger_name(outbound))
|
||||
logger = self.inbound_logger
|
||||
PacketCapture._logger_instances.append(self.inbound_logger)
|
||||
|
||||
logger.setLevel(60) # Custom log level > CRITICAL to prevent any unwanted standard DEBUG-CRITICAL logs
|
||||
logger.addHandler(file_handler)
|
||||
@@ -122,3 +126,13 @@ class PacketCapture:
|
||||
if SIM_OUTPUT.save_pcap_logs:
|
||||
msg = frame.model_dump_json()
|
||||
self.outbound_logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL
|
||||
|
||||
@staticmethod
|
||||
def clear():
|
||||
"""Close all open PCAP file handlers."""
|
||||
for logger in PacketCapture._logger_instances:
|
||||
handlers = logger.handlers[:]
|
||||
for handler in handlers:
|
||||
logger.removeHandler(handler)
|
||||
handler.close()
|
||||
PacketCapture._logger_instances = []
|
||||
|
||||
@@ -78,6 +78,31 @@ class SoftwareManager:
|
||||
open_ports.append(software.port)
|
||||
return open_ports
|
||||
|
||||
def check_port_is_open(self, port: Port, protocol: IPProtocol) -> bool:
|
||||
"""
|
||||
Check if a specific port is open and running a service using the specified protocol.
|
||||
|
||||
This method iterates through all installed software on the node and checks if any of them
|
||||
are using the specified port and protocol and are currently in a running state. It returns True if any software
|
||||
is found running on the specified port and protocol, otherwise False.
|
||||
|
||||
|
||||
:param port: The port to check.
|
||||
:type port: Port
|
||||
:param protocol: The protocol to check (e.g., TCP, UDP).
|
||||
:type protocol: IPProtocol
|
||||
:return: True if the port is open and a service is running on it using the specified protocol, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
for software in self.software.values():
|
||||
if (
|
||||
software.port == port
|
||||
and software.protocol == protocol
|
||||
and software.operating_state in {ApplicationOperatingState.RUNNING, ServiceOperatingState.RUNNING}
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def install(self, software_class: Type[IOSoftwareClass]):
|
||||
"""
|
||||
Install an Application or Service.
|
||||
@@ -150,6 +175,7 @@ class SoftwareManager:
|
||||
self,
|
||||
payload: Any,
|
||||
dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
|
||||
src_port: Optional[Port] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
ip_protocol: IPProtocol = IPProtocol.TCP,
|
||||
session_id: Optional[str] = None,
|
||||
@@ -170,6 +196,7 @@ class SoftwareManager:
|
||||
return self.session_manager.receive_payload_from_software_manager(
|
||||
payload=payload,
|
||||
dst_ip_address=dest_ip_address,
|
||||
src_port=src_port,
|
||||
dst_port=dest_port,
|
||||
ip_protocol=ip_protocol,
|
||||
session_id=session_id,
|
||||
@@ -190,6 +217,9 @@ class SoftwareManager:
|
||||
:param payload: The payload being received.
|
||||
:param session: The transport session the payload originates from.
|
||||
"""
|
||||
if payload.__class__.__name__ == "PortScanPayload":
|
||||
self.software.get("NMAP").receive(payload=payload, session_id=session_id)
|
||||
return
|
||||
receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None)
|
||||
if receiver:
|
||||
receiver.receive(
|
||||
|
||||
25
src/primaite/utils/converters.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def convert_dict_enum_keys_to_enum_values(d: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||
"""
|
||||
Convert dictionary keys from enums to their corresponding values.
|
||||
|
||||
:param d: dict
|
||||
The dictionary with enum keys to be converted.
|
||||
:return: dict
|
||||
The dictionary with enum values as keys.
|
||||
"""
|
||||
result = {}
|
||||
for key, value in d.items():
|
||||
if isinstance(key, Enum):
|
||||
new_key = key.value
|
||||
else:
|
||||
new_key = key
|
||||
|
||||
if isinstance(value, dict):
|
||||
result[new_key] = convert_dict_enum_keys_to_enum_values(value)
|
||||
else:
|
||||
result[new_key] = value
|
||||
return result
|
||||
@@ -2,7 +2,7 @@
|
||||
raise DeprecationWarning(
|
||||
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
|
||||
)
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
@@ -11,16 +11,16 @@ from typing import Any, Dict, Tuple, Union
|
||||
import polars as pl
|
||||
|
||||
|
||||
def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
|
||||
def total_rewards_dict(total_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
|
||||
"""
|
||||
Read an average rewards per episode csv file and return as a dict.
|
||||
|
||||
The dictionary keys are the episode number, and the values are the mean reward that episode.
|
||||
|
||||
:param av_rewards_csv_file: The average rewards per episode csv file path.
|
||||
:param total_rewards_csv_file: The average rewards per episode csv file path.
|
||||
:return: The average rewards per episode csv as a dict.
|
||||
"""
|
||||
df_dict = pl.read_csv(av_rewards_csv_file).to_dict()
|
||||
df_dict = pl.read_csv(total_rewards_csv_file).to_dict()
|
||||
|
||||
return {int(v): df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
raise DeprecationWarning(
|
||||
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
|
||||
)
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import csv
|
||||
from logging import Logger
|
||||
from typing import Final, List, Tuple, TYPE_CHECKING, Union
|
||||
@@ -26,9 +26,9 @@ class SessionOutputWriter:
|
||||
Is used to write session outputs to csv file.
|
||||
"""
|
||||
|
||||
_AV_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [
|
||||
_TOTAL_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [
|
||||
"Episode",
|
||||
"Average Reward",
|
||||
"Total Reward",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@@ -43,7 +43,7 @@ class SessionOutputWriter:
|
||||
:param env: PrimAITE gym environment.
|
||||
:type env: Primaite
|
||||
:param transaction_writer: If `true`, this will output a full account of every transaction taken by the agent.
|
||||
If `false` it will output the average reward per episode, defaults to False
|
||||
If `false` it will output the total reward per episode, defaults to False
|
||||
:type transaction_writer: bool, optional
|
||||
:param learning_session: Set to `true` to indicate that the current session is a training session. This
|
||||
determines the name of the folder which contains the final output csv. Defaults to True
|
||||
@@ -56,7 +56,7 @@ class SessionOutputWriter:
|
||||
if self.transaction_writer:
|
||||
fn = f"all_transactions_{self._env.timestamp_str}.csv"
|
||||
else:
|
||||
fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv"
|
||||
fn = f"total_reward_per_episode_{self._env.timestamp_str}.csv"
|
||||
|
||||
self._csv_file_path: "Path"
|
||||
if self.learning_session:
|
||||
@@ -94,7 +94,7 @@ class SessionOutputWriter:
|
||||
if isinstance(data, Transaction):
|
||||
header, data = data.as_csv_data()
|
||||
else:
|
||||
header = self._AV_REWARD_PER_EPISODE_HEADER
|
||||
header = self._TOTAL_REWARD_PER_EPISODE_HEADER
|
||||
|
||||
if self._first_write:
|
||||
self._init_csv_writer()
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
io_settings:
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: true
|
||||
save_sys_logs: true
|
||||
sys_log_level: WARNING
|
||||
|
||||
|
||||
|
||||
game:
|
||||
max_episode_length: 256
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_red_nmap
|
||||
team: RED
|
||||
type: ProbabilisticAgent
|
||||
observation_space: null
|
||||
action_space:
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: NMAP
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 1
|
||||
action_list:
|
||||
- type: NODE_NMAP_NETWORK_SERVICE_RECON
|
||||
action_map:
|
||||
0:
|
||||
action: NODE_NMAP_NETWORK_SERVICE_RECON
|
||||
options:
|
||||
source_node: client_1
|
||||
target_ip_address: 192.168.10.0/24
|
||||
target_port: 80
|
||||
target_protocol: tcp
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 1.0
|
||||
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
- hostname: switch_1
|
||||
num_ports: 8
|
||||
type: switch
|
||||
|
||||
- hostname: switch_2
|
||||
num_ports: 8
|
||||
type: switch
|
||||
|
||||
- hostname: router_1
|
||||
type: router
|
||||
ports:
|
||||
1:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
2:
|
||||
ip_address: 192.168.10.1
|
||||
subnet_mask: 255.255.255.0
|
||||
acl:
|
||||
1:
|
||||
action: PERMIT
|
||||
|
||||
- hostname: client_1
|
||||
type: computer
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
|
||||
- hostname: client_2
|
||||
type: computer
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
|
||||
- hostname: server_1
|
||||
type: server
|
||||
ip_address: 192.168.1.10
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
|
||||
- hostname: server_2
|
||||
type: server
|
||||
ip_address: 192.168.1.14
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
|
||||
|
||||
|
||||
|
||||
links:
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 8
|
||||
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 8
|
||||
|
||||
- endpoint_a_hostname: client_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 1
|
||||
|
||||
- endpoint_a_hostname: client_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 2
|
||||
|
||||
- endpoint_a_hostname: server_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 1
|
||||
|
||||
- endpoint_a_hostname: server_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 2
|
||||
135
tests/assets/configs/nmap_ping_scan_red_agent_config.yaml
Normal file
@@ -0,0 +1,135 @@
|
||||
io_settings:
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: true
|
||||
save_sys_logs: true
|
||||
sys_log_level: WARNING
|
||||
|
||||
|
||||
|
||||
game:
|
||||
max_episode_length: 256
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_red_nmap
|
||||
team: RED
|
||||
type: ProbabilisticAgent
|
||||
observation_space: null
|
||||
action_space:
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: NMAP
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 1
|
||||
action_list:
|
||||
- type: NODE_NMAP_PING_SCAN
|
||||
action_map:
|
||||
0:
|
||||
action: NODE_NMAP_PING_SCAN
|
||||
options:
|
||||
source_node: client_1
|
||||
target_ip_address: 192.168.1.0/24
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 1.0
|
||||
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
- hostname: switch_1
|
||||
num_ports: 8
|
||||
type: switch
|
||||
|
||||
- hostname: switch_2
|
||||
num_ports: 8
|
||||
type: switch
|
||||
|
||||
- hostname: router_1
|
||||
type: router
|
||||
ports:
|
||||
1:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
2:
|
||||
ip_address: 192.168.10.1
|
||||
subnet_mask: 255.255.255.0
|
||||
acl:
|
||||
1:
|
||||
action: PERMIT
|
||||
|
||||
- hostname: client_1
|
||||
type: computer
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
|
||||
- hostname: client_2
|
||||
type: computer
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
|
||||
- hostname: server_1
|
||||
type: server
|
||||
ip_address: 192.168.1.10
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
|
||||
- hostname: server_2
|
||||
type: server
|
||||
ip_address: 192.168.1.14
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
|
||||
|
||||
|
||||
|
||||
links:
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 8
|
||||
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 8
|
||||
|
||||
- endpoint_a_hostname: client_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 1
|
||||
|
||||
- endpoint_a_hostname: client_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 2
|
||||
|
||||
- endpoint_a_hostname: server_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 1
|
||||
|
||||
- endpoint_a_hostname: server_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 2
|
||||
135
tests/assets/configs/nmap_port_scan_red_agent_config.yaml
Normal file
@@ -0,0 +1,135 @@
|
||||
io_settings:
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: true
|
||||
save_sys_logs: true
|
||||
sys_log_level: WARNING
|
||||
|
||||
|
||||
|
||||
game:
|
||||
max_episode_length: 256
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_red_nmap
|
||||
team: RED
|
||||
type: ProbabilisticAgent
|
||||
observation_space: null
|
||||
action_space:
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: NMAP
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 1
|
||||
action_list:
|
||||
- type: NODE_NMAP_PORT_SCAN
|
||||
action_map:
|
||||
0:
|
||||
action: NODE_NMAP_PORT_SCAN
|
||||
options:
|
||||
source_node: client_1
|
||||
target_ip_address: 192.168.10.0/24
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 1.0
|
||||
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
- hostname: switch_1
|
||||
num_ports: 8
|
||||
type: switch
|
||||
|
||||
- hostname: switch_2
|
||||
num_ports: 8
|
||||
type: switch
|
||||
|
||||
- hostname: router_1
|
||||
type: router
|
||||
ports:
|
||||
1:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
2:
|
||||
ip_address: 192.168.10.1
|
||||
subnet_mask: 255.255.255.0
|
||||
acl:
|
||||
1:
|
||||
action: PERMIT
|
||||
|
||||
- hostname: client_1
|
||||
type: computer
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
|
||||
- hostname: client_2
|
||||
type: computer
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
|
||||
- hostname: server_1
|
||||
type: server
|
||||
ip_address: 192.168.1.10
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
|
||||
- hostname: server_2
|
||||
type: server
|
||||
ip_address: 192.168.1.14
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
|
||||
|
||||
|
||||
|
||||
links:
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 8
|
||||
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 8
|
||||
|
||||
- endpoint_a_hostname: client_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 1
|
||||
|
||||
- endpoint_a_hostname: client_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 2
|
||||
|
||||
- endpoint_a_hostname: server_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 1
|
||||
|
||||
- endpoint_a_hostname: server_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 2
|
||||
@@ -3,7 +3,7 @@ import yaml
|
||||
from ray import air, tune
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
|
||||
from primaite.session.environment import PrimaiteRayMARLEnv
|
||||
from primaite.session.ray_envs import PrimaiteRayMARLEnv
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
|
||||
|
||||
@@ -8,7 +8,7 @@ from ray.rllib.algorithms import ppo
|
||||
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteRayEnv
|
||||
from primaite.session.ray_envs import PrimaiteRayEnv
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Slow, reenable later")
|
||||
|
||||
@@ -4,7 +4,8 @@ import yaml
|
||||
from gymnasium.core import ObsType
|
||||
from numpy import ndarray
|
||||
|
||||
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayMARLEnv
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.session.ray_envs import PrimaiteRayMARLEnv
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Printer
|
||||
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from tests.conftest import TEST_ASSETS_ROOT
|
||||
|
||||
folder_path = TEST_ASSETS_ROOT / "configs" / "scenario_with_placeholders"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import yaml
|
||||
|
||||
from primaite.game.agent.interface import AgentActionHistoryItem
|
||||
from primaite.game.agent.interface import AgentHistoryItem
|
||||
from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
@@ -75,7 +75,7 @@ def test_uc2_rewards(game_and_agent):
|
||||
state = game.get_sim_state()
|
||||
reward_value = comp.calculate(
|
||||
state,
|
||||
last_action_response=AgentActionHistoryItem(
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
|
||||
),
|
||||
)
|
||||
@@ -91,7 +91,7 @@ def test_uc2_rewards(game_and_agent):
|
||||
state = game.get_sim_state()
|
||||
reward_value = comp.calculate(
|
||||
state,
|
||||
last_action_response=AgentActionHistoryItem(
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
|
||||
),
|
||||
)
|
||||
|
||||
@@ -9,12 +9,8 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import (
|
||||
RansomwareAttackStage,
|
||||
RansomwareScript,
|
||||
)
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
@@ -85,54 +81,24 @@ def ransomware_script_db_server_green_client(example_network) -> Network:
|
||||
return network
|
||||
|
||||
|
||||
def test_repeating_ransomware_script_attack(ransomware_script_and_db_server):
|
||||
def test_ransomware_script_attack(ransomware_script_and_db_server):
|
||||
"""Test a repeating data manipulation attack."""
|
||||
RansomwareScript, computer, db_server_service, server = ransomware_script_and_db_server
|
||||
|
||||
computer.apply_timestep(timestep=0)
|
||||
server.apply_timestep(timestep=0)
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
|
||||
assert computer.file_system.num_file_creations == 0
|
||||
assert server.file_system.num_file_creations == 1
|
||||
|
||||
RansomwareScript.target_scan_p_of_success = 1
|
||||
RansomwareScript.c2_beacon_p_of_success = 1
|
||||
RansomwareScript.ransomware_encrypt_p_of_success = 1
|
||||
RansomwareScript.repeat = True
|
||||
RansomwareScript.attack()
|
||||
|
||||
assert RansomwareScript.attack_stage == RansomwareAttackStage.NOT_STARTED
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.COMPROMISED
|
||||
assert computer.file_system.num_file_creations == 1
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT
|
||||
assert server.file_system.num_file_creations == 2
|
||||
|
||||
computer.apply_timestep(timestep=1)
|
||||
server.apply_timestep(timestep=1)
|
||||
|
||||
assert RansomwareScript.attack_stage == RansomwareAttackStage.NOT_STARTED
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.COMPROMISED
|
||||
|
||||
|
||||
def test_repeating_ransomware_script_attack(ransomware_script_and_db_server):
|
||||
"""Test a repeating ransowmare script attack."""
|
||||
RansomwareScript, computer, db_server_service, server = ransomware_script_and_db_server
|
||||
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
|
||||
|
||||
RansomwareScript.target_scan_p_of_success = 1
|
||||
RansomwareScript.c2_beacon_p_of_success = 1
|
||||
RansomwareScript.ransomware_encrypt_p_of_success = 1
|
||||
RansomwareScript.repeat = False
|
||||
RansomwareScript.attack()
|
||||
|
||||
assert RansomwareScript.attack_stage == RansomwareAttackStage.SUCCEEDED
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT
|
||||
assert computer.file_system.num_file_creations == 1
|
||||
|
||||
computer.apply_timestep(timestep=1)
|
||||
computer.pre_timestep(timestep=1)
|
||||
server.apply_timestep(timestep=1)
|
||||
server.pre_timestep(timestep=1)
|
||||
|
||||
assert RansomwareScript.attack_stage == RansomwareAttackStage.SUCCEEDED
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT
|
||||
assert computer.file_system.num_file_creations == 0
|
||||
|
||||
|
||||
def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_green_client):
|
||||
@@ -153,10 +119,6 @@ def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_
|
||||
assert green_db_client_connection.query("SELECT")
|
||||
assert green_db_client.last_query_response.get("status_code") == 200
|
||||
|
||||
ransomware_script_application.target_scan_p_of_success = 1
|
||||
ransomware_script_application.ransomware_encrypt_p_of_success = 1
|
||||
ransomware_script_application.c2_beacon_p_of_success = 1
|
||||
ransomware_script_application.repeat = False
|
||||
ransomware_script_application.attack()
|
||||
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT
|
||||
|
||||
185
tests/integration_tests/system/test_nmap.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.nmap import NMAP
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
|
||||
def test_ping_scan_all_on(example_network):
|
||||
network = example_network
|
||||
|
||||
client_1 = network.get_node_by_hostname("client_1")
|
||||
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
|
||||
|
||||
expected_result = [IPv4Address("192.168.1.10"), IPv4Address("192.168.1.14")]
|
||||
actual_result = client_1_nmap.ping_scan(target_ip_address=["192.168.1.10", "192.168.1.14"])
|
||||
|
||||
assert sorted(actual_result) == sorted(expected_result)
|
||||
|
||||
|
||||
def test_ping_scan_all_on_full_network(example_network):
|
||||
network = example_network
|
||||
|
||||
client_1 = network.get_node_by_hostname("client_1")
|
||||
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
|
||||
|
||||
expected_result = [IPv4Address("192.168.1.1"), IPv4Address("192.168.1.10"), IPv4Address("192.168.1.14")]
|
||||
actual_result = client_1_nmap.ping_scan(target_ip_address=IPv4Network("192.168.1.0/24"))
|
||||
|
||||
assert sorted(actual_result) == sorted(expected_result)
|
||||
|
||||
|
||||
def test_ping_scan_some_on(example_network):
|
||||
network = example_network
|
||||
|
||||
client_1 = network.get_node_by_hostname("client_1")
|
||||
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
|
||||
|
||||
network.get_node_by_hostname("server_2").power_off()
|
||||
|
||||
expected_result = [IPv4Address("192.168.1.1"), IPv4Address("192.168.1.10")]
|
||||
actual_result = client_1_nmap.ping_scan(target_ip_address=IPv4Network("192.168.1.0/24"))
|
||||
|
||||
assert sorted(actual_result) == sorted(expected_result)
|
||||
|
||||
|
||||
def test_ping_scan_all_off(example_network):
|
||||
network = example_network
|
||||
|
||||
client_1 = network.get_node_by_hostname("client_1")
|
||||
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
|
||||
|
||||
network.get_node_by_hostname("server_1").power_off()
|
||||
network.get_node_by_hostname("server_2").power_off()
|
||||
|
||||
expected_result = []
|
||||
actual_result = client_1_nmap.ping_scan(target_ip_address=["192.168.1.10", "192.168.1.14"])
|
||||
|
||||
assert sorted(actual_result) == sorted(expected_result)
|
||||
|
||||
|
||||
def test_port_scan_one_node_one_port(example_network):
|
||||
network = example_network
|
||||
|
||||
client_1 = network.get_node_by_hostname("client_1")
|
||||
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
|
||||
|
||||
client_2 = network.get_node_by_hostname("client_2")
|
||||
|
||||
actual_result = client_1_nmap.port_scan(
|
||||
target_ip_address=client_2.network_interface[1].ip_address, target_port=Port.DNS, target_protocol=IPProtocol.TCP
|
||||
)
|
||||
|
||||
expected_result = {IPv4Address("192.168.10.22"): {IPProtocol.TCP: [Port.DNS]}}
|
||||
|
||||
assert actual_result == expected_result
|
||||
|
||||
|
||||
def sort_dict(d):
|
||||
"""Recursively sorts a dictionary."""
|
||||
if isinstance(d, dict):
|
||||
return {k: sort_dict(v) for k, v in sorted(d.items(), key=lambda item: str(item[0]))}
|
||||
elif isinstance(d, list):
|
||||
return sorted(d, key=lambda item: str(item) if isinstance(item, Enum) else item)
|
||||
elif isinstance(d, Enum):
|
||||
return str(d)
|
||||
else:
|
||||
return d
|
||||
|
||||
|
||||
def test_port_scan_full_subnet_all_ports_and_protocols(example_network):
|
||||
network = example_network
|
||||
|
||||
client_1 = network.get_node_by_hostname("client_1")
|
||||
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
|
||||
|
||||
actual_result = client_1_nmap.port_scan(
|
||||
target_ip_address=IPv4Network("192.168.10.0/24"),
|
||||
)
|
||||
|
||||
expected_result = {
|
||||
IPv4Address("192.168.10.1"): {IPProtocol.UDP: [Port.ARP]},
|
||||
IPv4Address("192.168.10.22"): {
|
||||
IPProtocol.TCP: [Port.HTTP, Port.FTP, Port.DNS],
|
||||
IPProtocol.UDP: [Port.ARP, Port.NTP],
|
||||
},
|
||||
}
|
||||
|
||||
assert sort_dict(actual_result) == sort_dict(expected_result)
|
||||
|
||||
|
||||
def test_network_service_recon_all_ports_and_protocols(example_network):
|
||||
network = example_network
|
||||
|
||||
client_1 = network.get_node_by_hostname("client_1")
|
||||
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
|
||||
|
||||
actual_result = client_1_nmap.network_service_recon(
|
||||
target_ip_address=IPv4Network("192.168.10.0/24"), target_port=Port.HTTP, target_protocol=IPProtocol.TCP
|
||||
)
|
||||
|
||||
expected_result = {IPv4Address("192.168.10.22"): {IPProtocol.TCP: [Port.HTTP]}}
|
||||
|
||||
assert sort_dict(actual_result) == sort_dict(expected_result)
|
||||
|
||||
|
||||
def test_ping_scan_red_agent():
|
||||
with open(TEST_ASSETS_ROOT / "configs/nmap_ping_scan_red_agent_config.yaml", "r") as file:
|
||||
cfg = yaml.safe_load(file)
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
|
||||
game.step()
|
||||
|
||||
expected_result = ["192.168.1.1", "192.168.1.10", "192.168.1.14"]
|
||||
|
||||
action_history = game.agents["client_1_red_nmap"].history
|
||||
assert len(action_history) == 1
|
||||
actual_result = action_history[0].response.data["live_hosts"]
|
||||
|
||||
assert sorted(actual_result) == sorted(expected_result)
|
||||
|
||||
|
||||
def test_port_scan_red_agent():
|
||||
with open(TEST_ASSETS_ROOT / "configs/nmap_port_scan_red_agent_config.yaml", "r") as file:
|
||||
cfg = yaml.safe_load(file)
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
|
||||
game.step()
|
||||
|
||||
expected_result = {
|
||||
"192.168.10.1": {"udp": [219]},
|
||||
"192.168.10.22": {
|
||||
"tcp": [80, 21, 53],
|
||||
"udp": [219, 123],
|
||||
},
|
||||
}
|
||||
|
||||
action_history = game.agents["client_1_red_nmap"].history
|
||||
assert len(action_history) == 1
|
||||
actual_result = action_history[0].response.data
|
||||
|
||||
assert sorted(actual_result) == sorted(expected_result)
|
||||
|
||||
|
||||
def test_network_service_recon_red_agent():
|
||||
with open(TEST_ASSETS_ROOT / "configs/nmap_network_service_recon_red_agent_config.yaml", "r") as file:
|
||||
cfg = yaml.safe_load(file)
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
|
||||
game.step()
|
||||
|
||||
expected_result = {"192.168.10.22": {"tcp": [80]}}
|
||||
|
||||
action_history = game.agents["client_1_red_nmap"].history
|
||||
assert len(action_history) == 1
|
||||
actual_result = action_history[0].response.data
|
||||
|
||||
assert sorted(actual_result) == sorted(expected_result)
|
||||
0
tests/unit_tests/_primaite/_utils/__init__.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.utils.converters import convert_dict_enum_keys_to_enum_values
|
||||
|
||||
|
||||
def test_simple_conversion():
|
||||
"""
|
||||
Test conversion of a simple dictionary with enum keys to enum values.
|
||||
|
||||
The original dictionary contains one level of nested dictionary with enums as keys.
|
||||
The expected output should have string values of enums as keys.
|
||||
"""
|
||||
original_dict = {IPProtocol.UDP: {Port.ARP: {"inbound": 0, "outbound": 1016.0}}}
|
||||
expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0}}}
|
||||
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict
|
||||
|
||||
|
||||
def test_no_enums():
|
||||
"""
|
||||
Test conversion of a dictionary with no enum keys.
|
||||
|
||||
The original dictionary contains only string keys.
|
||||
The expected output should be identical to the original dictionary.
|
||||
"""
|
||||
original_dict = {"protocol": {"port": {"inbound": 0, "outbound": 1016.0}}}
|
||||
expected_dict = {"protocol": {"port": {"inbound": 0, "outbound": 1016.0}}}
|
||||
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict
|
||||
|
||||
|
||||
def test_mixed_keys():
|
||||
"""
|
||||
Test conversion of a dictionary with a mix of enum and string keys.
|
||||
|
||||
The original dictionary contains both enums and strings as keys.
|
||||
The expected output should have string values of enums and original string keys.
|
||||
"""
|
||||
original_dict = {
|
||||
IPProtocol.TCP: {"port": {"inbound": 0, "outbound": 1016.0}},
|
||||
"protocol": {Port.HTTP: {"inbound": 10, "outbound": 2020.0}},
|
||||
}
|
||||
expected_dict = {
|
||||
"tcp": {"port": {"inbound": 0, "outbound": 1016.0}},
|
||||
"protocol": {80: {"inbound": 10, "outbound": 2020.0}},
|
||||
}
|
||||
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict
|
||||
|
||||
|
||||
def test_empty_dict():
|
||||
"""
|
||||
Test conversion of an empty dictionary.
|
||||
|
||||
The original dictionary is empty.
|
||||
The expected output should also be an empty dictionary.
|
||||
"""
|
||||
original_dict = {}
|
||||
expected_dict = {}
|
||||
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict
|
||||
|
||||
|
||||
def test_nested_dicts():
|
||||
"""
|
||||
Test conversion of a nested dictionary with multiple levels of nested dictionaries and enums as keys.
|
||||
|
||||
The original dictionary contains nested dictionaries with enums as keys at different levels.
|
||||
The expected output should have string values of enums as keys at all levels.
|
||||
"""
|
||||
original_dict = {
|
||||
IPProtocol.UDP: {Port.ARP: {"inbound": 0, "outbound": 1016.0, "details": {IPProtocol.TCP: {"latency": "low"}}}}
|
||||
}
|
||||
expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0, "details": {"tcp": {"latency": "low"}}}}}
|
||||
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict
|
||||
|
||||
|
||||
def test_non_dict_values():
|
||||
"""
|
||||
Test conversion of a dictionary where some values are not dictionaries.
|
||||
|
||||
The original dictionary contains lists and tuples as values.
|
||||
The expected output should preserve these non-dictionary values while converting enum keys to string values.
|
||||
"""
|
||||
original_dict = {IPProtocol.UDP: [Port.ARP, Port.HTTP], "protocols": (IPProtocol.TCP, IPProtocol.UDP)}
|
||||
expected_dict = {"udp": [Port.ARP, Port.HTTP], "protocols": (IPProtocol.TCP, IPProtocol.UDP)}
|
||||
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict
|
||||