From c6b1d35215c3e266ecbee1c7138acd78774fea1c Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Thu, 4 Jul 2024 20:45:42 +0100 Subject: [PATCH 01/35] #2967 - Enhance AirSpace simulation with dynamic environment and bandwidth/channel management This commit introduces several key enhancements to the AirSpace class, improving the realism and configurability of the wireless network. Major additions include the AirSpaceEnvironmentType and ChannelWidth enums, dynamic adjustment of interface speeds based on environmental settings, and comprehensive bandwidth management features. Additionally, the software now supports configuration of channel widths via the config file, incorporates accurate SNR and capacity calculations, and enforces bandwidth limits more effectively across wireless interfaces. Updated tests ensure that the new functionalities integrate seamlessly with existing systems. --- CHANGELOG.md | 35 +- docs/source/configuration/simulation.rst | 13 +- docs/source/simulation.rst | 1 + .../network/airspace.rst | 100 +++ .../network/nodes/wireless_router.rst | 13 +- src/primaite/game/game.py | 6 + src/primaite/simulator/network/airspace.py | 721 +++++++++++++++--- src/primaite/simulator/network/container.py | 2 + .../simulator/network/hardware/base.py | 17 +- .../hardware/nodes/network/wireless_router.py | 26 +- .../network/transmission/data_link_layer.py | 5 +- .../configs/wireless_wan_network_config.yaml | 2 + ...s_wan_wifi_5_80_channel_width_blocked.yaml | 81 ++ ...ess_wan_wifi_5_80_channel_width_urban.yaml | 81 ++ .../test_airspace_capacity_configuration.py | 106 +++ ...ndwidth_load_checks_before_transmission.py | 138 ++++ 16 files changed, 1238 insertions(+), 109 deletions(-) create mode 100644 docs/source/simulation_components/network/airspace.rst create mode 100644 tests/assets/configs/wireless_wan_wifi_5_80_channel_width_blocked.yaml create mode 100644 tests/assets/configs/wireless_wan_wifi_5_80_channel_width_urban.yaml create mode 100644 tests/integration_tests/network/test_airspace_capacity_configuration.py create mode 100644 tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 17bf3557..0ed09b94 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,9 +2,42 @@ All notable changes to this project will be documented in this file. -The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] +### Added + +- **AirSpaceEnvironmentType Enum Class**: Introduced in `airspace.py` to define different environmental settings affecting wireless network behavior. +- **ChannelWidth Enum Class**: Added in `airspace.py` to specify channel width options for wireless network interfaces. +- **Channel Width Attribute**: Incorporated into the `WirelessNetworkInterface` class to allow dynamic setting based on `AirSpaceFrequency` and `AirSpaceEnvironmentType`. +- **SNR and Capacity Calculation Functions**: Functions `estimate_snr` and `calculate_total_channel_capacity` added to `airspace.py` for computing signal-to-noise ratio and capacity based on frequency and channel width. +- **Dynamic Speed Setting**: WirelessInterface speed attribute now dynamically adjusts based on the operational environment, frequency, and channel width. +- **airspace_key Attribute**: Added to `WirelessNetworkInterface` as a tuple of frequency and channel width, serving as a key for bandwidth/channel management. +- **airspace_environment_type Attribute**: Determines the environmental type for the airspace, influencing data rate calculations and capacity sharing. +- **show_bandwidth_load Function**: Displays current bandwidth load for each frequency and channel width in the airspace. +- **Configuration Schema Update**: The `simulation.network` config file now includes settings for the `airspace_environment_type`. +- **Bandwidth Tracking**: Tracks data transmission across each frequency/channel width pairing. +- **Configuration Support for Wireless Routers**: `channel_width` can now be configured in the config file under `wireless_access_point`. +- **New Tests**: Added to validate the respect of bandwidth capacities and the correct parsing of airspace configurations from YAML files. + +### Changed + +- **NetworkInterface Speed Type**: The `speed` attribute of `NetworkInterface` has been changed from `int` to `float`. +- **Transmission Feasibility Check**: Updated `_can_transmit` function in `Link` to account for current load and total bandwidth capacity, ensuring transmissions do not exceed limits. +- **Frame Size Details**: Frame `size` attribute now includes both core size and payload size in bytes. +- **WirelessRouter Configuration Function**: `configure_wireless_access_point` function now accepts `channel_width` as a parameter. +- **Interface Grouping**: `WirelessNetworkInterfaces` are now grouped by both `AirSpaceFrequency` and `ChannelWidth`. +- **Interface Frequency/Channel Width Adjustment**: Changing an interface's settings now involves removal from the airspace, recalculation of its data rate, and re-addition under new settings. +- **Transmission Blocking**: Enhanced `AirSpace` logic to block transmissions that would exceed the available capacity. + +### Fixed + +- **Transmission Permission Logic**: Corrected the logic in `can_transmit_frame` to accurately prevent overloads by checking if the transmission of a frame stays within allowable bandwidth limits after considering current load. + + +[//]: # (This file needs tidying up between 2.0.0 and this line as it hasn't been segmented into 3.0.0 and 3.1.0 and isn't compliant with https://keepachangelog.com/en/1.1.0/) + ## 3.0.0b9 - Removed deprecated `PrimaiteSession` class. - Added ability to set log levels via configuration. diff --git a/docs/source/configuration/simulation.rst b/docs/source/configuration/simulation.rst index 2bcc8b66..d585711d 100644 --- a/docs/source/configuration/simulation.rst +++ b/docs/source/configuration/simulation.rst @@ -7,7 +7,7 @@ ============== In this section the network layout is defined. This part of the config follows a hierarchical structure. Almost every component defines a ``ref`` field which acts as a human-readable unique identifier, used by other parts of the config, such as agents. -At the top level of the network are ``nodes`` and ``links``. +At the top level of the network are ``nodes``, ``links`` and ``airspace``. e.g. @@ -19,6 +19,9 @@ e.g. ... links: ... + airspace: + ... + ``nodes`` --------- @@ -101,3 +104,11 @@ This accepts an integer value e.g. if port 1 is to be connected, the configurati ``bandwidth`` This is an integer value specifying the allowed bandwidth across the connection. Units are in Mbps. + +``airspace`` +------------ + +This is where the airspace settings for wireless networks arte set. + +``airspace_environment_type`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/simulation.rst b/docs/source/simulation.rst index a8870cb4..cc723e40 100644 --- a/docs/source/simulation.rst +++ b/docs/source/simulation.rst @@ -27,6 +27,7 @@ Contents simulation_components/network/nodes/firewall simulation_components/network/switch simulation_components/network/network + simulation_components/network/airspace simulation_components/system/internal_frame_processing simulation_components/system/sys_log simulation_components/system/pcap diff --git a/docs/source/simulation_components/network/airspace.rst b/docs/source/simulation_components/network/airspace.rst new file mode 100644 index 00000000..dcd762d4 --- /dev/null +++ b/docs/source/simulation_components/network/airspace.rst @@ -0,0 +1,100 @@ +.. only:: comment + + © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + +.. _airspace: + +AirSpace +======== + + +1. Introduction +--------------- + +The AirSpace class is the central component for wireless networks in PrimAITE and is designed to model and manage the behavior and interactions of wireless network interfaces within a simulated wireless network environment. This documentation provides a detailed overview of the AirSpace class, its components, and how they interact to create a realistic simulation of wireless network dynamics. + +2. Overview of the AirSpace System +---------------------------------- + +The AirSpace is a virtual representation of a physical wireless environment, managing multiple wireless network interfaces that simulate devices connected to the wireless network. These interfaces communicate over radio frequencies, with their interactions influenced by various factors modeled within the AirSpace. + +2.1 Key Components +^^^^^^^^^^^^^^^^^^ + +- **Wireless Network Interfaces**: Representations of network interfaces connected physical devices like routers, computers, or IoT devices that can send and receive data wirelessly. +- **Environmental Settings**: Different types of environments (e.g., urban, rural) that affect signal propagation and interference. +- **Channel Management**: Handles channels and their widths (e.g., 20 MHz, 40 MHz) to determine data transmission over different frequencies. +- **Bandwidth Management**: Tracks data transmission over channels to prevent overloading and simulate real-world network congestion. + +3. AirSpace Environment Types +----------------------------- + +The AirspaceEnvironmentType is a critical component that simulates different physical environments: + +- Urban, Suburban, Rural, etc. +- Each type simulates different levels of electromagnetic interference and signal propagation characteristics. +- Changing the AirspaceEnvironmentType impacts data rates by affecting the signal-to-noise ratio (SNR). + +4. Simulation of Environment Changes +------------------------------------ + +When an AirspaceEnvironmentType is set or changed, the AirSpace: + +1. Recalculates the maximum data transmission capacities for all managed frequencies and channel widths. +2. Updates all wireless interfaces to reflect new capacities. + +5. Managing Wireless Network Interfaces +--------------------------------------- + +- Interfaces can be dynamically added or removed. +- Configurations can be changed in real-time. +- The AirSpace handles data transmissions, ensuring data sent by an interface is received by all other interfaces on the same frequency and channel. + +6. Signal-to-Noise Ratio (SNR) Calculation +------------------------------------------ + +SNR is crucial in determining the quality of a wireless communication channel: + +.. math:: + + SNR = \frac{\text{Signal Power}}{\text{Noise Power}} + +- Impacted by environment type, frequency, and channel width +- Higher SNR indicates a clearer signal, leading to higher data transmission rates + +7. Total Channel Capacity Calculation +------------------------------------- + +Channel capacity is calculated using the Shannon-Hartley theorem: + +.. math:: + + C = B \cdot \log_2(1 + SNR) + +Where: + +- C: channel capacity in bits per second (bps) +- B: bandwidth of the channel in hertz (Hz) +- SNR: signal-to-noise ratio + +Implementation in AirSpace: + +1. Convert channel width from MHz to Hz. +2. Recalculate SNR based on new environment or interface settings. +3. Apply Shannon-Hartley theorem to determine new maximum channel capacity in Mbps. + +8. Shared Maximum Capacity Across Devices +----------------------------------------- + +While individual devices have theoretical maximum data rates, the actual achievable rate is often less due to: + +- Shared wireless medium among all devices on the same frequency and channel width +- Interference and congestion from multiple devices transmitting simultaneously + +9. AirSpace Inspection +---------------------- + +The AirSpace class provides methods for visualizing network behavior: + +- ``show_wireless_interfaces()``: Displays current state of all interfaces +- ``show_bandwidth_load()``: Shows channel loads and bandwidth utilization diff --git a/docs/source/simulation_components/network/nodes/wireless_router.rst b/docs/source/simulation_components/network/nodes/wireless_router.rst index 29110a52..eb7f95e3 100644 --- a/docs/source/simulation_components/network/nodes/wireless_router.rst +++ b/docs/source/simulation_components/network/nodes/wireless_router.rst @@ -37,7 +37,7 @@ additional steps to configure wireless settings: .. code-block:: python from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter - from primaite.simulator.network.airspace import AirSpaceFrequency + from primaite.simulator.network.airspace import AirSpaceFrequency, ChannelWidth # Instantiate the WirelessRouter wireless_router = WirelessRouter(hostname="MyWirelessRouter") @@ -49,7 +49,8 @@ additional steps to configure wireless settings: wireless_router.configure_wireless_access_point( port=1, ip_address="192.168.2.1", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency.WIFI_2_4 + frequency=AirSpaceFrequency.WIFI_2_4, + channel_width=ChannelWidth.ChannelWidth.WIDTH_40_MHZ ) @@ -71,7 +72,7 @@ ICMP traffic, ensuring basic network connectivity and ping functionality. .. code-block:: python - from primaite.simulator.network.airspace import AIR_SPACE, AirSpaceFrequency + from primaite.simulator.network.airspace import AirSpaceFrequency, ChannelWidth from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction @@ -130,13 +131,15 @@ ICMP traffic, ensuring basic network connectivity and ping functionality. port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency.WIFI_2_4 + frequency=AirSpaceFrequency.WIFI_2_4, + channel_width=ChannelWidth.ChannelWidth.WIDTH_40_MHZ ) router_2.configure_wireless_access_point( port=1, ip_address="192.168.1.2", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency.WIFI_2_4 + frequency=AirSpaceFrequency.WIFI_2_4, + channel_width=ChannelWidth.ChannelWidth.WIDTH_40_MHZ ) # Configure routes for inter-router communication diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 8a79d068..38bd3597 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -15,6 +15,7 @@ from primaite.game.agent.scripted_agents.probabilistic_agent import Probabilisti from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort +from primaite.simulator.network.airspace import AirspaceEnvironmentType from primaite.simulator.network.hardware.base import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC @@ -233,6 +234,11 @@ class PrimaiteGame: simulation_config = cfg.get("simulation", {}) network_config = simulation_config.get("network", {}) + airspace_cfg = network_config.get("airspace", {}) + airspace_environment_type_str = airspace_cfg.get("airspace_environment_type", "urban") + + airspace_environment_type: AirspaceEnvironmentType = AirspaceEnvironmentType(airspace_environment_type_str) + net.airspace.airspace_environment_type = airspace_environment_type nodes_cfg = network_config.get("nodes", []) links_cfg = network_config.get("links", []) diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index 5fec098b..6060d969 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -3,9 +3,11 @@ from __future__ import annotations from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Tuple -from prettytable import PrettyTable +import numpy as np +from prettytable import MARKDOWN, PrettyTable +from pydantic import BaseModel, computed_field, Field, model_validator from primaite import getLogger from primaite.simulator.network.hardware.base import Layer3Interface, NetworkInterface, WiredNetworkInterface @@ -15,90 +17,29 @@ from primaite.simulator.system.core.packet_capture import PacketCapture _LOGGER = getLogger(__name__) -__all__ = ["AirSpaceFrequency", "WirelessNetworkInterface", "IPWirelessNetworkInterface"] +def format_hertz(hertz: float, format_terahertz: bool = False, decimals: int = 3) -> str: + """ + Convert a frequency in Hertz to a formatted string using the most appropriate unit. -class AirSpace: - """Represents a wireless airspace, managing wireless network interfaces and handling wireless transmission.""" + Optionally includes formatting for Terahertz. - def __init__(self): - self._wireless_interfaces: Dict[str, WirelessNetworkInterface] = {} - self._wireless_interfaces_by_frequency: Dict[AirSpaceFrequency, List[WirelessNetworkInterface]] = {} - - def show(self, frequency: Optional[AirSpaceFrequency] = None): - """ - Displays a summary of wireless interfaces in the airspace, optionally filtered by a specific frequency. - - :param frequency: The frequency band to filter devices by. If None, devices for all frequencies are shown. - """ - table = PrettyTable() - table.field_names = ["Connected Node", "MAC Address", "IP Address", "Subnet Mask", "Frequency", "Status"] - - # If a specific frequency is provided, filter by it; otherwise, use all frequencies. - frequencies_to_show = [frequency] if frequency else self._wireless_interfaces_by_frequency.keys() - - for freq in frequencies_to_show: - interfaces = self._wireless_interfaces_by_frequency.get(freq, []) - for interface in interfaces: - status = "Enabled" if interface.enabled else "Disabled" - table.add_row( - [ - interface._connected_node.hostname, # noqa - interface.mac_address, - interface.ip_address if hasattr(interface, "ip_address") else None, - interface.subnet_mask if hasattr(interface, "subnet_mask") else None, - str(freq), - status, - ] - ) - - print(table) - - def add_wireless_interface(self, wireless_interface: WirelessNetworkInterface): - """ - Adds a wireless network interface to the airspace if it's not already present. - - :param wireless_interface: The wireless network interface to be added. - """ - if wireless_interface.mac_address not in self._wireless_interfaces: - self._wireless_interfaces[wireless_interface.mac_address] = wireless_interface - if wireless_interface.frequency not in self._wireless_interfaces_by_frequency: - self._wireless_interfaces_by_frequency[wireless_interface.frequency] = [] - self._wireless_interfaces_by_frequency[wireless_interface.frequency].append(wireless_interface) - - def remove_wireless_interface(self, wireless_interface: WirelessNetworkInterface): - """ - Removes a wireless network interface from the airspace if it's present. - - :param wireless_interface: The wireless network interface to be removed. - """ - if wireless_interface.mac_address in self._wireless_interfaces: - self._wireless_interfaces.pop(wireless_interface.mac_address) - self._wireless_interfaces_by_frequency[wireless_interface.frequency].remove(wireless_interface) - - def clear(self): - """ - Clears all wireless network interfaces and their frequency associations from the airspace. - - After calling this method, the airspace will contain no wireless network interfaces, and transmissions cannot - occur until new interfaces are added again. - """ - self._wireless_interfaces.clear() - self._wireless_interfaces_by_frequency.clear() - - def transmit(self, frame: Frame, sender_network_interface: WirelessNetworkInterface): - """ - Transmits a frame to all enabled wireless network interfaces on a specific frequency within the airspace. - - This ensures that a wireless interface does not receive its own transmission. - - :param frame: The frame to be transmitted. - :param sender_network_interface: The wireless network interface sending the frame. This interface will be - excluded from the list of receivers to prevent it from receiving its own transmission. - """ - for wireless_interface in self._wireless_interfaces_by_frequency.get(sender_network_interface.frequency, []): - if wireless_interface != sender_network_interface and wireless_interface.enabled: - wireless_interface.receive_frame(frame) + :param hertz: Frequency in Hertz. + :param format_terahertz: Whether to format frequency in Terahertz, default is False. + :param decimals: Number of decimal places to round to, default is 3. + :returns: Formatted string with the frequency in the most suitable unit. + """ + format_str = f"{{:.{decimals}f}}" + if format_terahertz and hertz >= 1e12: # Terahertz + return format_str.format(hertz / 1e12) + " THz" + elif hertz >= 1e9: # Gigahertz + return format_str.format(hertz / 1e9) + " GHz" + elif hertz >= 1e6: # Megahertz + return format_str.format(hertz / 1e6) + " MHz" + elif hertz >= 1e3: # Kilohertz + return format_str.format(hertz / 1e3) + " kHz" + else: # Hertz + return format_str.format(hertz) + " Hz" class AirSpaceFrequency(Enum): @@ -110,12 +51,478 @@ class AirSpaceFrequency(Enum): """WiFi 5 GHz. Known for its higher data transmission speeds and reduced interference from other devices.""" def __str__(self) -> str: + hertz_str = format_hertz(hertz=self.value) if self == AirSpaceFrequency.WIFI_2_4: - return "WiFi 2.4 GHz" - elif self == AirSpaceFrequency.WIFI_5: - return "WiFi 5 GHz" - else: - return "Unknown Frequency" + return f"WiFi {hertz_str}" + if self == AirSpaceFrequency.WIFI_5: + return f"WiFi {hertz_str}" + return "Unknown Frequency" + + +class ChannelWidth(Enum): + """ + Enumeration representing the available channel widths in MHz for wireless communications. + + This enum facilitates standardising and validating channel width configurations. + + Attributes: + WIDTH_20_MHZ (int): Represents a channel width of 20 MHz, commonly used for basic + Wi-Fi connectivity with standard range and interference resistance. + WIDTH_40_MHZ (int): Represents a channel width of 40 MHz, offering higher data + throughput at the expense of potentially increased interference. + WIDTH_80_MHZ (int): Represents a channel width of 80 MHz, typically used in modern + Wi-Fi setups for high data rate applications but with higher susceptibility to interference. + WIDTH_160_MHZ (int): Represents a channel width of 160 MHz, used for ultra-high-speed + network applications, providing maximum data throughput with significant + requirements on the spectral environment to minimize interference. + """ + + WIDTH_20_MHZ = 20 + """ + Represents a channel width of 20 MHz, commonly used for basic Wi-Fi connectivity with standard range and + interference resistance + """ + + WIDTH_40_MHZ = 40 + """ + Represents a channel width of 40 MHz, offering higher data throughput at the expense of potentially increased + interference. + """ + + WIDTH_80_MHZ = 80 + """ + Represents a channel width of 80 MHz, typically used in modern Wi-Fi setups for high data rate applications but + with higher susceptibility to interference. + """ + + WIDTH_160_MHZ = 160 + """ + Represents a channel width of 160 MHz, used for ultra-high-speed network applications, providing maximum data + throughput with significant requirements on the spectral environment to minimize interference. + """ + + def __str__(self) -> str: + """ + Returns a string representation of the channel width. + + :return: String in the format of " MHz" indicating the channel width. + """ + return f"{self.value} MHz" + + +AirSpaceKeyType = Tuple[AirSpaceFrequency, ChannelWidth] + + +class AirspaceEnvironmentType(Enum): + """Enum representing different types of airspace environments which affect wireless communication signals.""" + + RURAL = "rural" + """ + A rural environment offers clear channel conditions due to low population density and minimal electronic device + presence. + """ + + OUTDOOR = "outdoor" + """ + Outdoor environments like parks or fields have minimal electronic interference. + """ + + SUBURBAN = "suburban" + """ + Suburban environments strike a balance with fewer electronic interferences than urban but more than rural. + """ + + OFFICE = "office" + """ + Office environments have moderate interference from numerous electronic devices and overlapping networks. + """ + + URBAN = "urban" + """ + Urban environments are characterized by tall buildings and a high density of electronic devices, leading to + significant interference. + """ + + INDUSTRIAL = "industrial" + """ + Industrial areas face high interference from heavy machinery and numerous electronic devices. + """ + + TRANSPORT = "transport" + """ + Environments such as subways and buses where metal structures and high mobility create complex interference + patterns. + """ + + DENSE_URBAN = "dense_urban" + """ + Dense urban areas like city centers have the highest level of signal interference due to the very high density of + buildings and devices. + """ + + JAMMING_ZONE = "jamming_zone" + """ + A jamming zone environment where signals are actively interfered with, typically through the use of signal jammers + or scrambling devices. This represents the environment with the highest level of interference. + """ + + BLOCKED = "blocked" + """ + A jamming zone environment with total levels of interference. Airspace is completely blocked. + """ + + @property + def snr_impact(self) -> int: + """ + Returns the SNR impact associated with the environment. + + :return: SNR impact in dB. + """ + impacts = { + AirspaceEnvironmentType.RURAL: 0, + AirspaceEnvironmentType.OUTDOOR: 1, + AirspaceEnvironmentType.SUBURBAN: -5, + AirspaceEnvironmentType.OFFICE: -7, + AirspaceEnvironmentType.URBAN: -10, + AirspaceEnvironmentType.INDUSTRIAL: -15, + AirspaceEnvironmentType.TRANSPORT: -12, + AirspaceEnvironmentType.DENSE_URBAN: -20, + AirspaceEnvironmentType.JAMMING_ZONE: -40, + AirspaceEnvironmentType.BLOCKED: -100, + } + return impacts[self] + + def __str__(self) -> str: + return f"{self.value.title()} Environment (SNR Impact: {self.snr_impact})" + + +def estimate_snr( + frequency: AirSpaceFrequency, environment_type: AirspaceEnvironmentType, channel_width: ChannelWidth +) -> float: + """ + Estimate the Signal-to-Noise Ratio (SNR) based on the communication frequency, environment, and channel width. + + This function considers both the base SNR value dependent on the frequency and the impact of environmental + factors and channel width on the SNR. + + The SNR is adjusted by reducing it for wider channels, reflecting the increased noise floor from a broader + frequency range. + + :param frequency: The operating frequency as defined by AirSpaceFrequency enum, influencing the base SNR. Higher + frequencies like 5 GHz generally start with a higher base SNR due to less noise. + :param environment_type: The type of environment from AirspaceEnvironmentType enum, which adjusts the SNR based on + expected environmental noise and interference levels. + :param channel_width: The channel width from ChannelWidth enum, where wider channels (80 MHz and 160 MHz) decrease + the SNR slightly due to an increased noise floor. + :return: Estimated SNR in dB, calculated as the base SNR modified by environmental and channel width impacts. + """ + base_snr = 40 if frequency == AirSpaceFrequency.WIFI_5 else 30 + snr_impact = environment_type.snr_impact + + # Adjust SNR impact based on channel width + if channel_width == ChannelWidth.WIDTH_80_MHZ or channel_width == ChannelWidth.WIDTH_160_MHZ: + snr_impact -= 3 # Assume wider channels have slightly lower SNR due to increased noise floor + + return base_snr + snr_impact + + +def calculate_total_channel_capacity( + channel_width: ChannelWidth, frequency: AirSpaceFrequency, environment_type: AirspaceEnvironmentType +) -> float: + """ + Calculate the total theoretical data rate for the channel using the Shannon-Hartley theorem. + + This function determines the channel's capacity by considering the bandwidth (derived from channel width), + and the signal-to-noise ratio (SNR) adjusted by frequency and environmental conditions. + + The Shannon-Hartley theorem states that channel capacity C (in bits per second) can be calculated as: + ``C = B * log2(1 + SNR)`` where B is the bandwidth in Hertz and SNR is the signal-to-noise ratio. + + :param channel_width: The width of the channel as defined by ChannelWidth enum, converted to Hz for calculation. + :param frequency: The operating frequency as defined by AirSpaceFrequency enum, influencing the base SNR and part + of the SNR estimation. + :param environment_type: The type of environment as defined by AirspaceEnvironmentType enum, used in SNR estimation. + :return: Theoretical total data rate in Mbps for the entire channel. + """ + bandwidth_hz = channel_width.value * 1_000_000 # Convert MHz to Hz + snr_db = estimate_snr(frequency, environment_type, channel_width) + snr_linear = 10 ** (snr_db / 10) + + total_capacity_bps = bandwidth_hz * np.log2(1 + snr_linear) + total_capacity_mbps = total_capacity_bps / 1_000_000 + + return total_capacity_mbps + + +def calculate_individual_device_rate( + channel_width: ChannelWidth, + frequency: AirSpaceFrequency, + environment_type: AirspaceEnvironmentType, + device_count: int, +) -> float: + """ + Calculate the theoretical data rate available to each individual device on the channel. + + This function first calculates the total channel capacity and then divides this capacity by the number + of active devices to estimate each device's share of the bandwidth. This reflects the practical limitation + that multiple devices must share the same channel resources. + + :param channel_width: The channel width as defined by ChannelWidth enum, used in total capacity calculation. + :param frequency: The operating frequency as defined by AirSpaceFrequency enum, used in total capacity calculation. + :param environment_type: The environment type as defined by AirspaceEnvironmentType enum, impacting SNR and + capacity. + :param device_count: The number of devices sharing the channel. If zero, returns zero to avoid division by zero. + :return: Theoretical data rate in Mbps available per device, based on shared channel capacity. + """ + total_capacity_mbps = calculate_total_channel_capacity(channel_width, frequency, environment_type) + if device_count == 0: + return 0 # Avoid division by zero + individual_device_rate_mbps = total_capacity_mbps / device_count + + return individual_device_rate_mbps + + +class AirSpace(BaseModel): + """ + Represents a wireless airspace, managing wireless network interfaces and handling wireless transmission. + + This class provides functionalities to manage a collection of wireless network interfaces, each associated with + specific frequencies and channel widths. It includes methods to calculate and manage bandwidth loads, add and + remove wireless interfaces, and handle data transmission across these interfaces. + """ + + airspace_environment_type_: AirspaceEnvironmentType = AirspaceEnvironmentType.URBAN + wireless_interfaces: Dict[str, WirelessNetworkInterface] = Field(default_factory=lambda: {}) + wireless_interfaces_by_frequency_channel_width: Dict[AirSpaceKeyType, List[WirelessNetworkInterface]] = Field( + default_factory=lambda: {} + ) + bandwidth_load: Dict[AirSpaceKeyType, float] = Field(default_factory=lambda: {}) + frequency_channel_width_max_capacity_mbps: Dict[AirSpaceKeyType, float] = Field(default_factory=lambda: {}) + + def model_post_init(self, __context: Any) -> None: + """ + Initialize the airspace metadata after instantiation. + + This method is called to set up initial configurations like the maximum capacity of each channel width and + frequency based on the current environment setting. + + :param __context: Contextual data or settings, typically used for further initializations beyond + the basic constructor. + """ + self._set_frequency_channel_width_max_capacity_mbps() + + def _set_frequency_channel_width_max_capacity_mbps(self): + """ + Private method to compute and set the maximum channel capacity in Mbps for each frequency and channel width. + + Based on the airspace environment type, this method calculates the maximum possible data transmission + capacity for each combination of frequency and channel width available and stores these values. + These capacities are critical for managing and limiting bandwidth load during operations. + """ + print( + f"Rebuilding the frequency channel width maximum capacity dictionary based on " + f"airspace environment type {self.airspace_environment_type_}" + ) + for frequency in AirSpaceFrequency: + for channel_width in ChannelWidth: + max_capacity = calculate_total_channel_capacity( + frequency=frequency, channel_width=channel_width, environment_type=self.airspace_environment_type + ) + self.frequency_channel_width_max_capacity_mbps[frequency, channel_width] = max_capacity + + @computed_field + @property + def airspace_environment_type(self) -> AirspaceEnvironmentType: + """ + Gets the current environment type of the airspace. + + :return: The AirspaceEnvironmentType representing the current environment type. + """ + return self.airspace_environment_type_ + + @airspace_environment_type.setter + def airspace_environment_type(self, value: AirspaceEnvironmentType) -> None: + """ + Sets a new environment type for the airspace and updates related configurations. + + Changing the environment type triggers a re-calculation of the maximum channel capacities and + adjustments to the current setup of wireless interfaces to ensure they are aligned with the + new environment settings. + + :param value: The new environment type as an AirspaceEnvironmentType. + """ + if value != self.airspace_environment_type_: + print(f"Setting airspace_environment_type to {value}") + self.airspace_environment_type_ = value + self._set_frequency_channel_width_max_capacity_mbps() + wireless_interface_keys = list(self.wireless_interfaces.keys()) + for wireless_interface_key in wireless_interface_keys: + wireless_interface = self.wireless_interfaces[wireless_interface_key] + self.remove_wireless_interface(wireless_interface) + self.add_wireless_interface(wireless_interface) + + def show_bandwidth_load(self, markdown: bool = False): + """ + Prints a table of the current bandwidth load for each frequency and channel width combination on the airspace. + + This method prints a tabulated view showing the utilisation of available bandwidth capacities for all configured + frequency and channel width pairings. The table includes the current capacity usage as a percentage of the + maximum capacity, alongside the absolute maximum capacity values in Mbps. + + :param markdown: Flag indicating if output should be in markdown format. + """ + headers = ["Frequency", "Channel Width", "Current Capacity (%)", "Maximum Capacity (Mbit)"] + table = PrettyTable(headers) + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = "Airspace Frequency Channel Loads" + for key, load in self.bandwidth_load.items(): + frequency, channel_width = key + maximum_capacity = self.frequency_channel_width_max_capacity_mbps[key] + load_percent = load / maximum_capacity + if load_percent > 1.0: + load_percent = 1.0 + table.add_row( + [format_hertz(frequency.value), str(channel_width), f"{load_percent:.0%}", f"{maximum_capacity:.3f}"] + ) + print(table) + + def show_wireless_interfaces(self, markdown: bool = False): + """ + Prints a table of wireless interfaces in the airspace. + + :param markdown: Flag indicating if output should be in markdown format. + """ + headers = [ + "Connected Node", + "MAC Address", + "IP Address", + "Subnet Mask", + "Frequency", + "Channel Width", + "Speed (Mbps)", + "Status", + ] + table = PrettyTable(headers) + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = f"Devices on Air Space - {self.airspace_environment_type}" + + for interface in self.wireless_interfaces.values(): + status = "Enabled" if interface.enabled else "Disabled" + table.add_row( + [ + interface._connected_node.hostname, # noqa + interface.mac_address, + interface.ip_address if hasattr(interface, "ip_address") else None, + interface.subnet_mask if hasattr(interface, "subnet_mask") else None, + format_hertz(interface.frequency.value), + str(interface.channel_width), + f"{interface.speed:.3f}", + status, + ] + ) + print(table.get_string(sortby="Frequency")) + + def show(self, markdown: bool = False): + """ + Prints a summary of the current state of the airspace, including both wireless interfaces and bandwidth loads. + + This method is a convenient wrapper that calls two separate methods to display detailed tables: one for + wireless interfaces and another for bandwidth load across all frequencies and channel widths managed within the + airspace. It provides a holistic view of the operational status and performance metrics of the airspace. + + :param markdown: Flag indicating if output should be in markdown format. + """ + self.show_wireless_interfaces(markdown) + self.show_bandwidth_load(markdown) + + def add_wireless_interface(self, wireless_interface: WirelessNetworkInterface): + """ + Adds a wireless network interface to the airspace if it's not already present. + + :param wireless_interface: The wireless network interface to be added. + """ + if wireless_interface.mac_address not in self.wireless_interfaces: + self.wireless_interfaces[wireless_interface.mac_address] = wireless_interface + if wireless_interface.airspace_key not in self.wireless_interfaces_by_frequency_channel_width: + self.wireless_interfaces_by_frequency_channel_width[wireless_interface.airspace_key] = [] + self.wireless_interfaces_by_frequency_channel_width[wireless_interface.airspace_key].append( + wireless_interface + ) + speed = calculate_total_channel_capacity( + wireless_interface.channel_width, wireless_interface.frequency, self.airspace_environment_type + ) + wireless_interface.set_speed(speed) + + def remove_wireless_interface(self, wireless_interface: WirelessNetworkInterface): + """ + Removes a wireless network interface from the airspace if it's present. + + :param wireless_interface: The wireless network interface to be removed. + """ + if wireless_interface.mac_address in self.wireless_interfaces: + self.wireless_interfaces.pop(wireless_interface.mac_address) + self.wireless_interfaces_by_frequency_channel_width[wireless_interface.airspace_key].remove( + wireless_interface + ) + + def clear(self): + """ + Clears all wireless network interfaces and their frequency associations from the airspace. + + After calling this method, the airspace will contain no wireless network interfaces, and transmissions cannot + occur until new interfaces are added again. + """ + self.wireless_interfaces.clear() + self.wireless_interfaces_by_frequency_channel_width.clear() + + def reset_bandwidth_load(self): + """ + Resets the bandwidth load tracking for all frequencies in the airspace. + + This method clears the current load metrics for all operating frequencies, effectively setting the load to zero. + """ + self.bandwidth_load = {} + + def can_transmit_frame(self, frame: Frame, sender_network_interface: WirelessNetworkInterface) -> bool: + """ + Determines if a frame can be transmitted by the sender network interface based on the current bandwidth load. + + This method checks if adding the size of the frame to the current bandwidth load of the frequency used by the + sender network interface would exceed the maximum allowed bandwidth for that frequency. It returns True if the + frame can be transmitted without exceeding the limit, and False otherwise. + + :param frame: The frame to be transmitted, used to check its size against the frequency's bandwidth limit. + :param sender_network_interface: The network interface attempting to transmit the frame, used to determine the + relevant frequency and its current bandwidth load. + :return: True if the frame can be transmitted within the bandwidth limit, False if it would exceed the limit. + """ + if sender_network_interface.airspace_key not in self.bandwidth_load: + self.bandwidth_load[sender_network_interface.airspace_key] = 0.0 + return ( + self.bandwidth_load[sender_network_interface.airspace_key] + frame.size_Mbits + <= self.frequency_channel_width_max_capacity_mbps[sender_network_interface.airspace_key] + ) + + def transmit(self, frame: Frame, sender_network_interface: WirelessNetworkInterface): + """ + Transmits a frame to all enabled wireless network interfaces on a specific frequency within the airspace. + + This ensures that a wireless interface does not receive its own transmission. + + :param frame: The frame to be transmitted. + :param sender_network_interface: The wireless network interface sending the frame. This interface will be + excluded from the list of receivers to prevent it from receiving its own transmission. + """ + self.bandwidth_load[sender_network_interface.airspace_key] += frame.size_Mbits + for wireless_interface in self.wireless_interfaces_by_frequency_channel_width.get( + sender_network_interface.airspace_key, [] + ): + if wireless_interface != sender_network_interface and wireless_interface.enabled: + wireless_interface.receive_frame(frame) class WirelessNetworkInterface(NetworkInterface, ABC): @@ -139,7 +546,135 @@ class WirelessNetworkInterface(NetworkInterface, ABC): """ airspace: AirSpace - frequency: AirSpaceFrequency = AirSpaceFrequency.WIFI_2_4 + frequency_: AirSpaceFrequency = AirSpaceFrequency.WIFI_2_4 + channel_width_: ChannelWidth = ChannelWidth.WIDTH_40_MHZ + + @model_validator(mode="after") # noqa + def validate_channel_width_for_2_4_ghz(self) -> "WirelessNetworkInterface": + """ + Validate the wireless interface's channel width settings after model changes. + + This method serves as a model validator to ensure that the channel width settings for the 2.4 GHz frequency + comply with accepted standards (either 20 MHz or 40 MHz). It's triggered after model instantiation. + + Ensures that the channel width is appropriate for the current frequency setting, particularly checking + and adjusting the settings for the 2.4 GHz frequency band to not exceed 40 MHz. This is crucial for + avoiding interference and ensuring optimal performance in densely populated wireless environments. + """ + self._check_wifi_24_channel_width() + return self + + def model_post_init(self, __context: Any) -> None: + """Initialise the model after its creation, setting the speed based on the calculated channel capacity.""" + speed = calculate_total_channel_capacity( + channel_width=self.channel_width, + frequency=self.frequency, + environment_type=self.airspace.airspace_environment_type, + ) + self.set_speed(speed) + + def _check_wifi_24_channel_width(self) -> None: + """ + Ensures that the channel width for 2.4 GHz frequency does not exceed 40 MHz. + + This method checks the current frequency and channel width settings and adjusts the channel width + to 40 MHz if the frequency is set to 2.4 GHz and the channel width exceeds 40 MHz. This is done to + comply with typical Wi-Fi standards for 2.4 GHz frequencies, which commonly support up to 40 MHz. + + Logs a SysLog warning if the channel width had to be adjusted, logging this change either to the connected + node's system log or the global logger, depending on whether the interface is connected to a node. + """ + if self.frequency_ == AirSpaceFrequency.WIFI_2_4 and self.channel_width_.value > 40: + self.channel_width_ = ChannelWidth.WIDTH_40_MHZ + msg = ( + f"Channel width must be either 20 Mhz or 40 Mhz when using {AirSpaceFrequency.WIFI_2_4}. " + f"Overriding value to use {ChannelWidth.WIDTH_40_MHZ}." + ) + if self._connected_node: + self._connected_node.sys_log.warning(f"Wireless Interface {self.port_num}: {msg}") + else: + _LOGGER.warning(msg) + + @computed_field + @property + def frequency(self) -> AirSpaceFrequency: + """ + Get the current operating frequency of the wireless interface. + + :return: The current frequency as an AirSpaceFrequency enum value. + """ + return self.frequency_ + + @frequency.setter + def frequency(self, value: AirSpaceFrequency) -> None: + """ + Set the operating frequency of the wireless interface and update the network configuration. + + This setter updates the frequency of the wireless interface if the new value differs from the current setting. + It handles the update by first removing the interface from the current airspace management to avoid conflicts, + setting the new frequency, ensuring the channel width remains compliant, and then re-adding the interface + to the airspace with the new settings. + + :param value: The new frequency to set, as an AirSpaceFrequency enum value. + """ + if value != self.frequency_: + self.airspace.remove_wireless_interface(self) + self.frequency_ = value + self._check_wifi_24_channel_width() + self.airspace.add_wireless_interface(self) + + @computed_field + @property + def channel_width(self) -> ChannelWidth: + """ + Get the current channel width setting of the wireless interface. + + :return: The current channel width as a ChannelWidth enum value. + """ + return self.channel_width_ + + @channel_width.setter + def channel_width(self, value: ChannelWidth) -> None: + """ + Set the channel width of the wireless interface and manage configuration compliance. + + Updates the channel width of the wireless interface. If the new channel width is different from the existing + one, it first removes the interface from the airspace to prevent configuration conflicts, sets the new channel + width, checks and adjusts it if necessary (especially for 2.4 GHz frequency to comply with typical standards), + and then re-registers the interface in the airspace with updated settings. + + :param value: The new channel width to set, as a ChannelWidth enum value. + """ + if value != self.channel_width_: + self.airspace.remove_wireless_interface(self) + self.channel_width_ = value + self._check_wifi_24_channel_width() + self.airspace.add_wireless_interface(self) + + @property + def airspace_key(self) -> tuple: + """ + The airspace bandwidth/channel identifier for the wireless interface based on its frequency and channel width. + + :return: A tuple containing the frequency and channel width, serving as a bandwidth/channel key. + """ + return self.frequency_, self.channel_width_ + + def set_speed(self, speed: float): + """ + Sets the network interface speed to the specified value and logs this action. + + This method updates the speed attribute of the network interface to the given value, reflecting + the theoretical maximum data rate that the interface can support based on the current settings. + It logs the new speed to the system log of the connected node if available. + + :param speed: The speed in Mbps to be set for the network interface. + """ + self.speed = speed + if self._connected_node: + self._connected_node.sys_log.info( + f"Wireless Interface {self.port_num}: Setting theoretical maximum data rate to {speed:.3f} Mbps." + ) def enable(self): """Attempt to enable the network interface.""" @@ -188,8 +723,12 @@ class WirelessNetworkInterface(NetworkInterface, ABC): if self.enabled: frame.set_sent_timestamp() self.pcap.capture_outbound(frame) - self.airspace.transmit(frame, self) - return True + if self.airspace.can_transmit_frame(frame, self): + self.airspace.transmit(frame, self) + return True + else: + # Cannot send Frame as the frequency bandwidth is at capacity + return False # Cannot send Frame as the network interface is not enabled return False diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index 2b9f3e53..0408acde 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -96,6 +96,8 @@ class Network(SimComponent): """Apply pre-timestep logic.""" super().pre_timestep(timestep) + self.airspace.reset_bandwidth_load() + for node in self.nodes.values(): node.pre_timestep(timestep) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 01745215..743b2e76 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -87,7 +87,7 @@ class NetworkInterface(SimComponent, ABC): mac_address: str = Field(default_factory=generate_mac_address) "The MAC address of the interface." - speed: int = 100 + speed: float = 100.0 "The speed of the interface in Mbps. Default is 100 Mbps." mtu: int = 1500 @@ -679,11 +679,20 @@ class Link(SimComponent): return self.endpoint_a.enabled and self.endpoint_b.enabled def _can_transmit(self, frame: Frame) -> bool: + """ + Determines whether a frame can be transmitted considering the current Link load and the Link's bandwidth. + + This method assesses if the transmission of a given frame is possible without exceeding the Link's total + bandwidth capacity. It checks if the current load of the Link plus the size of the frame (expressed in Mbps) + would remain within the defined bandwidth limits. The transmission is only feasible if the Link is active + ('up') and the total load including the new frame does not surpass the bandwidth limit. + + :param frame: The frame intended for transmission, which contains its size in Mbps. + :return: True if the frame can be transmitted without exceeding the bandwidth limit, False otherwise. + """ if self.is_up: frame_size_Mbits = frame.size_Mbits # noqa - Leaving it as Mbits as this is how they're expressed - # return self.current_load + frame_size_Mbits <= self.bandwidth - # TODO: re add this check once packet size limiting and MTU checks are implemented - return True + return self.current_load + frame.size_Mbits <= self.bandwidth return False def transmit_frame(self, sender_nic: WiredNetworkInterface, frame: Frame) -> bool: diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index e329f7a1..dda9e4f8 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -1,10 +1,10 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union from pydantic import validate_call -from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, IPWirelessNetworkInterface +from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, ChannelWidth, IPWirelessNetworkInterface from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterInterface from primaite.simulator.network.transmission.data_link_layer import Frame @@ -153,7 +153,8 @@ class WirelessRouter(Router): self, ip_address: IPV4Address, subnet_mask: IPV4Address, - frequency: AirSpaceFrequency = AirSpaceFrequency.WIFI_2_4, + frequency: Optional[AirSpaceFrequency] = AirSpaceFrequency.WIFI_2_4, + channel_width: Optional[ChannelWidth] = ChannelWidth.WIDTH_40_MHZ, ): """ Configures a wireless access point (WAP). @@ -170,13 +171,23 @@ class WirelessRouter(Router): enum. This determines the frequency band (e.g., 2.4 GHz or 5 GHz) the access point will use for wireless communication. Default is AirSpaceFrequency.WIFI_2_4. """ + if not frequency: + frequency = AirSpaceFrequency.WIFI_2_4 + if not channel_width: + channel_width = ChannelWidth.WIDTH_40_MHZ + self.sys_log.info("Configuring wireless access point") + self.wireless_access_point.disable() # Temporarily disable the WAP for reconfiguration + network_interface = self.network_interface[1] + network_interface.ip_address = ip_address network_interface.subnet_mask = subnet_mask - self.sys_log.info(f"Configured WAP {network_interface}") + self.wireless_access_point.frequency = frequency # Set operating frequency + self.wireless_access_point.channel_width = channel_width self.wireless_access_point.enable() # Re-enable the WAP with new settings + self.sys_log.info(f"Configured WAP {network_interface}") @property def router_interface(self) -> RouterInterface: @@ -258,7 +269,12 @@ class WirelessRouter(Router): ip_address = cfg["wireless_access_point"]["ip_address"] subnet_mask = cfg["wireless_access_point"]["subnet_mask"] frequency = AirSpaceFrequency[cfg["wireless_access_point"]["frequency"]] - router.configure_wireless_access_point(ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency) + channel_width = cfg["wireless_access_point"].get("channel_width") + if channel_width: + channel_width = ChannelWidth(channel_width) + router.configure_wireless_access_point( + ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency, channel_width=channel_width + ) if "acl" in cfg: for r_num, r_cfg in cfg["acl"].items(): diff --git a/src/primaite/simulator/network/transmission/data_link_layer.py b/src/primaite/simulator/network/transmission/data_link_layer.py index 776a5bfb..159eca7f 100644 --- a/src/primaite/simulator/network/transmission/data_link_layer.py +++ b/src/primaite/simulator/network/transmission/data_link_layer.py @@ -133,10 +133,11 @@ class Frame(BaseModel): def size(self) -> float: # noqa - Keep it as MBits as this is how they're expressed """The size of the Frame in Bytes.""" # get the payload size if it is a data packet + payload_size = 0.0 if isinstance(self.payload, DataPacket): - return self.payload.get_packet_size() + payload_size = self.payload.get_packet_size() - return float(len(self.model_dump_json().encode("utf-8"))) + return float(len(self.model_dump_json().encode("utf-8"))) + payload_size @property def size_Mbits(self) -> float: # noqa - Keep it as MBits as this is how they're expressed diff --git a/tests/assets/configs/wireless_wan_network_config.yaml b/tests/assets/configs/wireless_wan_network_config.yaml index c8f61bad..684acaf7 100644 --- a/tests/assets/configs/wireless_wan_network_config.yaml +++ b/tests/assets/configs/wireless_wan_network_config.yaml @@ -9,6 +9,8 @@ game: simulation: network: + airspace: + airspace_environment_type: blocked nodes: - type: computer hostname: pc_a diff --git a/tests/assets/configs/wireless_wan_wifi_5_80_channel_width_blocked.yaml b/tests/assets/configs/wireless_wan_wifi_5_80_channel_width_blocked.yaml new file mode 100644 index 00000000..21b0fe5e --- /dev/null +++ b/tests/assets/configs/wireless_wan_wifi_5_80_channel_width_blocked.yaml @@ -0,0 +1,81 @@ +game: + max_episode_length: 256 + ports: + - ARP + protocols: + - ICMP + - TCP + - UDP + +simulation: + network: + airspace: + airspace_environment_type: blocked + nodes: + - type: computer + hostname: pc_a + ip_address: 192.168.0.2 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.0.1 + start_up_duration: 0 + + - type: computer + hostname: pc_b + ip_address: 192.168.2.2 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.2.1 + start_up_duration: 0 + + - type: wireless_router + hostname: router_1 + start_up_duration: 0 + + router_interface: + ip_address: 192.168.0.1 + subnet_mask: 255.255.255.0 + + wireless_access_point: + ip_address: 192.168.1.1 + subnet_mask: 255.255.255.0 + frequency: WIFI_5 + channel_width: 80 + acl: + 1: + action: PERMIT + routes: + - address: 192.168.2.0 # PC B subnet + subnet_mask: 255.255.255.0 + next_hop_ip_address: 192.168.1.2 + metric: 0 + + - type: wireless_router + hostname: router_2 + start_up_duration: 0 + + router_interface: + ip_address: 192.168.2.1 + subnet_mask: 255.255.255.0 + + wireless_access_point: + ip_address: 192.168.1.2 + subnet_mask: 255.255.255.0 + frequency: WIFI_5 + channel_width: 80 + acl: + 1: + action: PERMIT + routes: + - address: 192.168.0.0 # PC A subnet + subnet_mask: 255.255.255.0 + next_hop_ip_address: 192.168.1.1 + metric: 0 + links: + - endpoint_a_hostname: pc_a + endpoint_a_port: 1 + endpoint_b_hostname: router_1 + endpoint_b_port: 2 + + - endpoint_a_hostname: pc_b + endpoint_a_port: 1 + endpoint_b_hostname: router_2 + endpoint_b_port: 2 diff --git a/tests/assets/configs/wireless_wan_wifi_5_80_channel_width_urban.yaml b/tests/assets/configs/wireless_wan_wifi_5_80_channel_width_urban.yaml new file mode 100644 index 00000000..ed27cd35 --- /dev/null +++ b/tests/assets/configs/wireless_wan_wifi_5_80_channel_width_urban.yaml @@ -0,0 +1,81 @@ +game: + max_episode_length: 256 + ports: + - ARP + protocols: + - ICMP + - TCP + - UDP + +simulation: + network: + airspace: + airspace_environment_type: urban + nodes: + - type: computer + hostname: pc_a + ip_address: 192.168.0.2 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.0.1 + start_up_duration: 0 + + - type: computer + hostname: pc_b + ip_address: 192.168.2.2 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.2.1 + start_up_duration: 0 + + - type: wireless_router + hostname: router_1 + start_up_duration: 0 + + router_interface: + ip_address: 192.168.0.1 + subnet_mask: 255.255.255.0 + + wireless_access_point: + ip_address: 192.168.1.1 + subnet_mask: 255.255.255.0 + frequency: WIFI_5 + channel_width: 80 + acl: + 1: + action: PERMIT + routes: + - address: 192.168.2.0 # PC B subnet + subnet_mask: 255.255.255.0 + next_hop_ip_address: 192.168.1.2 + metric: 0 + + - type: wireless_router + hostname: router_2 + start_up_duration: 0 + + router_interface: + ip_address: 192.168.2.1 + subnet_mask: 255.255.255.0 + + wireless_access_point: + ip_address: 192.168.1.2 + subnet_mask: 255.255.255.0 + frequency: WIFI_5 + channel_width: 80 + acl: + 1: + action: PERMIT + routes: + - address: 192.168.0.0 # PC A subnet + subnet_mask: 255.255.255.0 + next_hop_ip_address: 192.168.1.1 + metric: 0 + links: + - endpoint_a_hostname: pc_a + endpoint_a_port: 1 + endpoint_b_hostname: router_1 + endpoint_b_port: 2 + + - endpoint_a_hostname: pc_b + endpoint_a_port: 1 + endpoint_b_hostname: router_2 + endpoint_b_port: 2 diff --git a/tests/integration_tests/network/test_airspace_capacity_configuration.py b/tests/integration_tests/network/test_airspace_capacity_configuration.py new file mode 100644 index 00000000..f91f1290 --- /dev/null +++ b/tests/integration_tests/network/test_airspace_capacity_configuration.py @@ -0,0 +1,106 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import yaml + +from primaite.game.game import PrimaiteGame +from primaite.simulator.network.airspace import ( + AirspaceEnvironmentType, + AirSpaceFrequency, + calculate_total_channel_capacity, + ChannelWidth, +) +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter +from tests import TEST_ASSETS_ROOT + + +def test_wireless_wan_wifi_5_80_channel_width_urban(): + config_path = TEST_ASSETS_ROOT / "configs" / "wireless_wan_wifi_5_80_channel_width_urban.yaml" + + with open(config_path, "r") as f: + config_dict = yaml.safe_load(f) + network = PrimaiteGame.from_config(cfg=config_dict).simulation.network + + airspace = network.airspace + + assert airspace.airspace_environment_type == AirspaceEnvironmentType.URBAN + + router_1: WirelessRouter = network.get_node_by_hostname("router_1") + router_2: WirelessRouter = network.get_node_by_hostname("router_2") + + expected_speed = calculate_total_channel_capacity( + channel_width=ChannelWidth.WIDTH_80_MHZ, + frequency=AirSpaceFrequency.WIFI_5, + environment_type=AirspaceEnvironmentType.URBAN, + ) + + assert router_1.wireless_access_point.speed == expected_speed + assert router_2.wireless_access_point.speed == expected_speed + + pc_a = network.get_node_by_hostname("pc_a") + pc_b = network.get_node_by_hostname("pc_b") + + assert pc_a.ping(pc_a.default_gateway), "PC A should ping its default gateway successfully." + assert pc_b.ping(pc_b.default_gateway), "PC B should ping its default gateway successfully." + + assert pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should ping PC B across routers successfully." + assert pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should ping PC A across routers successfully." + + +def test_wireless_wan_wifi_5_80_channel_width_blocked(): + config_path = TEST_ASSETS_ROOT / "configs" / "wireless_wan_wifi_5_80_channel_width_blocked.yaml" + + with open(config_path, "r") as f: + config_dict = yaml.safe_load(f) + network = PrimaiteGame.from_config(cfg=config_dict).simulation.network + + airspace = network.airspace + + assert airspace.airspace_environment_type == AirspaceEnvironmentType.BLOCKED + + router_1: WirelessRouter = network.get_node_by_hostname("router_1") + router_2: WirelessRouter = network.get_node_by_hostname("router_2") + + expected_speed = calculate_total_channel_capacity( + channel_width=ChannelWidth.WIDTH_80_MHZ, + frequency=AirSpaceFrequency.WIFI_5, + environment_type=AirspaceEnvironmentType.BLOCKED, + ) + + assert router_1.wireless_access_point.speed == expected_speed + assert router_2.wireless_access_point.speed == expected_speed + + pc_a = network.get_node_by_hostname("pc_a") + pc_b = network.get_node_by_hostname("pc_b") + + assert pc_a.ping(pc_a.default_gateway), "PC A should ping its default gateway successfully." + assert pc_b.ping(pc_b.default_gateway), "PC B should ping its default gateway successfully." + + assert not pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should ping PC B across routers unsuccessfully." + assert not pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should ping PC A across routers unsuccessfully." + + +def test_wireless_wan_blocking_and_unblocking_airspace(): + config_path = TEST_ASSETS_ROOT / "configs" / "wireless_wan_wifi_5_80_channel_width_urban.yaml" + + with open(config_path, "r") as f: + config_dict = yaml.safe_load(f) + network = PrimaiteGame.from_config(cfg=config_dict).simulation.network + + airspace = network.airspace + + assert airspace.airspace_environment_type == AirspaceEnvironmentType.URBAN + + pc_a = network.get_node_by_hostname("pc_a") + pc_b = network.get_node_by_hostname("pc_b") + + assert pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should ping PC B across routers successfully." + assert pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should ping PC A across routers successfully." + + airspace.airspace_environment_type = AirspaceEnvironmentType.BLOCKED + + assert not pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should ping PC B across routers unsuccessfully." + assert not pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should ping PC A across routers unsuccessfully." + + airspace.airspace_environment_type = AirspaceEnvironmentType.URBAN + + assert pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should ping PC B across routers successfully." + assert pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should ping PC A across routers successfully." diff --git a/tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py b/tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py new file mode 100644 index 00000000..cf03ea8e --- /dev/null +++ b/tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py @@ -0,0 +1,138 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from primaite.simulator.file_system.file_type import FileType +from primaite.simulator.network.hardware.nodes.network.router import ACLAction +from primaite.simulator.system.services.ftp.ftp_client import FTPClient +from primaite.simulator.system.services.ftp.ftp_server import FTPServer +from tests.integration_tests.network.test_wireless_router import wireless_wan_network +from tests.integration_tests.system.test_ftp_client_server import ftp_client_and_ftp_server + + +def test_wireless_link_loading(wireless_wan_network): + client, server, router_1, router_2 = wireless_wan_network + + # Configure Router 1 ACLs + router_1.acl.add_rule(action=ACLAction.PERMIT, position=1) + + # Configure Router 2 ACLs + router_2.acl.add_rule(action=ACLAction.PERMIT, position=1) + + airspace = router_1.airspace + + client.software_manager.install(FTPClient) + ftp_client: FTPClient = client.software_manager.software.get("FTPClient") + ftp_client.start() + + server.software_manager.install(FTPServer) + ftp_server: FTPServer = server.software_manager.software.get("FTPServer") + ftp_server.start() + + client.file_system.create_file(file_name="mixtape", size=10 * 10**6, file_type=FileType.MP3, folder_name="music") + + assert ftp_client.send_file( + src_file_name="mixtape.mp3", + src_folder_name="music", + dest_ip_address=server.network_interface[1].ip_address, + dest_file_name="mixtape.mp3", + dest_folder_name="music", + ) + + # Reset the physical links between the host nodes and the routers + client.network_interface[1]._connected_link.pre_timestep(1) + server.network_interface[1]._connected_link.pre_timestep(1) + + assert ftp_client.send_file( + src_file_name="mixtape.mp3", + src_folder_name="music", + dest_ip_address=server.network_interface[1].ip_address, + dest_file_name="mixtape1.mp3", + dest_folder_name="music", + ) + + # Reset the physical links between the host nodes and the routers + client.network_interface[1]._connected_link.pre_timestep(1) + server.network_interface[1]._connected_link.pre_timestep(1) + + assert ftp_client.send_file( + src_file_name="mixtape.mp3", + src_folder_name="music", + dest_ip_address=server.network_interface[1].ip_address, + dest_file_name="mixtape2.mp3", + dest_folder_name="music", + ) + + # Reset the physical links between the host nodes and the routers + client.network_interface[1]._connected_link.pre_timestep(1) + server.network_interface[1]._connected_link.pre_timestep(1) + + assert not ftp_client.send_file( + src_file_name="mixtape.mp3", + src_folder_name="music", + dest_ip_address=server.network_interface[1].ip_address, + dest_file_name="mixtape3.mp3", + dest_folder_name="music", + ) + + # Reset the physical links between the host nodes and the routers + client.network_interface[1]._connected_link.pre_timestep(1) + server.network_interface[1]._connected_link.pre_timestep(1) + + airspace.reset_bandwidth_load() + + assert ftp_client.send_file( + src_file_name="mixtape.mp3", + src_folder_name="music", + dest_ip_address=server.network_interface[1].ip_address, + dest_file_name="mixtape3.mp3", + dest_folder_name="music", + ) + + +def test_wired_link_loading(ftp_client_and_ftp_server): + ftp_client, computer, ftp_server, server = ftp_client_and_ftp_server + + link = computer.network_interface[1]._connected_link # noqa + + assert link.is_up + + link.pre_timestep(1) + + computer.file_system.create_file( + file_name="mixtape", size=10 * 10**6, file_type=FileType.MP3, folder_name="music" + ) + link_load = link.current_load + assert link_load == 0.0 + + assert ftp_client.send_file( + src_file_name="mixtape.mp3", + src_folder_name="music", + dest_ip_address=server.network_interface[1].ip_address, + dest_file_name="mixtape.mp3", + dest_folder_name="music", + ) + + new_link_load = link.current_load + assert new_link_load > link_load + + assert not ftp_client.send_file( + src_file_name="mixtape.mp3", + src_folder_name="music", + dest_ip_address=server.network_interface[1].ip_address, + dest_file_name="mixtape1.mp3", + dest_folder_name="music", + ) + + link.pre_timestep(2) + + link_load = link.current_load + assert link_load == 0.0 + + assert ftp_client.send_file( + src_file_name="mixtape.mp3", + src_folder_name="music", + dest_ip_address=server.network_interface[1].ip_address, + dest_file_name="mixtape1.mp3", + dest_folder_name="music", + ) + + new_link_load = link.current_load + assert new_link_load > link_load From 9468adb6060bc067f3b4ccd857db424c896f718d Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Thu, 4 Jul 2024 20:52:20 +0100 Subject: [PATCH 02/35] #2967 - Updated the airspace configuration description in simulation.rst --- docs/source/configuration/simulation.rst | 26 +++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/docs/source/configuration/simulation.rst b/docs/source/configuration/simulation.rst index d585711d..bd66914d 100644 --- a/docs/source/configuration/simulation.rst +++ b/docs/source/configuration/simulation.rst @@ -108,7 +108,31 @@ This is an integer value specifying the allowed bandwidth across the connection. ``airspace`` ------------ -This is where the airspace settings for wireless networks arte set. +This section configures settings specific to the wireless network's virtual airspace. It defines how wireless interfaces within the simulation will interact and perform under various environmental conditions. ``airspace_environment_type`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This setting specifies the environmental conditions of the airspace which affect the propagation and interference characteristics of wireless signals. Changing this environment type impacts how signal noise and interference are calculated, thus affecting the overall network performance, including data transmission rates and signal quality. + +**Configurable Options** + +- **rural**: A rural environment offers clear channel conditions due to low population density and minimal electronic device presence. + +- **outdoor**: Outdoor environments like parks or fields have minimal electronic interference. + +- **suburban**: Suburban environments strike a balance with fewer electronic interferences than urban but more than rural. + +- **office**: Office environments have moderate interference from numerous electronic devices and overlapping networks. + +- **urban**: Urban environments are characterized by tall buildings and a high density of electronic devices, leading to significant interference. + +- **industrial**: Industrial areas face high interference from heavy machinery and numerous electronic devices. + +- **transport**: Environments such as subways and buses where metal structures and high mobility create complex interference patterns. + +- **dense_urban**: Dense urban areas like city centers have the highest level of signal interference due to the very high density of buildings and devices. + +- **jamming_zone**: A jamming zone environment where signals are actively interfered with, typically through the use of signal jammers or scrambling devices. This represents the environment with the highest level of interference. + +- **blocked**: A jamming zone environment with total levels of interference. Airspace is completely blocked. From 2a0695d0d123def6f9310ea761db4d3b9775e2f6 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 5 Jul 2024 15:06:17 +0100 Subject: [PATCH 03/35] #2688: apply the request validators + fixing the fix duration test + refactor test class names --- .../simulator/network/hardware/base.py | 90 ++++++++++++++++++- .../system/applications/application.py | 53 ++++++++++- .../simulator/system/services/service.py | 84 +++++++++++++++-- .../assets/configs/software_fix_duration.yaml | 3 + tests/conftest.py | 15 ++-- .../test_software_fix_duration.py | 34 +++---- .../actions/test_node_request_permission.py | 1 + .../network/test_broadcast.py | 32 +++---- .../system/test_application_on_node.py | 4 +- .../test_simulation/test_request_response.py | 12 +-- 10 files changed, 263 insertions(+), 65 deletions(-) create mode 100644 tests/integration_tests/game_layer/actions/test_node_request_permission.py diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 6942d280..e728ae97 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -130,10 +130,25 @@ class NetworkInterface(SimComponent, ABC): More information in user guide and docstring for SimComponent._init_request_manager. """ + _is_network_interface_enabled = NetworkInterface._EnabledValidator(network_interface=self) + _is_network_interface_disabled = NetworkInterface._DisabledValidator(network_interface=self) + rm = super()._init_request_manager() - rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable()))) - rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable()))) + rm.add_request( + "enable", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.enable()), + validator=_is_network_interface_disabled, + ), + ) + rm.add_request( + "disable", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.disable()), + validator=_is_network_interface_enabled, + ), + ) return rm @@ -332,6 +347,50 @@ class NetworkInterface(SimComponent, ABC): super().pre_timestep(timestep) self.traffic = {} + class _EnabledValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only let them through if the NetworkInterface is enabled. + + This is useful because most actions should be being resolved if the NetworkInterface is disabled. + """ + + network_interface: NetworkInterface + """Save a reference to the node instance.""" + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Return whether the NetworkInterface is enabled or not.""" + return self.network_interface.enabled + + @property + def fail_message(self) -> str: + """Message that is reported when a request is rejected by this validator.""" + return ( + f"Cannot perform request on NetworkInterface " + f"'{self.network_interface.mac_address}' because it is not enabled." + ) + + class _DisabledValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only let them through if the NetworkInterface is disabled. + + This is useful because some actions should be being resolved if the NetworkInterface is disabled. + """ + + network_interface: NetworkInterface + """Save a reference to the node instance.""" + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Return whether the NetworkInterface is disabled or not.""" + return not self.network_interface.enabled + + @property + def fail_message(self) -> str: + """Message that is reported when a request is rejected by this validator.""" + return ( + f"Cannot perform request on NetworkInterface " + f"'{self.network_interface.mac_address}' because it is not disabled." + ) + class WiredNetworkInterface(NetworkInterface, ABC): """ @@ -878,6 +937,25 @@ class Node(SimComponent): """Message that is reported when a request is rejected by this validator.""" return f"Cannot perform request on node '{self.node.hostname}' because it is not turned on." + class _NodeIsOffValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only let them through if the node is off. + + This is useful because some actions require the node to be in an off state. + """ + + node: Node + """Save a reference to the node instance.""" + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Return whether the node is on or off.""" + return self.node.operating_state == NodeOperatingState.OFF + + @property + def fail_message(self) -> str: + """Message that is reported when a request is rejected by this validator.""" + return f"Cannot perform request on node '{self.node.hostname}' because it is not turned off." + def _init_request_manager(self) -> RequestManager: """ Initialise the request manager. @@ -940,6 +1018,7 @@ class Node(SimComponent): return RequestResponse.from_bool(False) _node_is_on = Node._NodeIsOnValidator(node=self) + _node_is_off = Node._NodeIsOffValidator(node=self) rm = super()._init_request_manager() # since there are potentially many services, create an request manager that can map service name @@ -969,7 +1048,12 @@ class Node(SimComponent): func=lambda request, context: RequestResponse.from_bool(self.power_off()), validator=_node_is_on ), ) - rm.add_request("startup", RequestType(func=lambda request, context: RequestResponse.from_bool(self.power_on()))) + rm.add_request( + "startup", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.power_on()), validator=_node_is_off + ), + ) rm.add_request( "reset", RequestType(func=lambda request, context: RequestResponse.from_bool(self.reset()), validator=_node_is_on), diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 848e1ef0..dc16a725 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -1,10 +1,12 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from __future__ import annotations + from abc import abstractmethod from enum import Enum from typing import Any, ClassVar, Dict, Optional, Set, Type -from primaite.interface.request import RequestResponse -from primaite.simulator.core import RequestManager, RequestType +from primaite.interface.request import RequestFormat, RequestResponse +from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType from primaite.simulator.system.software import IOSoftware, SoftwareHealthState @@ -64,9 +66,27 @@ class Application(IOSoftware): More information in user guide and docstring for SimComponent._init_request_manager. """ - rm = super()._init_request_manager() + _is_application_running = Application._StateValidator(application=self, state=ApplicationOperatingState.RUNNING) - rm.add_request("close", RequestType(func=lambda request, context: RequestResponse.from_bool(self.close()))) + rm = super()._init_request_manager() + rm.add_request( + "scan", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.scan()), validator=_is_application_running + ), + ) + rm.add_request( + "close", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.close()), validator=_is_application_running + ), + ) + rm.add_request( + "fix", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.fix()), validator=_is_application_running + ), + ) return rm @abstractmethod @@ -169,3 +189,28 @@ class Application(IOSoftware): :return: True if successful, False otherwise. """ return super().receive(payload=payload, session_id=session_id, **kwargs) + + class _StateValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only let them through if the application is in the correct state. + + This is useful because most actions require the application to be in a specific state. + """ + + application: Application + """Save a reference to the application instance.""" + + state: ApplicationOperatingState + """The state of the application to validate.""" + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Return whether the application is in the state we are validating for.""" + return self.application.operating_state == self.state + + @property + def fail_message(self) -> str: + """Message that is reported when a request is rejected by this validator.""" + return ( + f"Cannot perform request on application '{self.application.name}' because it is not in the " + f"{self.state.name} state." + ) diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index e6ce2c87..8167a8a9 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -1,11 +1,13 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from __future__ import annotations + from abc import abstractmethod from enum import Enum from typing import Any, Dict, Optional from primaite import getLogger -from primaite.interface.request import RequestResponse -from primaite.simulator.core import RequestManager, RequestType +from primaite.interface.request import RequestFormat, RequestResponse +from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType from primaite.simulator.system.software import IOSoftware, SoftwareHealthState _LOGGER = getLogger(__name__) @@ -40,6 +42,7 @@ class Service(IOSoftware): restart_duration: int = 5 "How many timesteps does it take to restart this service." + restart_countdown: Optional[int] = None "If currently restarting, how many timesteps remain until the restart is finished." @@ -86,15 +89,55 @@ class Service(IOSoftware): More information in user guide and docstring for SimComponent._init_request_manager. """ + _is_service_running = Service._StateValidator(service=self, state=ServiceOperatingState.RUNNING) + _is_service_stopped = Service._StateValidator(service=self, state=ServiceOperatingState.STOPPED) + _is_service_paused = Service._StateValidator(service=self, state=ServiceOperatingState.PAUSED) + rm = super()._init_request_manager() - rm.add_request("scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()))) - rm.add_request("stop", RequestType(func=lambda request, context: RequestResponse.from_bool(self.stop()))) - rm.add_request("start", RequestType(func=lambda request, context: RequestResponse.from_bool(self.start()))) - rm.add_request("pause", RequestType(func=lambda request, context: RequestResponse.from_bool(self.pause()))) - rm.add_request("resume", RequestType(func=lambda request, context: RequestResponse.from_bool(self.resume()))) - rm.add_request("restart", RequestType(func=lambda request, context: RequestResponse.from_bool(self.restart()))) + rm.add_request( + "scan", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.scan()), validator=_is_service_running + ), + ) + rm.add_request( + "stop", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.stop()), validator=_is_service_running + ), + ) + rm.add_request( + "start", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.start()), validator=_is_service_stopped + ), + ) + rm.add_request( + "pause", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.pause()), validator=_is_service_running + ), + ) + rm.add_request( + "resume", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.resume()), validator=_is_service_paused + ), + ) + rm.add_request( + "restart", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.restart()), validator=_is_service_running + ), + ) rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable()))) rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable()))) + rm.add_request( + "fix", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.fix()), validator=_is_service_running + ), + ) return rm @abstractmethod @@ -191,3 +234,28 @@ class Service(IOSoftware): self.sys_log.debug(f"Restarting finished for service {self.name}") self.operating_state = ServiceOperatingState.RUNNING self.restart_countdown -= 1 + + class _StateValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only let them through if the service is in the correct state. + + This is useful because most actions require the service to be in a specific state. + """ + + service: Service + """Save a reference to the service instance.""" + + state: ServiceOperatingState + """The state of the service to validate.""" + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Return whether the service is in the state we are validating for.""" + return self.service.operating_state == self.state + + @property + def fail_message(self) -> str: + """Message that is reported when a request is rejected by this validator.""" + return ( + f"Cannot perform request on service '{self.service.name}' because it is not in the " + f"{self.state.name} state." + ) diff --git a/tests/assets/configs/software_fix_duration.yaml b/tests/assets/configs/software_fix_duration.yaml index beb176d1..1acb05a9 100644 --- a/tests/assets/configs/software_fix_duration.yaml +++ b/tests/assets/configs/software_fix_duration.yaml @@ -177,6 +177,9 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: + - type: NMAP + options: + fix_duration: 1 - type: RansomwareScript options: fix_duration: 1 diff --git a/tests/conftest.py b/tests/conftest.py index 980e4aa9..e36a2460 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,11 +51,11 @@ class TestService(Service): pass -class DummyApplication(Application, identifier="DummyApplication"): +class TestDummyApplication(Application, identifier="TestDummyApplication"): """Test Application class""" def __init__(self, **kwargs): - kwargs["name"] = "DummyApplication" + kwargs["name"] = "TestDummyApplication" kwargs["port"] = Port.HTTP kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) @@ -85,15 +85,18 @@ def service_class(): @pytest.fixture(scope="function") -def application(file_system) -> DummyApplication: - return DummyApplication( - name="DummyApplication", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_application") +def application(file_system) -> TestDummyApplication: + return TestDummyApplication( + name="TestDummyApplication", + port=Port.ARP, + file_system=file_system, + sys_log=SysLog(hostname="dummy_application"), ) @pytest.fixture(scope="function") def application_class(): - return DummyApplication + return TestDummyApplication @pytest.fixture(scope="function") diff --git a/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py b/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py index bf325946..04160f8f 100644 --- a/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py +++ b/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py @@ -1,35 +1,23 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK import copy -from ipaddress import IPv4Address from pathlib import Path from typing import Union import yaml -from primaite.config.load import data_manipulation_config_path -from primaite.game.agent.interface import ProxyAgent -from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent -from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent -from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING -from primaite.simulator.network.container import Network +from primaite.game.game import PrimaiteGame, SERVICE_TYPES_MAPPING from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient -from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot -from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot -from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.dns.dns_client import DNSClient -from primaite.simulator.system.services.dns.dns_server import DNSServer -from primaite.simulator.system.services.ftp.ftp_client import FTPClient -from primaite.simulator.system.services.ftp.ftp_server import FTPServer -from primaite.simulator.system.services.ntp.ntp_client import NTPClient -from primaite.simulator.system.services.ntp.ntp_server import NTPServer -from primaite.simulator.system.services.web_server.web_server import WebServer from tests import TEST_ASSETS_ROOT TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fix_duration.yaml" ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fix_duration_one_item.yaml" +TestApplications = ["TestDummyApplication", "TestBroadcastClient"] + def load_config(config_path: Union[str, Path]) -> PrimaiteGame: """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" @@ -62,9 +50,12 @@ def test_fix_duration_set_from_config(): assert client_1.software_manager.software.get(service).fixing_duration == 3 # in config - applications take 1 timestep to fix - for applications in APPLICATION_TYPES_MAPPING: - assert client_1.software_manager.software.get(applications) is not None - assert client_1.software_manager.software.get(applications).fixing_duration == 1 + # remove test applications from list + applications = set(Application._application_registry) - set(TestApplications) + + for application in applications: + assert client_1.software_manager.software.get(application) is not None + assert client_1.software_manager.software.get(application).fixing_duration == 1 def test_fix_duration_for_one_item(): @@ -80,8 +71,9 @@ def test_fix_duration_for_one_item(): assert client_1.software_manager.software.get(service).fixing_duration == 2 # in config - applications take 1 timestep to fix - applications = copy.copy(APPLICATION_TYPES_MAPPING) - applications.pop("DatabaseClient") + # remove test applications from list + applications = set(Application._application_registry) - set(TestApplications) + applications.remove("DatabaseClient") for applications in applications: assert client_1.software_manager.software.get(applications) is not None assert client_1.software_manager.software.get(applications).fixing_duration == 2 diff --git a/tests/integration_tests/game_layer/actions/test_node_request_permission.py b/tests/integration_tests/game_layer/actions/test_node_request_permission.py new file mode 100644 index 00000000..be6c00e7 --- /dev/null +++ b/tests/integration_tests/game_layer/actions/test_node_request_permission.py @@ -0,0 +1 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index b89d6db6..bcf7b9b0 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -14,7 +14,7 @@ from primaite.simulator.system.applications.application import Application from primaite.simulator.system.services.service import Service -class BroadcastService(Service): +class TestBroadcastService(Service): """A service for sending broadcast and unicast messages over a network.""" def __init__(self, **kwargs): @@ -41,14 +41,14 @@ class BroadcastService(Service): super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port.HTTP, ip_protocol=self.protocol) -class BroadcastClient(Application, identifier="BroadcastClient"): +class TestBroadcastClient(Application, identifier="TestBroadcastClient"): """A client application to receive broadcast and unicast messages.""" payloads_received: List = [] def __init__(self, **kwargs): # Set default client properties - kwargs["name"] = "BroadcastClient" + kwargs["name"] = "TestBroadcastClient" kwargs["port"] = Port.HTTP kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) @@ -75,8 +75,8 @@ def broadcast_network() -> Network: start_up_duration=0, ) client_1.power_on() - client_1.software_manager.install(BroadcastClient) - application_1 = client_1.software_manager.software["BroadcastClient"] + client_1.software_manager.install(TestBroadcastClient) + application_1 = client_1.software_manager.software["TestBroadcastClient"] application_1.run() client_2 = Computer( @@ -87,8 +87,8 @@ def broadcast_network() -> Network: start_up_duration=0, ) client_2.power_on() - client_2.software_manager.install(BroadcastClient) - application_2 = client_2.software_manager.software["BroadcastClient"] + client_2.software_manager.install(TestBroadcastClient) + application_2 = client_2.software_manager.software["TestBroadcastClient"] application_2.run() server_1 = Server( @@ -100,8 +100,8 @@ def broadcast_network() -> Network: ) server_1.power_on() - server_1.software_manager.install(BroadcastService) - service: BroadcastService = server_1.software_manager.software["BroadcastService"] + server_1.software_manager.install(TestBroadcastService) + service: TestBroadcastService = server_1.software_manager.software["BroadcastService"] service.start() switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0) @@ -115,14 +115,16 @@ def broadcast_network() -> Network: @pytest.fixture(scope="function") -def broadcast_service_and_clients(broadcast_network) -> Tuple[BroadcastService, BroadcastClient, BroadcastClient]: - client_1: BroadcastClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[ - "BroadcastClient" +def broadcast_service_and_clients( + broadcast_network, +) -> Tuple[TestBroadcastService, TestBroadcastClient, TestBroadcastClient]: + client_1: TestBroadcastClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[ + "TestBroadcastClient" ] - client_2: BroadcastClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[ - "BroadcastClient" + client_2: TestBroadcastClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[ + "TestBroadcastClient" ] - service: BroadcastService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[ + service: TestBroadcastService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[ "BroadcastService" ] diff --git a/tests/integration_tests/system/test_application_on_node.py b/tests/integration_tests/system/test_application_on_node.py index ffb5cc7f..400ab082 100644 --- a/tests/integration_tests/system/test_application_on_node.py +++ b/tests/integration_tests/system/test_application_on_node.py @@ -21,7 +21,7 @@ def populated_node(application_class) -> Tuple[Application, Computer]: computer.power_on() computer.software_manager.install(application_class) - app = computer.software_manager.software.get("DummyApplication") + app = computer.software_manager.software.get("TestDummyApplication") app.run() return app, computer @@ -39,7 +39,7 @@ def test_application_on_offline_node(application_class): ) computer.software_manager.install(application_class) - app: Application = computer.software_manager.software.get("DummyApplication") + app: Application = computer.software_manager.software.get("TestDummyApplication") computer.power_off() diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py index a9f0b58d..29c70566 100644 --- a/tests/integration_tests/test_simulation/test_request_response.py +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -13,7 +13,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.transmission.transport_layer import Port -from tests.conftest import DummyApplication, TestService +from tests.conftest import TestDummyApplication, TestService def test_successful_node_file_system_creation_request(example_network): @@ -47,14 +47,14 @@ def test_successful_application_requests(example_network): net = example_network client_1 = net.get_node_by_hostname("client_1") - client_1.software_manager.install(DummyApplication) - client_1.software_manager.software.get("DummyApplication").run() + client_1.software_manager.install(TestDummyApplication) + client_1.software_manager.software.get("TestDummyApplication").run() - resp_1 = net.apply_request(["node", "client_1", "application", "DummyApplication", "scan"]) + resp_1 = net.apply_request(["node", "client_1", "application", "TestDummyApplication", "scan"]) assert resp_1 == RequestResponse(status="success", data={}) - resp_2 = net.apply_request(["node", "client_1", "application", "DummyApplication", "fix"]) + resp_2 = net.apply_request(["node", "client_1", "application", "TestDummyApplication", "fix"]) assert resp_2 == RequestResponse(status="success", data={}) - resp_3 = net.apply_request(["node", "client_1", "application", "DummyApplication", "compromise"]) + resp_3 = net.apply_request(["node", "client_1", "application", "TestDummyApplication", "compromise"]) assert resp_3 == RequestResponse(status="success", data={}) From 4410e05e3ef0cd2758017063d0890a3d688afff2 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 5 Jul 2024 16:27:03 +0100 Subject: [PATCH 04/35] #2967 - Updated the DB filesize so that it doesn't fill the 100mbit link. moved the can transmit checks to the network interface to enable frame dropped syslog. narrowed the scope of the NODE_NMAP_PORT_SCAN action in nmap_port_scan_red_agent_config.yaml to select ports and protocols as the link was filling up on the full box scan. --- .../simulator/file_system/file_type.py | 2 +- src/primaite/simulator/network/airspace.py | 23 +++++++++--------- .../simulator/network/hardware/base.py | 24 +++++++++---------- .../network/hardware/nodes/network/switch.py | 16 ++++++++----- .../nmap_port_scan_red_agent_config.yaml | 6 +++++ .../configs/wireless_wan_network_config.yaml | 2 +- tests/conftest.py | 3 +-- 7 files changed, 42 insertions(+), 34 deletions(-) diff --git a/src/primaite/simulator/file_system/file_type.py b/src/primaite/simulator/file_system/file_type.py index 8f0cb778..e6e81070 100644 --- a/src/primaite/simulator/file_system/file_type.py +++ b/src/primaite/simulator/file_system/file_type.py @@ -185,5 +185,5 @@ file_type_sizes_bytes = { FileType.ZIP: 1024000, FileType.TAR: 1024000, FileType.GZ: 819200, - FileType.DB: 15360000, + FileType.DB: 5_000_000, } diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index 6060d969..2ac11a20 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -720,17 +720,18 @@ class WirelessNetworkInterface(NetworkInterface, ABC): :param frame: The network frame to be sent. :return: True if the frame is sent successfully, False if the network interface is disabled. """ - if self.enabled: - frame.set_sent_timestamp() - self.pcap.capture_outbound(frame) - if self.airspace.can_transmit_frame(frame, self): - self.airspace.transmit(frame, self) - return True - else: - # Cannot send Frame as the frequency bandwidth is at capacity - return False - # Cannot send Frame as the network interface is not enabled - return False + if not self.enabled: + return False + if not self.airspace.can_transmit_frame(frame, self): + # Drop frame for now. Queuing will happen here (probably) if it's done in the future. + self._connected_node.sys_log.info(f"{self}: Frame dropped as Link is at capacity") + return False + + super().send_frame(frame) + frame.set_sent_timestamp() + self.pcap.capture_outbound(frame) + self.airspace.transmit(frame, self) + return True def receive_frame(self, frame: Frame) -> bool: """ diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 743b2e76..5ed27658 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -440,14 +440,17 @@ class WiredNetworkInterface(NetworkInterface, ABC): :param frame: The network frame to be sent. :return: True if the frame is sent, False if the Network Interface is disabled or not connected to a link. """ + if not self.enabled: + return False + if not self._connected_link.can_transmit_frame(frame): + # Drop frame for now. Queuing will happen here (probably) if it's done in the future. + self._connected_node.sys_log.info(f"{self}: Frame dropped as Link is at capacity") + return False super().send_frame(frame) - if self.enabled: - frame.set_sent_timestamp() - self.pcap.capture_outbound(frame) - self._connected_link.transmit_frame(sender_nic=self, frame=frame) - return True - # Cannot send Frame as the NIC is not enabled - return False + frame.set_sent_timestamp() + self.pcap.capture_outbound(frame) + self._connected_link.transmit_frame(sender_nic=self, frame=frame) + return True @abstractmethod def receive_frame(self, frame: Frame) -> bool: @@ -678,7 +681,7 @@ class Link(SimComponent): """ return self.endpoint_a.enabled and self.endpoint_b.enabled - def _can_transmit(self, frame: Frame) -> bool: + def can_transmit_frame(self, frame: Frame) -> bool: """ Determines whether a frame can be transmitted considering the current Link load and the Link's bandwidth. @@ -703,11 +706,6 @@ class Link(SimComponent): :param frame: The network frame to be sent. :return: True if the Frame can be sent, otherwise False. """ - can_transmit = self._can_transmit(frame) - if not can_transmit: - _LOGGER.debug(f"Cannot transmit frame as {self} is at capacity") - return False - receiver = self.endpoint_a if receiver == sender_nic: receiver = self.endpoint_b diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index 6eee0d40..1a7da2e7 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -58,12 +58,16 @@ class SwitchPort(WiredNetworkInterface): :param frame: The network frame to be sent. :return: A boolean indicating whether the frame was successfully sent. """ - if self.enabled: - self.pcap.capture_outbound(frame) - self._connected_link.transmit_frame(sender_nic=self, frame=frame) - return True - # Cannot send Frame as the SwitchPort is not enabled - return False + if not self.enabled: + return False + if not self._connected_link.can_transmit_frame(frame): + # Drop frame for now. Queuing will happen here (probably) if it's done in the future. + self._connected_node.sys_log.info(f"{self}: Frame dropped as Link is at capacity") + return False + + self.pcap.capture_outbound(frame) + self._connected_link.transmit_frame(sender_nic=self, frame=frame) + return True def receive_frame(self, frame: Frame) -> bool: """ diff --git a/tests/assets/configs/nmap_port_scan_red_agent_config.yaml b/tests/assets/configs/nmap_port_scan_red_agent_config.yaml index 08944ee5..8ed715c1 100644 --- a/tests/assets/configs/nmap_port_scan_red_agent_config.yaml +++ b/tests/assets/configs/nmap_port_scan_red_agent_config.yaml @@ -41,6 +41,12 @@ agents: options: source_node: client_1 target_ip_address: 192.168.10.0/24 + target_port: + - 21 + - 53 + - 80 + - 123 + - 219 reward_function: reward_components: diff --git a/tests/assets/configs/wireless_wan_network_config.yaml b/tests/assets/configs/wireless_wan_network_config.yaml index 684acaf7..7172f66d 100644 --- a/tests/assets/configs/wireless_wan_network_config.yaml +++ b/tests/assets/configs/wireless_wan_network_config.yaml @@ -10,7 +10,7 @@ game: simulation: network: airspace: - airspace_environment_type: blocked + airspace_environment_type: urban nodes: - type: computer hostname: pc_a diff --git a/tests/conftest.py b/tests/conftest.py index b8359323..a0117eb6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -252,8 +252,7 @@ def example_network() -> Network: server_2.power_on() network.connect(endpoint_b=server_2.network_interface[1], endpoint_a=switch_1.network_interface[2]) - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, position=1) assert all(link.is_up for link in network.links.values()) From 7173c329b09c8ccf95f2d9fe2aedc2ba72bdedb4 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 5 Jul 2024 16:56:07 +0100 Subject: [PATCH 05/35] #2739 - Updated azure-benchmark-pipeline.yaml to allow it to run for unlimited time on the Imaginary Yak Pool --- .azure/azure-benchmark-pipeline.yaml | 119 ++++++++++++++------------- 1 file changed, 61 insertions(+), 58 deletions(-) diff --git a/.azure/azure-benchmark-pipeline.yaml b/.azure/azure-benchmark-pipeline.yaml index 1f7b8ebe..e764c3c1 100644 --- a/.azure/azure-benchmark-pipeline.yaml +++ b/.azure/azure-benchmark-pipeline.yaml @@ -11,74 +11,77 @@ schedules: branches: include: - 'refs/heads/dev' - -pool: - vmImage: ubuntu-latest - variables: VERSION: '' MAJOR_VERSION: '' -steps: -- checkout: self - persistCredentials: true +jobs: +- job: PrimAITE Benchmark + timeoutInMinutes: 0 # Set to unlimited timeout + pool: + name: 'Imaginary Yak Pool' + workspace: + clean: all + steps: + - checkout: self + persistCredentials: true -- script: | - VERSION=$(cat src/primaite/VERSION | tr -d '\n') - if [[ "$(Build.SourceBranch)" == "refs/heads/dev" ]]; then - DATE=$(date +%Y%m%d) - echo "${VERSION}+dev.${DATE}" > src/primaite/VERSION - fi - displayName: 'Update VERSION file for Dev Benchmark' + - script: | + VERSION=$(cat src/primaite/VERSION | tr -d '\n') + if [[ "$(Build.SourceBranch)" == "refs/heads/dev" ]]; then + DATE=$(date +%Y%m%d) + echo "${VERSION}+dev.${DATE}" > src/primaite/VERSION + fi + displayName: 'Update VERSION file for Dev Benchmark' -- script: | - VERSION=$(cat src/primaite/VERSION | tr -d '\n') - MAJOR_VERSION=$(echo $VERSION | cut -d. -f1) - echo "##vso[task.setvariable variable=VERSION]$VERSION" - echo "##vso[task.setvariable variable=MAJOR_VERSION]$MAJOR_VERSION" - displayName: 'Set Version Variables' + - script: | + VERSION=$(cat src/primaite/VERSION | tr -d '\n') + MAJOR_VERSION=$(echo $VERSION | cut -d. -f1) + echo "##vso[task.setvariable variable=VERSION]$VERSION" + echo "##vso[task.setvariable variable=MAJOR_VERSION]$MAJOR_VERSION" + displayName: 'Set Version Variables' -- task: UsePythonVersion@0 - inputs: - versionSpec: '3.11' - addToPath: true + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.11' + addToPath: true -- script: | - python -m pip install --upgrade pip - pip install -e .[dev,rl] - primaite setup - displayName: 'Install Dependencies' + - script: | + python -m pip install --upgrade pip + pip install -e .[dev,rl] + primaite setup + displayName: 'Install Dependencies' -- script: | - cd benchmark - python3 primaite_benchmark.py - cd .. - displayName: 'Run Benchmarking Script' + - script: | + cd benchmark + python3 primaite_benchmark.py + cd .. + displayName: 'Run Benchmarking Script' -- script: | - git config --global user.email "oss@dstl.gov.uk" - git config --global user.name "Defence Science and Technology Laboratory UK" - workingDirectory: $(System.DefaultWorkingDirectory) - displayName: 'Configure Git' - condition: and(succeeded(), eq(variables['Build.Reason'], 'Manual'), startsWith(variables['Build.SourceBranch'], 'refs/heads/release')) + - script: | + git config --global user.email "oss@dstl.gov.uk" + git config --global user.name "Defence Science and Technology Laboratory UK" + workingDirectory: $(System.DefaultWorkingDirectory) + displayName: 'Configure Git' + condition: and(succeeded(), eq(variables['Build.Reason'], 'Manual'), startsWith(variables['Build.SourceBranch'], 'refs/heads/release')) -- script: | - git add benchmark/results/v$(MAJOR_VERSION)/v$(VERSION)/* - git commit -m "Automated benchmark output commit for version $(VERSION)" - git push origin HEAD:refs/heads/$(Build.SourceBranchName) - displayName: 'Commit and Push Benchmark Results' - workingDirectory: $(System.DefaultWorkingDirectory) - env: - GIT_CREDENTIALS: $(System.AccessToken) - condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/heads/release')) + - script: | + git add benchmark/results/v$(MAJOR_VERSION)/v$(VERSION)/* + git commit -m "Automated benchmark output commit for version $(VERSION)" + git push origin HEAD:refs/heads/$(Build.SourceBranchName) + displayName: 'Commit and Push Benchmark Results' + workingDirectory: $(System.DefaultWorkingDirectory) + env: + GIT_CREDENTIALS: $(System.AccessToken) + condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/heads/release')) -- script: | - tar czf primaite_v$(VERSION)_benchmark.tar.gz benchmark/results/v$(MAJOR_VERSION)/v$(VERSION) - displayName: 'Prepare Artifacts for Publishing' + - script: | + tar czf primaite_v$(VERSION)_benchmark.tar.gz benchmark/results/v$(MAJOR_VERSION)/v$(VERSION) + displayName: 'Prepare Artifacts for Publishing' -- task: PublishPipelineArtifact@1 - inputs: - targetPath: primaite_v$(VERSION)_benchmark.tar.gz - artifactName: 'benchmark-output' - publishLocation: 'pipeline' - displayName: 'Publish Benchmark Output as Artifact' + - task: PublishPipelineArtifact@1 + inputs: + targetPath: primaite_v$(VERSION)_benchmark.tar.gz + artifactName: 'benchmark-output' + publishLocation: 'pipeline' + displayName: 'Publish Benchmark Output as Artifact' From 5e8343ca9154870b5eb6c5a2f88cd75a64903dd3 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 5 Jul 2024 16:57:40 +0100 Subject: [PATCH 06/35] #2739 - update job name in azure-benchmark-pipeline.yaml --- .azure/azure-benchmark-pipeline.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.azure/azure-benchmark-pipeline.yaml b/.azure/azure-benchmark-pipeline.yaml index e764c3c1..e8e6ec9c 100644 --- a/.azure/azure-benchmark-pipeline.yaml +++ b/.azure/azure-benchmark-pipeline.yaml @@ -16,7 +16,7 @@ variables: MAJOR_VERSION: '' jobs: -- job: PrimAITE Benchmark +- job: PrimAITE_Benchmark timeoutInMinutes: 0 # Set to unlimited timeout pool: name: 'Imaginary Yak Pool' From c14699230714843406e8538f0e0d39bfb7babf7f Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 5 Jul 2024 17:02:35 +0100 Subject: [PATCH 07/35] #2739 - updated azure-benchmark-pipeline.yaml to run on ubuntu-latest while we wait for authorisation to use the yak pool --- .azure/azure-benchmark-pipeline.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.azure/azure-benchmark-pipeline.yaml b/.azure/azure-benchmark-pipeline.yaml index e8e6ec9c..350123e5 100644 --- a/.azure/azure-benchmark-pipeline.yaml +++ b/.azure/azure-benchmark-pipeline.yaml @@ -19,7 +19,7 @@ jobs: - job: PrimAITE_Benchmark timeoutInMinutes: 0 # Set to unlimited timeout pool: - name: 'Imaginary Yak Pool' + vmImage: ubuntu-latest workspace: clean: all steps: From 2a003eece9d9845e4cbad1ea5bb6c02f7435af20 Mon Sep 17 00:00:00 2001 From: Christopher McCarthy Date: Mon, 8 Jul 2024 08:25:05 +0000 Subject: [PATCH 08/35] Updated v3.0.0_benchmark_metadata.json so that combined_av_reward_per_episode is now named combined_total_reward_per_episode to match the new script --- benchmark/results/v3/v3.0.0/v3.0.0_benchmark_metadata.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/results/v3/v3.0.0/v3.0.0_benchmark_metadata.json b/benchmark/results/v3/v3.0.0/v3.0.0_benchmark_metadata.json index b6780eac..ed3ea4eb 100644 --- a/benchmark/results/v3/v3.0.0/v3.0.0_benchmark_metadata.json +++ b/benchmark/results/v3/v3.0.0/v3.0.0_benchmark_metadata.json @@ -26,7 +26,7 @@ "av_s_per_session": 3205.6340542, "av_s_per_step": 0.10017606419375, "av_s_per_100_steps_10_nodes": 10.017606419375, - "combined_av_reward_per_episode": { + "combined_total_reward_per_episode": { "1": -53.42999999999999, "2": -25.18000000000001, "3": -42.00000000000002, From 2d2f2df360d12e1e27c38b3ea8f2baa05090ffb8 Mon Sep 17 00:00:00 2001 From: Christopher McCarthy Date: Mon, 8 Jul 2024 08:26:52 +0000 Subject: [PATCH 09/35] Added set -e to the Run Benchmark Script step in azure-benchmark-pipeline.yaml so that the pipeline fails if the python script fails. --- .azure/azure-benchmark-pipeline.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.azure/azure-benchmark-pipeline.yaml b/.azure/azure-benchmark-pipeline.yaml index 350123e5..7eab2114 100644 --- a/.azure/azure-benchmark-pipeline.yaml +++ b/.azure/azure-benchmark-pipeline.yaml @@ -53,6 +53,7 @@ jobs: displayName: 'Install Dependencies' - script: | + set -e cd benchmark python3 primaite_benchmark.py cd .. From 829a6371deb0a0c18b5ce3e97bb8baa15e580e1d Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 8 Jul 2024 14:39:37 +0100 Subject: [PATCH 10/35] #2688: tests --- .../test_application_request_permission.py | 54 +++++++++ .../actions/test_nic_request_permission.py | 97 ++++++++++++++++ .../actions/test_node_request_permission.py | 93 +++++++++++++++ .../test_service_request_permission.py | 106 ++++++++++++++++++ 4 files changed, 350 insertions(+) create mode 100644 tests/integration_tests/game_layer/actions/test_application_request_permission.py create mode 100644 tests/integration_tests/game_layer/actions/test_nic_request_permission.py create mode 100644 tests/integration_tests/game_layer/actions/test_service_request_permission.py diff --git a/tests/integration_tests/game_layer/actions/test_application_request_permission.py b/tests/integration_tests/game_layer/actions/test_application_request_permission.py new file mode 100644 index 00000000..36a7ae57 --- /dev/null +++ b/tests/integration_tests/game_layer/actions/test_application_request_permission.py @@ -0,0 +1,54 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Tuple + +import pytest + +from primaite.game.agent.interface import ProxyAgent +from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.system.applications.application import ApplicationOperatingState +from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.simulator.system.services.service import ServiceOperatingState + + +@pytest.fixture +def game_and_agent_fixture(game_and_agent): + """Create a game with a simple agent that can be controlled by the tests.""" + game, agent = game_and_agent + + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + client_1.start_up_duration = 3 + + return (game, agent) + + +def test_application_cannot_perform_actions_unless_running(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test the the request permissions prevent any actions unless application is running.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") + + browser.close() + assert browser.operating_state == ApplicationOperatingState.CLOSED + + action = ("NODE_APPLICATION_SCAN", {"node_id": 0, "application_id": 0}) + agent.store_action(action) + game.step() + assert browser.operating_state == ApplicationOperatingState.CLOSED + + action = ("NODE_APPLICATION_CLOSE", {"node_id": 0, "application_id": 0}) + agent.store_action(action) + game.step() + assert browser.operating_state == ApplicationOperatingState.CLOSED + + action = ("NODE_APPLICATION_FIX", {"node_id": 0, "application_id": 0}) + agent.store_action(action) + game.step() + assert browser.operating_state == ApplicationOperatingState.CLOSED + + action = ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}) + agent.store_action(action) + game.step() + assert browser.operating_state == ApplicationOperatingState.CLOSED diff --git a/tests/integration_tests/game_layer/actions/test_nic_request_permission.py b/tests/integration_tests/game_layer/actions/test_nic_request_permission.py new file mode 100644 index 00000000..4c1619e7 --- /dev/null +++ b/tests/integration_tests/game_layer/actions/test_nic_request_permission.py @@ -0,0 +1,97 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Tuple + +import pytest + +from primaite.game.agent.interface import ProxyAgent +from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.system.services.service import ServiceOperatingState + + +@pytest.fixture +def game_and_agent_fixture(game_and_agent): + """Create a game with a simple agent that can be controlled by the tests.""" + game, agent = game_and_agent + + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + client_1.start_up_duration = 3 + + return (game, agent) + + +def test_nic_cannot_be_turned_off_if_not_on(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that a NIC cannot be disabled if it is not enabled.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + nic = client_1.network_interface[1] + nic.disable() + assert nic.enabled is False + + action = ( + "HOST_NIC_DISABLE", + { + "node_id": 0, # client_1 + "nic_id": 0, # the only nic (eth-1) + }, + ) + agent.store_action(action) + game.step() + + assert nic.enabled is False + + +def test_nic_cannot_be_turned_on_if_already_on(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that a NIC cannot be enabled if it is already enabled.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + nic = client_1.network_interface[1] + assert nic.enabled + + action = ( + "HOST_NIC_ENABLE", + { + "node_id": 0, # client_1 + "nic_id": 0, # the only nic (eth-1) + }, + ) + agent.store_action(action) + game.step() + + assert nic.enabled + + +def test_that_a_nic_can_be_enabled_and_disabled(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Tests that a NIC can be enabled and disabled.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + nic = client_1.network_interface[1] + assert nic.enabled + + action = ( + "HOST_NIC_DISABLE", + { + "node_id": 0, # client_1 + "nic_id": 0, # the only nic (eth-1) + }, + ) + agent.store_action(action) + game.step() + + assert nic.enabled is False + + action = ( + "HOST_NIC_ENABLE", + { + "node_id": 0, # client_1 + "nic_id": 0, # the only nic (eth-1) + }, + ) + agent.store_action(action) + game.step() + + assert nic.enabled diff --git a/tests/integration_tests/game_layer/actions/test_node_request_permission.py b/tests/integration_tests/game_layer/actions/test_node_request_permission.py index be6c00e7..fdf04ad5 100644 --- a/tests/integration_tests/game_layer/actions/test_node_request_permission.py +++ b/tests/integration_tests/game_layer/actions/test_node_request_permission.py @@ -1 +1,94 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Tuple + +import pytest + +from primaite.game.agent.interface import ProxyAgent +from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.host.computer import Computer + + +@pytest.fixture +def game_and_agent_fixture(game_and_agent): + """Create a game with a simple agent that can be controlled by the tests.""" + game, agent = game_and_agent + + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + client_1.start_up_duration = 3 + + return (game, agent) + + +def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that the node can be shut down and started up.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + assert client_1.operating_state == NodeOperatingState.ON + + # turn it off + action = ("NODE_SHUTDOWN", {"node_id": 0}) + agent.store_action(action) + game.step() + + assert client_1.operating_state == NodeOperatingState.SHUTTING_DOWN + + for i in range(client_1.shut_down_duration + 1): + action = ("DONOTHING", {"node_id": 0}) + agent.store_action(action) + game.step() + + assert client_1.operating_state == NodeOperatingState.OFF + + # turn it on + action = ("NODE_STARTUP", {"node_id": 0}) + agent.store_action(action) + game.step() + + assert client_1.operating_state == NodeOperatingState.BOOTING + + for i in range(client_1.start_up_duration + 1): + action = ("DONOTHING", {"node_id": 0}) + agent.store_action(action) + game.step() + + assert client_1.operating_state == NodeOperatingState.ON + + +def test_node_cannot_be_started_up_if_node_is_already_on(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that a node cannot be started up if it is already on.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + assert client_1.operating_state == NodeOperatingState.ON + + # turn it on + action = ("NODE_STARTUP", {"node_id": 0}) + agent.store_action(action) + game.step() + + assert client_1.operating_state == NodeOperatingState.ON + + +def test_node_cannot_be_shut_down_if_node_is_already_off(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that a node cannot be shut down if it is already off.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + client_1.power_off() + + for i in range(client_1.shut_down_duration + 1): + action = ("DONOTHING", {"node_id": 0}) + agent.store_action(action) + game.step() + + assert client_1.operating_state == NodeOperatingState.OFF + + # turn it ff + action = ("NODE_SHUTDOWN", {"node_id": 0}) + agent.store_action(action) + game.step() + + assert client_1.operating_state == NodeOperatingState.OFF diff --git a/tests/integration_tests/game_layer/actions/test_service_request_permission.py b/tests/integration_tests/game_layer/actions/test_service_request_permission.py new file mode 100644 index 00000000..3054c73b --- /dev/null +++ b/tests/integration_tests/game_layer/actions/test_service_request_permission.py @@ -0,0 +1,106 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Tuple + +import pytest + +from primaite.game.agent.interface import ProxyAgent +from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.system.services.service import ServiceOperatingState + + +@pytest.fixture +def game_and_agent_fixture(game_and_agent): + """Create a game with a simple agent that can be controlled by the tests.""" + game, agent = game_and_agent + + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + client_1.start_up_duration = 3 + + return (game, agent) + + +def test_service_start(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that the validator makes sure that the service is stopped before starting the service.""" + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + dns_server = server_1.software_manager.software.get("DNSServer") + + dns_server.pause() + assert dns_server.operating_state == ServiceOperatingState.PAUSED + + action = ("NODE_SERVICE_START", {"node_id": 1, "service_id": 0}) + agent.store_action(action) + game.step() + assert dns_server.operating_state == ServiceOperatingState.PAUSED + + dns_server.stop() + + assert dns_server.operating_state == ServiceOperatingState.STOPPED + + action = ("NODE_SERVICE_START", {"node_id": 1, "service_id": 0}) + agent.store_action(action) + game.step() + + assert dns_server.operating_state == ServiceOperatingState.RUNNING + + +def test_service_resume(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that the validator checks if the service is paused before resuming.""" + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + dns_server = server_1.software_manager.software.get("DNSServer") + + action = ("NODE_SERVICE_RESUME", {"node_id": 1, "service_id": 0}) + agent.store_action(action) + game.step() + assert dns_server.operating_state == ServiceOperatingState.RUNNING + + dns_server.pause() + + assert dns_server.operating_state == ServiceOperatingState.PAUSED + + action = ("NODE_SERVICE_RESUME", {"node_id": 1, "service_id": 0}) + agent.store_action(action) + game.step() + + assert dns_server.operating_state == ServiceOperatingState.RUNNING + + +def test_service_cannot_perform_actions_unless_running(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test to make sure that the service cannot perform certain actions while not running.""" + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + dns_server = server_1.software_manager.software.get("DNSServer") + + dns_server.stop() + assert dns_server.operating_state == ServiceOperatingState.STOPPED + + action = ("NODE_SERVICE_SCAN", {"node_id": 1, "service_id": 0}) + agent.store_action(action) + game.step() + assert dns_server.operating_state == ServiceOperatingState.STOPPED + + action = ("NODE_SERVICE_PAUSE", {"node_id": 1, "service_id": 0}) + agent.store_action(action) + game.step() + assert dns_server.operating_state == ServiceOperatingState.STOPPED + + action = ("NODE_SERVICE_RESUME", {"node_id": 1, "service_id": 0}) + agent.store_action(action) + game.step() + assert dns_server.operating_state == ServiceOperatingState.STOPPED + + action = ("NODE_SERVICE_RESTART", {"node_id": 1, "service_id": 0}) + agent.store_action(action) + game.step() + assert dns_server.operating_state == ServiceOperatingState.STOPPED + + action = ("NODE_SERVICE_FIX", {"node_id": 1, "service_id": 0}) + agent.store_action(action) + game.step() + assert dns_server.operating_state == ServiceOperatingState.STOPPED From cbf54d442c37fb774b7e6708b9c2f3404df8bcaa Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 8 Jul 2024 15:17:35 +0100 Subject: [PATCH 11/35] #2623 Make it possible to view currently valid simulation requests --- src/primaite/simulator/core.py | 31 ++++++++++++++++----- src/primaite/simulator/domain/controller.py | 2 ++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 8d8425ec..be5eb4b9 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -3,9 +3,10 @@ """Core of the PrimAITE Simulator.""" import warnings from abc import abstractmethod -from typing import Callable, Dict, List, Literal, Optional, Union +from typing import Callable, Dict, List, Literal, Optional, Tuple, Union from uuid import uuid4 +from prettytable import PrettyTable from pydantic import BaseModel, ConfigDict, Field, validate_call from primaite import getLogger @@ -150,18 +151,34 @@ class RequestManager(BaseModel): self.request_types.pop(name) - def get_request_types_recursively(self) -> List[List[str]]: - """Recursively generate request tree for this component.""" + def get_request_types_recursively(self, _parent_valid: bool = True) -> List[Tuple[RequestFormat, bool]]: + """ + Recursively generate request tree for this component. + + :param parent_valid: Whether this sub-request's parent request was valid. This value should not be specified by + users, it is used by the recursive call. + :type parent_valid: bool + :returns: A list of tuples where the first tuple element is the request string and the second is whether that + request is currently possible to execute. + :rtype: List[Tuple[RequestFormat, bool]] + """ requests = [] for req_name, req in self.request_types.items(): + valid = req.validator([], {}) and _parent_valid # if parent is invalid, all children are invalid if isinstance(req.func, RequestManager): - sub_requests = req.func.get_request_types_recursively() - sub_requests = [[req_name] + a for a in sub_requests] + sub_requests = req.func.get_request_types_recursively(valid) # recurse + sub_requests = [([req_name] + a, valid) for a, valid in sub_requests] # prepend parent request to leaf requests.extend(sub_requests) - else: - requests.append([req_name]) + else: # leaf node found + requests.append(([req_name], valid)) return requests + def show(self) -> None: + table = PrettyTable(["request", "valid"]) + table.align = "l" + table.add_rows(self.get_request_types_recursively()) + print(table) + class SimComponent(BaseModel): """Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator.""" diff --git a/src/primaite/simulator/domain/controller.py b/src/primaite/simulator/domain/controller.py index 37e60aaa..a264ba24 100644 --- a/src/primaite/simulator/domain/controller.py +++ b/src/primaite/simulator/domain/controller.py @@ -52,6 +52,8 @@ class GroupMembershipValidator(RequestPermissionValidator): def __call__(self, request: List[str], context: Dict) -> bool: """Permit the action if the request comes from an account which belongs to the right group.""" # if context request source is part of any groups mentioned in self.allow_groups, return true, otherwise false + if not context: + return False requestor_groups: List[str] = context["request_source"]["groups"] for allowed_group in self.allowed_groups: if allowed_group.name in requestor_groups: From a3f74087fa27f0830f688cd4adb4837c297759f2 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 8 Jul 2024 15:26:30 +0100 Subject: [PATCH 12/35] #2688: refactor test classes --- tests/conftest.py | 12 ++++---- .../test_software_fix_duration.py | 2 +- .../network/test_broadcast.py | 30 +++++++++---------- .../system/test_application_on_node.py | 4 +-- .../test_simulation/test_request_response.py | 12 ++++---- 5 files changed, 30 insertions(+), 30 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e36a2460..e3c84e6d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,11 +51,11 @@ class TestService(Service): pass -class TestDummyApplication(Application, identifier="TestDummyApplication"): +class DummyApplication(Application, identifier="DummyApplication"): """Test Application class""" def __init__(self, **kwargs): - kwargs["name"] = "TestDummyApplication" + kwargs["name"] = "DummyApplication" kwargs["port"] = Port.HTTP kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) @@ -85,9 +85,9 @@ def service_class(): @pytest.fixture(scope="function") -def application(file_system) -> TestDummyApplication: - return TestDummyApplication( - name="TestDummyApplication", +def application(file_system) -> DummyApplication: + return DummyApplication( + name="DummyApplication", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_application"), @@ -96,7 +96,7 @@ def application(file_system) -> TestDummyApplication: @pytest.fixture(scope="function") def application_class(): - return TestDummyApplication + return DummyApplication @pytest.fixture(scope="function") diff --git a/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py b/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py index 04160f8f..ae4825ff 100644 --- a/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py +++ b/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py @@ -16,7 +16,7 @@ from tests import TEST_ASSETS_ROOT TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fix_duration.yaml" ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fix_duration_one_item.yaml" -TestApplications = ["TestDummyApplication", "TestBroadcastClient"] +TestApplications = ["DummyApplication", "BroadcastTestClient"] def load_config(config_path: Union[str, Path]) -> PrimaiteGame: diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index bcf7b9b0..80007c46 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -14,7 +14,7 @@ from primaite.simulator.system.applications.application import Application from primaite.simulator.system.services.service import Service -class TestBroadcastService(Service): +class BroadcastTestService(Service): """A service for sending broadcast and unicast messages over a network.""" def __init__(self, **kwargs): @@ -41,14 +41,14 @@ class TestBroadcastService(Service): super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port.HTTP, ip_protocol=self.protocol) -class TestBroadcastClient(Application, identifier="TestBroadcastClient"): +class BroadcastTestClient(Application, identifier="BroadcastTestClient"): """A client application to receive broadcast and unicast messages.""" payloads_received: List = [] def __init__(self, **kwargs): # Set default client properties - kwargs["name"] = "TestBroadcastClient" + kwargs["name"] = "BroadcastTestClient" kwargs["port"] = Port.HTTP kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) @@ -75,8 +75,8 @@ def broadcast_network() -> Network: start_up_duration=0, ) client_1.power_on() - client_1.software_manager.install(TestBroadcastClient) - application_1 = client_1.software_manager.software["TestBroadcastClient"] + client_1.software_manager.install(BroadcastTestClient) + application_1 = client_1.software_manager.software["BroadcastTestClient"] application_1.run() client_2 = Computer( @@ -87,8 +87,8 @@ def broadcast_network() -> Network: start_up_duration=0, ) client_2.power_on() - client_2.software_manager.install(TestBroadcastClient) - application_2 = client_2.software_manager.software["TestBroadcastClient"] + client_2.software_manager.install(BroadcastTestClient) + application_2 = client_2.software_manager.software["BroadcastTestClient"] application_2.run() server_1 = Server( @@ -100,8 +100,8 @@ def broadcast_network() -> Network: ) server_1.power_on() - server_1.software_manager.install(TestBroadcastService) - service: TestBroadcastService = server_1.software_manager.software["BroadcastService"] + server_1.software_manager.install(BroadcastTestService) + service: BroadcastTestService = server_1.software_manager.software["BroadcastService"] service.start() switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0) @@ -117,14 +117,14 @@ def broadcast_network() -> Network: @pytest.fixture(scope="function") def broadcast_service_and_clients( broadcast_network, -) -> Tuple[TestBroadcastService, TestBroadcastClient, TestBroadcastClient]: - client_1: TestBroadcastClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[ - "TestBroadcastClient" +) -> Tuple[BroadcastTestService, BroadcastTestClient, BroadcastTestClient]: + client_1: BroadcastTestClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[ + "BroadcastTestClient" ] - client_2: TestBroadcastClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[ - "TestBroadcastClient" + client_2: BroadcastTestClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[ + "BroadcastTestClient" ] - service: TestBroadcastService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[ + service: BroadcastTestService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[ "BroadcastService" ] diff --git a/tests/integration_tests/system/test_application_on_node.py b/tests/integration_tests/system/test_application_on_node.py index 400ab082..ffb5cc7f 100644 --- a/tests/integration_tests/system/test_application_on_node.py +++ b/tests/integration_tests/system/test_application_on_node.py @@ -21,7 +21,7 @@ def populated_node(application_class) -> Tuple[Application, Computer]: computer.power_on() computer.software_manager.install(application_class) - app = computer.software_manager.software.get("TestDummyApplication") + app = computer.software_manager.software.get("DummyApplication") app.run() return app, computer @@ -39,7 +39,7 @@ def test_application_on_offline_node(application_class): ) computer.software_manager.install(application_class) - app: Application = computer.software_manager.software.get("TestDummyApplication") + app: Application = computer.software_manager.software.get("DummyApplication") computer.power_off() diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py index 29c70566..a9f0b58d 100644 --- a/tests/integration_tests/test_simulation/test_request_response.py +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -13,7 +13,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.transmission.transport_layer import Port -from tests.conftest import TestDummyApplication, TestService +from tests.conftest import DummyApplication, TestService def test_successful_node_file_system_creation_request(example_network): @@ -47,14 +47,14 @@ def test_successful_application_requests(example_network): net = example_network client_1 = net.get_node_by_hostname("client_1") - client_1.software_manager.install(TestDummyApplication) - client_1.software_manager.software.get("TestDummyApplication").run() + client_1.software_manager.install(DummyApplication) + client_1.software_manager.software.get("DummyApplication").run() - resp_1 = net.apply_request(["node", "client_1", "application", "TestDummyApplication", "scan"]) + resp_1 = net.apply_request(["node", "client_1", "application", "DummyApplication", "scan"]) assert resp_1 == RequestResponse(status="success", data={}) - resp_2 = net.apply_request(["node", "client_1", "application", "TestDummyApplication", "fix"]) + resp_2 = net.apply_request(["node", "client_1", "application", "DummyApplication", "fix"]) assert resp_2 == RequestResponse(status="success", data={}) - resp_3 = net.apply_request(["node", "client_1", "application", "TestDummyApplication", "compromise"]) + resp_3 = net.apply_request(["node", "client_1", "application", "DummyApplication", "compromise"]) assert resp_3 == RequestResponse(status="success", data={}) From 470fa28ee1b38b9596b99af90cd5337018dbb6cf Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 9 Jul 2024 13:13:13 +0100 Subject: [PATCH 13/35] 2623 Implement basic action masking logic --- src/primaite/game/agent/actions.py | 38 +++++++++++++++-------------- src/primaite/session/environment.py | 23 ++++++++++++++++- src/primaite/simulator/core.py | 19 ++++++++++++++- 3 files changed, 60 insertions(+), 20 deletions(-) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index b3b7189c..9a5fedc9 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -49,7 +49,7 @@ class AbstractAction(ABC): objects.""" @abstractmethod - def form_request(self) -> List[str]: + def form_request(self) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" return [] @@ -67,7 +67,7 @@ class DoNothingAction(AbstractAction): # i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter # with one option. This just aids the Action Manager to enumerate all possibilities. - def form_request(self, **kwargs) -> List[str]: + def form_request(self, **kwargs) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" return ["do_nothing"] @@ -86,7 +86,7 @@ class NodeServiceAbstractAction(AbstractAction): self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services} self.verb: str # define but don't initialise: defends against children classes not defining this - def form_request(self, node_id: int, service_id: int) -> List[str]: + def form_request(self, node_id: int, service_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) service_name = self.manager.get_service_name_by_idx(node_id, service_id) @@ -181,7 +181,7 @@ class NodeApplicationAbstractAction(AbstractAction): self.shape: Dict[str, int] = {"node_id": num_nodes, "application_id": num_applications} self.verb: str # define but don't initialise: defends against children classes not defining this - def form_request(self, node_id: int, application_id: int) -> List[str]: + def form_request(self, node_id: int, application_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) application_name = self.manager.get_application_name_by_idx(node_id, application_id) @@ -229,7 +229,7 @@ class NodeApplicationInstallAction(AbstractAction): super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes} - def form_request(self, node_id: int, application_name: str) -> List[str]: + def form_request(self, node_id: int, application_name: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) if node_name is None: @@ -324,7 +324,7 @@ class NodeApplicationRemoveAction(AbstractAction): super().__init__(manager=manager) self.shape: Dict[str, int] = {"node_id": num_nodes} - def form_request(self, node_id: int, application_name: str) -> List[str]: + def form_request(self, node_id: int, application_name: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) if node_name is None: @@ -346,7 +346,7 @@ class NodeFolderAbstractAction(AbstractAction): self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders} self.verb: str # define but don't initialise: defends against children classes not defining this - def form_request(self, node_id: int, folder_id: int) -> List[str]: + def form_request(self, node_id: int, folder_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id) @@ -394,7 +394,9 @@ class NodeFileCreateAction(AbstractAction): super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "create" - def form_request(self, node_id: int, folder_name: str, file_name: str, force: Optional[bool] = False) -> List[str]: + def form_request( + self, node_id: int, folder_name: str, file_name: str, force: Optional[bool] = False + ) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) if node_name is None or folder_name is None or file_name is None: @@ -409,7 +411,7 @@ class NodeFolderCreateAction(AbstractAction): super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "create" - def form_request(self, node_id: int, folder_name: str) -> List[str]: + def form_request(self, node_id: int, folder_name: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) if node_name is None or folder_name is None: @@ -430,7 +432,7 @@ class NodeFileAbstractAction(AbstractAction): self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files} self.verb: str # define but don't initialise: defends against children classes not defining this - def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]: + def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id) @@ -463,7 +465,7 @@ class NodeFileDeleteAction(NodeFileAbstractAction): super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) self.verb: str = "delete" - def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]: + def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id) @@ -504,7 +506,7 @@ class NodeFileAccessAction(AbstractAction): super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs) self.verb: str = "access" - def form_request(self, node_id: int, folder_name: str, file_name: str) -> List[str]: + def form_request(self, node_id: int, folder_name: str, file_name: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) if node_name is None or folder_name is None or file_name is None: @@ -525,7 +527,7 @@ class NodeAbstractAction(AbstractAction): self.shape: Dict[str, int] = {"node_id": num_nodes} self.verb: str # define but don't initialise: defends against children classes not defining this - def form_request(self, node_id: int) -> List[str]: + def form_request(self, node_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) return ["network", "node", node_name, self.verb] @@ -740,7 +742,7 @@ class RouterACLRemoveRuleAction(AbstractAction): super().__init__(manager=manager) self.shape: Dict[str, int] = {"position": max_acl_rules} - def form_request(self, target_router: str, position: int) -> List[str]: + def form_request(self, target_router: str, position: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" return ["network", "node", target_router, "acl", "remove_rule", position] @@ -923,7 +925,7 @@ class HostNICAbstractAction(AbstractAction): self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node} self.verb: str # define but don't initialise: defends against children classes not defining this - def form_request(self, node_id: int, nic_id: int) -> List[str]: + def form_request(self, node_id: int, nic_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_idx=node_id) nic_num = self.manager.get_nic_num_by_idx(node_idx=node_id, nic_idx=nic_id) @@ -960,7 +962,7 @@ class NetworkPortEnableAction(AbstractAction): super().__init__(manager=manager) self.shape: Dict[str, int] = {"port_id": max_nics_per_node} - def form_request(self, target_nodename: str, port_id: int) -> List[str]: + def form_request(self, target_nodename: str, port_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" if target_nodename is None or port_id is None: return ["do_nothing"] @@ -979,7 +981,7 @@ class NetworkPortDisableAction(AbstractAction): super().__init__(manager=manager) self.shape: Dict[str, int] = {"port_id": max_nics_per_node} - def form_request(self, target_nodename: str, port_id: int) -> List[str]: + def form_request(self, target_nodename: str, port_id: int) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" if target_nodename is None or port_id is None: return ["do_nothing"] @@ -1315,7 +1317,7 @@ class ActionManager: act_identifier, act_options = self.action_map[action] return act_identifier, act_options - def form_request(self, action_identifier: str, action_options: Dict) -> List[str]: + def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat: """Take action in CAOS format and use the execution definition to change it into PrimAITE request format.""" act_obj = self.actions[action_identifier] return act_obj.form_request(**action_options) diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 6cc1282f..aa2bc308 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK import json from os import PathLike -from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union +from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, Union import gymnasium from gymnasium.core import ActType, ObsType @@ -40,6 +40,27 @@ class PrimaiteGymEnv(gymnasium.Env): """Current episode number.""" self.total_reward_per_episode: Dict[int, float] = {} """Average rewards of agents per episode.""" + self.action_masking: bool = False + """Whether to use action masking.""" + + def action_masks(self) -> List[bool]: + """ + Return the action mask for the agent. + + This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be + performed during this step. + + :return: Action mask + :rtype: List[bool] + """ + mask = [True] * len(self.agent.action_manager.action_map) + if not self.action_masking: + return mask + + for i, action in self.agent.action_manager.action_map.items(): + request = self.agent.action_manager.form_request(action_identifier=action[0], action_options=action[1]) + mask[i] = self.game.simulation._request_manager.check_valid(request, {}) + return mask @property def agent(self) -> ProxyAgent: diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index be5eb4b9..7653a3ab 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -3,7 +3,7 @@ """Core of the PrimAITE Simulator.""" import warnings from abc import abstractmethod -from typing import Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union from uuid import uuid4 from prettytable import PrettyTable @@ -179,6 +179,23 @@ class RequestManager(BaseModel): table.add_rows(self.get_request_types_recursively()) print(table) + def check_valid(self, request: RequestFormat, context: Dict) -> bool: + """Check if this request would be valid in the current state of the simulation without invoking it.""" + + request_key = request[0] + request_options = request[1:] + + if request_key not in self.request_types: + return False + + request_type = self.request_types[request_key] + + # recurse if we are not at a leaf node + if isinstance(request_type.func, RequestManager): + return request_type.func.check_valid(request_options, context) + + return request_type.validator(request_options, context) + class SimComponent(BaseModel): """Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator.""" From 5367f9ad5376241ee40005381652cc13f628c53c Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 9 Jul 2024 15:27:03 +0100 Subject: [PATCH 14/35] 2623 Ray single agent action masking --- .../_package_data/data_manipulation.yaml | 1 + src/primaite/game/agent/interface.py | 3 + src/primaite/notebooks/Action-masking.ipynb | 184 ++++++++++++++++++ src/primaite/session/environment.py | 11 +- src/primaite/session/ray_envs.py | 31 ++- 5 files changed, 214 insertions(+), 16 deletions(-) create mode 100644 src/primaite/notebooks/Action-masking.ipynb diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index 6d4ec9b4..97442903 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -741,6 +741,7 @@ agents: agent_settings: flatten_obs: true + action_masking: true diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 95468331..01b7fb0a 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -69,6 +69,8 @@ class AgentSettings(BaseModel): "Configuration for when an agent begins performing it's actions" flatten_obs: bool = True "Whether to flatten the observation space before passing it to the agent. True by default." + action_masking: bool = True + "Whether to return action masks at each step." @classmethod def from_config(cls, config: Optional[Dict]) -> "AgentSettings": @@ -205,6 +207,7 @@ class ProxyAgent(AbstractAgent): ) self.most_recent_action: ActType self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False + self.action_masking: bool = agent_settings.action_masking if agent_settings else False def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ diff --git a/src/primaite/notebooks/Action-masking.ipynb b/src/primaite/notebooks/Action-masking.ipynb new file mode 100644 index 00000000..822b8451 --- /dev/null +++ b/src/primaite/notebooks/Action-masking.ipynb @@ -0,0 +1,184 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Action Masking\n", + "\n", + "PrimAITE environments support action masking. The action mask shows which of the agent's actions are applicable with the current environment state. For example, a node can only be turned on if it is currently turned off." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.session.environment import PrimaiteGymEnv\n", + "from primaite.config.load import data_manipulation_config_path\n", + "from prettytable import PrettyTable" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env = PrimaiteGymEnv(data_manipulation_config_path())\n", + "env.action_masking = True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The action mask is a list of booleans that specifies whether each action in the agent's action map is currently possible. Demonstrated here:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "act_table = PrettyTable((\"number\", \"action\", \"parameters\", \"mask\"))\n", + "mask = env.action_masks()\n", + "actions = env.agent.action_manager.action_map\n", + "max_str_len = 70\n", + "for act,mask in zip(actions.items(), mask):\n", + " act_num, act_data = act\n", + " act_type, act_params = act_data\n", + " act_params = s if len(s:=str(act_params)) List[bool]: + def action_masks(self) -> np.ndarray: """ Return the action mask for the agent. @@ -54,13 +53,13 @@ class PrimaiteGymEnv(gymnasium.Env): :rtype: List[bool] """ mask = [True] * len(self.agent.action_manager.action_map) - if not self.action_masking: + if not self.agent.action_masking: return mask for i, action in self.agent.action_manager.action_map.items(): request = self.agent.action_manager.form_request(action_identifier=action[0], action_options=action[1]) mask[i] = self.game.simulation._request_manager.check_valid(request, {}) - return mask + return np.asarray(mask) @property def agent(self) -> ProxyAgent: diff --git a/src/primaite/session/ray_envs.py b/src/primaite/session/ray_envs.py index fc5d73d8..1fc7624f 100644 --- a/src/primaite/session/ray_envs.py +++ b/src/primaite/session/ray_envs.py @@ -3,6 +3,7 @@ import json from typing import Dict, SupportsFloat, Tuple import gymnasium +from gymnasium import spaces from gymnasium.core import ActType, ObsType from ray.rllib.env.multi_agent_env import MultiAgentEnv @@ -38,15 +39,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): self.terminateds = set() self.truncateds = set() - self.observation_space = gymnasium.spaces.Dict( - { - name: gymnasium.spaces.flatten_space(agent.observation_manager.space) - for name, agent in self.agents.items() - } - ) - self.action_space = gymnasium.spaces.Dict( - {name: agent.action_manager.space for name, agent in self.agents.items()} + self.observation_space = spaces.Dict( + {name: spaces.flatten_space(agent.observation_manager.space) for name, agent in self.agents.items()} ) + self.action_space = spaces.Dict({name: agent.action_manager.space for name, agent in self.agents.items()}) self._obs_space_in_preferred_format = True self._action_space_in_preferred_format = True super().__init__() @@ -158,15 +154,30 @@ class PrimaiteRayEnv(gymnasium.Env): self.env = PrimaiteGymEnv(env_config=env_config) # self.env.episode_counter -= 1 self.action_space = self.env.action_space - self.observation_space = self.env.observation_space + if self.env.agent.action_masking: + self.observation_space = spaces.Dict( + {"action_mask": spaces.MultiBinary(self.env.action_space.n), "observations": self.env.observation_space} + ) + else: + self.observation_space = self.env.observation_space def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" + if self.env.agent.action_masking: + obs, *_ = self.env.reset(seed=seed) + new_obs = {"action_mask": self.env.action_masks(), "observations": obs} + return new_obs, *_ return self.env.reset(seed=seed) def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]: """Perform a step in the environment.""" - return self.env.step(action) + # if action masking is enabled, intercept the step method and add action mask to observation + if self.env.agent.action_masking: + obs, *_ = self.env.step(action) + new_obs = {"action_mask": self.env.action_masks(), "observations": obs} + return new_obs, *_ + else: + return self.env.step(action) def close(self): """Close the simulation.""" From faf268a9b9b4fd847bab7f31518b8488dadc169b Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 9 Jul 2024 15:59:50 +0100 Subject: [PATCH 15/35] 2623 move action mask generation to game and fix MARL masking --- .../_package_data/data_manipulation_marl.yaml | 2 + src/primaite/config/load.py | 15 ++++++ src/primaite/game/game.py | 18 +++++++ src/primaite/notebooks/Action-masking.ipynb | 53 ++++++++++++++++--- src/primaite/session/environment.py | 10 ++-- src/primaite/session/ray_envs.py | 19 +++++-- 6 files changed, 101 insertions(+), 16 deletions(-) diff --git a/src/primaite/config/_package_data/data_manipulation_marl.yaml b/src/primaite/config/_package_data/data_manipulation_marl.yaml index 2e8221a0..ba666781 100644 --- a/src/primaite/config/_package_data/data_manipulation_marl.yaml +++ b/src/primaite/config/_package_data/data_manipulation_marl.yaml @@ -733,6 +733,7 @@ agents: agent_settings: flatten_obs: true + action_masking: true - ref: defender_2 team: BLUE @@ -1316,6 +1317,7 @@ agents: agent_settings: flatten_obs: true + action_masking: true diff --git a/src/primaite/config/load.py b/src/primaite/config/load.py index 3483fc87..144e0733 100644 --- a/src/primaite/config/load.py +++ b/src/primaite/config/load.py @@ -44,3 +44,18 @@ def data_manipulation_config_path() -> Path: _LOGGER.error(msg) raise FileNotFoundError(msg) return path + + +def data_manipulation_marl_config_path() -> Path: + """ + Get the path to the MARL example config. + + :return: Path to yaml config file for the MARL scenario. + :rtype: Path + """ + path = _EXAMPLE_CFG / "data_manipulation_marl.yaml" + if not path.exists(): + msg = f"Example config does not exist: {path}. Have you run `primaite setup`?" + _LOGGER.error(msg) + raise FileNotFoundError(msg) + return path diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 3dc9571f..e7d13061 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -3,6 +3,7 @@ from ipaddress import IPv4Address from typing import Dict, List, Optional +import numpy as np from pydantic import BaseModel, ConfigDict from primaite import DEFAULT_BANDWIDTH, getLogger @@ -192,6 +193,23 @@ class PrimaiteGame: return True return False + def action_mask(self, agent_name: str) -> np.ndarray: + """ + Return the action mask for the agent. + + This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be + performed during this step. + + :return: Action mask + :rtype: List[bool] + """ + agent = self.agents[agent_name] + mask = [True] * len(agent.action_manager.action_map) + for i, action in agent.action_manager.action_map.items(): + request = agent.action_manager.form_request(action_identifier=action[0], action_options=action[1]) + mask[i] = self.simulation._request_manager.check_valid(request, {}) + return np.asarray(mask) + def close(self) -> None: """Close the game, this will close the simulation.""" return NotImplemented diff --git a/src/primaite/notebooks/Action-masking.ipynb b/src/primaite/notebooks/Action-masking.ipynb index 822b8451..8090dacc 100644 --- a/src/primaite/notebooks/Action-masking.ipynb +++ b/src/primaite/notebooks/Action-masking.ipynb @@ -96,7 +96,7 @@ "metadata": {}, "outputs": [], "source": [ - "from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv\n", + "from primaite.session.ray_envs import PrimaiteRayEnv\n", "from ray.rllib.algorithms.ppo import PPOConfig\n", "import yaml\n", "from ray import air, tune\n" @@ -146,18 +146,59 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "## Action masking with MARL in Ray RLLib\n", + "Each agent has their own action mask, this is useful if the agents have different action spaces." + ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from primaite.session.ray_envs import PrimaiteRayMARLEnv\n", + "from primaite.config.load import data_manipulation_marl_config_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(data_manipulation_marl_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + "env_config = cfg\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config = (\n", + " PPOConfig()\n", + " .multi_agent(\n", + " policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n", + " policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n", + " )\n", + " .environment(env=PrimaiteRayMARLEnv, env_config=cfg)\n", + " .env_runners(num_env_runners=0)\n", + " .training(train_batch_size=128)\n", + " )\n", + "\n", + "tune.Tuner(\n", + " \"PPO\",\n", + " run_config=air.RunConfig(\n", + " stop={\"timesteps_total\": 5 * 128},\n", + " ),\n", + " param_space=config\n", + ").fit()" + ] } ], "metadata": { diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 0520cce9..a87f0cde 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -52,14 +52,10 @@ class PrimaiteGymEnv(gymnasium.Env): :return: Action mask :rtype: List[bool] """ - mask = [True] * len(self.agent.action_manager.action_map) if not self.agent.action_masking: - return mask - - for i, action in self.agent.action_manager.action_map.items(): - request = self.agent.action_manager.form_request(action_identifier=action[0], action_options=action[1]) - mask[i] = self.game.simulation._request_manager.check_valid(request, {}) - return np.asarray(mask) + return np.asarray([True] * len(self.agent.action_manager.action_map)) + else: + return self.game.action_mask(self._agent_name) @property def agent(self) -> ProxyAgent: diff --git a/src/primaite/session/ray_envs.py b/src/primaite/session/ray_envs.py index 1fc7624f..12167f89 100644 --- a/src/primaite/session/ray_envs.py +++ b/src/primaite/session/ray_envs.py @@ -42,6 +42,15 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): self.observation_space = spaces.Dict( {name: spaces.flatten_space(agent.observation_manager.space) for name, agent in self.agents.items()} ) + for agent_name in self._agent_ids: + agent = self.game.rl_agents[agent_name] + if agent.action_masking: + self.observation_space[agent_name] = spaces.Dict( + { + "action_mask": spaces.MultiBinary(agent.action_manager.space.n), + "observations": self.observation_space[agent_name], + } + ) self.action_space = spaces.Dict({name: agent.action_manager.space for name, agent in self.agents.items()}) self._obs_space_in_preferred_format = True self._action_space_in_preferred_format = True @@ -127,13 +136,17 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def _get_obs(self) -> Dict[str, ObsType]: """Return the current observation.""" - obs = {} + all_obs = {} for agent_name in self._agent_ids: agent = self.game.rl_agents[agent_name] unflat_space = agent.observation_manager.space unflat_obs = agent.observation_manager.current_observation - obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs) - return obs + obs = gymnasium.spaces.flatten(unflat_space, unflat_obs) + if agent.action_masking: + all_obs[agent_name] = {"action_mask": self.game.action_mask(agent_name), "observations": obs} + else: + all_obs[agent_name] = obs + return all_obs def close(self): """Close the simulation.""" From 20b9b61be58b218c59abaa5478bbb3852600361f Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 9 Jul 2024 16:44:02 +0100 Subject: [PATCH 16/35] #2740: added ability to merge validators + validators for folders --- src/primaite/simulator/core.py | 14 ++ .../simulator/file_system/file_system.py | 83 +++++++++++- src/primaite/simulator/file_system/folder.py | 30 ++++- .../actions/test_file_request_permission.py | 82 ++++++++++++ .../actions/test_folder_request_permission.py | 123 ++++++++++++++++++ .../actions/test_nic_request_permission.py | 2 - 6 files changed, 323 insertions(+), 11 deletions(-) create mode 100644 tests/integration_tests/game_layer/actions/test_file_request_permission.py create mode 100644 tests/integration_tests/game_layer/actions/test_folder_request_permission.py diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 8d8425ec..70485af5 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -34,6 +34,20 @@ class RequestPermissionValidator(BaseModel): """Message that is reported when a request is rejected by this validator.""" return "request rejected" + def __add__(self, other: "RequestPermissionValidator") -> "_CombinedValidator": + return _CombinedValidator(validators=[self, other]) + + +class _CombinedValidator(RequestPermissionValidator): + validators: List[RequestPermissionValidator] = [] + + def __call__(self, request, context) -> bool: + return all(x(request, context) for x in self.validators) + + @property + def fail_message(self): + return f"One of the following conditions are not met: {[v.fail_message for v in self.validators]}" + class AllowAllValidator(RequestPermissionValidator): """Always allows the request.""" diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index 456b800c..42aa0573 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -6,8 +6,8 @@ from typing import Any, Dict, List, Optional from prettytable import MARKDOWN, PrettyTable -from primaite.interface.request import RequestResponse -from primaite.simulator.core import RequestManager, RequestType, SimComponent +from primaite.interface.request import RequestFormat, RequestResponse +from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType, SimComponent from primaite.simulator.file_system.file import File from primaite.simulator.file_system.file_type import FileType from primaite.simulator.file_system.folder import Folder @@ -42,6 +42,10 @@ class FileSystem(SimComponent): More information in user guide and docstring for SimComponent._init_request_manager. """ + self._folder_exists = FileSystem._FolderExistsValidator(file_system=self) + self._folder_deleted = FileSystem._FolderNotDeletedValidator(file_system=self) + self._file_exists = FileSystem._FileExistsValidator(file_system=self) + rm = super()._init_request_manager() self._delete_manager = RequestManager() @@ -50,13 +54,15 @@ class FileSystem(SimComponent): request_type=RequestType( func=lambda request, context: RequestResponse.from_bool( self.delete_file(folder_name=request[0], file_name=request[1]) - ) + ), + validator=self._file_exists, ), ) self._delete_manager.add_request( name="folder", request_type=RequestType( - func=lambda request, context: RequestResponse.from_bool(self.delete_folder(folder_name=request[0])) + func=lambda request, context: RequestResponse.from_bool(self.delete_folder(folder_name=request[0])), + validator=self._folder_exists, ), ) rm.add_request( @@ -144,10 +150,13 @@ class FileSystem(SimComponent): ) self._folder_request_manager = RequestManager() - rm.add_request("folder", RequestType(func=self._folder_request_manager)) + rm.add_request( + "folder", + RequestType(func=self._folder_request_manager, validator=self._folder_exists + self._folder_deleted), + ) self._file_request_manager = RequestManager() - rm.add_request("file", RequestType(func=self._file_request_manager)) + rm.add_request("file", RequestType(func=self._file_request_manager, validator=self._file_exists)) return rm @@ -626,3 +635,65 @@ class FileSystem(SimComponent): self.sys_log.error(f"Unable to access file that does not exist. (file name: {file_name})") return False + + class _FolderExistsValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only let them through if the Folder exists. + + Actions cannot be performed on a non-existent folder. + """ + + file_system: FileSystem + """Save a reference to the FileSystem instance.""" + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Returns True if folder exists.""" + return self.file_system.get_folder(folder_name=request[0]) is not None + + @property + def fail_message(self) -> str: + """Message that is reported when a request is rejected by this validator.""" + return "Cannot perform request on folder because it does not exist" + + class _FolderNotDeletedValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only let them through if the Folder has not been deleted. + + Actions cannot be performed on a deleted folder. + """ + + file_system: FileSystem + """Save a reference to the FileSystem instance.""" + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Returns True if folder exists and is not deleted.""" + # get folder + folder = self.file_system.get_folder(folder_name=request[0], include_deleted=True) + return folder is not None and not folder.deleted + + @property + def fail_message(self) -> str: + """Message that is reported when a request is rejected by this validator.""" + return "Cannot perform request on folder because it is deleted." + + class _FileExistsValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only let them through if the File exists. + + Actions cannot be performed on a non-existent file. + """ + + file_system: FileSystem + """Save a reference to the FileSystem instance.""" + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Returns True if file exists.""" + return self.file_system.get_file(folder_name=request[0], file_name=request[1]) is not None + + @property + def fail_message(self) -> str: + """Message that is reported when a request is rejected by this validator.""" + return ( + f"Cannot perform request on application '{self.application.name}' because it is not in the " + f"{self.state.name} state." + ) diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index dd2a4c70..af7cc660 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -6,8 +6,8 @@ from typing import Dict, Optional from prettytable import MARKDOWN, PrettyTable -from primaite.interface.request import RequestResponse -from primaite.simulator.core import RequestManager, RequestType +from primaite.interface.request import RequestFormat, RequestResponse +from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType from primaite.simulator.file_system.file import File from primaite.simulator.file_system.file_system_item_abc import FileSystemItemABC, FileSystemItemHealthStatus @@ -55,6 +55,8 @@ class Folder(FileSystemItemABC): More information in user guide and docstring for SimComponent._init_request_manager. """ + self._file_exists = Folder._FileExistsValidator(folder=self) + rm = super()._init_request_manager() rm.add_request( name="delete", @@ -65,7 +67,7 @@ class Folder(FileSystemItemABC): self._file_request_manager = RequestManager() rm.add_request( name="file", - request_type=RequestType(func=self._file_request_manager), + request_type=RequestType(func=self._file_request_manager, validator=self._file_exists), ) return rm @@ -469,3 +471,25 @@ class Folder(FileSystemItemABC): self.deleted = True return True + + class _FileExistsValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only let them through if the File exists. + + Actions cannot be performed on a non-existent file. + """ + + folder: Folder + """Save a reference to the Folder instance.""" + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Returns True if file exists.""" + return self.folder.get_file(file_name=request[0]) is not None + + @property + def fail_message(self) -> str: + """Message that is reported when a request is rejected by this validator.""" + return ( + f"Cannot perform request on application '{self.application.name}' because it is not in the " + f"{self.state.name} state." + ) diff --git a/tests/integration_tests/game_layer/actions/test_file_request_permission.py b/tests/integration_tests/game_layer/actions/test_file_request_permission.py new file mode 100644 index 00000000..c422ad43 --- /dev/null +++ b/tests/integration_tests/game_layer/actions/test_file_request_permission.py @@ -0,0 +1,82 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import uuid +from typing import Tuple + +import pytest + +from primaite.game.agent.interface import ProxyAgent +from primaite.game.game import PrimaiteGame +from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus +from primaite.simulator.network.hardware.nodes.host.computer import Computer + + +@pytest.fixture +def game_and_agent_fixture(game_and_agent): + """Create a game with a simple agent that can be controlled by the tests.""" + game, agent = game_and_agent + + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + client_1.start_up_duration = 3 + + return (game, agent) + + +def test_create_file(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that the validator allows a folder to be created.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + random_folder = str(uuid.uuid4()) + random_file = str(uuid.uuid4()) + + assert client_1.file_system.get_file(folder_name=random_folder, file_name=random_file) is None + + action = ( + "NODE_FILE_CREATE", + {"node_id": 0, "folder_name": random_folder, "file_name": random_file}, + ) + agent.store_action(action) + game.step() + + assert client_1.file_system.get_file(folder_name=random_folder, file_name=random_file) is not None + + +def test_file_delete_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that the validator allows a folder to be created.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + file = client_1.file_system.get_file(folder_name="downloads", file_name="cat.png") + assert file.deleted is False + + action = ( + "NODE_FILE_DELETE", + {"node_id": 0, "folder_id": 0, "file_id": 0}, + ) + agent.store_action(action) + game.step() + + assert file.deleted + + +def test_file_scan_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that the validator allows a folder to be created.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + file = client_1.file_system.get_file(folder_name="downloads", file_name="cat.png") + + file.corrupt() + assert file.health_status == FileSystemItemHealthStatus.CORRUPT + assert file.visible_health_status == FileSystemItemHealthStatus.GOOD + + action = ( + "NODE_FILE_SCAN", + {"node_id": 0, "folder_id": 0, "file_id": 0}, + ) + agent.store_action(action) + game.step() + + assert file.health_status == FileSystemItemHealthStatus.CORRUPT + assert file.visible_health_status == FileSystemItemHealthStatus.CORRUPT diff --git a/tests/integration_tests/game_layer/actions/test_folder_request_permission.py b/tests/integration_tests/game_layer/actions/test_folder_request_permission.py new file mode 100644 index 00000000..e5e0806a --- /dev/null +++ b/tests/integration_tests/game_layer/actions/test_folder_request_permission.py @@ -0,0 +1,123 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import uuid +from typing import Tuple + +import pytest + +from primaite.game.agent.interface import ProxyAgent +from primaite.game.game import PrimaiteGame +from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus +from primaite.simulator.network.hardware.nodes.host.computer import Computer + + +@pytest.fixture +def game_and_agent_fixture(game_and_agent): + """Create a game with a simple agent that can be controlled by the tests.""" + game, agent = game_and_agent + + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + client_1.start_up_duration = 3 + + return (game, agent) + + +def test_create_folder(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that the validator allows a folder to be created.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + random_folder = str(uuid.uuid4()) + + assert client_1.file_system.get_folder(folder_name=random_folder) is None + + action = ( + "NODE_FOLDER_CREATE", + { + "node_id": 0, + "folder_name": random_folder, + }, + ) + agent.store_action(action) + game.step() + + assert client_1.file_system.get_folder(folder_name=random_folder) is not None + + +def test_folder_scan_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test to make sure that the validator checks if the folder exists before scanning.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + folder = client_1.file_system.get_folder(folder_name="downloads") + assert folder.health_status == FileSystemItemHealthStatus.GOOD + assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD + + folder.corrupt() + + assert folder.health_status == FileSystemItemHealthStatus.CORRUPT + assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD + + action = ( + "NODE_FOLDER_SCAN", + { + "node_id": 0, # client_1, + "folder_id": 0, # downloads + }, + ) + agent.store_action(action) + game.step() + + for i in range(folder.scan_duration + 1): + game.step() + + assert folder.health_status == FileSystemItemHealthStatus.CORRUPT + assert folder.visible_health_status == FileSystemItemHealthStatus.CORRUPT + + +def test_folder_repair_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test to make sure that the validator checks if the folder exists before repairing.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + folder = client_1.file_system.get_folder(folder_name="downloads") + folder.corrupt() + assert folder.health_status == FileSystemItemHealthStatus.CORRUPT + + action = ( + "NODE_FOLDER_REPAIR", + { + "node_id": 0, # client_1, + "folder_id": 0, # downloads + }, + ) + agent.store_action(action) + game.step() + + assert folder.health_status == FileSystemItemHealthStatus.GOOD + + +def test_folder_restore_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test to make sure that the validator checks if the folder exists before restoring.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + folder = client_1.file_system.get_folder(folder_name="downloads") + folder.corrupt() + + assert folder.health_status == FileSystemItemHealthStatus.CORRUPT + + action = ( + "NODE_FOLDER_RESTORE", + { + "node_id": 0, # client_1, + "folder_id": 0, # downloads + }, + ) + agent.store_action(action) + game.step() + + assert folder.health_status == FileSystemItemHealthStatus.RESTORING diff --git a/tests/integration_tests/game_layer/actions/test_nic_request_permission.py b/tests/integration_tests/game_layer/actions/test_nic_request_permission.py index 4c1619e7..d796b75e 100644 --- a/tests/integration_tests/game_layer/actions/test_nic_request_permission.py +++ b/tests/integration_tests/game_layer/actions/test_nic_request_permission.py @@ -6,8 +6,6 @@ import pytest from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.system.services.service import ServiceOperatingState @pytest.fixture From 48645d2e72adf0c66a48d8183952c959c02bedae Mon Sep 17 00:00:00 2001 From: "Archer.Bowen" Date: Tue, 9 Jul 2024 16:46:31 +0100 Subject: [PATCH 17/35] #2716 Initial Implementation + Initial Tests and updated changelog and sphinx documentation. --- CHANGELOG.md | 3 +- docs/source/configuration/io_settings.rst | 24 ++- src/primaite/game/agent/agent_log.py | 168 ++++++++++++++++++ src/primaite/game/agent/interface.py | 2 + .../scripted_agents/data_manipulation_bot.py | 4 +- .../scripted_agents/probabilistic_agent.py | 1 + src/primaite/session/io.py | 14 +- src/primaite/simulator/__init__.py | 34 ++++ .../configs/basic_switched_network.yaml | 3 + .../test_io_settings.py | 4 + .../_primaite/_game/_agent/test_agent_log.py | 137 ++++++++++++++ 11 files changed, 387 insertions(+), 7 deletions(-) create mode 100644 src/primaite/game/agent/agent_log.py create mode 100644 tests/unit_tests/_primaite/_game/_agent/test_agent_log.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 17bf3557..beec6d11 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,8 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Activating dev-mode will change the location where the sessions will be output - by default will output where the PrimAITE repository is located - Refactored all air-space usage to that a new instance of AirSpace is created for each instance of Network. This 1:1 relationship between network and airspace will allow parallelization. - Added notebook to demonstrate use of SubprocVecEnv from SB3 to vectorise environments to speed up training. - - +- Added a new agent simulation log which are more human friendly than agent action logging. Includes timesteps so that the agent action log can be cross referenced. These Logs are found in simulation_output directory, similar to that of sys_logs and can be enabled in the I/O settings in a yaml configuration file. ## [Unreleased] - Made requests fail to reach their target if the node is off diff --git a/docs/source/configuration/io_settings.rst b/docs/source/configuration/io_settings.rst index 82fd7408..1c9585c9 100644 --- a/docs/source/configuration/io_settings.rst +++ b/docs/source/configuration/io_settings.rst @@ -18,8 +18,11 @@ This section configures how PrimAITE saves data during simulation and training. save_step_metadata: False save_pcap_logs: False save_sys_logs: False + save_agent_logs: False write_sys_log_to_terminal: False + write_agent_log_to_terminal: False sys_log_level: WARNING + agent_log_level: INFO ``save_logs`` @@ -57,6 +60,12 @@ Optional. Default value is ``False``. If ``True``, then the log files which contain all node actions during the simulation will be saved. +``save_agent_logs`` +----------------- + +Optional. Default value is ``False``. + +If ``True``, then the log files which contain all human readable agent behaviour during the simulation will be saved. ``write_sys_log_to_terminal`` ----------------------------- @@ -65,16 +74,25 @@ Optional. Default value is ``False``. If ``True``, PrimAITE will print sys log to the terminal. +``write_agent_log_to_terminal`` +----------------------------- -``sys_log_level`` -------------- +Optional. Default value is ``False``. + +If ``True``, PrimAITE will print all human readable agent behaviour logs to the terminal. + + +``sys_log_level & agent_log_level`` +--------------------------------- Optional. Default value is ``WARNING``. -The level of logging that should be visible in the sys logs or the logs output to the terminal. +The level of logging that should be visible in the syslog, agent logs or the logs output to the terminal. ``save_sys_logs`` or ``write_sys_log_to_terminal`` has to be set to ``True`` for this setting to be used. +This is also true for agent behaviour logging. + Available options are: - ``DEBUG``: Debug level items and the items below diff --git a/src/primaite/game/agent/agent_log.py b/src/primaite/game/agent/agent_log.py new file mode 100644 index 00000000..1e51dcad --- /dev/null +++ b/src/primaite/game/agent/agent_log.py @@ -0,0 +1,168 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import logging +from pathlib import Path + +from prettytable import MARKDOWN, PrettyTable + +from primaite.simulator import LogLevel, SIM_OUTPUT + + +class _NotJSONFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + """ + Determines if a log message does not start and end with '{' and '}' (i.e., it is not a JSON-like message). + + :param record: LogRecord object containing all the information pertinent to the event being logged. + :return: True if log message is not JSON-like, False otherwise. + """ + return not record.getMessage().startswith("{") and not record.getMessage().endswith("}") + + +class AgentLog: + """ + A Agent Log class is a simple logger dedicated to managing and writing logging updates and information for an agent. + + Each log message is written to a file located at: /agent/agent_name.log + """ + + def __init__(self, agent_name: str): + """ + Constructs a Agent Log instance for a given hostname. + + :param hostname: The hostname associated with the system logs being recorded. + """ + self.agent_name = agent_name + self.current_episode: int = 1 + self.setup_logger() + + def setup_logger(self): + """ + Configures the logger for this Agent Log instance. + + The logger is set to the DEBUG level, and is equipped with a handler that writes to a file and filters out + JSON-like messages. + """ + if not SIM_OUTPUT.save_agent_logs: + return + + log_path = self._get_log_path() + file_handler = logging.FileHandler(filename=log_path) + file_handler.setLevel(logging.DEBUG) + + log_format = "%(timestep)s::%(levelname)s::%(message)s" + file_handler.setFormatter(logging.Formatter(log_format)) + + self.logger = logging.getLogger(f"{self.agent_name}_log") + for handler in self.logger.handlers: + self.logger.removeHandler(handler) + self.logger.setLevel(logging.DEBUG) + self.logger.addHandler(file_handler) + + def _get_log_path(self) -> Path: + """ + Constructs the path for the log file based on the hostname. + + :return: Path object representing the location of the log file. + """ + root = SIM_OUTPUT.path / f"episode_{self.current_episode}" / "agent_logs" / self.agent_name + root.mkdir(exist_ok=True, parents=True) + return root / f"{self.agent_name}.log" + + def _write_to_terminal(self, msg: str, timestep: int, level: str, to_terminal: bool = False): + if to_terminal or SIM_OUTPUT.write_agent_log_to_terminal: + print(f"{self.agent_name}: ({timestep}) ({level}) {msg}") + + def debug(self, msg: str, time_step: int, to_terminal: bool = False): + """ + Logs a message with the DEBUG level. + + :param msg: The message to be logged. + :param to_terminal: If True, prints to the terminal too. + """ + if SIM_OUTPUT.agent_log_level > LogLevel.DEBUG: + return + + if SIM_OUTPUT.save_agent_logs: + self.logger.debug(msg, extra={"timestep": time_step}) + self._write_to_terminal(msg, "DEBUG", to_terminal) + + def info(self, msg: str, time_step: int, to_terminal: bool = False): + """ + Logs a message with the INFO level. + + :param msg: The message to be logged. + :param timestep: The current timestep. + :param to_terminal: If True, prints to the terminal too. + """ + if SIM_OUTPUT.agent_log_level > LogLevel.INFO: + return + + if SIM_OUTPUT.save_agent_logs: + self.logger.info(msg, extra={"timestep": time_step}) + self._write_to_terminal(msg, "INFO", to_terminal) + + def warning(self, msg: str, time_step: int, to_terminal: bool = False): + """ + Logs a message with the WARNING level. + + :param msg: The message to be logged. + :param timestep: The current timestep. + :param to_terminal: If True, prints to the terminal too. + """ + if SIM_OUTPUT.agent_log_level > LogLevel.WARNING: + return + + if SIM_OUTPUT.save_agent_logs: + self.logger.warning(msg, extra={"timestep": time_step}) + self._write_to_terminal(msg, "WARNING", to_terminal) + + def error(self, msg: str, time_step: int, to_terminal: bool = False): + """ + Logs a message with the ERROR level. + + :param msg: The message to be logged. + :param timestep: The current timestep. + :param to_terminal: If True, prints to the terminal too. + """ + if SIM_OUTPUT.agent_log_level > LogLevel.ERROR: + return + + if SIM_OUTPUT.save_agent_logs: + self.logger.error(msg, extra={"timestep": time_step}) + self._write_to_terminal(msg, "ERROR", to_terminal) + + def critical(self, msg: str, time_step: int, to_terminal: bool = False): + """ + Logs a message with the CRITICAL level. + + :param msg: The message to be logged. + :param timestep: The current timestep. + :param to_terminal: If True, prints to the terminal too. + """ + if LogLevel.CRITICAL < SIM_OUTPUT.agent_log_level: + return + + if SIM_OUTPUT.save_agent_logs: + self.logger.critical(msg, extra={"timestep": time_step}) + self._write_to_terminal(msg, "CRITICAL", to_terminal) + + def show(self, last_n: int = 10, markdown: bool = False): + """ + Print an Agents Log as a table. + + Generate and print PrettyTable instance that shows the agents behaviour log, with columns Time step, + Level and Message. + + :param markdown: Use Markdown style in table output. Defaults to False. + """ + table = PrettyTable(["Time Step", "Level", "Message"]) + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = f"{self.agent_name} Behaviour Log" + if self._get_log_path().exists(): + with open(self._get_log_path()) as file: + lines = file.readlines() + for line in lines[-last_n:]: + table.add_row(line.strip().split("::")) + print(table) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 95468331..c53b1956 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -7,6 +7,7 @@ from gymnasium.core import ActType, ObsType from pydantic import BaseModel, model_validator from primaite.game.agent.actions import ActionManager +from primaite.game.agent.agent_log import AgentLog from primaite.game.agent.observations.observation_manager import ObservationManager from primaite.game.agent.rewards import RewardFunction from primaite.interface.request import RequestFormat, RequestResponse @@ -116,6 +117,7 @@ class AbstractAgent(ABC): self.reward_function: Optional[RewardFunction] = reward_function self.agent_settings = agent_settings or AgentSettings() self.history: List[AgentHistoryItem] = [] + self.logger = AgentLog(agent_name) def update_observation(self, state: Dict) -> ObsType: """ diff --git a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py index 3a91f1fe..cd72e001 100644 --- a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py +++ b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py @@ -38,10 +38,11 @@ class DataManipulationAgent(AbstractScriptedAgent): :rtype: Tuple[str, Dict] """ if timestep < self.next_execution_timestep: + self.logger.debug(msg="Performing do NOTHING", time_step=timestep) return "DONOTHING", {} self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency) - + self.logger.info(msg="Performing a data manipulation attack!", time_step=timestep) return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0} def setup_agent(self) -> None: @@ -54,3 +55,4 @@ class DataManipulationAgent(AbstractScriptedAgent): # we are assuming that every node in the node manager has a data manipulation application at idx 0 num_nodes = len(self.action_manager.node_names) self.starting_node_idx = random.randint(0, num_nodes - 1) + self.logger.debug(msg=f"Select Start Node ID: {self.starting_node_idx}", time_step=0) diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py index fc168687..e0f41302 100644 --- a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -85,4 +85,5 @@ class ProbabilisticAgent(AbstractScriptedAgent): :rtype: Tuple[str, Dict] """ choice = self.rng.choice(len(self.action_manager.action_map), p=self.probabilities) + self.logger.info(f"Performing Action: {choice}", time_step=timestep) return self.action_manager.get_action(choice) diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index 7bfd16f1..05a5ee09 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -35,10 +35,16 @@ class PrimaiteIO: """Whether to save PCAP logs.""" save_sys_logs: bool = True """Whether to save system logs.""" + save_agent_logs: bool = True + """Whether to save agent logs.""" write_sys_log_to_terminal: bool = False """Whether to write the sys log to the terminal.""" + write_agent_log_to_terminal: bool = False + """Whether to write the agent log to the terminal.""" sys_log_level: LogLevel = LogLevel.INFO - """The level of log that should be included in the logfiles/logged into terminal.""" + """The level of sys logs that should be included in the logfiles/logged into terminal.""" + agent_log_level: LogLevel = LogLevel.INFO + """The level of agent logs that should be included in the logfiles/logged into terminal.""" def __init__(self, settings: Optional[Settings] = None) -> None: """ @@ -53,8 +59,11 @@ class PrimaiteIO: SIM_OUTPUT.path = self.session_path / "simulation_output" SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs + SIM_OUTPUT.save_agent_logs = self.settings.save_agent_logs + SIM_OUTPUT.write_agent_log_to_terminal = self.settings.write_agent_log_to_terminal SIM_OUTPUT.write_sys_log_to_terminal = self.settings.write_sys_log_to_terminal SIM_OUTPUT.sys_log_level = self.settings.sys_log_level + SIM_OUTPUT.agent_log_level = self.settings.agent_log_level def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path: """Create a folder for the session and return the path to it.""" @@ -115,6 +124,9 @@ class PrimaiteIO: if config.get("sys_log_level"): config["sys_log_level"] = LogLevel[config["sys_log_level"].upper()] # convert to enum + if config.get("agent_log_level"): + config["agent_log_level"] = LogLevel[config["agent_log_level"].upper()] # convert to enum + new = cls(settings=cls.Settings(**config)) return new diff --git a/src/primaite/simulator/__init__.py b/src/primaite/simulator/__init__.py index e5fe3cb7..487e7c5e 100644 --- a/src/primaite/simulator/__init__.py +++ b/src/primaite/simulator/__init__.py @@ -36,8 +36,11 @@ class _SimOutput: self._path = path self._save_pcap_logs: bool = False self._save_sys_logs: bool = False + self._save_agent_logs: bool = False self._write_sys_log_to_terminal: bool = False + self._write_agent_log_to_terminal: bool = False self._sys_log_level: LogLevel = LogLevel.WARNING # default log level is at WARNING + self._agent_log_level: LogLevel = LogLevel.WARNING @property def path(self) -> Path: @@ -81,6 +84,16 @@ class _SimOutput: def save_sys_logs(self, save_sys_logs: bool) -> None: self._save_sys_logs = save_sys_logs + @property + def save_agent_logs(self) -> bool: + if is_dev_mode(): + return PRIMAITE_CONFIG.get("developer_mode").get("output_agent_logs") + return self._save_agent_logs + + @save_agent_logs.setter + def save_agent_logs(self, save_agent_logs: bool) -> None: + self._save_agent_logs = save_agent_logs + @property def write_sys_log_to_terminal(self) -> bool: if is_dev_mode(): @@ -91,6 +104,17 @@ class _SimOutput: def write_sys_log_to_terminal(self, write_sys_log_to_terminal: bool) -> None: self._write_sys_log_to_terminal = write_sys_log_to_terminal + # Should this be separate from sys_log? + @property + def write_agent_log_to_terminal(self) -> bool: + if is_dev_mode(): + return PRIMAITE_CONFIG.get("developer_mode").get("output_to_terminal") + return self._write_agent_log_to_terminal + + @write_agent_log_to_terminal.setter + def write_agent_log_to_terminal(self, write_agent_log_to_terminal: bool) -> None: + self._write_agent_log_to_terminal = write_agent_log_to_terminal + @property def sys_log_level(self) -> LogLevel: if is_dev_mode(): @@ -101,5 +125,15 @@ class _SimOutput: def sys_log_level(self, sys_log_level: LogLevel) -> None: self._sys_log_level = sys_log_level + @property + def agent_log_level(self) -> LogLevel: + if is_dev_mode(): + return LogLevel[PRIMAITE_CONFIG.get("developer_mode").get("agent_log_level")] + return self._agent_log_level + + @agent_log_level.setter + def agent_log_level(self, agent_log_level: LogLevel) -> None: + self._agent_log_level = agent_log_level + SIM_OUTPUT = _SimOutput() diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index 7d40075d..69187fa3 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -9,6 +9,9 @@ io_settings: save_pcap_logs: true save_sys_logs: true sys_log_level: WARNING + agent_log_level: INFO + save_agent_logs: true + write_agent_log_to_terminal: True game: diff --git a/tests/integration_tests/configuration_file_parsing/test_io_settings.py b/tests/integration_tests/configuration_file_parsing/test_io_settings.py index ebaa4956..82977b82 100644 --- a/tests/integration_tests/configuration_file_parsing/test_io_settings.py +++ b/tests/integration_tests/configuration_file_parsing/test_io_settings.py @@ -35,3 +35,7 @@ def test_io_settings(): assert env.io.settings.save_step_metadata is False assert env.io.settings.write_sys_log_to_terminal is False # false by default + + assert env.io.settings.save_agent_logs is True + assert env.io.settings.agent_log_level is LogLevel.INFO + assert env.io.settings.write_agent_log_to_terminal is True # Set to True by the config file. diff --git a/tests/unit_tests/_primaite/_game/_agent/test_agent_log.py b/tests/unit_tests/_primaite/_game/_agent/test_agent_log.py new file mode 100644 index 00000000..a7932cb7 --- /dev/null +++ b/tests/unit_tests/_primaite/_game/_agent/test_agent_log.py @@ -0,0 +1,137 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from uuid import uuid4 + +import pytest + +from primaite import PRIMAITE_CONFIG +from primaite.game.agent.agent_log import AgentLog +from primaite.simulator import LogLevel, SIM_OUTPUT + + +@pytest.fixture(autouse=True) +def override_dev_mode_temporarily(): + """Temporarily turn off dev mode for this test.""" + primaite_dev_mode = PRIMAITE_CONFIG["developer_mode"]["enabled"] + PRIMAITE_CONFIG["developer_mode"]["enabled"] = False + yield # run tests + PRIMAITE_CONFIG["developer_mode"]["enabled"] = primaite_dev_mode + + +@pytest.fixture(scope="function") +def agentlog() -> AgentLog: + return AgentLog(agent_name="test_agent") + + +def test_debug_agent_log_level(agentlog, capsys): + """Test that the debug log level logs debug agent logs and above.""" + SIM_OUTPUT.agent_log_level = LogLevel.DEBUG + SIM_OUTPUT.write_agent_log_to_terminal = True + + test_string = str(uuid4()) + + agentlog.debug(msg=test_string, time_step=0) + agentlog.info(msg=test_string, time_step=0) + agentlog.warning(msg=test_string, time_step=0) + agentlog.error(msg=test_string, time_step=0) + agentlog.critical(msg=test_string, time_step=0) + + captured = "".join(capsys.readouterr()) + + assert test_string in captured + assert "DEBUG" in captured + assert "INFO" in captured + assert "WARNING" in captured + assert "ERROR" in captured + assert "CRITICAL" in captured + + +def test_info_agent_log_level(agentlog, capsys): + """Test that the debug log level logs debug agent logs and above.""" + SIM_OUTPUT.agent_log_level = LogLevel.INFO + SIM_OUTPUT.write_agent_log_to_terminal = True + + test_string = str(uuid4()) + + agentlog.debug(msg=test_string, time_step=0) + agentlog.info(msg=test_string, time_step=0) + agentlog.warning(msg=test_string, time_step=0) + agentlog.error(msg=test_string, time_step=0) + agentlog.critical(msg=test_string, time_step=0) + + captured = "".join(capsys.readouterr()) + + assert test_string in captured + assert "DEBUG" not in captured + assert "INFO" in captured + assert "WARNING" in captured + assert "ERROR" in captured + assert "CRITICAL" in captured + + +def test_warning_agent_log_level(agentlog, capsys): + """Test that the debug log level logs debug agent logs and above.""" + SIM_OUTPUT.agent_log_level = LogLevel.WARNING + SIM_OUTPUT.write_agent_log_to_terminal = True + + test_string = str(uuid4()) + + agentlog.debug(msg=test_string, time_step=0) + agentlog.info(msg=test_string, time_step=0) + agentlog.warning(msg=test_string, time_step=0) + agentlog.error(msg=test_string, time_step=0) + agentlog.critical(msg=test_string, time_step=0) + + captured = "".join(capsys.readouterr()) + + assert test_string in captured + assert "DEBUG" not in captured + assert "INFO" not in captured + assert "WARNING" in captured + assert "ERROR" in captured + assert "CRITICAL" in captured + + +def test_error_agent_log_level(agentlog, capsys): + """Test that the debug log level logs debug agent logs and above.""" + SIM_OUTPUT.agent_log_level = LogLevel.ERROR + SIM_OUTPUT.write_agent_log_to_terminal = True + + test_string = str(uuid4()) + + agentlog.debug(msg=test_string, time_step=0) + agentlog.info(msg=test_string, time_step=0) + agentlog.warning(msg=test_string, time_step=0) + agentlog.error(msg=test_string, time_step=0) + agentlog.critical(msg=test_string, time_step=0) + + captured = "".join(capsys.readouterr()) + + assert test_string in captured + assert "DEBUG" not in captured + assert "INFO" not in captured + assert "WARNING" not in captured + assert "ERROR" in captured + assert "CRITICAL" in captured + + +def test_critical_agent_log_level(agentlog, capsys): + """Test that the debug log level logs debug agent logs and above.""" + SIM_OUTPUT.agent_log_level = LogLevel.CRITICAL + SIM_OUTPUT.write_agent_log_to_terminal = True + + test_string = str(uuid4()) + + agentlog.debug(msg=test_string, time_step=0) + agentlog.info(msg=test_string, time_step=0) + agentlog.warning(msg=test_string, time_step=0) + agentlog.error(msg=test_string, time_step=0) + agentlog.critical(msg=test_string, time_step=0) + + captured = "".join(capsys.readouterr()) + + assert test_string in captured + assert "DEBUG" not in captured + assert "INFO" not in captured + assert "WARNING" not in captured + assert "ERROR" not in captured + assert "CRITICAL" in captured From 2eb9d970bf5a84be9736a11a34c0047392983c36 Mon Sep 17 00:00:00 2001 From: Christopher McCarthy Date: Tue, 9 Jul 2024 19:25:13 +0000 Subject: [PATCH 18/35] Apply suggestions from code review --- .azure/azure-benchmark-pipeline.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.azure/azure-benchmark-pipeline.yaml b/.azure/azure-benchmark-pipeline.yaml index 7eab2114..8bd7d08e 100644 --- a/.azure/azure-benchmark-pipeline.yaml +++ b/.azure/azure-benchmark-pipeline.yaml @@ -17,7 +17,7 @@ variables: jobs: - job: PrimAITE_Benchmark - timeoutInMinutes: 0 # Set to unlimited timeout + timeoutInMinutes: 360 # 6-hour maximum pool: vmImage: ubuntu-latest workspace: From 7201b7b8e096247621eecce69e078a9ebb14a6ed Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 10 Jul 2024 11:01:42 +0100 Subject: [PATCH 19/35] 2623 Add e2e tests for action masking --- src/primaite/game/game.py | 2 +- src/primaite/notebooks/Action-masking.ipynb | 53 +- src/primaite/session/ray_envs.py | 2 +- tests/assets/configs/multi_agent_session.yaml | 995 +++++++++++++----- .../assets/configs/test_primaite_session.yaml | 1 + .../action_masking/__init__.py | 1 + .../test_agents_use_action_masks.py | 160 +++ .../actions/test_configure_actions.py | 2 +- 8 files changed, 897 insertions(+), 319 deletions(-) create mode 100644 tests/e2e_integration_tests/action_masking/__init__.py create mode 100644 tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index e7d13061..252d1ce2 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -208,7 +208,7 @@ class PrimaiteGame: for i, action in agent.action_manager.action_map.items(): request = agent.action_manager.form_request(action_identifier=action[0], action_options=action[1]) mask[i] = self.simulation._request_manager.check_valid(request, {}) - return np.asarray(mask) + return np.asarray(mask, dtype=np.int8) def close(self) -> None: """Close the game, this will close the simulation.""" diff --git a/src/primaite/notebooks/Action-masking.ipynb b/src/primaite/notebooks/Action-masking.ipynb index 8090dacc..0e067b26 100644 --- a/src/primaite/notebooks/Action-masking.ipynb +++ b/src/primaite/notebooks/Action-masking.ipynb @@ -17,7 +17,7 @@ "source": [ "from primaite.session.environment import PrimaiteGymEnv\n", "from primaite.config.load import data_manipulation_config_path\n", - "from prettytable import PrettyTable" + "from prettytable import PrettyTable\n" ] }, { @@ -99,7 +99,9 @@ "from primaite.session.ray_envs import PrimaiteRayEnv\n", "from ray.rllib.algorithms.ppo import PPOConfig\n", "import yaml\n", - "from ray import air, tune\n" + "from ray import air, tune\n", + "from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule\n", + "from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec\n" ] }, { @@ -124,25 +126,15 @@ "source": [ "config = (\n", " PPOConfig()\n", - " .environment(env=PrimaiteRayEnv, env_config=cfg)\n", + " .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)\n", + " .environment(env=PrimaiteRayEnv, env_config=cfg, action_mask_key=\"action_mask\")\n", + " .rl_module(rl_module_spec=SingleAgentRLModuleSpec(module_class = ActionMaskingTorchRLModule))\n", " .env_runners(num_env_runners=0)\n", " .training(train_batch_size=128)\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tune.Tuner(\n", - " \"PPO\",\n", - " run_config=air.RunConfig(\n", - " stop={\"timesteps_total\": 512}\n", - " ),\n", - " param_space=config\n", - ").fit()\n" + ")\n", + "algo = config.build()\n", + "for i in range(2):\n", + " results = algo.train()" ] }, { @@ -159,6 +151,7 @@ "metadata": {}, "outputs": [], "source": [ + "from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec\n", "from primaite.session.ray_envs import PrimaiteRayMARLEnv\n", "from primaite.config.load import data_manipulation_marl_config_path" ] @@ -184,20 +177,20 @@ " PPOConfig()\n", " .multi_agent(\n", " policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n", - " policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n", + " policy_mapping_fn=lambda agent_id, *args, **kwargs: agent_id,\n", " )\n", - " .environment(env=PrimaiteRayMARLEnv, env_config=cfg)\n", + " .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)\n", + " .environment(env=PrimaiteRayMARLEnv, env_config=cfg, action_mask_key=\"action_mask\")\n", + " .rl_module(rl_module_spec=MultiAgentRLModuleSpec(module_specs={\n", + " \"defender_1\":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),\n", + " \"defender_2\":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),\n", + " }))\n", " .env_runners(num_env_runners=0)\n", " .training(train_batch_size=128)\n", - " )\n", - "\n", - "tune.Tuner(\n", - " \"PPO\",\n", - " run_config=air.RunConfig(\n", - " stop={\"timesteps_total\": 5 * 128},\n", - " ),\n", - " param_space=config\n", - ").fit()" + ")\n", + "algo = config.build()\n", + "for i in range(2):\n", + " results = algo.train()" ] } ], diff --git a/src/primaite/session/ray_envs.py b/src/primaite/session/ray_envs.py index 12167f89..1adc324c 100644 --- a/src/primaite/session/ray_envs.py +++ b/src/primaite/session/ray_envs.py @@ -187,7 +187,7 @@ class PrimaiteRayEnv(gymnasium.Env): # if action masking is enabled, intercept the step method and add action mask to observation if self.env.agent.action_masking: obs, *_ = self.env.step(action) - new_obs = {"action_mask": self.env.action_masks(), "observations": obs} + new_obs = {"action_mask": self.game.action_mask(self.env._agent_name), "observations": obs} return new_obs, *_ else: return self.env.step(action) diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml index 971f36f8..a2d64605 100644 --- a/tests/assets/configs/multi_agent_session.yaml +++ b/tests/assets/configs/multi_agent_session.yaml @@ -1,3 +1,10 @@ +io_settings: + save_agent_actions: false + save_step_metadata: false + save_pcap_logs: false + save_sys_logs: false + + game: max_episode_length: 128 ports: @@ -13,31 +20,105 @@ game: agents: - ref: client_2_green_user team: GREEN - type: PeriodicAgent + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 observation_space: null action_space: action_list: - type: DONOTHING - type: NODE_APPLICATION_EXECUTE - options: nodes: - node_name: client_2 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 reward_function: reward_components: - - type: DUMMY + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_2 + + - ref: client_1_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_1 + + + - agent_settings: # options specific to this particular agent type, basically args of __init__(self) - start_settings: - start_step: 25 - frequency: 20 - variance: 5 - ref: data_manipulation_attacker team: RED @@ -57,6 +138,9 @@ agents: - node_name: client_1 applications: - application_name: DataManipulationBot + - node_name: client_2 + applications: + - application_name: DataManipulationBot max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 @@ -71,7 +155,7 @@ agents: frequency: 20 variance: 5 - - ref: defender1 + - ref: defender_1 team: BLUE type: ProxyAgent @@ -194,318 +278,425 @@ agents: 3: action: "NODE_SERVICE_START" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 4: action: "NODE_SERVICE_PAUSE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 5: action: "NODE_SERVICE_RESUME" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 6: action: "NODE_SERVICE_RESTART" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 7: action: "NODE_SERVICE_DISABLE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 8: action: "NODE_SERVICE_ENABLE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 9: # check database.db file action: "NODE_FILE_SCAN" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 10: - action: "NODE_FILE_CHECKHASH" + action: "NODE_FILE_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 11: action: "NODE_FILE_DELETE" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 12: action: "NODE_FILE_REPAIR" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 13: action: "NODE_SERVICE_FIX" options: - node_id: 2 - service_id: 0 + node_id: 2 + service_id: 0 14: action: "NODE_FOLDER_SCAN" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 15: - action: "NODE_FOLDER_CHECKHASH" + action: "NODE_FOLDER_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 18: action: "NODE_OS_SCAN" options: - node_id: 2 - 19: # shutdown client 1 + node_id: 0 + 19: action: "NODE_SHUTDOWN" options: - node_id: 5 + node_id: 0 20: - action: "NODE_STARTUP" + action: NODE_STARTUP options: - node_id: 5 + node_id: 0 21: - action: "NODE_RESET" + action: NODE_RESET options: - node_id: 5 - 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite) - action: "ROUTER_ACL_ADDRULE" + node_id: 0 + 22: + action: "NODE_OS_SCAN" options: - target_router: router_1 - position: 1 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite) - action: "ROUTER_ACL_ADDRULE" + node_id: 1 + 23: + action: "NODE_SHUTDOWN" options: - target_router: router_1 - position: 2 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 24: # block tcp traffic from client 1 to web app - action: "ROUTER_ACL_ADDRULE" + node_id: 1 + 24: + action: NODE_STARTUP options: - target_router: router_1 - position: 3 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 25: # block tcp traffic from client 2 to web app - action: "ROUTER_ACL_ADDRULE" + node_id: 1 + 25: + action: NODE_RESET options: - target_router: router_1 - position: 4 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 3 # web server - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 - 26: - action: "ROUTER_ACL_ADDRULE" + node_id: 1 + 26: # old action num: 18 + action: "NODE_OS_SCAN" options: - target_router: router_1 - position: 5 - permission: 2 - source_ip_id: 7 # client 1 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + node_id: 2 27: - action: "ROUTER_ACL_ADDRULE" + action: "NODE_SHUTDOWN" options: - target_router: router_1 - position: 6 - permission: 2 - source_ip_id: 8 # client 2 - dest_ip_id: 4 # database - source_port_id: 1 - dest_port_id: 1 - protocol_id: 3 - source_wildcard_id: 0 - dest_wildcard_id: 0 + node_id: 2 28: - action: "ROUTER_ACL_REMOVERULE" + action: NODE_STARTUP options: - target_router: router_1 - position: 0 + node_id: 2 29: - action: "ROUTER_ACL_REMOVERULE" + action: NODE_RESET options: - target_router: router_1 - position: 1 + node_id: 2 30: - action: "ROUTER_ACL_REMOVERULE" + action: "NODE_OS_SCAN" options: - target_router: router_1 - position: 2 + node_id: 3 31: - action: "ROUTER_ACL_REMOVERULE" + action: "NODE_SHUTDOWN" options: - target_router: router_1 - position: 3 + node_id: 3 32: - action: "ROUTER_ACL_REMOVERULE" + action: NODE_STARTUP options: - target_router: router_1 - position: 4 + node_id: 3 33: - action: "ROUTER_ACL_REMOVERULE" + action: NODE_RESET options: - target_router: router_1 - position: 5 + node_id: 3 34: - action: "ROUTER_ACL_REMOVERULE" + action: "NODE_OS_SCAN" options: - target_router: router_1 - position: 6 + node_id: 4 35: - action: "ROUTER_ACL_REMOVERULE" + action: "NODE_SHUTDOWN" options: - target_router: router_1 - position: 7 + node_id: 4 36: - action: "ROUTER_ACL_REMOVERULE" + action: NODE_STARTUP options: - target_router: router_1 - position: 8 + node_id: 4 37: - action: "ROUTER_ACL_REMOVERULE" + action: NODE_RESET options: - target_router: router_1 - position: 9 + node_id: 4 38: - action: "HOST_NIC_DISABLE" + action: "NODE_OS_SCAN" options: - node_id: 0 - nic_id: 0 - 39: - action: "HOST_NIC_ENABLE" + node_id: 5 + 39: # old action num: 19 # shutdown client 1 + action: "NODE_SHUTDOWN" options: - node_id: 0 - nic_id: 0 - 40: - action: "HOST_NIC_DISABLE" + node_id: 5 + 40: # old action num: 20 + action: NODE_STARTUP options: - node_id: 1 - nic_id: 0 - 41: - action: "HOST_NIC_ENABLE" + node_id: 5 + 41: # old action num: 21 + action: NODE_RESET options: - node_id: 1 - nic_id: 0 + node_id: 5 42: - action: "HOST_NIC_DISABLE" + action: "NODE_OS_SCAN" options: - node_id: 2 - nic_id: 0 + node_id: 6 43: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 44: + action: NODE_STARTUP + options: + node_id: 6 + 45: + action: NODE_RESET + options: + node_id: 6 + + 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 1 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 2 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 48: # old action num: 24 # block tcp traffic from client 1 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 3 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 49: # old action num: 25 # block tcp traffic from client 2 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 4 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 50: # old action num: 26 + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 5 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 51: # old action num: 27 + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 6 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 52: # old action num: 28 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 0 + 53: # old action num: 29 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 1 + 54: # old action num: 30 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 2 + 55: # old action num: 31 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 3 + 56: # old action num: 32 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 4 + 57: # old action num: 33 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 5 + 58: # old action num: 34 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 6 + 59: # old action num: 35 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 7 + 60: # old action num: 36 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 8 + 61: # old action num: 37 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 9 + 62: # old action num: 38 + action: "HOST_NIC_DISABLE" + options: + node_id: 0 + nic_id: 0 + 63: # old action num: 39 + action: "HOST_NIC_ENABLE" + options: + node_id: 0 + nic_id: 0 + 64: # old action num: 40 + action: "HOST_NIC_DISABLE" + options: + node_id: 1 + nic_id: 0 + 65: # old action num: 41 + action: "HOST_NIC_ENABLE" + options: + node_id: 1 + nic_id: 0 + 66: # old action num: 42 + action: "HOST_NIC_DISABLE" + options: + node_id: 2 + nic_id: 0 + 67: # old action num: 43 action: "HOST_NIC_ENABLE" options: node_id: 2 nic_id: 0 - 44: + 68: # old action num: 44 action: "HOST_NIC_DISABLE" options: node_id: 3 nic_id: 0 - 45: + 69: # old action num: 45 action: "HOST_NIC_ENABLE" options: node_id: 3 nic_id: 0 - 46: + 70: # old action num: 46 action: "HOST_NIC_DISABLE" options: node_id: 4 nic_id: 0 - 47: + 71: # old action num: 47 action: "HOST_NIC_ENABLE" options: node_id: 4 nic_id: 0 - 48: + 72: # old action num: 48 action: "HOST_NIC_DISABLE" options: node_id: 4 nic_id: 1 - 49: + 73: # old action num: 49 action: "HOST_NIC_ENABLE" options: node_id: 4 nic_id: 1 - 50: + 74: # old action num: 50 action: "HOST_NIC_DISABLE" options: node_id: 5 nic_id: 0 - 51: + 75: # old action num: 51 action: "HOST_NIC_ENABLE" options: node_id: 5 nic_id: 0 - 52: + 76: # old action num: 52 action: "HOST_NIC_DISABLE" options: node_id: 6 nic_id: 0 - 53: + 77: # old action num: 53 action: "HOST_NIC_ENABLE" options: node_id: 6 nic_id: 0 - options: nodes: - node_name: domain_controller - node_name: web_server + applications: + - application_name: DatabaseClient + services: + - service_name: WebServer - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db + services: + - service_name: DatabaseService - node_name: backup_server - node_name: security_suite - node_name: client_1 - node_name: client_2 + max_folders_per_node: 2 max_files_per_folder: 2 max_services_per_node: 2 @@ -521,27 +712,30 @@ agents: - 192.168.10.22 - 192.168.10.110 + reward_function: reward_components: - type: DATABASE_FILE_INTEGRITY - weight: 0.5 + weight: 0.40 options: node_hostname: database_server folder_name: database file_name: database.db - - - - type: WEB_SERVER_404_PENALTY - weight: 0.5 + - type: SHARED_REWARD + weight: 1.0 options: - node_hostname: web_server - service_name: web_server_web_service + agent_name: client_1_green_user + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_2_green_user agent_settings: - # ... + flatten_obs: true + action_masking: true - - ref: defender2 + - ref: defender_2 team: BLUE type: ProxyAgent @@ -640,7 +834,11 @@ agents: - type: NODE_STARTUP - type: NODE_RESET - type: ROUTER_ACL_ADDRULE + options: + target_router: router_1 - type: ROUTER_ACL_REMOVERULE + options: + target_router: router_1 - type: HOST_NIC_ENABLE - type: HOST_NIC_DISABLE @@ -664,99 +862,196 @@ agents: 3: action: "NODE_SERVICE_START" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 4: action: "NODE_SERVICE_PAUSE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 5: action: "NODE_SERVICE_RESUME" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 6: action: "NODE_SERVICE_RESTART" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 7: action: "NODE_SERVICE_DISABLE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 8: action: "NODE_SERVICE_ENABLE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 9: # check database.db file action: "NODE_FILE_SCAN" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 10: - action: "NODE_FILE_CHECKHASH" + action: "NODE_FILE_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 11: action: "NODE_FILE_DELETE" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 12: action: "NODE_FILE_REPAIR" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 13: action: "NODE_SERVICE_FIX" options: - node_id: 2 - service_id: 0 + node_id: 2 + service_id: 0 14: action: "NODE_FOLDER_SCAN" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 15: - action: "NODE_FOLDER_CHECKHASH" + action: "NODE_FOLDER_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 18: action: "NODE_OS_SCAN" options: - node_id: 2 - 19: # shutdown client 1 + node_id: 0 + 19: action: "NODE_SHUTDOWN" options: - node_id: 5 + node_id: 0 20: - action: "NODE_STARTUP" + action: NODE_STARTUP options: - node_id: 5 + node_id: 0 21: - action: "NODE_RESET" + action: NODE_RESET options: - node_id: 5 - 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite) + node_id: 0 + 22: + action: "NODE_OS_SCAN" + options: + node_id: 1 + 23: + action: "NODE_SHUTDOWN" + options: + node_id: 1 + 24: + action: NODE_STARTUP + options: + node_id: 1 + 25: + action: NODE_RESET + options: + node_id: 1 + 26: # old action num: 18 + action: "NODE_OS_SCAN" + options: + node_id: 2 + 27: + action: "NODE_SHUTDOWN" + options: + node_id: 2 + 28: + action: NODE_STARTUP + options: + node_id: 2 + 29: + action: NODE_RESET + options: + node_id: 2 + 30: + action: "NODE_OS_SCAN" + options: + node_id: 3 + 31: + action: "NODE_SHUTDOWN" + options: + node_id: 3 + 32: + action: NODE_STARTUP + options: + node_id: 3 + 33: + action: NODE_RESET + options: + node_id: 3 + 34: + action: "NODE_OS_SCAN" + options: + node_id: 4 + 35: + action: "NODE_SHUTDOWN" + options: + node_id: 4 + 36: + action: NODE_STARTUP + options: + node_id: 4 + 37: + action: NODE_RESET + options: + node_id: 4 + 38: + action: "NODE_OS_SCAN" + options: + node_id: 5 + 39: # old action num: 19 # shutdown client 1 + action: "NODE_SHUTDOWN" + options: + node_id: 5 + 40: # old action num: 20 + action: NODE_STARTUP + options: + node_id: 5 + 41: # old action num: 21 + action: NODE_RESET + options: + node_id: 5 + 42: + action: "NODE_OS_SCAN" + options: + node_id: 6 + 43: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 44: + action: NODE_STARTUP + options: + node_id: 6 + 45: + action: NODE_RESET + options: + node_id: 6 + + 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" action: "ROUTER_ACL_ADDRULE" options: target_router: router_1 @@ -769,7 +1064,7 @@ agents: protocol_id: 1 source_wildcard_id: 0 dest_wildcard_id: 0 - 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite) + 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" action: "ROUTER_ACL_ADDRULE" options: target_router: router_1 @@ -782,7 +1077,7 @@ agents: protocol_id: 1 source_wildcard_id: 0 dest_wildcard_id: 0 - 24: # block tcp traffic from client 1 to web app + 48: # old action num: 24 # block tcp traffic from client 1 to web app action: "ROUTER_ACL_ADDRULE" options: target_router: router_1 @@ -795,7 +1090,7 @@ agents: protocol_id: 3 source_wildcard_id: 0 dest_wildcard_id: 0 - 25: # block tcp traffic from client 2 to web app + 49: # old action num: 25 # block tcp traffic from client 2 to web app action: "ROUTER_ACL_ADDRULE" options: target_router: router_1 @@ -808,7 +1103,7 @@ agents: protocol_id: 3 source_wildcard_id: 0 dest_wildcard_id: 0 - 26: + 50: # old action num: 26 action: "ROUTER_ACL_ADDRULE" options: target_router: router_1 @@ -821,7 +1116,7 @@ agents: protocol_id: 3 source_wildcard_id: 0 dest_wildcard_id: 0 - 27: + 51: # old action num: 27 action: "ROUTER_ACL_ADDRULE" options: target_router: router_1 @@ -834,67 +1129,159 @@ agents: protocol_id: 3 source_wildcard_id: 0 dest_wildcard_id: 0 - 28: + 52: # old action num: 28 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 0 - 29: + 53: # old action num: 29 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 1 - 30: + 54: # old action num: 30 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 2 - 31: + 55: # old action num: 31 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 3 - 32: + 56: # old action num: 32 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 4 - 33: + 57: # old action num: 33 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 5 - 34: + 58: # old action num: 34 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 6 - 35: + 59: # old action num: 35 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 7 - 36: + 60: # old action num: 36 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 8 - 37: + 61: # old action num: 37 action: "ROUTER_ACL_REMOVERULE" options: target_router: router_1 position: 9 + 62: # old action num: 38 + action: "HOST_NIC_DISABLE" + options: + node_id: 0 + nic_id: 0 + 63: # old action num: 39 + action: "HOST_NIC_ENABLE" + options: + node_id: 0 + nic_id: 0 + 64: # old action num: 40 + action: "HOST_NIC_DISABLE" + options: + node_id: 1 + nic_id: 0 + 65: # old action num: 41 + action: "HOST_NIC_ENABLE" + options: + node_id: 1 + nic_id: 0 + 66: # old action num: 42 + action: "HOST_NIC_DISABLE" + options: + node_id: 2 + nic_id: 0 + 67: # old action num: 43 + action: "HOST_NIC_ENABLE" + options: + node_id: 2 + nic_id: 0 + 68: # old action num: 44 + action: "HOST_NIC_DISABLE" + options: + node_id: 3 + nic_id: 0 + 69: # old action num: 45 + action: "HOST_NIC_ENABLE" + options: + node_id: 3 + nic_id: 0 + 70: # old action num: 46 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 0 + 71: # old action num: 47 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 0 + 72: # old action num: 48 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 1 + 73: # old action num: 49 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 1 + 74: # old action num: 50 + action: "HOST_NIC_DISABLE" + options: + node_id: 5 + nic_id: 0 + 75: # old action num: 51 + action: "HOST_NIC_ENABLE" + options: + node_id: 5 + nic_id: 0 + 76: # old action num: 52 + action: "HOST_NIC_DISABLE" + options: + node_id: 6 + nic_id: 0 + 77: # old action num: 53 + action: "HOST_NIC_ENABLE" + options: + node_id: 6 + nic_id: 0 + options: nodes: - node_name: domain_controller - node_name: web_server + applications: + - application_name: DatabaseClient + services: + - service_name: WebServer - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db + services: + - service_name: DatabaseService - node_name: backup_server - node_name: security_suite - node_name: client_1 - node_name: client_2 + max_folders_per_node: 2 max_files_per_folder: 2 max_services_per_node: 2 @@ -913,50 +1300,63 @@ agents: reward_function: reward_components: - type: DATABASE_FILE_INTEGRITY - weight: 0.5 + weight: 0.40 options: node_hostname: database_server folder_name: database file_name: database.db - - - - type: WEB_SERVER_404_PENALTY - weight: 0.5 + - type: SHARED_REWARD + weight: 1.0 options: - node_hostname: web_server - service_name: web_server_web_service + agent_name: client_1_green_user + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_2_green_user agent_settings: - # ... - + flatten_obs: true + action_masking: true simulation: network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE nodes: - - type: router - hostname: router_1 + - hostname: router_1 + type: router num_ports: 5 ports: 1: ip_address: 192.168.1.1 subnet_mask: 255.255.255.0 2: - ip_address: 192.168.1.1 + ip_address: 192.168.10.1 subnet_mask: 255.255.255.0 acl: - 0: + 18: action: PERMIT src_port: POSTGRES_SERVER dst_port: POSTGRES_SERVER - 1: + 19: action: PERMIT src_port: DNS dst_port: DNS + 20: + action: PERMIT + src_port: FTP + dst_port: FTP + 21: + action: PERMIT + src_port: HTTP + dst_port: HTTP 22: action: PERMIT src_port: ARP @@ -965,16 +1365,16 @@ simulation: action: PERMIT protocol: ICMP - - type: switch - hostname: switch_1 + - hostname: switch_1 + type: switch num_ports: 8 - - type: switch - hostname: switch_2 + - hostname: switch_2 + type: switch num_ports: 8 - - type: server - hostname: domain_controller + - hostname: domain_controller + type: server ip_address: 192.168.1.10 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -984,8 +1384,8 @@ simulation: domain_mapping: arcd.com: 192.168.1.12 # web server - - type: server - hostname: web_server + - hostname: web_server + type: server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -997,17 +1397,21 @@ simulation: options: db_server_ip: 192.168.1.14 - - type: server - hostname: database_server + + - hostname: database_server + type: server ip_address: 192.168.1.14 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: - type: DatabaseService + options: + backup_server_ip: 192.168.1.16 + - type: FTPClient - - type: server - hostname: backup_server + - hostname: backup_server + type: server ip_address: 192.168.1.16 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -1015,8 +1419,8 @@ simulation: services: - type: FTPServer - - type: server - hostname: security_suite + - hostname: security_suite + type: server ip_address: 192.168.1.110 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -1026,8 +1430,8 @@ simulation: ip_address: 192.168.10.110 subnet_mask: 255.255.255.0 - - type: computer - hostname: client_1 + - hostname: client_1 + type: computer ip_address: 192.168.10.21 subnet_mask: 255.255.255.0 default_gateway: 192.168.10.1 @@ -1035,24 +1439,43 @@ simulation: applications: - type: DataManipulationBot options: - port_scan_p_of_success: 0.1 - data_manipulation_p_of_success: 0.1 + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 services: - type: DNSClient - - type: computer - hostname: client_2 + - hostname: client_2 + type: computer ip_address: 192.168.10.22 subnet_mask: 255.255.255.0 default_gateway: 192.168.10.1 dns_server: 192.168.1.10 applications: - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 services: - type: DNSClient + + links: - endpoint_a_hostname: router_1 endpoint_a_port: 1 diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index 54143af0..7c894ba0 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -557,6 +557,7 @@ agents: agent_settings: flatten_obs: true + action_masking: true diff --git a/tests/e2e_integration_tests/action_masking/__init__.py b/tests/e2e_integration_tests/action_masking/__init__.py new file mode 100644 index 00000000..be6c00e7 --- /dev/null +++ b/tests/e2e_integration_tests/action_masking/__init__.py @@ -0,0 +1 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK diff --git a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py new file mode 100644 index 00000000..3efda71a --- /dev/null +++ b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py @@ -0,0 +1,160 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import importlib +from typing import Dict + +import yaml +from ray import air, init, tune +from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule +from sb3_contrib import MaskablePPO + +from primaite.game.game import PrimaiteGame +from primaite.session.environment import PrimaiteGymEnv +from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv +from tests import TEST_ASSETS_ROOT + +init(local_mode=True) + +CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml" +MARL_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml" + + +def test_sb3_action_masking(monkeypatch): + # There's no simple way of capturing what the action mask was at every step, therefore we are mocking the action + # mask function here to save the output of the action mask method and pass through the result back to the agent. + old_action_mask_method = PrimaiteGame.action_mask + mask_history = [] + + def cache_action_mask(obj, agent_name): + mask = old_action_mask_method(obj, agent_name) + mask_history.append(mask) + return mask + + # Even though it's easy to know which CAOS action the agent took by looking at agent history, we don't know which + # action map action integer that was, therefore we cache it by using monkeypatch + action_num_history = [] + + def cache_step(env, action: int): + action_num_history.append(action) + return PrimaiteGymEnv.step(env, action) + + monkeypatch.setattr(PrimaiteGame, "action_mask", cache_action_mask) + env = PrimaiteGymEnv(CFG_PATH) + monkeypatch.setattr(env, "step", lambda action: cache_step(env, action)) + + model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, batch_size=32) + model.learn(512) + + assert len(action_num_history) == len(mask_history) > 0 + # Make sure the masks had at least some False entries, if it was all True then the mask was disabled + assert any([not all(x) for x in mask_history]) + # When the agent takes action N from its action map, we need to have a look at the action mask and make sure that + # the N-th entry was True, meaning that it was a valid action at that step. + # This plucks out the mask history at step i, and at action entry a and checks that it's set to True, and this + # happens for all steps i in the episode + assert all(mask_history[i][a] for i, a in enumerate(action_num_history)) + monkeypatch.undo() + + +def test_ray_single_agent_action_masking(monkeypatch): + """Check that a Ray agent uses the action mask and never chooses invalid actions.""" + with open(CFG_PATH, "r") as f: + cfg = yaml.safe_load(f) + for agent in cfg["agents"]: + if agent["ref"] == "defender": + agent["agent_settings"]["flatten_obs"] = True + + # There's no simple way of capturing what the action mask was at every step, therefore we are mocking the step + # function to save the action mask and the agent's chosen action to a local variable. + old_step_method = PrimaiteRayEnv.step + action_num_history = [] + mask_history = [] + + def cache_step(self, action: int): + action_num_history.append(action) + obs, *_ = old_step_method(self, action) + action_mask = obs["action_mask"] + mask_history.append(action_mask) + return obs, *_ + + monkeypatch.setattr(PrimaiteRayEnv, "step", lambda *args, **kwargs: cache_step(*args, **kwargs)) + + # Configure Ray PPO to use action masking by using the ActionMaskingTorchRLModule + config = ( + PPOConfig() + .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True) + .environment(env=PrimaiteRayEnv, env_config=cfg, action_mask_key="action_mask") + .rl_module(rl_module_spec=SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule)) + .env_runners(num_env_runners=0) + .training(train_batch_size=128) + ) + algo = config.build() + algo.train() + + assert len(action_num_history) == len(mask_history) > 0 + # Make sure the masks had at least some False entries, if it was all True then the mask was disabled + assert any([not all(x) for x in mask_history]) + # When the agent takes action N from its action map, we need to have a look at the action mask and make sure that + # the N-th action was valid. + # The first step uses the action mask provided by the reset method, so we are only checking from the second step + # onward, that's why we need to use mask_history[:-1] and action_num_history[1:] + assert all(mask_history[:-1][i][a] for i, a in enumerate(action_num_history[1:])) + monkeypatch.undo() + + +def test_ray_multi_agent_action_masking(monkeypatch): + """Check that Ray agents never take invalid actions when using MARL.""" + with open(MARL_PATH, "r") as f: + cfg = yaml.safe_load(f) + + old_step_method = PrimaiteRayMARLEnv.step + action_num_history = {"defender_1": [], "defender_2": []} + mask_history = {"defender_1": [], "defender_2": []} + + def cache_step(self, actions: Dict[str, int]): + for agent_name, action in actions.items(): + action_num_history[agent_name].append(action) + obs, *_ = old_step_method(self, actions) + for ( + agent_name, + o, + ) in obs.items(): + mask_history[agent_name].append(o["action_mask"]) + return obs, *_ + + monkeypatch.setattr(PrimaiteRayMARLEnv, "step", lambda *args, **kwargs: cache_step(*args, **kwargs)) + + config = ( + PPOConfig() + .multi_agent( + policies={ + "defender_1", + "defender_2", + }, # These names are the same as the agents defined in the example config. + policy_mapping_fn=lambda agent_id, *args, **kwargs: agent_id, + ) + .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True) + .environment(env=PrimaiteRayMARLEnv, env_config=cfg, action_mask_key="action_mask") + .rl_module( + rl_module_spec=MultiAgentRLModuleSpec( + module_specs={ + "defender_1": SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule), + "defender_2": SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule), + } + ) + ) + .env_runners(num_env_runners=0) + .training(train_batch_size=128) + ) + algo = config.build() + algo.train() + + for agent_name in ["defender_1", "defender_2"]: + act_hist = action_num_history[agent_name] + mask_hist = mask_history[agent_name] + assert len(act_hist) == len(mask_hist) > 0 + assert any([not all(x) for x in mask_hist]) + assert all(mask_hist[:-1][i][a] for i, a in enumerate(act_hist[1:])) + monkeypatch.undo() diff --git a/tests/integration_tests/game_layer/actions/test_configure_actions.py b/tests/integration_tests/game_layer/actions/test_configure_actions.py index b7acc8a8..0c9ec6f0 100644 --- a/tests/integration_tests/game_layer/actions/test_configure_actions.py +++ b/tests/integration_tests/game_layer/actions/test_configure_actions.py @@ -99,7 +99,7 @@ class TestConfigureDatabaseAction: game.step() assert db_client.server_ip_address == old_ip - assert db_client.server_password is "admin123" + assert db_client.server_password == "admin123" class TestConfigureRansomwareScriptAction: From caa6a4809c724bc1cecec90c3b1afd57dd900243 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Wed, 10 Jul 2024 13:14:22 +0100 Subject: [PATCH 20/35] #2740: tests + implement file request validators --- .../simulator/file_system/file_system.py | 5 +- src/primaite/simulator/file_system/folder.py | 5 +- .../actions/test_file_request_permission.py | 83 ++++++++++++++++++- 3 files changed, 82 insertions(+), 11 deletions(-) diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index 42aa0573..bcb0334e 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -693,7 +693,4 @@ class FileSystem(SimComponent): @property def fail_message(self) -> str: """Message that is reported when a request is rejected by this validator.""" - return ( - f"Cannot perform request on application '{self.application.name}' because it is not in the " - f"{self.state.name} state." - ) + return "Cannot perform request on file that does not exist." diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index af7cc660..d891641e 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -489,7 +489,4 @@ class Folder(FileSystemItemABC): @property def fail_message(self) -> str: """Message that is reported when a request is rejected by this validator.""" - return ( - f"Cannot perform request on application '{self.application.name}' because it is not in the " - f"{self.state.name} state." - ) + return "Cannot perform request on file that does not exist." diff --git a/tests/integration_tests/game_layer/actions/test_file_request_permission.py b/tests/integration_tests/game_layer/actions/test_file_request_permission.py index c422ad43..1c143aed 100644 --- a/tests/integration_tests/game_layer/actions/test_file_request_permission.py +++ b/tests/integration_tests/game_layer/actions/test_file_request_permission.py @@ -22,7 +22,7 @@ def game_and_agent_fixture(game_and_agent): def test_create_file(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): - """Test that the validator allows a folder to be created.""" + """Test that the validator allows a files to be created.""" game, agent = game_and_agent_fixture client_1 = game.simulation.network.get_node_by_hostname("client_1") @@ -43,7 +43,7 @@ def test_create_file(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): def test_file_delete_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): - """Test that the validator allows a folder to be created.""" + """Test that the validator allows a file to be deleted.""" game, agent = game_and_agent_fixture client_1 = game.simulation.network.get_node_by_hostname("client_1") @@ -61,7 +61,7 @@ def test_file_delete_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAge def test_file_scan_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): - """Test that the validator allows a folder to be created.""" + """Test that the validator allows a file to be scanned.""" game, agent = game_and_agent_fixture client_1 = game.simulation.network.get_node_by_hostname("client_1") @@ -80,3 +80,80 @@ def test_file_scan_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent assert file.health_status == FileSystemItemHealthStatus.CORRUPT assert file.visible_health_status == FileSystemItemHealthStatus.CORRUPT + + +def test_file_repair_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that the validator allows a folder to be created.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + file = client_1.file_system.get_file(folder_name="downloads", file_name="cat.png") + + file.corrupt() + assert file.health_status == FileSystemItemHealthStatus.CORRUPT + + action = ( + "NODE_FILE_REPAIR", + {"node_id": 0, "folder_id": 0, "file_id": 0}, + ) + agent.store_action(action) + game.step() + + assert file.health_status == FileSystemItemHealthStatus.GOOD + + +def test_file_restore_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that the validator allows a file to be restored.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + file = client_1.file_system.get_file(folder_name="downloads", file_name="cat.png") + + file.corrupt() + assert file.health_status == FileSystemItemHealthStatus.CORRUPT + + action = ( + "NODE_FILE_RESTORE", + {"node_id": 0, "folder_id": 0, "file_id": 0}, + ) + agent.store_action(action) + game.step() + + assert file.health_status == FileSystemItemHealthStatus.GOOD + + +def test_file_corrupt_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that the validator allows a file to be corrupted.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + file = client_1.file_system.get_file(folder_name="downloads", file_name="cat.png") + + assert file.health_status == FileSystemItemHealthStatus.GOOD + + action = ( + "NODE_FILE_CORRUPT", + {"node_id": 0, "folder_id": 0, "file_id": 0}, + ) + agent.store_action(action) + game.step() + + assert file.health_status == FileSystemItemHealthStatus.CORRUPT + + +def test_file_access_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + """Test that the validator allows a file to be accessed.""" + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + file = client_1.file_system.get_file(folder_name="downloads", file_name="cat.png") + assert file.num_access == 0 + + action = ( + "NODE_FILE_ACCESS", + {"node_id": 0, "folder_name": file.folder_name, "file_name": file.name}, + ) + agent.store_action(action) + game.step() + + assert file.num_access == 1 From 239f5b86c0fd33744b693f332038267ba6615fb5 Mon Sep 17 00:00:00 2001 From: "Archer.Bowen" Date: Wed, 10 Jul 2024 13:36:37 +0100 Subject: [PATCH 21/35] #2716 Agent logging now sits outside of the simulation output log directory, updated dev-mode CLI to include agent logging and added additional tests. --- CHANGELOG.md | 2 +- src/primaite/game/agent/agent_log.py | 50 +++++++++++------ .../scripted_agents/data_manipulation_bot.py | 6 +-- .../scripted_agents/probabilistic_agent.py | 2 +- src/primaite/game/game.py | 9 ++++ src/primaite/session/io.py | 17 +++--- .../setup/_package_data/primaite_config.yaml | 2 + src/primaite/simulator/__init__.py | 23 ++++++++ src/primaite/utils/cli/dev_cli.py | 27 ++++++++++ tests/integration_tests/cli/test_dev_cli.py | 54 ++++++++++++++++++- .../_primaite/_game/_agent/test_agent_log.py | 50 ++++++++--------- 11 files changed, 187 insertions(+), 55 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index beec6d11..63a4bfe8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,7 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Activating dev-mode will change the location where the sessions will be output - by default will output where the PrimAITE repository is located - Refactored all air-space usage to that a new instance of AirSpace is created for each instance of Network. This 1:1 relationship between network and airspace will allow parallelization. - Added notebook to demonstrate use of SubprocVecEnv from SB3 to vectorise environments to speed up training. -- Added a new agent simulation log which are more human friendly than agent action logging. Includes timesteps so that the agent action log can be cross referenced. These Logs are found in simulation_output directory, similar to that of sys_logs and can be enabled in the I/O settings in a yaml configuration file. +- Added a new agent behaviour log which are more human friendly than agent history. These Logs are found in session log directory and can be enabled in the I/O settings in a yaml configuration file. ## [Unreleased] - Made requests fail to reach their target if the node is off diff --git a/src/primaite/game/agent/agent_log.py b/src/primaite/game/agent/agent_log.py index 1e51dcad..62ef4884 100644 --- a/src/primaite/game/agent/agent_log.py +++ b/src/primaite/game/agent/agent_log.py @@ -22,7 +22,7 @@ class AgentLog: """ A Agent Log class is a simple logger dedicated to managing and writing logging updates and information for an agent. - Each log message is written to a file located at: /agent/agent_name.log + Each log message is written to a file located at: /agent_name/agent_name.log """ def __init__(self, agent_name: str): @@ -33,8 +33,28 @@ class AgentLog: """ self.agent_name = agent_name self.current_episode: int = 1 + self.current_timestep: int = 0 self.setup_logger() + @property + def timestep(self) -> int: + """Returns the current timestep. Used for log indexing. + + :return: The current timestep as an Int. + """ + return self.current_timestep + + def update_timestep(self, new_timestep: int): + """ + Updates the self.current_timestep attribute with the given parameter. + + This method is called within .step() to ensure that all instances of Agent Logs + are in sync with one another. + + :param new_timestep: The new timestep. + """ + self.current_timestep = new_timestep + def setup_logger(self): """ Configures the logger for this Agent Log instance. @@ -60,19 +80,19 @@ class AgentLog: def _get_log_path(self) -> Path: """ - Constructs the path for the log file based on the hostname. + Constructs the path for the log file based on the agent name. :return: Path object representing the location of the log file. """ - root = SIM_OUTPUT.path / f"episode_{self.current_episode}" / "agent_logs" / self.agent_name + root = SIM_OUTPUT.agent_behaviour_path / f"episode_{self.current_episode}" / self.agent_name root.mkdir(exist_ok=True, parents=True) return root / f"{self.agent_name}.log" - def _write_to_terminal(self, msg: str, timestep: int, level: str, to_terminal: bool = False): + def _write_to_terminal(self, msg: str, level: str, to_terminal: bool = False): if to_terminal or SIM_OUTPUT.write_agent_log_to_terminal: - print(f"{self.agent_name}: ({timestep}) ({level}) {msg}") + print(f"{self.agent_name}: ({ self.timestep}) ({level}) {msg}") - def debug(self, msg: str, time_step: int, to_terminal: bool = False): + def debug(self, msg: str, to_terminal: bool = False): """ Logs a message with the DEBUG level. @@ -83,10 +103,10 @@ class AgentLog: return if SIM_OUTPUT.save_agent_logs: - self.logger.debug(msg, extra={"timestep": time_step}) + self.logger.debug(msg, extra={"timestep": self.timestep}) self._write_to_terminal(msg, "DEBUG", to_terminal) - def info(self, msg: str, time_step: int, to_terminal: bool = False): + def info(self, msg: str, to_terminal: bool = False): """ Logs a message with the INFO level. @@ -98,10 +118,10 @@ class AgentLog: return if SIM_OUTPUT.save_agent_logs: - self.logger.info(msg, extra={"timestep": time_step}) + self.logger.info(msg, extra={"timestep": self.timestep}) self._write_to_terminal(msg, "INFO", to_terminal) - def warning(self, msg: str, time_step: int, to_terminal: bool = False): + def warning(self, msg: str, to_terminal: bool = False): """ Logs a message with the WARNING level. @@ -113,10 +133,10 @@ class AgentLog: return if SIM_OUTPUT.save_agent_logs: - self.logger.warning(msg, extra={"timestep": time_step}) + self.logger.warning(msg, extra={"timestep": self.timestep}) self._write_to_terminal(msg, "WARNING", to_terminal) - def error(self, msg: str, time_step: int, to_terminal: bool = False): + def error(self, msg: str, to_terminal: bool = False): """ Logs a message with the ERROR level. @@ -128,10 +148,10 @@ class AgentLog: return if SIM_OUTPUT.save_agent_logs: - self.logger.error(msg, extra={"timestep": time_step}) + self.logger.error(msg, extra={"timestep": self.timestep}) self._write_to_terminal(msg, "ERROR", to_terminal) - def critical(self, msg: str, time_step: int, to_terminal: bool = False): + def critical(self, msg: str, to_terminal: bool = False): """ Logs a message with the CRITICAL level. @@ -143,7 +163,7 @@ class AgentLog: return if SIM_OUTPUT.save_agent_logs: - self.logger.critical(msg, extra={"timestep": time_step}) + self.logger.critical(msg, extra={"timestep": self.timestep}) self._write_to_terminal(msg, "CRITICAL", to_terminal) def show(self, last_n: int = 10, markdown: bool = False): diff --git a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py index cd72e001..129fac1a 100644 --- a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py +++ b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py @@ -38,11 +38,11 @@ class DataManipulationAgent(AbstractScriptedAgent): :rtype: Tuple[str, Dict] """ if timestep < self.next_execution_timestep: - self.logger.debug(msg="Performing do NOTHING", time_step=timestep) + self.logger.debug(msg="Performing do NOTHING") return "DONOTHING", {} self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency) - self.logger.info(msg="Performing a data manipulation attack!", time_step=timestep) + self.logger.info(msg="Performing a data manipulation attack!") return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0} def setup_agent(self) -> None: @@ -55,4 +55,4 @@ class DataManipulationAgent(AbstractScriptedAgent): # we are assuming that every node in the node manager has a data manipulation application at idx 0 num_nodes = len(self.action_manager.node_names) self.starting_node_idx = random.randint(0, num_nodes - 1) - self.logger.debug(msg=f"Select Start Node ID: {self.starting_node_idx}", time_step=0) + self.logger.debug(msg=f"Select Start Node ID: {self.starting_node_idx}") diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py index e0f41302..f5905ad0 100644 --- a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -85,5 +85,5 @@ class ProbabilisticAgent(AbstractScriptedAgent): :rtype: Tuple[str, Dict] """ choice = self.rng.choice(len(self.action_manager.action_map), p=self.probabilities) - self.logger.info(f"Performing Action: {choice}", time_step=timestep) + self.logger.info(f"Performing Action: {choice}") return self.action_manager.get_action(choice) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 3dc9571f..cb787e68 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -15,6 +15,7 @@ from primaite.game.agent.scripted_agents.probabilistic_agent import Probabilisti from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort +from primaite.simulator import SIM_OUTPUT from primaite.simulator.network.hardware.base import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC @@ -164,6 +165,8 @@ class PrimaiteGame: for _, agent in self.agents.items(): obs = agent.observation_manager.current_observation action_choice, parameters = agent.get_action(obs, timestep=self.step_counter) + if SIM_OUTPUT.save_agent_logs: + agent.logger.debug(f"Chosen Action: {action_choice}") request = agent.format_request(action_choice, parameters) response = self.simulation.apply_request(request) agent.process_action_response( @@ -182,8 +185,14 @@ class PrimaiteGame: """Advance timestep.""" self.step_counter += 1 _LOGGER.debug(f"Advancing timestep to {self.step_counter} ") + self.update_agent_loggers() self.simulation.apply_timestep(self.step_counter) + def update_agent_loggers(self) -> None: + """Updates Agent Loggers with new timestep.""" + for agent in self.agents.values(): + agent.logger.update_timestep(self.step_counter) + def calculate_truncated(self) -> bool: """Calculate whether the episode is truncated.""" current_step = self.step_counter diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index 05a5ee09..78d7cb3c 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -57,6 +57,7 @@ class PrimaiteIO: self.session_path: Path = self.generate_session_path() # set global SIM_OUTPUT path SIM_OUTPUT.path = self.session_path / "simulation_output" + SIM_OUTPUT.agent_behaviour_path = self.session_path / "agent_behaviour" SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs SIM_OUTPUT.save_agent_logs = self.settings.save_agent_logs @@ -67,20 +68,20 @@ class PrimaiteIO: def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path: """Create a folder for the session and return the path to it.""" - if timestamp is None: - timestamp = datetime.now() - date_str = timestamp.strftime("%Y-%m-%d") - time_str = timestamp.strftime("%H-%M-%S") - - session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str + session_path = PRIMAITE_PATHS.user_sessions_path / SIM_OUTPUT.date_str / SIM_OUTPUT.time_str # check if running in dev mode if is_dev_mode(): - session_path = _PRIMAITE_ROOT.parent.parent / "sessions" / date_str / time_str + session_path = _PRIMAITE_ROOT.parent.parent / "sessions" / SIM_OUTPUT.date_str / SIM_OUTPUT.time_str # check if there is an output directory set in config if PRIMAITE_CONFIG["developer_mode"]["output_dir"]: - session_path = Path(PRIMAITE_CONFIG["developer_mode"]["output_dir"]) / "sessions" / date_str / time_str + session_path = ( + Path(PRIMAITE_CONFIG["developer_mode"]["output_dir"]) + / "sessions" + / SIM_OUTPUT.date_str + / SIM_OUTPUT.time_str + ) session_path.mkdir(exist_ok=True, parents=True) return session_path diff --git a/src/primaite/setup/_package_data/primaite_config.yaml b/src/primaite/setup/_package_data/primaite_config.yaml index c1caf1f4..e08d951e 100644 --- a/src/primaite/setup/_package_data/primaite_config.yaml +++ b/src/primaite/setup/_package_data/primaite_config.yaml @@ -3,6 +3,8 @@ developer_mode: enabled: False # not enabled by default sys_log_level: DEBUG # level of output for system logs, DEBUG by default + agent_log_level: DEBUG # level of output for agent logs, DEBUG by default + output_agent_logs: False # level of output for system logs, DEBUG by default output_sys_logs: False # system logs not output by default output_pcap_logs: False # pcap logs not output by default output_to_terminal: False # do not output to terminal by default diff --git a/src/primaite/simulator/__init__.py b/src/primaite/simulator/__init__.py index 487e7c5e..ade1a73b 100644 --- a/src/primaite/simulator/__init__.py +++ b/src/primaite/simulator/__init__.py @@ -34,6 +34,7 @@ class _SimOutput: path = PRIMAITE_PATHS.user_sessions_path / self.date_str / self.time_str self._path = path + self._agent_behaviour_path = path self._save_pcap_logs: bool = False self._save_sys_logs: bool = False self._save_agent_logs: bool = False @@ -64,6 +65,28 @@ class _SimOutput: self._path = new_path self._path.mkdir(exist_ok=True, parents=True) + @property + def agent_behaviour_path(self) -> Path: + if is_dev_mode(): + # if dev mode is enabled, if output dir is not set, print to primaite repo root + path: Path = _PRIMAITE_ROOT.parent.parent / "sessions" / self.date_str / self.time_str / "agent_behaviour" + # otherwise print to output dir + if PRIMAITE_CONFIG["developer_mode"]["output_dir"]: + path: Path = ( + Path(PRIMAITE_CONFIG["developer_mode"]["output_dir"]) + / "sessions" + / self.date_str + / self.time_str + / "agent_behaviour" + ) + self._agent_behaviour_path = path + return self._agent_behaviour_path + + @agent_behaviour_path.setter + def agent_behaviour_path(self, new_path: Path) -> None: + self._agent_behaviour_path = new_path + self._agent_behaviour_path.mkdir(exist_ok=True, parents=True) + @property def save_pcap_logs(self) -> bool: if is_dev_mode(): diff --git a/src/primaite/utils/cli/dev_cli.py b/src/primaite/utils/cli/dev_cli.py index 15adacb3..0dd9f9cc 100644 --- a/src/primaite/utils/cli/dev_cli.py +++ b/src/primaite/utils/cli/dev_cli.py @@ -82,12 +82,31 @@ def config_callback( show_default=False, ), ] = None, + agent_log_level: Annotated[ + LogLevel, + typer.Option( + "--agent-log-level", + "-level", + click_type=click.Choice(LogLevel._member_names_, case_sensitive=False), + help="The level of agent behaviour logs to output.", + show_default=False, + ), + ] = None, output_sys_logs: Annotated[ bool, typer.Option( "--output-sys-logs/--no-sys-logs", "-sys/-nsys", help="Output system logs to file.", show_default=False ), ] = None, + output_agent_logs: Annotated[ + bool, + typer.Option( + "--output-agent-logs/--no-agent-logs", + "-agent/-nagent", + help="Output agent logs to file.", + show_default=False, + ), + ] = None, output_pcap_logs: Annotated[ bool, typer.Option( @@ -109,10 +128,18 @@ def config_callback( PRIMAITE_CONFIG["developer_mode"]["sys_log_level"] = ctx.params.get("sys_log_level") print(f"PrimAITE dev-mode config updated sys_log_level={ctx.params.get('sys_log_level')}") + if ctx.params.get("agent_log_level") is not None: + PRIMAITE_CONFIG["developer_mode"]["agent_log_level"] = ctx.params.get("agent_log_level") + print(f"PrimAITE dev-mode config updated agent_log_level={ctx.params.get('agent_log_level')}") + if output_sys_logs is not None: PRIMAITE_CONFIG["developer_mode"]["output_sys_logs"] = output_sys_logs print(f"PrimAITE dev-mode config updated {output_sys_logs=}") + if output_agent_logs is not None: + PRIMAITE_CONFIG["developer_mode"]["output_agent_logs"] = output_agent_logs + print(f"PrimAITE dev-mode config updated {output_agent_logs=}") + if output_pcap_logs is not None: PRIMAITE_CONFIG["developer_mode"]["output_pcap_logs"] = output_pcap_logs print(f"PrimAITE dev-mode config updated {output_pcap_logs=}") diff --git a/tests/integration_tests/cli/test_dev_cli.py b/tests/integration_tests/cli/test_dev_cli.py index 43f623a5..19559e7c 100644 --- a/tests/integration_tests/cli/test_dev_cli.py +++ b/tests/integration_tests/cli/test_dev_cli.py @@ -67,7 +67,7 @@ def test_dev_mode_config_sys_log_level(): # check defaults assert PRIMAITE_CONFIG["developer_mode"]["sys_log_level"] == "DEBUG" # DEBUG by default - result = cli(["dev-mode", "config", "-level", "WARNING"]) + result = cli(["dev-mode", "config", "--sys-log-level", "WARNING"]) assert "sys_log_level=WARNING" in result.output # should print correct value @@ -78,10 +78,30 @@ def test_dev_mode_config_sys_log_level(): assert "sys_log_level=INFO" in result.output # should print correct value - # config should reflect that log level is WARNING + # config should reflect that log level is INFO assert PRIMAITE_CONFIG["developer_mode"]["sys_log_level"] == "INFO" +def test_dev_mode_config_agent_log_level(): + """Check that the agent log level can be changed via CLI.""" + # check defaults + assert PRIMAITE_CONFIG["developer_mode"]["agent_log_level"] == "DEBUG" # DEBUG by default + + result = cli(["dev-mode", "config", "-level", "WARNING"]) + + assert "agent_log_level=WARNING" in result.output # should print correct value + + # config should reflect that log level is WARNING + assert PRIMAITE_CONFIG["developer_mode"]["agent_log_level"] == "WARNING" + + result = cli(["dev-mode", "config", "--agent-log-level", "INFO"]) + + assert "agent_log_level=INFO" in result.output # should print correct value + + # config should reflect that log level is INFO + assert PRIMAITE_CONFIG["developer_mode"]["agent_log_level"] == "INFO" + + def test_dev_mode_config_sys_logs_enable_disable(): """Test that the system logs output can be enabled or disabled.""" # check defaults @@ -112,6 +132,36 @@ def test_dev_mode_config_sys_logs_enable_disable(): assert PRIMAITE_CONFIG["developer_mode"]["output_sys_logs"] is False +def test_dev_mode_config_agent_logs_enable_disable(): + """Test that the agent logs output can be enabled or disabled.""" + # check defaults + assert PRIMAITE_CONFIG["developer_mode"]["output_agent_logs"] is False # False by default + + result = cli(["dev-mode", "config", "--output-agent-logs"]) + assert "output_agent_logs=True" in result.output # should print correct value + + # config should reflect that output_agent_logs is True + assert PRIMAITE_CONFIG["developer_mode"]["output_agent_logs"] + + result = cli(["dev-mode", "config", "--no-agent-logs"]) + assert "output_agent_logs=False" in result.output # should print correct value + + # config should reflect that output_agent_logs is True + assert PRIMAITE_CONFIG["developer_mode"]["output_agent_logs"] is False + + result = cli(["dev-mode", "config", "-agent"]) + assert "output_agent_logs=True" in result.output # should print correct value + + # config should reflect that output_agent_logs is True + assert PRIMAITE_CONFIG["developer_mode"]["output_agent_logs"] + + result = cli(["dev-mode", "config", "-nagent"]) + assert "output_agent_logs=False" in result.output # should print correct value + + # config should reflect that output_agent_logs is True + assert PRIMAITE_CONFIG["developer_mode"]["output_agent_logs"] is False + + def test_dev_mode_config_pcap_logs_enable_disable(): """Test that the pcap logs output can be enabled or disabled.""" # check defaults diff --git a/tests/unit_tests/_primaite/_game/_agent/test_agent_log.py b/tests/unit_tests/_primaite/_game/_agent/test_agent_log.py index a7932cb7..d61e1a23 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_agent_log.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_agent_log.py @@ -29,11 +29,11 @@ def test_debug_agent_log_level(agentlog, capsys): test_string = str(uuid4()) - agentlog.debug(msg=test_string, time_step=0) - agentlog.info(msg=test_string, time_step=0) - agentlog.warning(msg=test_string, time_step=0) - agentlog.error(msg=test_string, time_step=0) - agentlog.critical(msg=test_string, time_step=0) + agentlog.debug(msg=test_string) + agentlog.info(msg=test_string) + agentlog.warning(msg=test_string) + agentlog.error(msg=test_string) + agentlog.critical(msg=test_string) captured = "".join(capsys.readouterr()) @@ -52,11 +52,11 @@ def test_info_agent_log_level(agentlog, capsys): test_string = str(uuid4()) - agentlog.debug(msg=test_string, time_step=0) - agentlog.info(msg=test_string, time_step=0) - agentlog.warning(msg=test_string, time_step=0) - agentlog.error(msg=test_string, time_step=0) - agentlog.critical(msg=test_string, time_step=0) + agentlog.debug(msg=test_string) + agentlog.info(msg=test_string) + agentlog.warning(msg=test_string) + agentlog.error(msg=test_string) + agentlog.critical(msg=test_string) captured = "".join(capsys.readouterr()) @@ -75,11 +75,11 @@ def test_warning_agent_log_level(agentlog, capsys): test_string = str(uuid4()) - agentlog.debug(msg=test_string, time_step=0) - agentlog.info(msg=test_string, time_step=0) - agentlog.warning(msg=test_string, time_step=0) - agentlog.error(msg=test_string, time_step=0) - agentlog.critical(msg=test_string, time_step=0) + agentlog.debug(msg=test_string) + agentlog.info(msg=test_string) + agentlog.warning(msg=test_string) + agentlog.error(msg=test_string) + agentlog.critical(msg=test_string) captured = "".join(capsys.readouterr()) @@ -98,11 +98,11 @@ def test_error_agent_log_level(agentlog, capsys): test_string = str(uuid4()) - agentlog.debug(msg=test_string, time_step=0) - agentlog.info(msg=test_string, time_step=0) - agentlog.warning(msg=test_string, time_step=0) - agentlog.error(msg=test_string, time_step=0) - agentlog.critical(msg=test_string, time_step=0) + agentlog.debug(msg=test_string) + agentlog.info(msg=test_string) + agentlog.warning(msg=test_string) + agentlog.error(msg=test_string) + agentlog.critical(msg=test_string) captured = "".join(capsys.readouterr()) @@ -121,11 +121,11 @@ def test_critical_agent_log_level(agentlog, capsys): test_string = str(uuid4()) - agentlog.debug(msg=test_string, time_step=0) - agentlog.info(msg=test_string, time_step=0) - agentlog.warning(msg=test_string, time_step=0) - agentlog.error(msg=test_string, time_step=0) - agentlog.critical(msg=test_string, time_step=0) + agentlog.debug(msg=test_string) + agentlog.info(msg=test_string) + agentlog.warning(msg=test_string) + agentlog.error(msg=test_string) + agentlog.critical(msg=test_string) captured = "".join(capsys.readouterr()) From 0c58c3969a901a8d508b2a429080905a311811ee Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 10 Jul 2024 13:46:30 +0100 Subject: [PATCH 22/35] 2623 - finish testing action mask --- .../simulator/system/services/service.py | 8 +- .../assets/configs/test_primaite_session.yaml | 27 ++- tests/conftest.py | 2 + .../test_agents_use_action_masks.py | 2 - .../test_rllib_multi_agent_environment.py | 3 - .../test_rllib_single_agent_environment.py | 4 - .../e2e_integration_tests/test_environment.py | 18 +- .../game_layer/test_action_mask.py | 161 ++++++++++++++++++ 8 files changed, 198 insertions(+), 27 deletions(-) create mode 100644 tests/integration_tests/game_layer/test_action_mask.py diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 8167a8a9..5adea6e7 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -92,6 +92,7 @@ class Service(IOSoftware): _is_service_running = Service._StateValidator(service=self, state=ServiceOperatingState.RUNNING) _is_service_stopped = Service._StateValidator(service=self, state=ServiceOperatingState.STOPPED) _is_service_paused = Service._StateValidator(service=self, state=ServiceOperatingState.PAUSED) + _is_service_disabled = Service._StateValidator(service=self, state=ServiceOperatingState.DISABLED) rm = super()._init_request_manager() rm.add_request( @@ -131,7 +132,12 @@ class Service(IOSoftware): ), ) rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable()))) - rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable()))) + rm.add_request( + "enable", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.enable()), validator=_is_service_disabled + ), + ) rm.add_request( "fix", RequestType( diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index 7c894ba0..c435fe44 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -243,25 +243,25 @@ agents: action: "NODE_FILE_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 10: action: "NODE_FILE_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 11: action: "NODE_FILE_DELETE" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 12: action: "NODE_FILE_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 13: action: "NODE_SERVICE_FIX" @@ -272,22 +272,22 @@ agents: action: "NODE_FOLDER_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 15: action: "NODE_FOLDER_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: node_id: 2 - folder_id: 1 + folder_id: 0 18: action: "NODE_OS_SCAN" options: @@ -518,11 +518,22 @@ agents: nodes: - node_name: domain_controller - node_name: web_server + applications: + - application_name: DatabaseClient + services: + - service_name: WebServer - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db + services: + - service_name: DatabaseService - node_name: backup_server - node_name: security_suite - node_name: client_1 - node_name: client_2 + max_folders_per_node: 2 max_files_per_folder: 2 max_services_per_node: 2 diff --git a/tests/conftest.py b/tests/conftest.py index e36a2460..adfa7724 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Tuple import pytest +import ray import yaml from primaite import getLogger, PRIMAITE_PATHS @@ -29,6 +30,7 @@ from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer from tests import TEST_ASSETS_ROOT +ray.init(local_mode=True) ACTION_SPACE_NODE_VALUES = 1 ACTION_SPACE_NODE_ACTION_VALUES = 1 diff --git a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py index 3efda71a..a299b913 100644 --- a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py +++ b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py @@ -15,8 +15,6 @@ from primaite.session.environment import PrimaiteGymEnv from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv from tests import TEST_ASSETS_ROOT -init(local_mode=True) - CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml" MARL_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml" diff --git a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py index 96ec799c..e015c33c 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py @@ -16,8 +16,6 @@ def test_rllib_multi_agent_compatibility(): with open(MULTI_AGENT_PATH, "r") as f: cfg = yaml.safe_load(f) - ray.init() - config = ( PPOConfig() .environment(env=PrimaiteRayMARLEnv, env_config=cfg) @@ -39,4 +37,3 @@ def test_rllib_multi_agent_compatibility(): ), param_space=config, ).fit() - ray.shutdown() diff --git a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py index d6cacfd2..a02a078c 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py @@ -20,9 +20,6 @@ def test_rllib_single_agent_compatibility(): game = PrimaiteGame.from_config(cfg) - ray.shutdown() - ray.init() - env_config = {"game": game} config = { "env": PrimaiteRayEnv, @@ -41,4 +38,3 @@ def test_rllib_single_agent_compatibility(): assert save_file.exists() save_file.unlink() # clean up - ray.shutdown() diff --git a/tests/e2e_integration_tests/test_environment.py b/tests/e2e_integration_tests/test_environment.py index c8238aba..253bd396 100644 --- a/tests/e2e_integration_tests/test_environment.py +++ b/tests/e2e_integration_tests/test_environment.py @@ -65,25 +65,25 @@ class TestPrimaiteEnvironment: cfg = yaml.safe_load(f) env = PrimaiteRayMARLEnv(env_config=cfg) - assert set(env._agent_ids) == {"defender1", "defender2"} + assert set(env._agent_ids) == {"defender_1", "defender_2"} assert len(env.agents) == 2 - defender1 = env.agents["defender1"] - defender2 = env.agents["defender2"] - assert (num_actions_1 := len(defender1.action_manager.action_map)) == 54 - assert (num_actions_2 := len(defender2.action_manager.action_map)) == 38 + defender_1 = env.agents["defender_1"] + defender_2 = env.agents["defender_2"] + assert (num_actions_1 := len(defender_1.action_manager.action_map)) == 74 + assert (num_actions_2 := len(defender_2.action_manager.action_map)) == 74 # ensure we can run all valid actions without error for act_1 in range(num_actions_1): - env.step({"defender1": act_1, "defender2": 0}) + env.step({"defender_1": act_1, "defender_2": 0}) for act_2 in range(num_actions_2): - env.step({"defender1": 0, "defender2": act_2}) + env.step({"defender_1": 0, "defender_2": act_2}) # ensure we get error when taking an invalid action with pytest.raises(KeyError): - env.step({"defender1": num_actions_1, "defender2": 0}) + env.step({"defender_1": num_actions_1, "defender_2": 0}) with pytest.raises(KeyError): - env.step({"defender1": 0, "defender2": num_actions_2}) + env.step({"defender_1": 0, "defender_2": num_actions_2}) def test_error_thrown_on_bad_configuration(self): """Make sure we throw an error when the config is bad.""" diff --git a/tests/integration_tests/game_layer/test_action_mask.py b/tests/integration_tests/game_layer/test_action_mask.py new file mode 100644 index 00000000..64464724 --- /dev/null +++ b/tests/integration_tests/game_layer/test_action_mask.py @@ -0,0 +1,161 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from primaite.session.environment import PrimaiteGymEnv +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.host.host_node import HostNode +from primaite.simulator.system.services.service import ServiceOperatingState +from tests.conftest import TEST_ASSETS_ROOT + +CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml" + + +def test_mask_contents_correct(): + env = PrimaiteGymEnv(CFG_PATH) + game = env.game + sim = game.simulation + net = sim.network + mask = game.action_mask("defender") + agent = env.agent + node_list = agent.action_manager.node_names + action_map = agent.action_manager.action_map + + # CHECK NIC ENABLE/DISABLE ACTIONS + for action_num, action in action_map.items(): + mask = game.action_mask("defender") + act_type, act_params = action + + if act_type == "NODE_NIC_ENABLE": + node_name = node_list[act_params["node_id"]] + node_obj = net.get_node_by_hostname(node_name) + nic_obj = node_obj.network_interface[act_params["nic_id"] + 1] + assert nic_obj.enabled + assert not mask[action_num] + nic_obj.disable() + mask = game.action_mask("defender") + assert mask[action_num] + nic_obj.enable() + + if act_type == "NODE_NIC_DISABLE": + node_name = node_list[act_params["node_id"]] + node_obj = net.get_node_by_hostname(node_name) + nic_obj = node_obj.network_interface[act_params["nic_id"] + 1] + assert nic_obj.enabled + assert mask[action_num] + nic_obj.disable() + mask = game.action_mask("defender") + assert not mask[action_num] + nic_obj.enable() + + if act_type == "ROUTER_ACL_ADDRULE": + assert mask[action_num] + + if act_type == "ROUTER_ACL_REMOVERULE": + assert mask[action_num] + + if act_type == "NODE_RESET": + node_name = node_list[act_params["node_id"]] + node_obj = net.get_node_by_hostname(node_name) + assert node_obj.operating_state is NodeOperatingState.ON + assert mask[action_num] + node_obj.operating_state = NodeOperatingState.OFF + mask = game.action_mask("defender") + assert not mask[action_num] + node_obj.operating_state = NodeOperatingState.ON + + if act_type == "NODE_SHUTDOWN": + node_name = node_list[act_params["node_id"]] + node_obj = net.get_node_by_hostname(node_name) + assert node_obj.operating_state is NodeOperatingState.ON + assert mask[action_num] + node_obj.operating_state = NodeOperatingState.OFF + mask = game.action_mask("defender") + assert not mask[action_num] + node_obj.operating_state = NodeOperatingState.ON + + if act_type == "NODE_OS_SCAN": + node_name = node_list[act_params["node_id"]] + node_obj = net.get_node_by_hostname(node_name) + assert node_obj.operating_state is NodeOperatingState.ON + assert mask[action_num] + node_obj.operating_state = NodeOperatingState.OFF + mask = game.action_mask("defender") + assert not mask[action_num] + node_obj.operating_state = NodeOperatingState.ON + + if act_type == "NODE_STARTUP": + node_name = node_list[act_params["node_id"]] + node_obj = net.get_node_by_hostname(node_name) + assert node_obj.operating_state is NodeOperatingState.ON + assert not mask[action_num] + node_obj.operating_state = NodeOperatingState.OFF + mask = game.action_mask("defender") + assert mask[action_num] + node_obj.operating_state = NodeOperatingState.ON + + if act_type == "DONOTHING": + assert mask[action_num] + + if act_type == "NODE_SERVICE_DISABLE": + assert mask[action_num] + + if act_type in ["NODE_SERVICE_SCAN", "NODE_SERVICE_STOP", "NODE_SERVICE_PAUSE"]: + node_name = node_list[act_params["node_id"]] + service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]] + node_obj = net.get_node_by_hostname(node_name) + service_obj = node_obj.software_manager.software.get(service_name) + assert service_obj.operating_state is ServiceOperatingState.RUNNING + assert mask[action_num] + service_obj.operating_state = ServiceOperatingState.DISABLED + mask = game.action_mask("defender") + assert not mask[action_num] + service_obj.operating_state = ServiceOperatingState.RUNNING + + if act_type == "NODE_SERVICE_RESUME": + node_name = node_list[act_params["node_id"]] + service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]] + node_obj = net.get_node_by_hostname(node_name) + service_obj = node_obj.software_manager.software.get(service_name) + assert service_obj.operating_state is ServiceOperatingState.RUNNING + assert not mask[action_num] + service_obj.operating_state = ServiceOperatingState.PAUSED + mask = game.action_mask("defender") + assert mask[action_num] + service_obj.operating_state = ServiceOperatingState.RUNNING + + if act_type == "NODE_SERVICE_START": + node_name = node_list[act_params["node_id"]] + service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]] + node_obj = net.get_node_by_hostname(node_name) + service_obj = node_obj.software_manager.software.get(service_name) + assert service_obj.operating_state is ServiceOperatingState.RUNNING + assert not mask[action_num] + service_obj.operating_state = ServiceOperatingState.STOPPED + mask = game.action_mask("defender") + assert mask[action_num] + service_obj.operating_state = ServiceOperatingState.RUNNING + + if act_type == "NODE_SERVICE_ENABLE": + node_name = node_list[act_params["node_id"]] + service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]] + node_obj = net.get_node_by_hostname(node_name) + service_obj = node_obj.software_manager.software.get(service_name) + assert service_obj.operating_state is ServiceOperatingState.RUNNING + assert not mask[action_num] + service_obj.operating_state = ServiceOperatingState.DISABLED + mask = game.action_mask("defender") + assert mask[action_num] + service_obj.operating_state = ServiceOperatingState.RUNNING + + if act_type in ["NODE_FILE_SCAN", "NODE_FILE_CHECKHASH", "NODE_FILE_DELETE"]: + node_name = node_list[act_params["node_id"]] + folder_name = agent.action_manager.get_folder_name_by_idx(act_params["node_id"], act_params["folder_id"]) + file_name = agent.action_manager.get_file_name_by_idx( + act_params["node_id"], act_params["folder_id"], act_params["file_id"] + ) + node_obj = net.get_node_by_hostname(node_name) + file_obj = node_obj.file_system.get_file(folder_name, file_name, include_deleted=True) + assert not file_obj.deleted + assert mask[action_num] + service_obj.operating_state = ServiceOperatingState.DISABLED + mask = game.action_mask("defender") + assert mask[action_num] + service_obj.operating_state = ServiceOperatingState.RUNNING From 2bedf0c36301cd07a22267065703fedff02ebf37 Mon Sep 17 00:00:00 2001 From: "Archer.Bowen" Date: Wed, 10 Jul 2024 13:56:33 +0100 Subject: [PATCH 23/35] Updated changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d5418bdf..2d9aed82 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **Bandwidth Tracking**: Tracks data transmission across each frequency/channel width pairing. - **Configuration Support for Wireless Routers**: `channel_width` can now be configured in the config file under `wireless_access_point`. - **New Tests**: Added to validate the respect of bandwidth capacities and the correct parsing of airspace configurations from YAML files. +- **New Logging**: Added a new agent behaviour log which are more human friendly than agent history. These Logs are found in session log directory and can be enabled in the I/O settings in a yaml configuration file. ### Changed @@ -58,7 +59,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Activating dev-mode will change the location where the sessions will be output - by default will output where the PrimAITE repository is located - Refactored all air-space usage to that a new instance of AirSpace is created for each instance of Network. This 1:1 relationship between network and airspace will allow parallelization. - Added notebook to demonstrate use of SubprocVecEnv from SB3 to vectorise environments to speed up training. -- Added a new agent behaviour log which are more human friendly than agent history. These Logs are found in session log directory and can be enabled in the I/O settings in a yaml configuration file. ## [Unreleased] - Made requests fail to reach their target if the node is off From 9e7fd017df0944f52c8103f0b9e85eb5bc5665a1 Mon Sep 17 00:00:00 2001 From: "Archer.Bowen" Date: Wed, 10 Jul 2024 14:04:52 +0100 Subject: [PATCH 24/35] #2716 Fixed game.py pre-commit issue. --- src/primaite/game/game.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index ec59a784..1e6aeae0 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -15,8 +15,8 @@ from primaite.game.agent.scripted_agents.probabilistic_agent import Probabilisti from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort -from primaite.simulator.network.airspace import AirspaceEnvironmentType from primaite.simulator import SIM_OUTPUT +from primaite.simulator.network.airspace import AirspaceEnvironmentType from primaite.simulator.network.hardware.base import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC From aa425a528443d14529f0873843d7138042023fe2 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Wed, 10 Jul 2024 14:40:25 +0100 Subject: [PATCH 25/35] #2740: fix tests affected by request permissions --- .../_simulator/_file_system/test_file_actions.py | 14 +++++++------- .../_simulator/_file_system/test_folder_actions.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_actions.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_actions.py index 295bca08..594c7afe 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_actions.py @@ -26,7 +26,7 @@ def test_file_scan_request(populated_file_system): assert file.health_status == FileSystemItemHealthStatus.CORRUPT assert file.visible_health_status == FileSystemItemHealthStatus.GOOD - fs.apply_request(request=["file", file.name, "scan"]) + fs.apply_request(request=["folder", folder.name, "file", file.name, "scan"]) assert file.health_status == FileSystemItemHealthStatus.CORRUPT assert file.visible_health_status == FileSystemItemHealthStatus.CORRUPT @@ -37,12 +37,12 @@ def test_file_checkhash_request(populated_file_system): """Test that an agent can request a file hash check.""" fs, folder, file = populated_file_system - fs.apply_request(request=["file", file.name, "checkhash"]) + fs.apply_request(request=["folder", folder.name, "file", file.name, "checkhash"]) assert file.health_status == FileSystemItemHealthStatus.GOOD file.sim_size = 0 - fs.apply_request(request=["file", file.name, "checkhash"]) + fs.apply_request(request=["folder", folder.name, "file", file.name, "checkhash"]) assert file.health_status == FileSystemItemHealthStatus.CORRUPT @@ -54,7 +54,7 @@ def test_file_repair_request(populated_file_system): file.corrupt() assert file.health_status == FileSystemItemHealthStatus.CORRUPT - fs.apply_request(request=["file", file.name, "repair"]) + fs.apply_request(request=["folder", folder.name, "file", file.name, "repair"]) assert file.health_status == FileSystemItemHealthStatus.GOOD @@ -71,7 +71,7 @@ def test_file_restore_request(populated_file_system): assert fs.get_file(folder_name=folder.name, file_name=file.name) is not None assert fs.get_file(folder_name=folder.name, file_name=file.name).deleted is False - fs.apply_request(request=["file", file.name, "corrupt"]) + fs.apply_request(request=["folder", folder.name, "file", file.name, "corrupt"]) assert fs.get_file(folder_name=folder.name, file_name=file.name).health_status == FileSystemItemHealthStatus.CORRUPT fs.apply_request(request=["restore", "file", folder.name, file.name]) @@ -81,7 +81,7 @@ def test_file_restore_request(populated_file_system): def test_file_corrupt_request(populated_file_system): """Test that an agent can request a file corruption.""" fs, folder, file = populated_file_system - fs.apply_request(request=["file", file.name, "corrupt"]) + fs.apply_request(request=["folder", folder.name, "file", file.name, "corrupt"]) assert file.health_status == FileSystemItemHealthStatus.CORRUPT @@ -90,7 +90,7 @@ def test_deleted_file_cannot_be_interacted_with(populated_file_system): fs, folder, file = populated_file_system assert fs.get_file(folder_name=folder.name, file_name=file.name) is not None - fs.apply_request(request=["file", file.name, "corrupt"]) + fs.apply_request(request=["folder", folder.name, "file", file.name, "corrupt"]) assert fs.get_file(folder_name=folder.name, file_name=file.name).health_status == FileSystemItemHealthStatus.CORRUPT assert ( fs.get_file(folder_name=folder.name, file_name=file.name).visible_health_status diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_folder_actions.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_folder_actions.py index 40f1e78b..00148311 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_folder_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_folder_actions.py @@ -166,13 +166,13 @@ def test_deleted_folder_and_its_files_cannot_be_interacted_with(populated_file_s fs, folder, file = populated_file_system assert fs.get_file(folder_name=folder.name, file_name=file.name) is not None - fs.apply_request(request=["file", file.name, "corrupt"]) + fs.apply_request(request=["folder", folder.name, "file", file.name, "corrupt"]) assert fs.get_file(folder_name=folder.name, file_name=file.name).health_status == FileSystemItemHealthStatus.CORRUPT fs.apply_request(request=["delete", "folder", folder.name]) assert fs.get_file(folder_name=folder.name, file_name=file.name) is None - fs.apply_request(request=["file", file.name, "repair"]) + fs.apply_request(request=["folder", folder.name, "file", file.name, "repair"]) deleted_folder = fs.deleted_folders.get(folder.uuid) deleted_file = deleted_folder.deleted_files.get(file.uuid) From bdf5ff21671a5799c894f04c5db7361f045a56e1 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 11 Jul 2024 10:01:33 +0100 Subject: [PATCH 26/35] 2623 Add docs for action masking --- docs/index.rst | 1 + docs/source/action_masking.rst | 80 ++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+) create mode 100644 docs/source/action_masking.rst diff --git a/docs/index.rst b/docs/index.rst index 5749ad56..431dea28 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -123,6 +123,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE! source/environment source/customising_scenarios source/varying_config_files + source/action_masking .. toctree:: :caption: Notebooks: diff --git a/docs/source/action_masking.rst b/docs/source/action_masking.rst new file mode 100644 index 00000000..3e5b967b --- /dev/null +++ b/docs/source/action_masking.rst @@ -0,0 +1,80 @@ +.. only:: comment + + © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + +Action Masking +************** +The PrimAITE simulation is able to provide action masks in the environment output. These action masks let the agents know +about which actions are invalid based on the current environment state. For instance, it's not possible to install +software on a node that is turned off. Therefore, if an agent has a NODE_SOFTWARE_INSTALL in it's action map for that node, +the action mask will show `0` in the corresponding entry. + +Configuration +============= +Action masking is supported for agents that use the `ProxyAgent` class (the class used for connecting to RL algorithms). +In order to use action masking, set the agent_settings.action_masking parameter to True in the config file. + +Masking Logic +============= +The following logic is applied: + +* **DONOTHING** : Always possible +* **NODE_HOST_SERVICE_SCAN** : Node is on. Service is running. +* **NODE_HOST_SERVICE_STOP** : Node is on. Service is running. +* **NODE_HOST_SERVICE_START** : Node is on. Service is stopped. +* **NODE_HOST_SERVICE_PAUSE** : Node is on. Service is running. +* **NODE_HOST_SERVICE_RESUME** : Node is on. Service is paused. +* **NODE_HOST_SERVICE_RESTART** : Node is on. Service is running. +* **NODE_HOST_SERVICE_DISABLE** : Node is on. +* **NODE_HOST_SERVICE_ENABLE** : Node is on. Service is disabled. +* **NODE_HOST_SERVICE_FIX** : Node is on. Service is running. +* **NODE_HOST_APPLICATION_EXECUTE** : Node is on. +* **NODE_HOST_APPLICATION_SCAN** : Node is on. Application is running. +* **NODE_HOST_APPLICATION_CLOSE** : Node is on. Application is running. +* **NODE_HOST_APPLICATION_FIX** : Node is on. Application is running. +* **NODE_HOST_APPLICATION_INSTALL** : Node is on. +* **NODE_HOST_APPLICATION_REMOVE** : Node is on. +* **NODE_HOST_FILE_SCAN** : Node is on. File exists. File not deleted. +* **NODE_HOST_FILE_CREATE** : Node is on. +* **NODE_HOST_FILE_CHECKHASH** : Node is on. File exists. File not deleted. +* **NODE_HOST_FILE_DELETE** : Node is on. File exists. +* **NODE_HOST_FILE_REPAIR** : Node is on. File exists. File not deleted. +* **NODE_HOST_FILE_RESTORE** : Node is on. File exists. File is deleted. +* **NODE_HOST_FILE_CORRUPT** : Node is on. File exists. File not deleted. +* **NODE_HOST_FILE_ACCESS** : Node is on. File exists. File not deleted. +* **NODE_HOST_FOLDER_CREATE** : Node is on. +* **NODE_HOST_FOLDER_SCAN** : Node is on. Folder exists. Folder not deleted. +* **NODE_HOST_FOLDER_CHECKHASH** : Node is on. Folder exists. Folder not deleted. +* **NODE_HOST_FOLDER_REPAIR** : Node is on. Folder exists. Folder not deleted. +* **NODE_HOST_FOLDER_RESTORE** : Node is on. Folder exists. Folder is deleted. +* **NODE_HOST_OS_SCAN** : Node is on. +* **NODE_HOST_NIC_ENABLE** : NIC is disabled. Node is on. +* **NODE_HOST_NIC_DISABLE** : NIC is enabled. Node is on. +* **NODE_HOST_SHUTDOWN** : Node is on. +* **NODE_HOST_STARTUP** : Node is off. +* **NODE_HOST_RESET** : Node is on. +* **NODE_HOST_NMAP_PING_SCAN** : Node is on. +* **NODE_HOST_NMAP_PORT_SCAN** : Node is on. +* **NODE_HOST_NMAP_NETWORK_SERVICE_RECON** : Node is on. +* **NODE_ROUTER_PORT_ENABLE** : Router is on. +* **NODE_ROUTER_PORT_DISABLE** : Router is on. +* **NODE_ROUTER_ACL_ADDRULE** : Router is on. +* **NODE_ROUTER_ACL_REMOVERULE** : Router is on. +* **NODE_FIREWALL_PORT_ENABLE** : Firewall is on. +* **NODE_FIREWALL_PORT_DISABLE** : Firewall is on. +* **NODE_FIREWALL_ACL_ADDRULE** : Firewall is on. +* **NODE_FIREWALL_ACL_REMOVERULE** : Firewall is on. + + +Mechanism +========= +The environment iterates over the RL agent's ``action_map`` and generates the corresponding simulator request string. +It uses the ``RequestManager.check_valid()`` method to invoke the relevant ``RequestPermissionValidator`` without +actually running the request on the simulation. + +Current Limitations +=================== +Currently, action masking only considers whether the action as a whole is possible, it doesn't verify that the exact +parameter combination passed to the action make sense in the current context. For instance, if ACL rule 3 on router_1 is +already populated, the action for adding another rule at position 3 will be available regardless, as long as that router +is turned on. This will never block valid actions. It will just occasionally allow invalid actions. From 579469d1c3fc30af8484e192f7f0c17724e96aee Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 11 Jul 2024 11:25:38 +0100 Subject: [PATCH 27/35] 2623 add sb3 contrib dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index a0c2e3eb..9e919604 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ rl = [ "ray[rllib] >= 2.20.0, < 3", "tensorflow==2.12.0", "stable-baselines3[extra]==2.1.0", + "sb3-contrib==2.1.0", ] dev = [ "build==0.10.0", From b8c62386104ff5dc2169abe5cb29a9d2e87ec9a9 Mon Sep 17 00:00:00 2001 From: "Archer.Bowen" Date: Thu, 11 Jul 2024 11:55:03 +0100 Subject: [PATCH 28/35] #2740 Fixed Nmap Test Failure. --- tests/integration_tests/system/test_nmap.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration_tests/system/test_nmap.py b/tests/integration_tests/system/test_nmap.py index bbfa4f43..2b8691cc 100644 --- a/tests/integration_tests/system/test_nmap.py +++ b/tests/integration_tests/system/test_nmap.py @@ -101,6 +101,7 @@ def test_port_scan_full_subnet_all_ports_and_protocols(example_network): actual_result = client_1_nmap.port_scan( target_ip_address=IPv4Network("192.168.10.0/24"), + target_port=[Port.ARP, Port.HTTP, Port.FTP, Port.DNS, Port.NTP], ) expected_result = { From 1d466d6807f7bab422316c719d5a977e9fb12ac0 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Thu, 11 Jul 2024 12:19:27 +0100 Subject: [PATCH 29/35] #2740: unit tests + a minor fix to nic test --- .../simulator/file_system/file_system.py | 4 +-- src/primaite/simulator/file_system/folder.py | 27 ++++++++++++-- .../observations/test_nic_observations.py | 4 +-- .../_file_system/test_file_system_actions.py | 36 +++++++++++++++++++ .../_file_system/test_folder_actions.py | 25 +++++++++++++ .../test_network_interface_actions.py | 34 ++++++++++++++++++ .../_network/_hardware/test_node_actions.py | 36 +++++++++++++++++++ .../_applications/test_application_actions.py | 14 ++++++++ .../_system/_services/test_service_actions.py | 20 ++++++++++- 9 files changed, 193 insertions(+), 7 deletions(-) create mode 100644 tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index bcb0334e..ba72e1ac 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -653,7 +653,7 @@ class FileSystem(SimComponent): @property def fail_message(self) -> str: """Message that is reported when a request is rejected by this validator.""" - return "Cannot perform request on folder because it does not exist" + return "Cannot perform request on folder because it does not exist." class _FolderNotDeletedValidator(RequestPermissionValidator): """ @@ -693,4 +693,4 @@ class FileSystem(SimComponent): @property def fail_message(self) -> str: """Message that is reported when a request is rejected by this validator.""" - return "Cannot perform request on file that does not exist." + return "Cannot perform request on a file that does not exist." diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index d891641e..c98e4492 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -56,6 +56,7 @@ class Folder(FileSystemItemABC): More information in user guide and docstring for SimComponent._init_request_manager. """ self._file_exists = Folder._FileExistsValidator(folder=self) + self._file_not_deleted = Folder._FileNotDeletedValidator(folder=self) rm = super()._init_request_manager() rm.add_request( @@ -67,7 +68,9 @@ class Folder(FileSystemItemABC): self._file_request_manager = RequestManager() rm.add_request( name="file", - request_type=RequestType(func=self._file_request_manager, validator=self._file_exists), + request_type=RequestType( + func=self._file_request_manager, validator=self._file_exists + self._file_not_deleted + ), ) return rm @@ -489,4 +492,24 @@ class Folder(FileSystemItemABC): @property def fail_message(self) -> str: """Message that is reported when a request is rejected by this validator.""" - return "Cannot perform request on file that does not exist." + return "Cannot perform request on a file that does not exist." + + class _FileNotDeletedValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only let them through if the File is not deleted. + + Actions cannot be performed on a deleted file. + """ + + folder: Folder + """Save a reference to the Folder instance.""" + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Returns True if file exists and is not deleted.""" + file = self.folder.get_file(file_name=request[0]) + return file is not None and not file.deleted + + @property + def fail_message(self) -> str: + """Message that is reported when a request is rejected by this validator.""" + return "Cannot perform request on a file that is deleted." diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index f1a8ea92..88dd2bd5 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -155,7 +155,7 @@ def test_nic_monitored_traffic(simulation): assert traffic_obs["icmp"]["outbound"] == 0 # send a ping - pc.ping(target_ip_address=pc2.network_interface[1].ip_address) + assert pc.ping(target_ip_address=pc2.network_interface[1].ip_address) traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC") assert traffic_obs["icmp"]["inbound"] == 1 @@ -178,7 +178,7 @@ def test_nic_monitored_traffic(simulation): traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC") assert traffic_obs["icmp"]["inbound"] == 0 assert traffic_obs["icmp"]["outbound"] == 0 - assert traffic_obs["tcp"][53]["inbound"] == 0 + assert traffic_obs["tcp"][53]["inbound"] == 1 assert traffic_obs["tcp"][53]["outbound"] == 1 # getting a webpage sent a dns request out simulation.pre_timestep(2) # apply timestep to whole sim diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_actions.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_actions.py index 209668c4..7d022ea4 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_actions.py @@ -39,3 +39,39 @@ def test_folder_delete_request(populated_file_system): assert fs.get_file_by_id(folder_uuid=folder.uuid, file_uuid=file.uuid) is None fs.show(full=True) + + +def test_folder_exists_request_validator(populated_file_system): + """Tests that the _FolderExistsValidator works as intended.""" + fs, folder, file = populated_file_system + validator = FileSystem._FolderExistsValidator(file_system=fs) + + assert validator(request=["test_folder"], context={}) # test_folder exists + assert validator(request=["fake_folder"], context={}) is False # fake_folder does not exist + + assert validator.fail_message == "Cannot perform request on folder because it does not exist." + + +def test_file_exists_request_validator(populated_file_system): + """Tests that the _FolderExistsValidator works as intended.""" + fs, folder, file = populated_file_system + validator = FileSystem._FileExistsValidator(file_system=fs) + + assert validator(request=["test_folder", "test_file.txt"], context={}) # test_file.txt exists + assert validator(request=["test_folder", "fake_file.txt"], context={}) is False # fake_file.txt does not exist + + assert validator.fail_message == "Cannot perform request on a file that does not exist." + + +def test_folder_not_deleted_request_validator(populated_file_system): + """Tests that the _FolderExistsValidator works as intended.""" + fs, folder, file = populated_file_system + validator = FileSystem._FolderNotDeletedValidator(file_system=fs) + + assert validator(request=["test_folder"], context={}) # test_folder is not deleted + + fs.delete_folder(folder_name="test_folder") + + assert validator(request=["test_folder"], context={}) is False # test_folder is deleted + + assert validator.fail_message == "Cannot perform request on folder because it is deleted." diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_folder_actions.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_folder_actions.py index 00148311..4a561b97 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_folder_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_folder_actions.py @@ -178,3 +178,28 @@ def test_deleted_folder_and_its_files_cannot_be_interacted_with(populated_file_s deleted_file = deleted_folder.deleted_files.get(file.uuid) assert deleted_file.health_status is not FileSystemItemHealthStatus.GOOD + + +def test_file_exists_request_validator(populated_file_system): + """Tests that the _FolderExistsValidator works as intended.""" + fs, folder, file = populated_file_system + validator = Folder._FileExistsValidator(folder=folder) + + assert validator(request=["test_file.txt"], context={}) # test_file.txt exists + assert validator(request=["fake_file.txt"], context={}) is False # fake_file.txt does not exist + + assert validator.fail_message == "Cannot perform request on a file that does not exist." + + +def test_file_not_deleted_request_validator(populated_file_system): + """Tests that the _FolderExistsValidator works as intended.""" + fs, folder, file = populated_file_system + validator = Folder._FileNotDeletedValidator(folder=folder) + + assert validator(request=["test_file.txt"], context={}) # test_file.txt is not deleted + + fs.delete_file(folder_name="test_folder", file_name="test_file.txt") + + assert validator(request=["fake_file.txt"], context={}) is False # test_file.txt is deleted + + assert validator.fail_message == "Cannot perform request on a file that is deleted." diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py new file mode 100644 index 00000000..f35cf171 --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py @@ -0,0 +1,34 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest + +from primaite.simulator.network.hardware.base import NetworkInterface, Node +from primaite.simulator.network.hardware.nodes.host.computer import Computer + + +@pytest.fixture +def node() -> Node: + return Computer(hostname="test", ip_address="192.168.1.2", subnet_mask="255.255.255.0") + + +def test_nic_enabled_validator(node): + """Test the NetworkInterface enabled validator.""" + network_interface = node.network_interface[1] + validator = NetworkInterface._EnabledValidator(network_interface=network_interface) + + assert validator(request=[], context={}) is False # not enabled + + network_interface.enabled = True + + assert validator(request=[], context={}) # enabled + + +def test_nic_disabled_validator(node): + """Test the NetworkInterface enabled validator.""" + network_interface = node.network_interface[1] + validator = NetworkInterface._DisabledValidator(network_interface=network_interface) + + assert validator(request=[], context={}) # not enabled + + network_interface.enabled = True + + assert validator(request=[], context={}) is False # enabled diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py index 57d6cecb..9b37ac80 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py @@ -155,3 +155,39 @@ def test_reset_node(node): assert node.operating_state == NodeOperatingState.BOOTING assert node.operating_state == NodeOperatingState.ON + + +def test_node_is_on_validator(node): + """Test that the node is on validator.""" + node.power_on() + + for i in range(node.start_up_duration + 1): + node.apply_timestep(i) + + validator = Node._NodeIsOnValidator(node=node) + + assert validator(request=[], context={}) + + node.power_off() + for i in range(node.shut_down_duration + 1): + node.apply_timestep(i) + + assert validator(request=[], context={}) is False + + +def test_node_is_off_validator(node): + """Test that the node is on validator.""" + node.power_on() + + for i in range(node.start_up_duration + 1): + node.apply_timestep(i) + + validator = Node._NodeIsOffValidator(node=node) + + assert validator(request=[], context={}) is False + + node.power_off() + for i in range(node.shut_down_duration + 1): + node.apply_timestep(i) + + assert validator(request=[], context={}) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_actions.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_actions.py index be6c00e7..0e9c536c 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_actions.py @@ -1 +1,15 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from primaite.simulator.system.applications.application import Application, ApplicationOperatingState + + +def test_application_state_validator(application): + """Test the application state validator.""" + validator = Application._StateValidator(application=application, state=ApplicationOperatingState.CLOSED) + assert validator(request=[], context={}) # application is closed + application.run() + assert validator(request=[], context={}) is False # application is running - expecting closed + + validator = Application._StateValidator(application=application, state=ApplicationOperatingState.RUNNING) + assert validator(request=[], context={}) # application is running + application.close() + assert validator(request=[], context={}) is False # application is closed - expecting running diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_service_actions.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_service_actions.py index 2d9a6c52..537beb8b 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_service_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_service_actions.py @@ -1,5 +1,5 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.simulator.system.services.service import Service, ServiceOperatingState from primaite.simulator.system.software import SoftwareHealthState @@ -92,3 +92,21 @@ def test_service_fix(service): assert service.health_state_actual == SoftwareHealthState.FIXING service.apply_timestep(2) assert service.health_state_actual == SoftwareHealthState.GOOD + + +def test_service_state_validator(service): + """Test the service state validator.""" + validator = Service._StateValidator(service=service, state=ServiceOperatingState.STOPPED) + assert validator(request=[], context={}) # service is stopped + service.start() + assert validator(request=[], context={}) is False # service is running - expecting stopped + + validator = Service._StateValidator(service=service, state=ServiceOperatingState.RUNNING) + assert validator(request=[], context={}) # service is running + service.pause() + assert validator(request=[], context={}) is False # service is paused - expecting running + + validator = Service._StateValidator(service=service, state=ServiceOperatingState.PAUSED) + assert validator(request=[], context={}) # service is paused + service.resume() + assert validator(request=[], context={}) is False # service is running - expecting paused From 9636882ed1de0e392f80671d021a9dfab46b02fa Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Thu, 11 Jul 2024 13:57:23 +0100 Subject: [PATCH 30/35] #2740: refactor variable --- src/primaite/simulator/file_system/file_system.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index ba72e1ac..2162915f 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -43,7 +43,7 @@ class FileSystem(SimComponent): More information in user guide and docstring for SimComponent._init_request_manager. """ self._folder_exists = FileSystem._FolderExistsValidator(file_system=self) - self._folder_deleted = FileSystem._FolderNotDeletedValidator(file_system=self) + self._folder_not_deleted = FileSystem._FolderNotDeletedValidator(file_system=self) self._file_exists = FileSystem._FileExistsValidator(file_system=self) rm = super()._init_request_manager() @@ -152,7 +152,7 @@ class FileSystem(SimComponent): self._folder_request_manager = RequestManager() rm.add_request( "folder", - RequestType(func=self._folder_request_manager, validator=self._folder_exists + self._folder_deleted), + RequestType(func=self._folder_request_manager, validator=self._folder_exists + self._folder_not_deleted), ) self._file_request_manager = RequestManager() From cf563149ec54e10b00a66a3562474bb7b058744f Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Thu, 11 Jul 2024 15:07:58 +0100 Subject: [PATCH 31/35] #2745 carried over changes from internal that backtracked on the complex channel width stuff for now and focussed on getting a stable data rate baked in for each frequency --- CHANGELOG.md | 16 +- .../network/airspace.rst | 68 +-- .../network/nodes/wireless_router.rst | 3 - src/primaite/game/game.py | 5 - src/primaite/simulator/network/airspace.py | 495 ++---------------- .../hardware/nodes/network/wireless_router.py | 11 +- ...s_wan_wifi_5_80_channel_width_blocked.yaml | 2 - ...ess_wan_wifi_5_80_channel_width_urban.yaml | 2 - .../test_airspace_capacity_configuration.py | 106 ---- ...ndwidth_load_checks_before_transmission.py | 24 - 10 files changed, 61 insertions(+), 671 deletions(-) delete mode 100644 tests/integration_tests/network/test_airspace_capacity_configuration.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d9aed82..515be435 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,17 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added -- **AirSpaceEnvironmentType Enum Class**: Introduced in `airspace.py` to define different environmental settings affecting wireless network behavior. -- **ChannelWidth Enum Class**: Added in `airspace.py` to specify channel width options for wireless network interfaces. -- **Channel Width Attribute**: Incorporated into the `WirelessNetworkInterface` class to allow dynamic setting based on `AirSpaceFrequency` and `AirSpaceEnvironmentType`. -- **SNR and Capacity Calculation Functions**: Functions `estimate_snr` and `calculate_total_channel_capacity` added to `airspace.py` for computing signal-to-noise ratio and capacity based on frequency and channel width. -- **Dynamic Speed Setting**: WirelessInterface speed attribute now dynamically adjusts based on the operational environment, frequency, and channel width. -- **airspace_key Attribute**: Added to `WirelessNetworkInterface` as a tuple of frequency and channel width, serving as a key for bandwidth/channel management. -- **airspace_environment_type Attribute**: Determines the environmental type for the airspace, influencing data rate calculations and capacity sharing. -- **show_bandwidth_load Function**: Displays current bandwidth load for each frequency and channel width in the airspace. -- **Configuration Schema Update**: The `simulation.network` config file now includes settings for the `airspace_environment_type`. -- **Bandwidth Tracking**: Tracks data transmission across each frequency/channel width pairing. -- **Configuration Support for Wireless Routers**: `channel_width` can now be configured in the config file under `wireless_access_point`. +- **show_bandwidth_load Function**: Displays current bandwidth load for each frequency in the airspace. +- **Bandwidth Tracking**: Tracks data transmission across each frequency. - **New Tests**: Added to validate the respect of bandwidth capacities and the correct parsing of airspace configurations from YAML files. - **New Logging**: Added a new agent behaviour log which are more human friendly than agent history. These Logs are found in session log directory and can be enabled in the I/O settings in a yaml configuration file. @@ -27,9 +18,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **NetworkInterface Speed Type**: The `speed` attribute of `NetworkInterface` has been changed from `int` to `float`. - **Transmission Feasibility Check**: Updated `_can_transmit` function in `Link` to account for current load and total bandwidth capacity, ensuring transmissions do not exceed limits. - **Frame Size Details**: Frame `size` attribute now includes both core size and payload size in bytes. -- **WirelessRouter Configuration Function**: `configure_wireless_access_point` function now accepts `channel_width` as a parameter. -- **Interface Grouping**: `WirelessNetworkInterfaces` are now grouped by both `AirSpaceFrequency` and `ChannelWidth`. -- **Interface Frequency/Channel Width Adjustment**: Changing an interface's settings now involves removal from the airspace, recalculation of its data rate, and re-addition under new settings. - **Transmission Blocking**: Enhanced `AirSpace` logic to block transmissions that would exceed the available capacity. ### Fixed diff --git a/docs/source/simulation_components/network/airspace.rst b/docs/source/simulation_components/network/airspace.rst index dcd762d4..06a884a7 100644 --- a/docs/source/simulation_components/network/airspace.rst +++ b/docs/source/simulation_components/network/airspace.rst @@ -22,79 +22,21 @@ The AirSpace is a virtual representation of a physical wireless environment, man ^^^^^^^^^^^^^^^^^^ - **Wireless Network Interfaces**: Representations of network interfaces connected physical devices like routers, computers, or IoT devices that can send and receive data wirelessly. -- **Environmental Settings**: Different types of environments (e.g., urban, rural) that affect signal propagation and interference. -- **Channel Management**: Handles channels and their widths (e.g., 20 MHz, 40 MHz) to determine data transmission over different frequencies. -- **Bandwidth Management**: Tracks data transmission over channels to prevent overloading and simulate real-world network congestion. +- **Bandwidth Management**: Tracks data transmission over frequencies to prevent overloading and simulate real-world network congestion. -3. AirSpace Environment Types ------------------------------ -The AirspaceEnvironmentType is a critical component that simulates different physical environments: - -- Urban, Suburban, Rural, etc. -- Each type simulates different levels of electromagnetic interference and signal propagation characteristics. -- Changing the AirspaceEnvironmentType impacts data rates by affecting the signal-to-noise ratio (SNR). - -4. Simulation of Environment Changes ------------------------------------- - -When an AirspaceEnvironmentType is set or changed, the AirSpace: - -1. Recalculates the maximum data transmission capacities for all managed frequencies and channel widths. -2. Updates all wireless interfaces to reflect new capacities. - -5. Managing Wireless Network Interfaces +3. Managing Wireless Network Interfaces --------------------------------------- - Interfaces can be dynamically added or removed. - Configurations can be changed in real-time. -- The AirSpace handles data transmissions, ensuring data sent by an interface is received by all other interfaces on the same frequency and channel. +- The AirSpace handles data transmissions, ensuring data sent by an interface is received by all other interfaces on the same frequency. -6. Signal-to-Noise Ratio (SNR) Calculation ------------------------------------------- -SNR is crucial in determining the quality of a wireless communication channel: - -.. math:: - - SNR = \frac{\text{Signal Power}}{\text{Noise Power}} - -- Impacted by environment type, frequency, and channel width -- Higher SNR indicates a clearer signal, leading to higher data transmission rates - -7. Total Channel Capacity Calculation -------------------------------------- - -Channel capacity is calculated using the Shannon-Hartley theorem: - -.. math:: - - C = B \cdot \log_2(1 + SNR) - -Where: - -- C: channel capacity in bits per second (bps) -- B: bandwidth of the channel in hertz (Hz) -- SNR: signal-to-noise ratio - -Implementation in AirSpace: - -1. Convert channel width from MHz to Hz. -2. Recalculate SNR based on new environment or interface settings. -3. Apply Shannon-Hartley theorem to determine new maximum channel capacity in Mbps. - -8. Shared Maximum Capacity Across Devices ------------------------------------------ - -While individual devices have theoretical maximum data rates, the actual achievable rate is often less due to: - -- Shared wireless medium among all devices on the same frequency and channel width -- Interference and congestion from multiple devices transmitting simultaneously - -9. AirSpace Inspection +4. AirSpace Inspection ---------------------- The AirSpace class provides methods for visualizing network behavior: - ``show_wireless_interfaces()``: Displays current state of all interfaces -- ``show_bandwidth_load()``: Shows channel loads and bandwidth utilization +- ``show_bandwidth_load()``: Shows bandwidth utilisation diff --git a/docs/source/simulation_components/network/nodes/wireless_router.rst b/docs/source/simulation_components/network/nodes/wireless_router.rst index eb7f95e3..c78c8419 100644 --- a/docs/source/simulation_components/network/nodes/wireless_router.rst +++ b/docs/source/simulation_components/network/nodes/wireless_router.rst @@ -50,7 +50,6 @@ additional steps to configure wireless settings: port=1, ip_address="192.168.2.1", subnet_mask="255.255.255.0", frequency=AirSpaceFrequency.WIFI_2_4, - channel_width=ChannelWidth.ChannelWidth.WIDTH_40_MHZ ) @@ -132,14 +131,12 @@ ICMP traffic, ensuring basic network connectivity and ping functionality. ip_address="192.168.1.1", subnet_mask="255.255.255.0", frequency=AirSpaceFrequency.WIFI_2_4, - channel_width=ChannelWidth.ChannelWidth.WIDTH_40_MHZ ) router_2.configure_wireless_access_point( port=1, ip_address="192.168.1.2", subnet_mask="255.255.255.0", frequency=AirSpaceFrequency.WIFI_2_4, - channel_width=ChannelWidth.ChannelWidth.WIDTH_40_MHZ ) # Configure routes for inter-router communication diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 1e6aeae0..b976e55f 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -16,7 +16,6 @@ from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT -from primaite.simulator.network.airspace import AirspaceEnvironmentType from primaite.simulator.network.hardware.base import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC @@ -238,10 +237,6 @@ class PrimaiteGame: simulation_config = cfg.get("simulation", {}) network_config = simulation_config.get("network", {}) airspace_cfg = network_config.get("airspace", {}) - airspace_environment_type_str = airspace_cfg.get("airspace_environment_type", "urban") - - airspace_environment_type: AirspaceEnvironmentType = AirspaceEnvironmentType(airspace_environment_type_str) - net.airspace.airspace_environment_type = airspace_environment_type nodes_cfg = network_config.get("nodes", []) links_cfg = network_config.get("links", []) diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index 2ac11a20..5019385a 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -3,11 +3,10 @@ from __future__ import annotations from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional -import numpy as np from prettytable import MARKDOWN, PrettyTable -from pydantic import BaseModel, computed_field, Field, model_validator +from pydantic import BaseModel, Field from primaite import getLogger from primaite.simulator.network.hardware.base import Layer3Interface, NetworkInterface, WiredNetworkInterface @@ -58,228 +57,17 @@ class AirSpaceFrequency(Enum): return f"WiFi {hertz_str}" return "Unknown Frequency" - -class ChannelWidth(Enum): - """ - Enumeration representing the available channel widths in MHz for wireless communications. - - This enum facilitates standardising and validating channel width configurations. - - Attributes: - WIDTH_20_MHZ (int): Represents a channel width of 20 MHz, commonly used for basic - Wi-Fi connectivity with standard range and interference resistance. - WIDTH_40_MHZ (int): Represents a channel width of 40 MHz, offering higher data - throughput at the expense of potentially increased interference. - WIDTH_80_MHZ (int): Represents a channel width of 80 MHz, typically used in modern - Wi-Fi setups for high data rate applications but with higher susceptibility to interference. - WIDTH_160_MHZ (int): Represents a channel width of 160 MHz, used for ultra-high-speed - network applications, providing maximum data throughput with significant - requirements on the spectral environment to minimize interference. - """ - - WIDTH_20_MHZ = 20 - """ - Represents a channel width of 20 MHz, commonly used for basic Wi-Fi connectivity with standard range and - interference resistance - """ - - WIDTH_40_MHZ = 40 - """ - Represents a channel width of 40 MHz, offering higher data throughput at the expense of potentially increased - interference. - """ - - WIDTH_80_MHZ = 80 - """ - Represents a channel width of 80 MHz, typically used in modern Wi-Fi setups for high data rate applications but - with higher susceptibility to interference. - """ - - WIDTH_160_MHZ = 160 - """ - Represents a channel width of 160 MHz, used for ultra-high-speed network applications, providing maximum data - throughput with significant requirements on the spectral environment to minimize interference. - """ - - def __str__(self) -> str: - """ - Returns a string representation of the channel width. - - :return: String in the format of " MHz" indicating the channel width. - """ - return f"{self.value} MHz" - - -AirSpaceKeyType = Tuple[AirSpaceFrequency, ChannelWidth] - - -class AirspaceEnvironmentType(Enum): - """Enum representing different types of airspace environments which affect wireless communication signals.""" - - RURAL = "rural" - """ - A rural environment offers clear channel conditions due to low population density and minimal electronic device - presence. - """ - - OUTDOOR = "outdoor" - """ - Outdoor environments like parks or fields have minimal electronic interference. - """ - - SUBURBAN = "suburban" - """ - Suburban environments strike a balance with fewer electronic interferences than urban but more than rural. - """ - - OFFICE = "office" - """ - Office environments have moderate interference from numerous electronic devices and overlapping networks. - """ - - URBAN = "urban" - """ - Urban environments are characterized by tall buildings and a high density of electronic devices, leading to - significant interference. - """ - - INDUSTRIAL = "industrial" - """ - Industrial areas face high interference from heavy machinery and numerous electronic devices. - """ - - TRANSPORT = "transport" - """ - Environments such as subways and buses where metal structures and high mobility create complex interference - patterns. - """ - - DENSE_URBAN = "dense_urban" - """ - Dense urban areas like city centers have the highest level of signal interference due to the very high density of - buildings and devices. - """ - - JAMMING_ZONE = "jamming_zone" - """ - A jamming zone environment where signals are actively interfered with, typically through the use of signal jammers - or scrambling devices. This represents the environment with the highest level of interference. - """ - - BLOCKED = "blocked" - """ - A jamming zone environment with total levels of interference. Airspace is completely blocked. - """ + @property + def maximum_data_rate_bps(self) -> float: + if self == AirSpaceFrequency.WIFI_2_4: + return 100_000_000.0 # 100 Megabits per second + if self == AirSpaceFrequency.WIFI_5: + return 500_000_000.0 # 500 Megabits per second + return 0.0 @property - def snr_impact(self) -> int: - """ - Returns the SNR impact associated with the environment. - - :return: SNR impact in dB. - """ - impacts = { - AirspaceEnvironmentType.RURAL: 0, - AirspaceEnvironmentType.OUTDOOR: 1, - AirspaceEnvironmentType.SUBURBAN: -5, - AirspaceEnvironmentType.OFFICE: -7, - AirspaceEnvironmentType.URBAN: -10, - AirspaceEnvironmentType.INDUSTRIAL: -15, - AirspaceEnvironmentType.TRANSPORT: -12, - AirspaceEnvironmentType.DENSE_URBAN: -20, - AirspaceEnvironmentType.JAMMING_ZONE: -40, - AirspaceEnvironmentType.BLOCKED: -100, - } - return impacts[self] - - def __str__(self) -> str: - return f"{self.value.title()} Environment (SNR Impact: {self.snr_impact})" - - -def estimate_snr( - frequency: AirSpaceFrequency, environment_type: AirspaceEnvironmentType, channel_width: ChannelWidth -) -> float: - """ - Estimate the Signal-to-Noise Ratio (SNR) based on the communication frequency, environment, and channel width. - - This function considers both the base SNR value dependent on the frequency and the impact of environmental - factors and channel width on the SNR. - - The SNR is adjusted by reducing it for wider channels, reflecting the increased noise floor from a broader - frequency range. - - :param frequency: The operating frequency as defined by AirSpaceFrequency enum, influencing the base SNR. Higher - frequencies like 5 GHz generally start with a higher base SNR due to less noise. - :param environment_type: The type of environment from AirspaceEnvironmentType enum, which adjusts the SNR based on - expected environmental noise and interference levels. - :param channel_width: The channel width from ChannelWidth enum, where wider channels (80 MHz and 160 MHz) decrease - the SNR slightly due to an increased noise floor. - :return: Estimated SNR in dB, calculated as the base SNR modified by environmental and channel width impacts. - """ - base_snr = 40 if frequency == AirSpaceFrequency.WIFI_5 else 30 - snr_impact = environment_type.snr_impact - - # Adjust SNR impact based on channel width - if channel_width == ChannelWidth.WIDTH_80_MHZ or channel_width == ChannelWidth.WIDTH_160_MHZ: - snr_impact -= 3 # Assume wider channels have slightly lower SNR due to increased noise floor - - return base_snr + snr_impact - - -def calculate_total_channel_capacity( - channel_width: ChannelWidth, frequency: AirSpaceFrequency, environment_type: AirspaceEnvironmentType -) -> float: - """ - Calculate the total theoretical data rate for the channel using the Shannon-Hartley theorem. - - This function determines the channel's capacity by considering the bandwidth (derived from channel width), - and the signal-to-noise ratio (SNR) adjusted by frequency and environmental conditions. - - The Shannon-Hartley theorem states that channel capacity C (in bits per second) can be calculated as: - ``C = B * log2(1 + SNR)`` where B is the bandwidth in Hertz and SNR is the signal-to-noise ratio. - - :param channel_width: The width of the channel as defined by ChannelWidth enum, converted to Hz for calculation. - :param frequency: The operating frequency as defined by AirSpaceFrequency enum, influencing the base SNR and part - of the SNR estimation. - :param environment_type: The type of environment as defined by AirspaceEnvironmentType enum, used in SNR estimation. - :return: Theoretical total data rate in Mbps for the entire channel. - """ - bandwidth_hz = channel_width.value * 1_000_000 # Convert MHz to Hz - snr_db = estimate_snr(frequency, environment_type, channel_width) - snr_linear = 10 ** (snr_db / 10) - - total_capacity_bps = bandwidth_hz * np.log2(1 + snr_linear) - total_capacity_mbps = total_capacity_bps / 1_000_000 - - return total_capacity_mbps - - -def calculate_individual_device_rate( - channel_width: ChannelWidth, - frequency: AirSpaceFrequency, - environment_type: AirspaceEnvironmentType, - device_count: int, -) -> float: - """ - Calculate the theoretical data rate available to each individual device on the channel. - - This function first calculates the total channel capacity and then divides this capacity by the number - of active devices to estimate each device's share of the bandwidth. This reflects the practical limitation - that multiple devices must share the same channel resources. - - :param channel_width: The channel width as defined by ChannelWidth enum, used in total capacity calculation. - :param frequency: The operating frequency as defined by AirSpaceFrequency enum, used in total capacity calculation. - :param environment_type: The environment type as defined by AirspaceEnvironmentType enum, impacting SNR and - capacity. - :param device_count: The number of devices sharing the channel. If zero, returns zero to avoid division by zero. - :return: Theoretical data rate in Mbps available per device, based on shared channel capacity. - """ - total_capacity_mbps = calculate_total_channel_capacity(channel_width, frequency, environment_type) - if device_count == 0: - return 0 # Avoid division by zero - individual_device_rate_mbps = total_capacity_mbps / device_count - - return individual_device_rate_mbps + def maximum_data_rate_mbps(self) -> float: + return self.maximum_data_rate_bps / 1_000_000.0 class AirSpace(BaseModel): @@ -287,105 +75,62 @@ class AirSpace(BaseModel): Represents a wireless airspace, managing wireless network interfaces and handling wireless transmission. This class provides functionalities to manage a collection of wireless network interfaces, each associated with - specific frequencies and channel widths. It includes methods to calculate and manage bandwidth loads, add and - remove wireless interfaces, and handle data transmission across these interfaces. + specific frequencies. It includes methods to add and remove wireless interfaces, and handle data transmission + across these interfaces. """ - airspace_environment_type_: AirspaceEnvironmentType = AirspaceEnvironmentType.URBAN wireless_interfaces: Dict[str, WirelessNetworkInterface] = Field(default_factory=lambda: {}) - wireless_interfaces_by_frequency_channel_width: Dict[AirSpaceKeyType, List[WirelessNetworkInterface]] = Field( + wireless_interfaces_by_frequency: Dict[AirSpaceFrequency, List[WirelessNetworkInterface]] = Field( default_factory=lambda: {} ) - bandwidth_load: Dict[AirSpaceKeyType, float] = Field(default_factory=lambda: {}) - frequency_channel_width_max_capacity_mbps: Dict[AirSpaceKeyType, float] = Field(default_factory=lambda: {}) + bandwidth_load: Dict[AirSpaceFrequency, float] = Field(default_factory=lambda: {}) + frequency_max_capacity_mbps: Dict[AirSpaceFrequency, float] = Field(default_factory=lambda: {}) def model_post_init(self, __context: Any) -> None: """ Initialize the airspace metadata after instantiation. - This method is called to set up initial configurations like the maximum capacity of each channel width and - frequency based on the current environment setting. + This method is called to set up initial configurations like the maximum capacity of each frequency. :param __context: Contextual data or settings, typically used for further initializations beyond the basic constructor. """ - self._set_frequency_channel_width_max_capacity_mbps() + self.set_frequency_max_capacity_mbps() - def _set_frequency_channel_width_max_capacity_mbps(self): + def set_frequency_max_capacity_mbps(self, capacity_config: Optional[Dict[AirSpaceFrequency, float]] = None): """ - Private method to compute and set the maximum channel capacity in Mbps for each frequency and channel width. - - Based on the airspace environment type, this method calculates the maximum possible data transmission - capacity for each combination of frequency and channel width available and stores these values. - These capacities are critical for managing and limiting bandwidth load during operations. + Set the maximum channel capacity in Mbps for each frequency. """ - print( - f"Rebuilding the frequency channel width maximum capacity dictionary based on " - f"airspace environment type {self.airspace_environment_type_}" - ) + if capacity_config is None: + capacity_config = {} for frequency in AirSpaceFrequency: - for channel_width in ChannelWidth: - max_capacity = calculate_total_channel_capacity( - frequency=frequency, channel_width=channel_width, environment_type=self.airspace_environment_type - ) - self.frequency_channel_width_max_capacity_mbps[frequency, channel_width] = max_capacity - - @computed_field - @property - def airspace_environment_type(self) -> AirspaceEnvironmentType: - """ - Gets the current environment type of the airspace. - - :return: The AirspaceEnvironmentType representing the current environment type. - """ - return self.airspace_environment_type_ - - @airspace_environment_type.setter - def airspace_environment_type(self, value: AirspaceEnvironmentType) -> None: - """ - Sets a new environment type for the airspace and updates related configurations. - - Changing the environment type triggers a re-calculation of the maximum channel capacities and - adjustments to the current setup of wireless interfaces to ensure they are aligned with the - new environment settings. - - :param value: The new environment type as an AirspaceEnvironmentType. - """ - if value != self.airspace_environment_type_: - print(f"Setting airspace_environment_type to {value}") - self.airspace_environment_type_ = value - self._set_frequency_channel_width_max_capacity_mbps() - wireless_interface_keys = list(self.wireless_interfaces.keys()) - for wireless_interface_key in wireless_interface_keys: - wireless_interface = self.wireless_interfaces[wireless_interface_key] - self.remove_wireless_interface(wireless_interface) - self.add_wireless_interface(wireless_interface) + max_capacity = capacity_config.get(frequency, frequency.maximum_data_rate_mbps) + self.frequency_max_capacity_mbps[frequency] = max_capacity def show_bandwidth_load(self, markdown: bool = False): """ - Prints a table of the current bandwidth load for each frequency and channel width combination on the airspace. + Prints a table of the current bandwidth load for each frequency on the airspace. - This method prints a tabulated view showing the utilisation of available bandwidth capacities for all configured - frequency and channel width pairings. The table includes the current capacity usage as a percentage of the - maximum capacity, alongside the absolute maximum capacity values in Mbps. + This method prints a tabulated view showing the utilisation of available bandwidth capacities for all + frequencies. The table includes the current capacity usage as a percentage of the maximum capacity, alongside + the absolute maximum capacity values in Mbps. :param markdown: Flag indicating if output should be in markdown format. """ - headers = ["Frequency", "Channel Width", "Current Capacity (%)", "Maximum Capacity (Mbit)"] + if not self.frequency_max_capacity_mbps: + self.set_frequency_max_capacity_mbps() + headers = ["Frequency", "Current Capacity (%)", "Maximum Capacity (Mbit)"] table = PrettyTable(headers) if markdown: table.set_style(MARKDOWN) table.align = "l" table.title = "Airspace Frequency Channel Loads" - for key, load in self.bandwidth_load.items(): - frequency, channel_width = key - maximum_capacity = self.frequency_channel_width_max_capacity_mbps[key] + for frequency, load in self.bandwidth_load.items(): + maximum_capacity = self.frequency_max_capacity_mbps[frequency] load_percent = load / maximum_capacity if load_percent > 1.0: load_percent = 1.0 - table.add_row( - [format_hertz(frequency.value), str(channel_width), f"{load_percent:.0%}", f"{maximum_capacity:.3f}"] - ) + table.add_row([format_hertz(frequency.value), f"{load_percent:.0%}", f"{maximum_capacity:.3f}"]) print(table) def show_wireless_interfaces(self, markdown: bool = False): @@ -400,7 +145,6 @@ class AirSpace(BaseModel): "IP Address", "Subnet Mask", "Frequency", - "Channel Width", "Speed (Mbps)", "Status", ] @@ -408,7 +152,7 @@ class AirSpace(BaseModel): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"Devices on Air Space - {self.airspace_environment_type}" + table.title = f"Devices on Air Space" for interface in self.wireless_interfaces.values(): status = "Enabled" if interface.enabled else "Disabled" @@ -419,7 +163,6 @@ class AirSpace(BaseModel): interface.ip_address if hasattr(interface, "ip_address") else None, interface.subnet_mask if hasattr(interface, "subnet_mask") else None, format_hertz(interface.frequency.value), - str(interface.channel_width), f"{interface.speed:.3f}", status, ] @@ -431,8 +174,8 @@ class AirSpace(BaseModel): Prints a summary of the current state of the airspace, including both wireless interfaces and bandwidth loads. This method is a convenient wrapper that calls two separate methods to display detailed tables: one for - wireless interfaces and another for bandwidth load across all frequencies and channel widths managed within the - airspace. It provides a holistic view of the operational status and performance metrics of the airspace. + wireless interfaces and another for bandwidth load across all frequencies managed within the airspace. It + provides a holistic view of the operational status and performance metrics of the airspace. :param markdown: Flag indicating if output should be in markdown format. """ @@ -447,15 +190,9 @@ class AirSpace(BaseModel): """ if wireless_interface.mac_address not in self.wireless_interfaces: self.wireless_interfaces[wireless_interface.mac_address] = wireless_interface - if wireless_interface.airspace_key not in self.wireless_interfaces_by_frequency_channel_width: - self.wireless_interfaces_by_frequency_channel_width[wireless_interface.airspace_key] = [] - self.wireless_interfaces_by_frequency_channel_width[wireless_interface.airspace_key].append( - wireless_interface - ) - speed = calculate_total_channel_capacity( - wireless_interface.channel_width, wireless_interface.frequency, self.airspace_environment_type - ) - wireless_interface.set_speed(speed) + if wireless_interface.frequency not in self.wireless_interfaces_by_frequency: + self.wireless_interfaces_by_frequency[wireless_interface.frequency] = [] + self.wireless_interfaces_by_frequency[wireless_interface.frequency].append(wireless_interface) def remove_wireless_interface(self, wireless_interface: WirelessNetworkInterface): """ @@ -465,9 +202,7 @@ class AirSpace(BaseModel): """ if wireless_interface.mac_address in self.wireless_interfaces: self.wireless_interfaces.pop(wireless_interface.mac_address) - self.wireless_interfaces_by_frequency_channel_width[wireless_interface.airspace_key].remove( - wireless_interface - ) + self.wireless_interfaces_by_frequency[wireless_interface.frequency].remove(wireless_interface) def clear(self): """ @@ -477,7 +212,7 @@ class AirSpace(BaseModel): occur until new interfaces are added again. """ self.wireless_interfaces.clear() - self.wireless_interfaces_by_frequency_channel_width.clear() + self.wireless_interfaces_by_frequency.clear() def reset_bandwidth_load(self): """ @@ -500,11 +235,13 @@ class AirSpace(BaseModel): relevant frequency and its current bandwidth load. :return: True if the frame can be transmitted within the bandwidth limit, False if it would exceed the limit. """ - if sender_network_interface.airspace_key not in self.bandwidth_load: - self.bandwidth_load[sender_network_interface.airspace_key] = 0.0 + if not self.frequency_max_capacity_mbps: + self.set_frequency_max_capacity_mbps() + if sender_network_interface.frequency not in self.bandwidth_load: + self.bandwidth_load[sender_network_interface.frequency] = 0.0 return ( - self.bandwidth_load[sender_network_interface.airspace_key] + frame.size_Mbits - <= self.frequency_channel_width_max_capacity_mbps[sender_network_interface.airspace_key] + self.bandwidth_load[sender_network_interface.frequency] + frame.size_Mbits + <= self.frequency_max_capacity_mbps[sender_network_interface.frequency] ) def transmit(self, frame: Frame, sender_network_interface: WirelessNetworkInterface): @@ -517,9 +254,9 @@ class AirSpace(BaseModel): :param sender_network_interface: The wireless network interface sending the frame. This interface will be excluded from the list of receivers to prevent it from receiving its own transmission. """ - self.bandwidth_load[sender_network_interface.airspace_key] += frame.size_Mbits - for wireless_interface in self.wireless_interfaces_by_frequency_channel_width.get( - sender_network_interface.airspace_key, [] + self.bandwidth_load[sender_network_interface.frequency] += frame.size_Mbits + for wireless_interface in self.wireless_interfaces_by_frequency.get( + sender_network_interface.frequency, [] ): if wireless_interface != sender_network_interface and wireless_interface.enabled: wireless_interface.receive_frame(frame) @@ -546,135 +283,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC): """ airspace: AirSpace - frequency_: AirSpaceFrequency = AirSpaceFrequency.WIFI_2_4 - channel_width_: ChannelWidth = ChannelWidth.WIDTH_40_MHZ - - @model_validator(mode="after") # noqa - def validate_channel_width_for_2_4_ghz(self) -> "WirelessNetworkInterface": - """ - Validate the wireless interface's channel width settings after model changes. - - This method serves as a model validator to ensure that the channel width settings for the 2.4 GHz frequency - comply with accepted standards (either 20 MHz or 40 MHz). It's triggered after model instantiation. - - Ensures that the channel width is appropriate for the current frequency setting, particularly checking - and adjusting the settings for the 2.4 GHz frequency band to not exceed 40 MHz. This is crucial for - avoiding interference and ensuring optimal performance in densely populated wireless environments. - """ - self._check_wifi_24_channel_width() - return self - - def model_post_init(self, __context: Any) -> None: - """Initialise the model after its creation, setting the speed based on the calculated channel capacity.""" - speed = calculate_total_channel_capacity( - channel_width=self.channel_width, - frequency=self.frequency, - environment_type=self.airspace.airspace_environment_type, - ) - self.set_speed(speed) - - def _check_wifi_24_channel_width(self) -> None: - """ - Ensures that the channel width for 2.4 GHz frequency does not exceed 40 MHz. - - This method checks the current frequency and channel width settings and adjusts the channel width - to 40 MHz if the frequency is set to 2.4 GHz and the channel width exceeds 40 MHz. This is done to - comply with typical Wi-Fi standards for 2.4 GHz frequencies, which commonly support up to 40 MHz. - - Logs a SysLog warning if the channel width had to be adjusted, logging this change either to the connected - node's system log or the global logger, depending on whether the interface is connected to a node. - """ - if self.frequency_ == AirSpaceFrequency.WIFI_2_4 and self.channel_width_.value > 40: - self.channel_width_ = ChannelWidth.WIDTH_40_MHZ - msg = ( - f"Channel width must be either 20 Mhz or 40 Mhz when using {AirSpaceFrequency.WIFI_2_4}. " - f"Overriding value to use {ChannelWidth.WIDTH_40_MHZ}." - ) - if self._connected_node: - self._connected_node.sys_log.warning(f"Wireless Interface {self.port_num}: {msg}") - else: - _LOGGER.warning(msg) - - @computed_field - @property - def frequency(self) -> AirSpaceFrequency: - """ - Get the current operating frequency of the wireless interface. - - :return: The current frequency as an AirSpaceFrequency enum value. - """ - return self.frequency_ - - @frequency.setter - def frequency(self, value: AirSpaceFrequency) -> None: - """ - Set the operating frequency of the wireless interface and update the network configuration. - - This setter updates the frequency of the wireless interface if the new value differs from the current setting. - It handles the update by first removing the interface from the current airspace management to avoid conflicts, - setting the new frequency, ensuring the channel width remains compliant, and then re-adding the interface - to the airspace with the new settings. - - :param value: The new frequency to set, as an AirSpaceFrequency enum value. - """ - if value != self.frequency_: - self.airspace.remove_wireless_interface(self) - self.frequency_ = value - self._check_wifi_24_channel_width() - self.airspace.add_wireless_interface(self) - - @computed_field - @property - def channel_width(self) -> ChannelWidth: - """ - Get the current channel width setting of the wireless interface. - - :return: The current channel width as a ChannelWidth enum value. - """ - return self.channel_width_ - - @channel_width.setter - def channel_width(self, value: ChannelWidth) -> None: - """ - Set the channel width of the wireless interface and manage configuration compliance. - - Updates the channel width of the wireless interface. If the new channel width is different from the existing - one, it first removes the interface from the airspace to prevent configuration conflicts, sets the new channel - width, checks and adjusts it if necessary (especially for 2.4 GHz frequency to comply with typical standards), - and then re-registers the interface in the airspace with updated settings. - - :param value: The new channel width to set, as a ChannelWidth enum value. - """ - if value != self.channel_width_: - self.airspace.remove_wireless_interface(self) - self.channel_width_ = value - self._check_wifi_24_channel_width() - self.airspace.add_wireless_interface(self) - - @property - def airspace_key(self) -> tuple: - """ - The airspace bandwidth/channel identifier for the wireless interface based on its frequency and channel width. - - :return: A tuple containing the frequency and channel width, serving as a bandwidth/channel key. - """ - return self.frequency_, self.channel_width_ - - def set_speed(self, speed: float): - """ - Sets the network interface speed to the specified value and logs this action. - - This method updates the speed attribute of the network interface to the given value, reflecting - the theoretical maximum data rate that the interface can support based on the current settings. - It logs the new speed to the system log of the connected node if available. - - :param speed: The speed in Mbps to be set for the network interface. - """ - self.speed = speed - if self._connected_node: - self._connected_node.sys_log.info( - f"Wireless Interface {self.port_num}: Setting theoretical maximum data rate to {speed:.3f} Mbps." - ) + frequency: AirSpaceFrequency = AirSpaceFrequency.WIFI_2_4 def enable(self): """Attempt to enable the network interface.""" diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index dda9e4f8..5ded993e 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Optional, Union from pydantic import validate_call -from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, ChannelWidth, IPWirelessNetworkInterface +from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, IPWirelessNetworkInterface from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterInterface from primaite.simulator.network.transmission.data_link_layer import Frame @@ -154,7 +154,6 @@ class WirelessRouter(Router): ip_address: IPV4Address, subnet_mask: IPV4Address, frequency: Optional[AirSpaceFrequency] = AirSpaceFrequency.WIFI_2_4, - channel_width: Optional[ChannelWidth] = ChannelWidth.WIDTH_40_MHZ, ): """ Configures a wireless access point (WAP). @@ -173,8 +172,6 @@ class WirelessRouter(Router): """ if not frequency: frequency = AirSpaceFrequency.WIFI_2_4 - if not channel_width: - channel_width = ChannelWidth.WIDTH_40_MHZ self.sys_log.info("Configuring wireless access point") self.wireless_access_point.disable() # Temporarily disable the WAP for reconfiguration @@ -185,7 +182,6 @@ class WirelessRouter(Router): network_interface.subnet_mask = subnet_mask self.wireless_access_point.frequency = frequency # Set operating frequency - self.wireless_access_point.channel_width = channel_width self.wireless_access_point.enable() # Re-enable the WAP with new settings self.sys_log.info(f"Configured WAP {network_interface}") @@ -269,11 +265,8 @@ class WirelessRouter(Router): ip_address = cfg["wireless_access_point"]["ip_address"] subnet_mask = cfg["wireless_access_point"]["subnet_mask"] frequency = AirSpaceFrequency[cfg["wireless_access_point"]["frequency"]] - channel_width = cfg["wireless_access_point"].get("channel_width") - if channel_width: - channel_width = ChannelWidth(channel_width) router.configure_wireless_access_point( - ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency, channel_width=channel_width + ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency ) if "acl" in cfg: diff --git a/tests/assets/configs/wireless_wan_wifi_5_80_channel_width_blocked.yaml b/tests/assets/configs/wireless_wan_wifi_5_80_channel_width_blocked.yaml index 21b0fe5e..5aed49cb 100644 --- a/tests/assets/configs/wireless_wan_wifi_5_80_channel_width_blocked.yaml +++ b/tests/assets/configs/wireless_wan_wifi_5_80_channel_width_blocked.yaml @@ -38,7 +38,6 @@ simulation: ip_address: 192.168.1.1 subnet_mask: 255.255.255.0 frequency: WIFI_5 - channel_width: 80 acl: 1: action: PERMIT @@ -60,7 +59,6 @@ simulation: ip_address: 192.168.1.2 subnet_mask: 255.255.255.0 frequency: WIFI_5 - channel_width: 80 acl: 1: action: PERMIT diff --git a/tests/assets/configs/wireless_wan_wifi_5_80_channel_width_urban.yaml b/tests/assets/configs/wireless_wan_wifi_5_80_channel_width_urban.yaml index ed27cd35..d2e64720 100644 --- a/tests/assets/configs/wireless_wan_wifi_5_80_channel_width_urban.yaml +++ b/tests/assets/configs/wireless_wan_wifi_5_80_channel_width_urban.yaml @@ -38,7 +38,6 @@ simulation: ip_address: 192.168.1.1 subnet_mask: 255.255.255.0 frequency: WIFI_5 - channel_width: 80 acl: 1: action: PERMIT @@ -60,7 +59,6 @@ simulation: ip_address: 192.168.1.2 subnet_mask: 255.255.255.0 frequency: WIFI_5 - channel_width: 80 acl: 1: action: PERMIT diff --git a/tests/integration_tests/network/test_airspace_capacity_configuration.py b/tests/integration_tests/network/test_airspace_capacity_configuration.py deleted file mode 100644 index f91f1290..00000000 --- a/tests/integration_tests/network/test_airspace_capacity_configuration.py +++ /dev/null @@ -1,106 +0,0 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -import yaml - -from primaite.game.game import PrimaiteGame -from primaite.simulator.network.airspace import ( - AirspaceEnvironmentType, - AirSpaceFrequency, - calculate_total_channel_capacity, - ChannelWidth, -) -from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter -from tests import TEST_ASSETS_ROOT - - -def test_wireless_wan_wifi_5_80_channel_width_urban(): - config_path = TEST_ASSETS_ROOT / "configs" / "wireless_wan_wifi_5_80_channel_width_urban.yaml" - - with open(config_path, "r") as f: - config_dict = yaml.safe_load(f) - network = PrimaiteGame.from_config(cfg=config_dict).simulation.network - - airspace = network.airspace - - assert airspace.airspace_environment_type == AirspaceEnvironmentType.URBAN - - router_1: WirelessRouter = network.get_node_by_hostname("router_1") - router_2: WirelessRouter = network.get_node_by_hostname("router_2") - - expected_speed = calculate_total_channel_capacity( - channel_width=ChannelWidth.WIDTH_80_MHZ, - frequency=AirSpaceFrequency.WIFI_5, - environment_type=AirspaceEnvironmentType.URBAN, - ) - - assert router_1.wireless_access_point.speed == expected_speed - assert router_2.wireless_access_point.speed == expected_speed - - pc_a = network.get_node_by_hostname("pc_a") - pc_b = network.get_node_by_hostname("pc_b") - - assert pc_a.ping(pc_a.default_gateway), "PC A should ping its default gateway successfully." - assert pc_b.ping(pc_b.default_gateway), "PC B should ping its default gateway successfully." - - assert pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should ping PC B across routers successfully." - assert pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should ping PC A across routers successfully." - - -def test_wireless_wan_wifi_5_80_channel_width_blocked(): - config_path = TEST_ASSETS_ROOT / "configs" / "wireless_wan_wifi_5_80_channel_width_blocked.yaml" - - with open(config_path, "r") as f: - config_dict = yaml.safe_load(f) - network = PrimaiteGame.from_config(cfg=config_dict).simulation.network - - airspace = network.airspace - - assert airspace.airspace_environment_type == AirspaceEnvironmentType.BLOCKED - - router_1: WirelessRouter = network.get_node_by_hostname("router_1") - router_2: WirelessRouter = network.get_node_by_hostname("router_2") - - expected_speed = calculate_total_channel_capacity( - channel_width=ChannelWidth.WIDTH_80_MHZ, - frequency=AirSpaceFrequency.WIFI_5, - environment_type=AirspaceEnvironmentType.BLOCKED, - ) - - assert router_1.wireless_access_point.speed == expected_speed - assert router_2.wireless_access_point.speed == expected_speed - - pc_a = network.get_node_by_hostname("pc_a") - pc_b = network.get_node_by_hostname("pc_b") - - assert pc_a.ping(pc_a.default_gateway), "PC A should ping its default gateway successfully." - assert pc_b.ping(pc_b.default_gateway), "PC B should ping its default gateway successfully." - - assert not pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should ping PC B across routers unsuccessfully." - assert not pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should ping PC A across routers unsuccessfully." - - -def test_wireless_wan_blocking_and_unblocking_airspace(): - config_path = TEST_ASSETS_ROOT / "configs" / "wireless_wan_wifi_5_80_channel_width_urban.yaml" - - with open(config_path, "r") as f: - config_dict = yaml.safe_load(f) - network = PrimaiteGame.from_config(cfg=config_dict).simulation.network - - airspace = network.airspace - - assert airspace.airspace_environment_type == AirspaceEnvironmentType.URBAN - - pc_a = network.get_node_by_hostname("pc_a") - pc_b = network.get_node_by_hostname("pc_b") - - assert pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should ping PC B across routers successfully." - assert pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should ping PC A across routers successfully." - - airspace.airspace_environment_type = AirspaceEnvironmentType.BLOCKED - - assert not pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should ping PC B across routers unsuccessfully." - assert not pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should ping PC A across routers unsuccessfully." - - airspace.airspace_environment_type = AirspaceEnvironmentType.URBAN - - assert pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should ping PC B across routers successfully." - assert pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should ping PC A across routers successfully." diff --git a/tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py b/tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py index cf03ea8e..b7317c3d 100644 --- a/tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py +++ b/tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py @@ -40,30 +40,6 @@ def test_wireless_link_loading(wireless_wan_network): client.network_interface[1]._connected_link.pre_timestep(1) server.network_interface[1]._connected_link.pre_timestep(1) - assert ftp_client.send_file( - src_file_name="mixtape.mp3", - src_folder_name="music", - dest_ip_address=server.network_interface[1].ip_address, - dest_file_name="mixtape1.mp3", - dest_folder_name="music", - ) - - # Reset the physical links between the host nodes and the routers - client.network_interface[1]._connected_link.pre_timestep(1) - server.network_interface[1]._connected_link.pre_timestep(1) - - assert ftp_client.send_file( - src_file_name="mixtape.mp3", - src_folder_name="music", - dest_ip_address=server.network_interface[1].ip_address, - dest_file_name="mixtape2.mp3", - dest_folder_name="music", - ) - - # Reset the physical links between the host nodes and the routers - client.network_interface[1]._connected_link.pre_timestep(1) - server.network_interface[1]._connected_link.pre_timestep(1) - assert not ftp_client.send_file( src_file_name="mixtape.mp3", src_folder_name="music", From abbfc869425bdcde77a432372299c939879f9335 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 11 Jul 2024 15:50:35 +0100 Subject: [PATCH 32/35] 2623 update defaults --- src/primaite/game/agent/interface.py | 2 +- src/primaite/simulator/core.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 01b7fb0a..c00fd9d4 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -69,7 +69,7 @@ class AgentSettings(BaseModel): "Configuration for when an agent begins performing it's actions" flatten_obs: bool = True "Whether to flatten the observation space before passing it to the agent. True by default." - action_masking: bool = True + action_masking: bool = False "Whether to return action masks at each step." @classmethod diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 7653a3ab..574fcc19 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -174,6 +174,7 @@ class RequestManager(BaseModel): return requests def show(self) -> None: + """Display all currently available requests and whether they are valid.""" table = PrettyTable(["request", "valid"]) table.align = "l" table.add_rows(self.get_request_types_recursively()) From e759ae59904725695e31eb20095a2b1de3daf0e9 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 11 Jul 2024 17:44:31 +0100 Subject: [PATCH 33/35] 2623 fix issues with tests and revert request show method --- src/primaite/simulator/core.py | 11 +++++------ tests/assets/configs/test_primaite_session.yaml | 2 ++ tests/e2e_integration_tests/test_environment.py | 4 ++-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 67108d5d..a5e39cc8 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -165,7 +165,7 @@ class RequestManager(BaseModel): self.request_types.pop(name) - def get_request_types_recursively(self, _parent_valid: bool = True) -> List[Tuple[RequestFormat, bool]]: + def get_request_types_recursively(self) -> List[RequestFormat]: """ Recursively generate request tree for this component. @@ -178,18 +178,17 @@ class RequestManager(BaseModel): """ requests = [] for req_name, req in self.request_types.items(): - valid = req.validator([], {}) and _parent_valid # if parent is invalid, all children are invalid if isinstance(req.func, RequestManager): - sub_requests = req.func.get_request_types_recursively(valid) # recurse - sub_requests = [([req_name] + a, valid) for a, valid in sub_requests] # prepend parent request to leaf + sub_requests = req.func.get_request_types_recursively() # recurse + sub_requests = [([req_name] + a) for a in sub_requests] # prepend parent request to leaf requests.extend(sub_requests) else: # leaf node found - requests.append(([req_name], valid)) + requests.append(req_name) return requests def show(self) -> None: """Display all currently available requests and whether they are valid.""" - table = PrettyTable(["request", "valid"]) + table = PrettyTable(["request"]) table.align = "l" table.add_rows(self.get_request_types_recursively()) print(table) diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index c435fe44..eb8103e8 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -646,6 +646,8 @@ simulation: dns_server: 192.168.1.10 services: - type: DatabaseService + options: + backup_server_ip: 192.168.1.16 - type: server hostname: backup_server diff --git a/tests/e2e_integration_tests/test_environment.py b/tests/e2e_integration_tests/test_environment.py index 253bd396..dcd51193 100644 --- a/tests/e2e_integration_tests/test_environment.py +++ b/tests/e2e_integration_tests/test_environment.py @@ -70,8 +70,8 @@ class TestPrimaiteEnvironment: assert len(env.agents) == 2 defender_1 = env.agents["defender_1"] defender_2 = env.agents["defender_2"] - assert (num_actions_1 := len(defender_1.action_manager.action_map)) == 74 - assert (num_actions_2 := len(defender_2.action_manager.action_map)) == 74 + assert (num_actions_1 := len(defender_1.action_manager.action_map)) == 78 + assert (num_actions_2 := len(defender_2.action_manager.action_map)) == 78 # ensure we can run all valid actions without error for act_1 in range(num_actions_1): From cde632066cf262c2e17b7705362bb9fd1530b50a Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Thu, 11 Jul 2024 21:11:27 +0100 Subject: [PATCH 34/35] #2745 implemented overriding of frequency max capacities on the airspace. updated documentation to reflect the changes in airspace.py. --- docs/source/configuration/simulation.rst | 36 ++++----- src/primaite/game/game.py | 6 ++ src/primaite/simulator/network/airspace.py | 75 +++++++++++-------- .../hardware/nodes/network/wireless_router.py | 4 +- .../configs/wireless_wan_network_config.yaml | 2 - ...wan_network_config_freq_max_override.yaml} | 8 +- ...ork_config_freq_max_override_blocked.yaml} | 8 +- .../network/test_airspace_config.py | 44 +++++++++++ 8 files changed, 120 insertions(+), 63 deletions(-) rename tests/assets/configs/{wireless_wan_wifi_5_80_channel_width_urban.yaml => wireless_wan_network_config_freq_max_override.yaml} (92%) rename tests/assets/configs/{wireless_wan_wifi_5_80_channel_width_blocked.yaml => wireless_wan_network_config_freq_max_override_blocked.yaml} (92%) create mode 100644 tests/integration_tests/network/test_airspace_config.py diff --git a/docs/source/configuration/simulation.rst b/docs/source/configuration/simulation.rst index bd66914d..48b857d9 100644 --- a/docs/source/configuration/simulation.rst +++ b/docs/source/configuration/simulation.rst @@ -108,31 +108,23 @@ This is an integer value specifying the allowed bandwidth across the connection. ``airspace`` ------------ -This section configures settings specific to the wireless network's virtual airspace. It defines how wireless interfaces within the simulation will interact and perform under various environmental conditions. +This section configures settings specific to the wireless network's virtual airspace. -``airspace_environment_type`` +``frequency_max_capacity_mbps`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -This setting specifies the environmental conditions of the airspace which affect the propagation and interference characteristics of wireless signals. Changing this environment type impacts how signal noise and interference are calculated, thus affecting the overall network performance, including data transmission rates and signal quality. +This setting allows the user to override the default maximum bandwidth capacity set for each frequency. The key should +be the AirSpaceFrequency name and the value be the desired maximum bandwidth capacity in mbps (megabits per second) for +a single timestep. -**Configurable Options** +The below example would permit 123.45 megabits to be transmit across the WiFi 2.4 GHz frequency in a single timestep. +Setting a frequencies max capacity to 0.0 blocks that frequency on the airspace. -- **rural**: A rural environment offers clear channel conditions due to low population density and minimal electronic device presence. +.. code-block:: yaml -- **outdoor**: Outdoor environments like parks or fields have minimal electronic interference. - -- **suburban**: Suburban environments strike a balance with fewer electronic interferences than urban but more than rural. - -- **office**: Office environments have moderate interference from numerous electronic devices and overlapping networks. - -- **urban**: Urban environments are characterized by tall buildings and a high density of electronic devices, leading to significant interference. - -- **industrial**: Industrial areas face high interference from heavy machinery and numerous electronic devices. - -- **transport**: Environments such as subways and buses where metal structures and high mobility create complex interference patterns. - -- **dense_urban**: Dense urban areas like city centers have the highest level of signal interference due to the very high density of buildings and devices. - -- **jamming_zone**: A jamming zone environment where signals are actively interfered with, typically through the use of signal jammers or scrambling devices. This represents the environment with the highest level of interference. - -- **blocked**: A jamming zone environment with total levels of interference. Airspace is completely blocked. + simulation: + network: + airspace: + frequency_max_capacity_mbps: + WIFI_2_4: 123.45 + WIFI_5: 0.0 diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index b976e55f..9eadc360 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -16,6 +16,7 @@ from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT +from primaite.simulator.network.airspace import AirSpaceFrequency from primaite.simulator.network.hardware.base import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC @@ -237,6 +238,11 @@ class PrimaiteGame: simulation_config = cfg.get("simulation", {}) network_config = simulation_config.get("network", {}) airspace_cfg = network_config.get("airspace", {}) + frequency_max_capacity_mbps_cfg = airspace_cfg.get("frequency_max_capacity_mbps", {}) + + frequency_max_capacity_mbps_cfg = {AirSpaceFrequency[k]: v for k, v in frequency_max_capacity_mbps_cfg.items()} + + net.airspace.frequency_max_capacity_mbps_ = frequency_max_capacity_mbps_cfg nodes_cfg = network_config.get("nodes", []) links_cfg = network_config.get("links", []) diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index 5019385a..9c736383 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from prettytable import MARKDOWN, PrettyTable from pydantic import BaseModel, Field @@ -59,6 +59,15 @@ class AirSpaceFrequency(Enum): @property def maximum_data_rate_bps(self) -> float: + """ + Retrieves the maximum data transmission rate in bits per second (bps) for the frequency. + + The maximum rates are predefined for known frequencies: + - For WIFI_2_4, it returns 100,000,000 bps (100 Mbps). + - For WIFI_5, it returns 500,000,000 bps (500 Mbps). + + :return: The maximum data rate in bits per second. If the frequency is not recognized, returns 0.0. + """ if self == AirSpaceFrequency.WIFI_2_4: return 100_000_000.0 # 100 Megabits per second if self == AirSpaceFrequency.WIFI_5: @@ -67,6 +76,14 @@ class AirSpaceFrequency(Enum): @property def maximum_data_rate_mbps(self) -> float: + """ + Retrieves the maximum data transmission rate in megabits per second (Mbps). + + This is derived by converting the maximum data rate from bits per second, as defined + in `maximum_data_rate_bps`, to megabits per second. + + :return: The maximum data rate in megabits per second. + """ return self.maximum_data_rate_bps / 1_000_000.0 @@ -84,28 +101,33 @@ class AirSpace(BaseModel): default_factory=lambda: {} ) bandwidth_load: Dict[AirSpaceFrequency, float] = Field(default_factory=lambda: {}) - frequency_max_capacity_mbps: Dict[AirSpaceFrequency, float] = Field(default_factory=lambda: {}) + frequency_max_capacity_mbps_: Dict[AirSpaceFrequency, float] = Field(default_factory=lambda: {}) - def model_post_init(self, __context: Any) -> None: + def get_frequency_max_capacity_mbps(self, frequency: AirSpaceFrequency) -> float: """ - Initialize the airspace metadata after instantiation. + Retrieves the maximum data transmission capacity for a specified frequency. - This method is called to set up initial configurations like the maximum capacity of each frequency. + This method checks a dictionary holding custom maximum capacities. If the frequency is found, it returns the + custom set maximum capacity. If the frequency is not found in the dictionary, it defaults to the standard + maximum data rate associated with that frequency. - :param __context: Contextual data or settings, typically used for further initializations beyond - the basic constructor. - """ - self.set_frequency_max_capacity_mbps() + :param frequency: The frequency for which the maximum capacity is queried. - def set_frequency_max_capacity_mbps(self, capacity_config: Optional[Dict[AirSpaceFrequency, float]] = None): + :return: The maximum capacity in Mbps for the specified frequency. """ - Set the maximum channel capacity in Mbps for each frequency. + if frequency in self.frequency_max_capacity_mbps_: + return self.frequency_max_capacity_mbps_[frequency] + return frequency.maximum_data_rate_mbps + + def set_frequency_max_capacity_mbps(self, cfg: Dict[AirSpaceFrequency, float]): """ - if capacity_config is None: - capacity_config = {} - for frequency in AirSpaceFrequency: - max_capacity = capacity_config.get(frequency, frequency.maximum_data_rate_mbps) - self.frequency_max_capacity_mbps[frequency] = max_capacity + Sets custom maximum data transmission capacities for multiple frequencies. + + :param cfg: A dictionary mapping frequencies to their new maximum capacities in Mbps. + """ + self.frequency_max_capacity_mbps_ = cfg + for freq, mbps in cfg.items(): + print(f"Overriding {freq} max capacity as {mbps:.3f} mbps") def show_bandwidth_load(self, markdown: bool = False): """ @@ -117,8 +139,6 @@ class AirSpace(BaseModel): :param markdown: Flag indicating if output should be in markdown format. """ - if not self.frequency_max_capacity_mbps: - self.set_frequency_max_capacity_mbps() headers = ["Frequency", "Current Capacity (%)", "Maximum Capacity (Mbit)"] table = PrettyTable(headers) if markdown: @@ -126,8 +146,8 @@ class AirSpace(BaseModel): table.align = "l" table.title = "Airspace Frequency Channel Loads" for frequency, load in self.bandwidth_load.items(): - maximum_capacity = self.frequency_max_capacity_mbps[frequency] - load_percent = load / maximum_capacity + maximum_capacity = self.get_frequency_max_capacity_mbps(frequency) + load_percent = load / maximum_capacity if maximum_capacity > 0 else 0.0 if load_percent > 1.0: load_percent = 1.0 table.add_row([format_hertz(frequency.value), f"{load_percent:.0%}", f"{maximum_capacity:.3f}"]) @@ -152,7 +172,7 @@ class AirSpace(BaseModel): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"Devices on Air Space" + table.title = "Devices on Air Space" for interface in self.wireless_interfaces.values(): status = "Enabled" if interface.enabled else "Disabled" @@ -235,14 +255,11 @@ class AirSpace(BaseModel): relevant frequency and its current bandwidth load. :return: True if the frame can be transmitted within the bandwidth limit, False if it would exceed the limit. """ - if not self.frequency_max_capacity_mbps: - self.set_frequency_max_capacity_mbps() if sender_network_interface.frequency not in self.bandwidth_load: self.bandwidth_load[sender_network_interface.frequency] = 0.0 - return ( - self.bandwidth_load[sender_network_interface.frequency] + frame.size_Mbits - <= self.frequency_max_capacity_mbps[sender_network_interface.frequency] - ) + return self.bandwidth_load[ + sender_network_interface.frequency + ] + frame.size_Mbits <= self.get_frequency_max_capacity_mbps(sender_network_interface.frequency) def transmit(self, frame: Frame, sender_network_interface: WirelessNetworkInterface): """ @@ -255,9 +272,7 @@ class AirSpace(BaseModel): excluded from the list of receivers to prevent it from receiving its own transmission. """ self.bandwidth_load[sender_network_interface.frequency] += frame.size_Mbits - for wireless_interface in self.wireless_interfaces_by_frequency.get( - sender_network_interface.frequency, [] - ): + for wireless_interface in self.wireless_interfaces_by_frequency.get(sender_network_interface.frequency, []): if wireless_interface != sender_network_interface and wireless_interface.enabled: wireless_interface.receive_frame(frame) diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 5ded993e..3cb4c515 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -265,9 +265,7 @@ class WirelessRouter(Router): ip_address = cfg["wireless_access_point"]["ip_address"] subnet_mask = cfg["wireless_access_point"]["subnet_mask"] frequency = AirSpaceFrequency[cfg["wireless_access_point"]["frequency"]] - router.configure_wireless_access_point( - ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency - ) + router.configure_wireless_access_point(ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency) if "acl" in cfg: for r_num, r_cfg in cfg["acl"].items(): diff --git a/tests/assets/configs/wireless_wan_network_config.yaml b/tests/assets/configs/wireless_wan_network_config.yaml index 7172f66d..c8f61bad 100644 --- a/tests/assets/configs/wireless_wan_network_config.yaml +++ b/tests/assets/configs/wireless_wan_network_config.yaml @@ -9,8 +9,6 @@ game: simulation: network: - airspace: - airspace_environment_type: urban nodes: - type: computer hostname: pc_a diff --git a/tests/assets/configs/wireless_wan_wifi_5_80_channel_width_urban.yaml b/tests/assets/configs/wireless_wan_network_config_freq_max_override.yaml similarity index 92% rename from tests/assets/configs/wireless_wan_wifi_5_80_channel_width_urban.yaml rename to tests/assets/configs/wireless_wan_network_config_freq_max_override.yaml index d2e64720..a327b0f5 100644 --- a/tests/assets/configs/wireless_wan_wifi_5_80_channel_width_urban.yaml +++ b/tests/assets/configs/wireless_wan_network_config_freq_max_override.yaml @@ -10,7 +10,9 @@ game: simulation: network: airspace: - airspace_environment_type: urban + frequency_max_capacity_mbps: + WIFI_2_4: 123.45 + WIFI_5: 0.0 nodes: - type: computer hostname: pc_a @@ -37,7 +39,7 @@ simulation: wireless_access_point: ip_address: 192.168.1.1 subnet_mask: 255.255.255.0 - frequency: WIFI_5 + frequency: WIFI_2_4 acl: 1: action: PERMIT @@ -58,7 +60,7 @@ simulation: wireless_access_point: ip_address: 192.168.1.2 subnet_mask: 255.255.255.0 - frequency: WIFI_5 + frequency: WIFI_2_4 acl: 1: action: PERMIT diff --git a/tests/assets/configs/wireless_wan_wifi_5_80_channel_width_blocked.yaml b/tests/assets/configs/wireless_wan_network_config_freq_max_override_blocked.yaml similarity index 92% rename from tests/assets/configs/wireless_wan_wifi_5_80_channel_width_blocked.yaml rename to tests/assets/configs/wireless_wan_network_config_freq_max_override_blocked.yaml index 5aed49cb..ff048c92 100644 --- a/tests/assets/configs/wireless_wan_wifi_5_80_channel_width_blocked.yaml +++ b/tests/assets/configs/wireless_wan_network_config_freq_max_override_blocked.yaml @@ -10,7 +10,9 @@ game: simulation: network: airspace: - airspace_environment_type: blocked + frequency_max_capacity_mbps: + WIFI_2_4: 0.0 + WIFI_5: 0.0 nodes: - type: computer hostname: pc_a @@ -37,7 +39,7 @@ simulation: wireless_access_point: ip_address: 192.168.1.1 subnet_mask: 255.255.255.0 - frequency: WIFI_5 + frequency: WIFI_2_4 acl: 1: action: PERMIT @@ -58,7 +60,7 @@ simulation: wireless_access_point: ip_address: 192.168.1.2 subnet_mask: 255.255.255.0 - frequency: WIFI_5 + frequency: WIFI_2_4 acl: 1: action: PERMIT diff --git a/tests/integration_tests/network/test_airspace_config.py b/tests/integration_tests/network/test_airspace_config.py new file mode 100644 index 00000000..78d00b47 --- /dev/null +++ b/tests/integration_tests/network/test_airspace_config.py @@ -0,0 +1,44 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import yaml + +from primaite.game.game import PrimaiteGame +from primaite.simulator.network.airspace import AirSpaceFrequency +from tests import TEST_ASSETS_ROOT + + +def test_override_freq_max_capacity_mbps(): + config_path = TEST_ASSETS_ROOT / "configs" / "wireless_wan_network_config_freq_max_override.yaml" + + with open(config_path, "r") as f: + config_dict = yaml.safe_load(f) + network = PrimaiteGame.from_config(cfg=config_dict).simulation.network + + assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_2_4) == 123.45 + assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_5) == 0.0 + + pc_a = network.get_node_by_hostname("pc_a") + pc_b = network.get_node_by_hostname("pc_b") + + assert pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should be able to ping PC B" + assert pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should be able to ping PC A" + + network.airspace.show() + + +def test_override_freq_max_capacity_mbps_blocked(): + config_path = TEST_ASSETS_ROOT / "configs" / "wireless_wan_network_config_freq_max_override_blocked.yaml" + + with open(config_path, "r") as f: + config_dict = yaml.safe_load(f) + network = PrimaiteGame.from_config(cfg=config_dict).simulation.network + + assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_2_4) == 0.0 + assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_5) == 0.0 + + pc_a = network.get_node_by_hostname("pc_a") + pc_b = network.get_node_by_hostname("pc_b") + + assert not pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should not be able to ping PC B" + assert not pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should not be able to ping PC A" + + network.airspace.show() From 199cd0d9dfe2ae40310491b07aa3fba2b169f5bc Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 12 Jul 2024 11:23:41 +0100 Subject: [PATCH 35/35] fix test problems and slowness --- src/primaite/simulator/core.py | 14 +++++++------- tests/conftest.py | 4 ++-- .../test_agents_use_action_masks.py | 4 +--- .../test_rllib_multi_agent_environment.py | 16 ++-------------- .../test_rllib_single_agent_environment.py | 1 - .../environments/test_sb3_environment.py | 2 +- 6 files changed, 13 insertions(+), 28 deletions(-) diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index a5e39cc8..848570fe 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -179,18 +179,18 @@ class RequestManager(BaseModel): requests = [] for req_name, req in self.request_types.items(): if isinstance(req.func, RequestManager): - sub_requests = req.func.get_request_types_recursively() # recurse - sub_requests = [([req_name] + a) for a in sub_requests] # prepend parent request to leaf + sub_requests = req.func.get_request_types_recursively() + sub_requests = [[req_name] + a for a in sub_requests] requests.extend(sub_requests) - else: # leaf node found - requests.append(req_name) + else: + requests.append([req_name]) return requests def show(self) -> None: - """Display all currently available requests and whether they are valid.""" - table = PrettyTable(["request"]) + """Display all currently available requests.""" + table = PrettyTable(["requests"]) table.align = "l" - table.add_rows(self.get_request_types_recursively()) + table.add_rows([[x] for x in self.get_request_types_recursively()]) print(table) def check_valid(self, request: RequestFormat, context: Dict) -> bool: diff --git a/tests/conftest.py b/tests/conftest.py index b8b50182..54519e2b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,8 +2,8 @@ from typing import Any, Dict, Tuple import pytest -import ray import yaml +from ray import init as rayinit from primaite import getLogger, PRIMAITE_PATHS from primaite.game.agent.actions import ActionManager @@ -30,7 +30,7 @@ from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer from tests import TEST_ASSETS_ROOT -ray.init(local_mode=True) +rayinit(local_mode=True) ACTION_SPACE_NODE_VALUES = 1 ACTION_SPACE_NODE_ACTION_VALUES = 1 diff --git a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py index a299b913..745e280b 100644 --- a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py +++ b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py @@ -1,9 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -import importlib from typing import Dict import yaml -from ray import air, init, tune from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec @@ -43,7 +41,7 @@ def test_sb3_action_masking(monkeypatch): monkeypatch.setattr(env, "step", lambda action: cache_step(env, action)) model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, batch_size=32) - model.learn(512) + model.learn(256) assert len(action_num_history) == len(mask_history) > 0 # Make sure the masks had at least some False entries, if it was all True then the mask was disabled diff --git a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py index e015c33c..26e690d0 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py @@ -1,7 +1,5 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -import ray import yaml -from ray import air, tune from ray.rllib.algorithms.ppo import PPOConfig from primaite.session.ray_envs import PrimaiteRayMARLEnv @@ -12,7 +10,6 @@ MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml" def test_rllib_multi_agent_compatibility(): """Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system.""" - with open(MULTI_AGENT_PATH, "r") as f: cfg = yaml.safe_load(f) @@ -26,14 +23,5 @@ def test_rllib_multi_agent_compatibility(): ) .training(train_batch_size=128) ) - - tune.Tuner( - "PPO", - run_config=air.RunConfig( - stop={"training_iteration": 128}, - checkpoint_config=air.CheckpointConfig( - checkpoint_frequency=10, - ), - ), - param_space=config, - ).fit() + algo = config.build() + algo.train() diff --git a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py index a02a078c..265257e4 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py @@ -3,7 +3,6 @@ import tempfile from pathlib import Path import pytest -import ray import yaml from ray.rllib.algorithms import ppo diff --git a/tests/e2e_integration_tests/environments/test_sb3_environment.py b/tests/e2e_integration_tests/environments/test_sb3_environment.py index 27fb134b..a07d5d2e 100644 --- a/tests/e2e_integration_tests/environments/test_sb3_environment.py +++ b/tests/e2e_integration_tests/environments/test_sb3_environment.py @@ -20,7 +20,7 @@ def test_sb3_compatibility(): gym = PrimaiteGymEnv(env_config=cfg) model = PPO("MlpPolicy", gym) - model.learn(total_timesteps=1000) + model.learn(total_timesteps=256) save_path = Path(tempfile.gettempdir()) / "model.zip" model.save(save_path)