Merge remote-tracking branch 'origin/dev' into feature/2646_Update-pre-commit-to-check-for-valid-copyright
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -164,3 +164,7 @@ src/primaite/notebooks/scratch.py
|
||||
sandbox.py
|
||||
sandbox/
|
||||
sandbox.ipynb
|
||||
|
||||
# benchmarking
|
||||
**/benchmark/sessions/
|
||||
**/benchmark/output/
|
||||
|
||||
@@ -43,6 +43,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Added support for SQL INSERT command.
|
||||
- Added ability to log each agent's action choices in each step to a JSON file.
|
||||
- Removal of Link bandwidth hardcoding. This can now be configured via the network configuraiton yaml. Will default to 100 if not present.
|
||||
- Added NMAP application to all host and layer-3 network nodes.
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
|
||||
22
benchmark/benchmark.py
Normal file
22
benchmark/benchmark.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
|
||||
|
||||
class BenchmarkPrimaiteGymEnv(PrimaiteGymEnv):
|
||||
"""
|
||||
Class that extends the PrimaiteGymEnv.
|
||||
|
||||
The reset method is extended so that the average rewards per episode are recorded.
|
||||
"""
|
||||
|
||||
total_time_steps: int = 0
|
||||
|
||||
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
"""Overrides the PrimAITEGymEnv reset so that the total timesteps is saved."""
|
||||
self.total_time_steps += self.game.step_counter
|
||||
|
||||
return super().reset(seed=seed)
|
||||
@@ -3,210 +3,95 @@
|
||||
raise DeprecationWarning(
|
||||
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
|
||||
)
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Optional, Tuple, Union
|
||||
from unittest.mock import patch
|
||||
from typing import Any, Dict, Final, Tuple
|
||||
|
||||
import GPUtil
|
||||
import plotly.graph_objects as go
|
||||
import polars as pl
|
||||
import psutil
|
||||
import yaml
|
||||
from plotly.graph_objs import Figure
|
||||
from pylatex import Command, Document
|
||||
from pylatex import Figure as LatexFigure
|
||||
from pylatex import Section, Subsection, Tabular
|
||||
from pylatex.utils import bold
|
||||
from report import build_benchmark_latex_report
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
import primaite
|
||||
from primaite.config.lay_down_config import data_manipulation_config_path
|
||||
from primaite.data_viz.session_plots import get_plotly_config
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.primaite_session import PrimaiteSession
|
||||
from benchmark import BenchmarkPrimaiteGymEnv
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
|
||||
_LOGGER = primaite.getLogger(__name__)
|
||||
|
||||
_MAJOR_V = primaite.__version__.split(".")[0]
|
||||
|
||||
_BENCHMARK_ROOT = Path(__file__).parent
|
||||
_RESULTS_ROOT: Final[Path] = _BENCHMARK_ROOT / "results"
|
||||
_RESULTS_ROOT.mkdir(exist_ok=True, parents=True)
|
||||
_RESULTS_ROOT: Final[Path] = _BENCHMARK_ROOT / "results" / f"v{_MAJOR_V}"
|
||||
_VERSION_ROOT: Final[Path] = _RESULTS_ROOT / f"v{primaite.__version__}"
|
||||
_SESSION_METADATA_ROOT: Final[Path] = _VERSION_ROOT / "session_metadata"
|
||||
|
||||
_OUTPUT_ROOT: Final[Path] = _BENCHMARK_ROOT / "output"
|
||||
# Clear and recreate the output directory
|
||||
if _OUTPUT_ROOT.exists():
|
||||
shutil.rmtree(_OUTPUT_ROOT)
|
||||
_OUTPUT_ROOT.mkdir()
|
||||
|
||||
_TRAINING_CONFIG_PATH = _BENCHMARK_ROOT / "config" / "benchmark_training_config.yaml"
|
||||
_LAY_DOWN_CONFIG_PATH = data_manipulation_config_path()
|
||||
_SESSION_METADATA_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_size(size_bytes: int) -> str:
|
||||
"""
|
||||
Scale bytes to its proper format.
|
||||
class BenchmarkSession:
|
||||
"""Benchmark Session class."""
|
||||
|
||||
e.g:
|
||||
1253656 => '1.20MB'
|
||||
1253656678 => '1.17GB'
|
||||
gym_env: BenchmarkPrimaiteGymEnv
|
||||
"""Gym environment used by the session to train."""
|
||||
|
||||
:
|
||||
"""
|
||||
factor = 1024
|
||||
for unit in ["", "K", "M", "G", "T", "P"]:
|
||||
if size_bytes < factor:
|
||||
return f"{size_bytes:.2f}{unit}B"
|
||||
size_bytes /= factor
|
||||
num_episodes: int
|
||||
"""Number of episodes to run the training session."""
|
||||
|
||||
episode_len: int
|
||||
"""The number of steps per episode."""
|
||||
|
||||
def _get_system_info() -> Dict:
|
||||
"""Builds and returns a dict containing system info."""
|
||||
uname = platform.uname()
|
||||
cpu_freq = psutil.cpu_freq()
|
||||
virtual_mem = psutil.virtual_memory()
|
||||
swap_mem = psutil.swap_memory()
|
||||
gpus = GPUtil.getGPUs()
|
||||
return {
|
||||
"System": {
|
||||
"OS": uname.system,
|
||||
"OS Version": uname.version,
|
||||
"Machine": uname.machine,
|
||||
"Processor": uname.processor,
|
||||
},
|
||||
"CPU": {
|
||||
"Physical Cores": psutil.cpu_count(logical=False),
|
||||
"Total Cores": psutil.cpu_count(logical=True),
|
||||
"Max Frequency": f"{cpu_freq.max:.2f}Mhz",
|
||||
},
|
||||
"Memory": {"Total": get_size(virtual_mem.total), "Swap Total": get_size(swap_mem.total)},
|
||||
"GPU": [{"Name": gpu.name, "Total Memory": f"{gpu.memoryTotal}MB"} for gpu in gpus],
|
||||
}
|
||||
total_steps: int
|
||||
"""Number of steps to run the training session."""
|
||||
|
||||
batch_size: int
|
||||
"""Number of steps for each episode."""
|
||||
|
||||
def _build_benchmark_latex_report(
|
||||
benchmark_metadata_dict: Dict, this_version_plot_path: Path, all_version_plot_path: Path
|
||||
) -> None:
|
||||
geometry_options = {"tmargin": "2.5cm", "rmargin": "2.5cm", "bmargin": "2.5cm", "lmargin": "2.5cm"}
|
||||
data = benchmark_metadata_dict
|
||||
primaite_version = data["primaite_version"]
|
||||
learning_rate: float
|
||||
"""Learning rate for the model."""
|
||||
|
||||
# Create a new document
|
||||
doc = Document("report", geometry_options=geometry_options)
|
||||
# Title
|
||||
doc.preamble.append(Command("title", f"PrimAITE {primaite_version} Learning Benchmark"))
|
||||
doc.preamble.append(Command("author", "PrimAITE Dev Team"))
|
||||
doc.preamble.append(Command("date", datetime.now().date()))
|
||||
doc.append(Command("maketitle"))
|
||||
start_time: datetime
|
||||
"""Start time for the session."""
|
||||
|
||||
sessions = data["total_sessions"]
|
||||
episodes = data["training_config"]["num_train_episodes"]
|
||||
steps = data["training_config"]["num_train_steps"]
|
||||
|
||||
# Body
|
||||
with doc.create(Section("Introduction")):
|
||||
doc.append(
|
||||
f"PrimAITE v{primaite_version} was benchmarked automatically upon release. Learning rate metrics "
|
||||
f"were captured to be referenced during system-level testing and user acceptance testing (UAT)."
|
||||
)
|
||||
doc.append(
|
||||
f"\nThe benchmarking process consists of running {sessions} training session using the same "
|
||||
f"training and lay down config files. Each session trains an agent for {episodes} episodes, "
|
||||
f"with each episode consisting of {steps} steps."
|
||||
)
|
||||
doc.append(
|
||||
f"\nThe mean reward per episode from each session is captured. This is then used to calculate a "
|
||||
f"combined average reward per episode from the {sessions} individual sessions for smoothing. "
|
||||
f"Finally, a 25-widow rolling average of the combined average reward per session is calculated for "
|
||||
f"further smoothing."
|
||||
)
|
||||
|
||||
with doc.create(Section("System Information")):
|
||||
with doc.create(Subsection("Python")):
|
||||
with doc.create(Tabular("|l|l|")) as table:
|
||||
table.add_hline()
|
||||
table.add_row((bold("Version"), sys.version))
|
||||
table.add_hline()
|
||||
for section, section_data in data["system_info"].items():
|
||||
if section_data:
|
||||
with doc.create(Subsection(section)):
|
||||
if isinstance(section_data, dict):
|
||||
with doc.create(Tabular("|l|l|")) as table:
|
||||
table.add_hline()
|
||||
for key, value in section_data.items():
|
||||
table.add_row((bold(key), value))
|
||||
table.add_hline()
|
||||
elif isinstance(section_data, list):
|
||||
headers = section_data[0].keys()
|
||||
tabs_str = "|".join(["l" for _ in range(len(headers))])
|
||||
tabs_str = f"|{tabs_str}|"
|
||||
with doc.create(Tabular(tabs_str)) as table:
|
||||
table.add_hline()
|
||||
table.add_row([bold(h) for h in headers])
|
||||
table.add_hline()
|
||||
for item in section_data:
|
||||
table.add_row(item.values())
|
||||
table.add_hline()
|
||||
|
||||
headers_map = {
|
||||
"total_sessions": "Total Sessions",
|
||||
"total_episodes": "Total Episodes",
|
||||
"total_time_steps": "Total Steps",
|
||||
"av_s_per_session": "Av Session Duration (s)",
|
||||
"av_s_per_step": "Av Step Duration (s)",
|
||||
"av_s_per_100_steps_10_nodes": "Av Duration per 100 Steps per 10 Nodes (s)",
|
||||
}
|
||||
with doc.create(Section("Stats")):
|
||||
with doc.create(Subsection("Benchmark Results")):
|
||||
with doc.create(Tabular("|l|l|")) as table:
|
||||
table.add_hline()
|
||||
for section, header in headers_map.items():
|
||||
if section.startswith("av_"):
|
||||
table.add_row((bold(header), f"{data[section]:.4f}"))
|
||||
else:
|
||||
table.add_row((bold(header), data[section]))
|
||||
table.add_hline()
|
||||
|
||||
with doc.create(Section("Graphs")):
|
||||
with doc.create(Subsection(f"PrimAITE {primaite_version} Learning Benchmark Plot")):
|
||||
with doc.create(LatexFigure(position="h!")) as pic:
|
||||
pic.add_image(str(this_version_plot_path))
|
||||
pic.add_caption(f"PrimAITE {primaite_version} Learning Benchmark Plot")
|
||||
|
||||
with doc.create(Subsection("PrimAITE All Versions Learning Benchmark Plot")):
|
||||
with doc.create(LatexFigure(position="h!")) as pic:
|
||||
pic.add_image(str(all_version_plot_path))
|
||||
pic.add_caption("PrimAITE All Versions Learning Benchmark Plot")
|
||||
|
||||
doc.generate_pdf(str(this_version_plot_path).replace(".png", ""), clean_tex=True)
|
||||
|
||||
|
||||
class BenchmarkPrimaiteSession(PrimaiteSession):
|
||||
"""A benchmarking primaite session."""
|
||||
end_time: datetime
|
||||
"""End time for the session."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
) -> None:
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
self.setup()
|
||||
gym_env: BenchmarkPrimaiteGymEnv,
|
||||
episode_len: int,
|
||||
num_episodes: int,
|
||||
n_steps: int,
|
||||
batch_size: int,
|
||||
learning_rate: float,
|
||||
):
|
||||
"""Initialise the BenchmarkSession."""
|
||||
self.gym_env = gym_env
|
||||
self.episode_len = episode_len
|
||||
self.n_steps = n_steps
|
||||
self.num_episodes = num_episodes
|
||||
self.total_steps = self.num_episodes * self.episode_len
|
||||
self.batch_size = batch_size
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
@property
|
||||
def env(self) -> Primaite:
|
||||
"""Direct access to the env for ease of testing."""
|
||||
return self._agent_session._env # noqa
|
||||
def train(self):
|
||||
"""Run the training session."""
|
||||
# start timer for session
|
||||
self.start_time = datetime.now()
|
||||
model = PPO(
|
||||
policy="MlpPolicy",
|
||||
env=self.gym_env,
|
||||
learning_rate=self.learning_rate,
|
||||
n_steps=self.n_steps,
|
||||
batch_size=self.batch_size,
|
||||
verbose=0,
|
||||
tensorboard_log="./PPO_UC2/",
|
||||
)
|
||||
model.learn(total_timesteps=self.total_steps)
|
||||
|
||||
def __enter__(self) -> "BenchmarkPrimaiteSession":
|
||||
return self
|
||||
# end timer for session
|
||||
self.end_time = datetime.now()
|
||||
|
||||
# TODO: typehints uncertain
|
||||
def __exit__(self, type: Any, value: Any, tb: Any) -> None:
|
||||
shutil.rmtree(self.session_path)
|
||||
_LOGGER.debug(f"Deleted benchmark session directory: {self.session_path}")
|
||||
self.session_metadata = self.generate_learn_metadata_dict()
|
||||
|
||||
def _learn_benchmark_durations(self) -> Tuple[float, float, float]:
|
||||
"""
|
||||
@@ -220,235 +105,99 @@ class BenchmarkPrimaiteSession(PrimaiteSession):
|
||||
:return: The learning benchmark durations as a Tuple of three floats:
|
||||
Tuple[total_s, s_per_step, s_per_100_steps_10_nodes].
|
||||
"""
|
||||
data = self.metadata_file_as_dict()
|
||||
start_dt = datetime.fromisoformat(data["start_datetime"])
|
||||
end_dt = datetime.fromisoformat(data["end_datetime"])
|
||||
delta = end_dt - start_dt
|
||||
delta = self.end_time - self.start_time
|
||||
total_s = delta.total_seconds()
|
||||
|
||||
total_steps = data["learning"]["total_time_steps"]
|
||||
total_steps = self.batch_size * self.num_episodes
|
||||
s_per_step = total_s / total_steps
|
||||
|
||||
num_nodes = self.env.num_nodes
|
||||
num_nodes = len(self.gym_env.game.simulation.network.nodes)
|
||||
num_intervals = total_steps / 100
|
||||
av_interval_time = total_s / num_intervals
|
||||
s_per_100_steps_10_nodes = av_interval_time / (num_nodes / 10)
|
||||
|
||||
return total_s, s_per_step, s_per_100_steps_10_nodes
|
||||
|
||||
def learn_metadata_dict(self) -> Dict[str, Any]:
|
||||
def generate_learn_metadata_dict(self) -> Dict[str, Any]:
|
||||
"""Metadata specific to the learning session."""
|
||||
total_s, s_per_step, s_per_100_steps_10_nodes = self._learn_benchmark_durations()
|
||||
self.gym_env.average_reward_per_episode.pop(0) # remove episode 0
|
||||
return {
|
||||
"total_episodes": self.env.actual_episode_count,
|
||||
"total_time_steps": self.env.total_step_count,
|
||||
"total_episodes": self.gym_env.episode_counter,
|
||||
"total_time_steps": self.gym_env.total_time_steps,
|
||||
"total_s": total_s,
|
||||
"s_per_step": s_per_step,
|
||||
"s_per_100_steps_10_nodes": s_per_100_steps_10_nodes,
|
||||
"av_reward_per_episode": self.learn_av_reward_per_episode_dict(),
|
||||
"av_reward_per_episode": self.gym_env.average_reward_per_episode,
|
||||
}
|
||||
|
||||
|
||||
def _get_benchmark_session_path(session_timestamp: datetime) -> Path:
|
||||
return _OUTPUT_ROOT / session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
|
||||
def _get_benchmark_primaite_session() -> BenchmarkPrimaiteSession:
|
||||
with patch("primaite.agents.agent_abc.get_session_path", _get_benchmark_session_path) as mck:
|
||||
mck.session_timestamp = datetime.now()
|
||||
return BenchmarkPrimaiteSession(_TRAINING_CONFIG_PATH, _LAY_DOWN_CONFIG_PATH)
|
||||
|
||||
|
||||
def _build_benchmark_results_dict(start_datetime: datetime, metadata_dict: Dict) -> dict:
|
||||
n = len(metadata_dict)
|
||||
with open(_TRAINING_CONFIG_PATH, "r") as file:
|
||||
training_config_dict = yaml.safe_load(file)
|
||||
with open(_LAY_DOWN_CONFIG_PATH, "r") as file:
|
||||
lay_down_config_dict = yaml.safe_load(file)
|
||||
averaged_data = {
|
||||
"start_timestamp": start_datetime.isoformat(),
|
||||
"end_datetime": datetime.now().isoformat(),
|
||||
"primaite_version": primaite.__version__,
|
||||
"system_info": _get_system_info(),
|
||||
"total_sessions": n,
|
||||
"total_episodes": sum(d["total_episodes"] for d in metadata_dict.values()),
|
||||
"total_time_steps": sum(d["total_time_steps"] for d in metadata_dict.values()),
|
||||
"av_s_per_session": sum(d["total_s"] for d in metadata_dict.values()) / n,
|
||||
"av_s_per_step": sum(d["s_per_step"] for d in metadata_dict.values()) / n,
|
||||
"av_s_per_100_steps_10_nodes": sum(d["s_per_100_steps_10_nodes"] for d in metadata_dict.values()) / n,
|
||||
"combined_av_reward_per_episode": {},
|
||||
"session_av_reward_per_episode": {k: v["av_reward_per_episode"] for k, v in metadata_dict.items()},
|
||||
"training_config": training_config_dict,
|
||||
"lay_down_config": lay_down_config_dict,
|
||||
}
|
||||
|
||||
episodes = metadata_dict[1]["av_reward_per_episode"].keys()
|
||||
|
||||
for episode in episodes:
|
||||
combined_av_reward = sum(metadata_dict[k]["av_reward_per_episode"][episode] for k in metadata_dict.keys()) / n
|
||||
averaged_data["combined_av_reward_per_episode"][episode] = combined_av_reward
|
||||
|
||||
return averaged_data
|
||||
|
||||
|
||||
def _get_df_from_episode_av_reward_dict(data: Dict) -> pl.DataFrame:
|
||||
data: Dict = {"episode": data.keys(), "av_reward": data.values()}
|
||||
|
||||
return (
|
||||
pl.from_dict(data)
|
||||
.with_columns(rolling_mean=pl.col("av_reward").rolling_mean(window_size=25))
|
||||
.rename({"rolling_mean": "rolling_av_reward"})
|
||||
)
|
||||
|
||||
|
||||
def _plot_benchmark_metadata(
|
||||
benchmark_metadata_dict: Dict,
|
||||
title: Optional[str] = None,
|
||||
subtitle: Optional[str] = None,
|
||||
) -> Figure:
|
||||
if title:
|
||||
if subtitle:
|
||||
title = f"{title} <br>{subtitle}</sup>"
|
||||
else:
|
||||
if subtitle:
|
||||
title = subtitle
|
||||
|
||||
config = get_plotly_config()
|
||||
layout = go.Layout(
|
||||
autosize=config["size"]["auto_size"],
|
||||
width=config["size"]["width"],
|
||||
height=config["size"]["height"],
|
||||
)
|
||||
# Create the line graph with a colored line
|
||||
fig = go.Figure(layout=layout)
|
||||
fig.update_layout(template=config["template"])
|
||||
|
||||
for session, av_reward_dict in benchmark_metadata_dict["session_av_reward_per_episode"].items():
|
||||
df = _get_df_from_episode_av_reward_dict(av_reward_dict)
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["episode"],
|
||||
y=df["av_reward"],
|
||||
mode="lines",
|
||||
name=f"Session {session}",
|
||||
opacity=0.25,
|
||||
line={"color": "#a6a6a6"},
|
||||
)
|
||||
)
|
||||
|
||||
df = _get_df_from_episode_av_reward_dict(benchmark_metadata_dict["combined_av_reward_per_episode"])
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["episode"], y=df["av_reward"], mode="lines", name="Combined Session Av", line={"color": "#FF0000"}
|
||||
)
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["episode"],
|
||||
y=df["rolling_av_reward"],
|
||||
mode="lines",
|
||||
name="Rolling Av (Combined Session Av)",
|
||||
line={"color": "#4CBB17"},
|
||||
)
|
||||
)
|
||||
|
||||
# Set the layout of the graph
|
||||
fig.update_layout(
|
||||
xaxis={
|
||||
"title": "Episode",
|
||||
"type": "linear",
|
||||
},
|
||||
yaxis={"title": "Average Reward"},
|
||||
title=title,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def _plot_all_benchmarks_combined_session_av() -> Figure:
|
||||
def _get_benchmark_primaite_environment() -> BenchmarkPrimaiteGymEnv:
|
||||
"""
|
||||
Plot the Benchmark results for each released version of PrimAITE.
|
||||
Create an instance of the BenchmarkPrimaiteGymEnv.
|
||||
|
||||
Does this by iterating over the ``benchmark/results`` directory and
|
||||
extracting the benchmark metadata json for each version that has been
|
||||
benchmarked. The combined_av_reward_per_episode is extracted from each,
|
||||
converted into a polars dataframe, and plotted as a scatter line in plotly.
|
||||
This environment will be used to train the agents on.
|
||||
"""
|
||||
title = "PrimAITE Versions Learning Benchmark"
|
||||
subtitle = "Rolling Av (Combined Session Av)"
|
||||
if title:
|
||||
if subtitle:
|
||||
title = f"{title} <br>{subtitle}</sup>"
|
||||
else:
|
||||
if subtitle:
|
||||
title = subtitle
|
||||
config = get_plotly_config()
|
||||
layout = go.Layout(
|
||||
autosize=config["size"]["auto_size"],
|
||||
width=config["size"]["width"],
|
||||
height=config["size"]["height"],
|
||||
)
|
||||
# Create the line graph with a colored line
|
||||
fig = go.Figure(layout=layout)
|
||||
fig.update_layout(template=config["template"])
|
||||
|
||||
for dir in _RESULTS_ROOT.iterdir():
|
||||
if dir.is_dir():
|
||||
metadata_file = dir / f"{dir.name}_benchmark_metadata.json"
|
||||
with open(metadata_file, "r") as file:
|
||||
metadata_dict = json.load(file)
|
||||
df = _get_df_from_episode_av_reward_dict(metadata_dict["combined_av_reward_per_episode"])
|
||||
|
||||
fig.add_trace(go.Scatter(x=df["episode"], y=df["rolling_av_reward"], mode="lines", name=dir.name))
|
||||
|
||||
# Set the layout of the graph
|
||||
fig.update_layout(
|
||||
xaxis={
|
||||
"title": "Episode",
|
||||
"type": "linear",
|
||||
},
|
||||
yaxis={"title": "Average Reward"},
|
||||
title=title,
|
||||
)
|
||||
fig["data"][0]["showlegend"] = True
|
||||
|
||||
return fig
|
||||
env = BenchmarkPrimaiteGymEnv(env_config=data_manipulation_config_path())
|
||||
return env
|
||||
|
||||
|
||||
def run() -> None:
|
||||
def _prepare_session_directory():
|
||||
"""Prepare the session directory so that it is easier to clean up after the benchmarking is done."""
|
||||
# override session path
|
||||
session_path = _BENCHMARK_ROOT / "sessions"
|
||||
|
||||
if session_path.is_dir():
|
||||
shutil.rmtree(session_path)
|
||||
|
||||
primaite.PRIMAITE_PATHS.user_sessions_path = session_path
|
||||
primaite.PRIMAITE_PATHS.user_sessions_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
def run(
|
||||
number_of_sessions: int = 5,
|
||||
num_episodes: int = 1000,
|
||||
episode_len: int = 128,
|
||||
n_steps: int = 1280,
|
||||
batch_size: int = 32,
|
||||
learning_rate: float = 3e-4,
|
||||
) -> None:
|
||||
"""Run the PrimAITE benchmark."""
|
||||
start_datetime = datetime.now()
|
||||
av_reward_per_episode_dicts = {}
|
||||
for i in range(1, 11):
|
||||
benchmark_start_time = datetime.now()
|
||||
|
||||
session_metadata_dict = {}
|
||||
|
||||
_prepare_session_directory()
|
||||
|
||||
# run training
|
||||
for i in range(1, number_of_sessions + 1):
|
||||
print(f"Starting Benchmark Session: {i}")
|
||||
with _get_benchmark_primaite_session() as session:
|
||||
session.learn()
|
||||
av_reward_per_episode_dicts[i] = session.learn_metadata_dict()
|
||||
|
||||
benchmark_metadata = _build_benchmark_results_dict(
|
||||
start_datetime=start_datetime, metadata_dict=av_reward_per_episode_dicts
|
||||
with _get_benchmark_primaite_environment() as gym_env:
|
||||
session = BenchmarkSession(
|
||||
gym_env=gym_env,
|
||||
num_episodes=num_episodes,
|
||||
n_steps=n_steps,
|
||||
episode_len=episode_len,
|
||||
batch_size=batch_size,
|
||||
learning_rate=learning_rate,
|
||||
)
|
||||
session.train()
|
||||
|
||||
# Dump the session metadata so that we're not holding it in memory as it's large
|
||||
with open(_SESSION_METADATA_ROOT / f"{i}.json", "w") as file:
|
||||
json.dump(session.session_metadata, file, indent=4)
|
||||
|
||||
for i in range(1, number_of_sessions + 1):
|
||||
with open(_SESSION_METADATA_ROOT / f"{i}.json", "r") as file:
|
||||
session_metadata_dict[i] = json.load(file)
|
||||
# generate report
|
||||
build_benchmark_latex_report(
|
||||
benchmark_start_time=benchmark_start_time,
|
||||
session_metadata=session_metadata_dict,
|
||||
config_path=data_manipulation_config_path(),
|
||||
results_root_path=_RESULTS_ROOT,
|
||||
)
|
||||
v_str = f"v{primaite.__version__}"
|
||||
|
||||
version_result_dir = _RESULTS_ROOT / v_str
|
||||
if version_result_dir.exists():
|
||||
shutil.rmtree(version_result_dir)
|
||||
version_result_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
with open(version_result_dir / f"{v_str}_benchmark_metadata.json", "w") as file:
|
||||
json.dump(benchmark_metadata, file, indent=4)
|
||||
title = f"PrimAITE v{primaite.__version__.strip()} Learning Benchmark"
|
||||
fig = _plot_benchmark_metadata(benchmark_metadata, title=title)
|
||||
this_version_plot_path = version_result_dir / f"{title}.png"
|
||||
fig.write_image(this_version_plot_path)
|
||||
|
||||
fig = _plot_all_benchmarks_combined_session_av()
|
||||
|
||||
all_version_plot_path = _RESULTS_ROOT / "PrimAITE Versions Learning Benchmark.png"
|
||||
fig.write_image(all_version_plot_path)
|
||||
|
||||
_build_benchmark_latex_report(benchmark_metadata, this_version_plot_path, all_version_plot_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
305
benchmark/report.py
Normal file
305
benchmark/report.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import plotly.graph_objects as go
|
||||
import polars as pl
|
||||
import yaml
|
||||
from plotly.graph_objs import Figure
|
||||
from pylatex import Command, Document
|
||||
from pylatex import Figure as LatexFigure
|
||||
from pylatex import Section, Subsection, Tabular
|
||||
from pylatex.utils import bold
|
||||
from utils import _get_system_info
|
||||
|
||||
import primaite
|
||||
|
||||
PLOT_CONFIG = {
|
||||
"size": {"auto_size": False, "width": 1500, "height": 900},
|
||||
"template": "plotly_white",
|
||||
"range_slider": False,
|
||||
}
|
||||
|
||||
|
||||
def _build_benchmark_results_dict(start_datetime: datetime, metadata_dict: Dict, config: Dict) -> dict:
|
||||
num_sessions = len(metadata_dict) # number of sessions
|
||||
|
||||
averaged_data = {
|
||||
"start_timestamp": start_datetime.isoformat(),
|
||||
"end_datetime": datetime.now().isoformat(),
|
||||
"primaite_version": primaite.__version__,
|
||||
"system_info": _get_system_info(),
|
||||
"total_sessions": num_sessions,
|
||||
"total_episodes": sum(d["total_episodes"] for d in metadata_dict.values()),
|
||||
"total_time_steps": sum(d["total_time_steps"] for d in metadata_dict.values()),
|
||||
"av_s_per_session": sum(d["total_s"] for d in metadata_dict.values()) / num_sessions,
|
||||
"av_s_per_step": sum(d["s_per_step"] for d in metadata_dict.values()) / num_sessions,
|
||||
"av_s_per_100_steps_10_nodes": sum(d["s_per_100_steps_10_nodes"] for d in metadata_dict.values())
|
||||
/ num_sessions,
|
||||
"combined_av_reward_per_episode": {},
|
||||
"session_av_reward_per_episode": {k: v["av_reward_per_episode"] for k, v in metadata_dict.items()},
|
||||
"config": config,
|
||||
}
|
||||
|
||||
# find the average of each episode across all sessions
|
||||
episodes = metadata_dict[1]["av_reward_per_episode"].keys()
|
||||
|
||||
for episode in episodes:
|
||||
combined_av_reward = (
|
||||
sum(metadata_dict[k]["av_reward_per_episode"][episode] for k in metadata_dict.keys()) / num_sessions
|
||||
)
|
||||
averaged_data["combined_av_reward_per_episode"][episode] = combined_av_reward
|
||||
|
||||
return averaged_data
|
||||
|
||||
|
||||
def _get_df_from_episode_av_reward_dict(data: Dict) -> pl.DataFrame:
|
||||
data: Dict = {"episode": data.keys(), "av_reward": data.values()}
|
||||
|
||||
return (
|
||||
pl.from_dict(data)
|
||||
.with_columns(rolling_mean=pl.col("av_reward").rolling_mean(window_size=25))
|
||||
.rename({"rolling_mean": "rolling_av_reward"})
|
||||
)
|
||||
|
||||
|
||||
def _plot_benchmark_metadata(
|
||||
benchmark_metadata_dict: Dict,
|
||||
title: Optional[str] = None,
|
||||
subtitle: Optional[str] = None,
|
||||
) -> Figure:
|
||||
if title:
|
||||
if subtitle:
|
||||
title = f"{title} <br>{subtitle}</sup>"
|
||||
else:
|
||||
if subtitle:
|
||||
title = subtitle
|
||||
|
||||
layout = go.Layout(
|
||||
autosize=PLOT_CONFIG["size"]["auto_size"],
|
||||
width=PLOT_CONFIG["size"]["width"],
|
||||
height=PLOT_CONFIG["size"]["height"],
|
||||
)
|
||||
# Create the line graph with a colored line
|
||||
fig = go.Figure(layout=layout)
|
||||
fig.update_layout(template=PLOT_CONFIG["template"])
|
||||
|
||||
for session, av_reward_dict in benchmark_metadata_dict["session_av_reward_per_episode"].items():
|
||||
df = _get_df_from_episode_av_reward_dict(av_reward_dict)
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["episode"],
|
||||
y=df["av_reward"],
|
||||
mode="lines",
|
||||
name=f"Session {session}",
|
||||
opacity=0.25,
|
||||
line={"color": "#a6a6a6"},
|
||||
)
|
||||
)
|
||||
|
||||
df = _get_df_from_episode_av_reward_dict(benchmark_metadata_dict["combined_av_reward_per_episode"])
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["episode"], y=df["av_reward"], mode="lines", name="Combined Session Av", line={"color": "#FF0000"}
|
||||
)
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["episode"],
|
||||
y=df["rolling_av_reward"],
|
||||
mode="lines",
|
||||
name="Rolling Av (Combined Session Av)",
|
||||
line={"color": "#4CBB17"},
|
||||
)
|
||||
)
|
||||
|
||||
# Set the layout of the graph
|
||||
fig.update_layout(
|
||||
xaxis={
|
||||
"title": "Episode",
|
||||
"type": "linear",
|
||||
},
|
||||
yaxis={"title": "Total Reward"},
|
||||
title=title,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure:
|
||||
"""
|
||||
Plot the Benchmark results for each released version of PrimAITE.
|
||||
|
||||
Does this by iterating over the ``benchmark/results`` directory and
|
||||
extracting the benchmark metadata json for each version that has been
|
||||
benchmarked. The combined_av_reward_per_episode is extracted from each,
|
||||
converted into a polars dataframe, and plotted as a scatter line in plotly.
|
||||
"""
|
||||
major_v = primaite.__version__.split(".")[0]
|
||||
title = f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*"
|
||||
subtitle = "Rolling Av (Combined Session Av)"
|
||||
if title:
|
||||
if subtitle:
|
||||
title = f"{title} <br>{subtitle}</sup>"
|
||||
else:
|
||||
if subtitle:
|
||||
title = subtitle
|
||||
layout = go.Layout(
|
||||
autosize=PLOT_CONFIG["size"]["auto_size"],
|
||||
width=PLOT_CONFIG["size"]["width"],
|
||||
height=PLOT_CONFIG["size"]["height"],
|
||||
)
|
||||
# Create the line graph with a colored line
|
||||
fig = go.Figure(layout=layout)
|
||||
fig.update_layout(template=PLOT_CONFIG["template"])
|
||||
|
||||
for dir in results_directory.iterdir():
|
||||
if dir.is_dir():
|
||||
metadata_file = dir / f"{dir.name}_benchmark_metadata.json"
|
||||
with open(metadata_file, "r") as file:
|
||||
metadata_dict = json.load(file)
|
||||
df = _get_df_from_episode_av_reward_dict(metadata_dict["combined_av_reward_per_episode"])
|
||||
|
||||
fig.add_trace(go.Scatter(x=df["episode"], y=df["rolling_av_reward"], mode="lines", name=dir.name))
|
||||
|
||||
# Set the layout of the graph
|
||||
fig.update_layout(
|
||||
xaxis={
|
||||
"title": "Episode",
|
||||
"type": "linear",
|
||||
},
|
||||
yaxis={"title": "Total Reward"},
|
||||
title=title,
|
||||
)
|
||||
fig["data"][0]["showlegend"] = True
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def build_benchmark_latex_report(
|
||||
benchmark_start_time: datetime, session_metadata: Dict, config_path: Path, results_root_path: Path
|
||||
) -> None:
|
||||
"""Generates a latex report of the benchmark run."""
|
||||
# generate report folder
|
||||
v_str = f"v{primaite.__version__}"
|
||||
|
||||
version_result_dir = results_root_path / v_str
|
||||
version_result_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# load the config file as dict
|
||||
with open(config_path, "r") as f:
|
||||
cfg_data = yaml.safe_load(f)
|
||||
|
||||
# generate the benchmark metadata dict
|
||||
benchmark_metadata_dict = _build_benchmark_results_dict(
|
||||
start_datetime=benchmark_start_time, metadata_dict=session_metadata, config=cfg_data
|
||||
)
|
||||
major_v = primaite.__version__.split(".")[0]
|
||||
with open(version_result_dir / f"{v_str}_benchmark_metadata.json", "w") as file:
|
||||
json.dump(benchmark_metadata_dict, file, indent=4)
|
||||
title = f"PrimAITE v{primaite.__version__.strip()} Learning Benchmark"
|
||||
fig = _plot_benchmark_metadata(benchmark_metadata_dict, title=title)
|
||||
this_version_plot_path = version_result_dir / f"{title}.png"
|
||||
fig.write_image(this_version_plot_path)
|
||||
|
||||
fig = _plot_all_benchmarks_combined_session_av(results_directory=results_root_path)
|
||||
|
||||
all_version_plot_path = results_root_path / "PrimAITE Versions Learning Benchmark.png"
|
||||
fig.write_image(all_version_plot_path)
|
||||
|
||||
geometry_options = {"tmargin": "2.5cm", "rmargin": "2.5cm", "bmargin": "2.5cm", "lmargin": "2.5cm"}
|
||||
data = benchmark_metadata_dict
|
||||
primaite_version = data["primaite_version"]
|
||||
|
||||
# Create a new document
|
||||
doc = Document("report", geometry_options=geometry_options)
|
||||
# Title
|
||||
doc.preamble.append(Command("title", f"PrimAITE {primaite_version} Learning Benchmark"))
|
||||
doc.preamble.append(Command("author", "PrimAITE Dev Team"))
|
||||
doc.preamble.append(Command("date", datetime.now().date()))
|
||||
doc.append(Command("maketitle"))
|
||||
|
||||
sessions = data["total_sessions"]
|
||||
episodes = session_metadata[1]["total_episodes"] - 1
|
||||
steps = data["config"]["game"]["max_episode_length"]
|
||||
|
||||
# Body
|
||||
with doc.create(Section("Introduction")):
|
||||
doc.append(
|
||||
f"PrimAITE v{primaite_version} was benchmarked automatically upon release. Learning rate metrics "
|
||||
f"were captured to be referenced during system-level testing and user acceptance testing (UAT)."
|
||||
)
|
||||
doc.append(
|
||||
f"\nThe benchmarking process consists of running {sessions} training session using the same "
|
||||
f"config file. Each session trains an agent for {episodes} episodes, "
|
||||
f"with each episode consisting of {steps} steps."
|
||||
)
|
||||
doc.append(
|
||||
f"\nThe total reward per episode from each session is captured. This is then used to calculate an "
|
||||
f"caverage total reward per episode from the {sessions} individual sessions for smoothing. "
|
||||
f"Finally, a 25-widow rolling average of the average total reward per session is calculated for "
|
||||
f"further smoothing."
|
||||
)
|
||||
|
||||
with doc.create(Section("System Information")):
|
||||
with doc.create(Subsection("Python")):
|
||||
with doc.create(Tabular("|l|l|")) as table:
|
||||
table.add_hline()
|
||||
table.add_row((bold("Version"), sys.version))
|
||||
table.add_hline()
|
||||
for section, section_data in data["system_info"].items():
|
||||
if section_data:
|
||||
with doc.create(Subsection(section)):
|
||||
if isinstance(section_data, dict):
|
||||
with doc.create(Tabular("|l|l|")) as table:
|
||||
table.add_hline()
|
||||
for key, value in section_data.items():
|
||||
table.add_row((bold(key), value))
|
||||
table.add_hline()
|
||||
elif isinstance(section_data, list):
|
||||
headers = section_data[0].keys()
|
||||
tabs_str = "|".join(["l" for _ in range(len(headers))])
|
||||
tabs_str = f"|{tabs_str}|"
|
||||
with doc.create(Tabular(tabs_str)) as table:
|
||||
table.add_hline()
|
||||
table.add_row([bold(h) for h in headers])
|
||||
table.add_hline()
|
||||
for item in section_data:
|
||||
table.add_row(item.values())
|
||||
table.add_hline()
|
||||
|
||||
headers_map = {
|
||||
"total_sessions": "Total Sessions",
|
||||
"total_episodes": "Total Episodes",
|
||||
"total_time_steps": "Total Steps",
|
||||
"av_s_per_session": "Av Session Duration (s)",
|
||||
"av_s_per_step": "Av Step Duration (s)",
|
||||
"av_s_per_100_steps_10_nodes": "Av Duration per 100 Steps per 10 Nodes (s)",
|
||||
}
|
||||
with doc.create(Section("Stats")):
|
||||
with doc.create(Subsection("Benchmark Results")):
|
||||
with doc.create(Tabular("|l|l|")) as table:
|
||||
table.add_hline()
|
||||
for section, header in headers_map.items():
|
||||
if section.startswith("av_"):
|
||||
table.add_row((bold(header), f"{data[section]:.4f}"))
|
||||
else:
|
||||
table.add_row((bold(header), data[section]))
|
||||
table.add_hline()
|
||||
|
||||
with doc.create(Section("Graphs")):
|
||||
with doc.create(Subsection(f"v{primaite_version} Learning Benchmark Plot")):
|
||||
with doc.create(LatexFigure(position="h!")) as pic:
|
||||
pic.add_image(str(this_version_plot_path))
|
||||
pic.add_caption(f"PrimAITE {primaite_version} Learning Benchmark Plot")
|
||||
|
||||
with doc.create(Subsection(f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*")):
|
||||
with doc.create(LatexFigure(position="h!")) as pic:
|
||||
pic.add_image(str(all_version_plot_path))
|
||||
pic.add_caption(f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*")
|
||||
|
||||
doc.generate_pdf(str(this_version_plot_path).replace(".png", ""), clean_tex=True)
|
||||
|
Before Width: | Height: | Size: 79 KiB After Width: | Height: | Size: 79 KiB |
|
Before Width: | Height: | Size: 225 KiB After Width: | Height: | Size: 225 KiB |
BIN
benchmark/results/v3/PrimAITE Versions Learning Benchmark.png
Normal file
BIN
benchmark/results/v3/PrimAITE Versions Learning Benchmark.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 91 KiB |
Binary file not shown.
Binary file not shown.
|
After Width: | Height: | Size: 295 KiB |
7436
benchmark/results/v3/v3.0.0/v3.0.0_benchmark_metadata.json
Normal file
7436
benchmark/results/v3/v3.0.0/v3.0.0_benchmark_metadata.json
Normal file
File diff suppressed because it is too large
Load Diff
47
benchmark/utils.py
Normal file
47
benchmark/utils.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import platform
|
||||
from typing import Dict
|
||||
|
||||
import psutil
|
||||
from GPUtil import GPUtil
|
||||
|
||||
|
||||
def get_size(size_bytes: int) -> str:
|
||||
"""
|
||||
Scale bytes to its proper format.
|
||||
|
||||
e.g:
|
||||
1253656 => '1.20MB'
|
||||
1253656678 => '1.17GB'
|
||||
|
||||
:
|
||||
"""
|
||||
factor = 1024
|
||||
for unit in ["", "K", "M", "G", "T", "P"]:
|
||||
if size_bytes < factor:
|
||||
return f"{size_bytes:.2f}{unit}B"
|
||||
size_bytes /= factor
|
||||
|
||||
|
||||
def _get_system_info() -> Dict:
|
||||
"""Builds and returns a dict containing system info."""
|
||||
uname = platform.uname()
|
||||
cpu_freq = psutil.cpu_freq()
|
||||
virtual_mem = psutil.virtual_memory()
|
||||
swap_mem = psutil.swap_memory()
|
||||
gpus = GPUtil.getGPUs()
|
||||
return {
|
||||
"System": {
|
||||
"OS": uname.system,
|
||||
"OS Version": uname.version,
|
||||
"Machine": uname.machine,
|
||||
"Processor": uname.processor,
|
||||
},
|
||||
"CPU": {
|
||||
"Physical Cores": psutil.cpu_count(logical=False),
|
||||
"Total Cores": psutil.cpu_count(logical=True),
|
||||
"Max Frequency": f"{cpu_freq.max:.2f}Mhz",
|
||||
},
|
||||
"Memory": {"Total": get_size(virtual_mem.total), "Swap Total": get_size(swap_mem.total)},
|
||||
"GPU": [{"Name": gpu.name, "Total Memory": f"{gpu.memoryTotal}MB"} for gpu in gpus],
|
||||
}
|
||||
@@ -54,7 +54,7 @@ It is agnostic to the number of agents, their action / observation spaces, and t
|
||||
It presents a public API providing a method for describing the current state of the simulation, a method that accepts action requests and provides responses, and a method that triggers a timestep advancement.
|
||||
The Game Layer converts the simulation into a playable game for the agent(s).
|
||||
|
||||
it translates between simulation state and Gymnasium.Spaces to pass action / observation data between the agent(s) and the simulation. It is responsible for calculating rewards, managing Multi-Agent RL (MARL) action turns, and via a single agent interface can interact with Blue, Red and Green agents.
|
||||
It translates between simulation state and Gymnasium.Spaces to pass action / observation data between the agent(s) and the simulation. It is responsible for calculating rewards, managing Multi-Agent RL (MARL) action turns, and via a single agent interface can interact with Blue, Red and Green agents.
|
||||
|
||||
Agents can either generate their own scripted behaviour or accept input behaviour from an RL agent.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _about:
|
||||
|
||||
|
||||
@@ -77,6 +77,6 @@ The following extensions should now be installed
|
||||
:width: 300
|
||||
:align: center
|
||||
|
||||
VSCode will then ask for a Python environment version to use. PrimAITE is compatible with Python versions 3.8 - 3.10
|
||||
VSCode will then ask for a Python environment version to use. PrimAITE is compatible with Python versions 3.8 - 3.11
|
||||
|
||||
You should now be able to interact with the notebook.
|
||||
|
||||
347
docs/source/simulation_components/system/applications/nmap.rst
Normal file
347
docs/source/simulation_components/system/applications/nmap.rst
Normal file
@@ -0,0 +1,347 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _NMAP:
|
||||
|
||||
NMAP
|
||||
====
|
||||
|
||||
Overview
|
||||
--------
|
||||
|
||||
The NMAP application is used to simulate network scanning activities. NMAP is a powerful tool that helps in discovering
|
||||
hosts and services on a network. It provides functionalities such as ping scans to discover active hosts and port scans
|
||||
to detect open ports on those hosts.
|
||||
|
||||
The NMAP application is essential for network administrators and security professionals to map out a network's
|
||||
structure, identify active devices, and find potential vulnerabilities by discovering open ports and running services.
|
||||
However, it is also a tool frequently used by attackers during the reconnaissance stage of a cyber attack to gather
|
||||
information about the target network.
|
||||
|
||||
Scan Types
|
||||
----------
|
||||
|
||||
Ping Scan
|
||||
^^^^^^^^^
|
||||
|
||||
A ping scan is used to identify which hosts on a network are active and reachable. This is achieved by sending ICMP
|
||||
Echo Request packets (ping) to the target IP addresses. If a host responds with an ICMP Echo Reply, it is considered
|
||||
active. Ping scans are useful for quickly mapping out live hosts in a network.
|
||||
|
||||
Port Scan
|
||||
^^^^^^^^^
|
||||
|
||||
A port scan is used to detect open ports on a target host or range of hosts. Open ports can indicate running services
|
||||
that might be exploitable or require securing. Port scans help in understanding the services available on a network and
|
||||
identifying potential entry points for attacks. There are three types of port scans based on the scope:
|
||||
|
||||
- **Horizontal Port Scan**: This scan targets a specific port across a range of IP addresses. It helps in identifying
|
||||
which hosts have a particular service running.
|
||||
|
||||
- **Vertical Port Scan**: This scan targets multiple ports on a single IP address. It provides detailed information
|
||||
about the services running on a specific host.
|
||||
|
||||
- **Box Scan**: This combines both horizontal and vertical scans, targeting multiple ports across multiple IP addresses.
|
||||
It gives a comprehensive view of the network's service landscape.
|
||||
|
||||
Example Usage
|
||||
-------------
|
||||
|
||||
The network we use for these examples is defined below:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ipaddress import IPv4Network
|
||||
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.network.router import Router
|
||||
from primaite.simulator.network.hardware.nodes.network.switch import Switch
|
||||
from primaite.simulator.system.applications.nmap import NMAP
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
|
||||
# Initialize the network
|
||||
network = Network()
|
||||
|
||||
# Set up the router
|
||||
router = Router(hostname="router", start_up_duration=0)
|
||||
router.power_on()
|
||||
router.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
|
||||
|
||||
# Set up PC 1
|
||||
pc_1 = Computer(
|
||||
hostname="pc_1",
|
||||
ip_address="192.168.1.11",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0
|
||||
)
|
||||
pc_1.power_on()
|
||||
|
||||
# Set up PC 2
|
||||
pc_2 = Computer(
|
||||
hostname="pc_2",
|
||||
ip_address="192.168.1.12",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0
|
||||
)
|
||||
pc_2.power_on()
|
||||
pc_2.software_manager.install(DatabaseService)
|
||||
pc_2.software_manager.software["DatabaseService"].start() # start the postgres server
|
||||
|
||||
# Set up PC 3
|
||||
pc_3 = Computer(
|
||||
hostname="pc_3",
|
||||
ip_address="192.168.1.13",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0
|
||||
)
|
||||
# Don't power on PC 3
|
||||
|
||||
# Set up the switch
|
||||
switch = Switch(hostname="switch", start_up_duration=0)
|
||||
switch.power_on()
|
||||
|
||||
# Connect devices
|
||||
network.connect(router.network_interface[1], switch.network_interface[24])
|
||||
network.connect(switch.network_interface[1], pc_1.network_interface[1])
|
||||
network.connect(switch.network_interface[2], pc_2.network_interface[1])
|
||||
network.connect(switch.network_interface[3], pc_3.network_interface[1])
|
||||
|
||||
|
||||
pc_1_nmap: NMAP = pc_1.software_manager.software["NMAP"]
|
||||
|
||||
|
||||
Ping Scan
|
||||
^^^^^^^^^
|
||||
|
||||
Perform a ping scan to find active hosts in the `192.168.1.0/24` subnet:
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Ping Scan Code
|
||||
|
||||
active_hosts = pc_1_nmap.ping_scan(target_ip_address=IPv4Network("192.168.1.0/24"))
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Ping Scan Return Value
|
||||
|
||||
[
|
||||
IPv4Address('192.168.1.11'),
|
||||
IPv4Address('192.168.1.12'),
|
||||
IPv4Address('192.168.1.1')
|
||||
]
|
||||
|
||||
.. code-block:: text
|
||||
:caption: Ping Scan Output
|
||||
|
||||
+-------------------------+
|
||||
| pc_1 NMAP Ping Scan |
|
||||
+--------------+----------+
|
||||
| IP Address | Can Ping |
|
||||
+--------------+----------+
|
||||
| 192.168.1.1 | True |
|
||||
| 192.168.1.11 | True |
|
||||
| 192.168.1.12 | True |
|
||||
+--------------+----------+
|
||||
|
||||
Horizontal Port Scan
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Perform a horizontal port scan on port 5432 across multiple IP addresses:
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Horizontal Port Scan Code
|
||||
|
||||
horizontal_scan_results = pc_1_nmap.port_scan(
|
||||
target_ip_address=[IPv4Address("192.168.1.12"), IPv4Address("192.168.1.13")],
|
||||
target_port=Port(5432 )
|
||||
)
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Horizontal Port Scan Return Value
|
||||
|
||||
{
|
||||
IPv4Address('192.168.1.12'): {
|
||||
<IPProtocol.TCP: 'tcp'>: [
|
||||
<Port.POSTGRES_SERVER: 5432>
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
.. code-block:: text
|
||||
:caption: Horizontal Port Scan Output
|
||||
|
||||
+--------------------------------------------------+
|
||||
| pc_1 NMAP Port Scan (Horizontal) |
|
||||
+--------------+------+-----------------+----------+
|
||||
| IP Address | Port | Name | Protocol |
|
||||
+--------------+------+-----------------+----------+
|
||||
| 192.168.1.12 | 5432 | POSTGRES_SERVER | TCP |
|
||||
+--------------+------+-----------------+----------+
|
||||
|
||||
Vertical Post Scan
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Perform a vertical port scan on multiple ports on a single IP address:
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Vertical Port Scan Code
|
||||
|
||||
vertical_scan_results = pc_1_nmap.port_scan(
|
||||
target_ip_address=[IPv4Address("192.168.1.12")],
|
||||
target_port=[Port(21), Port(22), Port(80), Port(443)]
|
||||
)
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Vertical Port Scan Return Value
|
||||
|
||||
{
|
||||
IPv4Address('192.168.1.12'): {
|
||||
<IPProtocol.TCP: 'tcp'>: [
|
||||
<Port.FTP: 21>,
|
||||
<Port.HTTP: 80>
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
.. code-block:: text
|
||||
:caption: Vertical Port Scan Output
|
||||
|
||||
+---------------------------------------+
|
||||
| pc_1 NMAP Port Scan (Vertical) |
|
||||
+--------------+------+------+----------+
|
||||
| IP Address | Port | Name | Protocol |
|
||||
+--------------+------+------+----------+
|
||||
| 192.168.1.12 | 21 | FTP | TCP |
|
||||
| 192.168.1.12 | 80 | HTTP | TCP |
|
||||
+--------------+------+------+----------+
|
||||
|
||||
Box Scan
|
||||
^^^^^^^^
|
||||
|
||||
Perform a box scan on multiple ports across multiple IP addresses:
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Box Port Scan Code
|
||||
|
||||
# Power PC 3 on before performing the box scan
|
||||
pc_3.power_on()
|
||||
|
||||
|
||||
box_scan_results = pc_1_nmap.port_scan(
|
||||
target_ip_address=[IPv4Address("192.168.1.12"), IPv4Address("192.168.1.13")],
|
||||
target_port=[Port(21), Port(22), Port(80), Port(443)]
|
||||
)
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Box Port Scan Return Value
|
||||
|
||||
{
|
||||
IPv4Address('192.168.1.13'): {
|
||||
<IPProtocol.TCP: 'tcp'>: [
|
||||
<Port.FTP: 21>,
|
||||
<Port.HTTP: 80>
|
||||
]
|
||||
},
|
||||
IPv4Address('192.168.1.12'): {
|
||||
<IPProtocol.TCP: 'tcp'>: [
|
||||
<Port.FTP: 21>,
|
||||
<Port.HTTP: 80>
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
.. code-block:: text
|
||||
:caption: Box Port Scan Output
|
||||
|
||||
+---------------------------------------+
|
||||
| pc_1 NMAP Port Scan (Box) |
|
||||
+--------------+------+------+----------+
|
||||
| IP Address | Port | Name | Protocol |
|
||||
+--------------+------+------+----------+
|
||||
| 192.168.1.12 | 21 | FTP | TCP |
|
||||
| 192.168.1.12 | 80 | HTTP | TCP |
|
||||
| 192.168.1.13 | 21 | FTP | TCP |
|
||||
| 192.168.1.13 | 80 | HTTP | TCP |
|
||||
+--------------+------+------+----------+
|
||||
|
||||
Full Box Scan
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Perform a full box scan on all ports, over both TCP and UDP, on a whole subnet:
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Box Port Scan Code
|
||||
|
||||
# Power PC 3 on before performing the full box scan
|
||||
pc_3.power_on()
|
||||
|
||||
|
||||
full_box_scan_results = pc_1_nmap.port_scan(
|
||||
target_ip_address=IPv4Network("192.168.1.0/24"),
|
||||
)
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Box Port Scan Return Value
|
||||
|
||||
{
|
||||
IPv4Address('192.168.1.11'): {
|
||||
<IPProtocol.UDP: 'udp'>: [
|
||||
<Port.ARP: 219>
|
||||
]
|
||||
},
|
||||
IPv4Address('192.168.1.1'): {
|
||||
<IPProtocol.UDP: 'udp'>: [
|
||||
<Port.ARP: 219>
|
||||
]
|
||||
},
|
||||
IPv4Address('192.168.1.12'): {
|
||||
<IPProtocol.TCP: 'tcp'>: [
|
||||
<Port.HTTP: 80>,
|
||||
<Port.DNS: 53>,
|
||||
<Port.POSTGRES_SERVER: 5432>,
|
||||
<Port.FTP: 21>
|
||||
],
|
||||
<IPProtocol.UDP: 'udp'>: [
|
||||
<Port.NTP: 123>,
|
||||
<Port.ARP: 219>
|
||||
]
|
||||
},
|
||||
IPv4Address('192.168.1.13'): {
|
||||
<IPProtocol.TCP: 'tcp'>: [
|
||||
<Port.HTTP: 80>,
|
||||
<Port.DNS: 53>,
|
||||
<Port.FTP: 21>
|
||||
],
|
||||
<IPProtocol.UDP: 'udp'>: [
|
||||
<Port.NTP: 123>,
|
||||
<Port.ARP: 219>
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
.. code-block:: text
|
||||
:caption: Box Port Scan Output
|
||||
|
||||
+--------------------------------------------------+
|
||||
| pc_1 NMAP Port Scan (Box) |
|
||||
+--------------+------+-----------------+----------+
|
||||
| IP Address | Port | Name | Protocol |
|
||||
+--------------+------+-----------------+----------+
|
||||
| 192.168.1.1 | 219 | ARP | UDP |
|
||||
| 192.168.1.11 | 219 | ARP | UDP |
|
||||
| 192.168.1.12 | 21 | FTP | TCP |
|
||||
| 192.168.1.12 | 53 | DNS | TCP |
|
||||
| 192.168.1.12 | 80 | HTTP | TCP |
|
||||
| 192.168.1.12 | 123 | NTP | UDP |
|
||||
| 192.168.1.12 | 219 | ARP | UDP |
|
||||
| 192.168.1.12 | 5432 | POSTGRES_SERVER | TCP |
|
||||
| 192.168.1.13 | 21 | FTP | TCP |
|
||||
| 192.168.1.13 | 53 | DNS | TCP |
|
||||
| 192.168.1.13 | 80 | HTTP | TCP |
|
||||
| 192.168.1.13 | 123 | NTP | UDP |
|
||||
| 192.168.1.13 | 219 | ARP | UDP |
|
||||
+--------------+------+-----------------+----------+
|
||||
@@ -24,7 +24,7 @@ For each variation that could be used in a placeholder, there is a separate yaml
|
||||
|
||||
The data that fills the placeholder is defined as a YAML Anchor in a separate file, denoted by an ampersand ``&anchor``.
|
||||
|
||||
Learn more about YAML Aliases and Anchors here.
|
||||
Learn more about YAML Aliases and Anchors `here <https://yaml.org/spec/1.2.2/#3222-anchors-and-aliases>`_.
|
||||
|
||||
Schedule
|
||||
********
|
||||
|
||||
@@ -33,7 +33,7 @@ dependencies = [
|
||||
"numpy==1.23.5",
|
||||
"platformdirs==3.5.1",
|
||||
"plotly==5.15.0",
|
||||
"polars==0.18.4",
|
||||
"polars==0.20.30",
|
||||
"prettytable==3.8.0",
|
||||
"PyYAML==6.0",
|
||||
"typer[all]==0.9.0",
|
||||
|
||||
@@ -1 +1 @@
|
||||
3.0.0b9
|
||||
3.0.0
|
||||
|
||||
@@ -11,7 +11,7 @@ AbstractAction. The ActionManager is responsible for:
|
||||
"""
|
||||
import itertools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
@@ -871,6 +871,74 @@ class NetworkPortDisableAction(AbstractAction):
|
||||
return ["network", "node", target_nodename, "network_interface", port_id, "disable"]
|
||||
|
||||
|
||||
class NodeNMAPPingScanAction(AbstractAction):
|
||||
"""Action which performs an NMAP ping scan."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
|
||||
def form_request(self, source_node: str, target_ip_address: Union[str, List[str]]) -> List[str]: # noqa
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
source_node,
|
||||
"application",
|
||||
"NMAP",
|
||||
"ping_scan",
|
||||
{"target_ip_address": target_ip_address},
|
||||
]
|
||||
|
||||
|
||||
class NodeNMAPPortScanAction(AbstractAction):
|
||||
"""Action which performs an NMAP port scan."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
|
||||
def form_request(
|
||||
self,
|
||||
source_node: str,
|
||||
target_ip_address: Union[str, List[str]],
|
||||
target_protocol: Optional[Union[str, List[str]]] = None,
|
||||
target_port: Optional[Union[str, List[str]]] = None,
|
||||
) -> List[str]: # noqa
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
source_node,
|
||||
"application",
|
||||
"NMAP",
|
||||
"port_scan",
|
||||
{"target_ip_address": target_ip_address, "target_port": target_port, "target_protocol": target_protocol},
|
||||
]
|
||||
|
||||
|
||||
class NodeNetworkServiceReconAction(AbstractAction):
|
||||
"""Action which performs an NMAP network service recon (ping scan followed by port scan)."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
|
||||
def form_request(
|
||||
self,
|
||||
source_node: str,
|
||||
target_ip_address: Union[str, List[str]],
|
||||
target_protocol: Optional[Union[str, List[str]]] = None,
|
||||
target_port: Optional[Union[str, List[str]]] = None,
|
||||
) -> List[str]: # noqa
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
source_node,
|
||||
"application",
|
||||
"NMAP",
|
||||
"network_service_recon",
|
||||
{"target_ip_address": target_ip_address, "target_port": target_port, "target_protocol": target_protocol},
|
||||
]
|
||||
|
||||
|
||||
class ActionManager:
|
||||
"""Class which manages the action space for an agent."""
|
||||
|
||||
@@ -916,6 +984,9 @@ class ActionManager:
|
||||
"HOST_NIC_DISABLE": HostNICDisableAction,
|
||||
"NETWORK_PORT_ENABLE": NetworkPortEnableAction,
|
||||
"NETWORK_PORT_DISABLE": NetworkPortDisableAction,
|
||||
"NODE_NMAP_PING_SCAN": NodeNMAPPingScanAction,
|
||||
"NODE_NMAP_PORT_SCAN": NodeNMAPPortScanAction,
|
||||
"NODE_NMAP_NETWORK_SERVICE_RECON": NodeNetworkServiceReconAction,
|
||||
}
|
||||
"""Dictionary which maps action type strings to the corresponding action class."""
|
||||
|
||||
|
||||
@@ -361,11 +361,6 @@ class PrimaiteGame:
|
||||
server_ip_address=IPv4Address(opt.get("server_ip")),
|
||||
server_password=opt.get("server_password"),
|
||||
payload=opt.get("payload", "ENCRYPT"),
|
||||
c2_beacon_p_of_success=float(opt.get("c2_beacon_p_of_success", "0.5")),
|
||||
target_scan_p_of_success=float(opt.get("target_scan_p_of_success", "0.1")),
|
||||
ransomware_encrypt_p_of_success=float(
|
||||
opt.get("ransomware_encrypt_p_of_success", "0.1")
|
||||
),
|
||||
)
|
||||
elif application_type == "DatabaseClient":
|
||||
if "options" in application_cfg:
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Dict, ForwardRef, List, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, StrictBool, validate_call
|
||||
|
||||
RequestFormat = List[Union[str, int, float]]
|
||||
RequestFormat = List[Union[str, int, float, Dict]]
|
||||
|
||||
RequestResponse = ForwardRef("RequestResponse")
|
||||
"""This makes it possible to type-hint RequestResponse.from_bool return type."""
|
||||
|
||||
@@ -4,13 +4,15 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Customising red agents\n",
|
||||
"# Customising UC2 Red Agents\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"This notebook will go over some examples of how red agent behaviour can be varied by changing its configuration parameters.\n",
|
||||
"\n",
|
||||
"First, let's load the standard Data Manipulation config file, and see what the red agent does.\n",
|
||||
"\n",
|
||||
"*(For a full explanation of the Data Manipulation scenario, check out the notebook `Data-Manipulation-E2E-Demonstration.ipynb`)*"
|
||||
"*(For a full explanation of the Data Manipulation scenario, check out the data manipulation scenario notebook)*"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Data Manipulation Scenario\n"
|
||||
"# Data Manipulation Scenario\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -79,7 +81,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Reinforcement learning details"
|
||||
"## Reinforcement learning details"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -692,7 +694,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Getting information out of PrimAITE"
|
||||
"# Getting information out of PrimAITE\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -160,7 +162,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
"version": "3.10.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -6,7 +6,9 @@
|
||||
"source": [
|
||||
"# Requests and Responses\n",
|
||||
"\n",
|
||||
"Agents interact with the PrimAITE simulation via the Request system.\n"
|
||||
"Agents interact with the PrimAITE simulation via the Request system.\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Train a Multi agent system using RLLIB\n",
|
||||
"# Train a Multi agent system using RLLIB\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"This notebook will demonstrate how to use the `PrimaiteRayMARLEnv` to train a very basic system with two PPO agents."
|
||||
]
|
||||
|
||||
@@ -4,7 +4,10 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Train a Single agent system using RLLib\n",
|
||||
"# Train a Single agent system using RLLib\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"This notebook will demonstrate how to use PrimaiteRayEnv to train a basic PPO agent."
|
||||
]
|
||||
},
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
"source": [
|
||||
"# Training an SB3 Agent\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"This notebook will demonstrate how to use primaite to create and train a PPO agent, using a pre-defined configuration file."
|
||||
]
|
||||
},
|
||||
@@ -166,7 +168,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -180,7 +182,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
"source": [
|
||||
"# Using Episode Schedules\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"PrimAITE supports the ability to use different variations on a scenario at different episodes. This can be used to increase \n",
|
||||
"domain randomisation to prevent overfitting, or to set up curriculum learning to train agents to perform more complicated tasks.\n",
|
||||
"\n",
|
||||
|
||||
@@ -4,7 +4,11 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Simple multi-processing demo using SubprocVecEnv from SB3"
|
||||
"# Simple multi-processing demonstration\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"This notebook uses SubprocVecEnv from SB3."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -139,7 +143,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -38,6 +38,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
self.total_reward_per_episode: Dict[int, float] = {}
|
||||
"""Average rewards of agents per episode."""
|
||||
|
||||
@property
|
||||
def agent(self) -> ProxyAgent:
|
||||
@@ -90,6 +92,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
f"Resetting environment, episode {self.episode_counter}, "
|
||||
f"avg. reward: {self.agent.reward_function.total_reward}"
|
||||
)
|
||||
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
|
||||
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
@@ -6,7 +6,9 @@
|
||||
"source": [
|
||||
"# Build a simulation using the Python API\n",
|
||||
"\n",
|
||||
"Currently, this notebook manipulates the simulation by directly placing objects inside of the attributes of the network and domain. It should be refactored when proper methods exist for adding these objects.\n"
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"Currently, this notebook manipulates the simulation by directly placing objects inside of the attributes of the network and domain. It should be refactored when proper methods exist for adding these objects."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
"source": [
|
||||
"# PrimAITE Router Simulation Demo\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"This demo uses a modified version of the ARCD Use Case 2 Network (seen below) to demonstrate the capabilities of the Network simulator in PrimAITE."
|
||||
]
|
||||
},
|
||||
|
||||
@@ -222,7 +222,7 @@ class SimComponent(BaseModel):
|
||||
return state
|
||||
|
||||
@validate_call
|
||||
def apply_request(self, request: RequestFormat, context: Dict = {}) -> RequestResponse:
|
||||
def apply_request(self, request: RequestFormat, context: Optional[Dict] = None) -> RequestResponse:
|
||||
"""
|
||||
Apply a request to a simulation component. Request data is passed in as a 'namespaced' list of strings.
|
||||
|
||||
@@ -240,6 +240,8 @@ class SimComponent(BaseModel):
|
||||
:param: context: Dict containing context for requests
|
||||
:type context: Dict
|
||||
"""
|
||||
if not context:
|
||||
context = None
|
||||
if self._request_manager is None:
|
||||
return
|
||||
return self._request_manager(request, context)
|
||||
|
||||
@@ -29,6 +29,7 @@ from primaite.simulator.network.nmne import (
|
||||
NMNE_CAPTURE_KEYWORDS,
|
||||
)
|
||||
from primaite.simulator.network.transmission.data_link_layer import Frame
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.core.packet_capture import PacketCapture
|
||||
from primaite.simulator.system.core.session_manager import SessionManager
|
||||
@@ -37,6 +38,7 @@ from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.processes.process import Process
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.software import IOSoftware
|
||||
from primaite.utils.converters import convert_dict_enum_keys_to_enum_values
|
||||
from primaite.utils.validators import IPV4Address
|
||||
|
||||
IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware)
|
||||
@@ -108,10 +110,14 @@ class NetworkInterface(SimComponent, ABC):
|
||||
nmne: Dict = Field(default_factory=lambda: {})
|
||||
"A dict containing details of the number of malicious network events captured."
|
||||
|
||||
traffic: Dict = Field(default_factory=lambda: {})
|
||||
"A dict containing details of the inbound and outbound traffic by port and protocol."
|
||||
|
||||
def setup_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
super().setup_for_episode(episode=episode)
|
||||
self.nmne = {}
|
||||
self.traffic = {}
|
||||
if episode and self.pcap and SIM_OUTPUT.save_pcap_logs:
|
||||
self.pcap.current_episode = episode
|
||||
self.pcap.setup_logger()
|
||||
@@ -147,6 +153,7 @@ class NetworkInterface(SimComponent, ABC):
|
||||
)
|
||||
if CAPTURE_NMNE:
|
||||
state.update({"nmne": {k: v for k, v in self.nmne.items()}})
|
||||
state.update({"traffic": convert_dict_enum_keys_to_enum_values(self.traffic)})
|
||||
return state
|
||||
|
||||
@abstractmethod
|
||||
@@ -237,6 +244,47 @@ class NetworkInterface(SimComponent, ABC):
|
||||
# Increment a generic counter if keyword capturing is not enabled
|
||||
keyword_level["*"] = keyword_level.get("*", 0) + 1
|
||||
|
||||
def _capture_traffic(self, frame: Frame, inbound: bool = True):
|
||||
"""
|
||||
Capture traffic statistics at the Network Interface.
|
||||
|
||||
:param frame: The network frame containing the traffic data.
|
||||
:type frame: Frame
|
||||
:param inbound: Flag indicating if the traffic is inbound or outbound. Defaults to True.
|
||||
:type inbound: bool
|
||||
"""
|
||||
# Determine the direction of the traffic
|
||||
direction = "inbound" if inbound else "outbound"
|
||||
|
||||
# Initialize protocol and port variables
|
||||
protocol = None
|
||||
port = None
|
||||
|
||||
# Identify the protocol and port from the frame
|
||||
if frame.tcp:
|
||||
protocol = IPProtocol.TCP
|
||||
port = frame.tcp.dst_port
|
||||
elif frame.udp:
|
||||
protocol = IPProtocol.UDP
|
||||
port = frame.udp.dst_port
|
||||
elif frame.icmp:
|
||||
protocol = IPProtocol.ICMP
|
||||
|
||||
# Ensure the protocol is in the capture dict
|
||||
if protocol not in self.traffic:
|
||||
self.traffic[protocol] = {}
|
||||
|
||||
# Handle non-ICMP protocols that use ports
|
||||
if protocol != IPProtocol.ICMP:
|
||||
if port not in self.traffic[protocol]:
|
||||
self.traffic[protocol][port] = {"inbound": 0, "outbound": 0}
|
||||
self.traffic[protocol][port][direction] += frame.size
|
||||
else:
|
||||
# Handle ICMP protocol separately (ICMP does not use ports)
|
||||
if not self.traffic[protocol]:
|
||||
self.traffic[protocol] = {"inbound": 0, "outbound": 0}
|
||||
self.traffic[protocol][direction] += frame.size
|
||||
|
||||
@abstractmethod
|
||||
def send_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
@@ -246,6 +294,7 @@ class NetworkInterface(SimComponent, ABC):
|
||||
:return: A boolean indicating whether the frame was successfully sent.
|
||||
"""
|
||||
self._capture_nmne(frame, inbound=False)
|
||||
self._capture_traffic(frame, inbound=False)
|
||||
|
||||
@abstractmethod
|
||||
def receive_frame(self, frame: Frame) -> bool:
|
||||
@@ -256,6 +305,7 @@ class NetworkInterface(SimComponent, ABC):
|
||||
:return: A boolean indicating whether the frame was successfully received.
|
||||
"""
|
||||
self._capture_nmne(frame, inbound=True)
|
||||
self._capture_traffic(frame, inbound=True)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
@@ -767,6 +817,24 @@ class Node(SimComponent):
|
||||
self.session_manager.software_manager = self.software_manager
|
||||
self._install_system_software()
|
||||
|
||||
def ip_is_network_interface(self, ip_address: IPv4Address, enabled_only: bool = False) -> bool:
|
||||
"""
|
||||
Checks if a given IP address belongs to any of the nodes interfaces.
|
||||
|
||||
:param ip_address: The IP address to check.
|
||||
:param enabled_only: If True, only considers enabled network interfaces.
|
||||
:return: True if the IP address is assigned to one of the nodes interfaces; False otherwise.
|
||||
"""
|
||||
for network_interface in self.network_interface.values():
|
||||
if not hasattr(network_interface, "ip_address"):
|
||||
continue
|
||||
if network_interface.ip_address == ip_address:
|
||||
if enabled_only:
|
||||
return network_interface.enabled
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
def setup_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
super().setup_for_episode(episode=episode)
|
||||
|
||||
@@ -8,6 +8,8 @@ from primaite import getLogger
|
||||
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface, Link, Node
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.transmission.data_link_layer import Frame
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.nmap import NMAP
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.arp.arp import ARP, ARPPacket
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
@@ -303,6 +305,7 @@ class HostNode(Node):
|
||||
"DNSClient": DNSClient,
|
||||
"NTPClient": NTPClient,
|
||||
"WebBrowser": WebBrowser,
|
||||
"NMAP": NMAP,
|
||||
}
|
||||
"""List of system software that is automatically installed on nodes."""
|
||||
|
||||
@@ -315,6 +318,16 @@ class HostNode(Node):
|
||||
super().__init__(**kwargs)
|
||||
self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask))
|
||||
|
||||
@property
|
||||
def nmap(self) -> Optional[NMAP]:
|
||||
"""
|
||||
Return the NMAP application installed on the Node.
|
||||
|
||||
:return: NMAP application installed on the Node.
|
||||
:rtype: Optional[NMAP]
|
||||
"""
|
||||
return self.software_manager.software.get("NMAP")
|
||||
|
||||
@property
|
||||
def arp(self) -> Optional[ARP]:
|
||||
"""
|
||||
@@ -366,8 +379,15 @@ class HostNode(Node):
|
||||
elif frame.udp:
|
||||
dst_port = frame.udp.dst_port
|
||||
|
||||
can_accept_nmap = False
|
||||
if self.software_manager.software.get("NMAP"):
|
||||
if self.software_manager.software["NMAP"].operating_state == ApplicationOperatingState.RUNNING:
|
||||
can_accept_nmap = True
|
||||
|
||||
accept_nmap = can_accept_nmap and frame.payload.__class__.__name__ == "PortScanPayload"
|
||||
|
||||
accept_frame = False
|
||||
if frame.icmp or dst_port in self.software_manager.get_open_ports():
|
||||
if frame.icmp or dst_port in self.software_manager.get_open_ports() or accept_nmap:
|
||||
# accept the frame as the port is open or if it's an ICMP frame
|
||||
accept_frame = True
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType
|
||||
from primaite.simulator.network.transmission.data_link_layer import Frame
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.nmap import NMAP
|
||||
from primaite.simulator.system.core.session_manager import SessionManager
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.services.arp.arp import ARP
|
||||
@@ -1239,6 +1240,7 @@ class Router(NetworkNode):
|
||||
icmp.router = self
|
||||
self.software_manager.install(RouterARP)
|
||||
self.arp.router = self
|
||||
self.software_manager.install(NMAP)
|
||||
|
||||
def _set_default_acl(self):
|
||||
"""
|
||||
|
||||
@@ -271,9 +271,16 @@ class DatabaseClient(Application):
|
||||
|
||||
Calls disconnect on all client connections to ensure that both client and server connections are killed.
|
||||
"""
|
||||
while self.client_connections.values():
|
||||
client_connection = self.client_connections[next(iter(self.client_connections.keys()))]
|
||||
client_connection.disconnect()
|
||||
while self.client_connections:
|
||||
conn_key = next(iter(self.client_connections.keys()))
|
||||
conn_obj: DatabaseClientConnection = self.client_connections[conn_key]
|
||||
conn_obj.disconnect()
|
||||
if conn_obj.is_active or conn_key in self.client_connections:
|
||||
self.sys_log.error(
|
||||
"Attempted to uninstall database client but could not drop active connections. "
|
||||
"Forcing uninstall anyway."
|
||||
)
|
||||
self.client_connections.pop(conn_key, None)
|
||||
super().uninstall()
|
||||
|
||||
def get_new_connection(self) -> Optional[DatabaseClientConnection]:
|
||||
|
||||
452
src/primaite/simulator/system/applications/nmap.py
Normal file
452
src/primaite/simulator/system/applications/nmap.py
Normal file
@@ -0,0 +1,452 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Final, List, Optional, Set, Tuple, Union
|
||||
|
||||
from prettytable import PrettyTable
|
||||
from pydantic import validate_call
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.utils.validators import IPV4Address
|
||||
|
||||
|
||||
class PortScanPayload(SimComponent):
|
||||
"""
|
||||
A class representing the payload for a port scan.
|
||||
|
||||
:ivar ip_address: The target IP address for the port scan.
|
||||
:ivar port: The target port for the port scan.
|
||||
:ivar protocol: The protocol used for the port scan.
|
||||
:ivar request:Flag to indicate whether this is a request or not.
|
||||
"""
|
||||
|
||||
ip_address: IPV4Address
|
||||
port: Port
|
||||
protocol: IPProtocol
|
||||
request: bool = True
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describe the state of the port scan payload.
|
||||
|
||||
:return: A dictionary representation of the port scan payload state.
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["ip_address"] = str(self.ip_address)
|
||||
state["port"] = self.port.value
|
||||
state["protocol"] = self.protocol.value
|
||||
state["request"] = self.request
|
||||
|
||||
return state
|
||||
|
||||
|
||||
class NMAP(Application):
|
||||
"""
|
||||
A class representing the NMAP application for network scanning.
|
||||
|
||||
NMAP is a network scanning tool used to discover hosts and services on a network. It provides functionalities such
|
||||
as ping scans to discover active hosts and port scans to detect open ports on those hosts.
|
||||
"""
|
||||
|
||||
_active_port_scans: Dict[str, PortScanPayload] = {}
|
||||
_port_scan_responses: Dict[str, PortScanPayload] = {}
|
||||
|
||||
_PORT_SCAN_TYPE_MAP: Final[Dict[Tuple[bool, bool], str]] = {
|
||||
(True, True): "Box",
|
||||
(True, False): "Horizontal",
|
||||
(False, True): "Vertical",
|
||||
(False, False): "Port",
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "NMAP"
|
||||
kwargs["port"] = Port.NONE
|
||||
kwargs["protocol"] = IPProtocol.NONE
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _can_perform_network_action(self) -> bool:
|
||||
"""
|
||||
Checks if the NMAP application can perform outbound network actions.
|
||||
|
||||
This is done by checking the parent application can_per_action functionality. Then checking if there is an
|
||||
enabled NIC that can be used for outbound traffic.
|
||||
|
||||
:return: True if outbound network actions can be performed, otherwise False.
|
||||
"""
|
||||
if not super()._can_perform_action():
|
||||
return False
|
||||
|
||||
for nic in self.software_manager.node.network_interface.values():
|
||||
if nic.enabled:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
def _ping_scan_action(request: List[Any], context: Any) -> RequestResponse:
|
||||
results = self.ping_scan(target_ip_address=request[0]["target_ip_address"], json_serializable=True)
|
||||
if not self._can_perform_network_action():
|
||||
return RequestResponse.from_bool(False)
|
||||
return RequestResponse(
|
||||
status="success",
|
||||
data={"live_hosts": results},
|
||||
)
|
||||
|
||||
def _port_scan_action(request: List[Any], context: Any) -> RequestResponse:
|
||||
results = self.port_scan(**request[0], json_serializable=True)
|
||||
if not self._can_perform_network_action():
|
||||
return RequestResponse.from_bool(False)
|
||||
return RequestResponse(
|
||||
status="success",
|
||||
data=results,
|
||||
)
|
||||
|
||||
def _network_service_recon_action(request: List[Any], context: Any) -> RequestResponse:
|
||||
results = self.network_service_recon(**request[0], json_serializable=True)
|
||||
if not self._can_perform_network_action():
|
||||
return RequestResponse.from_bool(False)
|
||||
return RequestResponse(
|
||||
status="success",
|
||||
data=results,
|
||||
)
|
||||
|
||||
rm = RequestManager()
|
||||
|
||||
rm.add_request(
|
||||
name="ping_scan",
|
||||
request_type=RequestType(func=_ping_scan_action),
|
||||
)
|
||||
|
||||
rm.add_request(
|
||||
name="port_scan",
|
||||
request_type=RequestType(func=_port_scan_action),
|
||||
)
|
||||
|
||||
rm.add_request(
|
||||
name="network_service_recon",
|
||||
request_type=RequestType(func=_network_service_recon_action),
|
||||
)
|
||||
|
||||
return rm
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describe the state of the NMAP application.
|
||||
|
||||
:return: A dictionary representation of the NMAP application's state.
|
||||
:rtype: Dict
|
||||
"""
|
||||
return super().describe_state()
|
||||
|
||||
@staticmethod
|
||||
def _explode_ip_address_network_array(
|
||||
target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]]
|
||||
) -> Set[IPv4Address]:
|
||||
"""
|
||||
Explode a mixed array of IP addresses and networks into a set of individual IP addresses.
|
||||
|
||||
This method takes a combination of single and lists of IPv4 addresses and IPv4 networks, expands any networks
|
||||
into their constituent subnet useable IP addresses, and returns a set of unique IP addresses. Broadcast and
|
||||
network addresses are excluded from the result.
|
||||
|
||||
:param target_ip_address: A single or list of IPv4 addresses and networks.
|
||||
:type target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]]
|
||||
:return: A set of unique IPv4 addresses expanded from the input.
|
||||
:rtype: Set[IPv4Address]
|
||||
"""
|
||||
if isinstance(target_ip_address, IPv4Address) or isinstance(target_ip_address, IPv4Network):
|
||||
target_ip_address = [target_ip_address]
|
||||
ip_addresses: List[IPV4Address] = []
|
||||
for ip_address in target_ip_address:
|
||||
if isinstance(ip_address, IPv4Network):
|
||||
ip_addresses += [
|
||||
ip
|
||||
for ip in ip_address.hosts()
|
||||
if not ip == ip_address.broadcast_address and not ip == ip_address.network_address
|
||||
]
|
||||
else:
|
||||
ip_addresses.append(ip_address)
|
||||
return set(ip_addresses)
|
||||
|
||||
@validate_call()
|
||||
def ping_scan(
|
||||
self,
|
||||
target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]],
|
||||
show: bool = True,
|
||||
show_online_only: bool = True,
|
||||
json_serializable: bool = False,
|
||||
) -> Union[List[IPV4Address], List[str]]:
|
||||
"""
|
||||
Perform a ping scan on the target IP address(es).
|
||||
|
||||
:param target_ip_address: The target IP address(es) or network(s) for the ping scan.
|
||||
:type target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]]
|
||||
:param show: Flag indicating whether to display the scan results. Defaults to True.
|
||||
:type show: bool
|
||||
:param show_online_only: Flag indicating whether to show only the online hosts. Defaults to True.
|
||||
:type show_online_only: bool
|
||||
:param json_serializable: Flag indicating whether the return value should be json serializable. Defaults to
|
||||
False.
|
||||
:type json_serializable: bool
|
||||
|
||||
:return: A list of active IP addresses that responded to the ping.
|
||||
:rtype: Union[List[IPV4Address], List[str]]
|
||||
"""
|
||||
active_nodes = []
|
||||
if show:
|
||||
table = PrettyTable(["IP Address", "Can Ping"])
|
||||
table.align = "l"
|
||||
table.title = f"{self.software_manager.node.hostname} NMAP Ping Scan"
|
||||
|
||||
ip_addresses = self._explode_ip_address_network_array(target_ip_address)
|
||||
|
||||
for ip_address in ip_addresses:
|
||||
# Prevent ping scan on this node
|
||||
if self.software_manager.node.ip_is_network_interface(ip_address=ip_address):
|
||||
continue
|
||||
can_ping = self.software_manager.icmp.ping(ip_address)
|
||||
if can_ping:
|
||||
active_nodes.append(ip_address if not json_serializable else str(ip_address))
|
||||
if show and (can_ping or not show_online_only):
|
||||
table.add_row([ip_address, can_ping])
|
||||
if show:
|
||||
print(table.get_string(sortby="IP Address"))
|
||||
return active_nodes
|
||||
|
||||
def _determine_port_scan_type(self, target_ip_addresses: List[IPV4Address], target_ports: List[Port]) -> str:
|
||||
"""
|
||||
Determine the type of port scan based on the number of target IP addresses and ports.
|
||||
|
||||
:param target_ip_addresses: The list of target IP addresses.
|
||||
:type target_ip_addresses: List[IPV4Address]
|
||||
:param target_ports: The list of target ports.
|
||||
:type target_ports: List[Port]
|
||||
|
||||
:return: The type of port scan.
|
||||
:rtype: str
|
||||
"""
|
||||
vertical_scan = len(target_ports) > 1
|
||||
horizontal_scan = len(target_ip_addresses) > 1
|
||||
|
||||
return self._PORT_SCAN_TYPE_MAP[horizontal_scan, vertical_scan]
|
||||
|
||||
def _check_port_open_on_ip_address(
|
||||
self,
|
||||
ip_address: IPv4Address,
|
||||
port: Port,
|
||||
protocol: IPProtocol,
|
||||
is_re_attempt: bool = False,
|
||||
port_scan_uuid: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a port is open on a specific IP address.
|
||||
|
||||
:param ip_address: The target IP address.
|
||||
:type ip_address: IPv4Address
|
||||
:param port: The target port.
|
||||
:type port: Port
|
||||
:param protocol: The protocol used for the port scan.
|
||||
:type protocol: IPProtocol
|
||||
:param is_re_attempt: Flag indicating if this is a reattempt. Defaults to False.
|
||||
:type is_re_attempt: bool
|
||||
:param port_scan_uuid: The UUID of the port scan payload. Defaults to None.
|
||||
:type port_scan_uuid: Optional[str]
|
||||
|
||||
:return: True if the port is open, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
# The recursive base case
|
||||
if is_re_attempt:
|
||||
# Return True if a response has been received, otherwise return False
|
||||
if port_scan_uuid in self._port_scan_responses:
|
||||
self._port_scan_responses.pop(port_scan_uuid)
|
||||
return True
|
||||
return False
|
||||
|
||||
# Send the port scan request
|
||||
payload = PortScanPayload(ip_address=ip_address, port=port, protocol=protocol)
|
||||
self._active_port_scans[payload.uuid] = payload
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Sending port scan request over {payload.protocol.name} on port {payload.port.value} "
|
||||
f"({payload.port.name}) to {payload.ip_address}"
|
||||
)
|
||||
self.software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=ip_address, src_port=port, dest_port=port, ip_protocol=protocol
|
||||
)
|
||||
|
||||
# Recursively call this function with as a reattempt
|
||||
return self._check_port_open_on_ip_address(
|
||||
ip_address=ip_address, port=port, protocol=protocol, is_re_attempt=True, port_scan_uuid=payload.uuid
|
||||
)
|
||||
|
||||
def _process_port_scan_response(self, payload: PortScanPayload):
|
||||
"""
|
||||
Process the response to a port scan request.
|
||||
|
||||
:param payload: The port scan payload received in response.
|
||||
:type payload: PortScanPayload
|
||||
"""
|
||||
if payload.uuid in self._active_port_scans:
|
||||
self._active_port_scans.pop(payload.uuid)
|
||||
self._port_scan_responses[payload.uuid] = payload
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Received port scan response from {payload.ip_address} on port {payload.port.value} "
|
||||
f"({payload.port.name}) over {payload.protocol.name}"
|
||||
)
|
||||
|
||||
def _process_port_scan_request(self, payload: PortScanPayload, session_id: str) -> None:
|
||||
"""
|
||||
Process a port scan request.
|
||||
|
||||
:param payload: The port scan payload received in the request.
|
||||
:type payload: PortScanPayload
|
||||
:param session_id: The session ID for the port scan request.
|
||||
:type session_id: str
|
||||
"""
|
||||
if self.software_manager.check_port_is_open(port=payload.port, protocol=payload.protocol):
|
||||
payload.request = False
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Responding to port scan request for port {payload.port.value} "
|
||||
f"({payload.port.name}) over {payload.protocol.name}",
|
||||
True,
|
||||
)
|
||||
self.software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
|
||||
|
||||
@validate_call()
|
||||
def port_scan(
|
||||
self,
|
||||
target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]],
|
||||
target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None,
|
||||
target_port: Optional[Union[Port, List[Port]]] = None,
|
||||
show: bool = True,
|
||||
json_serializable: bool = False,
|
||||
) -> Dict[IPv4Address, Dict[IPProtocol, List[Port]]]:
|
||||
"""
|
||||
Perform a port scan on the target IP address(es).
|
||||
|
||||
:param target_ip_address: The target IP address(es) or network(s) for the port scan.
|
||||
:type target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]]
|
||||
:param target_protocol: The protocol(s) to use for the port scan. Defaults to None, which includes TCP and UDP.
|
||||
:type target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]]
|
||||
:param target_port: The port(s) to scan. Defaults to None, which includes all valid ports.
|
||||
:type target_port: Optional[Union[Port, List[Port]]]
|
||||
:param show: Flag indicating whether to display the scan results. Defaults to True.
|
||||
:type show: bool
|
||||
:param json_serializable: Flag indicating whether the return value should be JSON serializable. Defaults to
|
||||
False.
|
||||
:type json_serializable: bool
|
||||
|
||||
:return: A dictionary mapping IP addresses to protocols and lists of open ports.
|
||||
:rtype: Dict[IPv4Address, Dict[IPProtocol, List[Port]]]
|
||||
"""
|
||||
ip_addresses = self._explode_ip_address_network_array(target_ip_address)
|
||||
|
||||
if isinstance(target_port, Port):
|
||||
target_port = [target_port]
|
||||
elif target_port is None:
|
||||
target_port = [port for port in Port if port not in {Port.NONE, Port.UNUSED}]
|
||||
|
||||
if isinstance(target_protocol, IPProtocol):
|
||||
target_protocol = [target_protocol]
|
||||
elif target_protocol is None:
|
||||
target_protocol = [IPProtocol.TCP, IPProtocol.UDP]
|
||||
|
||||
scan_type = self._determine_port_scan_type(list(ip_addresses), target_port)
|
||||
active_ports = {}
|
||||
if show:
|
||||
table = PrettyTable(["IP Address", "Port", "Name", "Protocol"])
|
||||
table.align = "l"
|
||||
table.title = f"{self.software_manager.node.hostname} NMAP Port Scan ({scan_type})"
|
||||
self.sys_log.info(f"{self.name}: Starting port scan")
|
||||
for ip_address in ip_addresses:
|
||||
# Prevent port scan on this node
|
||||
if self.software_manager.node.ip_is_network_interface(ip_address=ip_address):
|
||||
continue
|
||||
for protocol in target_protocol:
|
||||
for port in set(target_port):
|
||||
port_open = self._check_port_open_on_ip_address(ip_address=ip_address, port=port, protocol=protocol)
|
||||
|
||||
if port_open:
|
||||
table.add_row([ip_address, port.value, port.name, protocol.name])
|
||||
_ip_address = ip_address if not json_serializable else str(ip_address)
|
||||
_protocol = protocol if not json_serializable else protocol.value
|
||||
_port = port if not json_serializable else port.value
|
||||
if _ip_address not in active_ports:
|
||||
active_ports[_ip_address] = dict()
|
||||
if _protocol not in active_ports[_ip_address]:
|
||||
active_ports[_ip_address][_protocol] = []
|
||||
active_ports[_ip_address][_protocol].append(_port)
|
||||
|
||||
if show:
|
||||
print(table.get_string(sortby="IP Address"))
|
||||
|
||||
return active_ports
|
||||
|
||||
def network_service_recon(
|
||||
self,
|
||||
target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]],
|
||||
target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None,
|
||||
target_port: Optional[Union[Port, List[Port]]] = None,
|
||||
show: bool = True,
|
||||
show_online_only: bool = True,
|
||||
json_serializable: bool = False,
|
||||
) -> Dict[IPv4Address, Dict[IPProtocol, List[Port]]]:
|
||||
"""
|
||||
Perform a network service reconnaissance which includes a ping scan followed by a port scan.
|
||||
|
||||
This method combines the functionalities of a ping scan and a port scan to provide a comprehensive
|
||||
overview of the services on the network. It first identifies active hosts in the target IP range by performing
|
||||
a ping scan. Once the active hosts are identified, it performs a port scan on these hosts to identify open
|
||||
ports and running services. This two-step process ensures that the port scan is performed only on live hosts,
|
||||
optimising the scanning process and providing accurate results.
|
||||
|
||||
:param target_ip_address: The target IP address(es) or network(s) for the port scan.
|
||||
:type target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]]
|
||||
:param target_protocol: The protocol(s) to use for the port scan. Defaults to None, which includes TCP and UDP.
|
||||
:type target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]]
|
||||
:param target_port: The port(s) to scan. Defaults to None, which includes all valid ports.
|
||||
:type target_port: Optional[Union[Port, List[Port]]]
|
||||
:param show: Flag indicating whether to display the scan results. Defaults to True.
|
||||
:type show: bool
|
||||
:param show_online_only: Flag indicating whether to show only the online hosts. Defaults to True.
|
||||
:type show_online_only: bool
|
||||
:param json_serializable: Flag indicating whether the return value should be JSON serializable. Defaults to
|
||||
False.
|
||||
:type json_serializable: bool
|
||||
|
||||
:return: A dictionary mapping IP addresses to protocols and lists of open ports.
|
||||
:rtype: Dict[IPv4Address, Dict[IPProtocol, List[Port]]]
|
||||
"""
|
||||
ping_scan_results = self.ping_scan(
|
||||
target_ip_address=target_ip_address, show=show, show_online_only=show_online_only, json_serializable=False
|
||||
)
|
||||
return self.port_scan(
|
||||
target_ip_address=ping_scan_results,
|
||||
target_protocol=target_protocol,
|
||||
target_port=target_port,
|
||||
show=show,
|
||||
json_serializable=json_serializable,
|
||||
)
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
"""
|
||||
Receive and process a payload.
|
||||
|
||||
:param payload: The payload to be processed.
|
||||
:type payload: Any
|
||||
:param session_id: The session ID associated with the payload.
|
||||
:type session_id: str
|
||||
|
||||
:return: True if the payload was successfully processed, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
if isinstance(payload, PortScanPayload):
|
||||
if payload.request:
|
||||
self._process_port_scan_request(payload=payload, session_id=session_id)
|
||||
else:
|
||||
self._process_port_scan_response(payload=payload)
|
||||
|
||||
return True
|
||||
@@ -1,9 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from enum import IntEnum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
@@ -12,43 +10,10 @@ from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
|
||||
|
||||
class RansomwareAttackStage(IntEnum):
|
||||
"""
|
||||
Enumeration representing different attack stages of the ransomware script.
|
||||
|
||||
This enumeration defines the various stages a data manipulation attack can be in during its lifecycle
|
||||
in the simulation.
|
||||
Each stage represents a specific phase in the attack process.
|
||||
"""
|
||||
|
||||
NOT_STARTED = 0
|
||||
"Indicates that the attack has not started yet."
|
||||
DOWNLOAD = 1
|
||||
"Installing the Encryption Script - Testing"
|
||||
INSTALL = 2
|
||||
"The stage where logon procedures are simulated."
|
||||
ACTIVATE = 3
|
||||
"Operating Status Changes"
|
||||
PROPAGATE = 4
|
||||
"Represents the stage of performing a horizontal port scan on the target."
|
||||
COMMAND_AND_CONTROL = 5
|
||||
"Represents the stage of setting up a rely C2 Beacon (Not Implemented)"
|
||||
PAYLOAD = 6
|
||||
"Stage of actively attacking the target."
|
||||
SUCCEEDED = 7
|
||||
"Indicates the attack has been successfully completed."
|
||||
FAILED = 8
|
||||
"Signifies that the attack has failed."
|
||||
|
||||
|
||||
class RansomwareScript(Application):
|
||||
"""Ransomware Kill Chain - Designed to be used by the TAP001 Agent on the example layout Network.
|
||||
|
||||
:ivar payload: The attack stage query payload. (Default Corrupt)
|
||||
:ivar target_scan_p_of_success: The probability of success for the target scan stage.
|
||||
:ivar c2_beacon_p_of_success: The probability of success for the c2_beacon stage
|
||||
:ivar ransomware_encrypt_p_of_success: The probability of success for the ransomware 'attack' (encrypt) stage.
|
||||
:ivar repeat: Whether to repeat attacking once finished.
|
||||
:ivar payload: The attack stage query payload. (Default ENCRYPT)
|
||||
"""
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
@@ -57,16 +22,6 @@ class RansomwareScript(Application):
|
||||
"""Password required to access the database."""
|
||||
payload: Optional[str] = "ENCRYPT"
|
||||
"Payload String for the payload stage"
|
||||
target_scan_p_of_success: float = 0.9
|
||||
"Probability of the target scan succeeding: Default 0.9"
|
||||
c2_beacon_p_of_success: float = 0.9
|
||||
"Probability of the c2 beacon setup stage succeeding: Default 0.9"
|
||||
ransomware_encrypt_p_of_success: float = 0.9
|
||||
"Probability of the ransomware attack succeeding: Default 0.9"
|
||||
repeat: bool = False
|
||||
"If true, the Denial of Service bot will keep performing the attack."
|
||||
attack_stage: RansomwareAttackStage = RansomwareAttackStage.NOT_STARTED
|
||||
"The ransomware attack stage. See RansomwareAttackStage Class"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "RansomwareScript"
|
||||
@@ -91,7 +46,7 @@ class RansomwareScript(Application):
|
||||
@property
|
||||
def _host_db_client(self) -> DatabaseClient:
|
||||
"""Return the database client that is installed on the same machine as the Ransomware Script."""
|
||||
db_client = self.software_manager.software.get("DatabaseClient")
|
||||
db_client: DatabaseClient = self.software_manager.software.get("DatabaseClient")
|
||||
if db_client is None:
|
||||
self.sys_log.warning(f"{self.__class__.__name__} cannot find a database client on its host.")
|
||||
return db_client
|
||||
@@ -109,16 +64,6 @@ class RansomwareScript(Application):
|
||||
)
|
||||
return rm
|
||||
|
||||
def _activate(self):
|
||||
"""
|
||||
Simulate the install process as the initial stage of the attack.
|
||||
|
||||
Advances the attack stage to 'ACTIVATE' attack state.
|
||||
"""
|
||||
if self.attack_stage == RansomwareAttackStage.INSTALL:
|
||||
self.sys_log.info(f"{self.name}: Activated!")
|
||||
self.attack_stage = RansomwareAttackStage.ACTIVATE
|
||||
|
||||
def run(self) -> bool:
|
||||
"""Calls the parent classes execute method before starting the application loop."""
|
||||
super().run()
|
||||
@@ -134,20 +79,9 @@ class RansomwareScript(Application):
|
||||
return False
|
||||
if self.server_ip_address and self.payload:
|
||||
self.sys_log.info(f"{self.name}: Running")
|
||||
self.attack_stage = RansomwareAttackStage.NOT_STARTED
|
||||
self._local_download()
|
||||
self._install()
|
||||
self._activate()
|
||||
self._perform_target_scan()
|
||||
self._setup_beacon()
|
||||
self._perform_ransomware_encrypt()
|
||||
|
||||
if self.repeat and self.attack_stage in (
|
||||
RansomwareAttackStage.SUCCEEDED,
|
||||
RansomwareAttackStage.FAILED,
|
||||
):
|
||||
self.attack_stage = RansomwareAttackStage.NOT_STARTED
|
||||
return True
|
||||
if self._perform_ransomware_encrypt():
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
self.sys_log.warning(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.")
|
||||
return False
|
||||
@@ -157,10 +91,6 @@ class RansomwareScript(Application):
|
||||
server_ip_address: IPv4Address,
|
||||
server_password: Optional[str] = None,
|
||||
payload: Optional[str] = None,
|
||||
target_scan_p_of_success: Optional[float] = None,
|
||||
c2_beacon_p_of_success: Optional[float] = None,
|
||||
ransomware_encrypt_p_of_success: Optional[float] = None,
|
||||
repeat: bool = True,
|
||||
):
|
||||
"""
|
||||
Configure the Ransomware Script to communicate with a DatabaseService.
|
||||
@@ -168,10 +98,6 @@ class RansomwareScript(Application):
|
||||
:param server_ip_address: The IP address of the Node the DatabaseService is on.
|
||||
:param server_password: The password on the DatabaseService.
|
||||
:param payload: The attack stage query (Encrypt / Delete)
|
||||
:param target_scan_p_of_success: The probability of success for the target scan stage.
|
||||
:param c2_beacon_p_of_success: The probability of success for the c2_beacon stage
|
||||
:param ransomware_encrypt_p_of_success: The probability of success for the ransomware 'attack' (encrypt) stage.
|
||||
:param repeat: Whether to repeat attacking once finished.
|
||||
"""
|
||||
if server_ip_address:
|
||||
self.server_ip_address = server_ip_address
|
||||
@@ -179,74 +105,15 @@ class RansomwareScript(Application):
|
||||
self.server_password = server_password
|
||||
if payload:
|
||||
self.payload = payload
|
||||
if target_scan_p_of_success:
|
||||
self.target_scan_p_of_success = target_scan_p_of_success
|
||||
if c2_beacon_p_of_success:
|
||||
self.c2_beacon_p_of_success = c2_beacon_p_of_success
|
||||
if ransomware_encrypt_p_of_success:
|
||||
self.ransomware_encrypt_p_of_success = ransomware_encrypt_p_of_success
|
||||
if repeat:
|
||||
self.repeat = repeat
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}, "
|
||||
f"{repeat=}."
|
||||
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}."
|
||||
)
|
||||
|
||||
def _install(self):
|
||||
"""
|
||||
Simulate the install stage in the kill-chain.
|
||||
|
||||
Advances the attack stage to 'ACTIVATE' if successful.
|
||||
|
||||
From this attack stage onwards.
|
||||
the ransomware application is now visible from this point onwardin the observation space.
|
||||
"""
|
||||
if self.attack_stage == RansomwareAttackStage.DOWNLOAD:
|
||||
self.sys_log.info(f"{self.name}: Malware installed on the local file system")
|
||||
downloads_folder = self.file_system.get_folder(folder_name="downloads")
|
||||
ransomware_file = downloads_folder.get_file(file_name="ransom_script.pdf")
|
||||
ransomware_file.num_access += 1
|
||||
self.attack_stage = RansomwareAttackStage.INSTALL
|
||||
|
||||
def _setup_beacon(self):
|
||||
"""
|
||||
Simulates setting up a c2 beacon; currently a pseudo step for increasing red variance.
|
||||
|
||||
Advances the attack stage to 'COMMAND AND CONTROL` if successful.
|
||||
|
||||
:param p_of_sucess: Probability of a successful c2 setup (Advancing this step),
|
||||
by default the success rate is 0.5
|
||||
"""
|
||||
if self.attack_stage == RansomwareAttackStage.PROPAGATE:
|
||||
self.sys_log.info(f"{self.name} Attempting to set up C&C Beacon - Scan 1/2")
|
||||
if simulate_trial(self.c2_beacon_p_of_success):
|
||||
self.sys_log.info(f"{self.name} C&C Successful setup - Scan 2/2")
|
||||
c2c_setup = True # TODO Implement the c2c step via an FTP Application/Service
|
||||
if c2c_setup:
|
||||
self.attack_stage = RansomwareAttackStage.COMMAND_AND_CONTROL
|
||||
|
||||
def _perform_target_scan(self):
|
||||
"""
|
||||
Perform a simulated port scan to check for open SQL ports.
|
||||
|
||||
Advances the attack stage to `PROPAGATE` if successful.
|
||||
|
||||
:param p_of_success: Probability of successful port scan, by default 0.1.
|
||||
"""
|
||||
if self.attack_stage == RansomwareAttackStage.ACTIVATE:
|
||||
# perform a port scan to identify that the SQL port is open on the server
|
||||
self.sys_log.info(f"{self.name}: Scanning for vulnerable databases - Scan 0/2")
|
||||
if simulate_trial(self.target_scan_p_of_success):
|
||||
self.sys_log.info(f"{self.name}: Found a target database! Scan 1/2")
|
||||
port_is_open = True # TODO Implement a NNME Triggering scan as a seperate Red Application
|
||||
if port_is_open:
|
||||
self.attack_stage = RansomwareAttackStage.PROPAGATE
|
||||
|
||||
def attack(self) -> bool:
|
||||
"""Perform the attack steps after opening the application."""
|
||||
self.run()
|
||||
if not self._can_perform_action():
|
||||
self.sys_log.warning("Ransomware application is unable to perform it's actions.")
|
||||
self.run()
|
||||
self.num_executions += 1
|
||||
return self._application_loop()
|
||||
|
||||
@@ -255,57 +122,30 @@ class RansomwareScript(Application):
|
||||
self._db_connection = self._host_db_client.get_new_connection()
|
||||
return True if self._db_connection else False
|
||||
|
||||
def _perform_ransomware_encrypt(self):
|
||||
def _perform_ransomware_encrypt(self) -> bool:
|
||||
"""
|
||||
Execute the Ransomware Encrypt payload on the target.
|
||||
|
||||
Advances the attack stage to `COMPLETE` if successful, or 'FAILED' if unsuccessful.
|
||||
:param p_of_success: Probability of successfully performing ransomware encryption, by default 0.1.
|
||||
"""
|
||||
if self._host_db_client is None:
|
||||
self.sys_log.info(f"{self.name}: Failed to connect to db_client - Ransomware Script")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
return
|
||||
return False
|
||||
|
||||
self._host_db_client.server_ip_address = self.server_ip_address
|
||||
self._host_db_client.server_password = self.server_password
|
||||
if self.attack_stage == RansomwareAttackStage.COMMAND_AND_CONTROL:
|
||||
if simulate_trial(self.ransomware_encrypt_p_of_success):
|
||||
self.sys_log.info(f"{self.name}: Attempting to launch payload")
|
||||
if not self._db_connection:
|
||||
self._establish_db_connection()
|
||||
if self._db_connection:
|
||||
attack_successful = self._db_connection.query(self.payload)
|
||||
self.sys_log.info(f"{self.name} Payload delivered: {self.payload}")
|
||||
if attack_successful:
|
||||
self.sys_log.info(f"{self.name}: Payload Successful")
|
||||
self.attack_stage = RansomwareAttackStage.SUCCEEDED
|
||||
else:
|
||||
self.sys_log.info(f"{self.name}: Payload failed")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
self.sys_log.info(f"{self.name}: Attempting to launch payload")
|
||||
if not self._db_connection:
|
||||
self._establish_db_connection()
|
||||
if self._db_connection:
|
||||
attack_successful = self._db_connection.query(self.payload)
|
||||
self.sys_log.info(f"{self.name} Payload delivered: {self.payload}")
|
||||
if attack_successful:
|
||||
self.sys_log.info(f"{self.name}: Payload Successful")
|
||||
return True
|
||||
else:
|
||||
self.sys_log.info(f"{self.name}: Payload failed")
|
||||
return False
|
||||
else:
|
||||
self.sys_log.warning("Attack Attempted to launch too quickly")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
|
||||
def _local_download(self):
|
||||
"""Downloads itself via the onto the local file_system."""
|
||||
if self.attack_stage == RansomwareAttackStage.NOT_STARTED:
|
||||
if self._local_download_verify():
|
||||
self.attack_stage = RansomwareAttackStage.DOWNLOAD
|
||||
else:
|
||||
self.sys_log.info("Malware failed to create a installation location")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
else:
|
||||
self.sys_log.info("Malware failed to download")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
|
||||
def _local_download_verify(self) -> bool:
|
||||
"""Verifies a download location - Creates one if needed."""
|
||||
for folder in self.file_system.folders:
|
||||
if self.file_system.folders[folder].name == "downloads":
|
||||
self.file_system.num_file_creations += 1
|
||||
return True
|
||||
|
||||
self.file_system.create_folder("downloads")
|
||||
self.file_system.create_file(folder_name="downloads", file_name="ransom_script.pdf")
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -79,6 +79,31 @@ class SoftwareManager:
|
||||
open_ports.append(software.port)
|
||||
return open_ports
|
||||
|
||||
def check_port_is_open(self, port: Port, protocol: IPProtocol) -> bool:
|
||||
"""
|
||||
Check if a specific port is open and running a service using the specified protocol.
|
||||
|
||||
This method iterates through all installed software on the node and checks if any of them
|
||||
are using the specified port and protocol and are currently in a running state. It returns True if any software
|
||||
is found running on the specified port and protocol, otherwise False.
|
||||
|
||||
|
||||
:param port: The port to check.
|
||||
:type port: Port
|
||||
:param protocol: The protocol to check (e.g., TCP, UDP).
|
||||
:type protocol: IPProtocol
|
||||
:return: True if the port is open and a service is running on it using the specified protocol, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
for software in self.software.values():
|
||||
if (
|
||||
software.port == port
|
||||
and software.protocol == protocol
|
||||
and software.operating_state in {ApplicationOperatingState.RUNNING, ServiceOperatingState.RUNNING}
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def install(self, software_class: Type[IOSoftwareClass]):
|
||||
"""
|
||||
Install an Application or Service.
|
||||
@@ -151,6 +176,7 @@ class SoftwareManager:
|
||||
self,
|
||||
payload: Any,
|
||||
dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
|
||||
src_port: Optional[Port] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
ip_protocol: IPProtocol = IPProtocol.TCP,
|
||||
session_id: Optional[str] = None,
|
||||
@@ -171,6 +197,7 @@ class SoftwareManager:
|
||||
return self.session_manager.receive_payload_from_software_manager(
|
||||
payload=payload,
|
||||
dst_ip_address=dest_ip_address,
|
||||
src_port=src_port,
|
||||
dst_port=dest_port,
|
||||
ip_protocol=ip_protocol,
|
||||
session_id=session_id,
|
||||
@@ -191,6 +218,9 @@ class SoftwareManager:
|
||||
:param payload: The payload being received.
|
||||
:param session: The transport session the payload originates from.
|
||||
"""
|
||||
if payload.__class__.__name__ == "PortScanPayload":
|
||||
self.software.get("NMAP").receive(payload=payload, session_id=session_id)
|
||||
return
|
||||
receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None)
|
||||
if receiver:
|
||||
receiver.receive(
|
||||
|
||||
26
src/primaite/utils/converters.py
Normal file
26
src/primaite/utils/converters.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def convert_dict_enum_keys_to_enum_values(d: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||
"""
|
||||
Convert dictionary keys from enums to their corresponding values.
|
||||
|
||||
:param d: dict
|
||||
The dictionary with enum keys to be converted.
|
||||
:return: dict
|
||||
The dictionary with enum values as keys.
|
||||
"""
|
||||
result = {}
|
||||
for key, value in d.items():
|
||||
if isinstance(key, Enum):
|
||||
new_key = key.value
|
||||
else:
|
||||
new_key = key
|
||||
|
||||
if isinstance(value, dict):
|
||||
result[new_key] = convert_dict_enum_keys_to_enum_values(value)
|
||||
else:
|
||||
result[new_key] = value
|
||||
return result
|
||||
@@ -3,7 +3,7 @@
|
||||
raise DeprecationWarning(
|
||||
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
|
||||
)
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
@@ -12,16 +12,16 @@ from typing import Any, Dict, Tuple, Union
|
||||
import polars as pl
|
||||
|
||||
|
||||
def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
|
||||
def total_rewards_dict(total_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
|
||||
"""
|
||||
Read an average rewards per episode csv file and return as a dict.
|
||||
|
||||
The dictionary keys are the episode number, and the values are the mean reward that episode.
|
||||
|
||||
:param av_rewards_csv_file: The average rewards per episode csv file path.
|
||||
:param total_rewards_csv_file: The average rewards per episode csv file path.
|
||||
:return: The average rewards per episode csv as a dict.
|
||||
"""
|
||||
df_dict = pl.read_csv(av_rewards_csv_file).to_dict()
|
||||
df_dict = pl.read_csv(total_rewards_csv_file).to_dict()
|
||||
|
||||
return {int(v): df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
raise DeprecationWarning(
|
||||
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
|
||||
)
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import csv
|
||||
from logging import Logger
|
||||
from typing import Final, List, Tuple, TYPE_CHECKING, Union
|
||||
@@ -27,9 +27,9 @@ class SessionOutputWriter:
|
||||
Is used to write session outputs to csv file.
|
||||
"""
|
||||
|
||||
_AV_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [
|
||||
_TOTAL_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [
|
||||
"Episode",
|
||||
"Average Reward",
|
||||
"Total Reward",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@@ -44,7 +44,7 @@ class SessionOutputWriter:
|
||||
:param env: PrimAITE gym environment.
|
||||
:type env: Primaite
|
||||
:param transaction_writer: If `true`, this will output a full account of every transaction taken by the agent.
|
||||
If `false` it will output the average reward per episode, defaults to False
|
||||
If `false` it will output the total reward per episode, defaults to False
|
||||
:type transaction_writer: bool, optional
|
||||
:param learning_session: Set to `true` to indicate that the current session is a training session. This
|
||||
determines the name of the folder which contains the final output csv. Defaults to True
|
||||
@@ -57,7 +57,7 @@ class SessionOutputWriter:
|
||||
if self.transaction_writer:
|
||||
fn = f"all_transactions_{self._env.timestamp_str}.csv"
|
||||
else:
|
||||
fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv"
|
||||
fn = f"total_reward_per_episode_{self._env.timestamp_str}.csv"
|
||||
|
||||
self._csv_file_path: "Path"
|
||||
if self.learning_session:
|
||||
@@ -95,7 +95,7 @@ class SessionOutputWriter:
|
||||
if isinstance(data, Transaction):
|
||||
header, data = data.as_csv_data()
|
||||
else:
|
||||
header = self._AV_REWARD_PER_EPISODE_HEADER
|
||||
header = self._TOTAL_REWARD_PER_EPISODE_HEADER
|
||||
|
||||
if self._first_write:
|
||||
self._init_csv_writer()
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
io_settings:
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: true
|
||||
save_sys_logs: true
|
||||
sys_log_level: WARNING
|
||||
|
||||
|
||||
|
||||
game:
|
||||
max_episode_length: 256
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_red_nmap
|
||||
team: RED
|
||||
type: ProbabilisticAgent
|
||||
observation_space: null
|
||||
action_space:
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: NMAP
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 1
|
||||
action_list:
|
||||
- type: NODE_NMAP_NETWORK_SERVICE_RECON
|
||||
action_map:
|
||||
0:
|
||||
action: NODE_NMAP_NETWORK_SERVICE_RECON
|
||||
options:
|
||||
source_node: client_1
|
||||
target_ip_address: 192.168.10.0/24
|
||||
target_port: 80
|
||||
target_protocol: tcp
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 1.0
|
||||
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
- hostname: switch_1
|
||||
num_ports: 8
|
||||
type: switch
|
||||
|
||||
- hostname: switch_2
|
||||
num_ports: 8
|
||||
type: switch
|
||||
|
||||
- hostname: router_1
|
||||
type: router
|
||||
ports:
|
||||
1:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
2:
|
||||
ip_address: 192.168.10.1
|
||||
subnet_mask: 255.255.255.0
|
||||
acl:
|
||||
1:
|
||||
action: PERMIT
|
||||
|
||||
- hostname: client_1
|
||||
type: computer
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
|
||||
- hostname: client_2
|
||||
type: computer
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
|
||||
- hostname: server_1
|
||||
type: server
|
||||
ip_address: 192.168.1.10
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
|
||||
- hostname: server_2
|
||||
type: server
|
||||
ip_address: 192.168.1.14
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
|
||||
|
||||
|
||||
|
||||
links:
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 8
|
||||
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 8
|
||||
|
||||
- endpoint_a_hostname: client_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 1
|
||||
|
||||
- endpoint_a_hostname: client_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 2
|
||||
|
||||
- endpoint_a_hostname: server_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 1
|
||||
|
||||
- endpoint_a_hostname: server_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 2
|
||||
135
tests/assets/configs/nmap_ping_scan_red_agent_config.yaml
Normal file
135
tests/assets/configs/nmap_ping_scan_red_agent_config.yaml
Normal file
@@ -0,0 +1,135 @@
|
||||
io_settings:
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: true
|
||||
save_sys_logs: true
|
||||
sys_log_level: WARNING
|
||||
|
||||
|
||||
|
||||
game:
|
||||
max_episode_length: 256
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_red_nmap
|
||||
team: RED
|
||||
type: ProbabilisticAgent
|
||||
observation_space: null
|
||||
action_space:
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: NMAP
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 1
|
||||
action_list:
|
||||
- type: NODE_NMAP_PING_SCAN
|
||||
action_map:
|
||||
0:
|
||||
action: NODE_NMAP_PING_SCAN
|
||||
options:
|
||||
source_node: client_1
|
||||
target_ip_address: 192.168.1.0/24
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 1.0
|
||||
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
- hostname: switch_1
|
||||
num_ports: 8
|
||||
type: switch
|
||||
|
||||
- hostname: switch_2
|
||||
num_ports: 8
|
||||
type: switch
|
||||
|
||||
- hostname: router_1
|
||||
type: router
|
||||
ports:
|
||||
1:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
2:
|
||||
ip_address: 192.168.10.1
|
||||
subnet_mask: 255.255.255.0
|
||||
acl:
|
||||
1:
|
||||
action: PERMIT
|
||||
|
||||
- hostname: client_1
|
||||
type: computer
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
|
||||
- hostname: client_2
|
||||
type: computer
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
|
||||
- hostname: server_1
|
||||
type: server
|
||||
ip_address: 192.168.1.10
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
|
||||
- hostname: server_2
|
||||
type: server
|
||||
ip_address: 192.168.1.14
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
|
||||
|
||||
|
||||
|
||||
links:
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 8
|
||||
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 8
|
||||
|
||||
- endpoint_a_hostname: client_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 1
|
||||
|
||||
- endpoint_a_hostname: client_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 2
|
||||
|
||||
- endpoint_a_hostname: server_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 1
|
||||
|
||||
- endpoint_a_hostname: server_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 2
|
||||
135
tests/assets/configs/nmap_port_scan_red_agent_config.yaml
Normal file
135
tests/assets/configs/nmap_port_scan_red_agent_config.yaml
Normal file
@@ -0,0 +1,135 @@
|
||||
io_settings:
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: true
|
||||
save_sys_logs: true
|
||||
sys_log_level: WARNING
|
||||
|
||||
|
||||
|
||||
game:
|
||||
max_episode_length: 256
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_red_nmap
|
||||
team: RED
|
||||
type: ProbabilisticAgent
|
||||
observation_space: null
|
||||
action_space:
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: NMAP
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 1
|
||||
action_list:
|
||||
- type: NODE_NMAP_PORT_SCAN
|
||||
action_map:
|
||||
0:
|
||||
action: NODE_NMAP_PORT_SCAN
|
||||
options:
|
||||
source_node: client_1
|
||||
target_ip_address: 192.168.10.0/24
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 1.0
|
||||
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
- hostname: switch_1
|
||||
num_ports: 8
|
||||
type: switch
|
||||
|
||||
- hostname: switch_2
|
||||
num_ports: 8
|
||||
type: switch
|
||||
|
||||
- hostname: router_1
|
||||
type: router
|
||||
ports:
|
||||
1:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
2:
|
||||
ip_address: 192.168.10.1
|
||||
subnet_mask: 255.255.255.0
|
||||
acl:
|
||||
1:
|
||||
action: PERMIT
|
||||
|
||||
- hostname: client_1
|
||||
type: computer
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
|
||||
- hostname: client_2
|
||||
type: computer
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
|
||||
- hostname: server_1
|
||||
type: server
|
||||
ip_address: 192.168.1.10
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
|
||||
- hostname: server_2
|
||||
type: server
|
||||
ip_address: 192.168.1.14
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
|
||||
|
||||
|
||||
|
||||
links:
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 8
|
||||
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 8
|
||||
|
||||
- endpoint_a_hostname: client_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 1
|
||||
|
||||
- endpoint_a_hostname: client_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 2
|
||||
|
||||
- endpoint_a_hostname: server_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 1
|
||||
|
||||
- endpoint_a_hostname: server_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 2
|
||||
@@ -10,12 +10,8 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import (
|
||||
RansomwareAttackStage,
|
||||
RansomwareScript,
|
||||
)
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
@@ -86,54 +82,24 @@ def ransomware_script_db_server_green_client(example_network) -> Network:
|
||||
return network
|
||||
|
||||
|
||||
def test_repeating_ransomware_script_attack(ransomware_script_and_db_server):
|
||||
def test_ransomware_script_attack(ransomware_script_and_db_server):
|
||||
"""Test a repeating data manipulation attack."""
|
||||
RansomwareScript, computer, db_server_service, server = ransomware_script_and_db_server
|
||||
|
||||
computer.apply_timestep(timestep=0)
|
||||
server.apply_timestep(timestep=0)
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
|
||||
assert computer.file_system.num_file_creations == 0
|
||||
assert server.file_system.num_file_creations == 1
|
||||
|
||||
RansomwareScript.target_scan_p_of_success = 1
|
||||
RansomwareScript.c2_beacon_p_of_success = 1
|
||||
RansomwareScript.ransomware_encrypt_p_of_success = 1
|
||||
RansomwareScript.repeat = True
|
||||
RansomwareScript.attack()
|
||||
|
||||
assert RansomwareScript.attack_stage == RansomwareAttackStage.NOT_STARTED
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.COMPROMISED
|
||||
assert computer.file_system.num_file_creations == 1
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT
|
||||
assert server.file_system.num_file_creations == 2
|
||||
|
||||
computer.apply_timestep(timestep=1)
|
||||
server.apply_timestep(timestep=1)
|
||||
|
||||
assert RansomwareScript.attack_stage == RansomwareAttackStage.NOT_STARTED
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.COMPROMISED
|
||||
|
||||
|
||||
def test_repeating_ransomware_script_attack(ransomware_script_and_db_server):
|
||||
"""Test a repeating ransowmare script attack."""
|
||||
RansomwareScript, computer, db_server_service, server = ransomware_script_and_db_server
|
||||
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
|
||||
|
||||
RansomwareScript.target_scan_p_of_success = 1
|
||||
RansomwareScript.c2_beacon_p_of_success = 1
|
||||
RansomwareScript.ransomware_encrypt_p_of_success = 1
|
||||
RansomwareScript.repeat = False
|
||||
RansomwareScript.attack()
|
||||
|
||||
assert RansomwareScript.attack_stage == RansomwareAttackStage.SUCCEEDED
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT
|
||||
assert computer.file_system.num_file_creations == 1
|
||||
|
||||
computer.apply_timestep(timestep=1)
|
||||
computer.pre_timestep(timestep=1)
|
||||
server.apply_timestep(timestep=1)
|
||||
server.pre_timestep(timestep=1)
|
||||
|
||||
assert RansomwareScript.attack_stage == RansomwareAttackStage.SUCCEEDED
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT
|
||||
assert computer.file_system.num_file_creations == 0
|
||||
|
||||
|
||||
def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_green_client):
|
||||
@@ -154,10 +120,6 @@ def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_
|
||||
assert green_db_client_connection.query("SELECT")
|
||||
assert green_db_client.last_query_response.get("status_code") == 200
|
||||
|
||||
ransomware_script_application.target_scan_p_of_success = 1
|
||||
ransomware_script_application.ransomware_encrypt_p_of_success = 1
|
||||
ransomware_script_application.c2_beacon_p_of_success = 1
|
||||
ransomware_script_application.repeat = False
|
||||
ransomware_script_application.attack()
|
||||
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT
|
||||
|
||||
186
tests/integration_tests/system/test_nmap.py
Normal file
186
tests/integration_tests/system/test_nmap.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.nmap import NMAP
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
|
||||
def test_ping_scan_all_on(example_network):
|
||||
network = example_network
|
||||
|
||||
client_1 = network.get_node_by_hostname("client_1")
|
||||
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
|
||||
|
||||
expected_result = [IPv4Address("192.168.1.10"), IPv4Address("192.168.1.14")]
|
||||
actual_result = client_1_nmap.ping_scan(target_ip_address=["192.168.1.10", "192.168.1.14"])
|
||||
|
||||
assert sorted(actual_result) == sorted(expected_result)
|
||||
|
||||
|
||||
def test_ping_scan_all_on_full_network(example_network):
|
||||
network = example_network
|
||||
|
||||
client_1 = network.get_node_by_hostname("client_1")
|
||||
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
|
||||
|
||||
expected_result = [IPv4Address("192.168.1.1"), IPv4Address("192.168.1.10"), IPv4Address("192.168.1.14")]
|
||||
actual_result = client_1_nmap.ping_scan(target_ip_address=IPv4Network("192.168.1.0/24"))
|
||||
|
||||
assert sorted(actual_result) == sorted(expected_result)
|
||||
|
||||
|
||||
def test_ping_scan_some_on(example_network):
|
||||
network = example_network
|
||||
|
||||
client_1 = network.get_node_by_hostname("client_1")
|
||||
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
|
||||
|
||||
network.get_node_by_hostname("server_2").power_off()
|
||||
|
||||
expected_result = [IPv4Address("192.168.1.1"), IPv4Address("192.168.1.10")]
|
||||
actual_result = client_1_nmap.ping_scan(target_ip_address=IPv4Network("192.168.1.0/24"))
|
||||
|
||||
assert sorted(actual_result) == sorted(expected_result)
|
||||
|
||||
|
||||
def test_ping_scan_all_off(example_network):
|
||||
network = example_network
|
||||
|
||||
client_1 = network.get_node_by_hostname("client_1")
|
||||
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
|
||||
|
||||
network.get_node_by_hostname("server_1").power_off()
|
||||
network.get_node_by_hostname("server_2").power_off()
|
||||
|
||||
expected_result = []
|
||||
actual_result = client_1_nmap.ping_scan(target_ip_address=["192.168.1.10", "192.168.1.14"])
|
||||
|
||||
assert sorted(actual_result) == sorted(expected_result)
|
||||
|
||||
|
||||
def test_port_scan_one_node_one_port(example_network):
|
||||
network = example_network
|
||||
|
||||
client_1 = network.get_node_by_hostname("client_1")
|
||||
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
|
||||
|
||||
client_2 = network.get_node_by_hostname("client_2")
|
||||
|
||||
actual_result = client_1_nmap.port_scan(
|
||||
target_ip_address=client_2.network_interface[1].ip_address, target_port=Port.DNS, target_protocol=IPProtocol.TCP
|
||||
)
|
||||
|
||||
expected_result = {IPv4Address("192.168.10.22"): {IPProtocol.TCP: [Port.DNS]}}
|
||||
|
||||
assert actual_result == expected_result
|
||||
|
||||
|
||||
def sort_dict(d):
|
||||
"""Recursively sorts a dictionary."""
|
||||
if isinstance(d, dict):
|
||||
return {k: sort_dict(v) for k, v in sorted(d.items(), key=lambda item: str(item[0]))}
|
||||
elif isinstance(d, list):
|
||||
return sorted(d, key=lambda item: str(item) if isinstance(item, Enum) else item)
|
||||
elif isinstance(d, Enum):
|
||||
return str(d)
|
||||
else:
|
||||
return d
|
||||
|
||||
|
||||
def test_port_scan_full_subnet_all_ports_and_protocols(example_network):
|
||||
network = example_network
|
||||
|
||||
client_1 = network.get_node_by_hostname("client_1")
|
||||
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
|
||||
|
||||
actual_result = client_1_nmap.port_scan(
|
||||
target_ip_address=IPv4Network("192.168.10.0/24"),
|
||||
)
|
||||
|
||||
expected_result = {
|
||||
IPv4Address("192.168.10.1"): {IPProtocol.UDP: [Port.ARP]},
|
||||
IPv4Address("192.168.10.22"): {
|
||||
IPProtocol.TCP: [Port.HTTP, Port.FTP, Port.DNS],
|
||||
IPProtocol.UDP: [Port.ARP, Port.NTP],
|
||||
},
|
||||
}
|
||||
|
||||
assert sort_dict(actual_result) == sort_dict(expected_result)
|
||||
|
||||
|
||||
def test_network_service_recon_all_ports_and_protocols(example_network):
|
||||
network = example_network
|
||||
|
||||
client_1 = network.get_node_by_hostname("client_1")
|
||||
client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa
|
||||
|
||||
actual_result = client_1_nmap.network_service_recon(
|
||||
target_ip_address=IPv4Network("192.168.10.0/24"), target_port=Port.HTTP, target_protocol=IPProtocol.TCP
|
||||
)
|
||||
|
||||
expected_result = {IPv4Address("192.168.10.22"): {IPProtocol.TCP: [Port.HTTP]}}
|
||||
|
||||
assert sort_dict(actual_result) == sort_dict(expected_result)
|
||||
|
||||
|
||||
def test_ping_scan_red_agent():
|
||||
with open(TEST_ASSETS_ROOT / "configs/nmap_ping_scan_red_agent_config.yaml", "r") as file:
|
||||
cfg = yaml.safe_load(file)
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
|
||||
game.step()
|
||||
|
||||
expected_result = ["192.168.1.1", "192.168.1.10", "192.168.1.14"]
|
||||
|
||||
action_history = game.agents["client_1_red_nmap"].history
|
||||
assert len(action_history) == 1
|
||||
actual_result = action_history[0].response.data["live_hosts"]
|
||||
|
||||
assert sorted(actual_result) == sorted(expected_result)
|
||||
|
||||
|
||||
def test_port_scan_red_agent():
|
||||
with open(TEST_ASSETS_ROOT / "configs/nmap_port_scan_red_agent_config.yaml", "r") as file:
|
||||
cfg = yaml.safe_load(file)
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
|
||||
game.step()
|
||||
|
||||
expected_result = {
|
||||
"192.168.10.1": {"udp": [219]},
|
||||
"192.168.10.22": {
|
||||
"tcp": [80, 21, 53],
|
||||
"udp": [219, 123],
|
||||
},
|
||||
}
|
||||
|
||||
action_history = game.agents["client_1_red_nmap"].history
|
||||
assert len(action_history) == 1
|
||||
actual_result = action_history[0].response.data
|
||||
|
||||
assert sorted(actual_result) == sorted(expected_result)
|
||||
|
||||
|
||||
def test_network_service_recon_red_agent():
|
||||
with open(TEST_ASSETS_ROOT / "configs/nmap_network_service_recon_red_agent_config.yaml", "r") as file:
|
||||
cfg = yaml.safe_load(file)
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
|
||||
game.step()
|
||||
|
||||
expected_result = {"192.168.10.22": {"tcp": [80]}}
|
||||
|
||||
action_history = game.agents["client_1_red_nmap"].history
|
||||
assert len(action_history) == 1
|
||||
actual_result = action_history[0].response.data
|
||||
|
||||
assert sorted(actual_result) == sorted(expected_result)
|
||||
1
tests/unit_tests/_primaite/_utils/__init__.py
Normal file
1
tests/unit_tests/_primaite/_utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
@@ -0,0 +1,84 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.utils.converters import convert_dict_enum_keys_to_enum_values
|
||||
|
||||
|
||||
def test_simple_conversion():
|
||||
"""
|
||||
Test conversion of a simple dictionary with enum keys to enum values.
|
||||
|
||||
The original dictionary contains one level of nested dictionary with enums as keys.
|
||||
The expected output should have string values of enums as keys.
|
||||
"""
|
||||
original_dict = {IPProtocol.UDP: {Port.ARP: {"inbound": 0, "outbound": 1016.0}}}
|
||||
expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0}}}
|
||||
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict
|
||||
|
||||
|
||||
def test_no_enums():
|
||||
"""
|
||||
Test conversion of a dictionary with no enum keys.
|
||||
|
||||
The original dictionary contains only string keys.
|
||||
The expected output should be identical to the original dictionary.
|
||||
"""
|
||||
original_dict = {"protocol": {"port": {"inbound": 0, "outbound": 1016.0}}}
|
||||
expected_dict = {"protocol": {"port": {"inbound": 0, "outbound": 1016.0}}}
|
||||
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict
|
||||
|
||||
|
||||
def test_mixed_keys():
|
||||
"""
|
||||
Test conversion of a dictionary with a mix of enum and string keys.
|
||||
|
||||
The original dictionary contains both enums and strings as keys.
|
||||
The expected output should have string values of enums and original string keys.
|
||||
"""
|
||||
original_dict = {
|
||||
IPProtocol.TCP: {"port": {"inbound": 0, "outbound": 1016.0}},
|
||||
"protocol": {Port.HTTP: {"inbound": 10, "outbound": 2020.0}},
|
||||
}
|
||||
expected_dict = {
|
||||
"tcp": {"port": {"inbound": 0, "outbound": 1016.0}},
|
||||
"protocol": {80: {"inbound": 10, "outbound": 2020.0}},
|
||||
}
|
||||
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict
|
||||
|
||||
|
||||
def test_empty_dict():
|
||||
"""
|
||||
Test conversion of an empty dictionary.
|
||||
|
||||
The original dictionary is empty.
|
||||
The expected output should also be an empty dictionary.
|
||||
"""
|
||||
original_dict = {}
|
||||
expected_dict = {}
|
||||
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict
|
||||
|
||||
|
||||
def test_nested_dicts():
|
||||
"""
|
||||
Test conversion of a nested dictionary with multiple levels of nested dictionaries and enums as keys.
|
||||
|
||||
The original dictionary contains nested dictionaries with enums as keys at different levels.
|
||||
The expected output should have string values of enums as keys at all levels.
|
||||
"""
|
||||
original_dict = {
|
||||
IPProtocol.UDP: {Port.ARP: {"inbound": 0, "outbound": 1016.0, "details": {IPProtocol.TCP: {"latency": "low"}}}}
|
||||
}
|
||||
expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0, "details": {"tcp": {"latency": "low"}}}}}
|
||||
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict
|
||||
|
||||
|
||||
def test_non_dict_values():
|
||||
"""
|
||||
Test conversion of a dictionary where some values are not dictionaries.
|
||||
|
||||
The original dictionary contains lists and tuples as values.
|
||||
The expected output should preserve these non-dictionary values while converting enum keys to string values.
|
||||
"""
|
||||
original_dict = {IPProtocol.UDP: [Port.ARP, Port.HTTP], "protocols": (IPProtocol.TCP, IPProtocol.UDP)}
|
||||
expected_dict = {"udp": [Port.ARP, Port.HTTP], "protocols": (IPProtocol.TCP, IPProtocol.UDP)}
|
||||
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict
|
||||
Reference in New Issue
Block a user