Merge branch 'dev' into feature/2706-Terminal_Sim_Component
This commit is contained in:
@@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- **Transmission Feasibility Check**: Updated `_can_transmit` function in `Link` to account for current load and total bandwidth capacity, ensuring transmissions do not exceed limits.
|
||||
- **Frame Size Details**: Frame `size` attribute now includes both core size and payload size in bytes.
|
||||
- **Transmission Blocking**: Enhanced `AirSpace` logic to block transmissions that would exceed the available capacity.
|
||||
- **Software (un)install refactored**: Removed the install/uninstall methods in the node class and made the software manager install/uninstall handle all of their functionality.
|
||||
|
||||
### Fixed
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Tuple
|
||||
|
||||
from report import build_benchmark_latex_report
|
||||
from report import build_benchmark_md_report
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
import primaite
|
||||
@@ -188,7 +188,7 @@ def run(
|
||||
with open(_SESSION_METADATA_ROOT / f"{i}.json", "r") as file:
|
||||
session_metadata_dict[i] = json.load(file)
|
||||
# generate report
|
||||
build_benchmark_latex_report(
|
||||
build_benchmark_md_report(
|
||||
benchmark_start_time=benchmark_start_time,
|
||||
session_metadata=session_metadata_dict,
|
||||
config_path=data_manipulation_config_path(),
|
||||
|
||||
@@ -234,10 +234,7 @@ def _plot_av_s_per_100_steps_10_nodes(
|
||||
"""
|
||||
major_v = primaite.__version__.split(".")[0]
|
||||
title = f"Performance of Minor and Bugfix Releases for Major Version {major_v}"
|
||||
subtitle = (
|
||||
f"Average Training Time per 100 Steps on 10 Nodes "
|
||||
f"(target: <= {PLOT_CONFIG['av_s_per_100_steps_10_nodes_benchmark_threshold']} seconds)"
|
||||
)
|
||||
subtitle = "Average Training Time per 100 Steps on 10 Nodes "
|
||||
title = f"{title} <br><sub>{subtitle}</sub>"
|
||||
|
||||
layout = go.Layout(
|
||||
@@ -250,10 +247,6 @@ def _plot_av_s_per_100_steps_10_nodes(
|
||||
|
||||
versions = sorted(list(version_times_dict.keys()))
|
||||
times = [version_times_dict[version] for version in versions]
|
||||
av_s_per_100_steps_10_nodes_benchmark_threshold = PLOT_CONFIG["av_s_per_100_steps_10_nodes_benchmark_threshold"]
|
||||
|
||||
# Calculate the appropriate maximum y-axis value
|
||||
max_y_axis_value = max(max(times), av_s_per_100_steps_10_nodes_benchmark_threshold) + 1
|
||||
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
@@ -267,7 +260,6 @@ def _plot_av_s_per_100_steps_10_nodes(
|
||||
fig.update_layout(
|
||||
xaxis_title="PrimAITE Version",
|
||||
yaxis_title="Avg Time per 100 Steps on 10 Nodes (seconds)",
|
||||
yaxis=dict(range=[0, max_y_axis_value]),
|
||||
title=title,
|
||||
)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ license-files = ["LICENSE"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
rl = [
|
||||
"ray[rllib] >= 2.20.0, < 3",
|
||||
"ray[rllib] >= 2.20.0, <2.33",
|
||||
"tensorflow==2.12.0",
|
||||
"stable-baselines3[extra]==2.1.0",
|
||||
"sb3-contrib==2.1.0",
|
||||
|
||||
@@ -129,6 +129,10 @@ agents:
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nmne_config:
|
||||
capture_nmne: true
|
||||
nmne_capture_keywords:
|
||||
- DELETE
|
||||
nodes:
|
||||
- hostname: client
|
||||
type: computer
|
||||
|
||||
@@ -18,7 +18,7 @@ from primaite.game.agent.scripted_agents.tap001 import TAP001
|
||||
from primaite.game.science import graph_has_cycle, topological_sort
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.network.airspace import AirSpaceFrequency
|
||||
from primaite.simulator.network.hardware.base import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
|
||||
@@ -26,7 +26,7 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
|
||||
from primaite.simulator.network.hardware.nodes.network.router import Router
|
||||
from primaite.simulator.network.hardware.nodes.network.switch import Switch
|
||||
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
|
||||
from primaite.simulator.network.nmne import set_nmne_config
|
||||
from primaite.simulator.network.nmne import NMNEConfig
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
@@ -266,6 +266,8 @@ class PrimaiteGame:
|
||||
|
||||
nodes_cfg = network_config.get("nodes", [])
|
||||
links_cfg = network_config.get("links", [])
|
||||
# Set the NMNE capture config
|
||||
NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {}))
|
||||
|
||||
for node_cfg in nodes_cfg:
|
||||
n_type = node_cfg["type"]
|
||||
@@ -535,10 +537,7 @@ class PrimaiteGame:
|
||||
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
|
||||
game.setup_reward_sharing()
|
||||
|
||||
# Set the NMNE capture config
|
||||
set_nmne_config(network_config.get("nmne_config", {}))
|
||||
game.update_agents(game.get_sim_state())
|
||||
|
||||
return game
|
||||
|
||||
def setup_reward_sharing(self):
|
||||
|
||||
@@ -6,12 +6,11 @@ import secrets
|
||||
from abc import ABC, abstractmethod
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, TypeVar, Union
|
||||
from typing import Any, ClassVar, Dict, Optional, TypeVar, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import primaite.simulator.network.nmne
|
||||
from primaite import getLogger
|
||||
from primaite.exceptions import NetworkError
|
||||
from primaite.interface.request import RequestResponse
|
||||
@@ -20,15 +19,7 @@ from primaite.simulator.core import RequestFormat, RequestManager, RequestPermis
|
||||
from primaite.simulator.domain.account import Account
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.nmne import (
|
||||
CAPTURE_BY_DIRECTION,
|
||||
CAPTURE_BY_IP_ADDRESS,
|
||||
CAPTURE_BY_KEYWORD,
|
||||
CAPTURE_BY_PORT,
|
||||
CAPTURE_BY_PROTOCOL,
|
||||
CAPTURE_NMNE,
|
||||
NMNE_CAPTURE_KEYWORDS,
|
||||
)
|
||||
from primaite.simulator.network.nmne import NMNEConfig
|
||||
from primaite.simulator.network.transmission.data_link_layer import Frame
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
@@ -108,8 +99,11 @@ class NetworkInterface(SimComponent, ABC):
|
||||
pcap: Optional[PacketCapture] = None
|
||||
"A PacketCapture instance for capturing and analysing packets passing through this interface."
|
||||
|
||||
nmne_config: ClassVar[NMNEConfig] = NMNEConfig()
|
||||
"A dataclass defining malicious network events to be captured."
|
||||
|
||||
nmne: Dict = Field(default_factory=lambda: {})
|
||||
"A dict containing details of the number of malicious network events captured."
|
||||
"A dict containing details of the number of malicious events captured."
|
||||
|
||||
traffic: Dict = Field(default_factory=lambda: {})
|
||||
"A dict containing details of the inbound and outbound traffic by port and protocol."
|
||||
@@ -167,8 +161,8 @@ class NetworkInterface(SimComponent, ABC):
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
)
|
||||
if CAPTURE_NMNE:
|
||||
state.update({"nmne": {k: v for k, v in self.nmne.items()}})
|
||||
if self.nmne_config and self.nmne_config.capture_nmne:
|
||||
state.update({"nmne": self.nmne})
|
||||
state.update({"traffic": convert_dict_enum_keys_to_enum_values(self.traffic)})
|
||||
return state
|
||||
|
||||
@@ -201,7 +195,7 @@ class NetworkInterface(SimComponent, ABC):
|
||||
:param inbound: Boolean indicating if the frame direction is inbound. Defaults to True.
|
||||
"""
|
||||
# Exit function if NMNE capturing is disabled
|
||||
if not CAPTURE_NMNE:
|
||||
if not (self.nmne_config and self.nmne_config.capture_nmne):
|
||||
return
|
||||
|
||||
# Initialise basic frame data variables
|
||||
@@ -222,27 +216,27 @@ class NetworkInterface(SimComponent, ABC):
|
||||
frame_str = str(frame.payload)
|
||||
|
||||
# Proceed only if any NMNE keyword is present in the frame payload
|
||||
if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS):
|
||||
if any(keyword in frame_str for keyword in self.nmne_config.nmne_capture_keywords):
|
||||
# Start with the root of the NMNE capture structure
|
||||
current_level = self.nmne
|
||||
|
||||
# Update NMNE structure based on enabled settings
|
||||
if CAPTURE_BY_DIRECTION:
|
||||
if self.nmne_config.capture_by_direction:
|
||||
# Set or get the dictionary for the current direction
|
||||
current_level = current_level.setdefault("direction", {})
|
||||
current_level = current_level.setdefault(direction, {})
|
||||
|
||||
if CAPTURE_BY_IP_ADDRESS:
|
||||
if self.nmne_config.capture_by_ip_address:
|
||||
# Set or get the dictionary for the current IP address
|
||||
current_level = current_level.setdefault("ip_address", {})
|
||||
current_level = current_level.setdefault(ip_address, {})
|
||||
|
||||
if CAPTURE_BY_PROTOCOL:
|
||||
if self.nmne_config.capture_by_protocol:
|
||||
# Set or get the dictionary for the current protocol
|
||||
current_level = current_level.setdefault("protocol", {})
|
||||
current_level = current_level.setdefault(protocol, {})
|
||||
|
||||
if CAPTURE_BY_PORT:
|
||||
if self.nmne_config.capture_by_port:
|
||||
# Set or get the dictionary for the current port
|
||||
current_level = current_level.setdefault("port", {})
|
||||
current_level = current_level.setdefault(port, {})
|
||||
@@ -251,8 +245,8 @@ class NetworkInterface(SimComponent, ABC):
|
||||
keyword_level = current_level.setdefault("keywords", {})
|
||||
|
||||
# Increment the count for detected keywords in the payload
|
||||
if CAPTURE_BY_KEYWORD:
|
||||
for keyword in NMNE_CAPTURE_KEYWORDS:
|
||||
if self.nmne_config.capture_by_keyword:
|
||||
for keyword in self.nmne_config.nmne_capture_keywords:
|
||||
if keyword in frame_str:
|
||||
# Update the count for each keyword found
|
||||
keyword_level[keyword] = keyword_level.get(keyword, 0) + 1
|
||||
@@ -1173,7 +1167,7 @@ class Node(SimComponent):
|
||||
ip_address,
|
||||
network_interface.speed,
|
||||
"Enabled" if network_interface.enabled else "Disabled",
|
||||
network_interface.nmne if primaite.simulator.network.nmne.CAPTURE_NMNE else "Disabled",
|
||||
network_interface.nmne if network_interface.nmne_config.capture_nmne else "Disabled",
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
@@ -1455,74 +1449,6 @@ class Node(SimComponent):
|
||||
else:
|
||||
return
|
||||
|
||||
def install_service(self, service: Service) -> None:
|
||||
"""
|
||||
Install a service on this node.
|
||||
|
||||
:param service: Service instance that has not been installed on any node yet.
|
||||
:type service: Service
|
||||
"""
|
||||
if service in self:
|
||||
_LOGGER.warning(f"Can't add service {service.name} to node {self.hostname}. It's already installed.")
|
||||
return
|
||||
self.services[service.uuid] = service
|
||||
service.parent = self
|
||||
service.install() # Perform any additional setup, such as creating files for this service on the node.
|
||||
self.sys_log.info(f"Installed service {service.name}")
|
||||
_LOGGER.debug(f"Added service {service.name} to node {self.hostname}")
|
||||
self._service_request_manager.add_request(service.name, RequestType(func=service._request_manager))
|
||||
|
||||
def uninstall_service(self, service: Service) -> None:
|
||||
"""
|
||||
Uninstall and completely remove service from this node.
|
||||
|
||||
:param service: Service object that is currently associated with this node.
|
||||
:type service: Service
|
||||
"""
|
||||
if service not in self:
|
||||
_LOGGER.warning(f"Can't remove service {service.name} from node {self.hostname}. It's not installed.")
|
||||
return
|
||||
service.uninstall() # Perform additional teardown, such as removing files or restarting the machine.
|
||||
self.services.pop(service.uuid)
|
||||
service.parent = None
|
||||
self.sys_log.info(f"Uninstalled service {service.name}")
|
||||
self._service_request_manager.remove_request(service.name)
|
||||
|
||||
def install_application(self, application: Application) -> None:
|
||||
"""
|
||||
Install an application on this node.
|
||||
|
||||
:param application: Application instance that has not been installed on any node yet.
|
||||
:type application: Application
|
||||
"""
|
||||
if application in self:
|
||||
_LOGGER.warning(
|
||||
f"Can't add application {application.name} to node {self.hostname}. It's already installed."
|
||||
)
|
||||
return
|
||||
self.applications[application.uuid] = application
|
||||
application.parent = self
|
||||
self.sys_log.info(f"Installed application {application.name}")
|
||||
_LOGGER.debug(f"Added application {application.name} to node {self.hostname}")
|
||||
self._application_request_manager.add_request(application.name, RequestType(func=application._request_manager))
|
||||
|
||||
def uninstall_application(self, application: Application) -> None:
|
||||
"""
|
||||
Uninstall and completely remove application from this node.
|
||||
|
||||
:param application: Application object that is currently associated with this node.
|
||||
:type application: Application
|
||||
"""
|
||||
if application not in self:
|
||||
_LOGGER.warning(
|
||||
f"Can't remove application {application.name} from node {self.hostname}. It's not installed."
|
||||
)
|
||||
return
|
||||
self.applications.pop(application.uuid)
|
||||
application.parent = None
|
||||
self.sys_log.info(f"Uninstalled application {application.name}")
|
||||
self._application_request_manager.remove_request(application.name)
|
||||
|
||||
def _shut_down_actions(self):
|
||||
"""Actions to perform when the node is shut down."""
|
||||
# Turn off all the services in the node
|
||||
|
||||
@@ -1,48 +1,25 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from typing import Dict, Final, List
|
||||
from typing import List
|
||||
|
||||
CAPTURE_NMNE: bool = True
|
||||
"""Indicates whether Malicious Network Events (MNEs) should be captured. Default is True."""
|
||||
|
||||
NMNE_CAPTURE_KEYWORDS: List[str] = []
|
||||
"""List of keywords to identify malicious network events."""
|
||||
|
||||
# TODO: Remove final and make configurable after example layout when the NICObservation creates nmne structure dynamically
|
||||
CAPTURE_BY_DIRECTION: Final[bool] = True
|
||||
"""Flag to determine if captures should be organized by traffic direction (inbound/outbound)."""
|
||||
CAPTURE_BY_IP_ADDRESS: Final[bool] = False
|
||||
"""Flag to determine if captures should be organized by source or destination IP address."""
|
||||
CAPTURE_BY_PROTOCOL: Final[bool] = False
|
||||
"""Flag to determine if captures should be organized by network protocol (e.g., TCP, UDP)."""
|
||||
CAPTURE_BY_PORT: Final[bool] = False
|
||||
"""Flag to determine if captures should be organized by source or destination port."""
|
||||
CAPTURE_BY_KEYWORD: Final[bool] = False
|
||||
"""Flag to determine if captures should be filtered and categorised based on specific keywords."""
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
def set_nmne_config(nmne_config: Dict):
|
||||
"""
|
||||
Sets the configuration for capturing Malicious Network Events (MNEs) based on a provided dictionary.
|
||||
class NMNEConfig(BaseModel):
|
||||
"""Store all the information to perform NMNE operations."""
|
||||
|
||||
This function updates global settings related to NMNE capture, including whether to capture NMNEs and what
|
||||
keywords to use for identifying NMNEs.
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
The function ensures that the settings are updated only if they are provided in the `nmne_config` dictionary,
|
||||
and maintains type integrity by checking the types of the provided values.
|
||||
|
||||
:param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys include:
|
||||
"capture_nmne" (bool) to indicate whether NMNEs should be captured, "nmne_capture_keywords" (list of strings)
|
||||
to specify keywords for NMNE identification.
|
||||
"""
|
||||
global NMNE_CAPTURE_KEYWORDS
|
||||
global CAPTURE_NMNE
|
||||
|
||||
# Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect
|
||||
CAPTURE_NMNE = nmne_config.get("capture_nmne", False)
|
||||
if not isinstance(CAPTURE_NMNE, bool):
|
||||
CAPTURE_NMNE = True # Revert to default True if the provided value is not a boolean
|
||||
|
||||
# Update the NMNE capture keywords, appending new keywords if provided
|
||||
NMNE_CAPTURE_KEYWORDS += nmne_config.get("nmne_capture_keywords", [])
|
||||
if not isinstance(NMNE_CAPTURE_KEYWORDS, list):
|
||||
NMNE_CAPTURE_KEYWORDS = [] # Reset to empty list if the provided value is not a list
|
||||
capture_nmne: bool = False
|
||||
"""Indicates whether Malicious Network Events (MNEs) should be captured."""
|
||||
nmne_capture_keywords: List[str] = []
|
||||
"""List of keywords to identify malicious network events."""
|
||||
capture_by_direction: bool = True
|
||||
"""Captures should be organized by traffic direction (inbound/outbound)."""
|
||||
capture_by_ip_address: bool = False
|
||||
"""Captures should be organized by source or destination IP address."""
|
||||
capture_by_protocol: bool = False
|
||||
"""Captures should be organized by network protocol (e.g., TCP, UDP)."""
|
||||
capture_by_port: bool = False
|
||||
"""Captures should be organized by source or destination port."""
|
||||
capture_by_keyword: bool = False
|
||||
"""Captures should be filtered and categorised based on specific keywords."""
|
||||
|
||||
@@ -4,7 +4,7 @@ from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel, field_validator, validate_call
|
||||
from pydantic_core.core_schema import FieldValidationInfo
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
@@ -96,7 +96,7 @@ class ICMPPacket(BaseModel):
|
||||
|
||||
@field_validator("icmp_code") # noqa
|
||||
@classmethod
|
||||
def _icmp_type_must_have_icmp_code(cls, v: int, info: FieldValidationInfo) -> int:
|
||||
def _icmp_type_must_have_icmp_code(cls, v: int, info: ValidationInfo) -> int:
|
||||
"""Validates the icmp_type and icmp_code."""
|
||||
icmp_type = info.data["icmp_type"]
|
||||
if get_icmp_type_code_description(icmp_type, v):
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite.simulator.core import RequestType
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.transmission.data_link_layer import Frame
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
@@ -20,9 +21,7 @@ if TYPE_CHECKING:
|
||||
from primaite.simulator.system.services.arp.arp import ARP
|
||||
from primaite.simulator.system.services.icmp.icmp import ICMP
|
||||
|
||||
from typing import Type, TypeVar
|
||||
|
||||
IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware)
|
||||
from typing import Type
|
||||
|
||||
|
||||
class SoftwareManager:
|
||||
@@ -51,7 +50,7 @@ class SoftwareManager:
|
||||
self.node = parent_node
|
||||
self.session_manager = session_manager
|
||||
self.software: Dict[str, Union[Service, Application]] = {}
|
||||
self._software_class_to_name_map: Dict[Type[IOSoftwareClass], str] = {}
|
||||
self._software_class_to_name_map: Dict[Type[IOSoftware], str] = {}
|
||||
self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {}
|
||||
self.sys_log: SysLog = sys_log
|
||||
self.file_system: FileSystem = file_system
|
||||
@@ -104,33 +103,34 @@ class SoftwareManager:
|
||||
return True
|
||||
return False
|
||||
|
||||
def install(self, software_class: Type[IOSoftwareClass]):
|
||||
def install(self, software_class: Type[IOSoftware]):
|
||||
"""
|
||||
Install an Application or Service.
|
||||
|
||||
:param software_class: The software class.
|
||||
"""
|
||||
# TODO: Software manager and node itself both have an install method. Need to refactor to have more logical
|
||||
# separation of concerns.
|
||||
if software_class in self._software_class_to_name_map:
|
||||
self.sys_log.warning(f"Cannot install {software_class} as it is already installed")
|
||||
return
|
||||
software = software_class(
|
||||
software_manager=self, sys_log=self.sys_log, file_system=self.file_system, dns_server=self.dns_server
|
||||
)
|
||||
software.parent = self.node
|
||||
if isinstance(software, Application):
|
||||
software.install()
|
||||
self.node.applications[software.uuid] = software
|
||||
self.node._application_request_manager.add_request(
|
||||
software.name, RequestType(func=software._request_manager)
|
||||
)
|
||||
elif isinstance(software, Service):
|
||||
self.node.services[software.uuid] = software
|
||||
self.node._service_request_manager.add_request(software.name, RequestType(func=software._request_manager))
|
||||
software.install()
|
||||
software.software_manager = self
|
||||
self.software[software.name] = software
|
||||
self.port_protocol_mapping[(software.port, software.protocol)] = software
|
||||
if isinstance(software, Application):
|
||||
software.operating_state = ApplicationOperatingState.CLOSED
|
||||
|
||||
# add the software to the node's registry after it has been fully initialized
|
||||
if isinstance(software, Service):
|
||||
self.node.install_service(software)
|
||||
elif isinstance(software, Application):
|
||||
self.node.install_application(software)
|
||||
self.node.sys_log.info(f"Installed {software.name}")
|
||||
|
||||
def uninstall(self, software_name: str):
|
||||
"""
|
||||
@@ -138,25 +138,31 @@ class SoftwareManager:
|
||||
|
||||
:param software_name: The software name.
|
||||
"""
|
||||
if software_name in self.software:
|
||||
self.software[software_name].uninstall()
|
||||
software = self.software.pop(software_name) # noqa
|
||||
if isinstance(software, Application):
|
||||
self.node.uninstall_application(software)
|
||||
elif isinstance(software, Service):
|
||||
self.node.uninstall_service(software)
|
||||
for key, value in self.port_protocol_mapping.items():
|
||||
if value.name == software_name:
|
||||
self.port_protocol_mapping.pop(key)
|
||||
break
|
||||
for key, value in self._software_class_to_name_map.items():
|
||||
if value == software_name:
|
||||
self._software_class_to_name_map.pop(key)
|
||||
break
|
||||
del software
|
||||
self.sys_log.info(f"Uninstalled {software_name}")
|
||||
if software_name not in self.software:
|
||||
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")
|
||||
return
|
||||
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")
|
||||
|
||||
self.software[software_name].uninstall()
|
||||
software = self.software.pop(software_name) # noqa
|
||||
if isinstance(software, Application):
|
||||
self.node.applications.pop(software.uuid)
|
||||
self.node._application_request_manager.remove_request(software.name)
|
||||
elif isinstance(software, Service):
|
||||
self.node.services.pop(software.uuid)
|
||||
software.uninstall()
|
||||
self.node._service_request_manager.remove_request(software.name)
|
||||
software.parent = None
|
||||
for key, value in self.port_protocol_mapping.items():
|
||||
if value.name == software_name:
|
||||
self.port_protocol_mapping.pop(key)
|
||||
break
|
||||
for key, value in self._software_class_to_name_map.items():
|
||||
if value == software_name:
|
||||
self._software_class_to_name_map.pop(key)
|
||||
break
|
||||
del software
|
||||
self.sys_log.info(f"Uninstalled {software_name}")
|
||||
return
|
||||
|
||||
def send_internal_payload(self, target_software: str, payload: Any):
|
||||
"""
|
||||
|
||||
@@ -99,7 +99,7 @@ agents:
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
include_nmne: false
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
|
||||
@@ -92,7 +92,7 @@ agents:
|
||||
- NONE
|
||||
tcp:
|
||||
- DNS
|
||||
include_nmne: true
|
||||
include_nmne: false
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
|
||||
@@ -111,7 +111,7 @@ agents:
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
include_nmne: false
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
|
||||
@@ -68,7 +68,7 @@ agents:
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
include_nmne: false
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
|
||||
@@ -89,7 +89,7 @@ agents:
|
||||
- NONE
|
||||
tcp:
|
||||
- DNS
|
||||
include_nmne: true
|
||||
include_nmne: false
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
|
||||
@@ -44,7 +44,7 @@ agents:
|
||||
num_files: 1
|
||||
num_nics: 1
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
include_nmne: false
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
|
||||
@@ -89,7 +89,7 @@ agents:
|
||||
- NONE
|
||||
tcp:
|
||||
- DNS
|
||||
include_nmne: true
|
||||
include_nmne: false
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
|
||||
@@ -120,7 +120,7 @@ agents:
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
include_nmne: false
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
|
||||
@@ -30,21 +30,21 @@ from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
rayinit(local_mode=True)
|
||||
rayinit()
|
||||
ACTION_SPACE_NODE_VALUES = 1
|
||||
ACTION_SPACE_NODE_ACTION_VALUES = 1
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class TestService(Service):
|
||||
class DummyService(Service):
|
||||
"""Test Service class"""
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
return super().describe_state()
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "TestService"
|
||||
kwargs["name"] = "DummyService"
|
||||
kwargs["port"] = Port.HTTP
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
@@ -75,15 +75,15 @@ def uc2_network() -> Network:
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def service(file_system) -> TestService:
|
||||
return TestService(
|
||||
name="TestService", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_service")
|
||||
def service(file_system) -> DummyService:
|
||||
return DummyService(
|
||||
name="DummyService", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_service")
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def service_class():
|
||||
return TestService
|
||||
return DummyService
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
||||
@@ -22,8 +22,7 @@ def test_passing_actions_down(monkeypatch) -> None:
|
||||
for n in [pc1, pc2, srv, s1]:
|
||||
sim.network.add_node(n)
|
||||
|
||||
database_service = DatabaseService(file_system=srv.file_system)
|
||||
srv.install_service(database_service)
|
||||
srv.software_manager.install(DatabaseService)
|
||||
|
||||
downloads_folder = pc1.file_system.create_folder("downloads")
|
||||
pc1.file_system.create_file("bermuda_triangle.png", folder_name="downloads")
|
||||
|
||||
@@ -9,9 +9,11 @@ from gymnasium import spaces
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.agent.observations.nic_observations import NICObservation
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.simulator.network.hardware.base import NetworkInterface
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.nmne import NMNEConfig
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
@@ -75,6 +77,18 @@ def test_nic(simulation):
|
||||
|
||||
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True)
|
||||
|
||||
# Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs
|
||||
nmne_config = {
|
||||
"capture_nmne": True, # Enable the capture of MNEs
|
||||
"nmne_capture_keywords": [
|
||||
"DELETE",
|
||||
"ENCRYPT",
|
||||
], # Specify "DELETE/ENCRYPT" SQL command as a keyword for MNE detection
|
||||
}
|
||||
|
||||
# Apply the NMNE configuration settings
|
||||
NetworkInterface.nmne_config = NMNEConfig(**nmne_config)
|
||||
|
||||
assert nic_obs.space["nic_status"] == spaces.Discrete(3)
|
||||
assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4)
|
||||
assert nic_obs.space["NMNE"]["outbound"] == spaces.Discrete(4)
|
||||
@@ -144,7 +158,7 @@ def test_nic_monitored_traffic(simulation):
|
||||
pc2: Computer = simulation.network.get_node_by_hostname("client_2")
|
||||
|
||||
nic_obs = NICObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True, monitored_traffic=monitored_traffic
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=False, monitored_traffic=monitored_traffic
|
||||
)
|
||||
|
||||
simulation.pre_timestep(0) # apply timestep to whole sim
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from primaite.game.agent.observations.nic_observations import NICObservation
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.nmne import set_nmne_config
|
||||
from primaite.simulator.network.nmne import NMNEConfig
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
|
||||
|
||||
def test_capture_nmne(uc2_network):
|
||||
def test_capture_nmne(uc2_network: Network):
|
||||
"""
|
||||
Conducts a test to verify that Malicious Network Events (MNEs) are correctly captured.
|
||||
|
||||
@@ -33,7 +35,7 @@ def test_capture_nmne(uc2_network):
|
||||
}
|
||||
|
||||
# Apply the NMNE configuration settings
|
||||
set_nmne_config(nmne_config)
|
||||
NIC.nmne_config = NMNEConfig(**nmne_config)
|
||||
|
||||
# Assert that initially, there are no captured MNEs on both web and database servers
|
||||
assert web_server_nic.nmne == {}
|
||||
@@ -82,7 +84,7 @@ def test_capture_nmne(uc2_network):
|
||||
assert db_server_nic.nmne == {"direction": {"inbound": {"keywords": {"*": 3}}}}
|
||||
|
||||
|
||||
def test_describe_state_nmne(uc2_network):
|
||||
def test_describe_state_nmne(uc2_network: Network):
|
||||
"""
|
||||
Conducts a test to verify that Malicious Network Events (MNEs) are correctly represented in the nic state.
|
||||
|
||||
@@ -110,7 +112,7 @@ def test_describe_state_nmne(uc2_network):
|
||||
}
|
||||
|
||||
# Apply the NMNE configuration settings
|
||||
set_nmne_config(nmne_config)
|
||||
NIC.nmne_config = NMNEConfig(**nmne_config)
|
||||
|
||||
# Assert that initially, there are no captured MNEs on both web and database servers
|
||||
web_server_nic_state = web_server_nic.describe_state()
|
||||
@@ -190,7 +192,7 @@ def test_describe_state_nmne(uc2_network):
|
||||
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 4}}}}
|
||||
|
||||
|
||||
def test_capture_nmne_observations(uc2_network):
|
||||
def test_capture_nmne_observations(uc2_network: Network):
|
||||
"""
|
||||
Tests the NICObservation class's functionality within a simulated network environment.
|
||||
|
||||
@@ -219,7 +221,7 @@ def test_capture_nmne_observations(uc2_network):
|
||||
}
|
||||
|
||||
# Apply the NMNE configuration settings
|
||||
set_nmne_config(nmne_config)
|
||||
NIC.nmne_config = NMNEConfig(**nmne_config)
|
||||
|
||||
# Define observations for the NICs of the database and web servers
|
||||
db_server_nic_obs = NICObservation(where=["network", "nodes", "database_server", "NICs", 1], include_nmne=True)
|
||||
|
||||
@@ -23,7 +23,7 @@ def populated_node(
|
||||
server.power_on()
|
||||
server.software_manager.install(service_class)
|
||||
|
||||
service = server.software_manager.software.get("TestService")
|
||||
service = server.software_manager.software.get("DummyService")
|
||||
service.start()
|
||||
|
||||
return server, service
|
||||
@@ -42,7 +42,7 @@ def test_service_on_offline_node(service_class):
|
||||
computer.power_on()
|
||||
computer.software_manager.install(service_class)
|
||||
|
||||
service: Service = computer.software_manager.software.get("TestService")
|
||||
service: Service = computer.software_manager.software.get("DummyService")
|
||||
|
||||
computer.power_off()
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from tests.conftest import DummyApplication, TestService
|
||||
from tests.conftest import DummyApplication, DummyService
|
||||
|
||||
|
||||
def test_successful_node_file_system_creation_request(example_network):
|
||||
@@ -61,7 +61,7 @@ def test_successful_application_requests(example_network):
|
||||
def test_successful_service_requests(example_network):
|
||||
net = example_network
|
||||
server_1 = net.get_node_by_hostname("server_1")
|
||||
server_1.software_manager.install(TestService)
|
||||
server_1.software_manager.install(DummyService)
|
||||
|
||||
# Careful: the order here is important, for example we cannot run "stop" unless we run "start" first
|
||||
for verb in [
|
||||
@@ -77,7 +77,7 @@ def test_successful_service_requests(example_network):
|
||||
"scan",
|
||||
"fix",
|
||||
]:
|
||||
resp_1 = net.apply_request(["node", "server_1", "service", "TestService", verb])
|
||||
resp_1 = net.apply_request(["node", "server_1", "service", "DummyService", verb])
|
||||
assert resp_1 == RequestResponse(status="success", data={})
|
||||
server_1.apply_timestep(timestep=1)
|
||||
server_1.apply_timestep(timestep=1)
|
||||
|
||||
@@ -7,6 +7,7 @@ from primaite.simulator.file_system.folder import Folder
|
||||
from primaite.simulator.network.hardware.base import Node, NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
from tests.conftest import DummyApplication, DummyService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -47,7 +48,7 @@ def test_node_shutdown(node):
|
||||
assert node.operating_state == NodeOperatingState.OFF
|
||||
|
||||
|
||||
def test_node_os_scan(node, service, application):
|
||||
def test_node_os_scan(node):
|
||||
"""Test OS Scanning."""
|
||||
node.operating_state = NodeOperatingState.ON
|
||||
|
||||
@@ -55,13 +56,15 @@ def test_node_os_scan(node, service, application):
|
||||
# TODO implement processes
|
||||
|
||||
# add services to node
|
||||
node.software_manager.install(DummyService)
|
||||
service = node.software_manager.software.get("DummyService")
|
||||
service.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
node.install_service(service=service)
|
||||
assert service.health_state_visible == SoftwareHealthState.UNUSED
|
||||
|
||||
# add application to node
|
||||
node.software_manager.install(DummyApplication)
|
||||
application = node.software_manager.software.get("DummyApplication")
|
||||
application.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
node.install_application(application=application)
|
||||
assert application.health_state_visible == SoftwareHealthState.UNUSED
|
||||
|
||||
# add folder and file to node
|
||||
@@ -91,7 +94,7 @@ def test_node_os_scan(node, service, application):
|
||||
assert file2.visible_health_status == FileSystemItemHealthStatus.CORRUPT
|
||||
|
||||
|
||||
def test_node_red_scan(node, service, application):
|
||||
def test_node_red_scan(node):
|
||||
"""Test revealing to red"""
|
||||
node.operating_state = NodeOperatingState.ON
|
||||
|
||||
@@ -99,12 +102,14 @@ def test_node_red_scan(node, service, application):
|
||||
# TODO implement processes
|
||||
|
||||
# add services to node
|
||||
node.install_service(service=service)
|
||||
node.software_manager.install(DummyService)
|
||||
service = node.software_manager.software.get("DummyService")
|
||||
assert service.revealed_to_red is False
|
||||
|
||||
# add application to node
|
||||
node.software_manager.install(DummyApplication)
|
||||
application = node.software_manager.software.get("DummyApplication")
|
||||
application.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
node.install_application(application=application)
|
||||
assert application.revealed_to_red is False
|
||||
|
||||
# add folder and file to node
|
||||
|
||||
Reference in New Issue
Block a user