diff --git a/CHANGELOG.md b/CHANGELOG.md
index b27244bc..42519cdf 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **Transmission Feasibility Check**: Updated `_can_transmit` function in `Link` to account for current load and total bandwidth capacity, ensuring transmissions do not exceed limits.
- **Frame Size Details**: Frame `size` attribute now includes both core size and payload size in bytes.
- **Transmission Blocking**: Enhanced `AirSpace` logic to block transmissions that would exceed the available capacity.
+- **Software (un)install refactored**: Removed the install/uninstall methods in the node class and made the software manager install/uninstall handle all of their functionality.
### Fixed
diff --git a/benchmark/primaite_benchmark.py b/benchmark/primaite_benchmark.py
index 27e25a0c..0e6c2acc 100644
--- a/benchmark/primaite_benchmark.py
+++ b/benchmark/primaite_benchmark.py
@@ -5,7 +5,7 @@ from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Final, Tuple
-from report import build_benchmark_latex_report
+from report import build_benchmark_md_report
from stable_baselines3 import PPO
import primaite
@@ -188,7 +188,7 @@ def run(
with open(_SESSION_METADATA_ROOT / f"{i}.json", "r") as file:
session_metadata_dict[i] = json.load(file)
# generate report
- build_benchmark_latex_report(
+ build_benchmark_md_report(
benchmark_start_time=benchmark_start_time,
session_metadata=session_metadata_dict,
config_path=data_manipulation_config_path(),
diff --git a/benchmark/report.py b/benchmark/report.py
index 5eaaab9f..e1ff46b9 100644
--- a/benchmark/report.py
+++ b/benchmark/report.py
@@ -234,10 +234,7 @@ def _plot_av_s_per_100_steps_10_nodes(
"""
major_v = primaite.__version__.split(".")[0]
title = f"Performance of Minor and Bugfix Releases for Major Version {major_v}"
- subtitle = (
- f"Average Training Time per 100 Steps on 10 Nodes "
- f"(target: <= {PLOT_CONFIG['av_s_per_100_steps_10_nodes_benchmark_threshold']} seconds)"
- )
+ subtitle = "Average Training Time per 100 Steps on 10 Nodes "
title = f"{title}
{subtitle}"
layout = go.Layout(
@@ -250,10 +247,6 @@ def _plot_av_s_per_100_steps_10_nodes(
versions = sorted(list(version_times_dict.keys()))
times = [version_times_dict[version] for version in versions]
- av_s_per_100_steps_10_nodes_benchmark_threshold = PLOT_CONFIG["av_s_per_100_steps_10_nodes_benchmark_threshold"]
-
- # Calculate the appropriate maximum y-axis value
- max_y_axis_value = max(max(times), av_s_per_100_steps_10_nodes_benchmark_threshold) + 1
fig.add_trace(
go.Bar(
@@ -267,7 +260,6 @@ def _plot_av_s_per_100_steps_10_nodes(
fig.update_layout(
xaxis_title="PrimAITE Version",
yaxis_title="Avg Time per 100 Steps on 10 Nodes (seconds)",
- yaxis=dict(range=[0, max_y_axis_value]),
title=title,
)
diff --git a/pyproject.toml b/pyproject.toml
index 9e919604..c9b7c062 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -52,7 +52,7 @@ license-files = ["LICENSE"]
[project.optional-dependencies]
rl = [
- "ray[rllib] >= 2.20.0, < 3",
+ "ray[rllib] >= 2.20.0, <2.33",
"tensorflow==2.12.0",
"stable-baselines3[extra]==2.1.0",
"sb3-contrib==2.1.0",
diff --git a/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml b/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml
index 81848b2d..dfd200f3 100644
--- a/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml
+++ b/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml
@@ -129,6 +129,10 @@ agents:
simulation:
network:
+ nmne_config:
+ capture_nmne: true
+ nmne_capture_keywords:
+ - DELETE
nodes:
- hostname: client
type: computer
diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py
index 1b1231f6..6e29c2ce 100644
--- a/src/primaite/game/game.py
+++ b/src/primaite/game/game.py
@@ -18,7 +18,7 @@ from primaite.game.agent.scripted_agents.tap001 import TAP001
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.airspace import AirSpaceFrequency
-from primaite.simulator.network.hardware.base import NodeOperatingState
+from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
@@ -26,7 +26,7 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.network.hardware.nodes.network.router import Router
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
-from primaite.simulator.network.nmne import set_nmne_config
+from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.application import Application
@@ -266,6 +266,8 @@ class PrimaiteGame:
nodes_cfg = network_config.get("nodes", [])
links_cfg = network_config.get("links", [])
+ # Set the NMNE capture config
+ NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {}))
for node_cfg in nodes_cfg:
n_type = node_cfg["type"]
@@ -535,10 +537,7 @@ class PrimaiteGame:
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
game.setup_reward_sharing()
- # Set the NMNE capture config
- set_nmne_config(network_config.get("nmne_config", {}))
game.update_agents(game.get_sim_state())
-
return game
def setup_reward_sharing(self):
diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py
index 15c44821..98f47cc3 100644
--- a/src/primaite/simulator/network/hardware/base.py
+++ b/src/primaite/simulator/network/hardware/base.py
@@ -6,12 +6,11 @@ import secrets
from abc import ABC, abstractmethod
from ipaddress import IPv4Address, IPv4Network
from pathlib import Path
-from typing import Any, Dict, Optional, TypeVar, Union
+from typing import Any, ClassVar, Dict, Optional, TypeVar, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel, Field
-import primaite.simulator.network.nmne
from primaite import getLogger
from primaite.exceptions import NetworkError
from primaite.interface.request import RequestResponse
@@ -20,15 +19,7 @@ from primaite.simulator.core import RequestFormat, RequestManager, RequestPermis
from primaite.simulator.domain.account import Account
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
-from primaite.simulator.network.nmne import (
- CAPTURE_BY_DIRECTION,
- CAPTURE_BY_IP_ADDRESS,
- CAPTURE_BY_KEYWORD,
- CAPTURE_BY_PORT,
- CAPTURE_BY_PROTOCOL,
- CAPTURE_NMNE,
- NMNE_CAPTURE_KEYWORDS,
-)
+from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.system.applications.application import Application
@@ -108,8 +99,11 @@ class NetworkInterface(SimComponent, ABC):
pcap: Optional[PacketCapture] = None
"A PacketCapture instance for capturing and analysing packets passing through this interface."
+ nmne_config: ClassVar[NMNEConfig] = NMNEConfig()
+ "A dataclass defining malicious network events to be captured."
+
nmne: Dict = Field(default_factory=lambda: {})
- "A dict containing details of the number of malicious network events captured."
+ "A dict containing details of the number of malicious events captured."
traffic: Dict = Field(default_factory=lambda: {})
"A dict containing details of the inbound and outbound traffic by port and protocol."
@@ -167,8 +161,8 @@ class NetworkInterface(SimComponent, ABC):
"enabled": self.enabled,
}
)
- if CAPTURE_NMNE:
- state.update({"nmne": {k: v for k, v in self.nmne.items()}})
+ if self.nmne_config and self.nmne_config.capture_nmne:
+ state.update({"nmne": self.nmne})
state.update({"traffic": convert_dict_enum_keys_to_enum_values(self.traffic)})
return state
@@ -201,7 +195,7 @@ class NetworkInterface(SimComponent, ABC):
:param inbound: Boolean indicating if the frame direction is inbound. Defaults to True.
"""
# Exit function if NMNE capturing is disabled
- if not CAPTURE_NMNE:
+ if not (self.nmne_config and self.nmne_config.capture_nmne):
return
# Initialise basic frame data variables
@@ -222,27 +216,27 @@ class NetworkInterface(SimComponent, ABC):
frame_str = str(frame.payload)
# Proceed only if any NMNE keyword is present in the frame payload
- if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS):
+ if any(keyword in frame_str for keyword in self.nmne_config.nmne_capture_keywords):
# Start with the root of the NMNE capture structure
current_level = self.nmne
# Update NMNE structure based on enabled settings
- if CAPTURE_BY_DIRECTION:
+ if self.nmne_config.capture_by_direction:
# Set or get the dictionary for the current direction
current_level = current_level.setdefault("direction", {})
current_level = current_level.setdefault(direction, {})
- if CAPTURE_BY_IP_ADDRESS:
+ if self.nmne_config.capture_by_ip_address:
# Set or get the dictionary for the current IP address
current_level = current_level.setdefault("ip_address", {})
current_level = current_level.setdefault(ip_address, {})
- if CAPTURE_BY_PROTOCOL:
+ if self.nmne_config.capture_by_protocol:
# Set or get the dictionary for the current protocol
current_level = current_level.setdefault("protocol", {})
current_level = current_level.setdefault(protocol, {})
- if CAPTURE_BY_PORT:
+ if self.nmne_config.capture_by_port:
# Set or get the dictionary for the current port
current_level = current_level.setdefault("port", {})
current_level = current_level.setdefault(port, {})
@@ -251,8 +245,8 @@ class NetworkInterface(SimComponent, ABC):
keyword_level = current_level.setdefault("keywords", {})
# Increment the count for detected keywords in the payload
- if CAPTURE_BY_KEYWORD:
- for keyword in NMNE_CAPTURE_KEYWORDS:
+ if self.nmne_config.capture_by_keyword:
+ for keyword in self.nmne_config.nmne_capture_keywords:
if keyword in frame_str:
# Update the count for each keyword found
keyword_level[keyword] = keyword_level.get(keyword, 0) + 1
@@ -1173,7 +1167,7 @@ class Node(SimComponent):
ip_address,
network_interface.speed,
"Enabled" if network_interface.enabled else "Disabled",
- network_interface.nmne if primaite.simulator.network.nmne.CAPTURE_NMNE else "Disabled",
+ network_interface.nmne if network_interface.nmne_config.capture_nmne else "Disabled",
]
)
print(table)
@@ -1455,74 +1449,6 @@ class Node(SimComponent):
else:
return
- def install_service(self, service: Service) -> None:
- """
- Install a service on this node.
-
- :param service: Service instance that has not been installed on any node yet.
- :type service: Service
- """
- if service in self:
- _LOGGER.warning(f"Can't add service {service.name} to node {self.hostname}. It's already installed.")
- return
- self.services[service.uuid] = service
- service.parent = self
- service.install() # Perform any additional setup, such as creating files for this service on the node.
- self.sys_log.info(f"Installed service {service.name}")
- _LOGGER.debug(f"Added service {service.name} to node {self.hostname}")
- self._service_request_manager.add_request(service.name, RequestType(func=service._request_manager))
-
- def uninstall_service(self, service: Service) -> None:
- """
- Uninstall and completely remove service from this node.
-
- :param service: Service object that is currently associated with this node.
- :type service: Service
- """
- if service not in self:
- _LOGGER.warning(f"Can't remove service {service.name} from node {self.hostname}. It's not installed.")
- return
- service.uninstall() # Perform additional teardown, such as removing files or restarting the machine.
- self.services.pop(service.uuid)
- service.parent = None
- self.sys_log.info(f"Uninstalled service {service.name}")
- self._service_request_manager.remove_request(service.name)
-
- def install_application(self, application: Application) -> None:
- """
- Install an application on this node.
-
- :param application: Application instance that has not been installed on any node yet.
- :type application: Application
- """
- if application in self:
- _LOGGER.warning(
- f"Can't add application {application.name} to node {self.hostname}. It's already installed."
- )
- return
- self.applications[application.uuid] = application
- application.parent = self
- self.sys_log.info(f"Installed application {application.name}")
- _LOGGER.debug(f"Added application {application.name} to node {self.hostname}")
- self._application_request_manager.add_request(application.name, RequestType(func=application._request_manager))
-
- def uninstall_application(self, application: Application) -> None:
- """
- Uninstall and completely remove application from this node.
-
- :param application: Application object that is currently associated with this node.
- :type application: Application
- """
- if application not in self:
- _LOGGER.warning(
- f"Can't remove application {application.name} from node {self.hostname}. It's not installed."
- )
- return
- self.applications.pop(application.uuid)
- application.parent = None
- self.sys_log.info(f"Uninstalled application {application.name}")
- self._application_request_manager.remove_request(application.name)
-
def _shut_down_actions(self):
"""Actions to perform when the node is shut down."""
# Turn off all the services in the node
diff --git a/src/primaite/simulator/network/nmne.py b/src/primaite/simulator/network/nmne.py
index 5c0c657b..c9cff5de 100644
--- a/src/primaite/simulator/network/nmne.py
+++ b/src/primaite/simulator/network/nmne.py
@@ -1,48 +1,25 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
-from typing import Dict, Final, List
+from typing import List
-CAPTURE_NMNE: bool = True
-"""Indicates whether Malicious Network Events (MNEs) should be captured. Default is True."""
-
-NMNE_CAPTURE_KEYWORDS: List[str] = []
-"""List of keywords to identify malicious network events."""
-
-# TODO: Remove final and make configurable after example layout when the NICObservation creates nmne structure dynamically
-CAPTURE_BY_DIRECTION: Final[bool] = True
-"""Flag to determine if captures should be organized by traffic direction (inbound/outbound)."""
-CAPTURE_BY_IP_ADDRESS: Final[bool] = False
-"""Flag to determine if captures should be organized by source or destination IP address."""
-CAPTURE_BY_PROTOCOL: Final[bool] = False
-"""Flag to determine if captures should be organized by network protocol (e.g., TCP, UDP)."""
-CAPTURE_BY_PORT: Final[bool] = False
-"""Flag to determine if captures should be organized by source or destination port."""
-CAPTURE_BY_KEYWORD: Final[bool] = False
-"""Flag to determine if captures should be filtered and categorised based on specific keywords."""
+from pydantic import BaseModel, ConfigDict
-def set_nmne_config(nmne_config: Dict):
- """
- Sets the configuration for capturing Malicious Network Events (MNEs) based on a provided dictionary.
+class NMNEConfig(BaseModel):
+ """Store all the information to perform NMNE operations."""
- This function updates global settings related to NMNE capture, including whether to capture NMNEs and what
- keywords to use for identifying NMNEs.
+ model_config = ConfigDict(extra="forbid")
- The function ensures that the settings are updated only if they are provided in the `nmne_config` dictionary,
- and maintains type integrity by checking the types of the provided values.
-
- :param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys include:
- "capture_nmne" (bool) to indicate whether NMNEs should be captured, "nmne_capture_keywords" (list of strings)
- to specify keywords for NMNE identification.
- """
- global NMNE_CAPTURE_KEYWORDS
- global CAPTURE_NMNE
-
- # Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect
- CAPTURE_NMNE = nmne_config.get("capture_nmne", False)
- if not isinstance(CAPTURE_NMNE, bool):
- CAPTURE_NMNE = True # Revert to default True if the provided value is not a boolean
-
- # Update the NMNE capture keywords, appending new keywords if provided
- NMNE_CAPTURE_KEYWORDS += nmne_config.get("nmne_capture_keywords", [])
- if not isinstance(NMNE_CAPTURE_KEYWORDS, list):
- NMNE_CAPTURE_KEYWORDS = [] # Reset to empty list if the provided value is not a list
+ capture_nmne: bool = False
+ """Indicates whether Malicious Network Events (MNEs) should be captured."""
+ nmne_capture_keywords: List[str] = []
+ """List of keywords to identify malicious network events."""
+ capture_by_direction: bool = True
+ """Captures should be organized by traffic direction (inbound/outbound)."""
+ capture_by_ip_address: bool = False
+ """Captures should be organized by source or destination IP address."""
+ capture_by_protocol: bool = False
+ """Captures should be organized by network protocol (e.g., TCP, UDP)."""
+ capture_by_port: bool = False
+ """Captures should be organized by source or destination port."""
+ capture_by_keyword: bool = False
+ """Captures should be filtered and categorised based on specific keywords."""
diff --git a/src/primaite/simulator/network/protocols/icmp.py b/src/primaite/simulator/network/protocols/icmp.py
index 743e2375..9f0626f0 100644
--- a/src/primaite/simulator/network/protocols/icmp.py
+++ b/src/primaite/simulator/network/protocols/icmp.py
@@ -4,7 +4,7 @@ from enum import Enum
from typing import Union
from pydantic import BaseModel, field_validator, validate_call
-from pydantic_core.core_schema import FieldValidationInfo
+from pydantic_core.core_schema import ValidationInfo
from primaite import getLogger
@@ -96,7 +96,7 @@ class ICMPPacket(BaseModel):
@field_validator("icmp_code") # noqa
@classmethod
- def _icmp_type_must_have_icmp_code(cls, v: int, info: FieldValidationInfo) -> int:
+ def _icmp_type_must_have_icmp_code(cls, v: int, info: ValidationInfo) -> int:
"""Validates the icmp_type and icmp_code."""
icmp_type = info.data["icmp_type"]
if get_icmp_type_code_description(icmp_type, v):
diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py
index e2266c2d..9c4d7cf6 100644
--- a/src/primaite/simulator/system/core/software_manager.py
+++ b/src/primaite/simulator/system/core/software_manager.py
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from prettytable import MARKDOWN, PrettyTable
+from primaite.simulator.core import RequestType
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.network.transmission.network_layer import IPProtocol
@@ -20,9 +21,7 @@ if TYPE_CHECKING:
from primaite.simulator.system.services.arp.arp import ARP
from primaite.simulator.system.services.icmp.icmp import ICMP
-from typing import Type, TypeVar
-
-IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware)
+from typing import Type
class SoftwareManager:
@@ -51,7 +50,7 @@ class SoftwareManager:
self.node = parent_node
self.session_manager = session_manager
self.software: Dict[str, Union[Service, Application]] = {}
- self._software_class_to_name_map: Dict[Type[IOSoftwareClass], str] = {}
+ self._software_class_to_name_map: Dict[Type[IOSoftware], str] = {}
self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {}
self.sys_log: SysLog = sys_log
self.file_system: FileSystem = file_system
@@ -104,33 +103,34 @@ class SoftwareManager:
return True
return False
- def install(self, software_class: Type[IOSoftwareClass]):
+ def install(self, software_class: Type[IOSoftware]):
"""
Install an Application or Service.
:param software_class: The software class.
"""
- # TODO: Software manager and node itself both have an install method. Need to refactor to have more logical
- # separation of concerns.
if software_class in self._software_class_to_name_map:
self.sys_log.warning(f"Cannot install {software_class} as it is already installed")
return
software = software_class(
software_manager=self, sys_log=self.sys_log, file_system=self.file_system, dns_server=self.dns_server
)
+ software.parent = self.node
if isinstance(software, Application):
- software.install()
+ self.node.applications[software.uuid] = software
+ self.node._application_request_manager.add_request(
+ software.name, RequestType(func=software._request_manager)
+ )
+ elif isinstance(software, Service):
+ self.node.services[software.uuid] = software
+ self.node._service_request_manager.add_request(software.name, RequestType(func=software._request_manager))
+ software.install()
software.software_manager = self
self.software[software.name] = software
self.port_protocol_mapping[(software.port, software.protocol)] = software
if isinstance(software, Application):
software.operating_state = ApplicationOperatingState.CLOSED
-
- # add the software to the node's registry after it has been fully initialized
- if isinstance(software, Service):
- self.node.install_service(software)
- elif isinstance(software, Application):
- self.node.install_application(software)
+ self.node.sys_log.info(f"Installed {software.name}")
def uninstall(self, software_name: str):
"""
@@ -138,25 +138,31 @@ class SoftwareManager:
:param software_name: The software name.
"""
- if software_name in self.software:
- self.software[software_name].uninstall()
- software = self.software.pop(software_name) # noqa
- if isinstance(software, Application):
- self.node.uninstall_application(software)
- elif isinstance(software, Service):
- self.node.uninstall_service(software)
- for key, value in self.port_protocol_mapping.items():
- if value.name == software_name:
- self.port_protocol_mapping.pop(key)
- break
- for key, value in self._software_class_to_name_map.items():
- if value == software_name:
- self._software_class_to_name_map.pop(key)
- break
- del software
- self.sys_log.info(f"Uninstalled {software_name}")
+ if software_name not in self.software:
+ self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")
return
- self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")
+
+ self.software[software_name].uninstall()
+ software = self.software.pop(software_name) # noqa
+ if isinstance(software, Application):
+ self.node.applications.pop(software.uuid)
+ self.node._application_request_manager.remove_request(software.name)
+ elif isinstance(software, Service):
+ self.node.services.pop(software.uuid)
+ software.uninstall()
+ self.node._service_request_manager.remove_request(software.name)
+ software.parent = None
+ for key, value in self.port_protocol_mapping.items():
+ if value.name == software_name:
+ self.port_protocol_mapping.pop(key)
+ break
+ for key, value in self._software_class_to_name_map.items():
+ if value == software_name:
+ self._software_class_to_name_map.pop(key)
+ break
+ del software
+ self.sys_log.info(f"Uninstalled {software_name}")
+ return
def send_internal_payload(self, target_software: str, payload: Any):
"""
diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml
index 8cbd3ae9..c83cadc8 100644
--- a/tests/assets/configs/bad_primaite_session.yaml
+++ b/tests/assets/configs/bad_primaite_session.yaml
@@ -99,7 +99,7 @@ agents:
num_files: 1
num_nics: 2
include_num_access: false
- include_nmne: true
+ include_nmne: false
routers:
- hostname: router_1
num_ports: 0
diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml
index 69187fa3..fed0f52d 100644
--- a/tests/assets/configs/basic_switched_network.yaml
+++ b/tests/assets/configs/basic_switched_network.yaml
@@ -92,7 +92,7 @@ agents:
- NONE
tcp:
- DNS
- include_nmne: true
+ include_nmne: false
routers:
- hostname: router_1
num_ports: 0
diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml
index de861dcc..3d60eb6e 100644
--- a/tests/assets/configs/eval_only_primaite_session.yaml
+++ b/tests/assets/configs/eval_only_primaite_session.yaml
@@ -111,7 +111,7 @@ agents:
num_files: 1
num_nics: 2
include_num_access: false
- include_nmne: true
+ include_nmne: false
routers:
- hostname: router_1
num_ports: 0
diff --git a/tests/assets/configs/firewall_actions_network.yaml b/tests/assets/configs/firewall_actions_network.yaml
index fd5b1bf8..2292616d 100644
--- a/tests/assets/configs/firewall_actions_network.yaml
+++ b/tests/assets/configs/firewall_actions_network.yaml
@@ -68,7 +68,7 @@ agents:
num_files: 1
num_nics: 2
include_num_access: false
- include_nmne: true
+ include_nmne: false
routers:
- hostname: router_1
num_ports: 0
diff --git a/tests/assets/configs/fix_duration_one_item.yaml b/tests/assets/configs/fix_duration_one_item.yaml
index 59bc15f9..bd0fb61f 100644
--- a/tests/assets/configs/fix_duration_one_item.yaml
+++ b/tests/assets/configs/fix_duration_one_item.yaml
@@ -89,7 +89,7 @@ agents:
- NONE
tcp:
- DNS
- include_nmne: true
+ include_nmne: false
routers:
- hostname: router_1
num_ports: 0
diff --git a/tests/assets/configs/scenario_with_placeholders/scenario.yaml b/tests/assets/configs/scenario_with_placeholders/scenario.yaml
index 81848b2d..ef930a1a 100644
--- a/tests/assets/configs/scenario_with_placeholders/scenario.yaml
+++ b/tests/assets/configs/scenario_with_placeholders/scenario.yaml
@@ -44,7 +44,7 @@ agents:
num_files: 1
num_nics: 1
include_num_access: false
- include_nmne: true
+ include_nmne: false
- type: LINKS
label: LINKS
diff --git a/tests/assets/configs/software_fix_duration.yaml b/tests/assets/configs/software_fix_duration.yaml
index 1acb05a9..1a28258b 100644
--- a/tests/assets/configs/software_fix_duration.yaml
+++ b/tests/assets/configs/software_fix_duration.yaml
@@ -89,7 +89,7 @@ agents:
- NONE
tcp:
- DNS
- include_nmne: true
+ include_nmne: false
routers:
- hostname: router_1
num_ports: 0
diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml
index eb8103e8..27cfa240 100644
--- a/tests/assets/configs/test_primaite_session.yaml
+++ b/tests/assets/configs/test_primaite_session.yaml
@@ -120,7 +120,7 @@ agents:
num_files: 1
num_nics: 2
include_num_access: false
- include_nmne: true
+ include_nmne: false
routers:
- hostname: router_1
num_ports: 0
diff --git a/tests/conftest.py b/tests/conftest.py
index 54519e2b..2d605c94 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -30,21 +30,21 @@ from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.web_server.web_server import WebServer
from tests import TEST_ASSETS_ROOT
-rayinit(local_mode=True)
+rayinit()
ACTION_SPACE_NODE_VALUES = 1
ACTION_SPACE_NODE_ACTION_VALUES = 1
_LOGGER = getLogger(__name__)
-class TestService(Service):
+class DummyService(Service):
"""Test Service class"""
def describe_state(self) -> Dict:
return super().describe_state()
def __init__(self, **kwargs):
- kwargs["name"] = "TestService"
+ kwargs["name"] = "DummyService"
kwargs["port"] = Port.HTTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
@@ -75,15 +75,15 @@ def uc2_network() -> Network:
@pytest.fixture(scope="function")
-def service(file_system) -> TestService:
- return TestService(
- name="TestService", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_service")
+def service(file_system) -> DummyService:
+ return DummyService(
+ name="DummyService", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_service")
)
@pytest.fixture(scope="function")
def service_class():
- return TestService
+ return DummyService
@pytest.fixture(scope="function")
diff --git a/tests/integration_tests/component_creation/test_action_integration.py b/tests/integration_tests/component_creation/test_action_integration.py
index a6f09436..7bdc80fc 100644
--- a/tests/integration_tests/component_creation/test_action_integration.py
+++ b/tests/integration_tests/component_creation/test_action_integration.py
@@ -22,8 +22,7 @@ def test_passing_actions_down(monkeypatch) -> None:
for n in [pc1, pc2, srv, s1]:
sim.network.add_node(n)
- database_service = DatabaseService(file_system=srv.file_system)
- srv.install_service(database_service)
+ srv.software_manager.install(DatabaseService)
downloads_folder = pc1.file_system.create_folder("downloads")
pc1.file_system.create_file("bermuda_triangle.png", folder_name="downloads")
diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py
index 88dd2bd5..ef789ba7 100644
--- a/tests/integration_tests/game_layer/observations/test_nic_observations.py
+++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py
@@ -9,9 +9,11 @@ from gymnasium import spaces
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.game.game import PrimaiteGame
+from primaite.simulator.network.hardware.base import NetworkInterface
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Server
+from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.web_browser import WebBrowser
@@ -75,6 +77,18 @@ def test_nic(simulation):
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True)
+ # Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs
+ nmne_config = {
+ "capture_nmne": True, # Enable the capture of MNEs
+ "nmne_capture_keywords": [
+ "DELETE",
+ "ENCRYPT",
+ ], # Specify "DELETE/ENCRYPT" SQL command as a keyword for MNE detection
+ }
+
+ # Apply the NMNE configuration settings
+ NetworkInterface.nmne_config = NMNEConfig(**nmne_config)
+
assert nic_obs.space["nic_status"] == spaces.Discrete(3)
assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4)
assert nic_obs.space["NMNE"]["outbound"] == spaces.Discrete(4)
@@ -144,7 +158,7 @@ def test_nic_monitored_traffic(simulation):
pc2: Computer = simulation.network.get_node_by_hostname("client_2")
nic_obs = NICObservation(
- where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True, monitored_traffic=monitored_traffic
+ where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=False, monitored_traffic=monitored_traffic
)
simulation.pre_timestep(0) # apply timestep to whole sim
diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py
index a8f1f245..debf5b1c 100644
--- a/tests/integration_tests/network/test_capture_nmne.py
+++ b/tests/integration_tests/network/test_capture_nmne.py
@@ -1,12 +1,14 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from primaite.game.agent.observations.nic_observations import NICObservation
+from primaite.simulator.network.container import Network
+from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Server
-from primaite.simulator.network.nmne import set_nmne_config
+from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
-def test_capture_nmne(uc2_network):
+def test_capture_nmne(uc2_network: Network):
"""
Conducts a test to verify that Malicious Network Events (MNEs) are correctly captured.
@@ -33,7 +35,7 @@ def test_capture_nmne(uc2_network):
}
# Apply the NMNE configuration settings
- set_nmne_config(nmne_config)
+ NIC.nmne_config = NMNEConfig(**nmne_config)
# Assert that initially, there are no captured MNEs on both web and database servers
assert web_server_nic.nmne == {}
@@ -82,7 +84,7 @@ def test_capture_nmne(uc2_network):
assert db_server_nic.nmne == {"direction": {"inbound": {"keywords": {"*": 3}}}}
-def test_describe_state_nmne(uc2_network):
+def test_describe_state_nmne(uc2_network: Network):
"""
Conducts a test to verify that Malicious Network Events (MNEs) are correctly represented in the nic state.
@@ -110,7 +112,7 @@ def test_describe_state_nmne(uc2_network):
}
# Apply the NMNE configuration settings
- set_nmne_config(nmne_config)
+ NIC.nmne_config = NMNEConfig(**nmne_config)
# Assert that initially, there are no captured MNEs on both web and database servers
web_server_nic_state = web_server_nic.describe_state()
@@ -190,7 +192,7 @@ def test_describe_state_nmne(uc2_network):
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 4}}}}
-def test_capture_nmne_observations(uc2_network):
+def test_capture_nmne_observations(uc2_network: Network):
"""
Tests the NICObservation class's functionality within a simulated network environment.
@@ -219,7 +221,7 @@ def test_capture_nmne_observations(uc2_network):
}
# Apply the NMNE configuration settings
- set_nmne_config(nmne_config)
+ NIC.nmne_config = NMNEConfig(**nmne_config)
# Define observations for the NICs of the database and web servers
db_server_nic_obs = NICObservation(where=["network", "nodes", "database_server", "NICs", 1], include_nmne=True)
diff --git a/tests/integration_tests/system/test_service_on_node.py b/tests/integration_tests/system/test_service_on_node.py
index 15dbaf1d..cf9728ce 100644
--- a/tests/integration_tests/system/test_service_on_node.py
+++ b/tests/integration_tests/system/test_service_on_node.py
@@ -23,7 +23,7 @@ def populated_node(
server.power_on()
server.software_manager.install(service_class)
- service = server.software_manager.software.get("TestService")
+ service = server.software_manager.software.get("DummyService")
service.start()
return server, service
@@ -42,7 +42,7 @@ def test_service_on_offline_node(service_class):
computer.power_on()
computer.software_manager.install(service_class)
- service: Service = computer.software_manager.software.get("TestService")
+ service: Service = computer.software_manager.software.get("DummyService")
computer.power_off()
diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py
index a9f0b58d..95634cf1 100644
--- a/tests/integration_tests/test_simulation/test_request_response.py
+++ b/tests/integration_tests/test_simulation/test_request_response.py
@@ -13,7 +13,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.transmission.transport_layer import Port
-from tests.conftest import DummyApplication, TestService
+from tests.conftest import DummyApplication, DummyService
def test_successful_node_file_system_creation_request(example_network):
@@ -61,7 +61,7 @@ def test_successful_application_requests(example_network):
def test_successful_service_requests(example_network):
net = example_network
server_1 = net.get_node_by_hostname("server_1")
- server_1.software_manager.install(TestService)
+ server_1.software_manager.install(DummyService)
# Careful: the order here is important, for example we cannot run "stop" unless we run "start" first
for verb in [
@@ -77,7 +77,7 @@ def test_successful_service_requests(example_network):
"scan",
"fix",
]:
- resp_1 = net.apply_request(["node", "server_1", "service", "TestService", verb])
+ resp_1 = net.apply_request(["node", "server_1", "service", "DummyService", verb])
assert resp_1 == RequestResponse(status="success", data={})
server_1.apply_timestep(timestep=1)
server_1.apply_timestep(timestep=1)
diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py
index 9b37ac80..44c5c781 100644
--- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py
+++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py
@@ -7,6 +7,7 @@ from primaite.simulator.file_system.folder import Folder
from primaite.simulator.network.hardware.base import Node, NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.system.software import SoftwareHealthState
+from tests.conftest import DummyApplication, DummyService
@pytest.fixture
@@ -47,7 +48,7 @@ def test_node_shutdown(node):
assert node.operating_state == NodeOperatingState.OFF
-def test_node_os_scan(node, service, application):
+def test_node_os_scan(node):
"""Test OS Scanning."""
node.operating_state = NodeOperatingState.ON
@@ -55,13 +56,15 @@ def test_node_os_scan(node, service, application):
# TODO implement processes
# add services to node
+ node.software_manager.install(DummyService)
+ service = node.software_manager.software.get("DummyService")
service.set_health_state(SoftwareHealthState.COMPROMISED)
- node.install_service(service=service)
assert service.health_state_visible == SoftwareHealthState.UNUSED
# add application to node
+ node.software_manager.install(DummyApplication)
+ application = node.software_manager.software.get("DummyApplication")
application.set_health_state(SoftwareHealthState.COMPROMISED)
- node.install_application(application=application)
assert application.health_state_visible == SoftwareHealthState.UNUSED
# add folder and file to node
@@ -91,7 +94,7 @@ def test_node_os_scan(node, service, application):
assert file2.visible_health_status == FileSystemItemHealthStatus.CORRUPT
-def test_node_red_scan(node, service, application):
+def test_node_red_scan(node):
"""Test revealing to red"""
node.operating_state = NodeOperatingState.ON
@@ -99,12 +102,14 @@ def test_node_red_scan(node, service, application):
# TODO implement processes
# add services to node
- node.install_service(service=service)
+ node.software_manager.install(DummyService)
+ service = node.software_manager.software.get("DummyService")
assert service.revealed_to_red is False
# add application to node
+ node.software_manager.install(DummyApplication)
+ application = node.software_manager.software.get("DummyApplication")
application.set_health_state(SoftwareHealthState.COMPROMISED)
- node.install_application(application=application)
assert application.revealed_to_red is False
# add folder and file to node