Merge branch 'dev' into bugfix/2676_NMNE_var_access
This commit is contained in:
84
.azure/azure-benchmark-pipeline.yaml
Normal file
84
.azure/azure-benchmark-pipeline.yaml
Normal file
@@ -0,0 +1,84 @@
|
||||
trigger:
|
||||
branches:
|
||||
exclude:
|
||||
- '*'
|
||||
include:
|
||||
- 'refs/heads/release/*'
|
||||
|
||||
schedules:
|
||||
- cron: "0 2 * * 1-5" # Run at 2 AM every weekday
|
||||
displayName: "Weekday Schedule"
|
||||
branches:
|
||||
include:
|
||||
- 'refs/heads/dev'
|
||||
|
||||
pool:
|
||||
vmImage: ubuntu-latest
|
||||
|
||||
variables:
|
||||
VERSION: ''
|
||||
MAJOR_VERSION: ''
|
||||
|
||||
steps:
|
||||
- checkout: self
|
||||
persistCredentials: true
|
||||
|
||||
- script: |
|
||||
VERSION=$(cat src/primaite/VERSION | tr -d '\n')
|
||||
if [[ "$(Build.SourceBranch)" == "refs/heads/dev" ]]; then
|
||||
DATE=$(date +%Y%m%d)
|
||||
echo "${VERSION}+dev.${DATE}" > src/primaite/VERSION
|
||||
fi
|
||||
displayName: 'Update VERSION file for Dev Benchmark'
|
||||
|
||||
- script: |
|
||||
VERSION=$(cat src/primaite/VERSION | tr -d '\n')
|
||||
MAJOR_VERSION=$(echo $VERSION | cut -d. -f1)
|
||||
echo "##vso[task.setvariable variable=VERSION]$VERSION"
|
||||
echo "##vso[task.setvariable variable=MAJOR_VERSION]$MAJOR_VERSION"
|
||||
displayName: 'Set Version Variables'
|
||||
|
||||
- task: UsePythonVersion@0
|
||||
inputs:
|
||||
versionSpec: '3.11'
|
||||
addToPath: true
|
||||
|
||||
- script: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -e .[dev,rl]
|
||||
primaite setup
|
||||
displayName: 'Install Dependencies'
|
||||
|
||||
- script: |
|
||||
cd benchmark
|
||||
python3 primaite_benchmark.py
|
||||
cd ..
|
||||
displayName: 'Run Benchmarking Script'
|
||||
|
||||
- script: |
|
||||
git config --global user.email "oss@dstl.gov.uk"
|
||||
git config --global user.name "Defence Science and Technology Laboratory UK"
|
||||
workingDirectory: $(System.DefaultWorkingDirectory)
|
||||
displayName: 'Configure Git'
|
||||
condition: and(succeeded(), eq(variables['Build.Reason'], 'Manual'), startsWith(variables['Build.SourceBranch'], 'refs/heads/release'))
|
||||
|
||||
- script: |
|
||||
git add benchmark/results/v$(MAJOR_VERSION)/v$(VERSION)/*
|
||||
git commit -m "Automated benchmark output commit for version $(VERSION)"
|
||||
git push origin HEAD:refs/heads/$(Build.SourceBranchName)
|
||||
displayName: 'Commit and Push Benchmark Results'
|
||||
workingDirectory: $(System.DefaultWorkingDirectory)
|
||||
env:
|
||||
GIT_CREDENTIALS: $(System.AccessToken)
|
||||
condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/heads/release'))
|
||||
|
||||
- script: |
|
||||
tar czf primaite_v$(VERSION)_benchmark.tar.gz benchmark/results/v$(MAJOR_VERSION)/v$(VERSION)
|
||||
displayName: 'Prepare Artifacts for Publishing'
|
||||
|
||||
- task: PublishPipelineArtifact@1
|
||||
inputs:
|
||||
targetPath: primaite_v$(VERSION)_benchmark.tar.gz
|
||||
artifactName: 'benchmark-output'
|
||||
publishLocation: 'pipeline'
|
||||
displayName: 'Publish Benchmark Output as Artifact'
|
||||
@@ -107,11 +107,39 @@ stages:
|
||||
coverage html -d htmlcov -i
|
||||
displayName: 'Run tests and code coverage'
|
||||
|
||||
# Run the notebooks
|
||||
- script: |
|
||||
pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml
|
||||
notebooks_exit_code=$?
|
||||
pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml
|
||||
package_notebooks_exit_code=$?
|
||||
# Fail step if either of these do not have exit code 0
|
||||
if [ $notebooks_exit_code -ne 0 ] || [ $package_notebooks_exit_code -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
displayName: 'Run notebooks on Linux and macOS'
|
||||
condition: or(eq(variables['Agent.OS'], 'Linux'), eq(variables['Agent.OS'], 'Darwin'))
|
||||
|
||||
# Run notebooks
|
||||
- script: |
|
||||
pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml
|
||||
set notebooks_exit_code=%ERRORLEVEL%
|
||||
pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml
|
||||
set package_notebooks_exit_code=%ERRORLEVEL%
|
||||
rem Fail step if either of these do not have exit code 0
|
||||
if %notebooks_exit_code% NEQ 0 exit /b 1
|
||||
if %package_notebooks_exit_code% NEQ 0 exit /b 1
|
||||
displayName: 'Run notebooks on Windows'
|
||||
condition: eq(variables['Agent.OS'], 'Windows_NT')
|
||||
|
||||
- task: PublishTestResults@2
|
||||
condition: succeededOrFailed()
|
||||
displayName: 'Publish Test Results'
|
||||
inputs:
|
||||
testRunner: JUnit
|
||||
testResultsFiles: 'junit/**.xml'
|
||||
testResultsFiles: |
|
||||
'junit/**.xml'
|
||||
'notebook-tests/**.xml'
|
||||
testRunTitle: 'Publish test results'
|
||||
failTaskOnFailedTests: true
|
||||
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -54,6 +54,7 @@ cover/
|
||||
tests/assets/**/*.png
|
||||
tests/assets/**/tensorboard_logs/
|
||||
tests/assets/**/checkpoints/
|
||||
notebook-tests/*.xml
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
|
||||
@@ -117,14 +117,14 @@ class BenchmarkSession:
|
||||
def generate_learn_metadata_dict(self) -> Dict[str, Any]:
|
||||
"""Metadata specific to the learning session."""
|
||||
total_s, s_per_step, s_per_100_steps_10_nodes = self._learn_benchmark_durations()
|
||||
self.gym_env.average_reward_per_episode.pop(0) # remove episode 0
|
||||
self.gym_env.total_reward_per_episode.pop(0) # remove episode 0
|
||||
return {
|
||||
"total_episodes": self.gym_env.episode_counter,
|
||||
"total_time_steps": self.gym_env.total_time_steps,
|
||||
"total_s": total_s,
|
||||
"s_per_step": s_per_step,
|
||||
"s_per_100_steps_10_nodes": s_per_100_steps_10_nodes,
|
||||
"av_reward_per_episode": self.gym_env.average_reward_per_episode,
|
||||
"total_reward_per_episode": self.gym_env.total_reward_per_episode,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -9,10 +9,6 @@ import plotly.graph_objects as go
|
||||
import polars as pl
|
||||
import yaml
|
||||
from plotly.graph_objs import Figure
|
||||
from pylatex import Command, Document
|
||||
from pylatex import Figure as LatexFigure
|
||||
from pylatex import Section, Subsection, Tabular
|
||||
from pylatex.utils import bold
|
||||
from utils import _get_system_info
|
||||
|
||||
import primaite
|
||||
@@ -39,19 +35,19 @@ def _build_benchmark_results_dict(start_datetime: datetime, metadata_dict: Dict,
|
||||
"av_s_per_step": sum(d["s_per_step"] for d in metadata_dict.values()) / num_sessions,
|
||||
"av_s_per_100_steps_10_nodes": sum(d["s_per_100_steps_10_nodes"] for d in metadata_dict.values())
|
||||
/ num_sessions,
|
||||
"combined_av_reward_per_episode": {},
|
||||
"session_av_reward_per_episode": {k: v["av_reward_per_episode"] for k, v in metadata_dict.items()},
|
||||
"combined_total_reward_per_episode": {},
|
||||
"session_total_reward_per_episode": {k: v["total_reward_per_episode"] for k, v in metadata_dict.items()},
|
||||
"config": config,
|
||||
}
|
||||
|
||||
# find the average of each episode across all sessions
|
||||
episodes = metadata_dict[1]["av_reward_per_episode"].keys()
|
||||
episodes = metadata_dict[1]["total_reward_per_episode"].keys()
|
||||
|
||||
for episode in episodes:
|
||||
combined_av_reward = (
|
||||
sum(metadata_dict[k]["av_reward_per_episode"][episode] for k in metadata_dict.keys()) / num_sessions
|
||||
sum(metadata_dict[k]["total_reward_per_episode"][episode] for k in metadata_dict.keys()) / num_sessions
|
||||
)
|
||||
averaged_data["combined_av_reward_per_episode"][episode] = combined_av_reward
|
||||
averaged_data["combined_total_reward_per_episode"][episode] = combined_av_reward
|
||||
|
||||
return averaged_data
|
||||
|
||||
@@ -87,7 +83,7 @@ def _plot_benchmark_metadata(
|
||||
fig = go.Figure(layout=layout)
|
||||
fig.update_layout(template=PLOT_CONFIG["template"])
|
||||
|
||||
for session, av_reward_dict in benchmark_metadata_dict["session_av_reward_per_episode"].items():
|
||||
for session, av_reward_dict in benchmark_metadata_dict["session_total_reward_per_episode"].items():
|
||||
df = _get_df_from_episode_av_reward_dict(av_reward_dict)
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
@@ -100,7 +96,7 @@ def _plot_benchmark_metadata(
|
||||
)
|
||||
)
|
||||
|
||||
df = _get_df_from_episode_av_reward_dict(benchmark_metadata_dict["combined_av_reward_per_episode"])
|
||||
df = _get_df_from_episode_av_reward_dict(benchmark_metadata_dict["combined_total_reward_per_episode"])
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["episode"], y=df["av_reward"], mode="lines", name="Combined Session Av", line={"color": "#FF0000"}
|
||||
@@ -136,11 +132,11 @@ def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure:
|
||||
|
||||
Does this by iterating over the ``benchmark/results`` directory and
|
||||
extracting the benchmark metadata json for each version that has been
|
||||
benchmarked. The combined_av_reward_per_episode is extracted from each,
|
||||
benchmarked. The combined_total_reward_per_episode is extracted from each,
|
||||
converted into a polars dataframe, and plotted as a scatter line in plotly.
|
||||
"""
|
||||
major_v = primaite.__version__.split(".")[0]
|
||||
title = f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*"
|
||||
title = f"Learning Benchmarking of All Released Versions under Major v{major_v}.#.#"
|
||||
subtitle = "Rolling Av (Combined Session Av)"
|
||||
if title:
|
||||
if subtitle:
|
||||
@@ -162,7 +158,7 @@ def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure:
|
||||
metadata_file = dir / f"{dir.name}_benchmark_metadata.json"
|
||||
with open(metadata_file, "r") as file:
|
||||
metadata_dict = json.load(file)
|
||||
df = _get_df_from_episode_av_reward_dict(metadata_dict["combined_av_reward_per_episode"])
|
||||
df = _get_df_from_episode_av_reward_dict(metadata_dict["combined_total_reward_per_episode"])
|
||||
|
||||
fig.add_trace(go.Scatter(x=df["episode"], y=df["rolling_av_reward"], mode="lines", name=dir.name))
|
||||
|
||||
@@ -208,98 +204,77 @@ def build_benchmark_latex_report(
|
||||
|
||||
fig = _plot_all_benchmarks_combined_session_av(results_directory=results_root_path)
|
||||
|
||||
all_version_plot_path = results_root_path / "PrimAITE Versions Learning Benchmark.png"
|
||||
all_version_plot_path = version_result_dir / "PrimAITE Versions Learning Benchmark.png"
|
||||
fig.write_image(all_version_plot_path)
|
||||
|
||||
geometry_options = {"tmargin": "2.5cm", "rmargin": "2.5cm", "bmargin": "2.5cm", "lmargin": "2.5cm"}
|
||||
data = benchmark_metadata_dict
|
||||
primaite_version = data["primaite_version"]
|
||||
|
||||
# Create a new document
|
||||
doc = Document("report", geometry_options=geometry_options)
|
||||
# Title
|
||||
doc.preamble.append(Command("title", f"PrimAITE {primaite_version} Learning Benchmark"))
|
||||
doc.preamble.append(Command("author", "PrimAITE Dev Team"))
|
||||
doc.preamble.append(Command("date", datetime.now().date()))
|
||||
doc.append(Command("maketitle"))
|
||||
with open(version_result_dir / f"PrimAITE v{primaite_version} Learning Benchmark.md", "w") as file:
|
||||
# Title
|
||||
file.write(f"# PrimAITE v{primaite_version} Learning Benchmark\n")
|
||||
file.write("## PrimAITE Dev Team\n")
|
||||
file.write(f"### {datetime.now().date()}\n")
|
||||
file.write("\n---\n")
|
||||
|
||||
sessions = data["total_sessions"]
|
||||
episodes = session_metadata[1]["total_episodes"] - 1
|
||||
steps = data["config"]["game"]["max_episode_length"]
|
||||
sessions = data["total_sessions"]
|
||||
episodes = session_metadata[1]["total_episodes"] - 1
|
||||
steps = data["config"]["game"]["max_episode_length"]
|
||||
|
||||
# Body
|
||||
with doc.create(Section("Introduction")):
|
||||
doc.append(
|
||||
# Body
|
||||
file.write("## 1 Introduction\n")
|
||||
file.write(
|
||||
f"PrimAITE v{primaite_version} was benchmarked automatically upon release. Learning rate metrics "
|
||||
f"were captured to be referenced during system-level testing and user acceptance testing (UAT)."
|
||||
f"were captured to be referenced during system-level testing and user acceptance testing (UAT).\n"
|
||||
)
|
||||
doc.append(
|
||||
f"\nThe benchmarking process consists of running {sessions} training session using the same "
|
||||
file.write(
|
||||
f"The benchmarking process consists of running {sessions} training session using the same "
|
||||
f"config file. Each session trains an agent for {episodes} episodes, "
|
||||
f"with each episode consisting of {steps} steps."
|
||||
f"with each episode consisting of {steps} steps.\n"
|
||||
)
|
||||
doc.append(
|
||||
f"\nThe total reward per episode from each session is captured. This is then used to calculate an "
|
||||
file.write(
|
||||
f"The total reward per episode from each session is captured. This is then used to calculate an "
|
||||
f"caverage total reward per episode from the {sessions} individual sessions for smoothing. "
|
||||
f"Finally, a 25-widow rolling average of the average total reward per session is calculated for "
|
||||
f"further smoothing."
|
||||
f"further smoothing.\n"
|
||||
)
|
||||
|
||||
with doc.create(Section("System Information")):
|
||||
with doc.create(Subsection("Python")):
|
||||
with doc.create(Tabular("|l|l|")) as table:
|
||||
table.add_hline()
|
||||
table.add_row((bold("Version"), sys.version))
|
||||
table.add_hline()
|
||||
file.write("## 2 System Information\n")
|
||||
i = 1
|
||||
file.write(f"### 2.{i} Python\n")
|
||||
file.write(f"**Version:** {sys.version}\n")
|
||||
|
||||
for section, section_data in data["system_info"].items():
|
||||
i += 1
|
||||
if section_data:
|
||||
with doc.create(Subsection(section)):
|
||||
if isinstance(section_data, dict):
|
||||
with doc.create(Tabular("|l|l|")) as table:
|
||||
table.add_hline()
|
||||
for key, value in section_data.items():
|
||||
table.add_row((bold(key), value))
|
||||
table.add_hline()
|
||||
elif isinstance(section_data, list):
|
||||
headers = section_data[0].keys()
|
||||
tabs_str = "|".join(["l" for _ in range(len(headers))])
|
||||
tabs_str = f"|{tabs_str}|"
|
||||
with doc.create(Tabular(tabs_str)) as table:
|
||||
table.add_hline()
|
||||
table.add_row([bold(h) for h in headers])
|
||||
table.add_hline()
|
||||
for item in section_data:
|
||||
table.add_row(item.values())
|
||||
table.add_hline()
|
||||
file.write(f"### 2.{i} {section}\n")
|
||||
if isinstance(section_data, dict):
|
||||
for key, value in section_data.items():
|
||||
file.write(f"- **{key}:** {value}\n")
|
||||
|
||||
headers_map = {
|
||||
"total_sessions": "Total Sessions",
|
||||
"total_episodes": "Total Episodes",
|
||||
"total_time_steps": "Total Steps",
|
||||
"av_s_per_session": "Av Session Duration (s)",
|
||||
"av_s_per_step": "Av Step Duration (s)",
|
||||
"av_s_per_100_steps_10_nodes": "Av Duration per 100 Steps per 10 Nodes (s)",
|
||||
}
|
||||
with doc.create(Section("Stats")):
|
||||
with doc.create(Subsection("Benchmark Results")):
|
||||
with doc.create(Tabular("|l|l|")) as table:
|
||||
table.add_hline()
|
||||
for section, header in headers_map.items():
|
||||
if section.startswith("av_"):
|
||||
table.add_row((bold(header), f"{data[section]:.4f}"))
|
||||
else:
|
||||
table.add_row((bold(header), data[section]))
|
||||
table.add_hline()
|
||||
headers_map = {
|
||||
"total_sessions": "Total Sessions",
|
||||
"total_episodes": "Total Episodes",
|
||||
"total_time_steps": "Total Steps",
|
||||
"av_s_per_session": "Av Session Duration (s)",
|
||||
"av_s_per_step": "Av Step Duration (s)",
|
||||
"av_s_per_100_steps_10_nodes": "Av Duration per 100 Steps per 10 Nodes (s)",
|
||||
}
|
||||
|
||||
with doc.create(Section("Graphs")):
|
||||
with doc.create(Subsection(f"v{primaite_version} Learning Benchmark Plot")):
|
||||
with doc.create(LatexFigure(position="h!")) as pic:
|
||||
pic.add_image(str(this_version_plot_path))
|
||||
pic.add_caption(f"PrimAITE {primaite_version} Learning Benchmark Plot")
|
||||
file.write("## 3 Stats\n")
|
||||
for section, header in headers_map.items():
|
||||
if section.startswith("av_"):
|
||||
file.write(f"- **{header}:** {data[section]:.4f}\n")
|
||||
else:
|
||||
file.write(f"- **{header}:** {data[section]}\n")
|
||||
|
||||
with doc.create(Subsection(f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*")):
|
||||
with doc.create(LatexFigure(position="h!")) as pic:
|
||||
pic.add_image(str(all_version_plot_path))
|
||||
pic.add_caption(f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*")
|
||||
file.write("## 4 Graphs\n")
|
||||
|
||||
doc.generate_pdf(str(this_version_plot_path).replace(".png", ""), clean_tex=True)
|
||||
file.write(f"### 4.1 v{primaite_version} Learning Benchmark Plot\n")
|
||||
file.write(f"\n")
|
||||
|
||||
file.write(f"### 4.2 Learning Benchmarking of All Released Versions under Major v{major_v}.#.#\n")
|
||||
file.write(
|
||||
f"\n"
|
||||
)
|
||||
|
||||
@@ -16,3 +16,12 @@ The type of software that should be added. To add |SOFTWARE_NAME| this must be |
|
||||
===========
|
||||
|
||||
The configuration options are the attributes that fall under the options for an application.
|
||||
|
||||
|
||||
|
||||
``fix_duration``
|
||||
""""""""""""""""
|
||||
|
||||
Optional. Default value is ``2``.
|
||||
|
||||
The number of timesteps the |SOFTWARE_NAME| will remain in a ``FIXING`` state before going into a ``GOOD`` state.
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
applications/*
|
||||
|
||||
More info :py:mod:`primaite.game.game.APPLICATION_TYPES_MAPPING`
|
||||
More info :py:mod:`primaite.simulator.system.applications.application.Application`
|
||||
|
||||
.. include:: list_of_system_applications.rst
|
||||
|
||||
|
||||
@@ -64,7 +64,6 @@ dev = [
|
||||
"gputil==1.4.0",
|
||||
"pip-licenses==4.3.0",
|
||||
"pre-commit==2.20.0",
|
||||
"pylatex==1.4.1",
|
||||
"pytest==7.2.0",
|
||||
"pytest-xdist==3.3.1",
|
||||
"pytest-cov==4.0.0",
|
||||
@@ -73,7 +72,9 @@ dev = [
|
||||
"Sphinx==7.1.2",
|
||||
"sphinx-copybutton==0.5.2",
|
||||
"wheel==0.38.4",
|
||||
"nbsphinx==0.9.4"
|
||||
"nbsphinx==0.9.4",
|
||||
"nbmake==1.5.4",
|
||||
"pytest-xdist==3.3.1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -739,8 +739,6 @@ agents:
|
||||
options:
|
||||
agent_name: client_2_green_user
|
||||
|
||||
|
||||
|
||||
agent_settings:
|
||||
flatten_obs: true
|
||||
|
||||
|
||||
@@ -14,9 +14,10 @@ from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Literal, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from gymnasium import spaces
|
||||
from pydantic import BaseModel, Field, field_validator, ValidationInfo
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, ValidationInfo
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -228,7 +229,7 @@ class NodeApplicationInstallAction(AbstractAction):
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes}
|
||||
|
||||
def form_request(self, node_id: int, application_name: str, ip_address: str) -> List[str]:
|
||||
def form_request(self, node_id: int, application_name: str) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
if node_name is None:
|
||||
@@ -241,10 +242,81 @@ class NodeApplicationInstallAction(AbstractAction):
|
||||
"application",
|
||||
"install",
|
||||
application_name,
|
||||
ip_address,
|
||||
]
|
||||
|
||||
|
||||
class ConfigureDatabaseClientAction(AbstractAction):
|
||||
"""Action which sets config parameters for a database client on a node."""
|
||||
|
||||
class _Opts(BaseModel):
|
||||
"""Schema for options that can be passed to this action."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
server_ip_address: Optional[str] = None
|
||||
server_password: Optional[str] = None
|
||||
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
|
||||
def form_request(self, node_id: int, config: Dict) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
if node_name is None:
|
||||
return ["do_nothing"]
|
||||
ConfigureDatabaseClientAction._Opts.model_validate(config) # check that options adhere to schema
|
||||
return ["network", "node", node_name, "application", "DatabaseClient", "configure", config]
|
||||
|
||||
|
||||
class ConfigureRansomwareScriptAction(AbstractAction):
|
||||
"""Action which sets config parameters for a ransomware script on a node."""
|
||||
|
||||
class _Opts(BaseModel):
|
||||
"""Schema for options that can be passed to this option."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
server_ip_address: Optional[str] = None
|
||||
server_password: Optional[str] = None
|
||||
payload: Optional[str] = None
|
||||
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
|
||||
def form_request(self, node_id: int, config: Dict) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
if node_name is None:
|
||||
return ["do_nothing"]
|
||||
ConfigureRansomwareScriptAction._Opts.model_validate(config) # check that options adhere to schema
|
||||
return ["network", "node", node_name, "application", "RansomwareScript", "configure", config]
|
||||
|
||||
|
||||
class ConfigureDoSBotAction(AbstractAction):
|
||||
"""Action which sets config parameters for a DoS bot on a node."""
|
||||
|
||||
class _Opts(BaseModel):
|
||||
"""Schema for options that can be passed to this option."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
target_ip_address: Optional[str] = None
|
||||
target_port: Optional[str] = None
|
||||
payload: Optional[str] = None
|
||||
repeat: Optional[bool] = None
|
||||
port_scan_p_of_success: Optional[float] = None
|
||||
dos_intensity: Optional[float] = None
|
||||
max_sessions: Optional[int] = None
|
||||
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
|
||||
def form_request(self, node_id: int, config: Dict) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
if node_name is None:
|
||||
return ["do_nothing"]
|
||||
self._Opts.model_validate(config) # check that options adhere to schema
|
||||
return ["network", "node", node_name, "application", "DoSBot", "configure", config]
|
||||
|
||||
|
||||
class NodeApplicationRemoveAction(AbstractAction):
|
||||
"""Action which removes/uninstalls an application."""
|
||||
|
||||
@@ -1045,6 +1117,9 @@ class ActionManager:
|
||||
"NODE_NMAP_PING_SCAN": NodeNMAPPingScanAction,
|
||||
"NODE_NMAP_PORT_SCAN": NodeNMAPPortScanAction,
|
||||
"NODE_NMAP_NETWORK_SERVICE_RECON": NodeNetworkServiceReconAction,
|
||||
"CONFIGURE_DATABASE_CLIENT": ConfigureDatabaseClientAction,
|
||||
"CONFIGURE_RANSOMWARE_SCRIPT": ConfigureRansomwareScriptAction,
|
||||
"CONFIGURE_DOSBOT": ConfigureDoSBotAction,
|
||||
}
|
||||
"""Dictionary which maps action type strings to the corresponding action class."""
|
||||
|
||||
|
||||
@@ -360,6 +360,38 @@ class SharedReward(AbstractReward):
|
||||
return cls(agent_name=agent_name)
|
||||
|
||||
|
||||
class ActionPenalty(AbstractReward):
|
||||
"""Apply a negative reward when taking any action except DONOTHING."""
|
||||
|
||||
def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None:
|
||||
"""
|
||||
Initialise the reward.
|
||||
|
||||
Reward or penalise agents for doing nothing or taking actions.
|
||||
|
||||
:param action_penalty: Reward to give agents for taking any action except DONOTHING
|
||||
:type action_penalty: float
|
||||
:param do_nothing_penalty: Reward to give agent for taking the DONOTHING action
|
||||
:type do_nothing_penalty: float
|
||||
"""
|
||||
self.action_penalty = action_penalty
|
||||
self.do_nothing_penalty = do_nothing_penalty
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the penalty to be applied."""
|
||||
if last_action_response.action == "DONOTHING":
|
||||
return self.do_nothing_penalty
|
||||
else:
|
||||
return self.action_penalty
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "ActionPenalty":
|
||||
"""Build the ActionPenalty object from config."""
|
||||
action_penalty = config.get("action_penalty", -1.0)
|
||||
do_nothing_penalty = config.get("do_nothing_penalty", 0.0)
|
||||
return cls(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty)
|
||||
|
||||
|
||||
class RewardFunction:
|
||||
"""Manages the reward function for the agent."""
|
||||
|
||||
@@ -370,6 +402,7 @@ class RewardFunction:
|
||||
"WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty,
|
||||
"GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty,
|
||||
"SHARED_REWARD": SharedReward,
|
||||
"ACTION_PENALTY": ActionPenalty,
|
||||
}
|
||||
"""List of reward class identifiers."""
|
||||
|
||||
|
||||
@@ -55,7 +55,6 @@ class TAP001(AbstractScriptedAgent):
|
||||
return "NODE_APPLICATION_INSTALL", {
|
||||
"node_id": self.starting_node_idx,
|
||||
"application_name": "RansomwareScript",
|
||||
"ip_address": self.ip_address,
|
||||
}
|
||||
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
|
||||
|
||||
@@ -26,11 +26,14 @@ from primaite.simulator.network.hardware.nodes.network.wireless_router import Wi
|
||||
from primaite.simulator.network.nmne import NmneData, store_nmne_config
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient # noqa: F401
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import ( # noqa: F401
|
||||
DataManipulationBot,
|
||||
)
|
||||
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot # noqa: F401
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript # noqa: F401
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser # noqa: F401
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
@@ -42,15 +45,6 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
APPLICATION_TYPES_MAPPING = {
|
||||
"WebBrowser": WebBrowser,
|
||||
"DatabaseClient": DatabaseClient,
|
||||
"DataManipulationBot": DataManipulationBot,
|
||||
"DoSBot": DoSBot,
|
||||
"RansomwareScript": RansomwareScript,
|
||||
}
|
||||
"""List of available applications that can be installed on nodes in the PrimAITE Simulation."""
|
||||
|
||||
SERVICE_TYPES_MAPPING = {
|
||||
"DNSClient": DNSClient,
|
||||
"DNSServer": DNSServer,
|
||||
@@ -302,6 +296,10 @@ class PrimaiteGame:
|
||||
new_node.software_manager.install(SERVICE_TYPES_MAPPING[service_type])
|
||||
new_service = new_node.software_manager.software[service_type]
|
||||
|
||||
# fixing duration for the service
|
||||
if "fix_duration" in service_cfg.get("options", {}):
|
||||
new_service.fixing_duration = service_cfg["options"]["fix_duration"]
|
||||
|
||||
# start the service
|
||||
new_service.start()
|
||||
else:
|
||||
@@ -324,7 +322,8 @@ class PrimaiteGame:
|
||||
if "options" in service_cfg:
|
||||
opt = service_cfg["options"]
|
||||
new_service.password = opt.get("db_password", None)
|
||||
new_service.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip")))
|
||||
if "backup_server_ip" in opt:
|
||||
new_service.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip")))
|
||||
if service_type == "FTPServer":
|
||||
if "options" in service_cfg:
|
||||
opt = service_cfg["options"]
|
||||
@@ -338,9 +337,13 @@ class PrimaiteGame:
|
||||
new_application = None
|
||||
application_type = application_cfg["type"]
|
||||
|
||||
if application_type in APPLICATION_TYPES_MAPPING:
|
||||
new_node.software_manager.install(APPLICATION_TYPES_MAPPING[application_type])
|
||||
new_application = new_node.software_manager.software[application_type]
|
||||
if application_type in Application._application_registry:
|
||||
new_node.software_manager.install(Application._application_registry[application_type])
|
||||
new_application = new_node.software_manager.software[application_type] # grab the instance
|
||||
|
||||
# fixing duration for the application
|
||||
if "fix_duration" in application_cfg.get("options", {}):
|
||||
new_application.fixing_duration = application_cfg["options"]["fix_duration"]
|
||||
else:
|
||||
msg = f"Configuration contains an invalid application type: {application_type}"
|
||||
_LOGGER.error(msg)
|
||||
@@ -363,7 +366,7 @@ class PrimaiteGame:
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
server_ip_address=IPv4Address(opt.get("server_ip")),
|
||||
server_ip_address=IPv4Address(opt.get("server_ip")) if opt.get("server_ip") else None,
|
||||
server_password=opt.get("server_password"),
|
||||
payload=opt.get("payload", "ENCRYPT"),
|
||||
)
|
||||
|
||||
@@ -507,8 +507,8 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now service 1 on node 2 has `health_status = 3`, indicating that the webapp is compromised.\n",
|
||||
"File 1 in folder 1 on node 3 has `health_status = 2`, indicating that the database file is compromised."
|
||||
"Now service 1 on HOST1 has `health_status = 3`, indicating that the webapp is compromised.\n",
|
||||
"File 1 in folder 1 on HOST2 has `health_status = 2`, indicating that the database file is compromised."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -545,9 +545,9 @@
|
||||
"source": [
|
||||
"The fixing takes two steps, so the reward hasn't changed yet. Let's do nothing for another timestep, the reward should improve.\n",
|
||||
"\n",
|
||||
"The reward will increase slightly as soon as the file finishes restoring. Then, the reward will increase to 1 when both green agents make successful requests.\n",
|
||||
"The reward will increase slightly as soon as the file finishes restoring. Then, the reward will increase to 0.9 when both green agents make successful requests.\n",
|
||||
"\n",
|
||||
"Run the following cell until the green action is `NODE_APPLICATION_EXECUTE` for application 0, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again."
|
||||
"Run the following cell until the green action is `NODE_APPLICATION_EXECUTE` for application 0, then the reward should increase. If you run it enough times, another red attack will happen and the reward will drop again."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -708,7 +708,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.8"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -143,7 +143,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -171,7 +171,7 @@
|
||||
"from primaite.simulator.file_system.file_system import FileSystem\n",
|
||||
"\n",
|
||||
"# no applications exist yet so we will create our own.\n",
|
||||
"class MSPaint(Application):\n",
|
||||
"class MSPaint(Application, identifier=\"MSPaint\"):\n",
|
||||
" def describe_state(self):\n",
|
||||
" return super().describe_state()"
|
||||
]
|
||||
|
||||
@@ -196,7 +196,7 @@ class SimComponent(BaseModel):
|
||||
|
||||
..code::python
|
||||
|
||||
class WebBrowser(Application):
|
||||
class WebBrowser(Application, identifier="WebBrowser"):
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager() # all requests generic to any Application get initialised
|
||||
rm.add_request(...) # initialise any requests specific to the web browser
|
||||
|
||||
@@ -6,7 +6,7 @@ import secrets
|
||||
from abc import ABC, abstractmethod
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
|
||||
from typing import Any, ClassVar, Dict, Optional, TypeVar, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -878,6 +878,61 @@ class Node(SimComponent):
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
|
||||
def _install_application(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
"""
|
||||
Allows agents to install applications to the node.
|
||||
|
||||
:param request: list containing the application name as the only element
|
||||
:type request: RequestFormat
|
||||
:param context: additional context for resolving this action, currently unused
|
||||
:type context: dict
|
||||
:return: Request response with a success code if the application was installed.
|
||||
:rtype: RequestResponse
|
||||
"""
|
||||
application_name = request[0]
|
||||
if self.software_manager.software.get(application_name):
|
||||
self.sys_log.warning(f"Can't install {application_name}. It's already installed.")
|
||||
return RequestResponse.from_bool(False)
|
||||
application_class = Application._application_registry[application_name]
|
||||
self.software_manager.install(application_class)
|
||||
application_instance = self.software_manager.software.get(application_name)
|
||||
self.applications[application_instance.uuid] = application_instance
|
||||
_LOGGER.debug(f"Added application {application_instance.name} to node {self.hostname}")
|
||||
self._application_request_manager.add_request(
|
||||
application_name, RequestType(func=application_instance._request_manager)
|
||||
)
|
||||
application_instance.install()
|
||||
if application_name in self.software_manager.software:
|
||||
return RequestResponse.from_bool(True)
|
||||
else:
|
||||
return RequestResponse.from_bool(False)
|
||||
|
||||
def _uninstall_application(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
"""
|
||||
Uninstall and completely remove application from this node.
|
||||
|
||||
This method is useful for allowing agents to take this action.
|
||||
|
||||
:param request: list containing the application name as the only element
|
||||
:type request: RequestFormat
|
||||
:param context: additional context for resolving this action, currently unused
|
||||
:type context: dict
|
||||
:return: Request response with a success code if the application was uninstalled.
|
||||
:rtype: RequestResponse
|
||||
"""
|
||||
application_name = request[0]
|
||||
if application_name not in self.software_manager.software:
|
||||
self.sys_log.warning(f"Can't uninstall {application_name}. It's not installed.")
|
||||
return RequestResponse.from_bool(False)
|
||||
|
||||
application_instance = self.software_manager.software.get(application_name)
|
||||
self.software_manager.uninstall(application_instance.name)
|
||||
if application_instance.name not in self.software_manager.software:
|
||||
return RequestResponse.from_bool(True)
|
||||
else:
|
||||
return RequestResponse.from_bool(False)
|
||||
|
||||
_node_is_on = Node._NodeIsOnValidator(node=self)
|
||||
|
||||
rm = super()._init_request_manager()
|
||||
@@ -934,25 +989,8 @@ class Node(SimComponent):
|
||||
name="application", request_type=RequestType(func=self._application_manager)
|
||||
)
|
||||
|
||||
self._application_manager.add_request(
|
||||
name="install",
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.application_install_action(
|
||||
application=self._read_application_type(request[0]), ip_address=request[1]
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
self._application_manager.add_request(
|
||||
name="uninstall",
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.application_uninstall_action(application=self._read_application_type(request[0]))
|
||||
)
|
||||
),
|
||||
)
|
||||
self._application_manager.add_request(name="install", request_type=RequestType(func=_install_application))
|
||||
self._application_manager.add_request(name="uninstall", request_type=RequestType(func=_uninstall_application))
|
||||
|
||||
return rm
|
||||
|
||||
@@ -960,29 +998,6 @@ class Node(SimComponent):
|
||||
"""Install System Software - software that is usually provided with the OS."""
|
||||
pass
|
||||
|
||||
def _read_application_type(self, application_class_str: str) -> Type[IOSoftwareClass]:
|
||||
"""Wrapper that converts the string from the request manager into the appropriate class for the application."""
|
||||
if application_class_str == "DoSBot":
|
||||
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
|
||||
|
||||
return DoSBot
|
||||
elif application_class_str == "DataManipulationBot":
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import (
|
||||
DataManipulationBot,
|
||||
)
|
||||
|
||||
return DataManipulationBot
|
||||
elif application_class_str == "WebBrowser":
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
|
||||
return WebBrowser
|
||||
elif application_class_str == "RansomwareScript":
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
|
||||
|
||||
return RansomwareScript
|
||||
else:
|
||||
return 0
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
@@ -1411,76 +1426,6 @@ class Node(SimComponent):
|
||||
self.sys_log.info(f"Uninstalled application {application.name}")
|
||||
self._application_request_manager.remove_request(application.name)
|
||||
|
||||
def application_install_action(self, application: Application, ip_address: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Install an application on this node and configure it.
|
||||
|
||||
This method is useful for allowing agents to take this action.
|
||||
|
||||
:param application: Application object that has not been installed on any node yet.
|
||||
:type application: Application
|
||||
:param ip_address: IP address used to configure the application
|
||||
(target IP for the DoSBot or server IP for the DataManipulationBot)
|
||||
:type ip_address: str
|
||||
:return: True if the application is installed successfully, otherwise False.
|
||||
"""
|
||||
if application in self:
|
||||
_LOGGER.warning(
|
||||
f"Can't add application {application.__name__}" + f"to node {self.hostname}. It's already installed."
|
||||
)
|
||||
return True
|
||||
|
||||
self.software_manager.install(application)
|
||||
application_instance = self.software_manager.software.get(str(application.__name__))
|
||||
self.applications[application_instance.uuid] = application_instance
|
||||
_LOGGER.debug(f"Added application {application_instance.name} to node {self.hostname}")
|
||||
self._application_request_manager.add_request(
|
||||
application_instance.name, RequestType(func=application_instance._request_manager)
|
||||
)
|
||||
|
||||
# Configure application if additional parameters are given
|
||||
if ip_address:
|
||||
if application_instance.name == "DoSBot":
|
||||
application_instance.configure(target_ip_address=IPv4Address(ip_address))
|
||||
elif application_instance.name == "DataManipulationBot":
|
||||
application_instance.configure(server_ip_address=IPv4Address(ip_address))
|
||||
elif application_instance.name == "RansomwareScript":
|
||||
application_instance.configure(server_ip_address=IPv4Address(ip_address))
|
||||
else:
|
||||
pass
|
||||
application_instance.install()
|
||||
if application_instance.name in self.software_manager.software:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def application_uninstall_action(self, application: Application) -> bool:
|
||||
"""
|
||||
Uninstall and completely remove application from this node.
|
||||
|
||||
This method is useful for allowing agents to take this action.
|
||||
|
||||
:param application: Application object that is currently associated with this node.
|
||||
:type application: Application
|
||||
:return: True if the application is uninstalled successfully, otherwise False.
|
||||
"""
|
||||
if application.__name__ not in self.software_manager.software:
|
||||
_LOGGER.warning(
|
||||
f"Can't remove application {application.__name__}" + f"from node {self.hostname}. It's not installed."
|
||||
)
|
||||
return True
|
||||
|
||||
application_instance = self.software_manager.software.get(
|
||||
str(application.__name__)
|
||||
) # This works because we can't have two applications with the same name on the same node
|
||||
# self.uninstall_application(application_instance)
|
||||
self.software_manager.uninstall(application_instance.name)
|
||||
|
||||
if application_instance.name not in self.software_manager.software:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _shut_down_actions(self):
|
||||
"""Actions to perform when the node is shut down."""
|
||||
# Turn off all the services in the node
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Set
|
||||
from typing import Any, ClassVar, Dict, Optional, Set, Type
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
@@ -39,6 +39,22 @@ class Application(IOSoftware):
|
||||
install_countdown: Optional[int] = None
|
||||
"The countdown to the end of the installation process. None if not currently installing"
|
||||
|
||||
_application_registry: ClassVar[Dict[str, Type["Application"]]] = {}
|
||||
"""Registry of application types. Automatically populated when subclasses are defined."""
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register an application type.
|
||||
|
||||
:param identifier: Uniquely specifies an application class by name. Used for finding items by config.
|
||||
:type identifier: str
|
||||
:raises ValueError: When attempting to register an application with a name that is already allocated.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier in cls._application_registry:
|
||||
raise ValueError(f"Tried to define new application {identifier}, but this name is already reserved.")
|
||||
cls._application_registry[identifier] = cls
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
@@ -8,13 +8,14 @@ from uuid import uuid4
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.utils.validators import IPV4Address
|
||||
|
||||
|
||||
class DatabaseClientConnection(BaseModel):
|
||||
@@ -54,7 +55,7 @@ class DatabaseClientConnection(BaseModel):
|
||||
self.client._disconnect(self.connection_id) # noqa
|
||||
|
||||
|
||||
class DatabaseClient(Application):
|
||||
class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
"""
|
||||
A DatabaseClient application.
|
||||
|
||||
@@ -96,6 +97,14 @@ class DatabaseClient(Application):
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request("execute", RequestType(func=lambda request, context: RequestResponse.from_bool(self.execute())))
|
||||
|
||||
def _configure(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
ip, pw = request[-1].get("server_ip_address"), request[-1].get("server_password")
|
||||
ip = None if ip is None else IPV4Address(ip)
|
||||
success = self.configure(server_ip_address=ip, server_password=pw)
|
||||
return RequestResponse.from_bool(success)
|
||||
|
||||
rm.add_request("configure", RequestType(func=_configure))
|
||||
return rm
|
||||
|
||||
def execute(self) -> bool:
|
||||
@@ -141,16 +150,17 @@ class DatabaseClient(Application):
|
||||
table.add_row([connection_id, connection.is_active])
|
||||
print(table.get_string(sortby="Connection ID"))
|
||||
|
||||
def configure(self, server_ip_address: IPv4Address, server_password: Optional[str] = None):
|
||||
def configure(self, server_ip_address: Optional[IPv4Address] = None, server_password: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Configure the DatabaseClient to communicate with a DatabaseService.
|
||||
|
||||
:param server_ip_address: The IP address of the Node the DatabaseService is on.
|
||||
:param server_password: The password on the DatabaseService.
|
||||
"""
|
||||
self.server_ip_address = server_ip_address
|
||||
self.server_password = server_password
|
||||
self.server_ip_address = server_ip_address or self.server_ip_address
|
||||
self.server_password = server_password or self.server_password
|
||||
self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.")
|
||||
return True
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""Connect the native client connection."""
|
||||
|
||||
@@ -44,7 +44,7 @@ class PortScanPayload(SimComponent):
|
||||
return state
|
||||
|
||||
|
||||
class NMAP(Application):
|
||||
class NMAP(Application, identifier="NMAP"):
|
||||
"""
|
||||
A class representing the NMAP application for network scanning.
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class DataManipulationAttackStage(IntEnum):
|
||||
"Signifies that the attack has failed."
|
||||
|
||||
|
||||
class DataManipulationBot(Application):
|
||||
class DataManipulationBot(Application, identifier="DataManipulationBot"):
|
||||
"""A bot that simulates a script which performs a SQL injection attack."""
|
||||
|
||||
payload: Optional[str] = None
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from enum import IntEnum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
@@ -29,7 +29,7 @@ class DoSAttackStage(IntEnum):
|
||||
"Attack is completed."
|
||||
|
||||
|
||||
class DoSBot(DatabaseClient):
|
||||
class DoSBot(DatabaseClient, identifier="DoSBot"):
|
||||
"""A bot that simulates a Denial of Service attack."""
|
||||
|
||||
target_ip_address: Optional[IPv4Address] = None
|
||||
@@ -71,6 +71,24 @@ class DoSBot(DatabaseClient):
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.run())),
|
||||
)
|
||||
|
||||
def _configure(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
"""
|
||||
Configure the DoSBot.
|
||||
|
||||
:param request: List with one element that is a dict of options to pass to the configure method.
|
||||
:type request: RequestFormat
|
||||
:param context: additional context for resolving this action, currently unused
|
||||
:type context: dict
|
||||
:return: Request Response object with a success code determining if the configuration was successful.
|
||||
:rtype: RequestResponse
|
||||
"""
|
||||
if "target_ip_address" in request[-1]:
|
||||
request[-1]["target_ip_address"] = IPv4Address(request[-1]["target_ip_address"])
|
||||
if "target_port" in request[-1]:
|
||||
request[-1]["target_port"] = Port[request[-1]["target_port"]]
|
||||
return RequestResponse.from_bool(self.configure(**request[-1]))
|
||||
|
||||
rm.add_request("configure", request_type=RequestType(func=_configure))
|
||||
return rm
|
||||
|
||||
def configure(
|
||||
@@ -82,7 +100,7 @@ class DoSBot(DatabaseClient):
|
||||
port_scan_p_of_success: float = 0.1,
|
||||
dos_intensity: float = 1.0,
|
||||
max_sessions: int = 1000,
|
||||
):
|
||||
) -> bool:
|
||||
"""
|
||||
Configure the Denial of Service bot.
|
||||
|
||||
@@ -90,10 +108,12 @@ class DoSBot(DatabaseClient):
|
||||
:param: target_port: The port of the target service. Optional - Default is `Port.HTTP`
|
||||
:param: payload: The payload the DoS Bot will throw at the target service. Optional - Default is `None`
|
||||
:param: repeat: If True, the bot will maintain the attack. Optional - Default is `True`
|
||||
:param: port_scan_p_of_success: The chance of the port scan being sucessful. Optional - Default is 0.1 (10%)
|
||||
:param: port_scan_p_of_success: The chance of the port scan being successful. Optional - Default is 0.1 (10%)
|
||||
:param: dos_intensity: The intensity of the DoS attack.
|
||||
Multiplied with the application's max session - Default is 1.0
|
||||
:param: max_sessions: The maximum number of sessions the DoS bot will attack with. Optional - Default is 1000
|
||||
:return: Always returns True
|
||||
:rtype: bool
|
||||
"""
|
||||
self.target_ip_address = target_ip_address
|
||||
self.target_port = target_port
|
||||
@@ -106,6 +126,7 @@ class DoSBot(DatabaseClient):
|
||||
f"{self.name}: Configured the {self.name} with {target_ip_address=}, {target_port=}, {payload=}, "
|
||||
f"{repeat=}, {port_scan_p_of_success=}, {dos_intensity=}, {max_sessions=}."
|
||||
)
|
||||
return True
|
||||
|
||||
def run(self) -> bool:
|
||||
"""Run the Denial of Service Bot."""
|
||||
@@ -117,6 +138,9 @@ class DoSBot(DatabaseClient):
|
||||
The main application loop for the Denial of Service bot.
|
||||
|
||||
The loop goes through the stages of a DoS attack.
|
||||
|
||||
:return: True if the application loop could be executed, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
@@ -126,7 +150,7 @@ class DoSBot(DatabaseClient):
|
||||
self.sys_log.warning(
|
||||
f"{self.name} is not properly configured. {self.target_ip_address=}, {self.target_port=}"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
self.clear_connections()
|
||||
self._perform_port_scan(p_of_success=self.port_scan_p_of_success)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
@@ -10,7 +10,7 @@ from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
|
||||
|
||||
class RansomwareScript(Application):
|
||||
class RansomwareScript(Application, identifier="RansomwareScript"):
|
||||
"""Ransomware Kill Chain - Designed to be used by the TAP001 Agent on the example layout Network.
|
||||
|
||||
:ivar payload: The attack stage query payload. (Default ENCRYPT)
|
||||
@@ -62,6 +62,25 @@ class RansomwareScript(Application):
|
||||
name="execute",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.attack())),
|
||||
)
|
||||
|
||||
def _configure(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
"""
|
||||
Request for configuring the target database and payload.
|
||||
|
||||
:param request: Request with one element contianing a dict of parameters for the configure method.
|
||||
:type request: RequestFormat
|
||||
:param context: additional context for resolving this action, currently unused
|
||||
:type context: dict
|
||||
:return: RequestResponse object with a success code reflecting whether the configuration could be applied.
|
||||
:rtype: RequestResponse
|
||||
"""
|
||||
ip = request[-1].get("server_ip_address")
|
||||
ip = None if ip is None else IPv4Address(ip)
|
||||
pw = request[-1].get("server_password")
|
||||
payload = request[-1].get("payload")
|
||||
return RequestResponse.from_bool(self.configure(ip, pw, payload))
|
||||
|
||||
rm.add_request("configure", request_type=RequestType(func=_configure))
|
||||
return rm
|
||||
|
||||
def run(self) -> bool:
|
||||
@@ -88,10 +107,10 @@ class RansomwareScript(Application):
|
||||
|
||||
def configure(
|
||||
self,
|
||||
server_ip_address: IPv4Address,
|
||||
server_ip_address: Optional[IPv4Address] = None,
|
||||
server_password: Optional[str] = None,
|
||||
payload: Optional[str] = None,
|
||||
):
|
||||
) -> bool:
|
||||
"""
|
||||
Configure the Ransomware Script to communicate with a DatabaseService.
|
||||
|
||||
@@ -108,6 +127,7 @@ class RansomwareScript(Application):
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}."
|
||||
)
|
||||
return True
|
||||
|
||||
def attack(self) -> bool:
|
||||
"""Perform the attack steps after opening the application."""
|
||||
|
||||
@@ -23,7 +23,7 @@ from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class WebBrowser(Application):
|
||||
class WebBrowser(Application, identifier="WebBrowser"):
|
||||
"""
|
||||
Represents a web browser in the simulation environment.
|
||||
|
||||
|
||||
@@ -197,7 +197,11 @@ class DatabaseService(Service):
|
||||
status_code = 503 # service unavailable
|
||||
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
|
||||
self.sys_log.error(f"{self.name}: Connect request for {src_ip=} declined. Service is at capacity.")
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
if self.health_state_actual in [
|
||||
SoftwareHealthState.GOOD,
|
||||
SoftwareHealthState.FIXING,
|
||||
SoftwareHealthState.COMPROMISED,
|
||||
]:
|
||||
if self.password == password:
|
||||
status_code = 200 # ok
|
||||
connection_id = self._generate_connection_id()
|
||||
@@ -244,6 +248,10 @@ class DatabaseService(Service):
|
||||
self.sys_log.error(f"{self.name}: Failed to run {query} because the database file is missing.")
|
||||
return {"status_code": 404, "type": "sql", "data": False}
|
||||
|
||||
if self.health_state_actual is not SoftwareHealthState.GOOD:
|
||||
self.sys_log.error(f"{self.name}: Failed to run {query} because the database service is unavailable.")
|
||||
return {"status_code": 500, "type": "sql", "data": False}
|
||||
|
||||
if query == "SELECT":
|
||||
if self.db_file.health_status == FileSystemItemHealthStatus.CORRUPT:
|
||||
return {
|
||||
|
||||
792
tests/assets/configs/action_penalty.yaml
Normal file
792
tests/assets/configs/action_penalty.yaml
Normal file
@@ -0,0 +1,792 @@
|
||||
io_settings:
|
||||
save_agent_actions: false
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: false
|
||||
save_sys_logs: false
|
||||
|
||||
|
||||
game:
|
||||
max_episode_length: 256
|
||||
ports:
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
thresholds:
|
||||
nmne:
|
||||
high: 10
|
||||
medium: 5
|
||||
low: 0
|
||||
|
||||
agents:
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: CUSTOM
|
||||
options:
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1:eth-1<->switch_1:eth-8
|
||||
- router_1:eth-2<->switch_2:eth-8
|
||||
- switch_1:eth-1<->domain_controller:eth-1
|
||||
- switch_1:eth-2<->web_server:eth-1
|
||||
- switch_1:eth-3<->database_server:eth-1
|
||||
- switch_1:eth-4<->backup_server:eth-1
|
||||
- switch_1:eth-7<->security_suite:eth-1
|
||||
- switch_2:eth-1<->client_1:eth-1
|
||||
- switch_2:eth-2<->client_2:eth-1
|
||||
- switch_2:eth-7<->security_suite:eth-2
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SERVICE_SCAN
|
||||
- type: NODE_SERVICE_STOP
|
||||
- type: NODE_SERVICE_START
|
||||
- type: NODE_SERVICE_PAUSE
|
||||
- type: NODE_SERVICE_RESUME
|
||||
- type: NODE_SERVICE_RESTART
|
||||
- type: NODE_SERVICE_DISABLE
|
||||
- type: NODE_SERVICE_ENABLE
|
||||
- type: NODE_SERVICE_FIX
|
||||
- type: NODE_FILE_SCAN
|
||||
- type: NODE_FILE_CHECKHASH
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_REPAIR
|
||||
- type: NODE_FILE_RESTORE
|
||||
- type: NODE_FOLDER_SCAN
|
||||
- type: NODE_FOLDER_CHECKHASH
|
||||
- type: NODE_FOLDER_REPAIR
|
||||
- type: NODE_FOLDER_RESTORE
|
||||
- type: NODE_OS_SCAN
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: NODE_RESET
|
||||
- type: ROUTER_ACL_ADDRULE
|
||||
- type: ROUTER_ACL_REMOVERULE
|
||||
- type: HOST_NIC_ENABLE
|
||||
- type: HOST_NIC_DISABLE
|
||||
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
# scan webapp service
|
||||
1:
|
||||
action: NODE_SERVICE_SCAN
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
# stop webapp service
|
||||
2:
|
||||
action: NODE_SERVICE_STOP
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
# start webapp service
|
||||
3:
|
||||
action: "NODE_SERVICE_START"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
4:
|
||||
action: "NODE_SERVICE_PAUSE"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
5:
|
||||
action: "NODE_SERVICE_RESUME"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
6:
|
||||
action: "NODE_SERVICE_RESTART"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
7:
|
||||
action: "NODE_SERVICE_DISABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
8:
|
||||
action: "NODE_SERVICE_ENABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
9: # check database.db file
|
||||
action: "NODE_FILE_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
10:
|
||||
action: "NODE_FILE_CHECKHASH"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
11:
|
||||
action: "NODE_FILE_DELETE"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
12:
|
||||
action: "NODE_FILE_REPAIR"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
13:
|
||||
action: "NODE_SERVICE_FIX"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 0
|
||||
14:
|
||||
action: "NODE_FOLDER_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
15:
|
||||
action: "NODE_FOLDER_CHECKHASH"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
16:
|
||||
action: "NODE_FOLDER_REPAIR"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
17:
|
||||
action: "NODE_FOLDER_RESTORE"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
18:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 0
|
||||
19:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 0
|
||||
20:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 0
|
||||
21:
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 0
|
||||
22:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 1
|
||||
23:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 1
|
||||
24:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 1
|
||||
25:
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 1
|
||||
26: # old action num: 18
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
27:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 2
|
||||
28:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 2
|
||||
29:
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 2
|
||||
30:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
31:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 3
|
||||
32:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 3
|
||||
33:
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 3
|
||||
34:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 4
|
||||
35:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 4
|
||||
36:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 4
|
||||
37:
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 4
|
||||
38:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 5
|
||||
39: # old action num: 19 # shutdown client 1
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 5
|
||||
40: # old action num: 20
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 5
|
||||
41: # old action num: 21
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 5
|
||||
42:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 6
|
||||
43:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 6
|
||||
44:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 6
|
||||
45:
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 6
|
||||
|
||||
46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1"
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2"
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 2
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
48: # old action num: 24 # block tcp traffic from client 1 to web app
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 3
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
49: # old action num: 25 # block tcp traffic from client 2 to web app
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 4
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
50: # old action num: 26
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 5
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
51: # old action num: 27
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 6
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
52: # old action num: 28
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 0
|
||||
53: # old action num: 29
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 1
|
||||
54: # old action num: 30
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 2
|
||||
55: # old action num: 31
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 3
|
||||
56: # old action num: 32
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 4
|
||||
57: # old action num: 33
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 5
|
||||
58: # old action num: 34
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 6
|
||||
59: # old action num: 35
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 7
|
||||
60: # old action num: 36
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 8
|
||||
61: # old action num: 37
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 9
|
||||
62: # old action num: 38
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
63: # old action num: 39
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
64: # old action num: 40
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
65: # old action num: 41
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
66: # old action num: 42
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 0
|
||||
67: # old action num: 43
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 0
|
||||
68: # old action num: 44
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 0
|
||||
69: # old action num: 45
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 0
|
||||
70: # old action num: 46
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 0
|
||||
71: # old action num: 47
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 0
|
||||
72: # old action num: 48
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
73: # old action num: 49
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
74: # old action num: 50
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 0
|
||||
75: # old action num: 51
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 0
|
||||
76: # old action num: 52
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 0
|
||||
77: # old action num: 53
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 0
|
||||
|
||||
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_name: domain_controller
|
||||
- node_name: web_server
|
||||
applications:
|
||||
- application_name: DatabaseClient
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- node_name: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
services:
|
||||
- service_name: DatabaseService
|
||||
- node_name: backup_server
|
||||
- node_name: security_suite
|
||||
- node_name: client_1
|
||||
- node_name: client_2
|
||||
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: ACTION_PENALTY
|
||||
weight: 1.0
|
||||
options:
|
||||
action_penalty: -0.75
|
||||
do_nothing_penalty: 0.125
|
||||
|
||||
|
||||
agent_settings:
|
||||
flatten_obs: true
|
||||
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nmne_config:
|
||||
capture_nmne: true
|
||||
nmne_capture_keywords:
|
||||
- DELETE
|
||||
nodes:
|
||||
|
||||
- hostname: router_1
|
||||
type: router
|
||||
num_ports: 5
|
||||
ports:
|
||||
1:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
2:
|
||||
ip_address: 192.168.10.1
|
||||
subnet_mask: 255.255.255.0
|
||||
acl:
|
||||
18:
|
||||
action: PERMIT
|
||||
src_port: POSTGRES_SERVER
|
||||
dst_port: POSTGRES_SERVER
|
||||
19:
|
||||
action: PERMIT
|
||||
src_port: DNS
|
||||
dst_port: DNS
|
||||
20:
|
||||
action: PERMIT
|
||||
src_port: FTP
|
||||
dst_port: FTP
|
||||
21:
|
||||
action: PERMIT
|
||||
src_port: HTTP
|
||||
dst_port: HTTP
|
||||
22:
|
||||
action: PERMIT
|
||||
src_port: ARP
|
||||
dst_port: ARP
|
||||
23:
|
||||
action: PERMIT
|
||||
protocol: ICMP
|
||||
|
||||
- hostname: switch_1
|
||||
type: switch
|
||||
num_ports: 8
|
||||
|
||||
- hostname: switch_2
|
||||
type: switch
|
||||
num_ports: 8
|
||||
|
||||
- hostname: domain_controller
|
||||
type: server
|
||||
ip_address: 192.168.1.10
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
services:
|
||||
- type: DNSServer
|
||||
options:
|
||||
domain_mapping:
|
||||
arcd.com: 192.168.1.12 # web server
|
||||
|
||||
- hostname: web_server
|
||||
type: server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- type: WebServer
|
||||
applications:
|
||||
- type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
|
||||
|
||||
- hostname: database_server
|
||||
type: server
|
||||
ip_address: 192.168.1.14
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- type: DatabaseService
|
||||
options:
|
||||
backup_server_ip: 192.168.1.16
|
||||
- type: FTPClient
|
||||
|
||||
- hostname: backup_server
|
||||
type: server
|
||||
ip_address: 192.168.1.16
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- type: FTPServer
|
||||
|
||||
- hostname: security_suite
|
||||
type: server
|
||||
ip_address: 192.168.1.110
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
network_interfaces:
|
||||
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
|
||||
ip_address: 192.168.10.110
|
||||
subnet_mask: 255.255.255.0
|
||||
|
||||
- hostname: client_1
|
||||
type: computer
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- type: DataManipulationBot
|
||||
options:
|
||||
port_scan_p_of_success: 0.8
|
||||
data_manipulation_p_of_success: 0.8
|
||||
payload: "DELETE"
|
||||
server_ip: 192.168.1.14
|
||||
- type: WebBrowser
|
||||
options:
|
||||
target_url: http://arcd.com/users/
|
||||
- type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
services:
|
||||
- type: DNSClient
|
||||
|
||||
- hostname: client_2
|
||||
type: computer
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- type: WebBrowser
|
||||
options:
|
||||
target_url: http://arcd.com/users/
|
||||
- type: DataManipulationBot
|
||||
options:
|
||||
port_scan_p_of_success: 0.8
|
||||
data_manipulation_p_of_success: 0.8
|
||||
payload: "DELETE"
|
||||
server_ip: 192.168.1.14
|
||||
- type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
services:
|
||||
- type: DNSClient
|
||||
|
||||
|
||||
|
||||
links:
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 8
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 8
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: domain_controller
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_hostname: web_server
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 3
|
||||
endpoint_b_hostname: database_server
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 4
|
||||
endpoint_b_hostname: backup_server
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_hostname: security_suite
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: client_1
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_2
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_hostname: client_2
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_2
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_hostname: security_suite
|
||||
endpoint_b_port: 2
|
||||
248
tests/assets/configs/fix_duration_one_item.yaml
Normal file
248
tests/assets/configs/fix_duration_one_item.yaml
Normal file
@@ -0,0 +1,248 @@
|
||||
# Basic Switched network
|
||||
#
|
||||
# -------------- -------------- --------------
|
||||
# | client_1 |------| switch_1 |------| client_2 |
|
||||
# -------------- -------------- --------------
|
||||
#
|
||||
io_settings:
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: true
|
||||
save_sys_logs: true
|
||||
sys_log_level: WARNING
|
||||
|
||||
|
||||
game:
|
||||
max_episode_length: 256
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_2
|
||||
applications:
|
||||
- application_name: WebBrowser
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_settings:
|
||||
start_step: 5
|
||||
frequency: 4
|
||||
variance: 3
|
||||
|
||||
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: CUSTOM
|
||||
options:
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
- hostname: client_3
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
monitored_traffic:
|
||||
icmp:
|
||||
- NONE
|
||||
tcp:
|
||||
- DNS
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.23
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- switch_1:eth-1<->client_1:eth-1
|
||||
- switch_1:eth-2<->client_2:eth-1
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
options:
|
||||
nodes:
|
||||
- node_name: switch
|
||||
- node_name: client_1
|
||||
- node_name: client_2
|
||||
- node_name: client_3
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_list:
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.23
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_hostname: web_server
|
||||
service_name: web_server_web_service
|
||||
|
||||
|
||||
agent_settings:
|
||||
flatten_obs: true
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
|
||||
- type: switch
|
||||
hostname: switch_1
|
||||
num_ports: 8
|
||||
|
||||
- hostname: client_1
|
||||
type: computer
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- type: RansomwareScript
|
||||
- type: WebBrowser
|
||||
options:
|
||||
target_url: http://arcd.com/users/
|
||||
- type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.10
|
||||
server_password: arcd
|
||||
fix_duration: 1
|
||||
- type: DataManipulationBot
|
||||
options:
|
||||
port_scan_p_of_success: 0.8
|
||||
data_manipulation_p_of_success: 0.8
|
||||
payload: "DELETE"
|
||||
server_ip: 192.168.1.21
|
||||
server_password: arcd
|
||||
- type: DoSBot
|
||||
options:
|
||||
target_ip_address: 192.168.10.21
|
||||
payload: SPOOF DATA
|
||||
port_scan_p_of_success: 0.8
|
||||
services:
|
||||
- type: DNSClient
|
||||
options:
|
||||
dns_server: 192.168.1.10
|
||||
- type: DNSServer
|
||||
options:
|
||||
domain_mapping:
|
||||
arcd.com: 192.168.1.10
|
||||
- type: DatabaseService
|
||||
options:
|
||||
fix_duration: 5
|
||||
backup_server_ip: 192.168.1.10
|
||||
- type: WebServer
|
||||
- type: FTPClient
|
||||
- type: FTPServer
|
||||
options:
|
||||
server_password: arcd
|
||||
- type: NTPClient
|
||||
options:
|
||||
ntp_server_ip: 192.168.1.10
|
||||
- type: NTPServer
|
||||
- hostname: client_2
|
||||
type: computer
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.10
|
||||
server_password: arcd
|
||||
services:
|
||||
- type: DNSClient
|
||||
options:
|
||||
dns_server: 192.168.1.10
|
||||
|
||||
links:
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: client_1
|
||||
endpoint_b_port: 1
|
||||
bandwidth: 200
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_hostname: client_2
|
||||
endpoint_b_port: 1
|
||||
bandwidth: 200
|
||||
142
tests/assets/configs/install_and_configure_apps.yaml
Normal file
142
tests/assets/configs/install_and_configure_apps.yaml
Normal file
@@ -0,0 +1,142 @@
|
||||
io_settings:
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: false
|
||||
save_sys_logs: false
|
||||
save_agent_actions: false
|
||||
|
||||
game:
|
||||
max_episode_length: 256
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
|
||||
agents:
|
||||
- ref: agent_1
|
||||
team: BLUE
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_INSTALL
|
||||
- type: CONFIGURE_DATABASE_CLIENT
|
||||
- type: CONFIGURE_DOSBOT
|
||||
- type: CONFIGURE_RANSOMWARE_SCRIPT
|
||||
- type: NODE_APPLICATION_REMOVE
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_INSTALL
|
||||
options:
|
||||
node_id: 0
|
||||
application_name: DatabaseClient
|
||||
2:
|
||||
action: NODE_APPLICATION_INSTALL
|
||||
options:
|
||||
node_id: 1
|
||||
application_name: RansomwareScript
|
||||
3:
|
||||
action: NODE_APPLICATION_INSTALL
|
||||
options:
|
||||
node_id: 2
|
||||
application_name: DoSBot
|
||||
4:
|
||||
action: CONFIGURE_DATABASE_CLIENT
|
||||
options:
|
||||
node_id: 0
|
||||
config:
|
||||
server_ip_address: 10.0.0.5
|
||||
5:
|
||||
action: CONFIGURE_DATABASE_CLIENT
|
||||
options:
|
||||
node_id: 0
|
||||
config:
|
||||
server_password: correct_password
|
||||
6:
|
||||
action: CONFIGURE_RANSOMWARE_SCRIPT
|
||||
options:
|
||||
node_id: 1
|
||||
config:
|
||||
server_ip_address: 10.0.0.5
|
||||
server_password: correct_password
|
||||
payload: ENCRYPT
|
||||
7:
|
||||
action: CONFIGURE_DOSBOT
|
||||
options:
|
||||
node_id: 2
|
||||
config:
|
||||
target_ip_address: 10.0.0.5
|
||||
target_port: POSTGRES_SERVER
|
||||
payload: DELETE
|
||||
repeat: true
|
||||
port_scan_p_of_success: 1.0
|
||||
dos_intensity: 1.0
|
||||
max_sessions: 1000
|
||||
8:
|
||||
action: NODE_APPLICATION_INSTALL
|
||||
options:
|
||||
node_id: 1
|
||||
application_name: DatabaseClient
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
- node_name: client_2
|
||||
- node_name: client_3
|
||||
ip_list: []
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
- type: computer
|
||||
hostname: client_1
|
||||
ip_address: 10.0.0.2
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 10.0.0.1
|
||||
- type: computer
|
||||
hostname: client_2
|
||||
ip_address: 10.0.0.3
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 10.0.0.1
|
||||
- type: computer
|
||||
hostname: client_3
|
||||
ip_address: 10.0.0.4
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 10.0.0.1
|
||||
- type: switch
|
||||
hostname: switch_1
|
||||
num_ports: 8
|
||||
- type: server
|
||||
hostname: server_1
|
||||
ip_address: 10.0.0.5
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 10.0.0.1
|
||||
services:
|
||||
- type: DatabaseService
|
||||
options:
|
||||
db_password: correct_password
|
||||
links:
|
||||
- endpoint_a_hostname: client_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: client_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 2
|
||||
- endpoint_a_hostname: client_3
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 3
|
||||
- endpoint_a_hostname: server_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 8
|
||||
263
tests/assets/configs/software_fix_duration.yaml
Normal file
263
tests/assets/configs/software_fix_duration.yaml
Normal file
@@ -0,0 +1,263 @@
|
||||
# Basic Switched network
|
||||
#
|
||||
# -------------- -------------- --------------
|
||||
# | client_1 |------| switch_1 |------| client_2 |
|
||||
# -------------- -------------- --------------
|
||||
#
|
||||
io_settings:
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: true
|
||||
save_sys_logs: true
|
||||
sys_log_level: WARNING
|
||||
|
||||
|
||||
game:
|
||||
max_episode_length: 256
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_2
|
||||
applications:
|
||||
- application_name: WebBrowser
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_settings:
|
||||
start_step: 5
|
||||
frequency: 4
|
||||
variance: 3
|
||||
|
||||
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: CUSTOM
|
||||
options:
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
- hostname: client_3
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
monitored_traffic:
|
||||
icmp:
|
||||
- NONE
|
||||
tcp:
|
||||
- DNS
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.23
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- switch_1:eth-1<->client_1:eth-1
|
||||
- switch_1:eth-2<->client_2:eth-1
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
options:
|
||||
nodes:
|
||||
- node_name: switch
|
||||
- node_name: client_1
|
||||
- node_name: client_2
|
||||
- node_name: client_3
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_list:
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.23
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_hostname: web_server
|
||||
service_name: web_server_web_service
|
||||
|
||||
|
||||
agent_settings:
|
||||
flatten_obs: true
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
|
||||
- type: switch
|
||||
hostname: switch_1
|
||||
num_ports: 8
|
||||
|
||||
- hostname: client_1
|
||||
type: computer
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- type: RansomwareScript
|
||||
options:
|
||||
fix_duration: 1
|
||||
- type: WebBrowser
|
||||
options:
|
||||
target_url: http://arcd.com/users/
|
||||
fix_duration: 1
|
||||
- type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.10
|
||||
server_password: arcd
|
||||
fix_duration: 1
|
||||
- type: DataManipulationBot
|
||||
options:
|
||||
port_scan_p_of_success: 0.8
|
||||
data_manipulation_p_of_success: 0.8
|
||||
payload: "DELETE"
|
||||
server_ip: 192.168.1.21
|
||||
server_password: arcd
|
||||
fix_duration: 1
|
||||
- type: DoSBot
|
||||
options:
|
||||
target_ip_address: 192.168.10.21
|
||||
payload: SPOOF DATA
|
||||
port_scan_p_of_success: 0.8
|
||||
fix_duration: 1
|
||||
services:
|
||||
- type: DNSClient
|
||||
options:
|
||||
dns_server: 192.168.1.10
|
||||
fix_duration: 3
|
||||
- type: DNSServer
|
||||
options:
|
||||
fix_duration: 3
|
||||
domain_mapping:
|
||||
arcd.com: 192.168.1.10
|
||||
- type: DatabaseService
|
||||
options:
|
||||
backup_server_ip: 192.168.1.10
|
||||
fix_duration: 3
|
||||
- type: WebServer
|
||||
options:
|
||||
fix_duration: 3
|
||||
- type: FTPClient
|
||||
options:
|
||||
fix_duration: 3
|
||||
- type: FTPServer
|
||||
options:
|
||||
server_password: arcd
|
||||
fix_duration: 3
|
||||
- type: NTPClient
|
||||
options:
|
||||
ntp_server_ip: 192.168.1.10
|
||||
fix_duration: 3
|
||||
- type: NTPServer
|
||||
options:
|
||||
fix_duration: 3
|
||||
- hostname: client_2
|
||||
type: computer
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.10
|
||||
server_password: arcd
|
||||
services:
|
||||
- type: DNSClient
|
||||
options:
|
||||
dns_server: 192.168.1.10
|
||||
|
||||
links:
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: client_1
|
||||
endpoint_b_port: 1
|
||||
bandwidth: 200
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_hostname: client_2
|
||||
endpoint_b_port: 1
|
||||
bandwidth: 200
|
||||
@@ -260,6 +260,7 @@ agents:
|
||||
- type: NODE_APPLICATION_INSTALL
|
||||
- type: NODE_APPLICATION_REMOVE
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
- type: CONFIGURE_DOSBOT
|
||||
|
||||
action_map:
|
||||
0:
|
||||
@@ -683,7 +684,6 @@ agents:
|
||||
options:
|
||||
node_id: 0
|
||||
application_name: DoSBot
|
||||
ip_address: 192.168.1.14
|
||||
79:
|
||||
action: NODE_APPLICATION_REMOVE
|
||||
options:
|
||||
@@ -699,6 +699,14 @@ agents:
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
82:
|
||||
action: CONFIGURE_DOSBOT
|
||||
options:
|
||||
node_id: 0
|
||||
config:
|
||||
target_ip_address: 192.168.1.14
|
||||
target_port: POSTGRES_SERVER
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -51,11 +51,11 @@ class TestService(Service):
|
||||
pass
|
||||
|
||||
|
||||
class TestApplication(Application):
|
||||
class DummyApplication(Application, identifier="DummyApplication"):
|
||||
"""Test Application class"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "TestApplication"
|
||||
kwargs["name"] = "DummyApplication"
|
||||
kwargs["port"] = Port.HTTP
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
@@ -85,15 +85,15 @@ def service_class():
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def application(file_system) -> TestApplication:
|
||||
return TestApplication(
|
||||
name="TestApplication", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_application")
|
||||
def application(file_system) -> DummyApplication:
|
||||
return DummyApplication(
|
||||
name="DummyApplication", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_application")
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def application_class():
|
||||
return TestApplication
|
||||
return DummyApplication
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
||||
@@ -69,6 +69,7 @@ def test_application_install_uninstall_on_uc2():
|
||||
env.step(0)
|
||||
|
||||
# Test we can now execute the DoSBot app
|
||||
env.step(82) # configure dos bot with ip address and port
|
||||
_, _, _, _, info = env.step(81)
|
||||
assert info["agent_actions"]["defender"].response.status == "success"
|
||||
|
||||
|
||||
@@ -9,9 +9,10 @@ from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING
|
||||
from primaite.game.game import PrimaiteGame, SERVICE_TYPES_MAPPING
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
|
||||
@@ -85,7 +86,7 @@ def test_node_software_install():
|
||||
assert client_2.software_manager.software.get(software.__name__) is not None
|
||||
|
||||
# check that applications have been installed on client 1
|
||||
for applications in APPLICATION_TYPES_MAPPING:
|
||||
for applications in Application._application_registry:
|
||||
assert client_1.software_manager.software.get(applications) is not None
|
||||
|
||||
# check that services have been installed on client 1
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import copy
|
||||
from ipaddress import IPv4Address
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
|
||||
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
|
||||
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
|
||||
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fix_duration.yaml"
|
||||
ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fix_duration_one_item.yaml"
|
||||
|
||||
|
||||
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
|
||||
"""Returns a PrimaiteGame object which loads the contents of a given yaml path."""
|
||||
with open(config_path, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
return PrimaiteGame.from_config(cfg)
|
||||
|
||||
|
||||
def test_default_fix_duration():
|
||||
"""Test that software with no defined fix duration in config uses the default fix duration of 2."""
|
||||
game = load_config(TEST_CONFIG)
|
||||
client_2: Computer = game.simulation.network.get_node_by_hostname("client_2")
|
||||
|
||||
database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
|
||||
assert database_client.fixing_duration == 2
|
||||
|
||||
dns_client: DNSClient = client_2.software_manager.software.get("DNSClient")
|
||||
assert dns_client.fixing_duration == 2
|
||||
|
||||
|
||||
def test_fix_duration_set_from_config():
|
||||
"""Test to check that the fix duration set for applications and services works as intended."""
|
||||
game = load_config(TEST_CONFIG)
|
||||
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
# in config - services take 3 timesteps to fix
|
||||
for service in SERVICE_TYPES_MAPPING:
|
||||
assert client_1.software_manager.software.get(service) is not None
|
||||
assert client_1.software_manager.software.get(service).fixing_duration == 3
|
||||
|
||||
# in config - applications take 1 timestep to fix
|
||||
for applications in APPLICATION_TYPES_MAPPING:
|
||||
assert client_1.software_manager.software.get(applications) is not None
|
||||
assert client_1.software_manager.software.get(applications).fixing_duration == 1
|
||||
|
||||
|
||||
def test_fix_duration_for_one_item():
|
||||
"""Test that setting fix duration for one application does not affect other components."""
|
||||
game = load_config(ONE_ITEM_CONFIG)
|
||||
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
# in config - services take 3 timesteps to fix
|
||||
services = copy.copy(SERVICE_TYPES_MAPPING)
|
||||
services.pop("DatabaseService")
|
||||
for service in services:
|
||||
assert client_1.software_manager.software.get(service) is not None
|
||||
assert client_1.software_manager.software.get(service).fixing_duration == 2
|
||||
|
||||
# in config - applications take 1 timestep to fix
|
||||
applications = copy.copy(APPLICATION_TYPES_MAPPING)
|
||||
applications.pop("DatabaseClient")
|
||||
for applications in applications:
|
||||
assert client_1.software_manager.software.get(applications) is not None
|
||||
assert client_1.software_manager.software.get(applications).fixing_duration == 2
|
||||
|
||||
database_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
|
||||
assert database_client.fixing_duration == 1
|
||||
|
||||
database_service: DatabaseService = client_1.software_manager.software.get("DatabaseService")
|
||||
assert database_service.fixing_duration == 5
|
||||
1
tests/integration_tests/game_layer/actions/__init__.py
Normal file
1
tests/integration_tests/game_layer/actions/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
@@ -0,0 +1,292 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from primaite.game.agent.actions import (
|
||||
ConfigureDatabaseClientAction,
|
||||
ConfigureDoSBotAction,
|
||||
ConfigureRansomwareScriptAction,
|
||||
)
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
from tests.conftest import ControlledAgent
|
||||
|
||||
APP_CONFIG_YAML = TEST_ASSETS_ROOT / "configs/install_and_configure_apps.yaml"
|
||||
|
||||
|
||||
class TestConfigureDatabaseAction:
|
||||
def test_configure_ip_password(self, game_and_agent):
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager)
|
||||
|
||||
# make sure there is a database client on this node
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.install(DatabaseClient)
|
||||
db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"]
|
||||
|
||||
action = (
|
||||
"CONFIGURE_DATABASE_CLIENT",
|
||||
{
|
||||
"node_id": 0,
|
||||
"config": {
|
||||
"server_ip_address": "192.168.1.99",
|
||||
"server_password": "admin123",
|
||||
},
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
assert db_client.server_ip_address == IPv4Address("192.168.1.99")
|
||||
assert db_client.server_password == "admin123"
|
||||
|
||||
def test_configure_ip(self, game_and_agent):
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager)
|
||||
|
||||
# make sure there is a database client on this node
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.install(DatabaseClient)
|
||||
db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"]
|
||||
|
||||
action = (
|
||||
"CONFIGURE_DATABASE_CLIENT",
|
||||
{
|
||||
"node_id": 0,
|
||||
"config": {
|
||||
"server_ip_address": "192.168.1.99",
|
||||
},
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
assert db_client.server_ip_address == IPv4Address("192.168.1.99")
|
||||
assert db_client.server_password is None
|
||||
|
||||
def test_configure_password(self, game_and_agent):
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager)
|
||||
|
||||
# make sure there is a database client on this node
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.install(DatabaseClient)
|
||||
db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"]
|
||||
old_ip = db_client.server_ip_address
|
||||
|
||||
action = (
|
||||
"CONFIGURE_DATABASE_CLIENT",
|
||||
{
|
||||
"node_id": 0,
|
||||
"config": {
|
||||
"server_password": "admin123",
|
||||
},
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
assert db_client.server_ip_address == old_ip
|
||||
assert db_client.server_password is "admin123"
|
||||
|
||||
|
||||
class TestConfigureRansomwareScriptAction:
|
||||
@pytest.mark.parametrize(
|
||||
"config",
|
||||
[
|
||||
{},
|
||||
{"server_ip_address": "181.181.181.181"},
|
||||
{"server_password": "admin123"},
|
||||
{"payload": "ENCRYPT"},
|
||||
{
|
||||
"server_ip_address": "181.181.181.181",
|
||||
"server_password": "admin123",
|
||||
"payload": "ENCRYPT",
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_configure_ip_password(self, game_and_agent, config):
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
agent.action_manager.actions["CONFIGURE_RANSOMWARE_SCRIPT"] = ConfigureRansomwareScriptAction(
|
||||
agent.action_manager
|
||||
)
|
||||
|
||||
# make sure there is a database client on this node
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.install(RansomwareScript)
|
||||
ransomware_script: RansomwareScript = client_1.software_manager.software["RansomwareScript"]
|
||||
|
||||
old_ip = ransomware_script.server_ip_address
|
||||
old_pw = ransomware_script.server_password
|
||||
old_payload = ransomware_script.payload
|
||||
|
||||
action = (
|
||||
"CONFIGURE_RANSOMWARE_SCRIPT",
|
||||
{"node_id": 0, "config": config},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
expected_ip = old_ip if "server_ip_address" not in config else IPv4Address(config["server_ip_address"])
|
||||
expected_pw = old_pw if "server_password" not in config else config["server_password"]
|
||||
expected_payload = old_payload if "payload" not in config else config["payload"]
|
||||
|
||||
assert ransomware_script.server_ip_address == expected_ip
|
||||
assert ransomware_script.server_password == expected_pw
|
||||
assert ransomware_script.payload == expected_payload
|
||||
|
||||
def test_invalid_config(self, game_and_agent):
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
agent.action_manager.actions["CONFIGURE_RANSOMWARE_SCRIPT"] = ConfigureRansomwareScriptAction(
|
||||
agent.action_manager
|
||||
)
|
||||
|
||||
# make sure there is a database client on this node
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.install(RansomwareScript)
|
||||
ransomware_script: RansomwareScript = client_1.software_manager.software["RansomwareScript"]
|
||||
action = (
|
||||
"CONFIGURE_RANSOMWARE_SCRIPT",
|
||||
{
|
||||
"node_id": 0,
|
||||
"config": {"server_password": "admin123", "bad_option": 70},
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
with pytest.raises(ValidationError):
|
||||
game.step()
|
||||
|
||||
|
||||
class TestConfigureDoSBot:
|
||||
def test_configure_DoSBot(self, game_and_agent):
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
agent.action_manager.actions["CONFIGURE_DOSBOT"] = ConfigureDoSBotAction(agent.action_manager)
|
||||
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.install(DoSBot)
|
||||
dos_bot: DoSBot = client_1.software_manager.software["DoSBot"]
|
||||
|
||||
action = (
|
||||
"CONFIGURE_DOSBOT",
|
||||
{
|
||||
"node_id": 0,
|
||||
"config": {
|
||||
"target_ip_address": "192.168.1.99",
|
||||
"target_port": "POSTGRES_SERVER",
|
||||
"payload": "HACC",
|
||||
"repeat": False,
|
||||
"port_scan_p_of_success": 0.875,
|
||||
"dos_intensity": 0.75,
|
||||
"max_sessions": 50,
|
||||
},
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
assert dos_bot.target_ip_address == IPv4Address("192.168.1.99")
|
||||
assert dos_bot.target_port == Port.POSTGRES_SERVER
|
||||
assert dos_bot.payload == "HACC"
|
||||
assert not dos_bot.repeat
|
||||
assert dos_bot.port_scan_p_of_success == 0.875
|
||||
assert dos_bot.dos_intensity == 0.75
|
||||
assert dos_bot.max_sessions == 50
|
||||
|
||||
|
||||
class TestConfigureYAML:
|
||||
def test_configure_db_client(self):
|
||||
env = PrimaiteGymEnv(env_config=APP_CONFIG_YAML)
|
||||
|
||||
# make sure there's no db client on the node yet
|
||||
client_1 = env.game.simulation.network.get_node_by_hostname("client_1")
|
||||
assert client_1.software_manager.software.get("DatabaseClient") is None
|
||||
|
||||
# take the install action, check that the db gets installed, step to get it to finish installing
|
||||
env.step(1)
|
||||
db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
|
||||
assert isinstance(db_client, DatabaseClient)
|
||||
assert db_client.operating_state == ApplicationOperatingState.INSTALLING
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
|
||||
# configure the ip address and check that it changes, but password doesn't change
|
||||
assert db_client.server_ip_address is None
|
||||
assert db_client.server_password is None
|
||||
env.step(4)
|
||||
assert db_client.server_ip_address == IPv4Address("10.0.0.5")
|
||||
assert db_client.server_password is None
|
||||
|
||||
# configure the password and check that it changes, make sure this lets us connect to the db
|
||||
assert not db_client.connect()
|
||||
env.step(5)
|
||||
assert db_client.server_password == "correct_password"
|
||||
assert db_client.connect()
|
||||
|
||||
def test_configure_ransomware_script(self):
|
||||
env = PrimaiteGymEnv(env_config=APP_CONFIG_YAML)
|
||||
client_2 = env.game.simulation.network.get_node_by_hostname("client_2")
|
||||
assert client_2.software_manager.software.get("RansomwareScript") is None
|
||||
|
||||
# install ransomware script
|
||||
env.step(2)
|
||||
ransom = client_2.software_manager.software.get("RansomwareScript")
|
||||
assert isinstance(ransom, RansomwareScript)
|
||||
assert ransom.operating_state == ApplicationOperatingState.INSTALLING
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
|
||||
# make sure it's not working yet because it's not configured and there's no db client
|
||||
assert not ransom.attack()
|
||||
env.step(8) # install db client on the same node
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
env.step(0) # let it finish installing
|
||||
assert not ransom.attack()
|
||||
|
||||
# finally, configure the ransomware script with ip and password
|
||||
env.step(6)
|
||||
assert ransom.attack()
|
||||
|
||||
db_server = env.game.simulation.network.get_node_by_hostname("server_1")
|
||||
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
assert db_service.db_file.health_status == FileSystemItemHealthStatus.CORRUPT
|
||||
|
||||
def test_configure_dos_bot(self):
|
||||
env = PrimaiteGymEnv(env_config=APP_CONFIG_YAML)
|
||||
client_3 = env.game.simulation.network.get_node_by_hostname("client_3")
|
||||
assert client_3.software_manager.software.get("DoSBot") is None
|
||||
|
||||
# install DoSBot
|
||||
env.step(3)
|
||||
bot = client_3.software_manager.software.get("DoSBot")
|
||||
assert isinstance(bot, DoSBot)
|
||||
assert bot.operating_state == ApplicationOperatingState.INSTALLING
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
|
||||
# make sure dos bot doesn't work before being configured
|
||||
assert not bot.run()
|
||||
env.step(7)
|
||||
assert bot.run()
|
||||
@@ -557,7 +557,7 @@ def test_node_application_install_and_uninstall_integration(game_and_agent: Tupl
|
||||
|
||||
assert client_1.software_manager.software.get("DoSBot") is None
|
||||
|
||||
action = ("NODE_APPLICATION_INSTALL", {"node_id": 0, "application_name": "DoSBot", "ip_address": "192.168.1.14"})
|
||||
action = ("NODE_APPLICATION_INSTALL", {"node_id": 0, "application_name": "DoSBot"})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from primaite.game.agent.interface import AgentHistoryItem
|
||||
from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
|
||||
from primaite.game.agent.rewards import ActionPenalty, GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
@@ -119,3 +121,77 @@ def test_shared_reward():
|
||||
g2_reward = env.game.agents["client_2_green_user"].reward_function.current_reward
|
||||
blue_reward = env.game.agents["defender"].reward_function.current_reward
|
||||
assert blue_reward == g1_reward + g2_reward
|
||||
|
||||
|
||||
def test_action_penalty_loads_from_config():
|
||||
"""Test to ensure that action penalty is correctly loaded from config into PrimaiteGymEnv"""
|
||||
CFG_PATH = TEST_ASSETS_ROOT / "configs/action_penalty.yaml"
|
||||
with open(CFG_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
|
||||
env.reset()
|
||||
defender = env.game.agents["defender"]
|
||||
act_penalty_obj = None
|
||||
for comp in defender.reward_function.reward_components:
|
||||
if isinstance(comp[0], ActionPenalty):
|
||||
act_penalty_obj = comp[0]
|
||||
if act_penalty_obj is None:
|
||||
pytest.fail("Action penalty reward component was not added to the agent from config.")
|
||||
assert act_penalty_obj.action_penalty == -0.75
|
||||
assert act_penalty_obj.do_nothing_penalty == 0.125
|
||||
|
||||
|
||||
def test_action_penalty():
|
||||
"""Test that the action penalty is correctly applied when agent performs any action"""
|
||||
|
||||
# Create an ActionPenalty Reward
|
||||
Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
|
||||
|
||||
# Assert that penalty is applied if action isn't DONOTHING
|
||||
reward_value = Penalty.calculate(
|
||||
state={},
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0,
|
||||
action="NODE_APPLICATION_EXECUTE",
|
||||
parameters={"node_id": 0, "application_id": 1},
|
||||
request=["execute"],
|
||||
response=RequestResponse.from_bool(True),
|
||||
),
|
||||
)
|
||||
|
||||
assert reward_value == -0.75
|
||||
|
||||
# Assert that no penalty applied for a DONOTHING action
|
||||
reward_value = Penalty.calculate(
|
||||
state={},
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0,
|
||||
action="DONOTHING",
|
||||
parameters={},
|
||||
request=["do_nothing"],
|
||||
response=RequestResponse.from_bool(True),
|
||||
),
|
||||
)
|
||||
|
||||
assert reward_value == 0.125
|
||||
|
||||
|
||||
def test_action_penalty_e2e(game_and_agent):
|
||||
"""Test that we get the right reward for doing actions to fetch a website."""
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
comp = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
|
||||
|
||||
agent.reward_function.register_component(comp, 1.0)
|
||||
|
||||
action = ("DONOTHING", {})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert agent.reward_function.current_reward == 0.125
|
||||
|
||||
action = ("NODE_FILE_SCAN", {"node_id": 0, "folder_id": 0, "file_id": 0})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert agent.reward_function.current_reward == -0.75
|
||||
|
||||
@@ -41,7 +41,7 @@ class BroadcastService(Service):
|
||||
super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port.HTTP, ip_protocol=self.protocol)
|
||||
|
||||
|
||||
class BroadcastClient(Application):
|
||||
class BroadcastClient(Application, identifier="BroadcastClient"):
|
||||
"""A client application to receive broadcast and unicast messages."""
|
||||
|
||||
payloads_received: List = []
|
||||
|
||||
@@ -3,9 +3,9 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.networks import multi_lan_internet_network_example
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
|
||||
from src.primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
|
||||
|
||||
def test_all_with_configured_dns_server_ip_can_resolve_url():
|
||||
|
||||
@@ -21,7 +21,7 @@ def populated_node(application_class) -> Tuple[Application, Computer]:
|
||||
computer.power_on()
|
||||
computer.software_manager.install(application_class)
|
||||
|
||||
app = computer.software_manager.software.get("TestApplication")
|
||||
app = computer.software_manager.software.get("DummyApplication")
|
||||
app.run()
|
||||
|
||||
return app, computer
|
||||
@@ -39,7 +39,7 @@ def test_application_on_offline_node(application_class):
|
||||
)
|
||||
computer.software_manager.install(application_class)
|
||||
|
||||
app: Application = computer.software_manager.software.get("TestApplication")
|
||||
app: Application = computer.software_manager.software.get("DummyApplication")
|
||||
|
||||
computer.power_off()
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from primaite.simulator.system.applications.database_client import DatabaseClien
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -213,6 +214,110 @@ def test_restore_backup_after_deleting_file_without_updating_scan(uc2_network):
|
||||
assert db_service.db_file.visible_health_status == FileSystemItemHealthStatus.GOOD # now looks good
|
||||
|
||||
|
||||
def test_database_service_fix(uc2_network):
|
||||
"""Test that the software fix applies to database service."""
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
|
||||
|
||||
assert db_service.backup_database() is True
|
||||
|
||||
# delete database locally
|
||||
db_service.file_system.delete_file(folder_name="database", file_name="database.db")
|
||||
|
||||
# db file is gone, reduced to atoms
|
||||
assert db_service.db_file is None
|
||||
|
||||
db_service.fix() # fix the database service
|
||||
|
||||
assert db_service.health_state_actual == SoftwareHealthState.FIXING
|
||||
|
||||
# apply timestep until the fix is applied
|
||||
for i in range(db_service.fixing_duration + 1):
|
||||
uc2_network.apply_timestep(i)
|
||||
|
||||
assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert db_service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_database_cannot_be_queried_while_fixing(uc2_network):
|
||||
"""Tests that the database service cannot be queried if the service is being fixed."""
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
|
||||
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
|
||||
db_connection: DatabaseClientConnection = db_client.get_new_connection()
|
||||
|
||||
assert db_connection.query(sql="SELECT")
|
||||
|
||||
assert db_service.backup_database() is True
|
||||
|
||||
# delete database locally
|
||||
db_service.file_system.delete_file(folder_name="database", file_name="database.db")
|
||||
|
||||
# db file is gone, reduced to atoms
|
||||
assert db_service.db_file is None
|
||||
|
||||
db_service.fix() # fix the database service
|
||||
assert db_service.health_state_actual == SoftwareHealthState.FIXING
|
||||
|
||||
# fails to query because database is in FIXING state
|
||||
assert db_connection.query(sql="SELECT") is False
|
||||
|
||||
# apply timestep until the fix is applied
|
||||
for i in range(db_service.fixing_duration + 1):
|
||||
uc2_network.apply_timestep(i)
|
||||
|
||||
assert db_service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD
|
||||
|
||||
assert db_connection.query(sql="SELECT")
|
||||
|
||||
|
||||
def test_database_can_create_connection_while_fixing(uc2_network):
|
||||
"""Tests that connections cannot be created while the database is being fixed."""
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
|
||||
|
||||
client_2: Server = uc2_network.get_node_by_hostname("client_2")
|
||||
db_client: DatabaseClient = client_2.software_manager.software["DatabaseClient"]
|
||||
|
||||
db_connection: DatabaseClientConnection = db_client.get_new_connection()
|
||||
|
||||
assert db_connection.query(sql="SELECT")
|
||||
|
||||
assert db_service.backup_database() is True
|
||||
|
||||
# delete database locally
|
||||
db_service.file_system.delete_file(folder_name="database", file_name="database.db")
|
||||
|
||||
# db file is gone, reduced to atoms
|
||||
assert db_service.db_file is None
|
||||
|
||||
db_service.fix() # fix the database service
|
||||
assert db_service.health_state_actual == SoftwareHealthState.FIXING
|
||||
|
||||
# fails to query because database is in FIXING state
|
||||
assert db_connection.query(sql="SELECT") is False
|
||||
|
||||
# should be able to create a new connection
|
||||
new_db_connection: DatabaseClientConnection = db_client.get_new_connection()
|
||||
assert new_db_connection is not None
|
||||
assert new_db_connection.query(sql="SELECT") is False # still should fail to query because FIXING
|
||||
|
||||
# apply timestep until the fix is applied
|
||||
for i in range(db_service.fixing_duration + 1):
|
||||
uc2_network.apply_timestep(i)
|
||||
|
||||
assert db_service.health_state_actual == SoftwareHealthState.GOOD
|
||||
assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD
|
||||
|
||||
assert db_connection.query(sql="SELECT")
|
||||
assert new_db_connection.query(sql="SELECT")
|
||||
|
||||
|
||||
def test_database_client_cannot_query_offline_database_server(uc2_network):
|
||||
"""Tests DB query across the network returns HTTP status 404 when db server is offline."""
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
|
||||
@@ -16,6 +16,7 @@ from primaite.simulator.system.services.database.database_service import Databas
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -110,6 +111,29 @@ def test_web_client_requests_users(web_client_web_server_database):
|
||||
assert web_browser.get_webpage()
|
||||
|
||||
|
||||
def test_database_fix_disrupts_web_client(uc2_network):
|
||||
"""Tests that the database service being in fixed state disrupts the web client."""
|
||||
computer: Computer = uc2_network.get_node_by_hostname("client_1")
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
|
||||
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
|
||||
database_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
|
||||
# fix the database service
|
||||
database_service.fix()
|
||||
|
||||
assert database_service.health_state_actual == SoftwareHealthState.FIXING
|
||||
|
||||
assert web_browser.get_webpage() is False
|
||||
|
||||
for i in range(database_service.fixing_duration + 1):
|
||||
uc2_network.apply_timestep(i)
|
||||
|
||||
assert database_service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
assert web_browser.get_webpage()
|
||||
|
||||
|
||||
class TestWebBrowserHistory:
|
||||
def test_populating_history(self, web_client_web_server_database):
|
||||
network, computer, _, _ = web_client_web_server_database
|
||||
|
||||
@@ -13,7 +13,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from tests.conftest import TestApplication, TestService
|
||||
from tests.conftest import DummyApplication, TestService
|
||||
|
||||
|
||||
def test_successful_node_file_system_creation_request(example_network):
|
||||
@@ -47,14 +47,14 @@ def test_successful_application_requests(example_network):
|
||||
net = example_network
|
||||
|
||||
client_1 = net.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.install(TestApplication)
|
||||
client_1.software_manager.software.get("TestApplication").run()
|
||||
client_1.software_manager.install(DummyApplication)
|
||||
client_1.software_manager.software.get("DummyApplication").run()
|
||||
|
||||
resp_1 = net.apply_request(["node", "client_1", "application", "TestApplication", "scan"])
|
||||
resp_1 = net.apply_request(["node", "client_1", "application", "DummyApplication", "scan"])
|
||||
assert resp_1 == RequestResponse(status="success", data={})
|
||||
resp_2 = net.apply_request(["node", "client_1", "application", "TestApplication", "fix"])
|
||||
resp_2 = net.apply_request(["node", "client_1", "application", "DummyApplication", "fix"])
|
||||
assert resp_2 == RequestResponse(status="success", data={})
|
||||
resp_3 = net.apply_request(["node", "client_1", "application", "TestApplication", "compromise"])
|
||||
resp_3 = net.apply_request(["node", "client_1", "application", "DummyApplication", "compromise"])
|
||||
assert resp_3 == RequestResponse(status="success", data={})
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
|
||||
|
||||
def test_adding_to_app_registry():
|
||||
class temp_application(Application, identifier="temp_app"):
|
||||
pass
|
||||
|
||||
assert Application._application_registry["temp_app"] is temp_application
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
class another_application(Application, identifier="temp_app"):
|
||||
pass
|
||||
|
||||
# This is kinda evil...
|
||||
# Because pytest doesn't reimport classes from modules, registering this temporary test application will change the
|
||||
# state of the Application registry for all subsequently run tests. So, we have to delete and unregister the class.
|
||||
del temp_application
|
||||
Application._application_registry.pop("temp_app")
|
||||
Reference in New Issue
Block a user