Merge branch 'dev' into bugfix/2676_NMNE_var_access

This commit is contained in:
Nick Todd
2024-07-05 14:20:50 +01:00
46 changed files with 2578 additions and 281 deletions

View File

@@ -0,0 +1,84 @@
trigger:
branches:
exclude:
- '*'
include:
- 'refs/heads/release/*'
schedules:
- cron: "0 2 * * 1-5" # Run at 2 AM every weekday
displayName: "Weekday Schedule"
branches:
include:
- 'refs/heads/dev'
pool:
vmImage: ubuntu-latest
variables:
VERSION: ''
MAJOR_VERSION: ''
steps:
- checkout: self
persistCredentials: true
- script: |
VERSION=$(cat src/primaite/VERSION | tr -d '\n')
if [[ "$(Build.SourceBranch)" == "refs/heads/dev" ]]; then
DATE=$(date +%Y%m%d)
echo "${VERSION}+dev.${DATE}" > src/primaite/VERSION
fi
displayName: 'Update VERSION file for Dev Benchmark'
- script: |
VERSION=$(cat src/primaite/VERSION | tr -d '\n')
MAJOR_VERSION=$(echo $VERSION | cut -d. -f1)
echo "##vso[task.setvariable variable=VERSION]$VERSION"
echo "##vso[task.setvariable variable=MAJOR_VERSION]$MAJOR_VERSION"
displayName: 'Set Version Variables'
- task: UsePythonVersion@0
inputs:
versionSpec: '3.11'
addToPath: true
- script: |
python -m pip install --upgrade pip
pip install -e .[dev,rl]
primaite setup
displayName: 'Install Dependencies'
- script: |
cd benchmark
python3 primaite_benchmark.py
cd ..
displayName: 'Run Benchmarking Script'
- script: |
git config --global user.email "oss@dstl.gov.uk"
git config --global user.name "Defence Science and Technology Laboratory UK"
workingDirectory: $(System.DefaultWorkingDirectory)
displayName: 'Configure Git'
condition: and(succeeded(), eq(variables['Build.Reason'], 'Manual'), startsWith(variables['Build.SourceBranch'], 'refs/heads/release'))
- script: |
git add benchmark/results/v$(MAJOR_VERSION)/v$(VERSION)/*
git commit -m "Automated benchmark output commit for version $(VERSION)"
git push origin HEAD:refs/heads/$(Build.SourceBranchName)
displayName: 'Commit and Push Benchmark Results'
workingDirectory: $(System.DefaultWorkingDirectory)
env:
GIT_CREDENTIALS: $(System.AccessToken)
condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/heads/release'))
- script: |
tar czf primaite_v$(VERSION)_benchmark.tar.gz benchmark/results/v$(MAJOR_VERSION)/v$(VERSION)
displayName: 'Prepare Artifacts for Publishing'
- task: PublishPipelineArtifact@1
inputs:
targetPath: primaite_v$(VERSION)_benchmark.tar.gz
artifactName: 'benchmark-output'
publishLocation: 'pipeline'
displayName: 'Publish Benchmark Output as Artifact'

View File

@@ -107,11 +107,39 @@ stages:
coverage html -d htmlcov -i
displayName: 'Run tests and code coverage'
# Run the notebooks
- script: |
pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml
notebooks_exit_code=$?
pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml
package_notebooks_exit_code=$?
# Fail step if either of these do not have exit code 0
if [ $notebooks_exit_code -ne 0 ] || [ $package_notebooks_exit_code -ne 0 ]; then
exit 1
fi
displayName: 'Run notebooks on Linux and macOS'
condition: or(eq(variables['Agent.OS'], 'Linux'), eq(variables['Agent.OS'], 'Darwin'))
# Run notebooks
- script: |
pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml
set notebooks_exit_code=%ERRORLEVEL%
pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml
set package_notebooks_exit_code=%ERRORLEVEL%
rem Fail step if either of these do not have exit code 0
if %notebooks_exit_code% NEQ 0 exit /b 1
if %package_notebooks_exit_code% NEQ 0 exit /b 1
displayName: 'Run notebooks on Windows'
condition: eq(variables['Agent.OS'], 'Windows_NT')
- task: PublishTestResults@2
condition: succeededOrFailed()
displayName: 'Publish Test Results'
inputs:
testRunner: JUnit
testResultsFiles: 'junit/**.xml'
testResultsFiles: |
'junit/**.xml'
'notebook-tests/**.xml'
testRunTitle: 'Publish test results'
failTaskOnFailedTests: true

1
.gitignore vendored
View File

@@ -54,6 +54,7 @@ cover/
tests/assets/**/*.png
tests/assets/**/tensorboard_logs/
tests/assets/**/checkpoints/
notebook-tests/*.xml
# Translations
*.mo

View File

@@ -117,14 +117,14 @@ class BenchmarkSession:
def generate_learn_metadata_dict(self) -> Dict[str, Any]:
"""Metadata specific to the learning session."""
total_s, s_per_step, s_per_100_steps_10_nodes = self._learn_benchmark_durations()
self.gym_env.average_reward_per_episode.pop(0) # remove episode 0
self.gym_env.total_reward_per_episode.pop(0) # remove episode 0
return {
"total_episodes": self.gym_env.episode_counter,
"total_time_steps": self.gym_env.total_time_steps,
"total_s": total_s,
"s_per_step": s_per_step,
"s_per_100_steps_10_nodes": s_per_100_steps_10_nodes,
"av_reward_per_episode": self.gym_env.average_reward_per_episode,
"total_reward_per_episode": self.gym_env.total_reward_per_episode,
}

View File

@@ -9,10 +9,6 @@ import plotly.graph_objects as go
import polars as pl
import yaml
from plotly.graph_objs import Figure
from pylatex import Command, Document
from pylatex import Figure as LatexFigure
from pylatex import Section, Subsection, Tabular
from pylatex.utils import bold
from utils import _get_system_info
import primaite
@@ -39,19 +35,19 @@ def _build_benchmark_results_dict(start_datetime: datetime, metadata_dict: Dict,
"av_s_per_step": sum(d["s_per_step"] for d in metadata_dict.values()) / num_sessions,
"av_s_per_100_steps_10_nodes": sum(d["s_per_100_steps_10_nodes"] for d in metadata_dict.values())
/ num_sessions,
"combined_av_reward_per_episode": {},
"session_av_reward_per_episode": {k: v["av_reward_per_episode"] for k, v in metadata_dict.items()},
"combined_total_reward_per_episode": {},
"session_total_reward_per_episode": {k: v["total_reward_per_episode"] for k, v in metadata_dict.items()},
"config": config,
}
# find the average of each episode across all sessions
episodes = metadata_dict[1]["av_reward_per_episode"].keys()
episodes = metadata_dict[1]["total_reward_per_episode"].keys()
for episode in episodes:
combined_av_reward = (
sum(metadata_dict[k]["av_reward_per_episode"][episode] for k in metadata_dict.keys()) / num_sessions
sum(metadata_dict[k]["total_reward_per_episode"][episode] for k in metadata_dict.keys()) / num_sessions
)
averaged_data["combined_av_reward_per_episode"][episode] = combined_av_reward
averaged_data["combined_total_reward_per_episode"][episode] = combined_av_reward
return averaged_data
@@ -87,7 +83,7 @@ def _plot_benchmark_metadata(
fig = go.Figure(layout=layout)
fig.update_layout(template=PLOT_CONFIG["template"])
for session, av_reward_dict in benchmark_metadata_dict["session_av_reward_per_episode"].items():
for session, av_reward_dict in benchmark_metadata_dict["session_total_reward_per_episode"].items():
df = _get_df_from_episode_av_reward_dict(av_reward_dict)
fig.add_trace(
go.Scatter(
@@ -100,7 +96,7 @@ def _plot_benchmark_metadata(
)
)
df = _get_df_from_episode_av_reward_dict(benchmark_metadata_dict["combined_av_reward_per_episode"])
df = _get_df_from_episode_av_reward_dict(benchmark_metadata_dict["combined_total_reward_per_episode"])
fig.add_trace(
go.Scatter(
x=df["episode"], y=df["av_reward"], mode="lines", name="Combined Session Av", line={"color": "#FF0000"}
@@ -136,11 +132,11 @@ def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure:
Does this by iterating over the ``benchmark/results`` directory and
extracting the benchmark metadata json for each version that has been
benchmarked. The combined_av_reward_per_episode is extracted from each,
benchmarked. The combined_total_reward_per_episode is extracted from each,
converted into a polars dataframe, and plotted as a scatter line in plotly.
"""
major_v = primaite.__version__.split(".")[0]
title = f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*"
title = f"Learning Benchmarking of All Released Versions under Major v{major_v}.#.#"
subtitle = "Rolling Av (Combined Session Av)"
if title:
if subtitle:
@@ -162,7 +158,7 @@ def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure:
metadata_file = dir / f"{dir.name}_benchmark_metadata.json"
with open(metadata_file, "r") as file:
metadata_dict = json.load(file)
df = _get_df_from_episode_av_reward_dict(metadata_dict["combined_av_reward_per_episode"])
df = _get_df_from_episode_av_reward_dict(metadata_dict["combined_total_reward_per_episode"])
fig.add_trace(go.Scatter(x=df["episode"], y=df["rolling_av_reward"], mode="lines", name=dir.name))
@@ -208,98 +204,77 @@ def build_benchmark_latex_report(
fig = _plot_all_benchmarks_combined_session_av(results_directory=results_root_path)
all_version_plot_path = results_root_path / "PrimAITE Versions Learning Benchmark.png"
all_version_plot_path = version_result_dir / "PrimAITE Versions Learning Benchmark.png"
fig.write_image(all_version_plot_path)
geometry_options = {"tmargin": "2.5cm", "rmargin": "2.5cm", "bmargin": "2.5cm", "lmargin": "2.5cm"}
data = benchmark_metadata_dict
primaite_version = data["primaite_version"]
# Create a new document
doc = Document("report", geometry_options=geometry_options)
# Title
doc.preamble.append(Command("title", f"PrimAITE {primaite_version} Learning Benchmark"))
doc.preamble.append(Command("author", "PrimAITE Dev Team"))
doc.preamble.append(Command("date", datetime.now().date()))
doc.append(Command("maketitle"))
with open(version_result_dir / f"PrimAITE v{primaite_version} Learning Benchmark.md", "w") as file:
# Title
file.write(f"# PrimAITE v{primaite_version} Learning Benchmark\n")
file.write("## PrimAITE Dev Team\n")
file.write(f"### {datetime.now().date()}\n")
file.write("\n---\n")
sessions = data["total_sessions"]
episodes = session_metadata[1]["total_episodes"] - 1
steps = data["config"]["game"]["max_episode_length"]
sessions = data["total_sessions"]
episodes = session_metadata[1]["total_episodes"] - 1
steps = data["config"]["game"]["max_episode_length"]
# Body
with doc.create(Section("Introduction")):
doc.append(
# Body
file.write("## 1 Introduction\n")
file.write(
f"PrimAITE v{primaite_version} was benchmarked automatically upon release. Learning rate metrics "
f"were captured to be referenced during system-level testing and user acceptance testing (UAT)."
f"were captured to be referenced during system-level testing and user acceptance testing (UAT).\n"
)
doc.append(
f"\nThe benchmarking process consists of running {sessions} training session using the same "
file.write(
f"The benchmarking process consists of running {sessions} training session using the same "
f"config file. Each session trains an agent for {episodes} episodes, "
f"with each episode consisting of {steps} steps."
f"with each episode consisting of {steps} steps.\n"
)
doc.append(
f"\nThe total reward per episode from each session is captured. This is then used to calculate an "
file.write(
f"The total reward per episode from each session is captured. This is then used to calculate an "
f"caverage total reward per episode from the {sessions} individual sessions for smoothing. "
f"Finally, a 25-widow rolling average of the average total reward per session is calculated for "
f"further smoothing."
f"further smoothing.\n"
)
with doc.create(Section("System Information")):
with doc.create(Subsection("Python")):
with doc.create(Tabular("|l|l|")) as table:
table.add_hline()
table.add_row((bold("Version"), sys.version))
table.add_hline()
file.write("## 2 System Information\n")
i = 1
file.write(f"### 2.{i} Python\n")
file.write(f"**Version:** {sys.version}\n")
for section, section_data in data["system_info"].items():
i += 1
if section_data:
with doc.create(Subsection(section)):
if isinstance(section_data, dict):
with doc.create(Tabular("|l|l|")) as table:
table.add_hline()
for key, value in section_data.items():
table.add_row((bold(key), value))
table.add_hline()
elif isinstance(section_data, list):
headers = section_data[0].keys()
tabs_str = "|".join(["l" for _ in range(len(headers))])
tabs_str = f"|{tabs_str}|"
with doc.create(Tabular(tabs_str)) as table:
table.add_hline()
table.add_row([bold(h) for h in headers])
table.add_hline()
for item in section_data:
table.add_row(item.values())
table.add_hline()
file.write(f"### 2.{i} {section}\n")
if isinstance(section_data, dict):
for key, value in section_data.items():
file.write(f"- **{key}:** {value}\n")
headers_map = {
"total_sessions": "Total Sessions",
"total_episodes": "Total Episodes",
"total_time_steps": "Total Steps",
"av_s_per_session": "Av Session Duration (s)",
"av_s_per_step": "Av Step Duration (s)",
"av_s_per_100_steps_10_nodes": "Av Duration per 100 Steps per 10 Nodes (s)",
}
with doc.create(Section("Stats")):
with doc.create(Subsection("Benchmark Results")):
with doc.create(Tabular("|l|l|")) as table:
table.add_hline()
for section, header in headers_map.items():
if section.startswith("av_"):
table.add_row((bold(header), f"{data[section]:.4f}"))
else:
table.add_row((bold(header), data[section]))
table.add_hline()
headers_map = {
"total_sessions": "Total Sessions",
"total_episodes": "Total Episodes",
"total_time_steps": "Total Steps",
"av_s_per_session": "Av Session Duration (s)",
"av_s_per_step": "Av Step Duration (s)",
"av_s_per_100_steps_10_nodes": "Av Duration per 100 Steps per 10 Nodes (s)",
}
with doc.create(Section("Graphs")):
with doc.create(Subsection(f"v{primaite_version} Learning Benchmark Plot")):
with doc.create(LatexFigure(position="h!")) as pic:
pic.add_image(str(this_version_plot_path))
pic.add_caption(f"PrimAITE {primaite_version} Learning Benchmark Plot")
file.write("## 3 Stats\n")
for section, header in headers_map.items():
if section.startswith("av_"):
file.write(f"- **{header}:** {data[section]:.4f}\n")
else:
file.write(f"- **{header}:** {data[section]}\n")
with doc.create(Subsection(f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*")):
with doc.create(LatexFigure(position="h!")) as pic:
pic.add_image(str(all_version_plot_path))
pic.add_caption(f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*")
file.write("## 4 Graphs\n")
doc.generate_pdf(str(this_version_plot_path).replace(".png", ""), clean_tex=True)
file.write(f"### 4.1 v{primaite_version} Learning Benchmark Plot\n")
file.write(f"![PrimAITE {primaite_version} Learning Benchmark Plot]({this_version_plot_path.name})\n")
file.write(f"### 4.2 Learning Benchmarking of All Released Versions under Major v{major_v}.#.#\n")
file.write(
f"![Learning Benchmarking of All Released Versions under "
f"Major v{major_v}.#.#]({all_version_plot_path.name})\n"
)

View File

@@ -16,3 +16,12 @@ The type of software that should be added. To add |SOFTWARE_NAME| this must be |
===========
The configuration options are the attributes that fall under the options for an application.
``fix_duration``
""""""""""""""""
Optional. Default value is ``2``.
The number of timesteps the |SOFTWARE_NAME| will remain in a ``FIXING`` state before going into a ``GOOD`` state.

View File

@@ -8,7 +8,7 @@
applications/*
More info :py:mod:`primaite.game.game.APPLICATION_TYPES_MAPPING`
More info :py:mod:`primaite.simulator.system.applications.application.Application`
.. include:: list_of_system_applications.rst

View File

@@ -64,7 +64,6 @@ dev = [
"gputil==1.4.0",
"pip-licenses==4.3.0",
"pre-commit==2.20.0",
"pylatex==1.4.1",
"pytest==7.2.0",
"pytest-xdist==3.3.1",
"pytest-cov==4.0.0",
@@ -73,7 +72,9 @@ dev = [
"Sphinx==7.1.2",
"sphinx-copybutton==0.5.2",
"wheel==0.38.4",
"nbsphinx==0.9.4"
"nbsphinx==0.9.4",
"nbmake==1.5.4",
"pytest-xdist==3.3.1"
]
[project.scripts]

View File

@@ -739,8 +739,6 @@ agents:
options:
agent_name: client_2_green_user
agent_settings:
flatten_obs: true

View File

@@ -14,9 +14,10 @@ from abc import ABC, abstractmethod
from typing import Dict, List, Literal, Optional, Tuple, TYPE_CHECKING, Union
from gymnasium import spaces
from pydantic import BaseModel, Field, field_validator, ValidationInfo
from pydantic import BaseModel, ConfigDict, Field, field_validator, ValidationInfo
from primaite import getLogger
from primaite.interface.request import RequestFormat
_LOGGER = getLogger(__name__)
@@ -228,7 +229,7 @@ class NodeApplicationInstallAction(AbstractAction):
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes}
def form_request(self, node_id: int, application_name: str, ip_address: str) -> List[str]:
def form_request(self, node_id: int, application_name: str) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None:
@@ -241,10 +242,81 @@ class NodeApplicationInstallAction(AbstractAction):
"application",
"install",
application_name,
ip_address,
]
class ConfigureDatabaseClientAction(AbstractAction):
"""Action which sets config parameters for a database client on a node."""
class _Opts(BaseModel):
"""Schema for options that can be passed to this action."""
model_config = ConfigDict(extra="forbid")
server_ip_address: Optional[str] = None
server_password: Optional[str] = None
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, node_id: int, config: Dict) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None:
return ["do_nothing"]
ConfigureDatabaseClientAction._Opts.model_validate(config) # check that options adhere to schema
return ["network", "node", node_name, "application", "DatabaseClient", "configure", config]
class ConfigureRansomwareScriptAction(AbstractAction):
"""Action which sets config parameters for a ransomware script on a node."""
class _Opts(BaseModel):
"""Schema for options that can be passed to this option."""
model_config = ConfigDict(extra="forbid")
server_ip_address: Optional[str] = None
server_password: Optional[str] = None
payload: Optional[str] = None
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, node_id: int, config: Dict) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None:
return ["do_nothing"]
ConfigureRansomwareScriptAction._Opts.model_validate(config) # check that options adhere to schema
return ["network", "node", node_name, "application", "RansomwareScript", "configure", config]
class ConfigureDoSBotAction(AbstractAction):
"""Action which sets config parameters for a DoS bot on a node."""
class _Opts(BaseModel):
"""Schema for options that can be passed to this option."""
model_config = ConfigDict(extra="forbid")
target_ip_address: Optional[str] = None
target_port: Optional[str] = None
payload: Optional[str] = None
repeat: Optional[bool] = None
port_scan_p_of_success: Optional[float] = None
dos_intensity: Optional[float] = None
max_sessions: Optional[int] = None
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, node_id: int, config: Dict) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None:
return ["do_nothing"]
self._Opts.model_validate(config) # check that options adhere to schema
return ["network", "node", node_name, "application", "DoSBot", "configure", config]
class NodeApplicationRemoveAction(AbstractAction):
"""Action which removes/uninstalls an application."""
@@ -1045,6 +1117,9 @@ class ActionManager:
"NODE_NMAP_PING_SCAN": NodeNMAPPingScanAction,
"NODE_NMAP_PORT_SCAN": NodeNMAPPortScanAction,
"NODE_NMAP_NETWORK_SERVICE_RECON": NodeNetworkServiceReconAction,
"CONFIGURE_DATABASE_CLIENT": ConfigureDatabaseClientAction,
"CONFIGURE_RANSOMWARE_SCRIPT": ConfigureRansomwareScriptAction,
"CONFIGURE_DOSBOT": ConfigureDoSBotAction,
}
"""Dictionary which maps action type strings to the corresponding action class."""

View File

@@ -360,6 +360,38 @@ class SharedReward(AbstractReward):
return cls(agent_name=agent_name)
class ActionPenalty(AbstractReward):
"""Apply a negative reward when taking any action except DONOTHING."""
def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None:
"""
Initialise the reward.
Reward or penalise agents for doing nothing or taking actions.
:param action_penalty: Reward to give agents for taking any action except DONOTHING
:type action_penalty: float
:param do_nothing_penalty: Reward to give agent for taking the DONOTHING action
:type do_nothing_penalty: float
"""
self.action_penalty = action_penalty
self.do_nothing_penalty = do_nothing_penalty
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the penalty to be applied."""
if last_action_response.action == "DONOTHING":
return self.do_nothing_penalty
else:
return self.action_penalty
@classmethod
def from_config(cls, config: Dict) -> "ActionPenalty":
"""Build the ActionPenalty object from config."""
action_penalty = config.get("action_penalty", -1.0)
do_nothing_penalty = config.get("do_nothing_penalty", 0.0)
return cls(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty)
class RewardFunction:
"""Manages the reward function for the agent."""
@@ -370,6 +402,7 @@ class RewardFunction:
"WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty,
"GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty,
"SHARED_REWARD": SharedReward,
"ACTION_PENALTY": ActionPenalty,
}
"""List of reward class identifiers."""

View File

@@ -55,7 +55,6 @@ class TAP001(AbstractScriptedAgent):
return "NODE_APPLICATION_INSTALL", {
"node_id": self.starting_node_idx,
"application_name": "RansomwareScript",
"ip_address": self.ip_address,
}
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}

View File

@@ -26,11 +26,14 @@ from primaite.simulator.network.hardware.nodes.network.wireless_router import Wi
from primaite.simulator.network.nmne import NmneData, store_nmne_config
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient # noqa: F401
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import ( # noqa: F401
DataManipulationBot,
)
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot # noqa: F401
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript # noqa: F401
from primaite.simulator.system.applications.web_browser import WebBrowser # noqa: F401
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
@@ -42,15 +45,6 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
_LOGGER = getLogger(__name__)
APPLICATION_TYPES_MAPPING = {
"WebBrowser": WebBrowser,
"DatabaseClient": DatabaseClient,
"DataManipulationBot": DataManipulationBot,
"DoSBot": DoSBot,
"RansomwareScript": RansomwareScript,
}
"""List of available applications that can be installed on nodes in the PrimAITE Simulation."""
SERVICE_TYPES_MAPPING = {
"DNSClient": DNSClient,
"DNSServer": DNSServer,
@@ -302,6 +296,10 @@ class PrimaiteGame:
new_node.software_manager.install(SERVICE_TYPES_MAPPING[service_type])
new_service = new_node.software_manager.software[service_type]
# fixing duration for the service
if "fix_duration" in service_cfg.get("options", {}):
new_service.fixing_duration = service_cfg["options"]["fix_duration"]
# start the service
new_service.start()
else:
@@ -324,7 +322,8 @@ class PrimaiteGame:
if "options" in service_cfg:
opt = service_cfg["options"]
new_service.password = opt.get("db_password", None)
new_service.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip")))
if "backup_server_ip" in opt:
new_service.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip")))
if service_type == "FTPServer":
if "options" in service_cfg:
opt = service_cfg["options"]
@@ -338,9 +337,13 @@ class PrimaiteGame:
new_application = None
application_type = application_cfg["type"]
if application_type in APPLICATION_TYPES_MAPPING:
new_node.software_manager.install(APPLICATION_TYPES_MAPPING[application_type])
new_application = new_node.software_manager.software[application_type]
if application_type in Application._application_registry:
new_node.software_manager.install(Application._application_registry[application_type])
new_application = new_node.software_manager.software[application_type] # grab the instance
# fixing duration for the application
if "fix_duration" in application_cfg.get("options", {}):
new_application.fixing_duration = application_cfg["options"]["fix_duration"]
else:
msg = f"Configuration contains an invalid application type: {application_type}"
_LOGGER.error(msg)
@@ -363,7 +366,7 @@ class PrimaiteGame:
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
server_ip_address=IPv4Address(opt.get("server_ip")),
server_ip_address=IPv4Address(opt.get("server_ip")) if opt.get("server_ip") else None,
server_password=opt.get("server_password"),
payload=opt.get("payload", "ENCRYPT"),
)

View File

@@ -507,8 +507,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now service 1 on node 2 has `health_status = 3`, indicating that the webapp is compromised.\n",
"File 1 in folder 1 on node 3 has `health_status = 2`, indicating that the database file is compromised."
"Now service 1 on HOST1 has `health_status = 3`, indicating that the webapp is compromised.\n",
"File 1 in folder 1 on HOST2 has `health_status = 2`, indicating that the database file is compromised."
]
},
{
@@ -545,9 +545,9 @@
"source": [
"The fixing takes two steps, so the reward hasn't changed yet. Let's do nothing for another timestep, the reward should improve.\n",
"\n",
"The reward will increase slightly as soon as the file finishes restoring. Then, the reward will increase to 1 when both green agents make successful requests.\n",
"The reward will increase slightly as soon as the file finishes restoring. Then, the reward will increase to 0.9 when both green agents make successful requests.\n",
"\n",
"Run the following cell until the green action is `NODE_APPLICATION_EXECUTE` for application 0, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again."
"Run the following cell until the green action is `NODE_APPLICATION_EXECUTE` for application 0, then the reward should increase. If you run it enough times, another red attack will happen and the reward will drop again."
]
},
{
@@ -708,7 +708,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -143,7 +143,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.11"
}
},
"nbformat": 4,

View File

@@ -171,7 +171,7 @@
"from primaite.simulator.file_system.file_system import FileSystem\n",
"\n",
"# no applications exist yet so we will create our own.\n",
"class MSPaint(Application):\n",
"class MSPaint(Application, identifier=\"MSPaint\"):\n",
" def describe_state(self):\n",
" return super().describe_state()"
]

View File

@@ -196,7 +196,7 @@ class SimComponent(BaseModel):
..code::python
class WebBrowser(Application):
class WebBrowser(Application, identifier="WebBrowser"):
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager() # all requests generic to any Application get initialised
rm.add_request(...) # initialise any requests specific to the web browser

View File

@@ -6,7 +6,7 @@ import secrets
from abc import ABC, abstractmethod
from ipaddress import IPv4Address, IPv4Network
from pathlib import Path
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
from typing import Any, ClassVar, Dict, Optional, TypeVar, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel, Field
@@ -878,6 +878,61 @@ class Node(SimComponent):
More information in user guide and docstring for SimComponent._init_request_manager.
"""
def _install_application(request: RequestFormat, context: Dict) -> RequestResponse:
"""
Allows agents to install applications to the node.
:param request: list containing the application name as the only element
:type request: RequestFormat
:param context: additional context for resolving this action, currently unused
:type context: dict
:return: Request response with a success code if the application was installed.
:rtype: RequestResponse
"""
application_name = request[0]
if self.software_manager.software.get(application_name):
self.sys_log.warning(f"Can't install {application_name}. It's already installed.")
return RequestResponse.from_bool(False)
application_class = Application._application_registry[application_name]
self.software_manager.install(application_class)
application_instance = self.software_manager.software.get(application_name)
self.applications[application_instance.uuid] = application_instance
_LOGGER.debug(f"Added application {application_instance.name} to node {self.hostname}")
self._application_request_manager.add_request(
application_name, RequestType(func=application_instance._request_manager)
)
application_instance.install()
if application_name in self.software_manager.software:
return RequestResponse.from_bool(True)
else:
return RequestResponse.from_bool(False)
def _uninstall_application(request: RequestFormat, context: Dict) -> RequestResponse:
"""
Uninstall and completely remove application from this node.
This method is useful for allowing agents to take this action.
:param request: list containing the application name as the only element
:type request: RequestFormat
:param context: additional context for resolving this action, currently unused
:type context: dict
:return: Request response with a success code if the application was uninstalled.
:rtype: RequestResponse
"""
application_name = request[0]
if application_name not in self.software_manager.software:
self.sys_log.warning(f"Can't uninstall {application_name}. It's not installed.")
return RequestResponse.from_bool(False)
application_instance = self.software_manager.software.get(application_name)
self.software_manager.uninstall(application_instance.name)
if application_instance.name not in self.software_manager.software:
return RequestResponse.from_bool(True)
else:
return RequestResponse.from_bool(False)
_node_is_on = Node._NodeIsOnValidator(node=self)
rm = super()._init_request_manager()
@@ -934,25 +989,8 @@ class Node(SimComponent):
name="application", request_type=RequestType(func=self._application_manager)
)
self._application_manager.add_request(
name="install",
request_type=RequestType(
func=lambda request, context: RequestResponse.from_bool(
self.application_install_action(
application=self._read_application_type(request[0]), ip_address=request[1]
)
)
),
)
self._application_manager.add_request(
name="uninstall",
request_type=RequestType(
func=lambda request, context: RequestResponse.from_bool(
self.application_uninstall_action(application=self._read_application_type(request[0]))
)
),
)
self._application_manager.add_request(name="install", request_type=RequestType(func=_install_application))
self._application_manager.add_request(name="uninstall", request_type=RequestType(func=_uninstall_application))
return rm
@@ -960,29 +998,6 @@ class Node(SimComponent):
"""Install System Software - software that is usually provided with the OS."""
pass
def _read_application_type(self, application_class_str: str) -> Type[IOSoftwareClass]:
"""Wrapper that converts the string from the request manager into the appropriate class for the application."""
if application_class_str == "DoSBot":
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
return DoSBot
elif application_class_str == "DataManipulationBot":
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import (
DataManipulationBot,
)
return DataManipulationBot
elif application_class_str == "WebBrowser":
from primaite.simulator.system.applications.web_browser import WebBrowser
return WebBrowser
elif application_class_str == "RansomwareScript":
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
return RansomwareScript
else:
return 0
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -1411,76 +1426,6 @@ class Node(SimComponent):
self.sys_log.info(f"Uninstalled application {application.name}")
self._application_request_manager.remove_request(application.name)
def application_install_action(self, application: Application, ip_address: Optional[str] = None) -> bool:
"""
Install an application on this node and configure it.
This method is useful for allowing agents to take this action.
:param application: Application object that has not been installed on any node yet.
:type application: Application
:param ip_address: IP address used to configure the application
(target IP for the DoSBot or server IP for the DataManipulationBot)
:type ip_address: str
:return: True if the application is installed successfully, otherwise False.
"""
if application in self:
_LOGGER.warning(
f"Can't add application {application.__name__}" + f"to node {self.hostname}. It's already installed."
)
return True
self.software_manager.install(application)
application_instance = self.software_manager.software.get(str(application.__name__))
self.applications[application_instance.uuid] = application_instance
_LOGGER.debug(f"Added application {application_instance.name} to node {self.hostname}")
self._application_request_manager.add_request(
application_instance.name, RequestType(func=application_instance._request_manager)
)
# Configure application if additional parameters are given
if ip_address:
if application_instance.name == "DoSBot":
application_instance.configure(target_ip_address=IPv4Address(ip_address))
elif application_instance.name == "DataManipulationBot":
application_instance.configure(server_ip_address=IPv4Address(ip_address))
elif application_instance.name == "RansomwareScript":
application_instance.configure(server_ip_address=IPv4Address(ip_address))
else:
pass
application_instance.install()
if application_instance.name in self.software_manager.software:
return True
else:
return False
def application_uninstall_action(self, application: Application) -> bool:
"""
Uninstall and completely remove application from this node.
This method is useful for allowing agents to take this action.
:param application: Application object that is currently associated with this node.
:type application: Application
:return: True if the application is uninstalled successfully, otherwise False.
"""
if application.__name__ not in self.software_manager.software:
_LOGGER.warning(
f"Can't remove application {application.__name__}" + f"from node {self.hostname}. It's not installed."
)
return True
application_instance = self.software_manager.software.get(
str(application.__name__)
) # This works because we can't have two applications with the same name on the same node
# self.uninstall_application(application_instance)
self.software_manager.uninstall(application_instance.name)
if application_instance.name not in self.software_manager.software:
return True
else:
return False
def _shut_down_actions(self):
"""Actions to perform when the node is shut down."""
# Turn off all the services in the node

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, Optional, Set
from typing import Any, ClassVar, Dict, Optional, Set, Type
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
@@ -39,6 +39,22 @@ class Application(IOSoftware):
install_countdown: Optional[int] = None
"The countdown to the end of the installation process. None if not currently installing"
_application_registry: ClassVar[Dict[str, Type["Application"]]] = {}
"""Registry of application types. Automatically populated when subclasses are defined."""
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
"""
Register an application type.
:param identifier: Uniquely specifies an application class by name. Used for finding items by config.
:type identifier: str
:raises ValueError: When attempting to register an application with a name that is already allocated.
"""
super().__init_subclass__(**kwargs)
if identifier in cls._application_registry:
raise ValueError(f"Tried to define new application {identifier}, but this name is already reserved.")
cls._application_registry[identifier] = cls
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -8,13 +8,14 @@ from uuid import uuid4
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel
from primaite.interface.request import RequestResponse
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.utils.validators import IPV4Address
class DatabaseClientConnection(BaseModel):
@@ -54,7 +55,7 @@ class DatabaseClientConnection(BaseModel):
self.client._disconnect(self.connection_id) # noqa
class DatabaseClient(Application):
class DatabaseClient(Application, identifier="DatabaseClient"):
"""
A DatabaseClient application.
@@ -96,6 +97,14 @@ class DatabaseClient(Application):
"""
rm = super()._init_request_manager()
rm.add_request("execute", RequestType(func=lambda request, context: RequestResponse.from_bool(self.execute())))
def _configure(request: RequestFormat, context: Dict) -> RequestResponse:
ip, pw = request[-1].get("server_ip_address"), request[-1].get("server_password")
ip = None if ip is None else IPV4Address(ip)
success = self.configure(server_ip_address=ip, server_password=pw)
return RequestResponse.from_bool(success)
rm.add_request("configure", RequestType(func=_configure))
return rm
def execute(self) -> bool:
@@ -141,16 +150,17 @@ class DatabaseClient(Application):
table.add_row([connection_id, connection.is_active])
print(table.get_string(sortby="Connection ID"))
def configure(self, server_ip_address: IPv4Address, server_password: Optional[str] = None):
def configure(self, server_ip_address: Optional[IPv4Address] = None, server_password: Optional[str] = None) -> bool:
"""
Configure the DatabaseClient to communicate with a DatabaseService.
:param server_ip_address: The IP address of the Node the DatabaseService is on.
:param server_password: The password on the DatabaseService.
"""
self.server_ip_address = server_ip_address
self.server_password = server_password
self.server_ip_address = server_ip_address or self.server_ip_address
self.server_password = server_password or self.server_password
self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.")
return True
def connect(self) -> bool:
"""Connect the native client connection."""

View File

@@ -44,7 +44,7 @@ class PortScanPayload(SimComponent):
return state
class NMAP(Application):
class NMAP(Application, identifier="NMAP"):
"""
A class representing the NMAP application for network scanning.

View File

@@ -37,7 +37,7 @@ class DataManipulationAttackStage(IntEnum):
"Signifies that the attack has failed."
class DataManipulationBot(Application):
class DataManipulationBot(Application, identifier="DataManipulationBot"):
"""A bot that simulates a script which performs a SQL injection attack."""
payload: Optional[str] = None

View File

@@ -1,11 +1,11 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from enum import IntEnum
from ipaddress import IPv4Address
from typing import Optional
from typing import Dict, Optional
from primaite import getLogger
from primaite.game.science import simulate_trial
from primaite.interface.request import RequestResponse
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
@@ -29,7 +29,7 @@ class DoSAttackStage(IntEnum):
"Attack is completed."
class DoSBot(DatabaseClient):
class DoSBot(DatabaseClient, identifier="DoSBot"):
"""A bot that simulates a Denial of Service attack."""
target_ip_address: Optional[IPv4Address] = None
@@ -71,6 +71,24 @@ class DoSBot(DatabaseClient):
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.run())),
)
def _configure(request: RequestFormat, context: Dict) -> RequestResponse:
"""
Configure the DoSBot.
:param request: List with one element that is a dict of options to pass to the configure method.
:type request: RequestFormat
:param context: additional context for resolving this action, currently unused
:type context: dict
:return: Request Response object with a success code determining if the configuration was successful.
:rtype: RequestResponse
"""
if "target_ip_address" in request[-1]:
request[-1]["target_ip_address"] = IPv4Address(request[-1]["target_ip_address"])
if "target_port" in request[-1]:
request[-1]["target_port"] = Port[request[-1]["target_port"]]
return RequestResponse.from_bool(self.configure(**request[-1]))
rm.add_request("configure", request_type=RequestType(func=_configure))
return rm
def configure(
@@ -82,7 +100,7 @@ class DoSBot(DatabaseClient):
port_scan_p_of_success: float = 0.1,
dos_intensity: float = 1.0,
max_sessions: int = 1000,
):
) -> bool:
"""
Configure the Denial of Service bot.
@@ -90,10 +108,12 @@ class DoSBot(DatabaseClient):
:param: target_port: The port of the target service. Optional - Default is `Port.HTTP`
:param: payload: The payload the DoS Bot will throw at the target service. Optional - Default is `None`
:param: repeat: If True, the bot will maintain the attack. Optional - Default is `True`
:param: port_scan_p_of_success: The chance of the port scan being sucessful. Optional - Default is 0.1 (10%)
:param: port_scan_p_of_success: The chance of the port scan being successful. Optional - Default is 0.1 (10%)
:param: dos_intensity: The intensity of the DoS attack.
Multiplied with the application's max session - Default is 1.0
:param: max_sessions: The maximum number of sessions the DoS bot will attack with. Optional - Default is 1000
:return: Always returns True
:rtype: bool
"""
self.target_ip_address = target_ip_address
self.target_port = target_port
@@ -106,6 +126,7 @@ class DoSBot(DatabaseClient):
f"{self.name}: Configured the {self.name} with {target_ip_address=}, {target_port=}, {payload=}, "
f"{repeat=}, {port_scan_p_of_success=}, {dos_intensity=}, {max_sessions=}."
)
return True
def run(self) -> bool:
"""Run the Denial of Service Bot."""
@@ -117,6 +138,9 @@ class DoSBot(DatabaseClient):
The main application loop for the Denial of Service bot.
The loop goes through the stages of a DoS attack.
:return: True if the application loop could be executed, False otherwise.
:rtype: bool
"""
if not self._can_perform_action():
return False
@@ -126,7 +150,7 @@ class DoSBot(DatabaseClient):
self.sys_log.warning(
f"{self.name} is not properly configured. {self.target_ip_address=}, {self.target_port=}"
)
return True
return False
self.clear_connections()
self._perform_port_scan(p_of_success=self.port_scan_p_of_success)

View File

@@ -2,7 +2,7 @@
from ipaddress import IPv4Address
from typing import Dict, Optional
from primaite.interface.request import RequestResponse
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
@@ -10,7 +10,7 @@ from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
class RansomwareScript(Application):
class RansomwareScript(Application, identifier="RansomwareScript"):
"""Ransomware Kill Chain - Designed to be used by the TAP001 Agent on the example layout Network.
:ivar payload: The attack stage query payload. (Default ENCRYPT)
@@ -62,6 +62,25 @@ class RansomwareScript(Application):
name="execute",
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.attack())),
)
def _configure(request: RequestFormat, context: Dict) -> RequestResponse:
"""
Request for configuring the target database and payload.
:param request: Request with one element contianing a dict of parameters for the configure method.
:type request: RequestFormat
:param context: additional context for resolving this action, currently unused
:type context: dict
:return: RequestResponse object with a success code reflecting whether the configuration could be applied.
:rtype: RequestResponse
"""
ip = request[-1].get("server_ip_address")
ip = None if ip is None else IPv4Address(ip)
pw = request[-1].get("server_password")
payload = request[-1].get("payload")
return RequestResponse.from_bool(self.configure(ip, pw, payload))
rm.add_request("configure", request_type=RequestType(func=_configure))
return rm
def run(self) -> bool:
@@ -88,10 +107,10 @@ class RansomwareScript(Application):
def configure(
self,
server_ip_address: IPv4Address,
server_ip_address: Optional[IPv4Address] = None,
server_password: Optional[str] = None,
payload: Optional[str] = None,
):
) -> bool:
"""
Configure the Ransomware Script to communicate with a DatabaseService.
@@ -108,6 +127,7 @@ class RansomwareScript(Application):
self.sys_log.info(
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}."
)
return True
def attack(self) -> bool:
"""Perform the attack steps after opening the application."""

View File

@@ -23,7 +23,7 @@ from primaite.simulator.system.services.dns.dns_client import DNSClient
_LOGGER = getLogger(__name__)
class WebBrowser(Application):
class WebBrowser(Application, identifier="WebBrowser"):
"""
Represents a web browser in the simulation environment.

View File

@@ -197,7 +197,11 @@ class DatabaseService(Service):
status_code = 503 # service unavailable
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
self.sys_log.error(f"{self.name}: Connect request for {src_ip=} declined. Service is at capacity.")
if self.health_state_actual == SoftwareHealthState.GOOD:
if self.health_state_actual in [
SoftwareHealthState.GOOD,
SoftwareHealthState.FIXING,
SoftwareHealthState.COMPROMISED,
]:
if self.password == password:
status_code = 200 # ok
connection_id = self._generate_connection_id()
@@ -244,6 +248,10 @@ class DatabaseService(Service):
self.sys_log.error(f"{self.name}: Failed to run {query} because the database file is missing.")
return {"status_code": 404, "type": "sql", "data": False}
if self.health_state_actual is not SoftwareHealthState.GOOD:
self.sys_log.error(f"{self.name}: Failed to run {query} because the database service is unavailable.")
return {"status_code": 500, "type": "sql", "data": False}
if query == "SELECT":
if self.db_file.health_status == FileSystemItemHealthStatus.CORRUPT:
return {

View File

@@ -0,0 +1,792 @@
io_settings:
save_agent_actions: false
save_step_metadata: false
save_pcap_logs: false
save_sys_logs: false
game:
max_episode_length: 256
ports:
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
thresholds:
nmne:
high: 10
medium: 5
low: 0
agents:
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: CUSTOM
options:
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1:eth-1<->switch_1:eth-8
- router_1:eth-2<->switch_2:eth-8
- switch_1:eth-1<->domain_controller:eth-1
- switch_1:eth-2<->web_server:eth-1
- switch_1:eth-3<->database_server:eth-1
- switch_1:eth-4<->backup_server:eth-1
- switch_1:eth-7<->security_suite:eth-1
- switch_2:eth-1<->client_1:eth-1
- switch_2:eth-2<->client_2:eth-1
- switch_2:eth-7<->security_suite:eth-2
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:
- type: DONOTHING
- type: NODE_SERVICE_SCAN
- type: NODE_SERVICE_STOP
- type: NODE_SERVICE_START
- type: NODE_SERVICE_PAUSE
- type: NODE_SERVICE_RESUME
- type: NODE_SERVICE_RESTART
- type: NODE_SERVICE_DISABLE
- type: NODE_SERVICE_ENABLE
- type: NODE_SERVICE_FIX
- type: NODE_FILE_SCAN
- type: NODE_FILE_CHECKHASH
- type: NODE_FILE_DELETE
- type: NODE_FILE_REPAIR
- type: NODE_FILE_RESTORE
- type: NODE_FOLDER_SCAN
- type: NODE_FOLDER_CHECKHASH
- type: NODE_FOLDER_REPAIR
- type: NODE_FOLDER_RESTORE
- type: NODE_OS_SCAN
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: NODE_RESET
- type: ROUTER_ACL_ADDRULE
- type: ROUTER_ACL_REMOVERULE
- type: HOST_NIC_ENABLE
- type: HOST_NIC_DISABLE
action_map:
0:
action: DONOTHING
options: {}
# scan webapp service
1:
action: NODE_SERVICE_SCAN
options:
node_id: 1
service_id: 0
# stop webapp service
2:
action: NODE_SERVICE_STOP
options:
node_id: 1
service_id: 0
# start webapp service
3:
action: "NODE_SERVICE_START"
options:
node_id: 1
service_id: 0
4:
action: "NODE_SERVICE_PAUSE"
options:
node_id: 1
service_id: 0
5:
action: "NODE_SERVICE_RESUME"
options:
node_id: 1
service_id: 0
6:
action: "NODE_SERVICE_RESTART"
options:
node_id: 1
service_id: 0
7:
action: "NODE_SERVICE_DISABLE"
options:
node_id: 1
service_id: 0
8:
action: "NODE_SERVICE_ENABLE"
options:
node_id: 1
service_id: 0
9: # check database.db file
action: "NODE_FILE_SCAN"
options:
node_id: 2
folder_id: 0
file_id: 0
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 2
folder_id: 0
file_id: 0
11:
action: "NODE_FILE_DELETE"
options:
node_id: 2
folder_id: 0
file_id: 0
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 2
folder_id: 0
file_id: 0
13:
action: "NODE_SERVICE_FIX"
options:
node_id: 2
service_id: 0
14:
action: "NODE_FOLDER_SCAN"
options:
node_id: 2
folder_id: 0
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 2
folder_id: 0
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 2
folder_id: 0
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 2
folder_id: 0
18:
action: "NODE_OS_SCAN"
options:
node_id: 0
19:
action: "NODE_SHUTDOWN"
options:
node_id: 0
20:
action: NODE_STARTUP
options:
node_id: 0
21:
action: NODE_RESET
options:
node_id: 0
22:
action: "NODE_OS_SCAN"
options:
node_id: 1
23:
action: "NODE_SHUTDOWN"
options:
node_id: 1
24:
action: NODE_STARTUP
options:
node_id: 1
25:
action: NODE_RESET
options:
node_id: 1
26: # old action num: 18
action: "NODE_OS_SCAN"
options:
node_id: 2
27:
action: "NODE_SHUTDOWN"
options:
node_id: 2
28:
action: NODE_STARTUP
options:
node_id: 2
29:
action: NODE_RESET
options:
node_id: 2
30:
action: "NODE_OS_SCAN"
options:
node_id: 3
31:
action: "NODE_SHUTDOWN"
options:
node_id: 3
32:
action: NODE_STARTUP
options:
node_id: 3
33:
action: NODE_RESET
options:
node_id: 3
34:
action: "NODE_OS_SCAN"
options:
node_id: 4
35:
action: "NODE_SHUTDOWN"
options:
node_id: 4
36:
action: NODE_STARTUP
options:
node_id: 4
37:
action: NODE_RESET
options:
node_id: 4
38:
action: "NODE_OS_SCAN"
options:
node_id: 5
39: # old action num: 19 # shutdown client 1
action: "NODE_SHUTDOWN"
options:
node_id: 5
40: # old action num: 20
action: NODE_STARTUP
options:
node_id: 5
41: # old action num: 21
action: NODE_RESET
options:
node_id: 5
42:
action: "NODE_OS_SCAN"
options:
node_id: 6
43:
action: "NODE_SHUTDOWN"
options:
node_id: 6
44:
action: NODE_STARTUP
options:
node_id: 6
45:
action: NODE_RESET
options:
node_id: 6
46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1"
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 1
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2"
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 2
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
48: # old action num: 24 # block tcp traffic from client 1 to web app
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 3
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
49: # old action num: 25 # block tcp traffic from client 2 to web app
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 4
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
50: # old action num: 26
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 5
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
51: # old action num: 27
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 6
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
52: # old action num: 28
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 0
53: # old action num: 29
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 1
54: # old action num: 30
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 2
55: # old action num: 31
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 3
56: # old action num: 32
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 4
57: # old action num: 33
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 5
58: # old action num: 34
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 6
59: # old action num: 35
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 7
60: # old action num: 36
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 8
61: # old action num: 37
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 9
62: # old action num: 38
action: "HOST_NIC_DISABLE"
options:
node_id: 0
nic_id: 0
63: # old action num: 39
action: "HOST_NIC_ENABLE"
options:
node_id: 0
nic_id: 0
64: # old action num: 40
action: "HOST_NIC_DISABLE"
options:
node_id: 1
nic_id: 0
65: # old action num: 41
action: "HOST_NIC_ENABLE"
options:
node_id: 1
nic_id: 0
66: # old action num: 42
action: "HOST_NIC_DISABLE"
options:
node_id: 2
nic_id: 0
67: # old action num: 43
action: "HOST_NIC_ENABLE"
options:
node_id: 2
nic_id: 0
68: # old action num: 44
action: "HOST_NIC_DISABLE"
options:
node_id: 3
nic_id: 0
69: # old action num: 45
action: "HOST_NIC_ENABLE"
options:
node_id: 3
nic_id: 0
70: # old action num: 46
action: "HOST_NIC_DISABLE"
options:
node_id: 4
nic_id: 0
71: # old action num: 47
action: "HOST_NIC_ENABLE"
options:
node_id: 4
nic_id: 0
72: # old action num: 48
action: "HOST_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
73: # old action num: 49
action: "HOST_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
74: # old action num: 50
action: "HOST_NIC_DISABLE"
options:
node_id: 5
nic_id: 0
75: # old action num: 51
action: "HOST_NIC_ENABLE"
options:
node_id: 5
nic_id: 0
76: # old action num: 52
action: "HOST_NIC_DISABLE"
options:
node_id: 6
nic_id: 0
77: # old action num: 53
action: "HOST_NIC_ENABLE"
options:
node_id: 6
nic_id: 0
options:
nodes:
- node_name: domain_controller
- node_name: web_server
applications:
- application_name: DatabaseClient
services:
- service_name: WebServer
- node_name: database_server
folders:
- folder_name: database
files:
- file_name: database.db
services:
- service_name: DatabaseService
- node_name: backup_server
- node_name: security_suite
- node_name: client_1
- node_name: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
reward_function:
reward_components:
- type: ACTION_PENALTY
weight: 1.0
options:
action_penalty: -0.75
do_nothing_penalty: 0.125
agent_settings:
flatten_obs: true
simulation:
network:
nmne_config:
capture_nmne: true
nmne_capture_keywords:
- DELETE
nodes:
- hostname: router_1
type: router
num_ports: 5
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.10.1
subnet_mask: 255.255.255.0
acl:
18:
action: PERMIT
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
19:
action: PERMIT
src_port: DNS
dst_port: DNS
20:
action: PERMIT
src_port: FTP
dst_port: FTP
21:
action: PERMIT
src_port: HTTP
dst_port: HTTP
22:
action: PERMIT
src_port: ARP
dst_port: ARP
23:
action: PERMIT
protocol: ICMP
- hostname: switch_1
type: switch
num_ports: 8
- hostname: switch_2
type: switch
num_ports: 8
- hostname: domain_controller
type: server
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- type: DNSServer
options:
domain_mapping:
arcd.com: 192.168.1.12 # web server
- hostname: web_server
type: server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- type: WebServer
applications:
- type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- hostname: database_server
type: server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- type: DatabaseService
options:
backup_server_ip: 192.168.1.16
- type: FTPClient
- hostname: backup_server
type: server
ip_address: 192.168.1.16
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- type: FTPServer
- hostname: security_suite
type: server
ip_address: 192.168.1.110
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
network_interfaces:
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
ip_address: 192.168.10.110
subnet_mask: 255.255.255.0
- hostname: client_1
type: computer
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
- type: WebBrowser
options:
target_url: http://arcd.com/users/
- type: DatabaseClient
options:
db_server_ip: 192.168.1.14
services:
- type: DNSClient
- hostname: client_2
type: computer
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- type: WebBrowser
options:
target_url: http://arcd.com/users/
- type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
- type: DatabaseClient
options:
db_server_ip: 192.168.1.14
services:
- type: DNSClient
links:
- endpoint_a_hostname: router_1
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 8
- endpoint_a_hostname: router_1
endpoint_a_port: 2
endpoint_b_hostname: switch_2
endpoint_b_port: 8
- endpoint_a_hostname: switch_1
endpoint_a_port: 1
endpoint_b_hostname: domain_controller
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 2
endpoint_b_hostname: web_server
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 3
endpoint_b_hostname: database_server
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 4
endpoint_b_hostname: backup_server
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 7
endpoint_b_hostname: security_suite
endpoint_b_port: 1
- endpoint_a_hostname: switch_2
endpoint_a_port: 1
endpoint_b_hostname: client_1
endpoint_b_port: 1
- endpoint_a_hostname: switch_2
endpoint_a_port: 2
endpoint_b_hostname: client_2
endpoint_b_port: 1
- endpoint_a_hostname: switch_2
endpoint_a_port: 7
endpoint_b_hostname: security_suite
endpoint_b_port: 2

View File

@@ -0,0 +1,248 @@
# Basic Switched network
#
# -------------- -------------- --------------
# | client_1 |------| switch_1 |------| client_2 |
# -------------- -------------- --------------
#
io_settings:
save_step_metadata: false
save_pcap_logs: true
save_sys_logs: true
sys_log_level: WARNING
game:
max_episode_length: 256
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
options:
nodes:
- node_name: client_2
applications:
- application_name: WebBrowser
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_settings:
start_step: 5
frequency: 4
variance: 3
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: CUSTOM
options:
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: client_1
- hostname: client_2
- hostname: client_3
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
monitored_traffic:
icmp:
- NONE
tcp:
- DNS
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.10.21
- 192.168.10.22
- 192.168.10.23
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- switch_1:eth-1<->client_1:eth-1
- switch_1:eth-2<->client_2:eth-1
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:
- type: DONOTHING
action_map:
0:
action: DONOTHING
options: {}
options:
nodes:
- node_name: switch
- node_name: client_1
- node_name: client_2
- node_name: client_3
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_list:
- 192.168.10.21
- 192.168.10.22
- 192.168.10.23
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_hostname: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_hostname: web_server
service_name: web_server_web_service
agent_settings:
flatten_obs: true
simulation:
network:
nodes:
- type: switch
hostname: switch_1
num_ports: 8
- hostname: client_1
type: computer
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- type: RansomwareScript
- type: WebBrowser
options:
target_url: http://arcd.com/users/
- type: DatabaseClient
options:
db_server_ip: 192.168.1.10
server_password: arcd
fix_duration: 1
- type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.21
server_password: arcd
- type: DoSBot
options:
target_ip_address: 192.168.10.21
payload: SPOOF DATA
port_scan_p_of_success: 0.8
services:
- type: DNSClient
options:
dns_server: 192.168.1.10
- type: DNSServer
options:
domain_mapping:
arcd.com: 192.168.1.10
- type: DatabaseService
options:
fix_duration: 5
backup_server_ip: 192.168.1.10
- type: WebServer
- type: FTPClient
- type: FTPServer
options:
server_password: arcd
- type: NTPClient
options:
ntp_server_ip: 192.168.1.10
- type: NTPServer
- hostname: client_2
type: computer
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- type: DatabaseClient
options:
db_server_ip: 192.168.1.10
server_password: arcd
services:
- type: DNSClient
options:
dns_server: 192.168.1.10
links:
- endpoint_a_hostname: switch_1
endpoint_a_port: 1
endpoint_b_hostname: client_1
endpoint_b_port: 1
bandwidth: 200
- endpoint_a_hostname: switch_1
endpoint_a_port: 2
endpoint_b_hostname: client_2
endpoint_b_port: 1
bandwidth: 200

View File

@@ -0,0 +1,142 @@
io_settings:
save_step_metadata: false
save_pcap_logs: false
save_sys_logs: false
save_agent_actions: false
game:
max_episode_length: 256
ports:
- ARP
- DNS
protocols:
- ICMP
- TCP
agents:
- ref: agent_1
team: BLUE
type: ProxyAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_INSTALL
- type: CONFIGURE_DATABASE_CLIENT
- type: CONFIGURE_DOSBOT
- type: CONFIGURE_RANSOMWARE_SCRIPT
- type: NODE_APPLICATION_REMOVE
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_INSTALL
options:
node_id: 0
application_name: DatabaseClient
2:
action: NODE_APPLICATION_INSTALL
options:
node_id: 1
application_name: RansomwareScript
3:
action: NODE_APPLICATION_INSTALL
options:
node_id: 2
application_name: DoSBot
4:
action: CONFIGURE_DATABASE_CLIENT
options:
node_id: 0
config:
server_ip_address: 10.0.0.5
5:
action: CONFIGURE_DATABASE_CLIENT
options:
node_id: 0
config:
server_password: correct_password
6:
action: CONFIGURE_RANSOMWARE_SCRIPT
options:
node_id: 1
config:
server_ip_address: 10.0.0.5
server_password: correct_password
payload: ENCRYPT
7:
action: CONFIGURE_DOSBOT
options:
node_id: 2
config:
target_ip_address: 10.0.0.5
target_port: POSTGRES_SERVER
payload: DELETE
repeat: true
port_scan_p_of_success: 1.0
dos_intensity: 1.0
max_sessions: 1000
8:
action: NODE_APPLICATION_INSTALL
options:
node_id: 1
application_name: DatabaseClient
options:
nodes:
- node_name: client_1
- node_name: client_2
- node_name: client_3
ip_list: []
reward_function:
reward_components:
- type: DUMMY
simulation:
network:
nodes:
- type: computer
hostname: client_1
ip_address: 10.0.0.2
subnet_mask: 255.255.255.0
default_gateway: 10.0.0.1
- type: computer
hostname: client_2
ip_address: 10.0.0.3
subnet_mask: 255.255.255.0
default_gateway: 10.0.0.1
- type: computer
hostname: client_3
ip_address: 10.0.0.4
subnet_mask: 255.255.255.0
default_gateway: 10.0.0.1
- type: switch
hostname: switch_1
num_ports: 8
- type: server
hostname: server_1
ip_address: 10.0.0.5
subnet_mask: 255.255.255.0
default_gateway: 10.0.0.1
services:
- type: DatabaseService
options:
db_password: correct_password
links:
- endpoint_a_hostname: client_1
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 1
- endpoint_a_hostname: client_2
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 2
- endpoint_a_hostname: client_3
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 3
- endpoint_a_hostname: server_1
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 8

View File

@@ -0,0 +1,263 @@
# Basic Switched network
#
# -------------- -------------- --------------
# | client_1 |------| switch_1 |------| client_2 |
# -------------- -------------- --------------
#
io_settings:
save_step_metadata: false
save_pcap_logs: true
save_sys_logs: true
sys_log_level: WARNING
game:
max_episode_length: 256
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
options:
nodes:
- node_name: client_2
applications:
- application_name: WebBrowser
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_settings:
start_step: 5
frequency: 4
variance: 3
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: CUSTOM
options:
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: client_1
- hostname: client_2
- hostname: client_3
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
monitored_traffic:
icmp:
- NONE
tcp:
- DNS
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.10.21
- 192.168.10.22
- 192.168.10.23
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- switch_1:eth-1<->client_1:eth-1
- switch_1:eth-2<->client_2:eth-1
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:
- type: DONOTHING
action_map:
0:
action: DONOTHING
options: {}
options:
nodes:
- node_name: switch
- node_name: client_1
- node_name: client_2
- node_name: client_3
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_list:
- 192.168.10.21
- 192.168.10.22
- 192.168.10.23
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_hostname: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_hostname: web_server
service_name: web_server_web_service
agent_settings:
flatten_obs: true
simulation:
network:
nodes:
- type: switch
hostname: switch_1
num_ports: 8
- hostname: client_1
type: computer
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- type: RansomwareScript
options:
fix_duration: 1
- type: WebBrowser
options:
target_url: http://arcd.com/users/
fix_duration: 1
- type: DatabaseClient
options:
db_server_ip: 192.168.1.10
server_password: arcd
fix_duration: 1
- type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.21
server_password: arcd
fix_duration: 1
- type: DoSBot
options:
target_ip_address: 192.168.10.21
payload: SPOOF DATA
port_scan_p_of_success: 0.8
fix_duration: 1
services:
- type: DNSClient
options:
dns_server: 192.168.1.10
fix_duration: 3
- type: DNSServer
options:
fix_duration: 3
domain_mapping:
arcd.com: 192.168.1.10
- type: DatabaseService
options:
backup_server_ip: 192.168.1.10
fix_duration: 3
- type: WebServer
options:
fix_duration: 3
- type: FTPClient
options:
fix_duration: 3
- type: FTPServer
options:
server_password: arcd
fix_duration: 3
- type: NTPClient
options:
ntp_server_ip: 192.168.1.10
fix_duration: 3
- type: NTPServer
options:
fix_duration: 3
- hostname: client_2
type: computer
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- type: DatabaseClient
options:
db_server_ip: 192.168.1.10
server_password: arcd
services:
- type: DNSClient
options:
dns_server: 192.168.1.10
links:
- endpoint_a_hostname: switch_1
endpoint_a_port: 1
endpoint_b_hostname: client_1
endpoint_b_port: 1
bandwidth: 200
- endpoint_a_hostname: switch_1
endpoint_a_port: 2
endpoint_b_hostname: client_2
endpoint_b_port: 1
bandwidth: 200

View File

@@ -260,6 +260,7 @@ agents:
- type: NODE_APPLICATION_INSTALL
- type: NODE_APPLICATION_REMOVE
- type: NODE_APPLICATION_EXECUTE
- type: CONFIGURE_DOSBOT
action_map:
0:
@@ -683,7 +684,6 @@ agents:
options:
node_id: 0
application_name: DoSBot
ip_address: 192.168.1.14
79:
action: NODE_APPLICATION_REMOVE
options:
@@ -699,6 +699,14 @@ agents:
options:
node_id: 0
application_id: 0
82:
action: CONFIGURE_DOSBOT
options:
node_id: 0
config:
target_ip_address: 192.168.1.14
target_port: POSTGRES_SERVER

View File

@@ -51,11 +51,11 @@ class TestService(Service):
pass
class TestApplication(Application):
class DummyApplication(Application, identifier="DummyApplication"):
"""Test Application class"""
def __init__(self, **kwargs):
kwargs["name"] = "TestApplication"
kwargs["name"] = "DummyApplication"
kwargs["port"] = Port.HTTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
@@ -85,15 +85,15 @@ def service_class():
@pytest.fixture(scope="function")
def application(file_system) -> TestApplication:
return TestApplication(
name="TestApplication", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_application")
def application(file_system) -> DummyApplication:
return DummyApplication(
name="DummyApplication", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_application")
)
@pytest.fixture(scope="function")
def application_class():
return TestApplication
return DummyApplication
@pytest.fixture(scope="function")

View File

@@ -69,6 +69,7 @@ def test_application_install_uninstall_on_uc2():
env.step(0)
# Test we can now execute the DoSBot app
env.step(82) # configure dos bot with ip address and port
_, _, _, _, info = env.step(81)
assert info["agent_actions"]["defender"].response.status == "success"

View File

@@ -9,9 +9,10 @@ from primaite.config.load import data_manipulation_config_path
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING
from primaite.game.game import PrimaiteGame, SERVICE_TYPES_MAPPING
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
@@ -85,7 +86,7 @@ def test_node_software_install():
assert client_2.software_manager.software.get(software.__name__) is not None
# check that applications have been installed on client 1
for applications in APPLICATION_TYPES_MAPPING:
for applications in Application._application_registry:
assert client_1.software_manager.software.get(applications) is not None
# check that services have been installed on client 1

View File

@@ -0,0 +1,93 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import copy
from ipaddress import IPv4Address
from pathlib import Path
from typing import Union
import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
from primaite.simulator.system.services.web_server.web_server import WebServer
from tests import TEST_ASSETS_ROOT
TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fix_duration.yaml"
ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fix_duration_one_item.yaml"
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
"""Returns a PrimaiteGame object which loads the contents of a given yaml path."""
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
return PrimaiteGame.from_config(cfg)
def test_default_fix_duration():
"""Test that software with no defined fix duration in config uses the default fix duration of 2."""
game = load_config(TEST_CONFIG)
client_2: Computer = game.simulation.network.get_node_by_hostname("client_2")
database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
assert database_client.fixing_duration == 2
dns_client: DNSClient = client_2.software_manager.software.get("DNSClient")
assert dns_client.fixing_duration == 2
def test_fix_duration_set_from_config():
"""Test to check that the fix duration set for applications and services works as intended."""
game = load_config(TEST_CONFIG)
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
# in config - services take 3 timesteps to fix
for service in SERVICE_TYPES_MAPPING:
assert client_1.software_manager.software.get(service) is not None
assert client_1.software_manager.software.get(service).fixing_duration == 3
# in config - applications take 1 timestep to fix
for applications in APPLICATION_TYPES_MAPPING:
assert client_1.software_manager.software.get(applications) is not None
assert client_1.software_manager.software.get(applications).fixing_duration == 1
def test_fix_duration_for_one_item():
"""Test that setting fix duration for one application does not affect other components."""
game = load_config(ONE_ITEM_CONFIG)
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
# in config - services take 3 timesteps to fix
services = copy.copy(SERVICE_TYPES_MAPPING)
services.pop("DatabaseService")
for service in services:
assert client_1.software_manager.software.get(service) is not None
assert client_1.software_manager.software.get(service).fixing_duration == 2
# in config - applications take 1 timestep to fix
applications = copy.copy(APPLICATION_TYPES_MAPPING)
applications.pop("DatabaseClient")
for applications in applications:
assert client_1.software_manager.software.get(applications) is not None
assert client_1.software_manager.software.get(applications).fixing_duration == 2
database_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
assert database_client.fixing_duration == 1
database_service: DatabaseService = client_1.software_manager.software.get("DatabaseService")
assert database_service.fixing_duration == 5

View File

@@ -0,0 +1 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

View File

@@ -0,0 +1,292 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from ipaddress import IPv4Address
import pytest
from pydantic import ValidationError
from primaite.game.agent.actions import (
ConfigureDatabaseClientAction,
ConfigureDoSBotAction,
ConfigureRansomwareScriptAction,
)
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
from primaite.simulator.system.services.database.database_service import DatabaseService
from tests import TEST_ASSETS_ROOT
from tests.conftest import ControlledAgent
APP_CONFIG_YAML = TEST_ASSETS_ROOT / "configs/install_and_configure_apps.yaml"
class TestConfigureDatabaseAction:
def test_configure_ip_password(self, game_and_agent):
game, agent = game_and_agent
agent: ControlledAgent
agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager)
# make sure there is a database client on this node
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.install(DatabaseClient)
db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"]
action = (
"CONFIGURE_DATABASE_CLIENT",
{
"node_id": 0,
"config": {
"server_ip_address": "192.168.1.99",
"server_password": "admin123",
},
},
)
agent.store_action(action)
game.step()
assert db_client.server_ip_address == IPv4Address("192.168.1.99")
assert db_client.server_password == "admin123"
def test_configure_ip(self, game_and_agent):
game, agent = game_and_agent
agent: ControlledAgent
agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager)
# make sure there is a database client on this node
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.install(DatabaseClient)
db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"]
action = (
"CONFIGURE_DATABASE_CLIENT",
{
"node_id": 0,
"config": {
"server_ip_address": "192.168.1.99",
},
},
)
agent.store_action(action)
game.step()
assert db_client.server_ip_address == IPv4Address("192.168.1.99")
assert db_client.server_password is None
def test_configure_password(self, game_and_agent):
game, agent = game_and_agent
agent: ControlledAgent
agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager)
# make sure there is a database client on this node
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.install(DatabaseClient)
db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"]
old_ip = db_client.server_ip_address
action = (
"CONFIGURE_DATABASE_CLIENT",
{
"node_id": 0,
"config": {
"server_password": "admin123",
},
},
)
agent.store_action(action)
game.step()
assert db_client.server_ip_address == old_ip
assert db_client.server_password is "admin123"
class TestConfigureRansomwareScriptAction:
@pytest.mark.parametrize(
"config",
[
{},
{"server_ip_address": "181.181.181.181"},
{"server_password": "admin123"},
{"payload": "ENCRYPT"},
{
"server_ip_address": "181.181.181.181",
"server_password": "admin123",
"payload": "ENCRYPT",
},
],
)
def test_configure_ip_password(self, game_and_agent, config):
game, agent = game_and_agent
agent: ControlledAgent
agent.action_manager.actions["CONFIGURE_RANSOMWARE_SCRIPT"] = ConfigureRansomwareScriptAction(
agent.action_manager
)
# make sure there is a database client on this node
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.install(RansomwareScript)
ransomware_script: RansomwareScript = client_1.software_manager.software["RansomwareScript"]
old_ip = ransomware_script.server_ip_address
old_pw = ransomware_script.server_password
old_payload = ransomware_script.payload
action = (
"CONFIGURE_RANSOMWARE_SCRIPT",
{"node_id": 0, "config": config},
)
agent.store_action(action)
game.step()
expected_ip = old_ip if "server_ip_address" not in config else IPv4Address(config["server_ip_address"])
expected_pw = old_pw if "server_password" not in config else config["server_password"]
expected_payload = old_payload if "payload" not in config else config["payload"]
assert ransomware_script.server_ip_address == expected_ip
assert ransomware_script.server_password == expected_pw
assert ransomware_script.payload == expected_payload
def test_invalid_config(self, game_and_agent):
game, agent = game_and_agent
agent: ControlledAgent
agent.action_manager.actions["CONFIGURE_RANSOMWARE_SCRIPT"] = ConfigureRansomwareScriptAction(
agent.action_manager
)
# make sure there is a database client on this node
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.install(RansomwareScript)
ransomware_script: RansomwareScript = client_1.software_manager.software["RansomwareScript"]
action = (
"CONFIGURE_RANSOMWARE_SCRIPT",
{
"node_id": 0,
"config": {"server_password": "admin123", "bad_option": 70},
},
)
agent.store_action(action)
with pytest.raises(ValidationError):
game.step()
class TestConfigureDoSBot:
def test_configure_DoSBot(self, game_and_agent):
game, agent = game_and_agent
agent: ControlledAgent
agent.action_manager.actions["CONFIGURE_DOSBOT"] = ConfigureDoSBotAction(agent.action_manager)
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.install(DoSBot)
dos_bot: DoSBot = client_1.software_manager.software["DoSBot"]
action = (
"CONFIGURE_DOSBOT",
{
"node_id": 0,
"config": {
"target_ip_address": "192.168.1.99",
"target_port": "POSTGRES_SERVER",
"payload": "HACC",
"repeat": False,
"port_scan_p_of_success": 0.875,
"dos_intensity": 0.75,
"max_sessions": 50,
},
},
)
agent.store_action(action)
game.step()
assert dos_bot.target_ip_address == IPv4Address("192.168.1.99")
assert dos_bot.target_port == Port.POSTGRES_SERVER
assert dos_bot.payload == "HACC"
assert not dos_bot.repeat
assert dos_bot.port_scan_p_of_success == 0.875
assert dos_bot.dos_intensity == 0.75
assert dos_bot.max_sessions == 50
class TestConfigureYAML:
def test_configure_db_client(self):
env = PrimaiteGymEnv(env_config=APP_CONFIG_YAML)
# make sure there's no db client on the node yet
client_1 = env.game.simulation.network.get_node_by_hostname("client_1")
assert client_1.software_manager.software.get("DatabaseClient") is None
# take the install action, check that the db gets installed, step to get it to finish installing
env.step(1)
db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
assert isinstance(db_client, DatabaseClient)
assert db_client.operating_state == ApplicationOperatingState.INSTALLING
env.step(0)
env.step(0)
env.step(0)
env.step(0)
# configure the ip address and check that it changes, but password doesn't change
assert db_client.server_ip_address is None
assert db_client.server_password is None
env.step(4)
assert db_client.server_ip_address == IPv4Address("10.0.0.5")
assert db_client.server_password is None
# configure the password and check that it changes, make sure this lets us connect to the db
assert not db_client.connect()
env.step(5)
assert db_client.server_password == "correct_password"
assert db_client.connect()
def test_configure_ransomware_script(self):
env = PrimaiteGymEnv(env_config=APP_CONFIG_YAML)
client_2 = env.game.simulation.network.get_node_by_hostname("client_2")
assert client_2.software_manager.software.get("RansomwareScript") is None
# install ransomware script
env.step(2)
ransom = client_2.software_manager.software.get("RansomwareScript")
assert isinstance(ransom, RansomwareScript)
assert ransom.operating_state == ApplicationOperatingState.INSTALLING
env.step(0)
env.step(0)
env.step(0)
env.step(0)
# make sure it's not working yet because it's not configured and there's no db client
assert not ransom.attack()
env.step(8) # install db client on the same node
env.step(0)
env.step(0)
env.step(0)
env.step(0) # let it finish installing
assert not ransom.attack()
# finally, configure the ransomware script with ip and password
env.step(6)
assert ransom.attack()
db_server = env.game.simulation.network.get_node_by_hostname("server_1")
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
assert db_service.db_file.health_status == FileSystemItemHealthStatus.CORRUPT
def test_configure_dos_bot(self):
env = PrimaiteGymEnv(env_config=APP_CONFIG_YAML)
client_3 = env.game.simulation.network.get_node_by_hostname("client_3")
assert client_3.software_manager.software.get("DoSBot") is None
# install DoSBot
env.step(3)
bot = client_3.software_manager.software.get("DoSBot")
assert isinstance(bot, DoSBot)
assert bot.operating_state == ApplicationOperatingState.INSTALLING
env.step(0)
env.step(0)
env.step(0)
env.step(0)
# make sure dos bot doesn't work before being configured
assert not bot.run()
env.step(7)
assert bot.run()

View File

@@ -557,7 +557,7 @@ def test_node_application_install_and_uninstall_integration(game_and_agent: Tupl
assert client_1.software_manager.software.get("DoSBot") is None
action = ("NODE_APPLICATION_INSTALL", {"node_id": 0, "application_name": "DoSBot", "ip_address": "192.168.1.14"})
action = ("NODE_APPLICATION_INSTALL", {"node_id": 0, "application_name": "DoSBot"})
agent.store_action(action)
game.step()

View File

@@ -1,9 +1,11 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import pytest
import yaml
from primaite.game.agent.interface import AgentHistoryItem
from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
from primaite.game.agent.rewards import ActionPenalty, GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
from primaite.game.game import PrimaiteGame
from primaite.interface.request import RequestResponse
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
@@ -119,3 +121,77 @@ def test_shared_reward():
g2_reward = env.game.agents["client_2_green_user"].reward_function.current_reward
blue_reward = env.game.agents["defender"].reward_function.current_reward
assert blue_reward == g1_reward + g2_reward
def test_action_penalty_loads_from_config():
"""Test to ensure that action penalty is correctly loaded from config into PrimaiteGymEnv"""
CFG_PATH = TEST_ASSETS_ROOT / "configs/action_penalty.yaml"
with open(CFG_PATH, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(env_config=cfg)
env.reset()
defender = env.game.agents["defender"]
act_penalty_obj = None
for comp in defender.reward_function.reward_components:
if isinstance(comp[0], ActionPenalty):
act_penalty_obj = comp[0]
if act_penalty_obj is None:
pytest.fail("Action penalty reward component was not added to the agent from config.")
assert act_penalty_obj.action_penalty == -0.75
assert act_penalty_obj.do_nothing_penalty == 0.125
def test_action_penalty():
"""Test that the action penalty is correctly applied when agent performs any action"""
# Create an ActionPenalty Reward
Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
# Assert that penalty is applied if action isn't DONOTHING
reward_value = Penalty.calculate(
state={},
last_action_response=AgentHistoryItem(
timestep=0,
action="NODE_APPLICATION_EXECUTE",
parameters={"node_id": 0, "application_id": 1},
request=["execute"],
response=RequestResponse.from_bool(True),
),
)
assert reward_value == -0.75
# Assert that no penalty applied for a DONOTHING action
reward_value = Penalty.calculate(
state={},
last_action_response=AgentHistoryItem(
timestep=0,
action="DONOTHING",
parameters={},
request=["do_nothing"],
response=RequestResponse.from_bool(True),
),
)
assert reward_value == 0.125
def test_action_penalty_e2e(game_and_agent):
"""Test that we get the right reward for doing actions to fetch a website."""
game, agent = game_and_agent
agent: ControlledAgent
comp = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
agent.reward_function.register_component(comp, 1.0)
action = ("DONOTHING", {})
agent.store_action(action)
game.step()
assert agent.reward_function.current_reward == 0.125
action = ("NODE_FILE_SCAN", {"node_id": 0, "folder_id": 0, "file_id": 0})
agent.store_action(action)
game.step()
assert agent.reward_function.current_reward == -0.75

View File

@@ -41,7 +41,7 @@ class BroadcastService(Service):
super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port.HTTP, ip_protocol=self.protocol)
class BroadcastClient(Application):
class BroadcastClient(Application, identifier="BroadcastClient"):
"""A client application to receive broadcast and unicast messages."""
payloads_received: List = []

View File

@@ -3,9 +3,9 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.networks import multi_lan_internet_network_example
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from src.primaite.simulator.system.applications.web_browser import WebBrowser
def test_all_with_configured_dns_server_ip_can_resolve_url():

View File

@@ -21,7 +21,7 @@ def populated_node(application_class) -> Tuple[Application, Computer]:
computer.power_on()
computer.software_manager.install(application_class)
app = computer.software_manager.software.get("TestApplication")
app = computer.software_manager.software.get("DummyApplication")
app.run()
return app, computer
@@ -39,7 +39,7 @@ def test_application_on_offline_node(application_class):
)
computer.software_manager.install(application_class)
app: Application = computer.software_manager.software.get("TestApplication")
app: Application = computer.software_manager.software.get("DummyApplication")
computer.power_off()

View File

@@ -14,6 +14,7 @@ from primaite.simulator.system.applications.database_client import DatabaseClien
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.service import ServiceOperatingState
from primaite.simulator.system.software import SoftwareHealthState
@pytest.fixture(scope="function")
@@ -213,6 +214,110 @@ def test_restore_backup_after_deleting_file_without_updating_scan(uc2_network):
assert db_service.db_file.visible_health_status == FileSystemItemHealthStatus.GOOD # now looks good
def test_database_service_fix(uc2_network):
"""Test that the software fix applies to database service."""
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
assert db_service.backup_database() is True
# delete database locally
db_service.file_system.delete_file(folder_name="database", file_name="database.db")
# db file is gone, reduced to atoms
assert db_service.db_file is None
db_service.fix() # fix the database service
assert db_service.health_state_actual == SoftwareHealthState.FIXING
# apply timestep until the fix is applied
for i in range(db_service.fixing_duration + 1):
uc2_network.apply_timestep(i)
assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD
assert db_service.health_state_actual == SoftwareHealthState.GOOD
def test_database_cannot_be_queried_while_fixing(uc2_network):
"""Tests that the database service cannot be queried if the service is being fixed."""
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
db_connection: DatabaseClientConnection = db_client.get_new_connection()
assert db_connection.query(sql="SELECT")
assert db_service.backup_database() is True
# delete database locally
db_service.file_system.delete_file(folder_name="database", file_name="database.db")
# db file is gone, reduced to atoms
assert db_service.db_file is None
db_service.fix() # fix the database service
assert db_service.health_state_actual == SoftwareHealthState.FIXING
# fails to query because database is in FIXING state
assert db_connection.query(sql="SELECT") is False
# apply timestep until the fix is applied
for i in range(db_service.fixing_duration + 1):
uc2_network.apply_timestep(i)
assert db_service.health_state_actual == SoftwareHealthState.GOOD
assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD
assert db_connection.query(sql="SELECT")
def test_database_can_create_connection_while_fixing(uc2_network):
"""Tests that connections cannot be created while the database is being fixed."""
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
client_2: Server = uc2_network.get_node_by_hostname("client_2")
db_client: DatabaseClient = client_2.software_manager.software["DatabaseClient"]
db_connection: DatabaseClientConnection = db_client.get_new_connection()
assert db_connection.query(sql="SELECT")
assert db_service.backup_database() is True
# delete database locally
db_service.file_system.delete_file(folder_name="database", file_name="database.db")
# db file is gone, reduced to atoms
assert db_service.db_file is None
db_service.fix() # fix the database service
assert db_service.health_state_actual == SoftwareHealthState.FIXING
# fails to query because database is in FIXING state
assert db_connection.query(sql="SELECT") is False
# should be able to create a new connection
new_db_connection: DatabaseClientConnection = db_client.get_new_connection()
assert new_db_connection is not None
assert new_db_connection.query(sql="SELECT") is False # still should fail to query because FIXING
# apply timestep until the fix is applied
for i in range(db_service.fixing_duration + 1):
uc2_network.apply_timestep(i)
assert db_service.health_state_actual == SoftwareHealthState.GOOD
assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD
assert db_connection.query(sql="SELECT")
assert new_db_connection.query(sql="SELECT")
def test_database_client_cannot_query_offline_database_server(uc2_network):
"""Tests DB query across the network returns HTTP status 404 when db server is offline."""
db_server: Server = uc2_network.get_node_by_hostname("database_server")

View File

@@ -16,6 +16,7 @@ from primaite.simulator.system.services.database.database_service import Databas
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.web_server.web_server import WebServer
from primaite.simulator.system.software import SoftwareHealthState
@pytest.fixture(scope="function")
@@ -110,6 +111,29 @@ def test_web_client_requests_users(web_client_web_server_database):
assert web_browser.get_webpage()
def test_database_fix_disrupts_web_client(uc2_network):
"""Tests that the database service being in fixed state disrupts the web client."""
computer: Computer = uc2_network.get_node_by_hostname("client_1")
db_server: Server = uc2_network.get_node_by_hostname("database_server")
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
database_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
# fix the database service
database_service.fix()
assert database_service.health_state_actual == SoftwareHealthState.FIXING
assert web_browser.get_webpage() is False
for i in range(database_service.fixing_duration + 1):
uc2_network.apply_timestep(i)
assert database_service.health_state_actual == SoftwareHealthState.GOOD
assert web_browser.get_webpage()
class TestWebBrowserHistory:
def test_populating_history(self, web_client_web_server_database):
network, computer, _, _ = web_client_web_server_database

View File

@@ -13,7 +13,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.transmission.transport_layer import Port
from tests.conftest import TestApplication, TestService
from tests.conftest import DummyApplication, TestService
def test_successful_node_file_system_creation_request(example_network):
@@ -47,14 +47,14 @@ def test_successful_application_requests(example_network):
net = example_network
client_1 = net.get_node_by_hostname("client_1")
client_1.software_manager.install(TestApplication)
client_1.software_manager.software.get("TestApplication").run()
client_1.software_manager.install(DummyApplication)
client_1.software_manager.software.get("DummyApplication").run()
resp_1 = net.apply_request(["node", "client_1", "application", "TestApplication", "scan"])
resp_1 = net.apply_request(["node", "client_1", "application", "DummyApplication", "scan"])
assert resp_1 == RequestResponse(status="success", data={})
resp_2 = net.apply_request(["node", "client_1", "application", "TestApplication", "fix"])
resp_2 = net.apply_request(["node", "client_1", "application", "DummyApplication", "fix"])
assert resp_2 == RequestResponse(status="success", data={})
resp_3 = net.apply_request(["node", "client_1", "application", "TestApplication", "compromise"])
resp_3 = net.apply_request(["node", "client_1", "application", "DummyApplication", "compromise"])
assert resp_3 == RequestResponse(status="success", data={})

View File

@@ -0,0 +1,22 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import pytest
from primaite.simulator.system.applications.application import Application
def test_adding_to_app_registry():
class temp_application(Application, identifier="temp_app"):
pass
assert Application._application_registry["temp_app"] is temp_application
with pytest.raises(ValueError):
class another_application(Application, identifier="temp_app"):
pass
# This is kinda evil...
# Because pytest doesn't reimport classes from modules, registering this temporary test application will change the
# state of the Application registry for all subsequently run tests. So, we have to delete and unregister the class.
del temp_application
Application._application_registry.pop("temp_app")