Merge remote-tracking branch 'origin/dev' into feature/2646_Update-pre-commit-to-check-for-valid-copyright

This commit is contained in:
Marek Wolan
2024-06-13 12:52:09 +01:00
54 changed files with 9740 additions and 647 deletions

4
.gitignore vendored
View File

@@ -164,3 +164,7 @@ src/primaite/notebooks/scratch.py
sandbox.py
sandbox/
sandbox.ipynb
# benchmarking
**/benchmark/sessions/
**/benchmark/output/

View File

@@ -43,6 +43,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for SQL INSERT command.
- Added ability to log each agent's action choices in each step to a JSON file.
- Removal of Link bandwidth hardcoding. This can now be configured via the network configuraiton yaml. Will default to 100 if not present.
- Added NMAP application to all host and layer-3 network nodes.
### Bug Fixes

22
benchmark/benchmark.py Normal file
View File

@@ -0,0 +1,22 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Any, Dict, Optional, Tuple
from gymnasium.core import ObsType
from primaite.session.environment import PrimaiteGymEnv
class BenchmarkPrimaiteGymEnv(PrimaiteGymEnv):
"""
Class that extends the PrimaiteGymEnv.
The reset method is extended so that the average rewards per episode are recorded.
"""
total_time_steps: int = 0
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Overrides the PrimAITEGymEnv reset so that the total timesteps is saved."""
self.total_time_steps += self.game.step_counter
return super().reset(seed=seed)

View File

@@ -3,210 +3,95 @@
raise DeprecationWarning(
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
)
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import json
import platform
import shutil
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Final, Optional, Tuple, Union
from unittest.mock import patch
from typing import Any, Dict, Final, Tuple
import GPUtil
import plotly.graph_objects as go
import polars as pl
import psutil
import yaml
from plotly.graph_objs import Figure
from pylatex import Command, Document
from pylatex import Figure as LatexFigure
from pylatex import Section, Subsection, Tabular
from pylatex.utils import bold
from report import build_benchmark_latex_report
from stable_baselines3 import PPO
import primaite
from primaite.config.lay_down_config import data_manipulation_config_path
from primaite.data_viz.session_plots import get_plotly_config
from primaite.environment.primaite_env import Primaite
from primaite.primaite_session import PrimaiteSession
from benchmark import BenchmarkPrimaiteGymEnv
from primaite.config.load import data_manipulation_config_path
_LOGGER = primaite.getLogger(__name__)
_MAJOR_V = primaite.__version__.split(".")[0]
_BENCHMARK_ROOT = Path(__file__).parent
_RESULTS_ROOT: Final[Path] = _BENCHMARK_ROOT / "results"
_RESULTS_ROOT.mkdir(exist_ok=True, parents=True)
_RESULTS_ROOT: Final[Path] = _BENCHMARK_ROOT / "results" / f"v{_MAJOR_V}"
_VERSION_ROOT: Final[Path] = _RESULTS_ROOT / f"v{primaite.__version__}"
_SESSION_METADATA_ROOT: Final[Path] = _VERSION_ROOT / "session_metadata"
_OUTPUT_ROOT: Final[Path] = _BENCHMARK_ROOT / "output"
# Clear and recreate the output directory
if _OUTPUT_ROOT.exists():
shutil.rmtree(_OUTPUT_ROOT)
_OUTPUT_ROOT.mkdir()
_TRAINING_CONFIG_PATH = _BENCHMARK_ROOT / "config" / "benchmark_training_config.yaml"
_LAY_DOWN_CONFIG_PATH = data_manipulation_config_path()
_SESSION_METADATA_ROOT.mkdir(parents=True, exist_ok=True)
def get_size(size_bytes: int) -> str:
"""
Scale bytes to its proper format.
class BenchmarkSession:
"""Benchmark Session class."""
e.g:
1253656 => '1.20MB'
1253656678 => '1.17GB'
gym_env: BenchmarkPrimaiteGymEnv
"""Gym environment used by the session to train."""
:
"""
factor = 1024
for unit in ["", "K", "M", "G", "T", "P"]:
if size_bytes < factor:
return f"{size_bytes:.2f}{unit}B"
size_bytes /= factor
num_episodes: int
"""Number of episodes to run the training session."""
episode_len: int
"""The number of steps per episode."""
def _get_system_info() -> Dict:
"""Builds and returns a dict containing system info."""
uname = platform.uname()
cpu_freq = psutil.cpu_freq()
virtual_mem = psutil.virtual_memory()
swap_mem = psutil.swap_memory()
gpus = GPUtil.getGPUs()
return {
"System": {
"OS": uname.system,
"OS Version": uname.version,
"Machine": uname.machine,
"Processor": uname.processor,
},
"CPU": {
"Physical Cores": psutil.cpu_count(logical=False),
"Total Cores": psutil.cpu_count(logical=True),
"Max Frequency": f"{cpu_freq.max:.2f}Mhz",
},
"Memory": {"Total": get_size(virtual_mem.total), "Swap Total": get_size(swap_mem.total)},
"GPU": [{"Name": gpu.name, "Total Memory": f"{gpu.memoryTotal}MB"} for gpu in gpus],
}
total_steps: int
"""Number of steps to run the training session."""
batch_size: int
"""Number of steps for each episode."""
def _build_benchmark_latex_report(
benchmark_metadata_dict: Dict, this_version_plot_path: Path, all_version_plot_path: Path
) -> None:
geometry_options = {"tmargin": "2.5cm", "rmargin": "2.5cm", "bmargin": "2.5cm", "lmargin": "2.5cm"}
data = benchmark_metadata_dict
primaite_version = data["primaite_version"]
learning_rate: float
"""Learning rate for the model."""
# Create a new document
doc = Document("report", geometry_options=geometry_options)
# Title
doc.preamble.append(Command("title", f"PrimAITE {primaite_version} Learning Benchmark"))
doc.preamble.append(Command("author", "PrimAITE Dev Team"))
doc.preamble.append(Command("date", datetime.now().date()))
doc.append(Command("maketitle"))
start_time: datetime
"""Start time for the session."""
sessions = data["total_sessions"]
episodes = data["training_config"]["num_train_episodes"]
steps = data["training_config"]["num_train_steps"]
# Body
with doc.create(Section("Introduction")):
doc.append(
f"PrimAITE v{primaite_version} was benchmarked automatically upon release. Learning rate metrics "
f"were captured to be referenced during system-level testing and user acceptance testing (UAT)."
)
doc.append(
f"\nThe benchmarking process consists of running {sessions} training session using the same "
f"training and lay down config files. Each session trains an agent for {episodes} episodes, "
f"with each episode consisting of {steps} steps."
)
doc.append(
f"\nThe mean reward per episode from each session is captured. This is then used to calculate a "
f"combined average reward per episode from the {sessions} individual sessions for smoothing. "
f"Finally, a 25-widow rolling average of the combined average reward per session is calculated for "
f"further smoothing."
)
with doc.create(Section("System Information")):
with doc.create(Subsection("Python")):
with doc.create(Tabular("|l|l|")) as table:
table.add_hline()
table.add_row((bold("Version"), sys.version))
table.add_hline()
for section, section_data in data["system_info"].items():
if section_data:
with doc.create(Subsection(section)):
if isinstance(section_data, dict):
with doc.create(Tabular("|l|l|")) as table:
table.add_hline()
for key, value in section_data.items():
table.add_row((bold(key), value))
table.add_hline()
elif isinstance(section_data, list):
headers = section_data[0].keys()
tabs_str = "|".join(["l" for _ in range(len(headers))])
tabs_str = f"|{tabs_str}|"
with doc.create(Tabular(tabs_str)) as table:
table.add_hline()
table.add_row([bold(h) for h in headers])
table.add_hline()
for item in section_data:
table.add_row(item.values())
table.add_hline()
headers_map = {
"total_sessions": "Total Sessions",
"total_episodes": "Total Episodes",
"total_time_steps": "Total Steps",
"av_s_per_session": "Av Session Duration (s)",
"av_s_per_step": "Av Step Duration (s)",
"av_s_per_100_steps_10_nodes": "Av Duration per 100 Steps per 10 Nodes (s)",
}
with doc.create(Section("Stats")):
with doc.create(Subsection("Benchmark Results")):
with doc.create(Tabular("|l|l|")) as table:
table.add_hline()
for section, header in headers_map.items():
if section.startswith("av_"):
table.add_row((bold(header), f"{data[section]:.4f}"))
else:
table.add_row((bold(header), data[section]))
table.add_hline()
with doc.create(Section("Graphs")):
with doc.create(Subsection(f"PrimAITE {primaite_version} Learning Benchmark Plot")):
with doc.create(LatexFigure(position="h!")) as pic:
pic.add_image(str(this_version_plot_path))
pic.add_caption(f"PrimAITE {primaite_version} Learning Benchmark Plot")
with doc.create(Subsection("PrimAITE All Versions Learning Benchmark Plot")):
with doc.create(LatexFigure(position="h!")) as pic:
pic.add_image(str(all_version_plot_path))
pic.add_caption("PrimAITE All Versions Learning Benchmark Plot")
doc.generate_pdf(str(this_version_plot_path).replace(".png", ""), clean_tex=True)
class BenchmarkPrimaiteSession(PrimaiteSession):
"""A benchmarking primaite session."""
end_time: datetime
"""End time for the session."""
def __init__(
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
) -> None:
super().__init__(training_config_path, lay_down_config_path)
self.setup()
gym_env: BenchmarkPrimaiteGymEnv,
episode_len: int,
num_episodes: int,
n_steps: int,
batch_size: int,
learning_rate: float,
):
"""Initialise the BenchmarkSession."""
self.gym_env = gym_env
self.episode_len = episode_len
self.n_steps = n_steps
self.num_episodes = num_episodes
self.total_steps = self.num_episodes * self.episode_len
self.batch_size = batch_size
self.learning_rate = learning_rate
@property
def env(self) -> Primaite:
"""Direct access to the env for ease of testing."""
return self._agent_session._env # noqa
def train(self):
"""Run the training session."""
# start timer for session
self.start_time = datetime.now()
model = PPO(
policy="MlpPolicy",
env=self.gym_env,
learning_rate=self.learning_rate,
n_steps=self.n_steps,
batch_size=self.batch_size,
verbose=0,
tensorboard_log="./PPO_UC2/",
)
model.learn(total_timesteps=self.total_steps)
def __enter__(self) -> "BenchmarkPrimaiteSession":
return self
# end timer for session
self.end_time = datetime.now()
# TODO: typehints uncertain
def __exit__(self, type: Any, value: Any, tb: Any) -> None:
shutil.rmtree(self.session_path)
_LOGGER.debug(f"Deleted benchmark session directory: {self.session_path}")
self.session_metadata = self.generate_learn_metadata_dict()
def _learn_benchmark_durations(self) -> Tuple[float, float, float]:
"""
@@ -220,235 +105,99 @@ class BenchmarkPrimaiteSession(PrimaiteSession):
:return: The learning benchmark durations as a Tuple of three floats:
Tuple[total_s, s_per_step, s_per_100_steps_10_nodes].
"""
data = self.metadata_file_as_dict()
start_dt = datetime.fromisoformat(data["start_datetime"])
end_dt = datetime.fromisoformat(data["end_datetime"])
delta = end_dt - start_dt
delta = self.end_time - self.start_time
total_s = delta.total_seconds()
total_steps = data["learning"]["total_time_steps"]
total_steps = self.batch_size * self.num_episodes
s_per_step = total_s / total_steps
num_nodes = self.env.num_nodes
num_nodes = len(self.gym_env.game.simulation.network.nodes)
num_intervals = total_steps / 100
av_interval_time = total_s / num_intervals
s_per_100_steps_10_nodes = av_interval_time / (num_nodes / 10)
return total_s, s_per_step, s_per_100_steps_10_nodes
def learn_metadata_dict(self) -> Dict[str, Any]:
def generate_learn_metadata_dict(self) -> Dict[str, Any]:
"""Metadata specific to the learning session."""
total_s, s_per_step, s_per_100_steps_10_nodes = self._learn_benchmark_durations()
self.gym_env.average_reward_per_episode.pop(0) # remove episode 0
return {
"total_episodes": self.env.actual_episode_count,
"total_time_steps": self.env.total_step_count,
"total_episodes": self.gym_env.episode_counter,
"total_time_steps": self.gym_env.total_time_steps,
"total_s": total_s,
"s_per_step": s_per_step,
"s_per_100_steps_10_nodes": s_per_100_steps_10_nodes,
"av_reward_per_episode": self.learn_av_reward_per_episode_dict(),
"av_reward_per_episode": self.gym_env.average_reward_per_episode,
}
def _get_benchmark_session_path(session_timestamp: datetime) -> Path:
return _OUTPUT_ROOT / session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
def _get_benchmark_primaite_session() -> BenchmarkPrimaiteSession:
with patch("primaite.agents.agent_abc.get_session_path", _get_benchmark_session_path) as mck:
mck.session_timestamp = datetime.now()
return BenchmarkPrimaiteSession(_TRAINING_CONFIG_PATH, _LAY_DOWN_CONFIG_PATH)
def _build_benchmark_results_dict(start_datetime: datetime, metadata_dict: Dict) -> dict:
n = len(metadata_dict)
with open(_TRAINING_CONFIG_PATH, "r") as file:
training_config_dict = yaml.safe_load(file)
with open(_LAY_DOWN_CONFIG_PATH, "r") as file:
lay_down_config_dict = yaml.safe_load(file)
averaged_data = {
"start_timestamp": start_datetime.isoformat(),
"end_datetime": datetime.now().isoformat(),
"primaite_version": primaite.__version__,
"system_info": _get_system_info(),
"total_sessions": n,
"total_episodes": sum(d["total_episodes"] for d in metadata_dict.values()),
"total_time_steps": sum(d["total_time_steps"] for d in metadata_dict.values()),
"av_s_per_session": sum(d["total_s"] for d in metadata_dict.values()) / n,
"av_s_per_step": sum(d["s_per_step"] for d in metadata_dict.values()) / n,
"av_s_per_100_steps_10_nodes": sum(d["s_per_100_steps_10_nodes"] for d in metadata_dict.values()) / n,
"combined_av_reward_per_episode": {},
"session_av_reward_per_episode": {k: v["av_reward_per_episode"] for k, v in metadata_dict.items()},
"training_config": training_config_dict,
"lay_down_config": lay_down_config_dict,
}
episodes = metadata_dict[1]["av_reward_per_episode"].keys()
for episode in episodes:
combined_av_reward = sum(metadata_dict[k]["av_reward_per_episode"][episode] for k in metadata_dict.keys()) / n
averaged_data["combined_av_reward_per_episode"][episode] = combined_av_reward
return averaged_data
def _get_df_from_episode_av_reward_dict(data: Dict) -> pl.DataFrame:
data: Dict = {"episode": data.keys(), "av_reward": data.values()}
return (
pl.from_dict(data)
.with_columns(rolling_mean=pl.col("av_reward").rolling_mean(window_size=25))
.rename({"rolling_mean": "rolling_av_reward"})
)
def _plot_benchmark_metadata(
benchmark_metadata_dict: Dict,
title: Optional[str] = None,
subtitle: Optional[str] = None,
) -> Figure:
if title:
if subtitle:
title = f"{title} <br>{subtitle}</sup>"
else:
if subtitle:
title = subtitle
config = get_plotly_config()
layout = go.Layout(
autosize=config["size"]["auto_size"],
width=config["size"]["width"],
height=config["size"]["height"],
)
# Create the line graph with a colored line
fig = go.Figure(layout=layout)
fig.update_layout(template=config["template"])
for session, av_reward_dict in benchmark_metadata_dict["session_av_reward_per_episode"].items():
df = _get_df_from_episode_av_reward_dict(av_reward_dict)
fig.add_trace(
go.Scatter(
x=df["episode"],
y=df["av_reward"],
mode="lines",
name=f"Session {session}",
opacity=0.25,
line={"color": "#a6a6a6"},
)
)
df = _get_df_from_episode_av_reward_dict(benchmark_metadata_dict["combined_av_reward_per_episode"])
fig.add_trace(
go.Scatter(
x=df["episode"], y=df["av_reward"], mode="lines", name="Combined Session Av", line={"color": "#FF0000"}
)
)
fig.add_trace(
go.Scatter(
x=df["episode"],
y=df["rolling_av_reward"],
mode="lines",
name="Rolling Av (Combined Session Av)",
line={"color": "#4CBB17"},
)
)
# Set the layout of the graph
fig.update_layout(
xaxis={
"title": "Episode",
"type": "linear",
},
yaxis={"title": "Average Reward"},
title=title,
)
return fig
def _plot_all_benchmarks_combined_session_av() -> Figure:
def _get_benchmark_primaite_environment() -> BenchmarkPrimaiteGymEnv:
"""
Plot the Benchmark results for each released version of PrimAITE.
Create an instance of the BenchmarkPrimaiteGymEnv.
Does this by iterating over the ``benchmark/results`` directory and
extracting the benchmark metadata json for each version that has been
benchmarked. The combined_av_reward_per_episode is extracted from each,
converted into a polars dataframe, and plotted as a scatter line in plotly.
This environment will be used to train the agents on.
"""
title = "PrimAITE Versions Learning Benchmark"
subtitle = "Rolling Av (Combined Session Av)"
if title:
if subtitle:
title = f"{title} <br>{subtitle}</sup>"
else:
if subtitle:
title = subtitle
config = get_plotly_config()
layout = go.Layout(
autosize=config["size"]["auto_size"],
width=config["size"]["width"],
height=config["size"]["height"],
)
# Create the line graph with a colored line
fig = go.Figure(layout=layout)
fig.update_layout(template=config["template"])
for dir in _RESULTS_ROOT.iterdir():
if dir.is_dir():
metadata_file = dir / f"{dir.name}_benchmark_metadata.json"
with open(metadata_file, "r") as file:
metadata_dict = json.load(file)
df = _get_df_from_episode_av_reward_dict(metadata_dict["combined_av_reward_per_episode"])
fig.add_trace(go.Scatter(x=df["episode"], y=df["rolling_av_reward"], mode="lines", name=dir.name))
# Set the layout of the graph
fig.update_layout(
xaxis={
"title": "Episode",
"type": "linear",
},
yaxis={"title": "Average Reward"},
title=title,
)
fig["data"][0]["showlegend"] = True
return fig
env = BenchmarkPrimaiteGymEnv(env_config=data_manipulation_config_path())
return env
def run() -> None:
def _prepare_session_directory():
"""Prepare the session directory so that it is easier to clean up after the benchmarking is done."""
# override session path
session_path = _BENCHMARK_ROOT / "sessions"
if session_path.is_dir():
shutil.rmtree(session_path)
primaite.PRIMAITE_PATHS.user_sessions_path = session_path
primaite.PRIMAITE_PATHS.user_sessions_path.mkdir(exist_ok=True, parents=True)
def run(
number_of_sessions: int = 5,
num_episodes: int = 1000,
episode_len: int = 128,
n_steps: int = 1280,
batch_size: int = 32,
learning_rate: float = 3e-4,
) -> None:
"""Run the PrimAITE benchmark."""
start_datetime = datetime.now()
av_reward_per_episode_dicts = {}
for i in range(1, 11):
benchmark_start_time = datetime.now()
session_metadata_dict = {}
_prepare_session_directory()
# run training
for i in range(1, number_of_sessions + 1):
print(f"Starting Benchmark Session: {i}")
with _get_benchmark_primaite_session() as session:
session.learn()
av_reward_per_episode_dicts[i] = session.learn_metadata_dict()
benchmark_metadata = _build_benchmark_results_dict(
start_datetime=start_datetime, metadata_dict=av_reward_per_episode_dicts
with _get_benchmark_primaite_environment() as gym_env:
session = BenchmarkSession(
gym_env=gym_env,
num_episodes=num_episodes,
n_steps=n_steps,
episode_len=episode_len,
batch_size=batch_size,
learning_rate=learning_rate,
)
session.train()
# Dump the session metadata so that we're not holding it in memory as it's large
with open(_SESSION_METADATA_ROOT / f"{i}.json", "w") as file:
json.dump(session.session_metadata, file, indent=4)
for i in range(1, number_of_sessions + 1):
with open(_SESSION_METADATA_ROOT / f"{i}.json", "r") as file:
session_metadata_dict[i] = json.load(file)
# generate report
build_benchmark_latex_report(
benchmark_start_time=benchmark_start_time,
session_metadata=session_metadata_dict,
config_path=data_manipulation_config_path(),
results_root_path=_RESULTS_ROOT,
)
v_str = f"v{primaite.__version__}"
version_result_dir = _RESULTS_ROOT / v_str
if version_result_dir.exists():
shutil.rmtree(version_result_dir)
version_result_dir.mkdir(exist_ok=True, parents=True)
with open(version_result_dir / f"{v_str}_benchmark_metadata.json", "w") as file:
json.dump(benchmark_metadata, file, indent=4)
title = f"PrimAITE v{primaite.__version__.strip()} Learning Benchmark"
fig = _plot_benchmark_metadata(benchmark_metadata, title=title)
this_version_plot_path = version_result_dir / f"{title}.png"
fig.write_image(this_version_plot_path)
fig = _plot_all_benchmarks_combined_session_av()
all_version_plot_path = _RESULTS_ROOT / "PrimAITE Versions Learning Benchmark.png"
fig.write_image(all_version_plot_path)
_build_benchmark_latex_report(benchmark_metadata, this_version_plot_path, all_version_plot_path)
if __name__ == "__main__":

305
benchmark/report.py Normal file
View File

@@ -0,0 +1,305 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import json
import sys
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional
import plotly.graph_objects as go
import polars as pl
import yaml
from plotly.graph_objs import Figure
from pylatex import Command, Document
from pylatex import Figure as LatexFigure
from pylatex import Section, Subsection, Tabular
from pylatex.utils import bold
from utils import _get_system_info
import primaite
PLOT_CONFIG = {
"size": {"auto_size": False, "width": 1500, "height": 900},
"template": "plotly_white",
"range_slider": False,
}
def _build_benchmark_results_dict(start_datetime: datetime, metadata_dict: Dict, config: Dict) -> dict:
num_sessions = len(metadata_dict) # number of sessions
averaged_data = {
"start_timestamp": start_datetime.isoformat(),
"end_datetime": datetime.now().isoformat(),
"primaite_version": primaite.__version__,
"system_info": _get_system_info(),
"total_sessions": num_sessions,
"total_episodes": sum(d["total_episodes"] for d in metadata_dict.values()),
"total_time_steps": sum(d["total_time_steps"] for d in metadata_dict.values()),
"av_s_per_session": sum(d["total_s"] for d in metadata_dict.values()) / num_sessions,
"av_s_per_step": sum(d["s_per_step"] for d in metadata_dict.values()) / num_sessions,
"av_s_per_100_steps_10_nodes": sum(d["s_per_100_steps_10_nodes"] for d in metadata_dict.values())
/ num_sessions,
"combined_av_reward_per_episode": {},
"session_av_reward_per_episode": {k: v["av_reward_per_episode"] for k, v in metadata_dict.items()},
"config": config,
}
# find the average of each episode across all sessions
episodes = metadata_dict[1]["av_reward_per_episode"].keys()
for episode in episodes:
combined_av_reward = (
sum(metadata_dict[k]["av_reward_per_episode"][episode] for k in metadata_dict.keys()) / num_sessions
)
averaged_data["combined_av_reward_per_episode"][episode] = combined_av_reward
return averaged_data
def _get_df_from_episode_av_reward_dict(data: Dict) -> pl.DataFrame:
data: Dict = {"episode": data.keys(), "av_reward": data.values()}
return (
pl.from_dict(data)
.with_columns(rolling_mean=pl.col("av_reward").rolling_mean(window_size=25))
.rename({"rolling_mean": "rolling_av_reward"})
)
def _plot_benchmark_metadata(
benchmark_metadata_dict: Dict,
title: Optional[str] = None,
subtitle: Optional[str] = None,
) -> Figure:
if title:
if subtitle:
title = f"{title} <br>{subtitle}</sup>"
else:
if subtitle:
title = subtitle
layout = go.Layout(
autosize=PLOT_CONFIG["size"]["auto_size"],
width=PLOT_CONFIG["size"]["width"],
height=PLOT_CONFIG["size"]["height"],
)
# Create the line graph with a colored line
fig = go.Figure(layout=layout)
fig.update_layout(template=PLOT_CONFIG["template"])
for session, av_reward_dict in benchmark_metadata_dict["session_av_reward_per_episode"].items():
df = _get_df_from_episode_av_reward_dict(av_reward_dict)
fig.add_trace(
go.Scatter(
x=df["episode"],
y=df["av_reward"],
mode="lines",
name=f"Session {session}",
opacity=0.25,
line={"color": "#a6a6a6"},
)
)
df = _get_df_from_episode_av_reward_dict(benchmark_metadata_dict["combined_av_reward_per_episode"])
fig.add_trace(
go.Scatter(
x=df["episode"], y=df["av_reward"], mode="lines", name="Combined Session Av", line={"color": "#FF0000"}
)
)
fig.add_trace(
go.Scatter(
x=df["episode"],
y=df["rolling_av_reward"],
mode="lines",
name="Rolling Av (Combined Session Av)",
line={"color": "#4CBB17"},
)
)
# Set the layout of the graph
fig.update_layout(
xaxis={
"title": "Episode",
"type": "linear",
},
yaxis={"title": "Total Reward"},
title=title,
)
return fig
def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure:
"""
Plot the Benchmark results for each released version of PrimAITE.
Does this by iterating over the ``benchmark/results`` directory and
extracting the benchmark metadata json for each version that has been
benchmarked. The combined_av_reward_per_episode is extracted from each,
converted into a polars dataframe, and plotted as a scatter line in plotly.
"""
major_v = primaite.__version__.split(".")[0]
title = f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*"
subtitle = "Rolling Av (Combined Session Av)"
if title:
if subtitle:
title = f"{title} <br>{subtitle}</sup>"
else:
if subtitle:
title = subtitle
layout = go.Layout(
autosize=PLOT_CONFIG["size"]["auto_size"],
width=PLOT_CONFIG["size"]["width"],
height=PLOT_CONFIG["size"]["height"],
)
# Create the line graph with a colored line
fig = go.Figure(layout=layout)
fig.update_layout(template=PLOT_CONFIG["template"])
for dir in results_directory.iterdir():
if dir.is_dir():
metadata_file = dir / f"{dir.name}_benchmark_metadata.json"
with open(metadata_file, "r") as file:
metadata_dict = json.load(file)
df = _get_df_from_episode_av_reward_dict(metadata_dict["combined_av_reward_per_episode"])
fig.add_trace(go.Scatter(x=df["episode"], y=df["rolling_av_reward"], mode="lines", name=dir.name))
# Set the layout of the graph
fig.update_layout(
xaxis={
"title": "Episode",
"type": "linear",
},
yaxis={"title": "Total Reward"},
title=title,
)
fig["data"][0]["showlegend"] = True
return fig
def build_benchmark_latex_report(
benchmark_start_time: datetime, session_metadata: Dict, config_path: Path, results_root_path: Path
) -> None:
"""Generates a latex report of the benchmark run."""
# generate report folder
v_str = f"v{primaite.__version__}"
version_result_dir = results_root_path / v_str
version_result_dir.mkdir(exist_ok=True, parents=True)
# load the config file as dict
with open(config_path, "r") as f:
cfg_data = yaml.safe_load(f)
# generate the benchmark metadata dict
benchmark_metadata_dict = _build_benchmark_results_dict(
start_datetime=benchmark_start_time, metadata_dict=session_metadata, config=cfg_data
)
major_v = primaite.__version__.split(".")[0]
with open(version_result_dir / f"{v_str}_benchmark_metadata.json", "w") as file:
json.dump(benchmark_metadata_dict, file, indent=4)
title = f"PrimAITE v{primaite.__version__.strip()} Learning Benchmark"
fig = _plot_benchmark_metadata(benchmark_metadata_dict, title=title)
this_version_plot_path = version_result_dir / f"{title}.png"
fig.write_image(this_version_plot_path)
fig = _plot_all_benchmarks_combined_session_av(results_directory=results_root_path)
all_version_plot_path = results_root_path / "PrimAITE Versions Learning Benchmark.png"
fig.write_image(all_version_plot_path)
geometry_options = {"tmargin": "2.5cm", "rmargin": "2.5cm", "bmargin": "2.5cm", "lmargin": "2.5cm"}
data = benchmark_metadata_dict
primaite_version = data["primaite_version"]
# Create a new document
doc = Document("report", geometry_options=geometry_options)
# Title
doc.preamble.append(Command("title", f"PrimAITE {primaite_version} Learning Benchmark"))
doc.preamble.append(Command("author", "PrimAITE Dev Team"))
doc.preamble.append(Command("date", datetime.now().date()))
doc.append(Command("maketitle"))
sessions = data["total_sessions"]
episodes = session_metadata[1]["total_episodes"] - 1
steps = data["config"]["game"]["max_episode_length"]
# Body
with doc.create(Section("Introduction")):
doc.append(
f"PrimAITE v{primaite_version} was benchmarked automatically upon release. Learning rate metrics "
f"were captured to be referenced during system-level testing and user acceptance testing (UAT)."
)
doc.append(
f"\nThe benchmarking process consists of running {sessions} training session using the same "
f"config file. Each session trains an agent for {episodes} episodes, "
f"with each episode consisting of {steps} steps."
)
doc.append(
f"\nThe total reward per episode from each session is captured. This is then used to calculate an "
f"caverage total reward per episode from the {sessions} individual sessions for smoothing. "
f"Finally, a 25-widow rolling average of the average total reward per session is calculated for "
f"further smoothing."
)
with doc.create(Section("System Information")):
with doc.create(Subsection("Python")):
with doc.create(Tabular("|l|l|")) as table:
table.add_hline()
table.add_row((bold("Version"), sys.version))
table.add_hline()
for section, section_data in data["system_info"].items():
if section_data:
with doc.create(Subsection(section)):
if isinstance(section_data, dict):
with doc.create(Tabular("|l|l|")) as table:
table.add_hline()
for key, value in section_data.items():
table.add_row((bold(key), value))
table.add_hline()
elif isinstance(section_data, list):
headers = section_data[0].keys()
tabs_str = "|".join(["l" for _ in range(len(headers))])
tabs_str = f"|{tabs_str}|"
with doc.create(Tabular(tabs_str)) as table:
table.add_hline()
table.add_row([bold(h) for h in headers])
table.add_hline()
for item in section_data:
table.add_row(item.values())
table.add_hline()
headers_map = {
"total_sessions": "Total Sessions",
"total_episodes": "Total Episodes",
"total_time_steps": "Total Steps",
"av_s_per_session": "Av Session Duration (s)",
"av_s_per_step": "Av Step Duration (s)",
"av_s_per_100_steps_10_nodes": "Av Duration per 100 Steps per 10 Nodes (s)",
}
with doc.create(Section("Stats")):
with doc.create(Subsection("Benchmark Results")):
with doc.create(Tabular("|l|l|")) as table:
table.add_hline()
for section, header in headers_map.items():
if section.startswith("av_"):
table.add_row((bold(header), f"{data[section]:.4f}"))
else:
table.add_row((bold(header), data[section]))
table.add_hline()
with doc.create(Section("Graphs")):
with doc.create(Subsection(f"v{primaite_version} Learning Benchmark Plot")):
with doc.create(LatexFigure(position="h!")) as pic:
pic.add_image(str(this_version_plot_path))
pic.add_caption(f"PrimAITE {primaite_version} Learning Benchmark Plot")
with doc.create(Subsection(f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*")):
with doc.create(LatexFigure(position="h!")) as pic:
pic.add_image(str(all_version_plot_path))
pic.add_caption(f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*")
doc.generate_pdf(str(this_version_plot_path).replace(".png", ""), clean_tex=True)

View File

Before

Width:  |  Height:  |  Size: 79 KiB

After

Width:  |  Height:  |  Size: 79 KiB

View File

Before

Width:  |  Height:  |  Size: 225 KiB

After

Width:  |  Height:  |  Size: 225 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 295 KiB

File diff suppressed because it is too large Load Diff

47
benchmark/utils.py Normal file
View File

@@ -0,0 +1,47 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import platform
from typing import Dict
import psutil
from GPUtil import GPUtil
def get_size(size_bytes: int) -> str:
"""
Scale bytes to its proper format.
e.g:
1253656 => '1.20MB'
1253656678 => '1.17GB'
:
"""
factor = 1024
for unit in ["", "K", "M", "G", "T", "P"]:
if size_bytes < factor:
return f"{size_bytes:.2f}{unit}B"
size_bytes /= factor
def _get_system_info() -> Dict:
"""Builds and returns a dict containing system info."""
uname = platform.uname()
cpu_freq = psutil.cpu_freq()
virtual_mem = psutil.virtual_memory()
swap_mem = psutil.swap_memory()
gpus = GPUtil.getGPUs()
return {
"System": {
"OS": uname.system,
"OS Version": uname.version,
"Machine": uname.machine,
"Processor": uname.processor,
},
"CPU": {
"Physical Cores": psutil.cpu_count(logical=False),
"Total Cores": psutil.cpu_count(logical=True),
"Max Frequency": f"{cpu_freq.max:.2f}Mhz",
},
"Memory": {"Total": get_size(virtual_mem.total), "Swap Total": get_size(swap_mem.total)},
"GPU": [{"Name": gpu.name, "Total Memory": f"{gpu.memoryTotal}MB"} for gpu in gpus],
}

View File

@@ -54,7 +54,7 @@ It is agnostic to the number of agents, their action / observation spaces, and t
It presents a public API providing a method for describing the current state of the simulation, a method that accepts action requests and provides responses, and a method that triggers a timestep advancement.
The Game Layer converts the simulation into a playable game for the agent(s).
it translates between simulation state and Gymnasium.Spaces to pass action / observation data between the agent(s) and the simulation. It is responsible for calculating rewards, managing Multi-Agent RL (MARL) action turns, and via a single agent interface can interact with Blue, Red and Green agents.
It translates between simulation state and Gymnasium.Spaces to pass action / observation data between the agent(s) and the simulation. It is responsible for calculating rewards, managing Multi-Agent RL (MARL) action turns, and via a single agent interface can interact with Blue, Red and Green agents.
Agents can either generate their own scripted behaviour or accept input behaviour from an RL agent.

View File

@@ -1,6 +1,6 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
.. _about:

View File

@@ -77,6 +77,6 @@ The following extensions should now be installed
:width: 300
:align: center
VSCode will then ask for a Python environment version to use. PrimAITE is compatible with Python versions 3.8 - 3.10
VSCode will then ask for a Python environment version to use. PrimAITE is compatible with Python versions 3.8 - 3.11
You should now be able to interact with the notebook.

View File

@@ -0,0 +1,347 @@
.. only:: comment
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
.. _NMAP:
NMAP
====
Overview
--------
The NMAP application is used to simulate network scanning activities. NMAP is a powerful tool that helps in discovering
hosts and services on a network. It provides functionalities such as ping scans to discover active hosts and port scans
to detect open ports on those hosts.
The NMAP application is essential for network administrators and security professionals to map out a network's
structure, identify active devices, and find potential vulnerabilities by discovering open ports and running services.
However, it is also a tool frequently used by attackers during the reconnaissance stage of a cyber attack to gather
information about the target network.
Scan Types
----------
Ping Scan
^^^^^^^^^
A ping scan is used to identify which hosts on a network are active and reachable. This is achieved by sending ICMP
Echo Request packets (ping) to the target IP addresses. If a host responds with an ICMP Echo Reply, it is considered
active. Ping scans are useful for quickly mapping out live hosts in a network.
Port Scan
^^^^^^^^^
A port scan is used to detect open ports on a target host or range of hosts. Open ports can indicate running services
that might be exploitable or require securing. Port scans help in understanding the services available on a network and
identifying potential entry points for attacks. There are three types of port scans based on the scope:
- **Horizontal Port Scan**: This scan targets a specific port across a range of IP addresses. It helps in identifying
which hosts have a particular service running.
- **Vertical Port Scan**: This scan targets multiple ports on a single IP address. It provides detailed information
about the services running on a specific host.
- **Box Scan**: This combines both horizontal and vertical scans, targeting multiple ports across multiple IP addresses.
It gives a comprehensive view of the network's service landscape.
Example Usage
-------------
The network we use for these examples is defined below:
.. code-block:: python
from ipaddress import IPv4Network
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.network.router import Router
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.system.applications.nmap import NMAP
from primaite.simulator.system.services.database.database_service import DatabaseService
# Initialize the network
network = Network()
# Set up the router
router = Router(hostname="router", start_up_duration=0)
router.power_on()
router.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
# Set up PC 1
pc_1 = Computer(
hostname="pc_1",
ip_address="192.168.1.11",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0
)
pc_1.power_on()
# Set up PC 2
pc_2 = Computer(
hostname="pc_2",
ip_address="192.168.1.12",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0
)
pc_2.power_on()
pc_2.software_manager.install(DatabaseService)
pc_2.software_manager.software["DatabaseService"].start() # start the postgres server
# Set up PC 3
pc_3 = Computer(
hostname="pc_3",
ip_address="192.168.1.13",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0
)
# Don't power on PC 3
# Set up the switch
switch = Switch(hostname="switch", start_up_duration=0)
switch.power_on()
# Connect devices
network.connect(router.network_interface[1], switch.network_interface[24])
network.connect(switch.network_interface[1], pc_1.network_interface[1])
network.connect(switch.network_interface[2], pc_2.network_interface[1])
network.connect(switch.network_interface[3], pc_3.network_interface[1])
pc_1_nmap: NMAP = pc_1.software_manager.software["NMAP"]
Ping Scan
^^^^^^^^^
Perform a ping scan to find active hosts in the `192.168.1.0/24` subnet:
.. code-block:: python
:caption: Ping Scan Code
active_hosts = pc_1_nmap.ping_scan(target_ip_address=IPv4Network("192.168.1.0/24"))
.. code-block:: python
:caption: Ping Scan Return Value
[
IPv4Address('192.168.1.11'),
IPv4Address('192.168.1.12'),
IPv4Address('192.168.1.1')
]
.. code-block:: text
:caption: Ping Scan Output
+-------------------------+
| pc_1 NMAP Ping Scan |
+--------------+----------+
| IP Address | Can Ping |
+--------------+----------+
| 192.168.1.1 | True |
| 192.168.1.11 | True |
| 192.168.1.12 | True |
+--------------+----------+
Horizontal Port Scan
^^^^^^^^^^^^^^^^^^^^
Perform a horizontal port scan on port 5432 across multiple IP addresses:
.. code-block:: python
:caption: Horizontal Port Scan Code
horizontal_scan_results = pc_1_nmap.port_scan(
target_ip_address=[IPv4Address("192.168.1.12"), IPv4Address("192.168.1.13")],
target_port=Port(5432 )
)
.. code-block:: python
:caption: Horizontal Port Scan Return Value
{
IPv4Address('192.168.1.12'): {
<IPProtocol.TCP: 'tcp'>: [
<Port.POSTGRES_SERVER: 5432>
]
}
}
.. code-block:: text
:caption: Horizontal Port Scan Output
+--------------------------------------------------+
| pc_1 NMAP Port Scan (Horizontal) |
+--------------+------+-----------------+----------+
| IP Address | Port | Name | Protocol |
+--------------+------+-----------------+----------+
| 192.168.1.12 | 5432 | POSTGRES_SERVER | TCP |
+--------------+------+-----------------+----------+
Vertical Post Scan
^^^^^^^^^^^^^^^^^^
Perform a vertical port scan on multiple ports on a single IP address:
.. code-block:: python
:caption: Vertical Port Scan Code
vertical_scan_results = pc_1_nmap.port_scan(
target_ip_address=[IPv4Address("192.168.1.12")],
target_port=[Port(21), Port(22), Port(80), Port(443)]
)
.. code-block:: python
:caption: Vertical Port Scan Return Value
{
IPv4Address('192.168.1.12'): {
<IPProtocol.TCP: 'tcp'>: [
<Port.FTP: 21>,
<Port.HTTP: 80>
]
}
}
.. code-block:: text
:caption: Vertical Port Scan Output
+---------------------------------------+
| pc_1 NMAP Port Scan (Vertical) |
+--------------+------+------+----------+
| IP Address | Port | Name | Protocol |
+--------------+------+------+----------+
| 192.168.1.12 | 21 | FTP | TCP |
| 192.168.1.12 | 80 | HTTP | TCP |
+--------------+------+------+----------+
Box Scan
^^^^^^^^
Perform a box scan on multiple ports across multiple IP addresses:
.. code-block:: python
:caption: Box Port Scan Code
# Power PC 3 on before performing the box scan
pc_3.power_on()
box_scan_results = pc_1_nmap.port_scan(
target_ip_address=[IPv4Address("192.168.1.12"), IPv4Address("192.168.1.13")],
target_port=[Port(21), Port(22), Port(80), Port(443)]
)
.. code-block:: python
:caption: Box Port Scan Return Value
{
IPv4Address('192.168.1.13'): {
<IPProtocol.TCP: 'tcp'>: [
<Port.FTP: 21>,
<Port.HTTP: 80>
]
},
IPv4Address('192.168.1.12'): {
<IPProtocol.TCP: 'tcp'>: [
<Port.FTP: 21>,
<Port.HTTP: 80>
]
}
}
.. code-block:: text
:caption: Box Port Scan Output
+---------------------------------------+
| pc_1 NMAP Port Scan (Box) |
+--------------+------+------+----------+
| IP Address | Port | Name | Protocol |
+--------------+------+------+----------+
| 192.168.1.12 | 21 | FTP | TCP |
| 192.168.1.12 | 80 | HTTP | TCP |
| 192.168.1.13 | 21 | FTP | TCP |
| 192.168.1.13 | 80 | HTTP | TCP |
+--------------+------+------+----------+
Full Box Scan
^^^^^^^^^^^^^
Perform a full box scan on all ports, over both TCP and UDP, on a whole subnet:
.. code-block:: python
:caption: Box Port Scan Code
# Power PC 3 on before performing the full box scan
pc_3.power_on()
full_box_scan_results = pc_1_nmap.port_scan(
target_ip_address=IPv4Network("192.168.1.0/24"),
)
.. code-block:: python
:caption: Box Port Scan Return Value
{
IPv4Address('192.168.1.11'): {
<IPProtocol.UDP: 'udp'>: [
<Port.ARP: 219>
]
},
IPv4Address('192.168.1.1'): {
<IPProtocol.UDP: 'udp'>: [
<Port.ARP: 219>
]
},
IPv4Address('192.168.1.12'): {
<IPProtocol.TCP: 'tcp'>: [
<Port.HTTP: 80>,
<Port.DNS: 53>,
<Port.POSTGRES_SERVER: 5432>,
<Port.FTP: 21>
],
<IPProtocol.UDP: 'udp'>: [
<Port.NTP: 123>,
<Port.ARP: 219>
]
},
IPv4Address('192.168.1.13'): {
<IPProtocol.TCP: 'tcp'>: [
<Port.HTTP: 80>,
<Port.DNS: 53>,
<Port.FTP: 21>
],
<IPProtocol.UDP: 'udp'>: [
<Port.NTP: 123>,
<Port.ARP: 219>
]
}
}
.. code-block:: text
:caption: Box Port Scan Output
+--------------------------------------------------+
| pc_1 NMAP Port Scan (Box) |
+--------------+------+-----------------+----------+
| IP Address | Port | Name | Protocol |
+--------------+------+-----------------+----------+
| 192.168.1.1 | 219 | ARP | UDP |
| 192.168.1.11 | 219 | ARP | UDP |
| 192.168.1.12 | 21 | FTP | TCP |
| 192.168.1.12 | 53 | DNS | TCP |
| 192.168.1.12 | 80 | HTTP | TCP |
| 192.168.1.12 | 123 | NTP | UDP |
| 192.168.1.12 | 219 | ARP | UDP |
| 192.168.1.12 | 5432 | POSTGRES_SERVER | TCP |
| 192.168.1.13 | 21 | FTP | TCP |
| 192.168.1.13 | 53 | DNS | TCP |
| 192.168.1.13 | 80 | HTTP | TCP |
| 192.168.1.13 | 123 | NTP | UDP |
| 192.168.1.13 | 219 | ARP | UDP |
+--------------+------+-----------------+----------+

View File

@@ -24,7 +24,7 @@ For each variation that could be used in a placeholder, there is a separate yaml
The data that fills the placeholder is defined as a YAML Anchor in a separate file, denoted by an ampersand ``&anchor``.
Learn more about YAML Aliases and Anchors here.
Learn more about YAML Aliases and Anchors `here <https://yaml.org/spec/1.2.2/#3222-anchors-and-aliases>`_.
Schedule
********

View File

@@ -33,7 +33,7 @@ dependencies = [
"numpy==1.23.5",
"platformdirs==3.5.1",
"plotly==5.15.0",
"polars==0.18.4",
"polars==0.20.30",
"prettytable==3.8.0",
"PyYAML==6.0",
"typer[all]==0.9.0",

View File

@@ -1 +1 @@
3.0.0b9
3.0.0

View File

@@ -11,7 +11,7 @@ AbstractAction. The ActionManager is responsible for:
"""
import itertools
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from gymnasium import spaces
@@ -871,6 +871,74 @@ class NetworkPortDisableAction(AbstractAction):
return ["network", "node", target_nodename, "network_interface", port_id, "disable"]
class NodeNMAPPingScanAction(AbstractAction):
"""Action which performs an NMAP ping scan."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, source_node: str, target_ip_address: Union[str, List[str]]) -> List[str]: # noqa
return [
"network",
"node",
source_node,
"application",
"NMAP",
"ping_scan",
{"target_ip_address": target_ip_address},
]
class NodeNMAPPortScanAction(AbstractAction):
"""Action which performs an NMAP port scan."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(
self,
source_node: str,
target_ip_address: Union[str, List[str]],
target_protocol: Optional[Union[str, List[str]]] = None,
target_port: Optional[Union[str, List[str]]] = None,
) -> List[str]: # noqa
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
source_node,
"application",
"NMAP",
"port_scan",
{"target_ip_address": target_ip_address, "target_port": target_port, "target_protocol": target_protocol},
]
class NodeNetworkServiceReconAction(AbstractAction):
"""Action which performs an NMAP network service recon (ping scan followed by port scan)."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(
self,
source_node: str,
target_ip_address: Union[str, List[str]],
target_protocol: Optional[Union[str, List[str]]] = None,
target_port: Optional[Union[str, List[str]]] = None,
) -> List[str]: # noqa
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
source_node,
"application",
"NMAP",
"network_service_recon",
{"target_ip_address": target_ip_address, "target_port": target_port, "target_protocol": target_protocol},
]
class ActionManager:
"""Class which manages the action space for an agent."""
@@ -916,6 +984,9 @@ class ActionManager:
"HOST_NIC_DISABLE": HostNICDisableAction,
"NETWORK_PORT_ENABLE": NetworkPortEnableAction,
"NETWORK_PORT_DISABLE": NetworkPortDisableAction,
"NODE_NMAP_PING_SCAN": NodeNMAPPingScanAction,
"NODE_NMAP_PORT_SCAN": NodeNMAPPortScanAction,
"NODE_NMAP_NETWORK_SERVICE_RECON": NodeNetworkServiceReconAction,
}
"""Dictionary which maps action type strings to the corresponding action class."""

View File

@@ -361,11 +361,6 @@ class PrimaiteGame:
server_ip_address=IPv4Address(opt.get("server_ip")),
server_password=opt.get("server_password"),
payload=opt.get("payload", "ENCRYPT"),
c2_beacon_p_of_success=float(opt.get("c2_beacon_p_of_success", "0.5")),
target_scan_p_of_success=float(opt.get("target_scan_p_of_success", "0.1")),
ransomware_encrypt_p_of_success=float(
opt.get("ransomware_encrypt_p_of_success", "0.1")
),
)
elif application_type == "DatabaseClient":
if "options" in application_cfg:

View File

@@ -3,7 +3,7 @@ from typing import Dict, ForwardRef, List, Literal, Union
from pydantic import BaseModel, ConfigDict, StrictBool, validate_call
RequestFormat = List[Union[str, int, float]]
RequestFormat = List[Union[str, int, float, Dict]]
RequestResponse = ForwardRef("RequestResponse")
"""This makes it possible to type-hint RequestResponse.from_bool return type."""

View File

@@ -4,13 +4,15 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Customising red agents\n",
"# Customising UC2 Red Agents\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
"\n",
"This notebook will go over some examples of how red agent behaviour can be varied by changing its configuration parameters.\n",
"\n",
"First, let's load the standard Data Manipulation config file, and see what the red agent does.\n",
"\n",
"*(For a full explanation of the Data Manipulation scenario, check out the notebook `Data-Manipulation-E2E-Demonstration.ipynb`)*"
"*(For a full explanation of the Data Manipulation scenario, check out the data manipulation scenario notebook)*"
]
},
{

View File

@@ -4,7 +4,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data Manipulation Scenario\n"
"# Data Manipulation Scenario\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK"
]
},
{
@@ -79,7 +81,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Reinforcement learning details"
"## Reinforcement learning details"
]
},
{
@@ -692,7 +694,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},

View File

@@ -4,7 +4,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Getting information out of PrimAITE"
"# Getting information out of PrimAITE\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n"
]
},
{
@@ -160,7 +162,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.10.8"
}
},
"nbformat": 4,

View File

@@ -6,7 +6,9 @@
"source": [
"# Requests and Responses\n",
"\n",
"Agents interact with the PrimAITE simulation via the Request system.\n"
"Agents interact with the PrimAITE simulation via the Request system.\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n"
]
},
{

View File

@@ -4,7 +4,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train a Multi agent system using RLLIB\n",
"# Train a Multi agent system using RLLIB\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
"\n",
"This notebook will demonstrate how to use the `PrimaiteRayMARLEnv` to train a very basic system with two PPO agents."
]

View File

@@ -4,7 +4,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train a Single agent system using RLLib\n",
"# Train a Single agent system using RLLib\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
"\n",
"This notebook will demonstrate how to use PrimaiteRayEnv to train a basic PPO agent."
]
},

View File

@@ -6,6 +6,8 @@
"source": [
"# Training an SB3 Agent\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
"\n",
"This notebook will demonstrate how to use primaite to create and train a PPO agent, using a pre-defined configuration file."
]
},
@@ -166,7 +168,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -180,7 +182,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.8"
}
},
"nbformat": 4,

View File

@@ -6,6 +6,8 @@
"source": [
"# Using Episode Schedules\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
"\n",
"PrimAITE supports the ability to use different variations on a scenario at different episodes. This can be used to increase \n",
"domain randomisation to prevent overfitting, or to set up curriculum learning to train agents to perform more complicated tasks.\n",
"\n",

View File

@@ -4,7 +4,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Simple multi-processing demo using SubprocVecEnv from SB3"
"# Simple multi-processing demonstration\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
"\n",
"This notebook uses SubprocVecEnv from SB3."
]
},
{
@@ -139,7 +143,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -38,6 +38,8 @@ class PrimaiteGymEnv(gymnasium.Env):
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
self.episode_counter: int = 0
"""Current episode number."""
self.total_reward_per_episode: Dict[int, float] = {}
"""Average rewards of agents per episode."""
@property
def agent(self) -> ProxyAgent:
@@ -90,6 +92,8 @@ class PrimaiteGymEnv(gymnasium.Env):
f"Resetting environment, episode {self.episode_counter}, "
f"avg. reward: {self.agent.reward_function.total_reward}"
)
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)

View File

@@ -6,7 +6,9 @@
"source": [
"# Build a simulation using the Python API\n",
"\n",
"Currently, this notebook manipulates the simulation by directly placing objects inside of the attributes of the network and domain. It should be refactored when proper methods exist for adding these objects.\n"
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
"\n",
"Currently, this notebook manipulates the simulation by directly placing objects inside of the attributes of the network and domain. It should be refactored when proper methods exist for adding these objects."
]
},
{

View File

@@ -7,6 +7,8 @@
"source": [
"# PrimAITE Router Simulation Demo\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
"\n",
"This demo uses a modified version of the ARCD Use Case 2 Network (seen below) to demonstrate the capabilities of the Network simulator in PrimAITE."
]
},

View File

@@ -222,7 +222,7 @@ class SimComponent(BaseModel):
return state
@validate_call
def apply_request(self, request: RequestFormat, context: Dict = {}) -> RequestResponse:
def apply_request(self, request: RequestFormat, context: Optional[Dict] = None) -> RequestResponse:
"""
Apply a request to a simulation component. Request data is passed in as a 'namespaced' list of strings.
@@ -240,6 +240,8 @@ class SimComponent(BaseModel):
:param: context: Dict containing context for requests
:type context: Dict
"""
if not context:
context = None
if self._request_manager is None:
return
return self._request_manager(request, context)

View File

@@ -29,6 +29,7 @@ from primaite.simulator.network.nmne import (
NMNE_CAPTURE_KEYWORDS,
)
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.packet_capture import PacketCapture
from primaite.simulator.system.core.session_manager import SessionManager
@@ -37,6 +38,7 @@ from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.processes.process import Process
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.software import IOSoftware
from primaite.utils.converters import convert_dict_enum_keys_to_enum_values
from primaite.utils.validators import IPV4Address
IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware)
@@ -108,10 +110,14 @@ class NetworkInterface(SimComponent, ABC):
nmne: Dict = Field(default_factory=lambda: {})
"A dict containing details of the number of malicious network events captured."
traffic: Dict = Field(default_factory=lambda: {})
"A dict containing details of the inbound and outbound traffic by port and protocol."
def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().setup_for_episode(episode=episode)
self.nmne = {}
self.traffic = {}
if episode and self.pcap and SIM_OUTPUT.save_pcap_logs:
self.pcap.current_episode = episode
self.pcap.setup_logger()
@@ -147,6 +153,7 @@ class NetworkInterface(SimComponent, ABC):
)
if CAPTURE_NMNE:
state.update({"nmne": {k: v for k, v in self.nmne.items()}})
state.update({"traffic": convert_dict_enum_keys_to_enum_values(self.traffic)})
return state
@abstractmethod
@@ -237,6 +244,47 @@ class NetworkInterface(SimComponent, ABC):
# Increment a generic counter if keyword capturing is not enabled
keyword_level["*"] = keyword_level.get("*", 0) + 1
def _capture_traffic(self, frame: Frame, inbound: bool = True):
"""
Capture traffic statistics at the Network Interface.
:param frame: The network frame containing the traffic data.
:type frame: Frame
:param inbound: Flag indicating if the traffic is inbound or outbound. Defaults to True.
:type inbound: bool
"""
# Determine the direction of the traffic
direction = "inbound" if inbound else "outbound"
# Initialize protocol and port variables
protocol = None
port = None
# Identify the protocol and port from the frame
if frame.tcp:
protocol = IPProtocol.TCP
port = frame.tcp.dst_port
elif frame.udp:
protocol = IPProtocol.UDP
port = frame.udp.dst_port
elif frame.icmp:
protocol = IPProtocol.ICMP
# Ensure the protocol is in the capture dict
if protocol not in self.traffic:
self.traffic[protocol] = {}
# Handle non-ICMP protocols that use ports
if protocol != IPProtocol.ICMP:
if port not in self.traffic[protocol]:
self.traffic[protocol][port] = {"inbound": 0, "outbound": 0}
self.traffic[protocol][port][direction] += frame.size
else:
# Handle ICMP protocol separately (ICMP does not use ports)
if not self.traffic[protocol]:
self.traffic[protocol] = {"inbound": 0, "outbound": 0}
self.traffic[protocol][direction] += frame.size
@abstractmethod
def send_frame(self, frame: Frame) -> bool:
"""
@@ -246,6 +294,7 @@ class NetworkInterface(SimComponent, ABC):
:return: A boolean indicating whether the frame was successfully sent.
"""
self._capture_nmne(frame, inbound=False)
self._capture_traffic(frame, inbound=False)
@abstractmethod
def receive_frame(self, frame: Frame) -> bool:
@@ -256,6 +305,7 @@ class NetworkInterface(SimComponent, ABC):
:return: A boolean indicating whether the frame was successfully received.
"""
self._capture_nmne(frame, inbound=True)
self._capture_traffic(frame, inbound=True)
def __str__(self) -> str:
"""
@@ -767,6 +817,24 @@ class Node(SimComponent):
self.session_manager.software_manager = self.software_manager
self._install_system_software()
def ip_is_network_interface(self, ip_address: IPv4Address, enabled_only: bool = False) -> bool:
"""
Checks if a given IP address belongs to any of the nodes interfaces.
:param ip_address: The IP address to check.
:param enabled_only: If True, only considers enabled network interfaces.
:return: True if the IP address is assigned to one of the nodes interfaces; False otherwise.
"""
for network_interface in self.network_interface.values():
if not hasattr(network_interface, "ip_address"):
continue
if network_interface.ip_address == ip_address:
if enabled_only:
return network_interface.enabled
else:
return True
return False
def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().setup_for_episode(episode=episode)

View File

@@ -8,6 +8,8 @@ from primaite import getLogger
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface, Link, Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.nmap import NMAP
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.arp.arp import ARP, ARPPacket
from primaite.simulator.system.services.dns.dns_client import DNSClient
@@ -303,6 +305,7 @@ class HostNode(Node):
"DNSClient": DNSClient,
"NTPClient": NTPClient,
"WebBrowser": WebBrowser,
"NMAP": NMAP,
}
"""List of system software that is automatically installed on nodes."""
@@ -315,6 +318,16 @@ class HostNode(Node):
super().__init__(**kwargs)
self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask))
@property
def nmap(self) -> Optional[NMAP]:
"""
Return the NMAP application installed on the Node.
:return: NMAP application installed on the Node.
:rtype: Optional[NMAP]
"""
return self.software_manager.software.get("NMAP")
@property
def arp(self) -> Optional[ARP]:
"""
@@ -366,8 +379,15 @@ class HostNode(Node):
elif frame.udp:
dst_port = frame.udp.dst_port
can_accept_nmap = False
if self.software_manager.software.get("NMAP"):
if self.software_manager.software["NMAP"].operating_state == ApplicationOperatingState.RUNNING:
can_accept_nmap = True
accept_nmap = can_accept_nmap and frame.payload.__class__.__name__ == "PortScanPayload"
accept_frame = False
if frame.icmp or dst_port in self.software_manager.get_open_ports():
if frame.icmp or dst_port in self.software_manager.get_open_ports() or accept_nmap:
# accept the frame as the port is open or if it's an ICMP frame
accept_frame = True

View File

@@ -19,6 +19,7 @@ from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.nmap import NMAP
from primaite.simulator.system.core.session_manager import SessionManager
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.services.arp.arp import ARP
@@ -1239,6 +1240,7 @@ class Router(NetworkNode):
icmp.router = self
self.software_manager.install(RouterARP)
self.arp.router = self
self.software_manager.install(NMAP)
def _set_default_acl(self):
"""

View File

@@ -271,9 +271,16 @@ class DatabaseClient(Application):
Calls disconnect on all client connections to ensure that both client and server connections are killed.
"""
while self.client_connections.values():
client_connection = self.client_connections[next(iter(self.client_connections.keys()))]
client_connection.disconnect()
while self.client_connections:
conn_key = next(iter(self.client_connections.keys()))
conn_obj: DatabaseClientConnection = self.client_connections[conn_key]
conn_obj.disconnect()
if conn_obj.is_active or conn_key in self.client_connections:
self.sys_log.error(
"Attempted to uninstall database client but could not drop active connections. "
"Forcing uninstall anyway."
)
self.client_connections.pop(conn_key, None)
super().uninstall()
def get_new_connection(self) -> Optional[DatabaseClientConnection]:

View File

@@ -0,0 +1,452 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, Final, List, Optional, Set, Tuple, Union
from prettytable import PrettyTable
from pydantic import validate_call
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.utils.validators import IPV4Address
class PortScanPayload(SimComponent):
"""
A class representing the payload for a port scan.
:ivar ip_address: The target IP address for the port scan.
:ivar port: The target port for the port scan.
:ivar protocol: The protocol used for the port scan.
:ivar request:Flag to indicate whether this is a request or not.
"""
ip_address: IPV4Address
port: Port
protocol: IPProtocol
request: bool = True
def describe_state(self) -> Dict:
"""
Describe the state of the port scan payload.
:return: A dictionary representation of the port scan payload state.
:rtype: Dict
"""
state = super().describe_state()
state["ip_address"] = str(self.ip_address)
state["port"] = self.port.value
state["protocol"] = self.protocol.value
state["request"] = self.request
return state
class NMAP(Application):
"""
A class representing the NMAP application for network scanning.
NMAP is a network scanning tool used to discover hosts and services on a network. It provides functionalities such
as ping scans to discover active hosts and port scans to detect open ports on those hosts.
"""
_active_port_scans: Dict[str, PortScanPayload] = {}
_port_scan_responses: Dict[str, PortScanPayload] = {}
_PORT_SCAN_TYPE_MAP: Final[Dict[Tuple[bool, bool], str]] = {
(True, True): "Box",
(True, False): "Horizontal",
(False, True): "Vertical",
(False, False): "Port",
}
def __init__(self, **kwargs):
kwargs["name"] = "NMAP"
kwargs["port"] = Port.NONE
kwargs["protocol"] = IPProtocol.NONE
super().__init__(**kwargs)
def _can_perform_network_action(self) -> bool:
"""
Checks if the NMAP application can perform outbound network actions.
This is done by checking the parent application can_per_action functionality. Then checking if there is an
enabled NIC that can be used for outbound traffic.
:return: True if outbound network actions can be performed, otherwise False.
"""
if not super()._can_perform_action():
return False
for nic in self.software_manager.node.network_interface.values():
if nic.enabled:
return True
return False
def _init_request_manager(self) -> RequestManager:
def _ping_scan_action(request: List[Any], context: Any) -> RequestResponse:
results = self.ping_scan(target_ip_address=request[0]["target_ip_address"], json_serializable=True)
if not self._can_perform_network_action():
return RequestResponse.from_bool(False)
return RequestResponse(
status="success",
data={"live_hosts": results},
)
def _port_scan_action(request: List[Any], context: Any) -> RequestResponse:
results = self.port_scan(**request[0], json_serializable=True)
if not self._can_perform_network_action():
return RequestResponse.from_bool(False)
return RequestResponse(
status="success",
data=results,
)
def _network_service_recon_action(request: List[Any], context: Any) -> RequestResponse:
results = self.network_service_recon(**request[0], json_serializable=True)
if not self._can_perform_network_action():
return RequestResponse.from_bool(False)
return RequestResponse(
status="success",
data=results,
)
rm = RequestManager()
rm.add_request(
name="ping_scan",
request_type=RequestType(func=_ping_scan_action),
)
rm.add_request(
name="port_scan",
request_type=RequestType(func=_port_scan_action),
)
rm.add_request(
name="network_service_recon",
request_type=RequestType(func=_network_service_recon_action),
)
return rm
def describe_state(self) -> Dict:
"""
Describe the state of the NMAP application.
:return: A dictionary representation of the NMAP application's state.
:rtype: Dict
"""
return super().describe_state()
@staticmethod
def _explode_ip_address_network_array(
target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]]
) -> Set[IPv4Address]:
"""
Explode a mixed array of IP addresses and networks into a set of individual IP addresses.
This method takes a combination of single and lists of IPv4 addresses and IPv4 networks, expands any networks
into their constituent subnet useable IP addresses, and returns a set of unique IP addresses. Broadcast and
network addresses are excluded from the result.
:param target_ip_address: A single or list of IPv4 addresses and networks.
:type target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]]
:return: A set of unique IPv4 addresses expanded from the input.
:rtype: Set[IPv4Address]
"""
if isinstance(target_ip_address, IPv4Address) or isinstance(target_ip_address, IPv4Network):
target_ip_address = [target_ip_address]
ip_addresses: List[IPV4Address] = []
for ip_address in target_ip_address:
if isinstance(ip_address, IPv4Network):
ip_addresses += [
ip
for ip in ip_address.hosts()
if not ip == ip_address.broadcast_address and not ip == ip_address.network_address
]
else:
ip_addresses.append(ip_address)
return set(ip_addresses)
@validate_call()
def ping_scan(
self,
target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]],
show: bool = True,
show_online_only: bool = True,
json_serializable: bool = False,
) -> Union[List[IPV4Address], List[str]]:
"""
Perform a ping scan on the target IP address(es).
:param target_ip_address: The target IP address(es) or network(s) for the ping scan.
:type target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]]
:param show: Flag indicating whether to display the scan results. Defaults to True.
:type show: bool
:param show_online_only: Flag indicating whether to show only the online hosts. Defaults to True.
:type show_online_only: bool
:param json_serializable: Flag indicating whether the return value should be json serializable. Defaults to
False.
:type json_serializable: bool
:return: A list of active IP addresses that responded to the ping.
:rtype: Union[List[IPV4Address], List[str]]
"""
active_nodes = []
if show:
table = PrettyTable(["IP Address", "Can Ping"])
table.align = "l"
table.title = f"{self.software_manager.node.hostname} NMAP Ping Scan"
ip_addresses = self._explode_ip_address_network_array(target_ip_address)
for ip_address in ip_addresses:
# Prevent ping scan on this node
if self.software_manager.node.ip_is_network_interface(ip_address=ip_address):
continue
can_ping = self.software_manager.icmp.ping(ip_address)
if can_ping:
active_nodes.append(ip_address if not json_serializable else str(ip_address))
if show and (can_ping or not show_online_only):
table.add_row([ip_address, can_ping])
if show:
print(table.get_string(sortby="IP Address"))
return active_nodes
def _determine_port_scan_type(self, target_ip_addresses: List[IPV4Address], target_ports: List[Port]) -> str:
"""
Determine the type of port scan based on the number of target IP addresses and ports.
:param target_ip_addresses: The list of target IP addresses.
:type target_ip_addresses: List[IPV4Address]
:param target_ports: The list of target ports.
:type target_ports: List[Port]
:return: The type of port scan.
:rtype: str
"""
vertical_scan = len(target_ports) > 1
horizontal_scan = len(target_ip_addresses) > 1
return self._PORT_SCAN_TYPE_MAP[horizontal_scan, vertical_scan]
def _check_port_open_on_ip_address(
self,
ip_address: IPv4Address,
port: Port,
protocol: IPProtocol,
is_re_attempt: bool = False,
port_scan_uuid: Optional[str] = None,
) -> bool:
"""
Check if a port is open on a specific IP address.
:param ip_address: The target IP address.
:type ip_address: IPv4Address
:param port: The target port.
:type port: Port
:param protocol: The protocol used for the port scan.
:type protocol: IPProtocol
:param is_re_attempt: Flag indicating if this is a reattempt. Defaults to False.
:type is_re_attempt: bool
:param port_scan_uuid: The UUID of the port scan payload. Defaults to None.
:type port_scan_uuid: Optional[str]
:return: True if the port is open, False otherwise.
:rtype: bool
"""
# The recursive base case
if is_re_attempt:
# Return True if a response has been received, otherwise return False
if port_scan_uuid in self._port_scan_responses:
self._port_scan_responses.pop(port_scan_uuid)
return True
return False
# Send the port scan request
payload = PortScanPayload(ip_address=ip_address, port=port, protocol=protocol)
self._active_port_scans[payload.uuid] = payload
self.sys_log.info(
f"{self.name}: Sending port scan request over {payload.protocol.name} on port {payload.port.value} "
f"({payload.port.name}) to {payload.ip_address}"
)
self.software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=ip_address, src_port=port, dest_port=port, ip_protocol=protocol
)
# Recursively call this function with as a reattempt
return self._check_port_open_on_ip_address(
ip_address=ip_address, port=port, protocol=protocol, is_re_attempt=True, port_scan_uuid=payload.uuid
)
def _process_port_scan_response(self, payload: PortScanPayload):
"""
Process the response to a port scan request.
:param payload: The port scan payload received in response.
:type payload: PortScanPayload
"""
if payload.uuid in self._active_port_scans:
self._active_port_scans.pop(payload.uuid)
self._port_scan_responses[payload.uuid] = payload
self.sys_log.info(
f"{self.name}: Received port scan response from {payload.ip_address} on port {payload.port.value} "
f"({payload.port.name}) over {payload.protocol.name}"
)
def _process_port_scan_request(self, payload: PortScanPayload, session_id: str) -> None:
"""
Process a port scan request.
:param payload: The port scan payload received in the request.
:type payload: PortScanPayload
:param session_id: The session ID for the port scan request.
:type session_id: str
"""
if self.software_manager.check_port_is_open(port=payload.port, protocol=payload.protocol):
payload.request = False
self.sys_log.info(
f"{self.name}: Responding to port scan request for port {payload.port.value} "
f"({payload.port.name}) over {payload.protocol.name}",
True,
)
self.software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
@validate_call()
def port_scan(
self,
target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]],
target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None,
target_port: Optional[Union[Port, List[Port]]] = None,
show: bool = True,
json_serializable: bool = False,
) -> Dict[IPv4Address, Dict[IPProtocol, List[Port]]]:
"""
Perform a port scan on the target IP address(es).
:param target_ip_address: The target IP address(es) or network(s) for the port scan.
:type target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]]
:param target_protocol: The protocol(s) to use for the port scan. Defaults to None, which includes TCP and UDP.
:type target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]]
:param target_port: The port(s) to scan. Defaults to None, which includes all valid ports.
:type target_port: Optional[Union[Port, List[Port]]]
:param show: Flag indicating whether to display the scan results. Defaults to True.
:type show: bool
:param json_serializable: Flag indicating whether the return value should be JSON serializable. Defaults to
False.
:type json_serializable: bool
:return: A dictionary mapping IP addresses to protocols and lists of open ports.
:rtype: Dict[IPv4Address, Dict[IPProtocol, List[Port]]]
"""
ip_addresses = self._explode_ip_address_network_array(target_ip_address)
if isinstance(target_port, Port):
target_port = [target_port]
elif target_port is None:
target_port = [port for port in Port if port not in {Port.NONE, Port.UNUSED}]
if isinstance(target_protocol, IPProtocol):
target_protocol = [target_protocol]
elif target_protocol is None:
target_protocol = [IPProtocol.TCP, IPProtocol.UDP]
scan_type = self._determine_port_scan_type(list(ip_addresses), target_port)
active_ports = {}
if show:
table = PrettyTable(["IP Address", "Port", "Name", "Protocol"])
table.align = "l"
table.title = f"{self.software_manager.node.hostname} NMAP Port Scan ({scan_type})"
self.sys_log.info(f"{self.name}: Starting port scan")
for ip_address in ip_addresses:
# Prevent port scan on this node
if self.software_manager.node.ip_is_network_interface(ip_address=ip_address):
continue
for protocol in target_protocol:
for port in set(target_port):
port_open = self._check_port_open_on_ip_address(ip_address=ip_address, port=port, protocol=protocol)
if port_open:
table.add_row([ip_address, port.value, port.name, protocol.name])
_ip_address = ip_address if not json_serializable else str(ip_address)
_protocol = protocol if not json_serializable else protocol.value
_port = port if not json_serializable else port.value
if _ip_address not in active_ports:
active_ports[_ip_address] = dict()
if _protocol not in active_ports[_ip_address]:
active_ports[_ip_address][_protocol] = []
active_ports[_ip_address][_protocol].append(_port)
if show:
print(table.get_string(sortby="IP Address"))
return active_ports
def network_service_recon(
self,
target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]],
target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None,
target_port: Optional[Union[Port, List[Port]]] = None,
show: bool = True,
show_online_only: bool = True,
json_serializable: bool = False,
) -> Dict[IPv4Address, Dict[IPProtocol, List[Port]]]:
"""
Perform a network service reconnaissance which includes a ping scan followed by a port scan.
This method combines the functionalities of a ping scan and a port scan to provide a comprehensive
overview of the services on the network. It first identifies active hosts in the target IP range by performing
a ping scan. Once the active hosts are identified, it performs a port scan on these hosts to identify open
ports and running services. This two-step process ensures that the port scan is performed only on live hosts,
optimising the scanning process and providing accurate results.
:param target_ip_address: The target IP address(es) or network(s) for the port scan.
:type target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]]
:param target_protocol: The protocol(s) to use for the port scan. Defaults to None, which includes TCP and UDP.
:type target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]]
:param target_port: The port(s) to scan. Defaults to None, which includes all valid ports.
:type target_port: Optional[Union[Port, List[Port]]]
:param show: Flag indicating whether to display the scan results. Defaults to True.
:type show: bool
:param show_online_only: Flag indicating whether to show only the online hosts. Defaults to True.
:type show_online_only: bool
:param json_serializable: Flag indicating whether the return value should be JSON serializable. Defaults to
False.
:type json_serializable: bool
:return: A dictionary mapping IP addresses to protocols and lists of open ports.
:rtype: Dict[IPv4Address, Dict[IPProtocol, List[Port]]]
"""
ping_scan_results = self.ping_scan(
target_ip_address=target_ip_address, show=show, show_online_only=show_online_only, json_serializable=False
)
return self.port_scan(
target_ip_address=ping_scan_results,
target_protocol=target_protocol,
target_port=target_port,
show=show,
json_serializable=json_serializable,
)
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Receive and process a payload.
:param payload: The payload to be processed.
:type payload: Any
:param session_id: The session ID associated with the payload.
:type session_id: str
:return: True if the payload was successfully processed, False otherwise.
:rtype: bool
"""
if isinstance(payload, PortScanPayload):
if payload.request:
self._process_port_scan_request(payload=payload, session_id=session_id)
else:
self._process_port_scan_response(payload=payload)
return True

View File

@@ -1,9 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from enum import IntEnum
from ipaddress import IPv4Address
from typing import Dict, Optional
from primaite.game.science import simulate_trial
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.transmission.network_layer import IPProtocol
@@ -12,43 +10,10 @@ from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
class RansomwareAttackStage(IntEnum):
"""
Enumeration representing different attack stages of the ransomware script.
This enumeration defines the various stages a data manipulation attack can be in during its lifecycle
in the simulation.
Each stage represents a specific phase in the attack process.
"""
NOT_STARTED = 0
"Indicates that the attack has not started yet."
DOWNLOAD = 1
"Installing the Encryption Script - Testing"
INSTALL = 2
"The stage where logon procedures are simulated."
ACTIVATE = 3
"Operating Status Changes"
PROPAGATE = 4
"Represents the stage of performing a horizontal port scan on the target."
COMMAND_AND_CONTROL = 5
"Represents the stage of setting up a rely C2 Beacon (Not Implemented)"
PAYLOAD = 6
"Stage of actively attacking the target."
SUCCEEDED = 7
"Indicates the attack has been successfully completed."
FAILED = 8
"Signifies that the attack has failed."
class RansomwareScript(Application):
"""Ransomware Kill Chain - Designed to be used by the TAP001 Agent on the example layout Network.
:ivar payload: The attack stage query payload. (Default Corrupt)
:ivar target_scan_p_of_success: The probability of success for the target scan stage.
:ivar c2_beacon_p_of_success: The probability of success for the c2_beacon stage
:ivar ransomware_encrypt_p_of_success: The probability of success for the ransomware 'attack' (encrypt) stage.
:ivar repeat: Whether to repeat attacking once finished.
:ivar payload: The attack stage query payload. (Default ENCRYPT)
"""
server_ip_address: Optional[IPv4Address] = None
@@ -57,16 +22,6 @@ class RansomwareScript(Application):
"""Password required to access the database."""
payload: Optional[str] = "ENCRYPT"
"Payload String for the payload stage"
target_scan_p_of_success: float = 0.9
"Probability of the target scan succeeding: Default 0.9"
c2_beacon_p_of_success: float = 0.9
"Probability of the c2 beacon setup stage succeeding: Default 0.9"
ransomware_encrypt_p_of_success: float = 0.9
"Probability of the ransomware attack succeeding: Default 0.9"
repeat: bool = False
"If true, the Denial of Service bot will keep performing the attack."
attack_stage: RansomwareAttackStage = RansomwareAttackStage.NOT_STARTED
"The ransomware attack stage. See RansomwareAttackStage Class"
def __init__(self, **kwargs):
kwargs["name"] = "RansomwareScript"
@@ -91,7 +46,7 @@ class RansomwareScript(Application):
@property
def _host_db_client(self) -> DatabaseClient:
"""Return the database client that is installed on the same machine as the Ransomware Script."""
db_client = self.software_manager.software.get("DatabaseClient")
db_client: DatabaseClient = self.software_manager.software.get("DatabaseClient")
if db_client is None:
self.sys_log.warning(f"{self.__class__.__name__} cannot find a database client on its host.")
return db_client
@@ -109,16 +64,6 @@ class RansomwareScript(Application):
)
return rm
def _activate(self):
"""
Simulate the install process as the initial stage of the attack.
Advances the attack stage to 'ACTIVATE' attack state.
"""
if self.attack_stage == RansomwareAttackStage.INSTALL:
self.sys_log.info(f"{self.name}: Activated!")
self.attack_stage = RansomwareAttackStage.ACTIVATE
def run(self) -> bool:
"""Calls the parent classes execute method before starting the application loop."""
super().run()
@@ -134,20 +79,9 @@ class RansomwareScript(Application):
return False
if self.server_ip_address and self.payload:
self.sys_log.info(f"{self.name}: Running")
self.attack_stage = RansomwareAttackStage.NOT_STARTED
self._local_download()
self._install()
self._activate()
self._perform_target_scan()
self._setup_beacon()
self._perform_ransomware_encrypt()
if self.repeat and self.attack_stage in (
RansomwareAttackStage.SUCCEEDED,
RansomwareAttackStage.FAILED,
):
self.attack_stage = RansomwareAttackStage.NOT_STARTED
return True
if self._perform_ransomware_encrypt():
return True
return False
else:
self.sys_log.warning(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.")
return False
@@ -157,10 +91,6 @@ class RansomwareScript(Application):
server_ip_address: IPv4Address,
server_password: Optional[str] = None,
payload: Optional[str] = None,
target_scan_p_of_success: Optional[float] = None,
c2_beacon_p_of_success: Optional[float] = None,
ransomware_encrypt_p_of_success: Optional[float] = None,
repeat: bool = True,
):
"""
Configure the Ransomware Script to communicate with a DatabaseService.
@@ -168,10 +98,6 @@ class RansomwareScript(Application):
:param server_ip_address: The IP address of the Node the DatabaseService is on.
:param server_password: The password on the DatabaseService.
:param payload: The attack stage query (Encrypt / Delete)
:param target_scan_p_of_success: The probability of success for the target scan stage.
:param c2_beacon_p_of_success: The probability of success for the c2_beacon stage
:param ransomware_encrypt_p_of_success: The probability of success for the ransomware 'attack' (encrypt) stage.
:param repeat: Whether to repeat attacking once finished.
"""
if server_ip_address:
self.server_ip_address = server_ip_address
@@ -179,74 +105,15 @@ class RansomwareScript(Application):
self.server_password = server_password
if payload:
self.payload = payload
if target_scan_p_of_success:
self.target_scan_p_of_success = target_scan_p_of_success
if c2_beacon_p_of_success:
self.c2_beacon_p_of_success = c2_beacon_p_of_success
if ransomware_encrypt_p_of_success:
self.ransomware_encrypt_p_of_success = ransomware_encrypt_p_of_success
if repeat:
self.repeat = repeat
self.sys_log.info(
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}, "
f"{repeat=}."
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}."
)
def _install(self):
"""
Simulate the install stage in the kill-chain.
Advances the attack stage to 'ACTIVATE' if successful.
From this attack stage onwards.
the ransomware application is now visible from this point onwardin the observation space.
"""
if self.attack_stage == RansomwareAttackStage.DOWNLOAD:
self.sys_log.info(f"{self.name}: Malware installed on the local file system")
downloads_folder = self.file_system.get_folder(folder_name="downloads")
ransomware_file = downloads_folder.get_file(file_name="ransom_script.pdf")
ransomware_file.num_access += 1
self.attack_stage = RansomwareAttackStage.INSTALL
def _setup_beacon(self):
"""
Simulates setting up a c2 beacon; currently a pseudo step for increasing red variance.
Advances the attack stage to 'COMMAND AND CONTROL` if successful.
:param p_of_sucess: Probability of a successful c2 setup (Advancing this step),
by default the success rate is 0.5
"""
if self.attack_stage == RansomwareAttackStage.PROPAGATE:
self.sys_log.info(f"{self.name} Attempting to set up C&C Beacon - Scan 1/2")
if simulate_trial(self.c2_beacon_p_of_success):
self.sys_log.info(f"{self.name} C&C Successful setup - Scan 2/2")
c2c_setup = True # TODO Implement the c2c step via an FTP Application/Service
if c2c_setup:
self.attack_stage = RansomwareAttackStage.COMMAND_AND_CONTROL
def _perform_target_scan(self):
"""
Perform a simulated port scan to check for open SQL ports.
Advances the attack stage to `PROPAGATE` if successful.
:param p_of_success: Probability of successful port scan, by default 0.1.
"""
if self.attack_stage == RansomwareAttackStage.ACTIVATE:
# perform a port scan to identify that the SQL port is open on the server
self.sys_log.info(f"{self.name}: Scanning for vulnerable databases - Scan 0/2")
if simulate_trial(self.target_scan_p_of_success):
self.sys_log.info(f"{self.name}: Found a target database! Scan 1/2")
port_is_open = True # TODO Implement a NNME Triggering scan as a seperate Red Application
if port_is_open:
self.attack_stage = RansomwareAttackStage.PROPAGATE
def attack(self) -> bool:
"""Perform the attack steps after opening the application."""
self.run()
if not self._can_perform_action():
self.sys_log.warning("Ransomware application is unable to perform it's actions.")
self.run()
self.num_executions += 1
return self._application_loop()
@@ -255,57 +122,30 @@ class RansomwareScript(Application):
self._db_connection = self._host_db_client.get_new_connection()
return True if self._db_connection else False
def _perform_ransomware_encrypt(self):
def _perform_ransomware_encrypt(self) -> bool:
"""
Execute the Ransomware Encrypt payload on the target.
Advances the attack stage to `COMPLETE` if successful, or 'FAILED' if unsuccessful.
:param p_of_success: Probability of successfully performing ransomware encryption, by default 0.1.
"""
if self._host_db_client is None:
self.sys_log.info(f"{self.name}: Failed to connect to db_client - Ransomware Script")
self.attack_stage = RansomwareAttackStage.FAILED
return
return False
self._host_db_client.server_ip_address = self.server_ip_address
self._host_db_client.server_password = self.server_password
if self.attack_stage == RansomwareAttackStage.COMMAND_AND_CONTROL:
if simulate_trial(self.ransomware_encrypt_p_of_success):
self.sys_log.info(f"{self.name}: Attempting to launch payload")
if not self._db_connection:
self._establish_db_connection()
if self._db_connection:
attack_successful = self._db_connection.query(self.payload)
self.sys_log.info(f"{self.name} Payload delivered: {self.payload}")
if attack_successful:
self.sys_log.info(f"{self.name}: Payload Successful")
self.attack_stage = RansomwareAttackStage.SUCCEEDED
else:
self.sys_log.info(f"{self.name}: Payload failed")
self.attack_stage = RansomwareAttackStage.FAILED
self.sys_log.info(f"{self.name}: Attempting to launch payload")
if not self._db_connection:
self._establish_db_connection()
if self._db_connection:
attack_successful = self._db_connection.query(self.payload)
self.sys_log.info(f"{self.name} Payload delivered: {self.payload}")
if attack_successful:
self.sys_log.info(f"{self.name}: Payload Successful")
return True
else:
self.sys_log.info(f"{self.name}: Payload failed")
return False
else:
self.sys_log.warning("Attack Attempted to launch too quickly")
self.attack_stage = RansomwareAttackStage.FAILED
def _local_download(self):
"""Downloads itself via the onto the local file_system."""
if self.attack_stage == RansomwareAttackStage.NOT_STARTED:
if self._local_download_verify():
self.attack_stage = RansomwareAttackStage.DOWNLOAD
else:
self.sys_log.info("Malware failed to create a installation location")
self.attack_stage = RansomwareAttackStage.FAILED
else:
self.sys_log.info("Malware failed to download")
self.attack_stage = RansomwareAttackStage.FAILED
def _local_download_verify(self) -> bool:
"""Verifies a download location - Creates one if needed."""
for folder in self.file_system.folders:
if self.file_system.folders[folder].name == "downloads":
self.file_system.num_file_creations += 1
return True
self.file_system.create_folder("downloads")
self.file_system.create_file(folder_name="downloads", file_name="ransom_script.pdf")
return True
return False

View File

@@ -79,6 +79,31 @@ class SoftwareManager:
open_ports.append(software.port)
return open_ports
def check_port_is_open(self, port: Port, protocol: IPProtocol) -> bool:
"""
Check if a specific port is open and running a service using the specified protocol.
This method iterates through all installed software on the node and checks if any of them
are using the specified port and protocol and are currently in a running state. It returns True if any software
is found running on the specified port and protocol, otherwise False.
:param port: The port to check.
:type port: Port
:param protocol: The protocol to check (e.g., TCP, UDP).
:type protocol: IPProtocol
:return: True if the port is open and a service is running on it using the specified protocol, False otherwise.
:rtype: bool
"""
for software in self.software.values():
if (
software.port == port
and software.protocol == protocol
and software.operating_state in {ApplicationOperatingState.RUNNING, ServiceOperatingState.RUNNING}
):
return True
return False
def install(self, software_class: Type[IOSoftwareClass]):
"""
Install an Application or Service.
@@ -151,6 +176,7 @@ class SoftwareManager:
self,
payload: Any,
dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
src_port: Optional[Port] = None,
dest_port: Optional[Port] = None,
ip_protocol: IPProtocol = IPProtocol.TCP,
session_id: Optional[str] = None,
@@ -171,6 +197,7 @@ class SoftwareManager:
return self.session_manager.receive_payload_from_software_manager(
payload=payload,
dst_ip_address=dest_ip_address,
src_port=src_port,
dst_port=dest_port,
ip_protocol=ip_protocol,
session_id=session_id,
@@ -191,6 +218,9 @@ class SoftwareManager:
:param payload: The payload being received.
:param session: The transport session the payload originates from.
"""
if payload.__class__.__name__ == "PortScanPayload":
self.software.get("NMAP").receive(payload=payload, session_id=session_id)
return
receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None)
if receiver:
receiver.receive(

View File

@@ -0,0 +1,26 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from enum import Enum
from typing import Any, Dict
def convert_dict_enum_keys_to_enum_values(d: Dict[Any, Any]) -> Dict[Any, Any]:
"""
Convert dictionary keys from enums to their corresponding values.
:param d: dict
The dictionary with enum keys to be converted.
:return: dict
The dictionary with enum values as keys.
"""
result = {}
for key, value in d.items():
if isinstance(key, Enum):
new_key = key.value
else:
new_key = key
if isinstance(value, dict):
result[new_key] = convert_dict_enum_keys_to_enum_values(value)
else:
result[new_key] = value
return result

View File

@@ -3,7 +3,7 @@
raise DeprecationWarning(
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
)
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from pathlib import Path
from typing import Any, Dict, Tuple, Union
@@ -12,16 +12,16 @@ from typing import Any, Dict, Tuple, Union
import polars as pl
def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
def total_rewards_dict(total_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
"""
Read an average rewards per episode csv file and return as a dict.
The dictionary keys are the episode number, and the values are the mean reward that episode.
:param av_rewards_csv_file: The average rewards per episode csv file path.
:param total_rewards_csv_file: The average rewards per episode csv file path.
:return: The average rewards per episode csv as a dict.
"""
df_dict = pl.read_csv(av_rewards_csv_file).to_dict()
df_dict = pl.read_csv(total_rewards_csv_file).to_dict()
return {int(v): df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])}

View File

@@ -3,7 +3,7 @@
raise DeprecationWarning(
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
)
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import csv
from logging import Logger
from typing import Final, List, Tuple, TYPE_CHECKING, Union
@@ -27,9 +27,9 @@ class SessionOutputWriter:
Is used to write session outputs to csv file.
"""
_AV_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [
_TOTAL_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [
"Episode",
"Average Reward",
"Total Reward",
]
def __init__(
@@ -44,7 +44,7 @@ class SessionOutputWriter:
:param env: PrimAITE gym environment.
:type env: Primaite
:param transaction_writer: If `true`, this will output a full account of every transaction taken by the agent.
If `false` it will output the average reward per episode, defaults to False
If `false` it will output the total reward per episode, defaults to False
:type transaction_writer: bool, optional
:param learning_session: Set to `true` to indicate that the current session is a training session. This
determines the name of the folder which contains the final output csv. Defaults to True
@@ -57,7 +57,7 @@ class SessionOutputWriter:
if self.transaction_writer:
fn = f"all_transactions_{self._env.timestamp_str}.csv"
else:
fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv"
fn = f"total_reward_per_episode_{self._env.timestamp_str}.csv"
self._csv_file_path: "Path"
if self.learning_session:
@@ -95,7 +95,7 @@ class SessionOutputWriter:
if isinstance(data, Transaction):
header, data = data.as_csv_data()
else:
header = self._AV_REWARD_PER_EPISODE_HEADER
header = self._TOTAL_REWARD_PER_EPISODE_HEADER
if self._first_write:
self._init_csv_writer()

View File

@@ -0,0 +1,137 @@
io_settings:
save_step_metadata: false
save_pcap_logs: true
save_sys_logs: true
sys_log_level: WARNING
game:
max_episode_length: 256
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
agents:
- ref: client_1_red_nmap
team: RED
type: ProbabilisticAgent
observation_space: null
action_space:
options:
nodes:
- node_name: client_1
applications:
- application_name: NMAP
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 1
action_list:
- type: NODE_NMAP_NETWORK_SERVICE_RECON
action_map:
0:
action: NODE_NMAP_NETWORK_SERVICE_RECON
options:
source_node: client_1
target_ip_address: 192.168.10.0/24
target_port: 80
target_protocol: tcp
reward_function:
reward_components:
- type: DUMMY
agent_settings:
action_probabilities:
0: 1.0
simulation:
network:
nodes:
- hostname: switch_1
num_ports: 8
type: switch
- hostname: switch_2
num_ports: 8
type: switch
- hostname: router_1
type: router
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.10.1
subnet_mask: 255.255.255.0
acl:
1:
action: PERMIT
- hostname: client_1
type: computer
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
- hostname: client_2
type: computer
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
- hostname: server_1
type: server
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
- hostname: server_2
type: server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
links:
- endpoint_a_hostname: router_1
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 8
- endpoint_a_hostname: router_1
endpoint_a_port: 2
endpoint_b_hostname: switch_2
endpoint_b_port: 8
- endpoint_a_hostname: client_1
endpoint_a_port: 1
endpoint_b_hostname: switch_2
endpoint_b_port: 1
- endpoint_a_hostname: client_2
endpoint_a_port: 1
endpoint_b_hostname: switch_2
endpoint_b_port: 2
- endpoint_a_hostname: server_1
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 1
- endpoint_a_hostname: server_2
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 2

View File

@@ -0,0 +1,135 @@
io_settings:
save_step_metadata: false
save_pcap_logs: true
save_sys_logs: true
sys_log_level: WARNING
game:
max_episode_length: 256
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
agents:
- ref: client_1_red_nmap
team: RED
type: ProbabilisticAgent
observation_space: null
action_space:
options:
nodes:
- node_name: client_1
applications:
- application_name: NMAP
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 1
action_list:
- type: NODE_NMAP_PING_SCAN
action_map:
0:
action: NODE_NMAP_PING_SCAN
options:
source_node: client_1
target_ip_address: 192.168.1.0/24
reward_function:
reward_components:
- type: DUMMY
agent_settings:
action_probabilities:
0: 1.0
simulation:
network:
nodes:
- hostname: switch_1
num_ports: 8
type: switch
- hostname: switch_2
num_ports: 8
type: switch
- hostname: router_1
type: router
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.10.1
subnet_mask: 255.255.255.0
acl:
1:
action: PERMIT
- hostname: client_1
type: computer
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
- hostname: client_2
type: computer
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
- hostname: server_1
type: server
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
- hostname: server_2
type: server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
links:
- endpoint_a_hostname: router_1
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 8
- endpoint_a_hostname: router_1
endpoint_a_port: 2
endpoint_b_hostname: switch_2
endpoint_b_port: 8
- endpoint_a_hostname: client_1
endpoint_a_port: 1
endpoint_b_hostname: switch_2
endpoint_b_port: 1
- endpoint_a_hostname: client_2
endpoint_a_port: 1
endpoint_b_hostname: switch_2
endpoint_b_port: 2
- endpoint_a_hostname: server_1
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 1
- endpoint_a_hostname: server_2
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 2

View File

@@ -0,0 +1,135 @@
io_settings:
save_step_metadata: false
save_pcap_logs: true
save_sys_logs: true
sys_log_level: WARNING
game:
max_episode_length: 256
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
agents:
- ref: client_1_red_nmap
team: RED
type: ProbabilisticAgent
observation_space: null
action_space:
options:
nodes:
- node_name: client_1
applications:
- application_name: NMAP
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 1
action_list:
- type: NODE_NMAP_PORT_SCAN
action_map:
0:
action: NODE_NMAP_PORT_SCAN
options:
source_node: client_1
target_ip_address: 192.168.10.0/24
reward_function:
reward_components:
- type: DUMMY
agent_settings:
action_probabilities:
0: 1.0
simulation:
network:
nodes:
- hostname: switch_1
num_ports: 8
type: switch
- hostname: switch_2
num_ports: 8
type: switch
- hostname: router_1
type: router
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.10.1
subnet_mask: 255.255.255.0
acl:
1:
action: PERMIT
- hostname: client_1
type: computer
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
- hostname: client_2
type: computer
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
- hostname: server_1
type: server
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
- hostname: server_2
type: server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
links:
- endpoint_a_hostname: router_1
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 8
- endpoint_a_hostname: router_1
endpoint_a_port: 2
endpoint_b_hostname: switch_2
endpoint_b_port: 8
- endpoint_a_hostname: client_1
endpoint_a_port: 1
endpoint_b_hostname: switch_2
endpoint_b_port: 1
- endpoint_a_hostname: client_2
endpoint_a_port: 1
endpoint_b_hostname: switch_2
endpoint_b_port: 2
- endpoint_a_hostname: server_1
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 1
- endpoint_a_hostname: server_2
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 2

View File

@@ -10,12 +10,8 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
from primaite.simulator.system.applications.red_applications.ransomware_script import (
RansomwareAttackStage,
RansomwareScript,
)
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.software import SoftwareHealthState
@@ -86,54 +82,24 @@ def ransomware_script_db_server_green_client(example_network) -> Network:
return network
def test_repeating_ransomware_script_attack(ransomware_script_and_db_server):
def test_ransomware_script_attack(ransomware_script_and_db_server):
"""Test a repeating data manipulation attack."""
RansomwareScript, computer, db_server_service, server = ransomware_script_and_db_server
computer.apply_timestep(timestep=0)
server.apply_timestep(timestep=0)
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
assert computer.file_system.num_file_creations == 0
assert server.file_system.num_file_creations == 1
RansomwareScript.target_scan_p_of_success = 1
RansomwareScript.c2_beacon_p_of_success = 1
RansomwareScript.ransomware_encrypt_p_of_success = 1
RansomwareScript.repeat = True
RansomwareScript.attack()
assert RansomwareScript.attack_stage == RansomwareAttackStage.NOT_STARTED
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.COMPROMISED
assert computer.file_system.num_file_creations == 1
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT
assert server.file_system.num_file_creations == 2
computer.apply_timestep(timestep=1)
server.apply_timestep(timestep=1)
assert RansomwareScript.attack_stage == RansomwareAttackStage.NOT_STARTED
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.COMPROMISED
def test_repeating_ransomware_script_attack(ransomware_script_and_db_server):
"""Test a repeating ransowmare script attack."""
RansomwareScript, computer, db_server_service, server = ransomware_script_and_db_server
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
RansomwareScript.target_scan_p_of_success = 1
RansomwareScript.c2_beacon_p_of_success = 1
RansomwareScript.ransomware_encrypt_p_of_success = 1
RansomwareScript.repeat = False
RansomwareScript.attack()
assert RansomwareScript.attack_stage == RansomwareAttackStage.SUCCEEDED
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT
assert computer.file_system.num_file_creations == 1
computer.apply_timestep(timestep=1)
computer.pre_timestep(timestep=1)
server.apply_timestep(timestep=1)
server.pre_timestep(timestep=1)
assert RansomwareScript.attack_stage == RansomwareAttackStage.SUCCEEDED
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT
assert computer.file_system.num_file_creations == 0
def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_green_client):
@@ -154,10 +120,6 @@ def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_
assert green_db_client_connection.query("SELECT")
assert green_db_client.last_query_response.get("status_code") == 200
ransomware_script_application.target_scan_p_of_success = 1
ransomware_script_application.ransomware_encrypt_p_of_success = 1
ransomware_script_application.c2_beacon_p_of_success = 1
ransomware_script_application.repeat = False
ransomware_script_application.attack()
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT

View File

@@ -0,0 +1,186 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from enum import Enum
from ipaddress import IPv4Address, IPv4Network
import yaml
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.nmap import NMAP
from tests import TEST_ASSETS_ROOT
def test_ping_scan_all_on(example_network):
network = example_network
client_1 = network.get_node_by_hostname("client_1")
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
expected_result = [IPv4Address("192.168.1.10"), IPv4Address("192.168.1.14")]
actual_result = client_1_nmap.ping_scan(target_ip_address=["192.168.1.10", "192.168.1.14"])
assert sorted(actual_result) == sorted(expected_result)
def test_ping_scan_all_on_full_network(example_network):
network = example_network
client_1 = network.get_node_by_hostname("client_1")
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
expected_result = [IPv4Address("192.168.1.1"), IPv4Address("192.168.1.10"), IPv4Address("192.168.1.14")]
actual_result = client_1_nmap.ping_scan(target_ip_address=IPv4Network("192.168.1.0/24"))
assert sorted(actual_result) == sorted(expected_result)
def test_ping_scan_some_on(example_network):
network = example_network
client_1 = network.get_node_by_hostname("client_1")
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
network.get_node_by_hostname("server_2").power_off()
expected_result = [IPv4Address("192.168.1.1"), IPv4Address("192.168.1.10")]
actual_result = client_1_nmap.ping_scan(target_ip_address=IPv4Network("192.168.1.0/24"))
assert sorted(actual_result) == sorted(expected_result)
def test_ping_scan_all_off(example_network):
network = example_network
client_1 = network.get_node_by_hostname("client_1")
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
network.get_node_by_hostname("server_1").power_off()
network.get_node_by_hostname("server_2").power_off()
expected_result = []
actual_result = client_1_nmap.ping_scan(target_ip_address=["192.168.1.10", "192.168.1.14"])
assert sorted(actual_result) == sorted(expected_result)
def test_port_scan_one_node_one_port(example_network):
network = example_network
client_1 = network.get_node_by_hostname("client_1")
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
client_2 = network.get_node_by_hostname("client_2")
actual_result = client_1_nmap.port_scan(
target_ip_address=client_2.network_interface[1].ip_address, target_port=Port.DNS, target_protocol=IPProtocol.TCP
)
expected_result = {IPv4Address("192.168.10.22"): {IPProtocol.TCP: [Port.DNS]}}
assert actual_result == expected_result
def sort_dict(d):
"""Recursively sorts a dictionary."""
if isinstance(d, dict):
return {k: sort_dict(v) for k, v in sorted(d.items(), key=lambda item: str(item[0]))}
elif isinstance(d, list):
return sorted(d, key=lambda item: str(item) if isinstance(item, Enum) else item)
elif isinstance(d, Enum):
return str(d)
else:
return d
def test_port_scan_full_subnet_all_ports_and_protocols(example_network):
network = example_network
client_1 = network.get_node_by_hostname("client_1")
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
actual_result = client_1_nmap.port_scan(
target_ip_address=IPv4Network("192.168.10.0/24"),
)
expected_result = {
IPv4Address("192.168.10.1"): {IPProtocol.UDP: [Port.ARP]},
IPv4Address("192.168.10.22"): {
IPProtocol.TCP: [Port.HTTP, Port.FTP, Port.DNS],
IPProtocol.UDP: [Port.ARP, Port.NTP],
},
}
assert sort_dict(actual_result) == sort_dict(expected_result)
def test_network_service_recon_all_ports_and_protocols(example_network):
network = example_network
client_1 = network.get_node_by_hostname("client_1")
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
actual_result = client_1_nmap.network_service_recon(
target_ip_address=IPv4Network("192.168.10.0/24"), target_port=Port.HTTP, target_protocol=IPProtocol.TCP
)
expected_result = {IPv4Address("192.168.10.22"): {IPProtocol.TCP: [Port.HTTP]}}
assert sort_dict(actual_result) == sort_dict(expected_result)
def test_ping_scan_red_agent():
with open(TEST_ASSETS_ROOT / "configs/nmap_ping_scan_red_agent_config.yaml", "r") as file:
cfg = yaml.safe_load(file)
game = PrimaiteGame.from_config(cfg)
game.step()
expected_result = ["192.168.1.1", "192.168.1.10", "192.168.1.14"]
action_history = game.agents["client_1_red_nmap"].history
assert len(action_history) == 1
actual_result = action_history[0].response.data["live_hosts"]
assert sorted(actual_result) == sorted(expected_result)
def test_port_scan_red_agent():
with open(TEST_ASSETS_ROOT / "configs/nmap_port_scan_red_agent_config.yaml", "r") as file:
cfg = yaml.safe_load(file)
game = PrimaiteGame.from_config(cfg)
game.step()
expected_result = {
"192.168.10.1": {"udp": [219]},
"192.168.10.22": {
"tcp": [80, 21, 53],
"udp": [219, 123],
},
}
action_history = game.agents["client_1_red_nmap"].history
assert len(action_history) == 1
actual_result = action_history[0].response.data
assert sorted(actual_result) == sorted(expected_result)
def test_network_service_recon_red_agent():
with open(TEST_ASSETS_ROOT / "configs/nmap_network_service_recon_red_agent_config.yaml", "r") as file:
cfg = yaml.safe_load(file)
game = PrimaiteGame.from_config(cfg)
game.step()
expected_result = {"192.168.10.22": {"tcp": [80]}}
action_history = game.agents["client_1_red_nmap"].history
assert len(action_history) == 1
actual_result = action_history[0].response.data
assert sorted(actual_result) == sorted(expected_result)

View File

@@ -0,0 +1 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

View File

@@ -0,0 +1,84 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.utils.converters import convert_dict_enum_keys_to_enum_values
def test_simple_conversion():
"""
Test conversion of a simple dictionary with enum keys to enum values.
The original dictionary contains one level of nested dictionary with enums as keys.
The expected output should have string values of enums as keys.
"""
original_dict = {IPProtocol.UDP: {Port.ARP: {"inbound": 0, "outbound": 1016.0}}}
expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0}}}
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict
def test_no_enums():
"""
Test conversion of a dictionary with no enum keys.
The original dictionary contains only string keys.
The expected output should be identical to the original dictionary.
"""
original_dict = {"protocol": {"port": {"inbound": 0, "outbound": 1016.0}}}
expected_dict = {"protocol": {"port": {"inbound": 0, "outbound": 1016.0}}}
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict
def test_mixed_keys():
"""
Test conversion of a dictionary with a mix of enum and string keys.
The original dictionary contains both enums and strings as keys.
The expected output should have string values of enums and original string keys.
"""
original_dict = {
IPProtocol.TCP: {"port": {"inbound": 0, "outbound": 1016.0}},
"protocol": {Port.HTTP: {"inbound": 10, "outbound": 2020.0}},
}
expected_dict = {
"tcp": {"port": {"inbound": 0, "outbound": 1016.0}},
"protocol": {80: {"inbound": 10, "outbound": 2020.0}},
}
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict
def test_empty_dict():
"""
Test conversion of an empty dictionary.
The original dictionary is empty.
The expected output should also be an empty dictionary.
"""
original_dict = {}
expected_dict = {}
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict
def test_nested_dicts():
"""
Test conversion of a nested dictionary with multiple levels of nested dictionaries and enums as keys.
The original dictionary contains nested dictionaries with enums as keys at different levels.
The expected output should have string values of enums as keys at all levels.
"""
original_dict = {
IPProtocol.UDP: {Port.ARP: {"inbound": 0, "outbound": 1016.0, "details": {IPProtocol.TCP: {"latency": "low"}}}}
}
expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0, "details": {"tcp": {"latency": "low"}}}}}
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict
def test_non_dict_values():
"""
Test conversion of a dictionary where some values are not dictionaries.
The original dictionary contains lists and tuples as values.
The expected output should preserve these non-dictionary values while converting enum keys to string values.
"""
original_dict = {IPProtocol.UDP: [Port.ARP, Port.HTTP], "protocols": (IPProtocol.TCP, IPProtocol.UDP)}
expected_dict = {"udp": [Port.ARP, Port.HTTP], "protocols": (IPProtocol.TCP, IPProtocol.UDP)}
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict