From db27bea4ec37dd2d4ad7082ad951a36881168144 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 25 Jun 2024 12:29:01 +0100 Subject: [PATCH 01/49] #2656 - Committing current state before lunch. New ActionPenalty reward added. Basic implementation returns a -1 reward if last_action_response.action isn't DONOTHING. Minor change in data_manipulation so I can see it working in the data_manipulation notebook. Need to use configured values but so far, promising?. Looks to result in a better average reward than without which is good, I think. --- .../_package_data/data_manipulation.yaml | 1 - src/primaite/game/agent/rewards.py | 33 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index 6cded5f2..1ec98f39 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -740,7 +740,6 @@ agents: agent_name: client_2_green_user - agent_settings: flatten_obs: true diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index cabea5f4..7d14e097 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -360,6 +360,38 @@ class SharedReward(AbstractReward): return cls(agent_name=agent_name) +class ActionPenalty(AbstractReward): + """ + Apply a negative reward when taking any action except DONOTHING. + + Optional Configuration item therefore default value of 0 (?). + """ + + def __init__(self, agent_name: str, penalty: float = 0): + """ + Initialise the reward. + + Penalty will default to 0, as this is an optional param. + """ + self.agent_name = agent_name + self.penalty = penalty + + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: + """Calculate the penalty to be applied.""" + if last_action_response.action == "DONOTHING": + # No penalty for doing nothing at present + return 0 + else: + return -1 + + @classmethod + def from_config(cls, config: Dict) -> "ActionPenalty": + """Build the ActionPenalty object from config.""" + agent_name = config.get("agent_name") + # penalty_value = config.get("ACTION_PENALTY", 0) + return cls(agent_name=agent_name) + + class RewardFunction: """Manages the reward function for the agent.""" @@ -370,6 +402,7 @@ class RewardFunction: "WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty, "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty, "SHARED_REWARD": SharedReward, + "ACTION_PENALTY": ActionPenalty, } """List of reward class identifiers.""" From 5ad16fdb7eecfd3d0e9f8e6349a0d542066405a7 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 25 Jun 2024 15:36:47 +0100 Subject: [PATCH 02/49] #2656 - Corrected from_config() for ActionPenalty so that it can pull the negative reward value from YAML and apply, defaulting to 0 still if not found/not configured. Currerntly prints to terminal when a negative reward is being applied, though this is for implementation and troubleshooting. To be removed before PR is pushed out of draft --- .../config/_package_data/data_manipulation.yaml | 5 +++++ src/primaite/game/agent/rewards.py | 11 +++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index 1ec98f39..be613918 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -739,6 +739,11 @@ agents: options: agent_name: client_2_green_user + - type: ACTION_PENALTY + weight: 1.0 + options: + agent_name: defender + penalty_value: -1 agent_settings: flatten_obs: true diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 7d14e097..d75597f0 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -367,7 +367,7 @@ class ActionPenalty(AbstractReward): Optional Configuration item therefore default value of 0 (?). """ - def __init__(self, agent_name: str, penalty: float = 0): + def __init__(self, agent_name: str, penalty: float): """ Initialise the reward. @@ -382,14 +382,17 @@ class ActionPenalty(AbstractReward): # No penalty for doing nothing at present return 0 else: - return -1 + _LOGGER.info( + f"Blue agent has incurred a penalty of {self.penalty}, for action: {last_action_response.action}" + ) + return self.penalty @classmethod def from_config(cls, config: Dict) -> "ActionPenalty": """Build the ActionPenalty object from config.""" agent_name = config.get("agent_name") - # penalty_value = config.get("ACTION_PENALTY", 0) - return cls(agent_name=agent_name) + penalty_value = config.get("penalty_value", 0) # default to 0 so that no adverse effects. + return cls(agent_name=agent_name, penalty=penalty_value) class RewardFunction: From 824729276ec60f9dda1ebc2f62f159376c9c749f Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Tue, 25 Jun 2024 16:58:39 +0100 Subject: [PATCH 03/49] #2648 - updated benchmark process to output markdown file instead of LaTeX. Added pipeline that runs benchmarking at 2am on a weekday and automatically upon creation of release branch --- .azure/azure-benchmark-pipeline.yaml | 83 +++++++++++++++++ benchmark/primaite_benchmark.py | 4 +- benchmark/report.py | 133 +++++++++++---------------- pyproject.toml | 1 - 4 files changed, 139 insertions(+), 82 deletions(-) create mode 100644 .azure/azure-benchmark-pipeline.yaml diff --git a/.azure/azure-benchmark-pipeline.yaml b/.azure/azure-benchmark-pipeline.yaml new file mode 100644 index 00000000..ac52ce2b --- /dev/null +++ b/.azure/azure-benchmark-pipeline.yaml @@ -0,0 +1,83 @@ +trigger: + branches: + exclude: + - '*' + include: + - 'refs/heads/release/*' # Trigger on creation of release branches + +schedules: +- cron: "0 2 * * 1-5" # Run at 2 AM every weekday + displayName: "Weekday Schedule" + branches: + include: + - dev + +pool: + vmImage: ubuntu-latest + +variables: + VERSION: '' + MAJOR_VERSION: '' + +steps: +- checkout: self + persistCredentials: true + +- script: | + VERSION=$(cat src/primaite/VERSION | tr -d '\n') + MAJOR_VERSION=$(echo $VERSION | cut -d. -f1) + echo "##vso[task.setvariable variable=VERSION]$VERSION" + echo "##vso[task.setvariable variable=MAJOR_VERSION]$MAJOR_VERSION" + displayName: 'Set Version Variables' + +- script: | + if [[ "$(Build.SourceBranch)" == "refs/heads/dev" ]]; then + DATE=$(date +%Y%m%d%H%M) + echo "${VERSION}+dev.${DATE}" > src/primaite/VERSION + fi + displayName: 'Update VERSION file for Dev Benchmark' + +- task: UsePythonVersion@0 + inputs: + versionSpec: '3.11' + addToPath: true + +- script: | + python -m pip install --upgrade pip + pip install -e . + primaite setup + displayName: 'Install Dependencies' + +- script: | + mkdir -p benchmark/results/v$(MAJOR_VERSION)/v$(VERSION) + python benchmark.py --output benchmark/results/v$(MAJOR_VERSION)/v$(VERSION) + displayName: 'Run Benchmarking Script' + +- script: | + git config --global user.email "oss@dstl.gov.uk" + git config --global user.name "Defence Science and Technology Laboratory UK" + workingDirectory: $(System.DefaultWorkingDirectory) + displayName: 'Configure Git' + condition: and(succeeded(), eq(variables['Build.Reason'], 'Manual'), startsWith(variables['Build.SourceBranch'], 'refs/heads/release/')) + +- script: | + git add benchmark/results/v$(MAJOR_VERSION)/v$(VERSION)/* + git commit -m "Automated benchmark output commit for version $(VERSION)" + git push origin HEAD:$(Build.SourceBranchName) + displayName: 'Commit and Push Benchmark Results' + workingDirectory: $(System.DefaultWorkingDirectory) + env: + GIT_CREDENTIALS: $(System.AccessToken) + condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/heads/release/')) + +- script: | + mkdir -p artifact_output/benchmark/results/v$(MAJOR_VERSION) + cp -r benchmark/results/v$(MAJOR_VERSION)/v$(VERSION) artifact_output/benchmark/results/v$(MAJOR_VERSION)/ + displayName: 'Prepare Artifacts for Publishing' + +- task: PublishPipelineArtifact@1 + inputs: + targetPath: 'artifact_output/benchmark/results' # Path to the files you want to publish + artifactName: 'benchmark-output' # Name of the artifact + publishLocation: 'pipeline' + displayName: 'Publish Benchmark Output as Artifact' diff --git a/benchmark/primaite_benchmark.py b/benchmark/primaite_benchmark.py index b19dbb16..f3d0a10c 100644 --- a/benchmark/primaite_benchmark.py +++ b/benchmark/primaite_benchmark.py @@ -117,14 +117,14 @@ class BenchmarkSession: def generate_learn_metadata_dict(self) -> Dict[str, Any]: """Metadata specific to the learning session.""" total_s, s_per_step, s_per_100_steps_10_nodes = self._learn_benchmark_durations() - self.gym_env.average_reward_per_episode.pop(0) # remove episode 0 + self.gym_env.total_reward_per_episode.pop(0) # remove episode 0 return { "total_episodes": self.gym_env.episode_counter, "total_time_steps": self.gym_env.total_time_steps, "total_s": total_s, "s_per_step": s_per_step, "s_per_100_steps_10_nodes": s_per_100_steps_10_nodes, - "av_reward_per_episode": self.gym_env.average_reward_per_episode, + "av_reward_per_episode": self.gym_env.total_reward_per_episode, } diff --git a/benchmark/report.py b/benchmark/report.py index ca3d03a3..6a71ef57 100644 --- a/benchmark/report.py +++ b/benchmark/report.py @@ -9,10 +9,6 @@ import plotly.graph_objects as go import polars as pl import yaml from plotly.graph_objs import Figure -from pylatex import Command, Document -from pylatex import Figure as LatexFigure -from pylatex import Section, Subsection, Tabular -from pylatex.utils import bold from utils import _get_system_info import primaite @@ -140,7 +136,7 @@ def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure: converted into a polars dataframe, and plotted as a scatter line in plotly. """ major_v = primaite.__version__.split(".")[0] - title = f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*" + title = f"Learning Benchmarking of All Released Versions under Major v{major_v}.#.#" subtitle = "Rolling Av (Combined Session Av)" if title: if subtitle: @@ -208,98 +204,77 @@ def build_benchmark_latex_report( fig = _plot_all_benchmarks_combined_session_av(results_directory=results_root_path) - all_version_plot_path = results_root_path / "PrimAITE Versions Learning Benchmark.png" + all_version_plot_path = version_result_dir / "PrimAITE Versions Learning Benchmark.png" fig.write_image(all_version_plot_path) - geometry_options = {"tmargin": "2.5cm", "rmargin": "2.5cm", "bmargin": "2.5cm", "lmargin": "2.5cm"} data = benchmark_metadata_dict primaite_version = data["primaite_version"] - # Create a new document - doc = Document("report", geometry_options=geometry_options) - # Title - doc.preamble.append(Command("title", f"PrimAITE {primaite_version} Learning Benchmark")) - doc.preamble.append(Command("author", "PrimAITE Dev Team")) - doc.preamble.append(Command("date", datetime.now().date())) - doc.append(Command("maketitle")) + with open(version_result_dir / f"PrimAITE v{primaite_version} Learning Benchmark.md", "w") as file: + # Title + file.write(f"# PrimAITE v{primaite_version} Learning Benchmark\n") + file.write("## PrimAITE Dev Team\n") + file.write(f"### {datetime.now().date()}\n") + file.write("\n---\n") - sessions = data["total_sessions"] - episodes = session_metadata[1]["total_episodes"] - 1 - steps = data["config"]["game"]["max_episode_length"] + sessions = data["total_sessions"] + episodes = session_metadata[1]["total_episodes"] - 1 + steps = data["config"]["game"]["max_episode_length"] - # Body - with doc.create(Section("Introduction")): - doc.append( + # Body + file.write("## 1 Introduction\n") + file.write( f"PrimAITE v{primaite_version} was benchmarked automatically upon release. Learning rate metrics " - f"were captured to be referenced during system-level testing and user acceptance testing (UAT)." + f"were captured to be referenced during system-level testing and user acceptance testing (UAT).\n" ) - doc.append( - f"\nThe benchmarking process consists of running {sessions} training session using the same " + file.write( + f"The benchmarking process consists of running {sessions} training session using the same " f"config file. Each session trains an agent for {episodes} episodes, " - f"with each episode consisting of {steps} steps." + f"with each episode consisting of {steps} steps.\n" ) - doc.append( - f"\nThe total reward per episode from each session is captured. This is then used to calculate an " + file.write( + f"The total reward per episode from each session is captured. This is then used to calculate an " f"caverage total reward per episode from the {sessions} individual sessions for smoothing. " f"Finally, a 25-widow rolling average of the average total reward per session is calculated for " - f"further smoothing." + f"further smoothing.\n" ) - with doc.create(Section("System Information")): - with doc.create(Subsection("Python")): - with doc.create(Tabular("|l|l|")) as table: - table.add_hline() - table.add_row((bold("Version"), sys.version)) - table.add_hline() + file.write("## 2 System Information\n") + i = 1 + file.write(f"### 2.{i} Python\n") + file.write(f"**Version:** {sys.version}\n") + for section, section_data in data["system_info"].items(): + i += 1 if section_data: - with doc.create(Subsection(section)): - if isinstance(section_data, dict): - with doc.create(Tabular("|l|l|")) as table: - table.add_hline() - for key, value in section_data.items(): - table.add_row((bold(key), value)) - table.add_hline() - elif isinstance(section_data, list): - headers = section_data[0].keys() - tabs_str = "|".join(["l" for _ in range(len(headers))]) - tabs_str = f"|{tabs_str}|" - with doc.create(Tabular(tabs_str)) as table: - table.add_hline() - table.add_row([bold(h) for h in headers]) - table.add_hline() - for item in section_data: - table.add_row(item.values()) - table.add_hline() + file.write(f"### 2.{i} {section}\n") + if isinstance(section_data, dict): + for key, value in section_data.items(): + file.write(f"- **{key}:** {value}\n") - headers_map = { - "total_sessions": "Total Sessions", - "total_episodes": "Total Episodes", - "total_time_steps": "Total Steps", - "av_s_per_session": "Av Session Duration (s)", - "av_s_per_step": "Av Step Duration (s)", - "av_s_per_100_steps_10_nodes": "Av Duration per 100 Steps per 10 Nodes (s)", - } - with doc.create(Section("Stats")): - with doc.create(Subsection("Benchmark Results")): - with doc.create(Tabular("|l|l|")) as table: - table.add_hline() - for section, header in headers_map.items(): - if section.startswith("av_"): - table.add_row((bold(header), f"{data[section]:.4f}")) - else: - table.add_row((bold(header), data[section])) - table.add_hline() + headers_map = { + "total_sessions": "Total Sessions", + "total_episodes": "Total Episodes", + "total_time_steps": "Total Steps", + "av_s_per_session": "Av Session Duration (s)", + "av_s_per_step": "Av Step Duration (s)", + "av_s_per_100_steps_10_nodes": "Av Duration per 100 Steps per 10 Nodes (s)", + } - with doc.create(Section("Graphs")): - with doc.create(Subsection(f"v{primaite_version} Learning Benchmark Plot")): - with doc.create(LatexFigure(position="h!")) as pic: - pic.add_image(str(this_version_plot_path)) - pic.add_caption(f"PrimAITE {primaite_version} Learning Benchmark Plot") + file.write("## 3 Stats\n") + for section, header in headers_map.items(): + if section.startswith("av_"): + file.write(f"- **{header}:** {data[section]:.4f}\n") + else: + file.write(f"- **{header}:** {data[section]}\n") - with doc.create(Subsection(f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*")): - with doc.create(LatexFigure(position="h!")) as pic: - pic.add_image(str(all_version_plot_path)) - pic.add_caption(f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*") + file.write("## 4 Graphs\n") - doc.generate_pdf(str(this_version_plot_path).replace(".png", ""), clean_tex=True) + file.write(f"### 4.1 v{primaite_version} Learning Benchmark Plot\n") + file.write(f"![PrimAITE {primaite_version} Learning Benchmark Plot]({this_version_plot_path.name})\n") + + file.write(f"### 4.2 Learning Benchmarking of All Released Versions under Major v{major_v}.#.#\n") + file.write( + f"![Learning Benchmarking of All Released Versions under " + f"Major v{major_v}.#.#]({all_version_plot_path.name})\n" + ) diff --git a/pyproject.toml b/pyproject.toml index 31ce5312..9d53e961 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,6 @@ dev = [ "gputil==1.4.0", "pip-licenses==4.3.0", "pre-commit==2.20.0", - "pylatex==1.4.1", "pytest==7.2.0", "pytest-xdist==3.3.1", "pytest-cov==4.0.0", From 1033e696cd94e342bbc66bd4b298cc09161e75db Mon Sep 17 00:00:00 2001 From: Christopher McCarthy Date: Tue, 25 Jun 2024 16:01:18 +0000 Subject: [PATCH 04/49] Set up CI with Azure Pipelines [skip ci] --- .azure/azure-benchmark-pipeline.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.azure/azure-benchmark-pipeline.yaml b/.azure/azure-benchmark-pipeline.yaml index ac52ce2b..53df7155 100644 --- a/.azure/azure-benchmark-pipeline.yaml +++ b/.azure/azure-benchmark-pipeline.yaml @@ -10,7 +10,7 @@ schedules: displayName: "Weekday Schedule" branches: include: - - dev + - feature/2648_Automate-the-benchmarking-process pool: vmImage: ubuntu-latest @@ -31,7 +31,7 @@ steps: displayName: 'Set Version Variables' - script: | - if [[ "$(Build.SourceBranch)" == "refs/heads/dev" ]]; then + if [[ "$(Build.SourceBranch)" == "refs/heads/feature/2648_Automate-the-benchmarking-process" ]]; then DATE=$(date +%Y%m%d%H%M) echo "${VERSION}+dev.${DATE}" > src/primaite/VERSION fi From 55d69d6568bfd97950989df56db011cba33dc05c Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Tue, 25 Jun 2024 17:23:04 +0100 Subject: [PATCH 05/49] #2648 - updated benchmark run command --- .azure/azure-benchmark-pipeline.yaml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.azure/azure-benchmark-pipeline.yaml b/.azure/azure-benchmark-pipeline.yaml index ac52ce2b..43140d7d 100644 --- a/.azure/azure-benchmark-pipeline.yaml +++ b/.azure/azure-benchmark-pipeline.yaml @@ -49,8 +49,9 @@ steps: displayName: 'Install Dependencies' - script: | - mkdir -p benchmark/results/v$(MAJOR_VERSION)/v$(VERSION) - python benchmark.py --output benchmark/results/v$(MAJOR_VERSION)/v$(VERSION) + cd benchmark + python3 primaite_benchmark.py + cd .. displayName: 'Run Benchmarking Script' - script: | From 4249314672062c6edc8854e2ec9980075a3f6091 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Tue, 25 Jun 2024 17:28:05 +0100 Subject: [PATCH 06/49] #2648 - reduced the benchmark sessions and episodes for fail-fast speed while testing --- benchmark/primaite_benchmark.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmark/primaite_benchmark.py b/benchmark/primaite_benchmark.py index f3d0a10c..92c9bf0a 100644 --- a/benchmark/primaite_benchmark.py +++ b/benchmark/primaite_benchmark.py @@ -151,8 +151,8 @@ def _prepare_session_directory(): def run( - number_of_sessions: int = 5, - num_episodes: int = 1000, + number_of_sessions: int = 1, + num_episodes: int = 25, episode_len: int = 128, n_steps: int = 1280, batch_size: int = 32, From 0ee243a242b47dc109905b2a089d42406dea89c7 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 26 Jun 2024 09:09:16 +0100 Subject: [PATCH 07/49] #2648 - trying to fix the artifacts publish stage. currently creating tar.gz and publishing that --- .azure/azure-benchmark-pipeline.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.azure/azure-benchmark-pipeline.yaml b/.azure/azure-benchmark-pipeline.yaml index 133de713..f0c7c063 100644 --- a/.azure/azure-benchmark-pipeline.yaml +++ b/.azure/azure-benchmark-pipeline.yaml @@ -72,13 +72,13 @@ steps: condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/heads/release/')) - script: | - mkdir -p artifact_output/benchmark/results/v$(MAJOR_VERSION) - cp -r benchmark/results/v$(MAJOR_VERSION)/v$(VERSION) artifact_output/benchmark/results/v$(MAJOR_VERSION)/ + ls benchmark/results/v$(MAJOR_VERSION)/v$(VERSION) + tar czf benchmark.tar.gz benchmark/results/v$(MAJOR_VERSION)/v$(VERSION) displayName: 'Prepare Artifacts for Publishing' - task: PublishPipelineArtifact@1 inputs: - targetPath: 'artifact_output/benchmark/results' # Path to the files you want to publish + targetPath: benchmark.tar.gz # Path to the files you want to publish artifactName: 'benchmark-output' # Name of the artifact publishLocation: 'pipeline' displayName: 'Publish Benchmark Output as Artifact' From ad2b132a10b2b701e3337a7a11ad8539c5b8cbae Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 26 Jun 2024 09:13:55 +0100 Subject: [PATCH 08/49] #2648 - added dev and rl extras to the pip install step --- .azure/azure-benchmark-pipeline.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.azure/azure-benchmark-pipeline.yaml b/.azure/azure-benchmark-pipeline.yaml index f0c7c063..c16a21cf 100644 --- a/.azure/azure-benchmark-pipeline.yaml +++ b/.azure/azure-benchmark-pipeline.yaml @@ -44,7 +44,7 @@ steps: - script: | python -m pip install --upgrade pip - pip install -e . + pip install -e .[dev,rl] primaite setup displayName: 'Install Dependencies' From 2192516c9aa360bec42c0ebda21dbae8a71d89a1 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 26 Jun 2024 09:40:16 +0100 Subject: [PATCH 09/49] #2648 - reordered steps so that dev version is set first before version variables are set --- .azure/azure-benchmark-pipeline.yaml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.azure/azure-benchmark-pipeline.yaml b/.azure/azure-benchmark-pipeline.yaml index c16a21cf..1ad28197 100644 --- a/.azure/azure-benchmark-pipeline.yaml +++ b/.azure/azure-benchmark-pipeline.yaml @@ -23,13 +23,6 @@ steps: - checkout: self persistCredentials: true -- script: | - VERSION=$(cat src/primaite/VERSION | tr -d '\n') - MAJOR_VERSION=$(echo $VERSION | cut -d. -f1) - echo "##vso[task.setvariable variable=VERSION]$VERSION" - echo "##vso[task.setvariable variable=MAJOR_VERSION]$MAJOR_VERSION" - displayName: 'Set Version Variables' - - script: | if [[ "$(Build.SourceBranch)" == "refs/heads/feature/2648_Automate-the-benchmarking-process" ]]; then DATE=$(date +%Y%m%d%H%M) @@ -37,6 +30,13 @@ steps: fi displayName: 'Update VERSION file for Dev Benchmark' +- script: | + VERSION=$(cat src/primaite/VERSION | tr -d '\n') + MAJOR_VERSION=$(echo $VERSION | cut -d. -f1) + echo "##vso[task.setvariable variable=VERSION]$VERSION" + echo "##vso[task.setvariable variable=MAJOR_VERSION]$MAJOR_VERSION" + displayName: 'Set Version Variables' + - task: UsePythonVersion@0 inputs: versionSpec: '3.11' From 112f116a89a208b88502ef63302d828cdede33b5 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 26 Jun 2024 09:46:05 +0100 Subject: [PATCH 10/49] #2648 - fixed error whereby VERSION variable isn't set before setting the dev build version in the VERSION file --- .azure/azure-benchmark-pipeline.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.azure/azure-benchmark-pipeline.yaml b/.azure/azure-benchmark-pipeline.yaml index 1ad28197..06ef72a4 100644 --- a/.azure/azure-benchmark-pipeline.yaml +++ b/.azure/azure-benchmark-pipeline.yaml @@ -24,6 +24,7 @@ steps: persistCredentials: true - script: | + VERSION=$(cat src/primaite/VERSION | tr -d '\n') if [[ "$(Build.SourceBranch)" == "refs/heads/feature/2648_Automate-the-benchmarking-process" ]]; then DATE=$(date +%Y%m%d%H%M) echo "${VERSION}+dev.${DATE}" > src/primaite/VERSION From 08ca4d8889fc386b49863ebc6a788ef9c5acbc62 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 26 Jun 2024 10:50:05 +0100 Subject: [PATCH 11/49] #2648 - now testing the release benchmark auto commit and push --- .azure/azure-benchmark-pipeline.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.azure/azure-benchmark-pipeline.yaml b/.azure/azure-benchmark-pipeline.yaml index 06ef72a4..5d84fc8c 100644 --- a/.azure/azure-benchmark-pipeline.yaml +++ b/.azure/azure-benchmark-pipeline.yaml @@ -3,7 +3,7 @@ trigger: exclude: - '*' include: - - 'refs/heads/release/*' # Trigger on creation of release branches + - feature/2648_Automate-the-benchmarking-process # Trigger on creation of release branches schedules: - cron: "0 2 * * 1-5" # Run at 2 AM every weekday @@ -60,7 +60,7 @@ steps: git config --global user.name "Defence Science and Technology Laboratory UK" workingDirectory: $(System.DefaultWorkingDirectory) displayName: 'Configure Git' - condition: and(succeeded(), eq(variables['Build.Reason'], 'Manual'), startsWith(variables['Build.SourceBranch'], 'refs/heads/release/')) + condition: and(succeeded(), eq(variables['Build.Reason'], 'Manual'), startsWith(variables['Build.SourceBranch'], 'refs/heads/feature/2648_Automate-the-benchmarking-process')) - script: | git add benchmark/results/v$(MAJOR_VERSION)/v$(VERSION)/* @@ -70,7 +70,7 @@ steps: workingDirectory: $(System.DefaultWorkingDirectory) env: GIT_CREDENTIALS: $(System.AccessToken) - condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/heads/release/')) + condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/heads/feature/2648_Automate-the-benchmarking-process')) - script: | ls benchmark/results/v$(MAJOR_VERSION)/v$(VERSION) From 3b118fa0add4f9696b531f84e29aab4d96b781c5 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 26 Jun 2024 10:52:23 +0100 Subject: [PATCH 12/49] #2648 - now testing the release benchmark auto commit and push --- .azure/azure-benchmark-pipeline.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.azure/azure-benchmark-pipeline.yaml b/.azure/azure-benchmark-pipeline.yaml index 5d84fc8c..a6118570 100644 --- a/.azure/azure-benchmark-pipeline.yaml +++ b/.azure/azure-benchmark-pipeline.yaml @@ -3,7 +3,7 @@ trigger: exclude: - '*' include: - - feature/2648_Automate-the-benchmarking-process # Trigger on creation of release branches + - 'refs/heads/feature/2648_Automate-the-benchmarking-process' schedules: - cron: "0 2 * * 1-5" # Run at 2 AM every weekday From 20c9719f1e70849b2823142b4d1720007ae7de55 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 26 Jun 2024 11:32:17 +0100 Subject: [PATCH 13/49] #2648 - added full branch name reference for push --- .azure/azure-benchmark-pipeline.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.azure/azure-benchmark-pipeline.yaml b/.azure/azure-benchmark-pipeline.yaml index a6118570..bc123454 100644 --- a/.azure/azure-benchmark-pipeline.yaml +++ b/.azure/azure-benchmark-pipeline.yaml @@ -65,7 +65,7 @@ steps: - script: | git add benchmark/results/v$(MAJOR_VERSION)/v$(VERSION)/* git commit -m "Automated benchmark output commit for version $(VERSION)" - git push origin HEAD:$(Build.SourceBranchName) + git push origin HEAD:refs/heads/$(Build.SourceBranchName) displayName: 'Commit and Push Benchmark Results' workingDirectory: $(System.DefaultWorkingDirectory) env: From 795b5a80fbf69add363f60cedc2928c1f6fb2e45 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 26 Jun 2024 12:00:56 +0100 Subject: [PATCH 14/49] #2648 - testing the new benchmark artifact name --- .azure/azure-benchmark-pipeline.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.azure/azure-benchmark-pipeline.yaml b/.azure/azure-benchmark-pipeline.yaml index bc123454..cd8703ec 100644 --- a/.azure/azure-benchmark-pipeline.yaml +++ b/.azure/azure-benchmark-pipeline.yaml @@ -74,12 +74,12 @@ steps: - script: | ls benchmark/results/v$(MAJOR_VERSION)/v$(VERSION) - tar czf benchmark.tar.gz benchmark/results/v$(MAJOR_VERSION)/v$(VERSION) + tar czf primaite_v$(VERSION)_benchmark.tar.gz benchmark/results/v$(MAJOR_VERSION)/v$(VERSION) displayName: 'Prepare Artifacts for Publishing' - task: PublishPipelineArtifact@1 inputs: - targetPath: benchmark.tar.gz # Path to the files you want to publish - artifactName: 'benchmark-output' # Name of the artifact + targetPath: primaite_v$(VERSION)_benchmark.tar.gz + artifactName: 'benchmark-output' publishLocation: 'pipeline' displayName: 'Publish Benchmark Output as Artifact' From 7a833afe2d608176eb2c775551b7b1093cab27e8 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 26 Jun 2024 12:20:28 +0100 Subject: [PATCH 15/49] #2656 - Unit tests for new ActionPenalty reward component, testing yaml and some minor changes to the implementation. Need to update Documentation to detail how this is added --- src/primaite/game/agent/rewards.py | 13 +- tests/assets/configs/action_penalty.yaml | 929 ++++++++++++++++++ .../game_layer/test_rewards.py | 66 +- 3 files changed, 999 insertions(+), 9 deletions(-) create mode 100644 tests/assets/configs/action_penalty.yaml diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index d75597f0..a0736bb0 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -361,17 +361,14 @@ class SharedReward(AbstractReward): class ActionPenalty(AbstractReward): - """ - Apply a negative reward when taking any action except DONOTHING. - - Optional Configuration item therefore default value of 0 (?). - """ + """Apply a negative reward when taking any action except DONOTHING.""" def __init__(self, agent_name: str, penalty: float): """ Initialise the reward. - Penalty will default to 0, as this is an optional param. + This negative reward should be applied when the agent in training chooses to take any + action that isn't DONOTHING. """ self.agent_name = agent_name self.penalty = penalty @@ -383,7 +380,7 @@ class ActionPenalty(AbstractReward): return 0 else: _LOGGER.info( - f"Blue agent has incurred a penalty of {self.penalty}, for action: {last_action_response.action}" + f"Blue Agent has incurred a penalty of {self.penalty}, for action: {last_action_response.action}" ) return self.penalty @@ -391,7 +388,7 @@ class ActionPenalty(AbstractReward): def from_config(cls, config: Dict) -> "ActionPenalty": """Build the ActionPenalty object from config.""" agent_name = config.get("agent_name") - penalty_value = config.get("penalty_value", 0) # default to 0 so that no adverse effects. + penalty_value = config.get("penalty_value", 0) # default to 0. return cls(agent_name=agent_name, penalty=penalty_value) diff --git a/tests/assets/configs/action_penalty.yaml b/tests/assets/configs/action_penalty.yaml new file mode 100644 index 00000000..4eb562fe --- /dev/null +++ b/tests/assets/configs/action_penalty.yaml @@ -0,0 +1,929 @@ +io_settings: + save_agent_actions: false + save_step_metadata: false + save_pcap_logs: false + save_sys_logs: false + + +game: + max_episode_length: 256 + ports: + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + thresholds: + nmne: + high: 10 + medium: 5 + low: 0 + +agents: + - ref: client_2_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_2 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_2 + + - ref: client_1_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_1 + + - ref: data_manipulation_attacker + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: null + + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: DataManipulationBot + - node_name: client_2 + applications: + - application_name: DataManipulationBot + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 + + - ref: defender + team: BLUE + type: ProxyAgent + + observation_space: + type: CUSTOM + options: + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1:eth-1<->switch_1:eth-8 + - router_1:eth-2<->switch_2:eth-8 + - switch_1:eth-1<->domain_controller:eth-1 + - switch_1:eth-2<->web_server:eth-1 + - switch_1:eth-3<->database_server:eth-1 + - switch_1:eth-4<->backup_server:eth-1 + - switch_1:eth-7<->security_suite:eth-1 + - switch_2:eth-1<->client_1:eth-1 + - switch_2:eth-2<->client_2:eth-1 + - switch_2:eth-7<->security_suite:eth-2 + - type: "NONE" + label: ICS + options: {} + + action_space: + action_list: + - type: DONOTHING + - type: NODE_SERVICE_SCAN + - type: NODE_SERVICE_STOP + - type: NODE_SERVICE_START + - type: NODE_SERVICE_PAUSE + - type: NODE_SERVICE_RESUME + - type: NODE_SERVICE_RESTART + - type: NODE_SERVICE_DISABLE + - type: NODE_SERVICE_ENABLE + - type: NODE_SERVICE_FIX + - type: NODE_FILE_SCAN + - type: NODE_FILE_CHECKHASH + - type: NODE_FILE_DELETE + - type: NODE_FILE_REPAIR + - type: NODE_FILE_RESTORE + - type: NODE_FOLDER_SCAN + - type: NODE_FOLDER_CHECKHASH + - type: NODE_FOLDER_REPAIR + - type: NODE_FOLDER_RESTORE + - type: NODE_OS_SCAN + - type: NODE_SHUTDOWN + - type: NODE_STARTUP + - type: NODE_RESET + - type: ROUTER_ACL_ADDRULE + - type: ROUTER_ACL_REMOVERULE + - type: HOST_NIC_ENABLE + - type: HOST_NIC_DISABLE + + action_map: + 0: + action: DONOTHING + options: {} + # scan webapp service + 1: + action: NODE_SERVICE_SCAN + options: + node_id: 1 + service_id: 0 + # stop webapp service + 2: + action: NODE_SERVICE_STOP + options: + node_id: 1 + service_id: 0 + # start webapp service + 3: + action: "NODE_SERVICE_START" + options: + node_id: 1 + service_id: 0 + 4: + action: "NODE_SERVICE_PAUSE" + options: + node_id: 1 + service_id: 0 + 5: + action: "NODE_SERVICE_RESUME" + options: + node_id: 1 + service_id: 0 + 6: + action: "NODE_SERVICE_RESTART" + options: + node_id: 1 + service_id: 0 + 7: + action: "NODE_SERVICE_DISABLE" + options: + node_id: 1 + service_id: 0 + 8: + action: "NODE_SERVICE_ENABLE" + options: + node_id: 1 + service_id: 0 + 9: # check database.db file + action: "NODE_FILE_SCAN" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 10: + action: "NODE_FILE_CHECKHASH" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 11: + action: "NODE_FILE_DELETE" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 12: + action: "NODE_FILE_REPAIR" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 13: + action: "NODE_SERVICE_FIX" + options: + node_id: 2 + service_id: 0 + 14: + action: "NODE_FOLDER_SCAN" + options: + node_id: 2 + folder_id: 0 + 15: + action: "NODE_FOLDER_CHECKHASH" + options: + node_id: 2 + folder_id: 0 + 16: + action: "NODE_FOLDER_REPAIR" + options: + node_id: 2 + folder_id: 0 + 17: + action: "NODE_FOLDER_RESTORE" + options: + node_id: 2 + folder_id: 0 + 18: + action: "NODE_OS_SCAN" + options: + node_id: 0 + 19: + action: "NODE_SHUTDOWN" + options: + node_id: 0 + 20: + action: NODE_STARTUP + options: + node_id: 0 + 21: + action: NODE_RESET + options: + node_id: 0 + 22: + action: "NODE_OS_SCAN" + options: + node_id: 1 + 23: + action: "NODE_SHUTDOWN" + options: + node_id: 1 + 24: + action: NODE_STARTUP + options: + node_id: 1 + 25: + action: NODE_RESET + options: + node_id: 1 + 26: # old action num: 18 + action: "NODE_OS_SCAN" + options: + node_id: 2 + 27: + action: "NODE_SHUTDOWN" + options: + node_id: 2 + 28: + action: NODE_STARTUP + options: + node_id: 2 + 29: + action: NODE_RESET + options: + node_id: 2 + 30: + action: "NODE_OS_SCAN" + options: + node_id: 3 + 31: + action: "NODE_SHUTDOWN" + options: + node_id: 3 + 32: + action: NODE_STARTUP + options: + node_id: 3 + 33: + action: NODE_RESET + options: + node_id: 3 + 34: + action: "NODE_OS_SCAN" + options: + node_id: 4 + 35: + action: "NODE_SHUTDOWN" + options: + node_id: 4 + 36: + action: NODE_STARTUP + options: + node_id: 4 + 37: + action: NODE_RESET + options: + node_id: 4 + 38: + action: "NODE_OS_SCAN" + options: + node_id: 5 + 39: # old action num: 19 # shutdown client 1 + action: "NODE_SHUTDOWN" + options: + node_id: 5 + 40: # old action num: 20 + action: NODE_STARTUP + options: + node_id: 5 + 41: # old action num: 21 + action: NODE_RESET + options: + node_id: 5 + 42: + action: "NODE_OS_SCAN" + options: + node_id: 6 + 43: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 44: + action: NODE_STARTUP + options: + node_id: 6 + 45: + action: NODE_RESET + options: + node_id: 6 + + 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" + action: "ROUTER_ACL_ADDRULE" + options: + target_router_nodename: router_1 + position: 1 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" + action: "ROUTER_ACL_ADDRULE" + options: + target_router_nodename: router_1 + position: 2 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 48: # old action num: 24 # block tcp traffic from client 1 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router_nodename: router_1 + position: 3 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 49: # old action num: 25 # block tcp traffic from client 2 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router_nodename: router_1 + position: 4 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 50: # old action num: 26 + action: "ROUTER_ACL_ADDRULE" + options: + target_router_nodename: router_1 + position: 5 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 51: # old action num: 27 + action: "ROUTER_ACL_ADDRULE" + options: + target_router_nodename: router_1 + position: 6 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 52: # old action num: 28 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router_nodename: router_1 + position: 0 + 53: # old action num: 29 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router_nodename: router_1 + position: 1 + 54: # old action num: 30 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router_nodename: router_1 + position: 2 + 55: # old action num: 31 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router_nodename: router_1 + position: 3 + 56: # old action num: 32 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router_nodename: router_1 + position: 4 + 57: # old action num: 33 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router_nodename: router_1 + position: 5 + 58: # old action num: 34 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router_nodename: router_1 + position: 6 + 59: # old action num: 35 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router_nodename: router_1 + position: 7 + 60: # old action num: 36 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router_nodename: router_1 + position: 8 + 61: # old action num: 37 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router_nodename: router_1 + position: 9 + 62: # old action num: 38 + action: "HOST_NIC_DISABLE" + options: + node_id: 0 + nic_id: 0 + 63: # old action num: 39 + action: "HOST_NIC_ENABLE" + options: + node_id: 0 + nic_id: 0 + 64: # old action num: 40 + action: "HOST_NIC_DISABLE" + options: + node_id: 1 + nic_id: 0 + 65: # old action num: 41 + action: "HOST_NIC_ENABLE" + options: + node_id: 1 + nic_id: 0 + 66: # old action num: 42 + action: "HOST_NIC_DISABLE" + options: + node_id: 2 + nic_id: 0 + 67: # old action num: 43 + action: "HOST_NIC_ENABLE" + options: + node_id: 2 + nic_id: 0 + 68: # old action num: 44 + action: "HOST_NIC_DISABLE" + options: + node_id: 3 + nic_id: 0 + 69: # old action num: 45 + action: "HOST_NIC_ENABLE" + options: + node_id: 3 + nic_id: 0 + 70: # old action num: 46 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 0 + 71: # old action num: 47 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 0 + 72: # old action num: 48 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 1 + 73: # old action num: 49 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 1 + 74: # old action num: 50 + action: "HOST_NIC_DISABLE" + options: + node_id: 5 + nic_id: 0 + 75: # old action num: 51 + action: "HOST_NIC_ENABLE" + options: + node_id: 5 + nic_id: 0 + 76: # old action num: 52 + action: "HOST_NIC_DISABLE" + options: + node_id: 6 + nic_id: 0 + 77: # old action num: 53 + action: "HOST_NIC_ENABLE" + options: + node_id: 6 + nic_id: 0 + + + + options: + nodes: + - node_name: domain_controller + - node_name: web_server + applications: + - application_name: DatabaseClient + services: + - service_name: WebServer + - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db + services: + - service_name: DatabaseService + - node_name: backup_server + - node_name: security_suite + - node_name: client_1 + - node_name: client_2 + + max_folders_per_node: 2 + max_files_per_folder: 2 + max_services_per_node: 2 + max_nics_per_node: 8 + max_acl_rules: 10 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + + + reward_function: + reward_components: + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_1_green_user + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_2_green_user + - type: ACTION_PENALTY + weight: 1.0 + options: + agent_name: defender + penalty_value: -1 + + + agent_settings: + flatten_obs: true + + + +simulation: + network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE + nodes: + + - hostname: router_1 + type: router + num_ports: 5 + ports: + 1: + ip_address: 192.168.1.1 + subnet_mask: 255.255.255.0 + 2: + ip_address: 192.168.10.1 + subnet_mask: 255.255.255.0 + acl: + 18: + action: PERMIT + src_port: POSTGRES_SERVER + dst_port: POSTGRES_SERVER + 19: + action: PERMIT + src_port: DNS + dst_port: DNS + 20: + action: PERMIT + src_port: FTP + dst_port: FTP + 21: + action: PERMIT + src_port: HTTP + dst_port: HTTP + 22: + action: PERMIT + src_port: ARP + dst_port: ARP + 23: + action: PERMIT + protocol: ICMP + + - hostname: switch_1 + type: switch + num_ports: 8 + + - hostname: switch_2 + type: switch + num_ports: 8 + + - hostname: domain_controller + type: server + ip_address: 192.168.1.10 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + services: + - type: DNSServer + options: + domain_mapping: + arcd.com: 192.168.1.12 # web server + + - hostname: web_server + type: server + ip_address: 192.168.1.12 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: WebServer + applications: + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + + + - hostname: database_server + type: server + ip_address: 192.168.1.14 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: DatabaseService + options: + backup_server_ip: 192.168.1.16 + - type: FTPClient + + - hostname: backup_server + type: server + ip_address: 192.168.1.16 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: FTPServer + + - hostname: security_suite + type: server + ip_address: 192.168.1.110 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + network_interfaces: + 2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot + ip_address: 192.168.10.110 + subnet_mask: 255.255.255.0 + + - hostname: client_1 + type: computer + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + services: + - type: DNSClient + + - hostname: client_2 + type: computer + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + services: + - type: DNSClient + + + + links: + - endpoint_a_hostname: router_1 + endpoint_a_port: 1 + endpoint_b_hostname: switch_1 + endpoint_b_port: 8 + - endpoint_a_hostname: router_1 + endpoint_a_port: 2 + endpoint_b_hostname: switch_2 + endpoint_b_port: 8 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 1 + endpoint_b_hostname: domain_controller + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 2 + endpoint_b_hostname: web_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 3 + endpoint_b_hostname: database_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 4 + endpoint_b_hostname: backup_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 7 + endpoint_b_hostname: security_suite + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 1 + endpoint_b_hostname: client_1 + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 2 + endpoint_b_hostname: client_2 + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 7 + endpoint_b_hostname: security_suite + endpoint_b_port: 2 diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index db2b0c3a..95e70271 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -2,7 +2,7 @@ import yaml from primaite.game.agent.interface import AgentHistoryItem -from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty +from primaite.game.agent.rewards import ActionPenalty, GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.hardware.nodes.host.server import Server @@ -119,3 +119,67 @@ def test_shared_reward(): g2_reward = env.game.agents["client_2_green_user"].reward_function.current_reward blue_reward = env.game.agents["defender"].reward_function.current_reward assert blue_reward == g1_reward + g2_reward + + +def test_action_penalty_loads_from_config(): + """Test to ensure that action penalty is correctly loaded from config into PrimaiteGymEnv""" + CFG_PATH = TEST_ASSETS_ROOT / "configs/action_penalty.yaml" + with open(CFG_PATH, "r") as f: + cfg = yaml.safe_load(f) + + env = PrimaiteGymEnv(env_config=cfg) + + env.reset() + + ActionPenalty_Value = env.game.agents["defender"].reward_function.reward_components[2][0].penalty + CFG_Penalty_Value = cfg["agents"][3]["reward_function"]["reward_components"][2]["options"]["penalty_value"] + + assert ActionPenalty_Value == CFG_Penalty_Value + + +def test_action_penalty(game_and_agent): + """Test that the action penalty is correctly applied when agent performs any action""" + + # Create an ActionPenalty Reward + Penalty = ActionPenalty(agent_name="Test_Blue_Agent", penalty=-1.0) + + game, _ = game_and_agent + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + server_1.software_manager.install(DatabaseService) + db_service = server_1.software_manager.software.get("DatabaseService") + db_service.start() + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + client_1.software_manager.install(DatabaseClient) + db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") + db_client.configure(server_ip_address=server_1.network_interface[1].ip_address) + db_client.run() + + response = db_client.apply_request( + [ + "execute", + ] + ) + + state = game.get_sim_state() + + # Assert that penalty is applied if action isn't DONOTHING + reward_value = Penalty.calculate( + state, + last_action_response=AgentHistoryItem( + timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response + ), + ) + + assert reward_value == -1.0 + + # Assert that no penalty applied for a DONOTHING action + reward_value = Penalty.calculate( + state, + last_action_response=AgentHistoryItem( + timestep=0, action="DONOTHING", parameters={}, request=["execute"], response=response + ), + ) + + assert reward_value == 0 From e7f979b78e566139be8ccd50f89b8f27be0f9c04 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 26 Jun 2024 12:52:29 +0100 Subject: [PATCH 16/49] #2648 - reverted temp changes to benchmark durations and branch names for testing purposes --- .azure/azure-benchmark-pipeline.yaml | 13 ++++++------- benchmark/primaite_benchmark.py | 4 ++-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/.azure/azure-benchmark-pipeline.yaml b/.azure/azure-benchmark-pipeline.yaml index cd8703ec..1f7b8ebe 100644 --- a/.azure/azure-benchmark-pipeline.yaml +++ b/.azure/azure-benchmark-pipeline.yaml @@ -3,14 +3,14 @@ trigger: exclude: - '*' include: - - 'refs/heads/feature/2648_Automate-the-benchmarking-process' + - 'refs/heads/release/*' schedules: - cron: "0 2 * * 1-5" # Run at 2 AM every weekday displayName: "Weekday Schedule" branches: include: - - feature/2648_Automate-the-benchmarking-process + - 'refs/heads/dev' pool: vmImage: ubuntu-latest @@ -25,8 +25,8 @@ steps: - script: | VERSION=$(cat src/primaite/VERSION | tr -d '\n') - if [[ "$(Build.SourceBranch)" == "refs/heads/feature/2648_Automate-the-benchmarking-process" ]]; then - DATE=$(date +%Y%m%d%H%M) + if [[ "$(Build.SourceBranch)" == "refs/heads/dev" ]]; then + DATE=$(date +%Y%m%d) echo "${VERSION}+dev.${DATE}" > src/primaite/VERSION fi displayName: 'Update VERSION file for Dev Benchmark' @@ -60,7 +60,7 @@ steps: git config --global user.name "Defence Science and Technology Laboratory UK" workingDirectory: $(System.DefaultWorkingDirectory) displayName: 'Configure Git' - condition: and(succeeded(), eq(variables['Build.Reason'], 'Manual'), startsWith(variables['Build.SourceBranch'], 'refs/heads/feature/2648_Automate-the-benchmarking-process')) + condition: and(succeeded(), eq(variables['Build.Reason'], 'Manual'), startsWith(variables['Build.SourceBranch'], 'refs/heads/release')) - script: | git add benchmark/results/v$(MAJOR_VERSION)/v$(VERSION)/* @@ -70,10 +70,9 @@ steps: workingDirectory: $(System.DefaultWorkingDirectory) env: GIT_CREDENTIALS: $(System.AccessToken) - condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/heads/feature/2648_Automate-the-benchmarking-process')) + condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/heads/release')) - script: | - ls benchmark/results/v$(MAJOR_VERSION)/v$(VERSION) tar czf primaite_v$(VERSION)_benchmark.tar.gz benchmark/results/v$(MAJOR_VERSION)/v$(VERSION) displayName: 'Prepare Artifacts for Publishing' diff --git a/benchmark/primaite_benchmark.py b/benchmark/primaite_benchmark.py index 92c9bf0a..f3d0a10c 100644 --- a/benchmark/primaite_benchmark.py +++ b/benchmark/primaite_benchmark.py @@ -151,8 +151,8 @@ def _prepare_session_directory(): def run( - number_of_sessions: int = 1, - num_episodes: int = 25, + number_of_sessions: int = 5, + num_episodes: int = 1000, episode_len: int = 128, n_steps: int = 1280, batch_size: int = 32, From 7f0a0c562fd89858fc1d39ef96fcef04da775063 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 26 Jun 2024 13:26:18 +0100 Subject: [PATCH 17/49] #2648 - changed av_reward_per_episode to total_reward_per_episode in primaite_benchmark.py and report.py --- benchmark/primaite_benchmark.py | 2 +- benchmark/report.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/benchmark/primaite_benchmark.py b/benchmark/primaite_benchmark.py index f3d0a10c..27e25a0c 100644 --- a/benchmark/primaite_benchmark.py +++ b/benchmark/primaite_benchmark.py @@ -124,7 +124,7 @@ class BenchmarkSession: "total_s": total_s, "s_per_step": s_per_step, "s_per_100_steps_10_nodes": s_per_100_steps_10_nodes, - "av_reward_per_episode": self.gym_env.total_reward_per_episode, + "total_reward_per_episode": self.gym_env.total_reward_per_episode, } diff --git a/benchmark/report.py b/benchmark/report.py index 6a71ef57..dc8e51e4 100644 --- a/benchmark/report.py +++ b/benchmark/report.py @@ -35,19 +35,19 @@ def _build_benchmark_results_dict(start_datetime: datetime, metadata_dict: Dict, "av_s_per_step": sum(d["s_per_step"] for d in metadata_dict.values()) / num_sessions, "av_s_per_100_steps_10_nodes": sum(d["s_per_100_steps_10_nodes"] for d in metadata_dict.values()) / num_sessions, - "combined_av_reward_per_episode": {}, - "session_av_reward_per_episode": {k: v["av_reward_per_episode"] for k, v in metadata_dict.items()}, + "combined_total_reward_per_episode": {}, + "session_total_reward_per_episode": {k: v["total_reward_per_episode"] for k, v in metadata_dict.items()}, "config": config, } # find the average of each episode across all sessions - episodes = metadata_dict[1]["av_reward_per_episode"].keys() + episodes = metadata_dict[1]["total_reward_per_episode"].keys() for episode in episodes: combined_av_reward = ( - sum(metadata_dict[k]["av_reward_per_episode"][episode] for k in metadata_dict.keys()) / num_sessions + sum(metadata_dict[k]["total_reward_per_episode"][episode] for k in metadata_dict.keys()) / num_sessions ) - averaged_data["combined_av_reward_per_episode"][episode] = combined_av_reward + averaged_data["combined_total_reward_per_episode"][episode] = combined_av_reward return averaged_data @@ -83,7 +83,7 @@ def _plot_benchmark_metadata( fig = go.Figure(layout=layout) fig.update_layout(template=PLOT_CONFIG["template"]) - for session, av_reward_dict in benchmark_metadata_dict["session_av_reward_per_episode"].items(): + for session, av_reward_dict in benchmark_metadata_dict["session_total_reward_per_episode"].items(): df = _get_df_from_episode_av_reward_dict(av_reward_dict) fig.add_trace( go.Scatter( @@ -96,7 +96,7 @@ def _plot_benchmark_metadata( ) ) - df = _get_df_from_episode_av_reward_dict(benchmark_metadata_dict["combined_av_reward_per_episode"]) + df = _get_df_from_episode_av_reward_dict(benchmark_metadata_dict["combined_total_reward_per_episode"]) fig.add_trace( go.Scatter( x=df["episode"], y=df["av_reward"], mode="lines", name="Combined Session Av", line={"color": "#FF0000"} @@ -132,7 +132,7 @@ def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure: Does this by iterating over the ``benchmark/results`` directory and extracting the benchmark metadata json for each version that has been - benchmarked. The combined_av_reward_per_episode is extracted from each, + benchmarked. The combined_total_reward_per_episode is extracted from each, converted into a polars dataframe, and plotted as a scatter line in plotly. """ major_v = primaite.__version__.split(".")[0] @@ -158,7 +158,7 @@ def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure: metadata_file = dir / f"{dir.name}_benchmark_metadata.json" with open(metadata_file, "r") as file: metadata_dict = json.load(file) - df = _get_df_from_episode_av_reward_dict(metadata_dict["combined_av_reward_per_episode"]) + df = _get_df_from_episode_av_reward_dict(metadata_dict["combined_total_reward_per_episode"]) fig.add_trace(go.Scatter(x=df["episode"], y=df["rolling_av_reward"], mode="lines", name=dir.name)) From e204afff6f1c6526356cfd7d76d367b5361df6f0 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 26 Jun 2024 20:58:52 +0100 Subject: [PATCH 18/49] #2656 - Removing the change to Data_Manipulation.yaml as this isn't necessary --- src/primaite/config/_package_data/data_manipulation.yaml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index be613918..f320c22f 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -739,12 +739,6 @@ agents: options: agent_name: client_2_green_user - - type: ACTION_PENALTY - weight: 1.0 - options: - agent_name: defender - penalty_value: -1 - agent_settings: flatten_obs: true From 7a680678aa4e69355f1c2a11bf2c8157f2bae321 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 27 Jun 2024 12:01:32 +0100 Subject: [PATCH 19/49] #2656 - Make action penalty more configurable --- src/primaite/game/agent/rewards.py | 28 ++-- tests/assets/configs/action_penalty.yaml | 141 +----------------- .../game_layer/test_rewards.py | 80 +++++----- 3 files changed, 62 insertions(+), 187 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index a0736bb0..4a17e9a5 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -363,33 +363,33 @@ class SharedReward(AbstractReward): class ActionPenalty(AbstractReward): """Apply a negative reward when taking any action except DONOTHING.""" - def __init__(self, agent_name: str, penalty: float): + def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None: """ Initialise the reward. - This negative reward should be applied when the agent in training chooses to take any - action that isn't DONOTHING. + Reward or penalise agents for doing nothing or taking actions. + + :param action_penalty: Reward to give agents for taking any action except DONOTHING + :type action_penalty: float + :param do_nothing_penalty: Reward to give agent for taking the DONOTHING action + :type do_nothing_penalty: float """ - self.agent_name = agent_name - self.penalty = penalty + self.action_penalty = action_penalty + self.do_nothing_penalty = do_nothing_penalty def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the penalty to be applied.""" if last_action_response.action == "DONOTHING": - # No penalty for doing nothing at present - return 0 + return self.do_nothing_penalty else: - _LOGGER.info( - f"Blue Agent has incurred a penalty of {self.penalty}, for action: {last_action_response.action}" - ) - return self.penalty + return self.action_penalty @classmethod def from_config(cls, config: Dict) -> "ActionPenalty": """Build the ActionPenalty object from config.""" - agent_name = config.get("agent_name") - penalty_value = config.get("penalty_value", 0) # default to 0. - return cls(agent_name=agent_name, penalty=penalty_value) + action_penalty = config.get("action_penalty", -1.0) + do_nothing_penalty = config.get("do_nothing_penalty", 0.0) + return cls(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty) class RewardFunction: diff --git a/tests/assets/configs/action_penalty.yaml b/tests/assets/configs/action_penalty.yaml index 4eb562fe..1771ba5f 100644 --- a/tests/assets/configs/action_penalty.yaml +++ b/tests/assets/configs/action_penalty.yaml @@ -21,135 +21,6 @@ game: low: 0 agents: - - ref: client_2_green_user - team: GREEN - type: ProbabilisticAgent - agent_settings: - action_probabilities: - 0: 0.3 - 1: 0.6 - 2: 0.1 - observation_space: null - action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_2 - applications: - - application_name: WebBrowser - - application_name: DatabaseClient - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 2 - action_map: - 0: - action: DONOTHING - options: {} - 1: - action: NODE_APPLICATION_EXECUTE - options: - node_id: 0 - application_id: 0 - 2: - action: NODE_APPLICATION_EXECUTE - options: - node_id: 0 - application_id: 1 - - reward_function: - reward_components: - - type: WEBPAGE_UNAVAILABLE_PENALTY - weight: 0.25 - options: - node_hostname: client_2 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY - weight: 0.05 - options: - node_hostname: client_2 - - - ref: client_1_green_user - team: GREEN - type: ProbabilisticAgent - agent_settings: - action_probabilities: - 0: 0.3 - 1: 0.6 - 2: 0.1 - observation_space: null - action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_1 - applications: - - application_name: WebBrowser - - application_name: DatabaseClient - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 2 - action_map: - 0: - action: DONOTHING - options: {} - 1: - action: NODE_APPLICATION_EXECUTE - options: - node_id: 0 - application_id: 0 - 2: - action: NODE_APPLICATION_EXECUTE - options: - node_id: 0 - application_id: 1 - - reward_function: - reward_components: - - type: WEBPAGE_UNAVAILABLE_PENALTY - weight: 0.25 - options: - node_hostname: client_1 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY - weight: 0.05 - options: - node_hostname: client_1 - - - ref: data_manipulation_attacker - team: RED - type: RedDatabaseCorruptingAgent - - observation_space: null - - action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_1 - applications: - - application_name: DataManipulationBot - - node_name: client_2 - applications: - - application_name: DataManipulationBot - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - - reward_function: - reward_components: - - type: DUMMY - - agent_settings: # options specific to this particular agent type, basically args of __init__(self) - start_settings: - start_step: 25 - frequency: 20 - variance: 5 - ref: defender team: BLUE @@ -712,19 +583,11 @@ agents: reward_function: reward_components: - - type: SHARED_REWARD - weight: 1.0 - options: - agent_name: client_1_green_user - - type: SHARED_REWARD - weight: 1.0 - options: - agent_name: client_2_green_user - type: ACTION_PENALTY weight: 1.0 options: - agent_name: defender - penalty_value: -1 + action_penalty: -0.75 + do_nothing_penalty: 0.125 agent_settings: diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 95e70271..2bf551c8 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -1,9 +1,11 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest import yaml from primaite.game.agent.interface import AgentHistoryItem from primaite.game.agent.rewards import ActionPenalty, GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty from primaite.game.game import PrimaiteGame +from primaite.interface.request import RequestResponse from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router @@ -130,56 +132,66 @@ def test_action_penalty_loads_from_config(): env = PrimaiteGymEnv(env_config=cfg) env.reset() - - ActionPenalty_Value = env.game.agents["defender"].reward_function.reward_components[2][0].penalty - CFG_Penalty_Value = cfg["agents"][3]["reward_function"]["reward_components"][2]["options"]["penalty_value"] - - assert ActionPenalty_Value == CFG_Penalty_Value + defender = env.game.agents["defender"] + act_penalty_obj = None + for comp in defender.reward_function.reward_components: + if isinstance(comp[0], ActionPenalty): + act_penalty_obj = comp[0] + if act_penalty_obj is None: + pytest.fail("Action penalty reward component was not added to the agent from config.") + assert act_penalty_obj.action_penalty == -0.75 + assert act_penalty_obj.do_nothing_penalty == 0.125 -def test_action_penalty(game_and_agent): +def test_action_penalty(): """Test that the action penalty is correctly applied when agent performs any action""" # Create an ActionPenalty Reward - Penalty = ActionPenalty(agent_name="Test_Blue_Agent", penalty=-1.0) - - game, _ = game_and_agent - - server_1: Server = game.simulation.network.get_node_by_hostname("server_1") - server_1.software_manager.install(DatabaseService) - db_service = server_1.software_manager.software.get("DatabaseService") - db_service.start() - - client_1 = game.simulation.network.get_node_by_hostname("client_1") - client_1.software_manager.install(DatabaseClient) - db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") - db_client.configure(server_ip_address=server_1.network_interface[1].ip_address) - db_client.run() - - response = db_client.apply_request( - [ - "execute", - ] - ) - - state = game.get_sim_state() + Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125) # Assert that penalty is applied if action isn't DONOTHING reward_value = Penalty.calculate( - state, + state={}, last_action_response=AgentHistoryItem( - timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response + timestep=0, + action="NODE_APPLICATION_EXECUTE", + parameters={"node_id": 0, "application_id": 1}, + request=["execute"], + response=RequestResponse.from_bool(True), ), ) - assert reward_value == -1.0 + assert reward_value == -0.75 # Assert that no penalty applied for a DONOTHING action reward_value = Penalty.calculate( - state, + state={}, last_action_response=AgentHistoryItem( - timestep=0, action="DONOTHING", parameters={}, request=["execute"], response=response + timestep=0, + action="DONOTHING", + parameters={}, + request=["do_nothing"], + response=RequestResponse.from_bool(True), ), ) - assert reward_value == 0 + assert reward_value == 0.125 + + +def test_action_penalty_e2e(game_and_agent): + """Test that we get the right reward for doing actions to fetch a website.""" + game, agent = game_and_agent + agent: ControlledAgent + comp = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125) + + agent.reward_function.register_component(comp, 1.0) + + action = ("DONOTHING", {}) + agent.store_action(action) + game.step() + assert agent.reward_function.current_reward == 0.125 + + action = ("NODE_FILE_SCAN", {"node_id": 0, "folder_id": 0, "file_id": 0}) + agent.store_action(action) + game.step() + assert agent.reward_function.current_reward == -0.75 From f796babf93a9ff9d78561598445b9bf0b8ba0985 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 28 Jun 2024 11:57:54 +0100 Subject: [PATCH 20/49] #2705 - Move application registry into application module instead of hardcoding in game module --- src/primaite/game/game.py | 28 ++++++++----------- .../create-simulation_demo.ipynb | 2 +- src/primaite/simulator/core.py | 2 +- .../system/applications/application.py | 18 +++++++++++- .../system/applications/database_client.py | 2 +- .../simulator/system/applications/nmap.py | 2 +- .../red_applications/data_manipulation_bot.py | 2 +- .../applications/red_applications/dos_bot.py | 2 +- .../red_applications/ransomware_script.py | 2 +- .../system/applications/web_browser.py | 2 +- tests/conftest.py | 8 +++--- ...software_installation_and_configuration.py | 5 ++-- .../network/test_broadcast.py | 2 +- ...test_multi_lan_internet_example_network.py | 2 +- .../test_simulation/test_request_response.py | 4 +-- .../test_application_registry.py | 22 +++++++++++++++ 16 files changed, 69 insertions(+), 36 deletions(-) create mode 100644 tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_registry.py diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 8a79d068..05210278 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -26,11 +26,14 @@ from primaite.simulator.network.hardware.nodes.network.wireless_router import Wi from primaite.simulator.network.nmne import set_nmne_config from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation -from primaite.simulator.system.applications.database_client import DatabaseClient -from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot -from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot -from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript -from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.simulator.system.applications.application import Application +from primaite.simulator.system.applications.database_client import DatabaseClient # noqa: F401 +from primaite.simulator.system.applications.red_applications.data_manipulation_bot import ( # noqa: F401 + DataManipulationBot, +) +from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot # noqa: F401 +from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript # noqa: F401 +from primaite.simulator.system.applications.web_browser import WebBrowser # noqa: F401 from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer @@ -42,15 +45,6 @@ from primaite.simulator.system.services.web_server.web_server import WebServer _LOGGER = getLogger(__name__) -APPLICATION_TYPES_MAPPING = { - "WebBrowser": WebBrowser, - "DatabaseClient": DatabaseClient, - "DataManipulationBot": DataManipulationBot, - "DoSBot": DoSBot, - "RansomwareScript": RansomwareScript, -} -"""List of available applications that can be installed on nodes in the PrimAITE Simulation.""" - SERVICE_TYPES_MAPPING = { "DNSClient": DNSClient, "DNSServer": DNSServer, @@ -333,9 +327,9 @@ class PrimaiteGame: new_application = None application_type = application_cfg["type"] - if application_type in APPLICATION_TYPES_MAPPING: - new_node.software_manager.install(APPLICATION_TYPES_MAPPING[application_type]) - new_application = new_node.software_manager.software[application_type] + if application_type in Application._application_registry: + new_node.software_manager.install(Application._application_registry[application_type]) + new_application = new_node.software_manager.software[application_type] # grab the instance else: msg = f"Configuration contains an invalid application type: {application_type}" _LOGGER.error(msg) diff --git a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb index 9f4abbf3..77ac4842 100644 --- a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb +++ b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb @@ -171,7 +171,7 @@ "from primaite.simulator.file_system.file_system import FileSystem\n", "\n", "# no applications exist yet so we will create our own.\n", - "class MSPaint(Application):\n", + "class MSPaint(Application, identifier=\"MSPaint\"):\n", " def describe_state(self):\n", " return super().describe_state()" ] diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 8c7d64c9..8d8425ec 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -196,7 +196,7 @@ class SimComponent(BaseModel): ..code::python - class WebBrowser(Application): + class WebBrowser(Application, identifier="WebBrowser"): def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() # all requests generic to any Application get initialised rm.add_request(...) # initialise any requests specific to the web browser diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 1b9a9657..98ccc27f 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from abc import abstractmethod from enum import Enum -from typing import Any, Dict, Optional, Set +from typing import Any, ClassVar, Dict, Optional, Set, Type from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType @@ -39,6 +39,22 @@ class Application(IOSoftware): install_countdown: Optional[int] = None "The countdown to the end of the installation process. None if not currently installing" + _application_registry: ClassVar[Dict[str, Type["Application"]]] = {} + """Registry of application types. Automatically populated when subclasses are defined.""" + + def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: + """ + Register an application type. + + :param identifier: Uniquely specifies an application class by name. + :type identifier: str + :raises ValueError: When attempting to register an application with a name that is already allocated. + """ + super().__init_subclass__(**kwargs) + if identifier in cls._application_registry: + raise ValueError(f"Tried to define new application {identifier}, but this name is already reserved.") + cls._application_registry[identifier] = cls + def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index bae2139b..bc2d426e 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -54,7 +54,7 @@ class DatabaseClientConnection(BaseModel): self.client._disconnect(self.connection_id) # noqa -class DatabaseClient(Application): +class DatabaseClient(Application, identifier="DatabaseClient"): """ A DatabaseClient application. diff --git a/src/primaite/simulator/system/applications/nmap.py b/src/primaite/simulator/system/applications/nmap.py index d8af1b7b..c87eaaf5 100644 --- a/src/primaite/simulator/system/applications/nmap.py +++ b/src/primaite/simulator/system/applications/nmap.py @@ -44,7 +44,7 @@ class PortScanPayload(SimComponent): return state -class NMAP(Application): +class NMAP(Application, identifier="NMAP"): """ A class representing the NMAP application for network scanning. diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index cf03d901..fefb22c3 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -37,7 +37,7 @@ class DataManipulationAttackStage(IntEnum): "Signifies that the attack has failed." -class DataManipulationBot(Application): +class DataManipulationBot(Application, identifier="DataManipulationBot"): """A bot that simulates a script which performs a SQL injection attack.""" payload: Optional[str] = None diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index 65e34227..6bfce9ba 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -29,7 +29,7 @@ class DoSAttackStage(IntEnum): "Attack is completed." -class DoSBot(DatabaseClient): +class DoSBot(DatabaseClient, identifier="DoSBot"): """A bot that simulates a Denial of Service attack.""" target_ip_address: Optional[IPv4Address] = None diff --git a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py index af4a59d4..3723585b 100644 --- a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py +++ b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py @@ -10,7 +10,7 @@ from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection -class RansomwareScript(Application): +class RansomwareScript(Application, identifier="RansomwareScript"): """Ransomware Kill Chain - Designed to be used by the TAP001 Agent on the example layout Network. :ivar payload: The attack stage query payload. (Default ENCRYPT) diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 19cc4065..73791676 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -23,7 +23,7 @@ from primaite.simulator.system.services.dns.dns_client import DNSClient _LOGGER = getLogger(__name__) -class WebBrowser(Application): +class WebBrowser(Application, identifier="WebBrowser"): """ Represents a web browser in the simulation environment. diff --git a/tests/conftest.py b/tests/conftest.py index b8359323..171e1996 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,7 +51,7 @@ class TestService(Service): pass -class TestApplication(Application): +class DummyApplication(Application, identifier="DummyApplication"): """Test Application class""" def __init__(self, **kwargs): @@ -85,15 +85,15 @@ def service_class(): @pytest.fixture(scope="function") -def application(file_system) -> TestApplication: - return TestApplication( +def application(file_system) -> DummyApplication: + return DummyApplication( name="TestApplication", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_application") ) @pytest.fixture(scope="function") def application_class(): - return TestApplication + return DummyApplication @pytest.fixture(scope="function") diff --git a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py index 4da1b674..3e06d371 100644 --- a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py +++ b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py @@ -9,9 +9,10 @@ from primaite.config.load import data_manipulation_config_path from primaite.game.agent.interface import ProxyAgent from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent -from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING +from primaite.game.game import PrimaiteGame, SERVICE_TYPES_MAPPING from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot @@ -85,7 +86,7 @@ def test_node_software_install(): assert client_2.software_manager.software.get(software.__name__) is not None # check that applications have been installed on client 1 - for applications in APPLICATION_TYPES_MAPPING: + for applications in Application._application_registry: assert client_1.software_manager.software.get(applications) is not None # check that services have been installed on client 1 diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index 8f65344f..b89d6db6 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -41,7 +41,7 @@ class BroadcastService(Service): super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port.HTTP, ip_protocol=self.protocol) -class BroadcastClient(Application): +class BroadcastClient(Application, identifier="BroadcastClient"): """A client application to receive broadcast and unicast messages.""" payloads_received: List = [] diff --git a/tests/integration_tests/network/test_multi_lan_internet_example_network.py b/tests/integration_tests/network/test_multi_lan_internet_example_network.py index fa290b79..bcc9ad94 100644 --- a/tests/integration_tests/network/test_multi_lan_internet_example_network.py +++ b/tests/integration_tests/network/test_multi_lan_internet_example_network.py @@ -3,9 +3,9 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.networks import multi_lan_internet_network_example from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.ftp.ftp_client import FTPClient -from src.primaite.simulator.system.applications.web_browser import WebBrowser def test_all_with_configured_dns_server_ip_can_resolve_url(): diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py index 79c72339..8c58bb42 100644 --- a/tests/integration_tests/test_simulation/test_request_response.py +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -13,7 +13,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.transmission.transport_layer import Port -from tests.conftest import TestApplication, TestService +from tests.conftest import DummyApplication, TestService def test_successful_node_file_system_creation_request(example_network): @@ -47,7 +47,7 @@ def test_successful_application_requests(example_network): net = example_network client_1 = net.get_node_by_hostname("client_1") - client_1.software_manager.install(TestApplication) + client_1.software_manager.install(DummyApplication) client_1.software_manager.software.get("TestApplication").run() resp_1 = net.apply_request(["node", "client_1", "application", "TestApplication", "scan"]) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_registry.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_registry.py new file mode 100644 index 00000000..d8d7dfab --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_registry.py @@ -0,0 +1,22 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest + +from primaite.simulator.system.applications.application import Application + + +def test_adding_to_app_registry(): + class temp_application(Application, identifier="temp_app"): + pass + + assert Application._application_registry["temp_app"] is temp_application + + with pytest.raises(ValueError): + + class another_application(Application, identifier="temp_app"): + pass + + # This is kinda evil... + # Because pytest doesn't reimport classes from modules, registering this temporary test application will change the + # state of the Application registry for all subsequently run tests. So, we have to delete and unregister the class. + del temp_application + Application._application_registry.pop("temp_app") From 1ebeb27c53d5ab207938787f11d71943095ea4ad Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 28 Jun 2024 12:03:05 +0100 Subject: [PATCH 21/49] #2705 Update documentation link --- .../simulation_components/system/list_of_applications.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/simulation_components/system/list_of_applications.rst b/docs/source/simulation_components/system/list_of_applications.rst index a1d8bfd4..94090d93 100644 --- a/docs/source/simulation_components/system/list_of_applications.rst +++ b/docs/source/simulation_components/system/list_of_applications.rst @@ -8,7 +8,7 @@ applications/* -More info :py:mod:`primaite.game.game.APPLICATION_TYPES_MAPPING` +More info :py:mod:`primaite.simulator.system.applications.application.Application` .. include:: list_of_system_applications.rst From 3ac97f8c3f111c70ee28b1f2f839e70afbbf0eaf Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 28 Jun 2024 13:07:57 +0100 Subject: [PATCH 22/49] #2641: Added a check for software health state in db service + tests --- .../services/database/database_service.py | 10 +- .../system/test_database_on_node.py | 105 ++++++++++++++++++ .../test_web_client_server_and_database.py | 24 ++++ 3 files changed, 138 insertions(+), 1 deletion(-) diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index cbd640f6..22ae0ff3 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -197,7 +197,11 @@ class DatabaseService(Service): status_code = 503 # service unavailable if self.health_state_actual == SoftwareHealthState.OVERWHELMED: self.sys_log.error(f"{self.name}: Connect request for {src_ip=} declined. Service is at capacity.") - if self.health_state_actual == SoftwareHealthState.GOOD: + if self.health_state_actual in [ + SoftwareHealthState.GOOD, + SoftwareHealthState.FIXING, + SoftwareHealthState.COMPROMISED, + ]: if self.password == password: status_code = 200 # ok connection_id = self._generate_connection_id() @@ -244,6 +248,10 @@ class DatabaseService(Service): self.sys_log.error(f"{self.name}: Failed to run {query} because the database file is missing.") return {"status_code": 404, "type": "sql", "data": False} + if self.health_state_actual is not SoftwareHealthState.GOOD: + self.sys_log.error(f"{self.name}: Failed to run {query} because the database service is unavailable.") + return {"status_code": 500, "type": "sql", "data": False} + if query == "SELECT": if self.db_file.health_status == FileSystemItemHealthStatus.CORRUPT: return { diff --git a/tests/integration_tests/system/test_database_on_node.py b/tests/integration_tests/system/test_database_on_node.py index 8da3bb1a..965b4ae8 100644 --- a/tests/integration_tests/system/test_database_on_node.py +++ b/tests/integration_tests/system/test_database_on_node.py @@ -14,6 +14,7 @@ from primaite.simulator.system.applications.database_client import DatabaseClien from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.simulator.system.software import SoftwareHealthState @pytest.fixture(scope="function") @@ -213,6 +214,110 @@ def test_restore_backup_after_deleting_file_without_updating_scan(uc2_network): assert db_service.db_file.visible_health_status == FileSystemItemHealthStatus.GOOD # now looks good +def test_database_service_fix(uc2_network): + """Test that the software fix applies to database service.""" + db_server: Server = uc2_network.get_node_by_hostname("database_server") + db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + + assert db_service.backup_database() is True + + # delete database locally + db_service.file_system.delete_file(folder_name="database", file_name="database.db") + + # db file is gone, reduced to atoms + assert db_service.db_file is None + + db_service.fix() # fix the database service + + assert db_service.health_state_actual == SoftwareHealthState.FIXING + + # apply timestep until the fix is applied + for i in range(db_service.fixing_duration + 1): + uc2_network.apply_timestep(i) + + assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD + assert db_service.health_state_actual == SoftwareHealthState.GOOD + + +def test_database_cannot_be_queried_while_fixing(uc2_network): + """Tests that the database service cannot be queried if the service is being fixed.""" + db_server: Server = uc2_network.get_node_by_hostname("database_server") + db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + + web_server: Server = uc2_network.get_node_by_hostname("web_server") + db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + + db_connection: DatabaseClientConnection = db_client.get_new_connection() + + assert db_connection.query(sql="SELECT") + + assert db_service.backup_database() is True + + # delete database locally + db_service.file_system.delete_file(folder_name="database", file_name="database.db") + + # db file is gone, reduced to atoms + assert db_service.db_file is None + + db_service.fix() # fix the database service + assert db_service.health_state_actual == SoftwareHealthState.FIXING + + # fails to query because database is in FIXING state + assert db_connection.query(sql="SELECT") is False + + # apply timestep until the fix is applied + for i in range(db_service.fixing_duration + 1): + uc2_network.apply_timestep(i) + + assert db_service.health_state_actual == SoftwareHealthState.GOOD + + assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD + + assert db_connection.query(sql="SELECT") + + +def test_database_can_create_connection_while_fixing(uc2_network): + """Tests that connections cannot be created while the database is being fixed.""" + db_server: Server = uc2_network.get_node_by_hostname("database_server") + db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + + client_2: Server = uc2_network.get_node_by_hostname("client_2") + db_client: DatabaseClient = client_2.software_manager.software["DatabaseClient"] + + db_connection: DatabaseClientConnection = db_client.get_new_connection() + + assert db_connection.query(sql="SELECT") + + assert db_service.backup_database() is True + + # delete database locally + db_service.file_system.delete_file(folder_name="database", file_name="database.db") + + # db file is gone, reduced to atoms + assert db_service.db_file is None + + db_service.fix() # fix the database service + assert db_service.health_state_actual == SoftwareHealthState.FIXING + + # fails to query because database is in FIXING state + assert db_connection.query(sql="SELECT") is False + + # should be able to create a new connection + new_db_connection: DatabaseClientConnection = db_client.get_new_connection() + assert new_db_connection is not None + assert new_db_connection.query(sql="SELECT") is False # still should fail to query because FIXING + + # apply timestep until the fix is applied + for i in range(db_service.fixing_duration + 1): + uc2_network.apply_timestep(i) + + assert db_service.health_state_actual == SoftwareHealthState.GOOD + assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD + + assert db_connection.query(sql="SELECT") + assert new_db_connection.query(sql="SELECT") + + def test_database_client_cannot_query_offline_database_server(uc2_network): """Tests DB query across the network returns HTTP status 404 when db server is offline.""" db_server: Server = uc2_network.get_node_by_hostname("database_server") diff --git a/tests/integration_tests/system/test_web_client_server_and_database.py b/tests/integration_tests/system/test_web_client_server_and_database.py index 3fe77fa0..5a765763 100644 --- a/tests/integration_tests/system/test_web_client_server_and_database.py +++ b/tests/integration_tests/system/test_web_client_server_and_database.py @@ -16,6 +16,7 @@ from primaite.simulator.system.services.database.database_service import Databas from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.simulator.system.software import SoftwareHealthState @pytest.fixture(scope="function") @@ -110,6 +111,29 @@ def test_web_client_requests_users(web_client_web_server_database): assert web_browser.get_webpage() +def test_database_fix_disrupts_web_client(uc2_network): + """Tests that the database service being in fixed state disrupts the web client.""" + computer: Computer = uc2_network.get_node_by_hostname("client_1") + db_server: Server = uc2_network.get_node_by_hostname("database_server") + + web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") + database_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") + + # fix the database service + database_service.fix() + + assert database_service.health_state_actual == SoftwareHealthState.FIXING + + assert web_browser.get_webpage() is False + + for i in range(database_service.fixing_duration + 1): + uc2_network.apply_timestep(i) + + assert database_service.health_state_actual == SoftwareHealthState.GOOD + + assert web_browser.get_webpage() + + class TestWebBrowserHistory: def test_populating_history(self, web_client_web_server_database): network, computer, _, _ = web_client_web_server_database From a4424608dd41f1b07e0778af858add47fc72b97a Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 28 Jun 2024 16:33:24 +0100 Subject: [PATCH 23/49] #2620: add nbmake and pytest xdist to run the notebooks as part of pieline --- .azure/azure-ci-build-pipeline.yaml | 5 +++++ pyproject.toml | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index aea94807..7ac50ea8 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -126,3 +126,8 @@ stages: inputs: codeCoverageTool: Cobertura summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml' + # Run the notebooks to make sure that they work + - script: | + pytest --nbmake -n=auto src/primaite/notebooks + pytest --nbmake -n=auto src/primaite/simulator/_package_data + displayName: 'Run notebooks' diff --git a/pyproject.toml b/pyproject.toml index 9d53e961..badb1557 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,9 @@ dependencies = [ "typer[all]==0.9.0", "pydantic==2.7.0", "ipywidgets", - "deepdiff" + "deepdiff", + "nbmake==1.5.4", + "pytest-xdist==3.6.1" ] [tool.setuptools.dynamic] From ce58f3960c7f3c1ab496df5e89369029997d972a Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 28 Jun 2024 16:43:10 +0100 Subject: [PATCH 24/49] #2620: downgrade pytest-xdist version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index badb1557..cd4436bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ "ipywidgets", "deepdiff", "nbmake==1.5.4", - "pytest-xdist==3.6.1" + "pytest-xdist==3.3.1" ] [tool.setuptools.dynamic] From c34cb6d7ce6142a1989532d85240209e50ac06b8 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 1 Jul 2024 11:31:27 +0100 Subject: [PATCH 25/49] #2700 Add DatabaseConfigure action --- src/primaite/game/agent/actions.py | 23 +++++ .../system/applications/database_client.py | 18 +++- .../observations/actions/__init__.py | 1 + .../actions/test_configure_actions.py | 85 +++++++++++++++++++ 4 files changed, 123 insertions(+), 4 deletions(-) create mode 100644 tests/integration_tests/game_layer/observations/actions/__init__.py create mode 100644 tests/integration_tests/game_layer/observations/actions/test_configure_actions.py diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index e165c9ad..4cb31d25 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -17,6 +17,7 @@ from gymnasium import spaces from pydantic import BaseModel, Field, field_validator, ValidationInfo from primaite import getLogger +from primaite.interface.request import RequestFormat _LOGGER = getLogger(__name__) @@ -245,6 +246,27 @@ class NodeApplicationInstallAction(AbstractAction): ] +class ConfigureDatabaseClientAction(AbstractAction): + """Action which sets config parameters for a database client on a node.""" + + class _Opts(BaseModel): + """Schema for options that can be passed to this action.""" + + server_ip_address: Optional[str] = None + server_password: Optional[str] = None + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + super().__init__(manager=manager) + + def form_request(self, node_id: int, options: Dict) -> RequestFormat: + """Return the action formatted as a request that can be ingested by the simulation.""" + node_name = self.manager.get_node_name_by_idx(node_id) + if node_name is None: + return ["do_nothing"] + ConfigureDatabaseClientAction._Opts.model_validate(options) # check that options adhere to schema + return ["network", "node", node_name, "application", "DatabaseClient", "configure", options] + + class NodeApplicationRemoveAction(AbstractAction): """Action which removes/uninstalls an application.""" @@ -1045,6 +1067,7 @@ class ActionManager: "NODE_NMAP_PING_SCAN": NodeNMAPPingScanAction, "NODE_NMAP_PORT_SCAN": NodeNMAPPortScanAction, "NODE_NMAP_NETWORK_SERVICE_RECON": NodeNetworkServiceReconAction, + "CONFIGURE_DATABASE_CLIENT": ConfigureDatabaseClientAction, } """Dictionary which maps action type strings to the corresponding action class.""" diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index bae2139b..6396c678 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -8,13 +8,14 @@ from uuid import uuid4 from prettytable import MARKDOWN, PrettyTable from pydantic import BaseModel -from primaite.interface.request import RequestResponse +from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.software_manager import SoftwareManager +from primaite.utils.validators import IPV4Address class DatabaseClientConnection(BaseModel): @@ -96,6 +97,14 @@ class DatabaseClient(Application): """ rm = super()._init_request_manager() rm.add_request("execute", RequestType(func=lambda request, context: RequestResponse.from_bool(self.execute()))) + + def _configure(request: RequestFormat, context: Dict) -> RequestResponse: + ip, pw = request[-1].get("server_ip_address"), request[-1].get("server_password") + ip = None if ip is None else IPV4Address(ip) + success = self.configure(server_ip_address=ip, server_password=pw) + return RequestResponse.from_bool(success) + + rm.add_request("configure", RequestType(func=lambda request, context: _configure(request, context))) return rm def execute(self) -> bool: @@ -141,16 +150,17 @@ class DatabaseClient(Application): table.add_row([connection_id, connection.is_active]) print(table.get_string(sortby="Connection ID")) - def configure(self, server_ip_address: IPv4Address, server_password: Optional[str] = None): + def configure(self, server_ip_address: Optional[IPv4Address] = None, server_password: Optional[str] = None) -> bool: """ Configure the DatabaseClient to communicate with a DatabaseService. :param server_ip_address: The IP address of the Node the DatabaseService is on. :param server_password: The password on the DatabaseService. """ - self.server_ip_address = server_ip_address - self.server_password = server_password + self.server_ip_address = server_ip_address or self.server_ip_address + self.server_password = server_password or self.server_password self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.") + return True def connect(self) -> bool: """Connect the native client connection.""" diff --git a/tests/integration_tests/game_layer/observations/actions/__init__.py b/tests/integration_tests/game_layer/observations/actions/__init__.py new file mode 100644 index 00000000..be6c00e7 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/actions/__init__.py @@ -0,0 +1 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK diff --git a/tests/integration_tests/game_layer/observations/actions/test_configure_actions.py b/tests/integration_tests/game_layer/observations/actions/test_configure_actions.py new file mode 100644 index 00000000..17e262d1 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/actions/test_configure_actions.py @@ -0,0 +1,85 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from ipaddress import IPv4Address + +from primaite.game.agent.actions import ConfigureDatabaseClientAction +from primaite.simulator.system.applications.database_client import DatabaseClient +from tests.conftest import ControlledAgent + + +class TestConfigureDatabaseAction: + def test_configure_ip_password(self, game_and_agent): + game, agent = game_and_agent + agent: ControlledAgent + agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager) + + # make sure there is a database client on this node + client_1 = game.simulation.network.get_node_by_hostname("client_1") + client_1.software_manager.install(DatabaseClient) + db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"] + + action = ( + "CONFIGURE_DATABASE_CLIENT", + { + "node_id": 0, + "options": { + "server_ip_address": "192.168.1.99", + "server_password": "admin123", + }, + }, + ) + agent.store_action(action) + game.step() + + assert db_client.server_ip_address == IPv4Address("192.168.1.99") + assert db_client.server_password == "admin123" + + def test_configure_ip(self, game_and_agent): + game, agent = game_and_agent + agent: ControlledAgent + agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager) + + # make sure there is a database client on this node + client_1 = game.simulation.network.get_node_by_hostname("client_1") + client_1.software_manager.install(DatabaseClient) + db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"] + + action = ( + "CONFIGURE_DATABASE_CLIENT", + { + "node_id": 0, + "options": { + "server_ip_address": "192.168.1.99", + }, + }, + ) + agent.store_action(action) + game.step() + + assert db_client.server_ip_address == IPv4Address("192.168.1.99") + assert db_client.server_password is None + + def test_configure_password(self, game_and_agent): + game, agent = game_and_agent + agent: ControlledAgent + agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager) + + # make sure there is a database client on this node + client_1 = game.simulation.network.get_node_by_hostname("client_1") + client_1.software_manager.install(DatabaseClient) + db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"] + old_ip = db_client.server_ip_address + + action = ( + "CONFIGURE_DATABASE_CLIENT", + { + "node_id": 0, + "options": { + "server_password": "admin123", + }, + }, + ) + agent.store_action(action) + game.step() + + assert db_client.server_ip_address == old_ip + assert db_client.server_password is "admin123" From ee4e152f13d920ed2d78eebc66f452d253b969aa Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 1 Jul 2024 12:00:17 +0100 Subject: [PATCH 26/49] #2620: publish result of test and checking if pipeline fails --- .azure/azure-ci-build-pipeline.yaml | 11 +++++++++-- .gitignore | 1 + src/primaite/notebooks/multi-processing.ipynb | 11 ++++++++++- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 7ac50ea8..a45d26bc 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -128,6 +128,13 @@ stages: summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml' # Run the notebooks to make sure that they work - script: | - pytest --nbmake -n=auto src/primaite/notebooks - pytest --nbmake -n=auto src/primaite/simulator/_package_data + pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml + pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml displayName: 'Run notebooks' + - task: PublishTestResults@2 + condition: succeededOrFailed() + inputs: + testRunner: JUnit + testResultsFiles: 'notebook-tests/**.xml' + testRunTitle: 'Publish Notebook run' + failTaskOnFailedTests: true diff --git a/.gitignore b/.gitignore index e48dc5dc..140a3d0b 100644 --- a/.gitignore +++ b/.gitignore @@ -54,6 +54,7 @@ cover/ tests/assets/**/*.png tests/assets/**/tensorboard_logs/ tests/assets/**/checkpoints/ +notebook-tests/*.xml # Translations *.mo diff --git a/src/primaite/notebooks/multi-processing.ipynb b/src/primaite/notebooks/multi-processing.ipynb index 86b549a7..6fc8b4c2 100644 --- a/src/primaite/notebooks/multi-processing.ipynb +++ b/src/primaite/notebooks/multi-processing.ipynb @@ -18,6 +18,15 @@ "Import packages and read config file." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "1/0" + ] + }, { "cell_type": "code", "execution_count": null, @@ -143,7 +152,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.11" } }, "nbformat": 4, From cb61756e43716ef3c82320cf4cce40b6bbec816c Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 1 Jul 2024 12:21:05 +0100 Subject: [PATCH 27/49] #2620: attempting to fail pipeline if notebook fails --- .azure/azure-ci-build-pipeline.yaml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index a45d26bc..a27fed25 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -109,6 +109,7 @@ stages: - task: PublishTestResults@2 condition: succeededOrFailed() + displayName: 'Publish Test Results' inputs: testRunner: JUnit testResultsFiles: 'junit/**.xml' @@ -131,10 +132,4 @@ stages: pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml displayName: 'Run notebooks' - - task: PublishTestResults@2 condition: succeededOrFailed() - inputs: - testRunner: JUnit - testResultsFiles: 'notebook-tests/**.xml' - testRunTitle: 'Publish Notebook run' - failTaskOnFailedTests: true From 655bc04a42b2aa645fbe63c6131f6127fc41b5c8 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 1 Jul 2024 13:21:08 +0100 Subject: [PATCH 28/49] #2705 Minor comment and name changes to applications --- src/primaite/simulator/system/applications/application.py | 2 +- tests/conftest.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 98ccc27f..848e1ef0 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -46,7 +46,7 @@ class Application(IOSoftware): """ Register an application type. - :param identifier: Uniquely specifies an application class by name. + :param identifier: Uniquely specifies an application class by name. Used for finding items by config. :type identifier: str :raises ValueError: When attempting to register an application with a name that is already allocated. """ diff --git a/tests/conftest.py b/tests/conftest.py index 171e1996..980e4aa9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -55,7 +55,7 @@ class DummyApplication(Application, identifier="DummyApplication"): """Test Application class""" def __init__(self, **kwargs): - kwargs["name"] = "TestApplication" + kwargs["name"] = "DummyApplication" kwargs["port"] = Port.HTTP kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) @@ -87,7 +87,7 @@ def service_class(): @pytest.fixture(scope="function") def application(file_system) -> DummyApplication: return DummyApplication( - name="TestApplication", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_application") + name="DummyApplication", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_application") ) From 2dd7546f3d1578625096012cc5806938c0f442ed Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 1 Jul 2024 13:25:16 +0100 Subject: [PATCH 29/49] 2705 Fix application tests by correctly renaming fixture --- .../integration_tests/system/test_application_on_node.py | 4 ++-- .../test_simulation/test_request_response.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/integration_tests/system/test_application_on_node.py b/tests/integration_tests/system/test_application_on_node.py index 275646c6..ffb5cc7f 100644 --- a/tests/integration_tests/system/test_application_on_node.py +++ b/tests/integration_tests/system/test_application_on_node.py @@ -21,7 +21,7 @@ def populated_node(application_class) -> Tuple[Application, Computer]: computer.power_on() computer.software_manager.install(application_class) - app = computer.software_manager.software.get("TestApplication") + app = computer.software_manager.software.get("DummyApplication") app.run() return app, computer @@ -39,7 +39,7 @@ def test_application_on_offline_node(application_class): ) computer.software_manager.install(application_class) - app: Application = computer.software_manager.software.get("TestApplication") + app: Application = computer.software_manager.software.get("DummyApplication") computer.power_off() diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py index 8c58bb42..a9f0b58d 100644 --- a/tests/integration_tests/test_simulation/test_request_response.py +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -48,13 +48,13 @@ def test_successful_application_requests(example_network): client_1 = net.get_node_by_hostname("client_1") client_1.software_manager.install(DummyApplication) - client_1.software_manager.software.get("TestApplication").run() + client_1.software_manager.software.get("DummyApplication").run() - resp_1 = net.apply_request(["node", "client_1", "application", "TestApplication", "scan"]) + resp_1 = net.apply_request(["node", "client_1", "application", "DummyApplication", "scan"]) assert resp_1 == RequestResponse(status="success", data={}) - resp_2 = net.apply_request(["node", "client_1", "application", "TestApplication", "fix"]) + resp_2 = net.apply_request(["node", "client_1", "application", "DummyApplication", "fix"]) assert resp_2 == RequestResponse(status="success", data={}) - resp_3 = net.apply_request(["node", "client_1", "application", "TestApplication", "compromise"]) + resp_3 = net.apply_request(["node", "client_1", "application", "DummyApplication", "compromise"]) assert resp_3 == RequestResponse(status="success", data={}) From 4dd50be11ac9242084ec30030a7ff6e58f4d5df1 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 1 Jul 2024 14:31:58 +0100 Subject: [PATCH 30/49] #2620: run notebooks after test so that the results can be published --- .azure/azure-ci-build-pipeline.yaml | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index a27fed25..9217748b 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -107,12 +107,20 @@ stages: coverage html -d htmlcov -i displayName: 'Run tests and code coverage' + # Run the notebooks to make sure that they work + - script: | + pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml + pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml + displayName: 'Run notebooks' + - task: PublishTestResults@2 condition: succeededOrFailed() displayName: 'Publish Test Results' inputs: testRunner: JUnit - testResultsFiles: 'junit/**.xml' + testResultsFiles: | + 'junit/**.xml' + 'notebook-tests/**.xml' testRunTitle: 'Publish test results' failTaskOnFailedTests: true @@ -127,9 +135,3 @@ stages: inputs: codeCoverageTool: Cobertura summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml' - # Run the notebooks to make sure that they work - - script: | - pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml - pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml - displayName: 'Run notebooks' - condition: succeededOrFailed() From ab73ac20e84766b5c77688b9adf0e02a98b22d6b Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 1 Jul 2024 14:41:41 +0100 Subject: [PATCH 31/49] #2700 add ransomware configure action --- src/primaite/game/agent/actions.py | 27 ++++++- .../system/applications/database_client.py | 2 +- .../red_applications/ransomware_script.py | 14 +++- .../{observations => }/actions/__init__.py | 0 .../actions/test_configure_actions.py | 75 ++++++++++++++++++- 5 files changed, 113 insertions(+), 5 deletions(-) rename tests/integration_tests/game_layer/{observations => }/actions/__init__.py (100%) rename tests/integration_tests/game_layer/{observations => }/actions/test_configure_actions.py (52%) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 4cb31d25..9f2693e5 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -14,7 +14,7 @@ from abc import ABC, abstractmethod from typing import Dict, List, Literal, Optional, Tuple, TYPE_CHECKING, Union from gymnasium import spaces -from pydantic import BaseModel, Field, field_validator, ValidationInfo +from pydantic import BaseModel, ConfigDict, Field, field_validator, ValidationInfo from primaite import getLogger from primaite.interface.request import RequestFormat @@ -252,6 +252,7 @@ class ConfigureDatabaseClientAction(AbstractAction): class _Opts(BaseModel): """Schema for options that can be passed to this action.""" + model_config = ConfigDict(extra="forbid") server_ip_address: Optional[str] = None server_password: Optional[str] = None @@ -267,6 +268,29 @@ class ConfigureDatabaseClientAction(AbstractAction): return ["network", "node", node_name, "application", "DatabaseClient", "configure", options] +class ConfigureRansomwareScriptAction(AbstractAction): + """Action which sets config parameters for a ransomware script on a node.""" + + class _Opts(BaseModel): + """Schema for options that can be passed to this option.""" + + model_config = ConfigDict(extra="forbid") + server_ip_address: Optional[str] = None + server_password: Optional[str] = None + payload: Optional[str] = None + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + super().__init__(manager=manager) + + def form_request(self, node_id: int, options: Dict) -> RequestFormat: + """Return the action formatted as a request that can be ingested by the simulation.""" + node_name = self.manager.get_node_name_by_idx(node_id) + if node_name is None: + return ["do_nothing"] + ConfigureRansomwareScriptAction._Opts.model_validate(options) # check that options adhere to schema + return ["network", "node", node_name, "application", "RansomwareScript", "configure", options] + + class NodeApplicationRemoveAction(AbstractAction): """Action which removes/uninstalls an application.""" @@ -1068,6 +1092,7 @@ class ActionManager: "NODE_NMAP_PORT_SCAN": NodeNMAPPortScanAction, "NODE_NMAP_NETWORK_SERVICE_RECON": NodeNetworkServiceReconAction, "CONFIGURE_DATABASE_CLIENT": ConfigureDatabaseClientAction, + "CONFIGURE_RANSOMWARE_SCRIPT": ConfigureRansomwareScriptAction, } """Dictionary which maps action type strings to the corresponding action class.""" diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 6396c678..fcfd603b 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -104,7 +104,7 @@ class DatabaseClient(Application): success = self.configure(server_ip_address=ip, server_password=pw) return RequestResponse.from_bool(success) - rm.add_request("configure", RequestType(func=lambda request, context: _configure(request, context))) + rm.add_request("configure", RequestType(func=_configure)) return rm def execute(self) -> bool: diff --git a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py index af4a59d4..46e42fc2 100644 --- a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py +++ b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py @@ -2,7 +2,7 @@ from ipaddress import IPv4Address from typing import Dict, Optional -from primaite.interface.request import RequestResponse +from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port @@ -62,6 +62,15 @@ class RansomwareScript(Application): name="execute", request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.attack())), ) + + def _configure(request: RequestFormat, context: Dict) -> RequestResponse: + ip = request[-1].get("server_ip_address") + ip = None if ip is None else IPv4Address(ip) + pw = request[-1].get("server_password") + payload = request[-1].get("payload") + return RequestResponse.from_bool(self.configure(ip, pw, payload)) + + rm.add_request("configure", request_type=RequestType(func=_configure)) return rm def run(self) -> bool: @@ -91,7 +100,7 @@ class RansomwareScript(Application): server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None, - ): + ) -> bool: """ Configure the Ransomware Script to communicate with a DatabaseService. @@ -108,6 +117,7 @@ class RansomwareScript(Application): self.sys_log.info( f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}." ) + return True def attack(self) -> bool: """Perform the attack steps after opening the application.""" diff --git a/tests/integration_tests/game_layer/observations/actions/__init__.py b/tests/integration_tests/game_layer/actions/__init__.py similarity index 100% rename from tests/integration_tests/game_layer/observations/actions/__init__.py rename to tests/integration_tests/game_layer/actions/__init__.py diff --git a/tests/integration_tests/game_layer/observations/actions/test_configure_actions.py b/tests/integration_tests/game_layer/actions/test_configure_actions.py similarity index 52% rename from tests/integration_tests/game_layer/observations/actions/test_configure_actions.py rename to tests/integration_tests/game_layer/actions/test_configure_actions.py index 17e262d1..5439f3c9 100644 --- a/tests/integration_tests/game_layer/observations/actions/test_configure_actions.py +++ b/tests/integration_tests/game_layer/actions/test_configure_actions.py @@ -1,8 +1,12 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address -from primaite.game.agent.actions import ConfigureDatabaseClientAction +import pytest +from pydantic import ValidationError + +from primaite.game.agent.actions import ConfigureDatabaseClientAction, ConfigureRansomwareScriptAction from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from tests.conftest import ControlledAgent @@ -83,3 +87,72 @@ class TestConfigureDatabaseAction: assert db_client.server_ip_address == old_ip assert db_client.server_password is "admin123" + + +class TestConfigureRansomwareScriptAction: + @pytest.mark.parametrize( + "options", + [ + {}, + {"server_ip_address": "181.181.181.181"}, + {"server_password": "admin123"}, + {"payload": "ENCRYPT"}, + { + "server_ip_address": "181.181.181.181", + "server_password": "admin123", + "payload": "ENCRYPT", + }, + ], + ) + def test_configure_ip_password(self, game_and_agent, options): + game, agent = game_and_agent + agent: ControlledAgent + agent.action_manager.actions["CONFIGURE_RANSOMWARE_SCRIPT"] = ConfigureRansomwareScriptAction( + agent.action_manager + ) + + # make sure there is a database client on this node + client_1 = game.simulation.network.get_node_by_hostname("client_1") + client_1.software_manager.install(RansomwareScript) + ransomware_script: RansomwareScript = client_1.software_manager.software["RansomwareScript"] + + old_ip = ransomware_script.server_ip_address + old_pw = ransomware_script.server_password + old_payload = ransomware_script.payload + + action = ( + "CONFIGURE_RANSOMWARE_SCRIPT", + {"node_id": 0, "options": options}, + ) + agent.store_action(action) + game.step() + + expected_ip = old_ip if "server_ip_address" not in options else IPv4Address(options["server_ip_address"]) + expected_pw = old_pw if "server_password" not in options else options["server_password"] + expected_payload = old_payload if "payload" not in options else options["payload"] + + assert ransomware_script.server_ip_address == expected_ip + assert ransomware_script.server_password == expected_pw + assert ransomware_script.payload == expected_payload + + def test_invalid_options(self, game_and_agent): + game, agent = game_and_agent + agent: ControlledAgent + agent.action_manager.actions["CONFIGURE_RANSOMWARE_SCRIPT"] = ConfigureRansomwareScriptAction( + agent.action_manager + ) + + # make sure there is a database client on this node + client_1 = game.simulation.network.get_node_by_hostname("client_1") + client_1.software_manager.install(RansomwareScript) + ransomware_script: RansomwareScript = client_1.software_manager.software["RansomwareScript"] + action = ( + "CONFIGURE_RANSOMWARE_SCRIPT", + { + "node_id": 0, + "options": {"server_password": "admin123", "bad_option": 70}, + }, + ) + agent.store_action(action) + with pytest.raises(ValidationError): + game.step() From bf8ec6083331adfcd089a55a4b0b6fc60b0423b3 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 1 Jul 2024 15:25:20 +0100 Subject: [PATCH 32/49] #2700 Add configure dosbot action --- src/primaite/game/agent/actions.py | 28 ++++++++++++ .../applications/red_applications/dos_bot.py | 17 +++++-- .../actions/test_configure_actions.py | 45 ++++++++++++++++++- 3 files changed, 85 insertions(+), 5 deletions(-) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 9f2693e5..1de5276c 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -291,6 +291,33 @@ class ConfigureRansomwareScriptAction(AbstractAction): return ["network", "node", node_name, "application", "RansomwareScript", "configure", options] +class ConfigureDoSBotAction(AbstractAction): + """Action which sets config parameters for a DoS bot on a node.""" + + class _Opts(BaseModel): + """Schema for options that can be passed to this option.""" + + model_config = ConfigDict(extra="forbid") + target_ip_address: Optional[str] = None + target_port: Optional[str] = None + payload: Optional[str] = None + repeat: Optional[bool] = None + port_scan_p_of_success: Optional[float] = None + dos_intensity: Optional[float] = None + max_sessions: Optional[int] = None + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + super().__init__(manager=manager) + + def form_request(self, node_id: int, options: Dict) -> RequestFormat: + """Return the action formatted as a request that can be ingested by the simulation.""" + node_name = self.manager.get_node_name_by_idx(node_id) + if node_name is None: + return ["do_nothing"] + self._Opts.model_validate(options) # check that options adhere to schema + return ["network", "node", node_name, "application", "DoSBot", "configure", options] + + class NodeApplicationRemoveAction(AbstractAction): """Action which removes/uninstalls an application.""" @@ -1093,6 +1120,7 @@ class ActionManager: "NODE_NMAP_NETWORK_SERVICE_RECON": NodeNetworkServiceReconAction, "CONFIGURE_DATABASE_CLIENT": ConfigureDatabaseClientAction, "CONFIGURE_RANSOMWARE_SCRIPT": ConfigureRansomwareScriptAction, + "CONFIGURE_DOSBOT": ConfigureDoSBotAction, } """Dictionary which maps action type strings to the corresponding action class.""" diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index 65e34227..dccf45f5 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -1,11 +1,11 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from enum import IntEnum from ipaddress import IPv4Address -from typing import Optional +from typing import Dict, Optional from primaite import getLogger from primaite.game.science import simulate_trial -from primaite.interface.request import RequestResponse +from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient @@ -71,6 +71,14 @@ class DoSBot(DatabaseClient): request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.run())), ) + def _configure(request: RequestFormat, context: Dict) -> RequestResponse: + if "target_ip_address" in request[-1]: + request[-1]["target_ip_address"] = IPv4Address(request[-1]["target_ip_address"]) + if "target_port" in request[-1]: + request[-1]["target_port"] = Port[request[-1]["target_port"]] + return RequestResponse.from_bool(self.configure(**request[-1])) + + rm.add_request("configure", request_type=RequestType(func=_configure)) return rm def configure( @@ -82,7 +90,7 @@ class DoSBot(DatabaseClient): port_scan_p_of_success: float = 0.1, dos_intensity: float = 1.0, max_sessions: int = 1000, - ): + ) -> bool: """ Configure the Denial of Service bot. @@ -90,7 +98,7 @@ class DoSBot(DatabaseClient): :param: target_port: The port of the target service. Optional - Default is `Port.HTTP` :param: payload: The payload the DoS Bot will throw at the target service. Optional - Default is `None` :param: repeat: If True, the bot will maintain the attack. Optional - Default is `True` - :param: port_scan_p_of_success: The chance of the port scan being sucessful. Optional - Default is 0.1 (10%) + :param: port_scan_p_of_success: The chance of the port scan being successful. Optional - Default is 0.1 (10%) :param: dos_intensity: The intensity of the DoS attack. Multiplied with the application's max session - Default is 1.0 :param: max_sessions: The maximum number of sessions the DoS bot will attack with. Optional - Default is 1000 @@ -106,6 +114,7 @@ class DoSBot(DatabaseClient): f"{self.name}: Configured the {self.name} with {target_ip_address=}, {target_port=}, {payload=}, " f"{repeat=}, {port_scan_p_of_success=}, {dos_intensity=}, {max_sessions=}." ) + return True def run(self) -> bool: """Run the Denial of Service Bot.""" diff --git a/tests/integration_tests/game_layer/actions/test_configure_actions.py b/tests/integration_tests/game_layer/actions/test_configure_actions.py index 5439f3c9..6bcd3b52 100644 --- a/tests/integration_tests/game_layer/actions/test_configure_actions.py +++ b/tests/integration_tests/game_layer/actions/test_configure_actions.py @@ -4,8 +4,14 @@ from ipaddress import IPv4Address import pytest from pydantic import ValidationError -from primaite.game.agent.actions import ConfigureDatabaseClientAction, ConfigureRansomwareScriptAction +from primaite.game.agent.actions import ( + ConfigureDatabaseClientAction, + ConfigureDoSBotAction, + ConfigureRansomwareScriptAction, +) +from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from tests.conftest import ControlledAgent @@ -156,3 +162,40 @@ class TestConfigureRansomwareScriptAction: agent.store_action(action) with pytest.raises(ValidationError): game.step() + + +class TestConfigureDoSBot: + def test_configure_DoSBot(self, game_and_agent): + game, agent = game_and_agent + agent: ControlledAgent + agent.action_manager.actions["CONFIGURE_DOSBOT"] = ConfigureDoSBotAction(agent.action_manager) + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + client_1.software_manager.install(DoSBot) + dos_bot: DoSBot = client_1.software_manager.software["DoSBot"] + + action = ( + "CONFIGURE_DOSBOT", + { + "node_id": 0, + "options": { + "target_ip_address": "192.168.1.99", + "target_port": "POSTGRES_SERVER", + "payload": "HACC", + "repeat": False, + "port_scan_p_of_success": 0.875, + "dos_intensity": 0.75, + "max_sessions": 50, + }, + }, + ) + agent.store_action(action) + game.step() + + assert dos_bot.target_ip_address == IPv4Address("192.168.1.99") + assert dos_bot.target_port == Port.POSTGRES_SERVER + assert dos_bot.payload == "HACC" + assert not dos_bot.repeat + assert dos_bot.port_scan_p_of_success == 0.875 + assert dos_bot.dos_intensity == 0.75 + assert dos_bot.max_sessions == 50 From a47a14b86e23d5533f5f88f84c49d1d64394a6af Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 1 Jul 2024 15:57:15 +0100 Subject: [PATCH 33/49] #2620: Going around azure dev ops to fail the script --- .azure/azure-ci-build-pipeline.yaml | 23 +++++++++++++++++++++-- pyproject.toml | 8 ++++---- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 9217748b..47daa10f 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -109,8 +109,27 @@ stages: # Run the notebooks to make sure that they work - script: | - pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml - pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml + # Detect OS + if [ "$(uname)" = "Linux" ] || [ "$(uname)" = "Darwin" ]; then + # Commands for Linux and macOS + pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml + notebooks_exit_code=$? + pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml + package_notebooks_exit_code=$? + # Fail step if either of these do not have exit code 0 + if [ $notebooks_exit_code -ne 0 ] || [ $package_notebooks_exit_code -ne 0 ]; then + exit 1 + fi + else + # Commands for Windows + pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml + set notebooks_exit_code=%ERRORLEVEL% + pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml + set package_notebooks_exit_code=%ERRORLEVEL% + # Fail step if either of these do not have exit code 0 + if %notebooks_exit_code% NEQ 0 exit /b 1 + if %package_notebooks_exit_code% NEQ 0 exit /b 1 + fi displayName: 'Run notebooks' - task: PublishTestResults@2 diff --git a/pyproject.toml b/pyproject.toml index cd4436bf..a0c2e3eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,9 +37,7 @@ dependencies = [ "typer[all]==0.9.0", "pydantic==2.7.0", "ipywidgets", - "deepdiff", - "nbmake==1.5.4", - "pytest-xdist==3.3.1" + "deepdiff" ] [tool.setuptools.dynamic] @@ -74,7 +72,9 @@ dev = [ "Sphinx==7.1.2", "sphinx-copybutton==0.5.2", "wheel==0.38.4", - "nbsphinx==0.9.4" + "nbsphinx==0.9.4", + "nbmake==1.5.4", + "pytest-xdist==3.3.1" ] [project.scripts] From dc2c64b2f67fae73767ea0637d46fccb230ef1c1 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 1 Jul 2024 16:23:10 +0100 Subject: [PATCH 34/49] #2701 - Remove ip address option from node application install --- src/primaite/game/agent/actions.py | 3 +- .../game/agent/scripted_agents/tap001.py | 1 - .../simulator/network/hardware/base.py | 164 ++++++------------ .../configs/test_application_install.yaml | 1 - .../game_layer/test_actions.py | 2 +- 5 files changed, 53 insertions(+), 118 deletions(-) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index e165c9ad..3a21a95f 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -228,7 +228,7 @@ class NodeApplicationInstallAction(AbstractAction): super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes} - def form_request(self, node_id: int, application_name: str, ip_address: str) -> List[str]: + def form_request(self, node_id: int, application_name: str) -> List[str]: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) if node_name is None: @@ -241,7 +241,6 @@ class NodeApplicationInstallAction(AbstractAction): "application", "install", application_name, - ip_address, ] diff --git a/src/primaite/game/agent/scripted_agents/tap001.py b/src/primaite/game/agent/scripted_agents/tap001.py index b1a378ef..c4f6062a 100644 --- a/src/primaite/game/agent/scripted_agents/tap001.py +++ b/src/primaite/game/agent/scripted_agents/tap001.py @@ -55,7 +55,6 @@ class TAP001(AbstractScriptedAgent): return "NODE_APPLICATION_INSTALL", { "node_id": self.starting_node_idx, "application_name": "RansomwareScript", - "ip_address": self.ip_address, } return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0} diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 01745215..1982b08f 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -6,7 +6,7 @@ import secrets from abc import ABC, abstractmethod from ipaddress import IPv4Address, IPv4Network from pathlib import Path -from typing import Any, Dict, Optional, Type, TypeVar, Union +from typing import Any, Dict, Optional, TypeVar, Union from prettytable import MARKDOWN, PrettyTable from pydantic import BaseModel, Field @@ -884,6 +884,54 @@ class Node(SimComponent): More information in user guide and docstring for SimComponent._init_request_manager. """ + + def _install_application(request: RequestFormat, context: Dict) -> RequestResponse: + """ + Allows agents to install applications to the node. + + :param request: list containing the application name as the only element + :type application: str + """ + application_name = request[0] + if self.software_manager.software.get(application_name): + self.sys_log.warning(f"Can't install {application_name}. It's already installed.") + return RequestResponse.from_bool(False) + application_class = Application._application_registry[application_name] + self.software_manager.install(application_class) + application_instance = self.software_manager.software.get(application_name) + self.applications[application_instance.uuid] = application_instance + _LOGGER.debug(f"Added application {application_instance.name} to node {self.hostname}") + self._application_request_manager.add_request( + application_name, RequestType(func=application_instance._request_manager) + ) + application_instance.install() + if application_name in self.software_manager.software: + return RequestResponse.from_bool(True) + else: + return RequestResponse.from_bool(False) + + def _uninstall_application(request: RequestFormat, context: Dict) -> RequestResponse: + """ + Uninstall and completely remove application from this node. + + This method is useful for allowing agents to take this action. + + :param application: Application object that is currently associated with this node. + :type application: Application + :return: True if the application is uninstalled successfully, otherwise False. + """ + application_name = request[0] + if application_name not in self.software_manager.software: + self.sys_log.warning(f"Can't uninstall {application_name}. It's not installed.") + return RequestResponse.from_bool(False) + + application_instance = self.software_manager.software.get(application_name) + self.software_manager.uninstall(application_instance.name) + if application_instance.name not in self.software_manager.software: + return RequestResponse.from_bool(True) + else: + return RequestResponse.from_bool(False) + _node_is_on = Node._NodeIsOnValidator(node=self) rm = super()._init_request_manager() @@ -940,25 +988,8 @@ class Node(SimComponent): name="application", request_type=RequestType(func=self._application_manager) ) - self._application_manager.add_request( - name="install", - request_type=RequestType( - func=lambda request, context: RequestResponse.from_bool( - self.application_install_action( - application=self._read_application_type(request[0]), ip_address=request[1] - ) - ) - ), - ) - - self._application_manager.add_request( - name="uninstall", - request_type=RequestType( - func=lambda request, context: RequestResponse.from_bool( - self.application_uninstall_action(application=self._read_application_type(request[0])) - ) - ), - ) + self._application_manager.add_request(name="install", request_type=RequestType(func=_install_application)) + self._application_manager.add_request(name="uninstall", request_type=RequestType(func=_uninstall_application)) return rm @@ -966,29 +997,6 @@ class Node(SimComponent): """Install System Software - software that is usually provided with the OS.""" pass - def _read_application_type(self, application_class_str: str) -> Type[IOSoftwareClass]: - """Wrapper that converts the string from the request manager into the appropriate class for the application.""" - if application_class_str == "DoSBot": - from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot - - return DoSBot - elif application_class_str == "DataManipulationBot": - from primaite.simulator.system.applications.red_applications.data_manipulation_bot import ( - DataManipulationBot, - ) - - return DataManipulationBot - elif application_class_str == "WebBrowser": - from primaite.simulator.system.applications.web_browser import WebBrowser - - return WebBrowser - elif application_class_str == "RansomwareScript": - from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript - - return RansomwareScript - else: - return 0 - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -1417,76 +1425,6 @@ class Node(SimComponent): self.sys_log.info(f"Uninstalled application {application.name}") self._application_request_manager.remove_request(application.name) - def application_install_action(self, application: Application, ip_address: Optional[str] = None) -> bool: - """ - Install an application on this node and configure it. - - This method is useful for allowing agents to take this action. - - :param application: Application object that has not been installed on any node yet. - :type application: Application - :param ip_address: IP address used to configure the application - (target IP for the DoSBot or server IP for the DataManipulationBot) - :type ip_address: str - :return: True if the application is installed successfully, otherwise False. - """ - if application in self: - _LOGGER.warning( - f"Can't add application {application.__name__}" + f"to node {self.hostname}. It's already installed." - ) - return True - - self.software_manager.install(application) - application_instance = self.software_manager.software.get(str(application.__name__)) - self.applications[application_instance.uuid] = application_instance - _LOGGER.debug(f"Added application {application_instance.name} to node {self.hostname}") - self._application_request_manager.add_request( - application_instance.name, RequestType(func=application_instance._request_manager) - ) - - # Configure application if additional parameters are given - if ip_address: - if application_instance.name == "DoSBot": - application_instance.configure(target_ip_address=IPv4Address(ip_address)) - elif application_instance.name == "DataManipulationBot": - application_instance.configure(server_ip_address=IPv4Address(ip_address)) - elif application_instance.name == "RansomwareScript": - application_instance.configure(server_ip_address=IPv4Address(ip_address)) - else: - pass - application_instance.install() - if application_instance.name in self.software_manager.software: - return True - else: - return False - - def application_uninstall_action(self, application: Application) -> bool: - """ - Uninstall and completely remove application from this node. - - This method is useful for allowing agents to take this action. - - :param application: Application object that is currently associated with this node. - :type application: Application - :return: True if the application is uninstalled successfully, otherwise False. - """ - if application.__name__ not in self.software_manager.software: - _LOGGER.warning( - f"Can't remove application {application.__name__}" + f"from node {self.hostname}. It's not installed." - ) - return True - - application_instance = self.software_manager.software.get( - str(application.__name__) - ) # This works because we can't have two applications with the same name on the same node - # self.uninstall_application(application_instance) - self.software_manager.uninstall(application_instance.name) - - if application_instance.name not in self.software_manager.software: - return True - else: - return False - def _shut_down_actions(self): """Actions to perform when the node is shut down.""" # Turn off all the services in the node diff --git a/tests/assets/configs/test_application_install.yaml b/tests/assets/configs/test_application_install.yaml index 87402f73..a4e898ae 100644 --- a/tests/assets/configs/test_application_install.yaml +++ b/tests/assets/configs/test_application_install.yaml @@ -683,7 +683,6 @@ agents: options: node_id: 0 application_name: DoSBot - ip_address: 192.168.1.14 79: action: NODE_APPLICATION_REMOVE options: diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index 0dcf125d..a1005f34 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -557,7 +557,7 @@ def test_node_application_install_and_uninstall_integration(game_and_agent: Tupl assert client_1.software_manager.software.get("DoSBot") is None - action = ("NODE_APPLICATION_INSTALL", {"node_id": 0, "application_name": "DoSBot", "ip_address": "192.168.1.14"}) + action = ("NODE_APPLICATION_INSTALL", {"node_id": 0, "application_name": "DoSBot"}) agent.store_action(action) game.step() From afcb844501a887e7d1c3e761cac9c3dd2f8d3749 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 1 Jul 2024 16:45:28 +0100 Subject: [PATCH 35/49] #2620: remove cell that fails on purpose --- src/primaite/notebooks/multi-processing.ipynb | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/primaite/notebooks/multi-processing.ipynb b/src/primaite/notebooks/multi-processing.ipynb index 6fc8b4c2..305cfd70 100644 --- a/src/primaite/notebooks/multi-processing.ipynb +++ b/src/primaite/notebooks/multi-processing.ipynb @@ -18,15 +18,6 @@ "Import packages and read config file." ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "1/0" - ] - }, { "cell_type": "code", "execution_count": null, From ab3e84b8b9ad78d947a6efa8b4a6e706bef94dc3 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 1 Jul 2024 17:03:14 +0100 Subject: [PATCH 36/49] #2620: remove irrelevant change --- src/primaite/notebooks/multi-processing.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/notebooks/multi-processing.ipynb b/src/primaite/notebooks/multi-processing.ipynb index 305cfd70..86b549a7 100644 --- a/src/primaite/notebooks/multi-processing.ipynb +++ b/src/primaite/notebooks/multi-processing.ipynb @@ -143,7 +143,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4, From 8097884ae2f3142fa75cd512a44d5d72895072cb Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 2 Jul 2024 00:00:20 +0100 Subject: [PATCH 37/49] #2620: modify script --- .azure/azure-ci-build-pipeline.yaml | 49 +++++++++++++++-------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 47daa10f..1176f52f 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -107,30 +107,33 @@ stages: coverage html -d htmlcov -i displayName: 'Run tests and code coverage' - # Run the notebooks to make sure that they work + # Run the notebooks - script: | - # Detect OS - if [ "$(uname)" = "Linux" ] || [ "$(uname)" = "Darwin" ]; then - # Commands for Linux and macOS - pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml - notebooks_exit_code=$? - pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml - package_notebooks_exit_code=$? - # Fail step if either of these do not have exit code 0 - if [ $notebooks_exit_code -ne 0 ] || [ $package_notebooks_exit_code -ne 0 ]; then - exit 1 - fi - else - # Commands for Windows - pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml - set notebooks_exit_code=%ERRORLEVEL% - pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml - set package_notebooks_exit_code=%ERRORLEVEL% - # Fail step if either of these do not have exit code 0 - if %notebooks_exit_code% NEQ 0 exit /b 1 - if %package_notebooks_exit_code% NEQ 0 exit /b 1 - fi - displayName: 'Run notebooks' + pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml + notebooks_exit_code=$? + pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml + package_notebooks_exit_code=$? + # Fail step if either of these do not have exit code 0 + if [ $notebooks_exit_code -ne 0 ] || [ $package_notebooks_exit_code -ne 0 ]; then + exit 1 + fi + displayName: 'Run notebooks on Linux and macOS' + condition: or(eq(variables['Agent.OS'], 'Linux'), eq(variables['Agent.OS'], 'Darwin')) + shell: bash + + # Run notebooks + - script: | + pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml + $notebooks_exit_code = $LASTEXITCODE + pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml + $package_notebooks_exit_code = $LASTEXITCODE + # Fail step if either of these do not have exit code 0 + if ($notebooks_exit_code -ne 0 -or $package_notebooks_exit_code -ne 0) { + exit 1 + } + displayName: 'Run notebooks on Windows' + condition: eq(variables['Agent.OS'], 'Windows_NT') + shell: pwsh - task: PublishTestResults@2 condition: succeededOrFailed() From 918eba2217b4bb2fbf97ef2092a25566302ca5ba Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 2 Jul 2024 00:25:30 +0100 Subject: [PATCH 38/49] #2620: fix indentation --- .azure/azure-ci-build-pipeline.yaml | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 1176f52f..24606c24 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -121,19 +121,19 @@ stages: condition: or(eq(variables['Agent.OS'], 'Linux'), eq(variables['Agent.OS'], 'Darwin')) shell: bash - # Run notebooks - - script: | - pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml - $notebooks_exit_code = $LASTEXITCODE - pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml - $package_notebooks_exit_code = $LASTEXITCODE - # Fail step if either of these do not have exit code 0 - if ($notebooks_exit_code -ne 0 -or $package_notebooks_exit_code -ne 0) { - exit 1 - } - displayName: 'Run notebooks on Windows' - condition: eq(variables['Agent.OS'], 'Windows_NT') - shell: pwsh + # Run notebooks + - script: | + pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml + $notebooks_exit_code = $LASTEXITCODE + pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml + $package_notebooks_exit_code = $LASTEXITCODE + # Fail step if either of these do not have exit code 0 + if ($notebooks_exit_code -ne 0 -or $package_notebooks_exit_code -ne 0) { + exit 1 + } + displayName: 'Run notebooks on Windows' + condition: eq(variables['Agent.OS'], 'Windows_NT') + shell: pwsh - task: PublishTestResults@2 condition: succeededOrFailed() From 1faacc8d7c3e7514356666b6d59b2143bb83b462 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 2 Jul 2024 00:28:59 +0100 Subject: [PATCH 39/49] #2620: fix indentation --- .azure/azure-ci-build-pipeline.yaml | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 24606c24..afff0dbf 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -109,17 +109,17 @@ stages: # Run the notebooks - script: | - pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml - notebooks_exit_code=$? - pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml - package_notebooks_exit_code=$? - # Fail step if either of these do not have exit code 0 - if [ $notebooks_exit_code -ne 0 ] || [ $package_notebooks_exit_code -ne 0 ]; then - exit 1 - fi - displayName: 'Run notebooks on Linux and macOS' - condition: or(eq(variables['Agent.OS'], 'Linux'), eq(variables['Agent.OS'], 'Darwin')) - shell: bash + pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml + notebooks_exit_code=$? + pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml + package_notebooks_exit_code=$? + # Fail step if either of these do not have exit code 0 + if [ $notebooks_exit_code -ne 0 ] || [ $package_notebooks_exit_code -ne 0 ]; then + exit 1 + fi + displayName: 'Run notebooks on Linux and macOS' + condition: or(eq(variables['Agent.OS'], 'Linux'), eq(variables['Agent.OS'], 'Darwin')) + shell: bash # Run notebooks - script: | From 88ab3c3ca152512086e732194eb1a952ad3c3254 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 2 Jul 2024 00:45:24 +0100 Subject: [PATCH 40/49] #2620: remove shell --- .azure/azure-ci-build-pipeline.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index afff0dbf..8d418ba5 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -119,7 +119,6 @@ stages: fi displayName: 'Run notebooks on Linux and macOS' condition: or(eq(variables['Agent.OS'], 'Linux'), eq(variables['Agent.OS'], 'Darwin')) - shell: bash # Run notebooks - script: | @@ -133,7 +132,6 @@ stages: } displayName: 'Run notebooks on Windows' condition: eq(variables['Agent.OS'], 'Windows_NT') - shell: pwsh - task: PublishTestResults@2 condition: succeededOrFailed() From 12a6aa5e7faa05aed0086e0f61433e33b3e4ce98 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 2 Jul 2024 01:31:59 +0100 Subject: [PATCH 41/49] #2620: make the windows script work --- .azure/azure-ci-build-pipeline.yaml | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 8d418ba5..01111290 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -123,13 +123,12 @@ stages: # Run notebooks - script: | pytest --nbmake -n=auto src/primaite/notebooks --junit-xml=./notebook-tests/notebooks.xml - $notebooks_exit_code = $LASTEXITCODE + set notebooks_exit_code=%ERRORLEVEL% pytest --nbmake -n=auto src/primaite/simulator/_package_data --junit-xml=./notebook-tests/package-notebooks.xml - $package_notebooks_exit_code = $LASTEXITCODE - # Fail step if either of these do not have exit code 0 - if ($notebooks_exit_code -ne 0 -or $package_notebooks_exit_code -ne 0) { - exit 1 - } + set package_notebooks_exit_code=%ERRORLEVEL% + rem Fail step if either of these do not have exit code 0 + if %notebooks_exit_code% NEQ 0 exit /b 1 + if %package_notebooks_exit_code% NEQ 0 exit /b 1 displayName: 'Run notebooks on Windows' condition: eq(variables['Agent.OS'], 'Windows_NT') From a5a93fe8508dfd5e25898785aba244b216401ad7 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 2 Jul 2024 02:10:40 +0100 Subject: [PATCH 42/49] #2620: commit a div zero to check that the pipeline fails --- src/primaite/notebooks/multi-processing.ipynb | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/primaite/notebooks/multi-processing.ipynb b/src/primaite/notebooks/multi-processing.ipynb index 86b549a7..bdf9a6b5 100644 --- a/src/primaite/notebooks/multi-processing.ipynb +++ b/src/primaite/notebooks/multi-processing.ipynb @@ -18,6 +18,15 @@ "Import packages and read config file." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "1/0" + ] + }, { "cell_type": "code", "execution_count": null, From e2429df2200f9941d67898bd7b3549182479831f Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 2 Jul 2024 02:52:08 +0100 Subject: [PATCH 43/49] #2620: remove div zero from notebook --- src/primaite/notebooks/multi-processing.ipynb | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/primaite/notebooks/multi-processing.ipynb b/src/primaite/notebooks/multi-processing.ipynb index bdf9a6b5..305cfd70 100644 --- a/src/primaite/notebooks/multi-processing.ipynb +++ b/src/primaite/notebooks/multi-processing.ipynb @@ -18,15 +18,6 @@ "Import packages and read config file." ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "1/0" - ] - }, { "cell_type": "code", "execution_count": null, @@ -152,7 +143,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.11" } }, "nbformat": 4, From b27ac52d9ecdb05268175413b283bec4e9101946 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 2 Jul 2024 11:10:19 +0100 Subject: [PATCH 44/49] #2700 add E2E tests for application configure actions --- src/primaite/game/agent/actions.py | 18 +-- src/primaite/game/game.py | 3 +- .../applications/red_applications/dos_bot.py | 2 +- .../configs/install_and_configure_apps.yaml | 142 ++++++++++++++++++ .../configs/test_application_install.yaml | 9 ++ .../test_uc2_data_manipulation_scenario.py | 1 + .../actions/test_configure_actions.py | 115 ++++++++++++-- 7 files changed, 267 insertions(+), 23 deletions(-) create mode 100644 tests/assets/configs/install_and_configure_apps.yaml diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 60ff19e5..b3b7189c 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -258,13 +258,13 @@ class ConfigureDatabaseClientAction(AbstractAction): def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) - def form_request(self, node_id: int, options: Dict) -> RequestFormat: + def form_request(self, node_id: int, config: Dict) -> RequestFormat: """Return the action formatted as a request that can be ingested by the simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) if node_name is None: return ["do_nothing"] - ConfigureDatabaseClientAction._Opts.model_validate(options) # check that options adhere to schema - return ["network", "node", node_name, "application", "DatabaseClient", "configure", options] + ConfigureDatabaseClientAction._Opts.model_validate(config) # check that options adhere to schema + return ["network", "node", node_name, "application", "DatabaseClient", "configure", config] class ConfigureRansomwareScriptAction(AbstractAction): @@ -281,13 +281,13 @@ class ConfigureRansomwareScriptAction(AbstractAction): def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) - def form_request(self, node_id: int, options: Dict) -> RequestFormat: + def form_request(self, node_id: int, config: Dict) -> RequestFormat: """Return the action formatted as a request that can be ingested by the simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) if node_name is None: return ["do_nothing"] - ConfigureRansomwareScriptAction._Opts.model_validate(options) # check that options adhere to schema - return ["network", "node", node_name, "application", "RansomwareScript", "configure", options] + ConfigureRansomwareScriptAction._Opts.model_validate(config) # check that options adhere to schema + return ["network", "node", node_name, "application", "RansomwareScript", "configure", config] class ConfigureDoSBotAction(AbstractAction): @@ -308,13 +308,13 @@ class ConfigureDoSBotAction(AbstractAction): def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) - def form_request(self, node_id: int, options: Dict) -> RequestFormat: + def form_request(self, node_id: int, config: Dict) -> RequestFormat: """Return the action formatted as a request that can be ingested by the simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) if node_name is None: return ["do_nothing"] - self._Opts.model_validate(options) # check that options adhere to schema - return ["network", "node", node_name, "application", "DoSBot", "configure", options] + self._Opts.model_validate(config) # check that options adhere to schema + return ["network", "node", node_name, "application", "DoSBot", "configure", config] class NodeApplicationRemoveAction(AbstractAction): diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 05210278..89102afb 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -313,7 +313,8 @@ class PrimaiteGame: if "options" in service_cfg: opt = service_cfg["options"] new_service.password = opt.get("db_password", None) - new_service.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip"))) + if "backup_server_ip" in opt: + new_service.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip"))) if service_type == "FTPServer": if "options" in service_cfg: opt = service_cfg["options"] diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index 17478b71..01a375ee 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -135,7 +135,7 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): self.sys_log.warning( f"{self.name} is not properly configured. {self.target_ip_address=}, {self.target_port=}" ) - return True + return False self.clear_connections() self._perform_port_scan(p_of_success=self.port_scan_p_of_success) diff --git a/tests/assets/configs/install_and_configure_apps.yaml b/tests/assets/configs/install_and_configure_apps.yaml new file mode 100644 index 00000000..6b548f7e --- /dev/null +++ b/tests/assets/configs/install_and_configure_apps.yaml @@ -0,0 +1,142 @@ +io_settings: + save_step_metadata: false + save_pcap_logs: false + save_sys_logs: false + save_agent_actions: false + +game: + max_episode_length: 256 + ports: + - ARP + - DNS + protocols: + - ICMP + - TCP + +agents: + - ref: agent_1 + team: BLUE + type: ProxyAgent + + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_INSTALL + - type: CONFIGURE_DATABASE_CLIENT + - type: CONFIGURE_DOSBOT + - type: CONFIGURE_RANSOMWARE_SCRIPT + - type: NODE_APPLICATION_REMOVE + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_INSTALL + options: + node_id: 0 + application_name: DatabaseClient + 2: + action: NODE_APPLICATION_INSTALL + options: + node_id: 1 + application_name: RansomwareScript + 3: + action: NODE_APPLICATION_INSTALL + options: + node_id: 2 + application_name: DoSBot + 4: + action: CONFIGURE_DATABASE_CLIENT + options: + node_id: 0 + config: + server_ip_address: 10.0.0.5 + 5: + action: CONFIGURE_DATABASE_CLIENT + options: + node_id: 0 + config: + server_password: correct_password + 6: + action: CONFIGURE_RANSOMWARE_SCRIPT + options: + node_id: 1 + config: + server_ip_address: 10.0.0.5 + server_password: correct_password + payload: ENCRYPT + 7: + action: CONFIGURE_DOSBOT + options: + node_id: 2 + config: + target_ip_address: 10.0.0.5 + target_port: POSTGRES_SERVER + payload: DELETE + repeat: true + port_scan_p_of_success: 1.0 + dos_intensity: 1.0 + max_sessions: 1000 + 8: + action: NODE_APPLICATION_INSTALL + options: + node_id: 1 + application_name: DatabaseClient + options: + nodes: + - node_name: client_1 + - node_name: client_2 + - node_name: client_3 + ip_list: [] + reward_function: + reward_components: + - type: DUMMY + +simulation: + network: + nodes: + - type: computer + hostname: client_1 + ip_address: 10.0.0.2 + subnet_mask: 255.255.255.0 + default_gateway: 10.0.0.1 + - type: computer + hostname: client_2 + ip_address: 10.0.0.3 + subnet_mask: 255.255.255.0 + default_gateway: 10.0.0.1 + - type: computer + hostname: client_3 + ip_address: 10.0.0.4 + subnet_mask: 255.255.255.0 + default_gateway: 10.0.0.1 + - type: switch + hostname: switch_1 + num_ports: 8 + - type: server + hostname: server_1 + ip_address: 10.0.0.5 + subnet_mask: 255.255.255.0 + default_gateway: 10.0.0.1 + services: + - type: DatabaseService + options: + db_password: correct_password + links: + - endpoint_a_hostname: client_1 + endpoint_a_port: 1 + endpoint_b_hostname: switch_1 + endpoint_b_port: 1 + - endpoint_a_hostname: client_2 + endpoint_a_port: 1 + endpoint_b_hostname: switch_1 + endpoint_b_port: 2 + - endpoint_a_hostname: client_3 + endpoint_a_port: 1 + endpoint_b_hostname: switch_1 + endpoint_b_port: 3 + - endpoint_a_hostname: server_1 + endpoint_a_port: 1 + endpoint_b_hostname: switch_1 + endpoint_b_port: 8 diff --git a/tests/assets/configs/test_application_install.yaml b/tests/assets/configs/test_application_install.yaml index a4e898ae..3a3a6890 100644 --- a/tests/assets/configs/test_application_install.yaml +++ b/tests/assets/configs/test_application_install.yaml @@ -260,6 +260,7 @@ agents: - type: NODE_APPLICATION_INSTALL - type: NODE_APPLICATION_REMOVE - type: NODE_APPLICATION_EXECUTE + - type: CONFIGURE_DOSBOT action_map: 0: @@ -698,6 +699,14 @@ agents: options: node_id: 0 application_id: 0 + 82: + action: CONFIGURE_DOSBOT + options: + node_id: 0 + config: + target_ip_address: 192.168.1.14 + target_port: POSTGRES_SERVER + diff --git a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py index e6cd113f..7ec38d72 100644 --- a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py +++ b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py @@ -69,6 +69,7 @@ def test_application_install_uninstall_on_uc2(): env.step(0) # Test we can now execute the DoSBot app + env.step(82) # configure dos bot with ip address and port _, _, _, _, info = env.step(81) assert info["agent_actions"]["defender"].response.status == "success" diff --git a/tests/integration_tests/game_layer/actions/test_configure_actions.py b/tests/integration_tests/game_layer/actions/test_configure_actions.py index 6bcd3b52..b7acc8a8 100644 --- a/tests/integration_tests/game_layer/actions/test_configure_actions.py +++ b/tests/integration_tests/game_layer/actions/test_configure_actions.py @@ -9,12 +9,19 @@ from primaite.game.agent.actions import ( ConfigureDoSBotAction, ConfigureRansomwareScriptAction, ) +from primaite.session.environment import PrimaiteGymEnv +from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript +from primaite.simulator.system.services.database.database_service import DatabaseService +from tests import TEST_ASSETS_ROOT from tests.conftest import ControlledAgent +APP_CONFIG_YAML = TEST_ASSETS_ROOT / "configs/install_and_configure_apps.yaml" + class TestConfigureDatabaseAction: def test_configure_ip_password(self, game_and_agent): @@ -31,7 +38,7 @@ class TestConfigureDatabaseAction: "CONFIGURE_DATABASE_CLIENT", { "node_id": 0, - "options": { + "config": { "server_ip_address": "192.168.1.99", "server_password": "admin123", }, @@ -57,7 +64,7 @@ class TestConfigureDatabaseAction: "CONFIGURE_DATABASE_CLIENT", { "node_id": 0, - "options": { + "config": { "server_ip_address": "192.168.1.99", }, }, @@ -83,7 +90,7 @@ class TestConfigureDatabaseAction: "CONFIGURE_DATABASE_CLIENT", { "node_id": 0, - "options": { + "config": { "server_password": "admin123", }, }, @@ -97,7 +104,7 @@ class TestConfigureDatabaseAction: class TestConfigureRansomwareScriptAction: @pytest.mark.parametrize( - "options", + "config", [ {}, {"server_ip_address": "181.181.181.181"}, @@ -110,7 +117,7 @@ class TestConfigureRansomwareScriptAction: }, ], ) - def test_configure_ip_password(self, game_and_agent, options): + def test_configure_ip_password(self, game_and_agent, config): game, agent = game_and_agent agent: ControlledAgent agent.action_manager.actions["CONFIGURE_RANSOMWARE_SCRIPT"] = ConfigureRansomwareScriptAction( @@ -128,20 +135,20 @@ class TestConfigureRansomwareScriptAction: action = ( "CONFIGURE_RANSOMWARE_SCRIPT", - {"node_id": 0, "options": options}, + {"node_id": 0, "config": config}, ) agent.store_action(action) game.step() - expected_ip = old_ip if "server_ip_address" not in options else IPv4Address(options["server_ip_address"]) - expected_pw = old_pw if "server_password" not in options else options["server_password"] - expected_payload = old_payload if "payload" not in options else options["payload"] + expected_ip = old_ip if "server_ip_address" not in config else IPv4Address(config["server_ip_address"]) + expected_pw = old_pw if "server_password" not in config else config["server_password"] + expected_payload = old_payload if "payload" not in config else config["payload"] assert ransomware_script.server_ip_address == expected_ip assert ransomware_script.server_password == expected_pw assert ransomware_script.payload == expected_payload - def test_invalid_options(self, game_and_agent): + def test_invalid_config(self, game_and_agent): game, agent = game_and_agent agent: ControlledAgent agent.action_manager.actions["CONFIGURE_RANSOMWARE_SCRIPT"] = ConfigureRansomwareScriptAction( @@ -156,7 +163,7 @@ class TestConfigureRansomwareScriptAction: "CONFIGURE_RANSOMWARE_SCRIPT", { "node_id": 0, - "options": {"server_password": "admin123", "bad_option": 70}, + "config": {"server_password": "admin123", "bad_option": 70}, }, ) agent.store_action(action) @@ -178,7 +185,7 @@ class TestConfigureDoSBot: "CONFIGURE_DOSBOT", { "node_id": 0, - "options": { + "config": { "target_ip_address": "192.168.1.99", "target_port": "POSTGRES_SERVER", "payload": "HACC", @@ -199,3 +206,87 @@ class TestConfigureDoSBot: assert dos_bot.port_scan_p_of_success == 0.875 assert dos_bot.dos_intensity == 0.75 assert dos_bot.max_sessions == 50 + + +class TestConfigureYAML: + def test_configure_db_client(self): + env = PrimaiteGymEnv(env_config=APP_CONFIG_YAML) + + # make sure there's no db client on the node yet + client_1 = env.game.simulation.network.get_node_by_hostname("client_1") + assert client_1.software_manager.software.get("DatabaseClient") is None + + # take the install action, check that the db gets installed, step to get it to finish installing + env.step(1) + db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") + assert isinstance(db_client, DatabaseClient) + assert db_client.operating_state == ApplicationOperatingState.INSTALLING + env.step(0) + env.step(0) + env.step(0) + env.step(0) + + # configure the ip address and check that it changes, but password doesn't change + assert db_client.server_ip_address is None + assert db_client.server_password is None + env.step(4) + assert db_client.server_ip_address == IPv4Address("10.0.0.5") + assert db_client.server_password is None + + # configure the password and check that it changes, make sure this lets us connect to the db + assert not db_client.connect() + env.step(5) + assert db_client.server_password == "correct_password" + assert db_client.connect() + + def test_configure_ransomware_script(self): + env = PrimaiteGymEnv(env_config=APP_CONFIG_YAML) + client_2 = env.game.simulation.network.get_node_by_hostname("client_2") + assert client_2.software_manager.software.get("RansomwareScript") is None + + # install ransomware script + env.step(2) + ransom = client_2.software_manager.software.get("RansomwareScript") + assert isinstance(ransom, RansomwareScript) + assert ransom.operating_state == ApplicationOperatingState.INSTALLING + env.step(0) + env.step(0) + env.step(0) + env.step(0) + + # make sure it's not working yet because it's not configured and there's no db client + assert not ransom.attack() + env.step(8) # install db client on the same node + env.step(0) + env.step(0) + env.step(0) + env.step(0) # let it finish installing + assert not ransom.attack() + + # finally, configure the ransomware script with ip and password + env.step(6) + assert ransom.attack() + + db_server = env.game.simulation.network.get_node_by_hostname("server_1") + db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") + assert db_service.db_file.health_status == FileSystemItemHealthStatus.CORRUPT + + def test_configure_dos_bot(self): + env = PrimaiteGymEnv(env_config=APP_CONFIG_YAML) + client_3 = env.game.simulation.network.get_node_by_hostname("client_3") + assert client_3.software_manager.software.get("DoSBot") is None + + # install DoSBot + env.step(3) + bot = client_3.software_manager.software.get("DoSBot") + assert isinstance(bot, DoSBot) + assert bot.operating_state == ApplicationOperatingState.INSTALLING + env.step(0) + env.step(0) + env.step(0) + env.step(0) + + # make sure dos bot doesn't work before being configured + assert not bot.run() + env.step(7) + assert bot.run() From feabe5117c5d1471e61aeefc1f7609b42c819b28 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 2 Jul 2024 12:48:23 +0100 Subject: [PATCH 45/49] #2700 Fix docstrings in application configure methods --- src/primaite/simulator/network/hardware/base.py | 15 +++++++++++---- .../applications/red_applications/dos_bot.py | 15 +++++++++++++++ .../red_applications/ransomware_script.py | 10 ++++++++++ 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 1982b08f..6942d280 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -890,7 +890,11 @@ class Node(SimComponent): Allows agents to install applications to the node. :param request: list containing the application name as the only element - :type application: str + :type request: RequestFormat + :param context: additional context for resolving this action, currently unused + :type context: dict + :return: Request response with a success code if the application was installed. + :rtype: RequestResponse """ application_name = request[0] if self.software_manager.software.get(application_name): @@ -916,9 +920,12 @@ class Node(SimComponent): This method is useful for allowing agents to take this action. - :param application: Application object that is currently associated with this node. - :type application: Application - :return: True if the application is uninstalled successfully, otherwise False. + :param request: list containing the application name as the only element + :type request: RequestFormat + :param context: additional context for resolving this action, currently unused + :type context: dict + :return: Request response with a success code if the application was uninstalled. + :rtype: RequestResponse """ application_name = request[0] if application_name not in self.software_manager.software: diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index 01a375ee..fcad3b3e 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -72,6 +72,16 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): ) def _configure(request: RequestFormat, context: Dict) -> RequestResponse: + """ + Configure the DoSBot. + + :param request: List with one element that is a dict of options to pass to the configure method. + :type request: RequestFormat + :param context: additional context for resolving this action, currently unused + :type context: dict + :return: Request Response object with a success code determining if the configuration was successful. + :rtype: RequestResponse + """ if "target_ip_address" in request[-1]: request[-1]["target_ip_address"] = IPv4Address(request[-1]["target_ip_address"]) if "target_port" in request[-1]: @@ -102,6 +112,8 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): :param: dos_intensity: The intensity of the DoS attack. Multiplied with the application's max session - Default is 1.0 :param: max_sessions: The maximum number of sessions the DoS bot will attack with. Optional - Default is 1000 + :return: Always returns True + :rtype: bool """ self.target_ip_address = target_ip_address self.target_port = target_port @@ -126,6 +138,9 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): The main application loop for the Denial of Service bot. The loop goes through the stages of a DoS attack. + + :return: True if the application loop could be executed, False otherwise. + :rtype: bool """ if not self._can_perform_action(): return False diff --git a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py index 8d9d0d18..71e422c3 100644 --- a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py +++ b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py @@ -64,6 +64,16 @@ class RansomwareScript(Application, identifier="RansomwareScript"): ) def _configure(request: RequestFormat, context: Dict) -> RequestResponse: + """ + Request for configuring the target database and payload. + + :param request: Request with one element contianing a dict of parameters for the configure method. + :type request: RequestFormat + :param context: additional context for resolving this action, currently unused + :type context: dict + :return: RequestResponse object with a success code reflecting whether the configuration could be applied. + :rtype: RequestResponse + """ ip = request[-1].get("server_ip_address") ip = None if ip is None else IPv4Address(ip) pw = request[-1].get("server_password") From cbc414bddf4eb43c6d096b9d2f240bbec1b9ae0f Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 2 Jul 2024 15:25:40 +0100 Subject: [PATCH 46/49] #2702 - update data manipulation notebook text --- .../Data-Manipulation-E2E-Demonstration.ipynb | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb index b3a90cc0..0460f771 100644 --- a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb @@ -507,8 +507,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now service 1 on node 2 has `health_status = 3`, indicating that the webapp is compromised.\n", - "File 1 in folder 1 on node 3 has `health_status = 2`, indicating that the database file is compromised." + "Now service 1 on HOST1 has `health_status = 3`, indicating that the webapp is compromised.\n", + "File 1 in folder 1 on HOST2 has `health_status = 2`, indicating that the database file is compromised." ] }, { @@ -545,9 +545,9 @@ "source": [ "The fixing takes two steps, so the reward hasn't changed yet. Let's do nothing for another timestep, the reward should improve.\n", "\n", - "The reward will increase slightly as soon as the file finishes restoring. Then, the reward will increase to 1 when both green agents make successful requests.\n", + "The reward will increase slightly as soon as the file finishes restoring. Then, the reward will increase to 0.9 when both green agents make successful requests.\n", "\n", - "Run the following cell until the green action is `NODE_APPLICATION_EXECUTE` for application 0, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again." + "Run the following cell until the green action is `NODE_APPLICATION_EXECUTE` for application 0, then the reward should increase. If you run it enough times, another red attack will happen and the reward will drop again." ] }, { @@ -708,7 +708,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.10.12" } }, "nbformat": 4, From 6a72f6af42e0a80eb246c956e89f5536cf76507e Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 2 Jul 2024 15:52:18 +0100 Subject: [PATCH 47/49] #2725: add fix duration to application and service configuration --- src/primaite/game/game.py | 12 +- .../red_applications/ransomware_script.py | 2 +- .../assets/configs/fix_duration_one_item.yaml | 248 +++++++++++++++++ .../assets/configs/software_fix_duration.yaml | 263 ++++++++++++++++++ .../test_software_fix_duration.py | 93 +++++++ 5 files changed, 616 insertions(+), 2 deletions(-) create mode 100644 tests/assets/configs/fix_duration_one_item.yaml create mode 100644 tests/assets/configs/software_fix_duration.yaml create mode 100644 tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 8a79d068..62b7f231 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -297,6 +297,11 @@ class PrimaiteGame: new_node.software_manager.install(SERVICE_TYPES_MAPPING[service_type]) new_service = new_node.software_manager.software[service_type] + # fixing duration for the service + fix_duration = service_cfg.get("options", {}).get("fix_duration", None) + if fix_duration: + new_service.fixing_duration = fix_duration + # start the service new_service.start() else: @@ -336,6 +341,11 @@ class PrimaiteGame: if application_type in APPLICATION_TYPES_MAPPING: new_node.software_manager.install(APPLICATION_TYPES_MAPPING[application_type]) new_application = new_node.software_manager.software[application_type] + + # fixing duration for the application + fix_duration = application_cfg.get("options", {}).get("fix_duration", None) + if fix_duration: + new_application.fixing_duration = fix_duration else: msg = f"Configuration contains an invalid application type: {application_type}" _LOGGER.error(msg) @@ -358,7 +368,7 @@ class PrimaiteGame: if "options" in application_cfg: opt = application_cfg["options"] new_application.configure( - server_ip_address=IPv4Address(opt.get("server_ip")), + server_ip_address=IPv4Address(opt.get("server_ip")) if opt.get("server_ip") else None, server_password=opt.get("server_password"), payload=opt.get("payload", "ENCRYPT"), ) diff --git a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py index af4a59d4..f47dd72f 100644 --- a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py +++ b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py @@ -88,7 +88,7 @@ class RansomwareScript(Application): def configure( self, - server_ip_address: IPv4Address, + server_ip_address: Optional[IPv4Address] = None, server_password: Optional[str] = None, payload: Optional[str] = None, ): diff --git a/tests/assets/configs/fix_duration_one_item.yaml b/tests/assets/configs/fix_duration_one_item.yaml new file mode 100644 index 00000000..59bc15f9 --- /dev/null +++ b/tests/assets/configs/fix_duration_one_item.yaml @@ -0,0 +1,248 @@ +# Basic Switched network +# +# -------------- -------------- -------------- +# | client_1 |------| switch_1 |------| client_2 | +# -------------- -------------- -------------- +# +io_settings: + save_step_metadata: false + save_pcap_logs: true + save_sys_logs: true + sys_log_level: WARNING + + +game: + max_episode_length: 256 + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + +agents: + - ref: client_2_green_user + team: GREEN + type: ProbabilisticAgent + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + options: + nodes: + - node_name: client_2 + applications: + - application_name: WebBrowser + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_settings: + start_step: 5 + frequency: 4 + variance: 3 + + + + - ref: defender + team: BLUE + type: ProxyAgent + + observation_space: + type: CUSTOM + options: + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: client_1 + - hostname: client_2 + - hostname: client_3 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + include_nmne: true + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.23 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - switch_1:eth-1<->client_1:eth-1 + - switch_1:eth-2<->client_2:eth-1 + - type: "NONE" + label: ICS + options: {} + + action_space: + action_list: + - type: DONOTHING + + action_map: + 0: + action: DONOTHING + options: {} + options: + nodes: + - node_name: switch + - node_name: client_1 + - node_name: client_2 + - node_name: client_3 + max_folders_per_node: 2 + max_files_per_folder: 2 + max_services_per_node: 2 + max_nics_per_node: 8 + max_acl_rules: 10 + ip_list: + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.23 + + reward_function: + reward_components: + - type: DATABASE_FILE_INTEGRITY + weight: 0.5 + options: + node_hostname: database_server + folder_name: database + file_name: database.db + + + - type: WEB_SERVER_404_PENALTY + weight: 0.5 + options: + node_hostname: web_server + service_name: web_server_web_service + + + agent_settings: + flatten_obs: true + +simulation: + network: + nodes: + + - type: switch + hostname: switch_1 + num_ports: 8 + + - hostname: client_1 + type: computer + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: RansomwareScript + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DatabaseClient + options: + db_server_ip: 192.168.1.10 + server_password: arcd + fix_duration: 1 + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.21 + server_password: arcd + - type: DoSBot + options: + target_ip_address: 192.168.10.21 + payload: SPOOF DATA + port_scan_p_of_success: 0.8 + services: + - type: DNSClient + options: + dns_server: 192.168.1.10 + - type: DNSServer + options: + domain_mapping: + arcd.com: 192.168.1.10 + - type: DatabaseService + options: + fix_duration: 5 + backup_server_ip: 192.168.1.10 + - type: WebServer + - type: FTPClient + - type: FTPServer + options: + server_password: arcd + - type: NTPClient + options: + ntp_server_ip: 192.168.1.10 + - type: NTPServer + - hostname: client_2 + type: computer + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: DatabaseClient + options: + db_server_ip: 192.168.1.10 + server_password: arcd + services: + - type: DNSClient + options: + dns_server: 192.168.1.10 + + links: + - endpoint_a_hostname: switch_1 + endpoint_a_port: 1 + endpoint_b_hostname: client_1 + endpoint_b_port: 1 + bandwidth: 200 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 2 + endpoint_b_hostname: client_2 + endpoint_b_port: 1 + bandwidth: 200 diff --git a/tests/assets/configs/software_fix_duration.yaml b/tests/assets/configs/software_fix_duration.yaml new file mode 100644 index 00000000..beb176d1 --- /dev/null +++ b/tests/assets/configs/software_fix_duration.yaml @@ -0,0 +1,263 @@ +# Basic Switched network +# +# -------------- -------------- -------------- +# | client_1 |------| switch_1 |------| client_2 | +# -------------- -------------- -------------- +# +io_settings: + save_step_metadata: false + save_pcap_logs: true + save_sys_logs: true + sys_log_level: WARNING + + +game: + max_episode_length: 256 + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + +agents: + - ref: client_2_green_user + team: GREEN + type: ProbabilisticAgent + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + options: + nodes: + - node_name: client_2 + applications: + - application_name: WebBrowser + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_settings: + start_step: 5 + frequency: 4 + variance: 3 + + + + - ref: defender + team: BLUE + type: ProxyAgent + + observation_space: + type: CUSTOM + options: + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: client_1 + - hostname: client_2 + - hostname: client_3 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + include_nmne: true + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.23 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - switch_1:eth-1<->client_1:eth-1 + - switch_1:eth-2<->client_2:eth-1 + - type: "NONE" + label: ICS + options: {} + + action_space: + action_list: + - type: DONOTHING + + action_map: + 0: + action: DONOTHING + options: {} + options: + nodes: + - node_name: switch + - node_name: client_1 + - node_name: client_2 + - node_name: client_3 + max_folders_per_node: 2 + max_files_per_folder: 2 + max_services_per_node: 2 + max_nics_per_node: 8 + max_acl_rules: 10 + ip_list: + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.23 + + reward_function: + reward_components: + - type: DATABASE_FILE_INTEGRITY + weight: 0.5 + options: + node_hostname: database_server + folder_name: database + file_name: database.db + + + - type: WEB_SERVER_404_PENALTY + weight: 0.5 + options: + node_hostname: web_server + service_name: web_server_web_service + + + agent_settings: + flatten_obs: true + +simulation: + network: + nodes: + + - type: switch + hostname: switch_1 + num_ports: 8 + + - hostname: client_1 + type: computer + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: RansomwareScript + options: + fix_duration: 1 + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + fix_duration: 1 + - type: DatabaseClient + options: + db_server_ip: 192.168.1.10 + server_password: arcd + fix_duration: 1 + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.21 + server_password: arcd + fix_duration: 1 + - type: DoSBot + options: + target_ip_address: 192.168.10.21 + payload: SPOOF DATA + port_scan_p_of_success: 0.8 + fix_duration: 1 + services: + - type: DNSClient + options: + dns_server: 192.168.1.10 + fix_duration: 3 + - type: DNSServer + options: + fix_duration: 3 + domain_mapping: + arcd.com: 192.168.1.10 + - type: DatabaseService + options: + backup_server_ip: 192.168.1.10 + fix_duration: 3 + - type: WebServer + options: + fix_duration: 3 + - type: FTPClient + options: + fix_duration: 3 + - type: FTPServer + options: + server_password: arcd + fix_duration: 3 + - type: NTPClient + options: + ntp_server_ip: 192.168.1.10 + fix_duration: 3 + - type: NTPServer + options: + fix_duration: 3 + - hostname: client_2 + type: computer + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: DatabaseClient + options: + db_server_ip: 192.168.1.10 + server_password: arcd + services: + - type: DNSClient + options: + dns_server: 192.168.1.10 + + links: + - endpoint_a_hostname: switch_1 + endpoint_a_port: 1 + endpoint_b_hostname: client_1 + endpoint_b_port: 1 + bandwidth: 200 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 2 + endpoint_b_hostname: client_2 + endpoint_b_port: 1 + bandwidth: 200 diff --git a/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py b/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py new file mode 100644 index 00000000..bf325946 --- /dev/null +++ b/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py @@ -0,0 +1,93 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import copy +from ipaddress import IPv4Address +from pathlib import Path +from typing import Union + +import yaml + +from primaite.config.load import data_manipulation_config_path +from primaite.game.agent.interface import ProxyAgent +from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent +from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent +from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot +from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot +from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.simulator.system.services.database.database_service import DatabaseService +from primaite.simulator.system.services.dns.dns_client import DNSClient +from primaite.simulator.system.services.dns.dns_server import DNSServer +from primaite.simulator.system.services.ftp.ftp_client import FTPClient +from primaite.simulator.system.services.ftp.ftp_server import FTPServer +from primaite.simulator.system.services.ntp.ntp_client import NTPClient +from primaite.simulator.system.services.ntp.ntp_server import NTPServer +from primaite.simulator.system.services.web_server.web_server import WebServer +from tests import TEST_ASSETS_ROOT + +TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fix_duration.yaml" +ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fix_duration_one_item.yaml" + + +def load_config(config_path: Union[str, Path]) -> PrimaiteGame: + """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) + + return PrimaiteGame.from_config(cfg) + + +def test_default_fix_duration(): + """Test that software with no defined fix duration in config uses the default fix duration of 2.""" + game = load_config(TEST_CONFIG) + client_2: Computer = game.simulation.network.get_node_by_hostname("client_2") + + database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient") + assert database_client.fixing_duration == 2 + + dns_client: DNSClient = client_2.software_manager.software.get("DNSClient") + assert dns_client.fixing_duration == 2 + + +def test_fix_duration_set_from_config(): + """Test to check that the fix duration set for applications and services works as intended.""" + game = load_config(TEST_CONFIG) + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + # in config - services take 3 timesteps to fix + for service in SERVICE_TYPES_MAPPING: + assert client_1.software_manager.software.get(service) is not None + assert client_1.software_manager.software.get(service).fixing_duration == 3 + + # in config - applications take 1 timestep to fix + for applications in APPLICATION_TYPES_MAPPING: + assert client_1.software_manager.software.get(applications) is not None + assert client_1.software_manager.software.get(applications).fixing_duration == 1 + + +def test_fix_duration_for_one_item(): + """Test that setting fix duration for one application does not affect other components.""" + game = load_config(ONE_ITEM_CONFIG) + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + # in config - services take 3 timesteps to fix + services = copy.copy(SERVICE_TYPES_MAPPING) + services.pop("DatabaseService") + for service in services: + assert client_1.software_manager.software.get(service) is not None + assert client_1.software_manager.software.get(service).fixing_duration == 2 + + # in config - applications take 1 timestep to fix + applications = copy.copy(APPLICATION_TYPES_MAPPING) + applications.pop("DatabaseClient") + for applications in applications: + assert client_1.software_manager.software.get(applications) is not None + assert client_1.software_manager.software.get(applications).fixing_duration == 2 + + database_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") + assert database_client.fixing_duration == 1 + + database_service: DatabaseService = client_1.software_manager.software.get("DatabaseService") + assert database_service.fixing_duration == 5 From fcd12091567ef7873761bb9438453e7d26793bba Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 2 Jul 2024 16:55:28 +0100 Subject: [PATCH 48/49] #2725: documentation --- .../system/common/common_configuration.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/source/simulation_components/system/common/common_configuration.rst b/docs/source/simulation_components/system/common/common_configuration.rst index 7a5b6ab5..e35ee378 100644 --- a/docs/source/simulation_components/system/common/common_configuration.rst +++ b/docs/source/simulation_components/system/common/common_configuration.rst @@ -16,3 +16,12 @@ The type of software that should be added. To add |SOFTWARE_NAME| this must be | =========== The configuration options are the attributes that fall under the options for an application. + + + +``fix_duration`` +"""""""""""""""" + +Optional. Default value is ``2``. + +The number of timesteps the |SOFTWARE_NAME| will remain in a ``FIXING`` state before going into a ``GOOD`` state. From 55c457a87d5a5d55341da299675ae40d6655eee5 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Wed, 3 Jul 2024 10:34:44 +0100 Subject: [PATCH 49/49] #2725: apply PR suggestions --- src/primaite/game/game.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 62b7f231..d4e5d100 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -298,9 +298,8 @@ class PrimaiteGame: new_service = new_node.software_manager.software[service_type] # fixing duration for the service - fix_duration = service_cfg.get("options", {}).get("fix_duration", None) - if fix_duration: - new_service.fixing_duration = fix_duration + if "fix_duration" in service_cfg.get("options", {}): + new_service.fixing_duration = service_cfg["options"]["fix_duration"] # start the service new_service.start() @@ -343,9 +342,8 @@ class PrimaiteGame: new_application = new_node.software_manager.software[application_type] # fixing duration for the application - fix_duration = application_cfg.get("options", {}).get("fix_duration", None) - if fix_duration: - new_application.fixing_duration = fix_duration + if "fix_duration" in application_cfg.get("options", {}): + new_application.fixing_duration = application_cfg["options"]["fix_duration"] else: msg = f"Configuration contains an invalid application type: {application_type}" _LOGGER.error(msg)