From 08f1cf1fbd67f5fa8158876915f9ea940def0c7c Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 19 Sep 2024 15:06:29 +0100 Subject: [PATCH] Fix airspace and remaining port problems from refactor --- .../network/nodes/wireless_router.rst | 6 +- src/primaite/config/load.py | 5 +- .../agent/observations/acl_observation.py | 4 +- .../agent/observations/host_observations.py | 10 +- .../agent/observations/nic_observations.py | 17 +-- .../agent/observations/node_observations.py | 10 +- src/primaite/game/game.py | 26 ++--- src/primaite/simulator/network/airspace.py | 101 +++++++++++------- src/primaite/simulator/network/container.py | 8 +- .../network/hardware/nodes/host/host_node.py | 4 +- .../hardware/nodes/network/network_node.py | 4 +- .../network/hardware/nodes/network/router.py | 11 +- .../hardware/nodes/network/wireless_router.py | 14 +-- .../network/transmission/network_layer.py | 10 +- .../network/transmission/transport_layer.py | 61 ++++++----- .../system/applications/application.py | 4 +- .../red_applications/c2/c2_beacon.py | 1 - .../simulator/system/core/session_manager.py | 4 +- .../system/services/ftp/ftp_service.py | 1 - .../simulator/system/services/service.py | 4 +- src/primaite/simulator/system/software.py | 1 - .../extensions/nodes/giga_switch.py | 3 +- .../extensions/nodes/super_computer.py | 4 +- .../extensions/services/extended_service.py | 12 ++- .../extensions/test_extendable_config.py | 18 ++-- .../observations/test_acl_observations.py | 4 +- .../observations/test_firewall_observation.py | 4 +- .../observations/test_nic_observations.py | 7 +- .../observations/test_router_observation.py | 4 +- .../game_layer/test_rewards.py | 4 +- .../network/test_airspace_config.py | 9 +- .../network/test_firewall.py | 16 ++- tests/integration_tests/system/test_nmap.py | 4 +- .../_utils/test_dict_enum_keys_conversion.py | 9 +- 34 files changed, 227 insertions(+), 177 deletions(-) diff --git a/docs/source/simulation_components/network/nodes/wireless_router.rst b/docs/source/simulation_components/network/nodes/wireless_router.rst index bd361afa..436852ea 100644 --- a/docs/source/simulation_components/network/nodes/wireless_router.rst +++ b/docs/source/simulation_components/network/nodes/wireless_router.rst @@ -49,7 +49,7 @@ additional steps to configure wireless settings: wireless_router.configure_wireless_access_point( port=1, ip_address="192.168.2.1", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency["WIFI_2_4"], + frequency="WIFI_2_4", ) @@ -130,13 +130,13 @@ ICMP traffic, ensuring basic network connectivity and ping functionality. port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency["WIFI_2_4"], + frequency="WIFI_2_4", ) router_2.configure_wireless_access_point( port=1, ip_address="192.168.1.2", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency["WIFI_2_4"], + frequency="WIFI_2_4", ) # Configure routes for inter-router communication diff --git a/src/primaite/config/load.py b/src/primaite/config/load.py index b00c26f6..39040d76 100644 --- a/src/primaite/config/load.py +++ b/src/primaite/config/load.py @@ -60,9 +60,10 @@ def data_manipulation_marl_config_path() -> Path: raise FileNotFoundError(msg) return path + def get_extended_config_path() -> Path: """ - Get the path to an 'extended' example config that contains nodes using the extension framework + Get the path to an 'extended' example config that contains nodes using the extension framework. :return: Path to the extended example config :rtype: Path @@ -72,4 +73,4 @@ def get_extended_config_path() -> Path: msg = f"Example config does not exist: {path}. Have you run `primaite setup`?" _LOGGER.error(msg) raise FileNotFoundError(msg) - return path \ No newline at end of file + return path diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py index abb6f1f8..41af5a8f 100644 --- a/src/primaite/game/agent/observations/acl_observation.py +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -10,8 +10,6 @@ from gymnasium.core import ObsType from primaite import getLogger from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port _LOGGER = getLogger(__name__) @@ -63,7 +61,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)} self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)} self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)} - self.protocol_to_id: Dict[str, int] = {IPProtocol[p]: i + 2 for i, p in enumerate(protocol_list)} + self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)} self.default_observation: Dict = { i + 1: { diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 05b25952..30ccd195 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -58,8 +58,14 @@ class HostObservation(AbstractObservation, identifier="HOST"): include_users: Optional[bool] = True """If True, report user session information.""" - @field_validator('monitored_traffic', mode='before') - def traffic_lookup(cls, val:Optional[Dict]) -> Optional[Dict]: + @field_validator("monitored_traffic", mode="before") + def traffic_lookup(cls, val: Optional[Dict]) -> Optional[Dict]: + """ + Convert monitored_traffic by lookup against Port and Protocol dicts. + + This is necessary for retaining compatiblility with configs written for PrimAITE <=3.3. + This method will be removed in PrimAITE >= 4.0 + """ if val is None: return val new_val = {} diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 200187f5..296ce04c 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -26,8 +26,14 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): monitored_traffic: Optional[Dict] = None """A dict containing which traffic types are to be included in the observation.""" - @field_validator('monitored_traffic', mode='before') - def traffic_lookup(cls, val:Optional[Dict]) -> Optional[Dict]: + @field_validator("monitored_traffic", mode="before") + def traffic_lookup(cls, val: Optional[Dict]) -> Optional[Dict]: + """ + Convert monitored_traffic by lookup against Port and Protocol dicts. + + This is necessary for retaining compatiblility with configs written for PrimAITE <=3.3. + This method will be removed in PrimAITE >= 4.0 + """ if val is None: return val new_val = {} @@ -41,7 +47,6 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): new_val[proto].append(port) return new_val - def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None: """ Initialise a network interface observation instance. @@ -76,7 +81,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): def _default_monitored_traffic_observation(self, monitored_traffic_config: Dict) -> Dict: default_traffic_obs = {"TRAFFIC": {}} - for protocol in monitored_traffic_config: + for protocol in self.monitored_traffic: protocol = str(protocol).lower() default_traffic_obs["TRAFFIC"][protocol] = {} @@ -84,8 +89,8 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): default_traffic_obs["TRAFFIC"]["icmp"] = {"inbound": 0, "outbound": 0} else: default_traffic_obs["TRAFFIC"][protocol] = {} - for port in monitored_traffic_config[protocol]: - default_traffic_obs["TRAFFIC"][protocol] = {"inbound": 0, "outbound": 0} + for port in self.monitored_traffic[protocol]: + default_traffic_obs["TRAFFIC"][protocol][port] = {"inbound": 0, "outbound": 0} return default_traffic_obs diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 3e51c3b3..054ffcdb 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -63,8 +63,14 @@ class NodesObservation(AbstractObservation, identifier="NODES"): num_rules: Optional[int] = None """Number of rules ACL rules to show.""" - @field_validator('monitored_traffic', mode='before') - def traffic_lookup(cls, val:Optional[Dict]) -> Optional[Dict]: + @field_validator("monitored_traffic", mode="before") + def traffic_lookup(cls, val: Optional[Dict]) -> Optional[Dict]: + """ + Convert monitored_traffic by lookup against Port and Protocol dicts. + + This is necessary for retaining compatiblility with configs written for PrimAITE <=3.3. + This method will be removed in PrimAITE >= 4.0 + """ if val is None: return val new_val = {} diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index e8329c63..8e0abb1e 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -17,10 +17,9 @@ from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT -from primaite.simulator.network.airspace import AirSpaceFrequency from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState, UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.hardware.nodes.host.host_node import NIC, HostNode +from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC from primaite.simulator.network.hardware.nodes.host.server import Printer, Server from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode @@ -89,8 +88,8 @@ class PrimaiteGameOptions(BaseModel): thresholds: Optional[Dict] = {} """A dict containing the thresholds used for determining what is acceptable during observations.""" - @field_validator('ports', mode='before') - def ports_str2int(cls, vals:Union[List[str],List[int]]) -> List[int]: + @field_validator("ports", mode="before") + def ports_str2int(cls, vals: Union[List[str], List[int]]) -> List[int]: """ Convert named port strings to port integer values. Integer ports remain unaffected. @@ -102,8 +101,8 @@ class PrimaiteGameOptions(BaseModel): vals[i] = Port[port_val] return vals - @field_validator('protocols', mode='before') - def protocols_str2int(cls, vals:List[str]) -> List[str]: + @field_validator("protocols", mode="before") + def protocols_str2int(cls, vals: List[str]) -> List[str]: """ Convert old-style named protocols to their proper values. @@ -116,7 +115,6 @@ class PrimaiteGameOptions(BaseModel): return vals - class PrimaiteGame: """ Primaite game encapsulates the simulation and agents which interact with it. @@ -294,10 +292,7 @@ class PrimaiteGame: network_config = simulation_config.get("network", {}) airspace_cfg = network_config.get("airspace", {}) frequency_max_capacity_mbps_cfg = airspace_cfg.get("frequency_max_capacity_mbps", {}) - - frequency_max_capacity_mbps_cfg = {AirSpaceFrequency[k]: v for k, v in frequency_max_capacity_mbps_cfg.items()} - - net.airspace.frequency_max_capacity_mbps_ = frequency_max_capacity_mbps_cfg + net.airspace.set_frequency_max_capacity_mbps(frequency_max_capacity_mbps_cfg) nodes_cfg = network_config.get("nodes", []) links_cfg = network_config.get("links", []) @@ -318,11 +313,10 @@ class PrimaiteGame: dns_server=node_cfg.get("dns_server", None), operating_state=NodeOperatingState.ON if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()]) - elif n_type in NetworkNode._registry: - new_node = NetworkNode._registry[n_type]( - **node_cfg + else NodeOperatingState[p.upper()], ) + elif n_type in NetworkNode._registry: + new_node = NetworkNode._registry[n_type](**node_cfg) # Default PrimAITE nodes elif n_type == "computer": new_node = Computer( @@ -502,7 +496,7 @@ class PrimaiteGame: opt = application_cfg["options"] new_application.configure( target_ip_address=IPv4Address(opt.get("target_ip_address")), - target_port = Port[opt.get("target_port", "POSTGRES_SERVER")], + target_port=Port[opt.get("target_port", "POSTGRES_SERVER")], payload=opt.get("payload"), repeat=bool(opt.get("repeat")), port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index 29326df8..65dceeb1 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -1,14 +1,12 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations +import copy from abc import ABC, abstractmethod -from enum import Enum -from typing import Any, ClassVar, Dict, List, Type -from pydantic._internal._generics import PydanticGenericMetadata -from typing_extensions import Unpack +from typing import Any, Dict, List from prettytable import MARKDOWN, PrettyTable -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, Field, validate_call from primaite import getLogger from primaite.simulator.network.hardware.base import Layer3Interface, NetworkInterface, WiredNetworkInterface @@ -42,29 +40,31 @@ def format_hertz(hertz: float, format_terahertz: bool = False, decimals: int = 3 else: # Hertz return format_str.format(hertz) + " Hz" -AirSpaceFrequencyRegistry: Dict[str,Dict] = { - "WIFI_2_4" : {'frequency': 2.4e9, 'data_rate_bps':100_000_000.0}, - "WIFI_5" : {'frequency': 5e9, 'data_rate_bps':500_000_000.0}, + +_default_frequency_set: Dict[str, Dict] = { + "WIFI_2_4": {"frequency": 2.4e9, "data_rate_bps": 100_000_000.0}, + "WIFI_5": {"frequency": 5e9, "data_rate_bps": 500_000_000.0}, } +"""Frequency configuration that is automatically used for any new airspace.""" -def register_frequency(freq_name: str, freq_hz: int, data_rate_bps: int) -> None: - if freq_name in AirSpaceFrequencyRegistry: - raise RuntimeError(f"Cannot register new frequency {freq_name} because it's already registered.") - AirSpaceFrequencyRegistry.update({freq_name:{'frequency': freq_hz, 'data_rate_bps':data_rate_bps}}) -def maximum_data_rate_mbps(frequency_name:str) -> float: +def register_default_frequency(freq_name: str, freq_hz: float, data_rate_bps: float): + """Add to the default frequency configuration. This is intended as a plugin hook. + + If your plugin makes use of bespoke frequencies for wireless communication, you should make a call to this method + whereever you define components that rely on the bespoke frequencies. That way, as soon as your components are + imported, this function automatically updates the default frequency set. + + This should also be run before instances of AirSpace are created. + + :param freq_name: The frequency name. If this clashes with an existing frequency name, it will be overwritten. + :type freq_name: str + :param freq_hz: The frequency itself, measured in Hertz. + :type freq_hz: float + :param data_rate_bps: The transmission capacity over this frequency, in bits per second. + :type data_rate_bps: float """ - Retrieves the maximum data transmission rate in megabits per second (Mbps). - - This is derived by converting the maximum data rate from bits per second, as defined - in `maximum_data_rate_bps`, to megabits per second. - - :return: The maximum data rate in megabits per second. - """ - return AirSpaceFrequencyRegistry[frequency_name]['data_rate_bps'] - return data_rate / 1_000_000.0 - - + _default_frequency_set.update({freq_name: {"frequency": freq_hz, "data_rate_bps": data_rate_bps}}) class AirSpace(BaseModel): @@ -77,27 +77,21 @@ class AirSpace(BaseModel): """ wireless_interfaces: Dict[str, WirelessNetworkInterface] = Field(default_factory=lambda: {}) - wireless_interfaces_by_frequency: Dict[int, List[WirelessNetworkInterface]] = Field( - default_factory=lambda: {} - ) + wireless_interfaces_by_frequency: Dict[int, List[WirelessNetworkInterface]] = Field(default_factory=lambda: {}) bandwidth_load: Dict[int, float] = Field(default_factory=lambda: {}) - frequency_max_capacity_mbps_: Dict[int, float] = Field(default_factory=lambda: {}) + frequencies: Dict[str, Dict] = Field(default_factory=lambda: copy.deepcopy(_default_frequency_set)) - def get_frequency_max_capacity_mbps(self, frequency: str) -> float: + @validate_call + def get_frequency_max_capacity_mbps(self, freq_name: str) -> float: """ Retrieves the maximum data transmission capacity for a specified frequency. - This method checks a dictionary holding custom maximum capacities. If the frequency is found, it returns the - custom set maximum capacity. If the frequency is not found in the dictionary, it defaults to the standard - maximum data rate associated with that frequency. - - :param frequency: The frequency for which the maximum capacity is queried. - + :param freq_name: The frequency for which the maximum capacity is queried. :return: The maximum capacity in Mbps for the specified frequency. """ - if frequency in self.frequency_max_capacity_mbps_: - return self.frequency_max_capacity_mbps_[frequency] - return maximum_data_rate_mbps(frequency) + if freq_name in self.frequencies: + return self.frequencies[freq_name]["data_rate_bps"] / (1024.0 * 1024.0) + return 0.0 def set_frequency_max_capacity_mbps(self, cfg: Dict[int, float]): """ @@ -105,10 +99,29 @@ class AirSpace(BaseModel): :param cfg: A dictionary mapping frequencies to their new maximum capacities in Mbps. """ - self.frequency_max_capacity_mbps_ = cfg for freq, mbps in cfg.items(): + self.frequencies[freq]["data_rate_bps"] = mbps * 1024 * 1024 print(f"Overriding {freq} max capacity as {mbps:.3f} mbps") + def register_frequency(self, freq_name: str, freq_hz: float, data_rate_bps: float) -> None: + """ + Define a new frequency for this airspace. + + :param freq_name: The frequency name. If this clashes with an existing frequency name, it will be overwritten. + :type freq_name: str + :param freq_hz: The frequency itself, measured in Hertz. + :type freq_hz: float + :param data_rate_bps: The transmission capacity over this frequency, in bits per second. + :type data_rate_bps: float + """ + if freq_name in self.frequencies: + _LOGGER.info( + f"Overwriting Air space frequency {freq_name}. " + f"Previous data rate: {self.frequencies[freq_name]['data_rate_bps']}. " + f"Current data rate: {data_rate_bps}." + ) + self.frequencies.update({freq_name: {"frequency": freq_hz, "data_rate_bps": data_rate_bps}}) + def show_bandwidth_load(self, markdown: bool = False): """ Prints a table of the current bandwidth load for each frequency on the airspace. @@ -130,7 +143,13 @@ class AirSpace(BaseModel): load_percent = load / maximum_capacity if maximum_capacity > 0 else 0.0 if load_percent > 1.0: load_percent = 1.0 - table.add_row([format_hertz(frequency), f"{load_percent:.0%}", f"{maximum_capacity:.3f}"]) + table.add_row( + [ + format_hertz(self.frequencies[frequency]["frequency"]), + f"{load_percent:.0%}", + f"{maximum_capacity:.3f}", + ] + ) print(table) def show_wireless_interfaces(self, markdown: bool = False): @@ -162,7 +181,7 @@ class AirSpace(BaseModel): interface.mac_address, interface.ip_address if hasattr(interface, "ip_address") else None, interface.subnet_mask if hasattr(interface, "subnet_mask") else None, - format_hertz(interface.frequency), + format_hertz(self.frequencies[interface.frequency]["frequency"]), f"{interface.speed:.3f}", status, ] diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index 39fbe783..6e019f32 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -130,15 +130,15 @@ class Network(SimComponent): def firewall_nodes(self) -> List[Node]: """The Firewalls in the Network.""" return [node for node in self.nodes.values() if node.__class__.__name__ == "Firewall"] - + @property def extended_hostnodes(self) -> List[Node]: - """Extended nodes that inherited HostNode in the network""" + """Extended nodes that inherited HostNode in the network.""" return [node for node in self.nodes.values() if node.__class__.__name__.lower() in HostNode._registry] - + @property def extended_networknodes(self) -> List[Node]: - """Extended nodes that inherited NetworkNode in the network""" + """Extended nodes that inherited NetworkNode in the network.""" return [node for node in self.nodes.values() if node.__class__.__name__.lower() in NetworkNode._registry] @property diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index ea162e88..8a420e44 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -332,7 +332,7 @@ class HostNode(Node): super().__init__(**kwargs) self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask)) - def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: """ Register a hostnode type. @@ -340,7 +340,7 @@ class HostNode(Node): :type identifier: str :raises ValueError: When attempting to register an hostnode with a name that is already allocated. """ - if identifier == 'default': + if identifier == "default": return # Enforce lowercase registry entries because it makes comparisons everywhere else much easier. identifier = identifier.lower() diff --git a/src/primaite/simulator/network/hardware/nodes/network/network_node.py b/src/primaite/simulator/network/hardware/nodes/network/network_node.py index 6515bb02..a0cb63e1 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/network_node.py +++ b/src/primaite/simulator/network/hardware/nodes/network/network_node.py @@ -19,7 +19,7 @@ class NetworkNode(Node): _registry: ClassVar[Dict[str, Type["NetworkNode"]]] = {} """Registry of application types. Automatically populated when subclasses are defined.""" - def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: """ Register a networknode type. @@ -27,7 +27,7 @@ class NetworkNode(Node): :type identifier: str :raises ValueError: When attempting to register an networknode with a name that is already allocated. """ - if identifier == 'default': + if identifier == "default": return identifier = identifier.lower() super().__init_subclass__(**kwargs) diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 013c473e..fded23f9 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -130,19 +130,20 @@ class ACLRule(SimComponent): dst_port: Optional[int] = None match_count: int = 0 - @field_validator('protocol', mode='before') - def protocol_valid(cls, val:Optional[str]) -> Optional[str]: + @field_validator("protocol", mode="before") + def protocol_valid(cls, val: Optional[str]) -> Optional[str]: + """Assert that the protocol for the rule is predefined in the IPProtocol lookup.""" if val is not None: assert val in IPProtocol.values(), f"Cannot create ACL rule with invalid protocol {val}" return val - @field_validator('src_port', 'dst_port', mode='before') - def ports_valid(cls, val:Optional[int]) -> Optional[int]: + @field_validator("src_port", "dst_port", mode="before") + def ports_valid(cls, val: Optional[int]) -> Optional[int]: + """Assert that the port for the rule is predefined in the Port lookup.""" if val is not None: assert val in Port.values(), f"Cannot create ACL rule with invalid port {val}" return val - def __str__(self) -> str: rule_strings = [] for key, value in self.model_dump(exclude={"uuid", "request_manager"}).items(): diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index d73bc756..1969a121 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Optional, Union from pydantic import validate_call -from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, IPWirelessNetworkInterface +from primaite.simulator.network.airspace import AirSpace, IPWirelessNetworkInterface from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterInterface from primaite.simulator.network.transmission.data_link_layer import Frame @@ -116,7 +116,7 @@ class WirelessRouter(Router): >>> wireless_router.configure_wireless_access_point( ... ip_address="10.10.10.1", ... subnet_mask="255.255.255.0" - ... frequency=AirSpaceFrequency["WIFI_2_4"] + ... frequency="WIFI_2_4" ... ) """ @@ -153,7 +153,7 @@ class WirelessRouter(Router): self, ip_address: IPV4Address, subnet_mask: IPV4Address, - frequency: Optional[int] = AirSpaceFrequency["WIFI_2_4"], + frequency: Optional[str] = "WIFI_2_4", ): """ Configures a wireless access point (WAP). @@ -166,12 +166,12 @@ class WirelessRouter(Router): :param ip_address: The IP address to be assigned to the wireless access point. :param subnet_mask: The subnet mask associated with the IP address - :param frequency: The operating frequency of the wireless access point, defined by the AirSpaceFrequency + :param frequency: The operating frequency of the wireless access point, defined by the air space frequency enum. This determines the frequency band (e.g., 2.4 GHz or 5 GHz) the access point will use for wireless - communication. Default is AirSpaceFrequency["WIFI_2_4"]. + communication. Default is "WIFI_2_4". """ if not frequency: - frequency = AirSpaceFrequency["WIFI_2_4"] + frequency = "WIFI_2_4" self.sys_log.info("Configuring wireless access point") self.wireless_access_point.disable() # Temporarily disable the WAP for reconfiguration @@ -264,7 +264,7 @@ class WirelessRouter(Router): if "wireless_access_point" in cfg: ip_address = cfg["wireless_access_point"]["ip_address"] subnet_mask = cfg["wireless_access_point"]["subnet_mask"] - frequency = AirSpaceFrequency[cfg["wireless_access_point"]["frequency"]] + frequency = cfg["wireless_access_point"]["frequency"] router.configure_wireless_access_point(ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency) if "acl" in cfg: diff --git a/src/primaite/simulator/network/transmission/network_layer.py b/src/primaite/simulator/network/transmission/network_layer.py index 36ff2751..a01b7f42 100644 --- a/src/primaite/simulator/network/transmission/network_layer.py +++ b/src/primaite/simulator/network/transmission/network_layer.py @@ -9,11 +9,11 @@ from primaite.utils.validators import IPV4Address _LOGGER = getLogger(__name__) -IPProtocol : dict[str, str] = dict( - NONE = "none", - TCP = "tcp", - UDP = "udp", - ICMP = "icmp", +IPProtocol: dict[str, str] = dict( + NONE="none", + TCP="tcp", + UDP="udp", + ICMP="icmp", ) # class IPProtocol(Enum): diff --git a/src/primaite/simulator/network/transmission/transport_layer.py b/src/primaite/simulator/network/transmission/transport_layer.py index c77ef532..60f2f070 100644 --- a/src/primaite/simulator/network/transmission/transport_layer.py +++ b/src/primaite/simulator/network/transmission/transport_layer.py @@ -1,40 +1,39 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from enum import Enum -from typing import List, Union +from typing import List from pydantic import BaseModel - Port: dict[str, int] = dict( - UNUSED = -1, - NONE = 0, - WOL = 9, - FTP_DATA = 20, - FTP = 21, - SSH = 22, - SMTP = 25, - DNS = 53, - HTTP = 80, - POP3 = 110, - SFTP = 115, - NTP = 123, - IMAP = 143, - SNMP = 161, - SNMP_TRAP = 162, - ARP = 219, - LDAP = 389, - HTTPS = 443, - SMB = 445, - IPP = 631, - SQL_SERVER = 1433, - MYSQL = 3306, - RDP = 3389, - RTP = 5004, - RTP_ALT = 5005, - DNS_ALT = 5353, - HTTP_ALT = 8080, - HTTPS_ALT = 8443, - POSTGRES_SERVER = 5432, + UNUSED=-1, + NONE=0, + WOL=9, + FTP_DATA=20, + FTP=21, + SSH=22, + SMTP=25, + DNS=53, + HTTP=80, + POP3=110, + SFTP=115, + NTP=123, + IMAP=143, + SNMP=161, + SNMP_TRAP=162, + ARP=219, + LDAP=389, + HTTPS=443, + SMB=445, + IPP=631, + SQL_SERVER=1433, + MYSQL=3306, + RDP=3389, + RTP=5004, + RTP_ALT=5005, + DNS_ALT=5353, + HTTP_ALT=8080, + HTTPS_ALT=8443, + POSTGRES_SERVER=5432, ) # class Port(): diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index b5284968..a7871315 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -44,7 +44,7 @@ class Application(IOSoftware): _registry: ClassVar[Dict[str, Type["Application"]]] = {} """Registry of application types. Automatically populated when subclasses are defined.""" - def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: """ Register an application type. @@ -52,7 +52,7 @@ class Application(IOSoftware): :type identifier: str :raises ValueError: When attempting to register an application with a name that is already allocated. """ - if identifier == 'default': + if identifier == "default": return super().__init_subclass__(**kwargs) if identifier in cls._registry: diff --git a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py index 06453330..9178e68a 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py @@ -1,5 +1,4 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from enum import Enum from ipaddress import IPv4Address from typing import Dict, Optional diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index 172be453..33de3443 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -76,9 +76,7 @@ class SessionManager: """ def __init__(self, sys_log: SysLog): - self.sessions_by_key: Dict[ - Tuple[str, IPv4Address, IPv4Address, Optional[int], Optional[int]], Session - ] = {} + self.sessions_by_key: Dict[Tuple[str, IPv4Address, IPv4Address, Optional[int], Optional[int]], Session] = {} self.sessions_by_uuid: Dict[str, Session] = {} self.sys_log: SysLog = sys_log self.software_manager: SoftwareManager = None # Noqa diff --git a/src/primaite/simulator/system/services/ftp/ftp_service.py b/src/primaite/simulator/system/services/ftp/ftp_service.py index 36245e0f..49678c82 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_service.py +++ b/src/primaite/simulator/system/services/ftp/ftp_service.py @@ -5,7 +5,6 @@ from typing import Dict, Optional from primaite.simulator.file_system.file_system import File from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.service import Service diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 74dcb506..4f0b879c 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -52,7 +52,7 @@ class Service(IOSoftware): def __init__(self, **kwargs): super().__init__(**kwargs) - def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: """ Register a hostnode type. @@ -60,7 +60,7 @@ class Service(IOSoftware): :type identifier: str :raises ValueError: When attempting to register an hostnode with a name that is already allocated. """ - if identifier == 'default': + if identifier == "default": return # Enforce lowercase registry entries because it makes comparisons everywhere else much easier. identifier = identifier.lower() diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 1880d244..084bdaf6 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -14,7 +14,6 @@ from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.file_system.file_system import FileSystem, Folder from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.core.sys_log import SysLog diff --git a/tests/integration_tests/extensions/nodes/giga_switch.py b/tests/integration_tests/extensions/nodes/giga_switch.py index b86bea7d..e4100741 100644 --- a/tests/integration_tests/extensions/nodes/giga_switch.py +++ b/tests/integration_tests/extensions/nodes/giga_switch.py @@ -1,3 +1,4 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from typing import Dict from prettytable import MARKDOWN, PrettyTable @@ -27,7 +28,7 @@ class GigaSwitch(NetworkNode, identifier="gigaswitch"): "A MAC address table mapping destination MAC addresses to corresponding SwitchPorts." def __init__(self, **kwargs): - print('--- Extended Component: GigaSwitch ---') + print("--- Extended Component: GigaSwitch ---") super().__init__(**kwargs) for i in range(1, self.num_ports + 1): self.connect_nic(SwitchPort()) diff --git a/tests/integration_tests/extensions/nodes/super_computer.py b/tests/integration_tests/extensions/nodes/super_computer.py index 8a1465e9..55bdce09 100644 --- a/tests/integration_tests/extensions/nodes/super_computer.py +++ b/tests/integration_tests/extensions/nodes/super_computer.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from typing import ClassVar, Dict -from primaite.simulator.network.hardware.nodes.host.host_node import NIC, HostNode +from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.utils.validators import IPV4Address @@ -37,7 +37,7 @@ class SuperComputer(HostNode, identifier="supercomputer"): SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "FTPClient": FTPClient} def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs): - print('--- Extended Component: SuperComputer ---') + print("--- Extended Component: SuperComputer ---") super().__init__(ip_address=ip_address, subnet_mask=subnet_mask, **kwargs) pass diff --git a/tests/integration_tests/extensions/services/extended_service.py b/tests/integration_tests/extensions/services/extended_service.py index d4af600f..b745b774 100644 --- a/tests/integration_tests/extensions/services/extended_service.py +++ b/tests/integration_tests/extensions/services/extended_service.py @@ -17,7 +17,7 @@ from primaite.simulator.system.software import SoftwareHealthState _LOGGER = getLogger(__name__) -class ExtendedService(Service, identifier='extendedservice'): +class ExtendedService(Service, identifier="extendedservice"): """ A copy of DatabaseService that uses the extension framework instead of being part of PrimAITE. @@ -42,7 +42,7 @@ class ExtendedService(Service, identifier='extendedservice'): kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) self._create_db_file() - if kwargs.get('options'): + if kwargs.get("options"): opt = kwargs["options"] self.password = opt.get("db_password", None) if "backup_server_ip" in opt: @@ -139,7 +139,9 @@ class ExtendedService(Service, identifier='extendedservice'): old_visible_state = SoftwareHealthState.GOOD # get db file regardless of whether or not it was deleted - db_file = self.file_system.get_file(folder_name="database", file_name="extended_service_database.db", include_deleted=True) + db_file = self.file_system.get_file( + folder_name="database", file_name="extended_service_database.db", include_deleted=True + ) if db_file is None: self.sys_log.warning("Database file not initialised.") @@ -153,7 +155,9 @@ class ExtendedService(Service, identifier='extendedservice'): self.file_system.delete_file(folder_name="database", file_name="extended_service_database.db") # replace db file - self.file_system.copy_file(src_folder_name="downloads", src_file_name="extended_service_database.db", dst_folder_name="database") + self.file_system.copy_file( + src_folder_name="downloads", src_file_name="extended_service_database.db", dst_folder_name="database" + ) if self.db_file is None: self.sys_log.error("Copying database backup failed.") diff --git a/tests/integration_tests/extensions/test_extendable_config.py b/tests/integration_tests/extensions/test_extendable_config.py index 5d8af64d..8467151b 100644 --- a/tests/integration_tests/extensions/test_extendable_config.py +++ b/tests/integration_tests/extensions/test_extendable_config.py @@ -1,22 +1,22 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import os + from primaite.config.load import get_extended_config_path from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ_NETWORK, load_config -import os +from tests.integration_tests.extensions.applications.extended_application import ExtendedApplication +from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch # Import the extended components so that PrimAITE registers them from tests.integration_tests.extensions.nodes.super_computer import SuperComputer -from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch from tests.integration_tests.extensions.services.extended_service import ExtendedService -from tests.integration_tests.extensions.applications.extended_application import ExtendedApplication def test_extended_example_config(): - """Test that the example config can be parsed properly.""" - config_path = os.path.join( "tests", "assets", "configs", "extended_config.yaml") + config_path = os.path.join("tests", "assets", "configs", "extended_config.yaml") game = load_config(config_path) network: Network = game.simulation.network @@ -25,8 +25,8 @@ def test_extended_example_config(): assert len(network.router_nodes) == 1 # 1 router in network assert len(network.switch_nodes) == 1 # 1 switches in network assert len(network.server_nodes) == 5 # 5 servers in network - assert len(network.extended_hostnodes) == 1 # One extended node based on HostNode - assert len(network.extended_networknodes) == 1 # One extended node based on NetworkNode + assert len(network.extended_hostnodes) == 1 # One extended node based on HostNode + assert len(network.extended_networknodes) == 1 # One extended node based on NetworkNode - assert 'ExtendedApplication' in network.extended_hostnodes[0].software_manager.software - assert 'ExtendedService' in network.extended_hostnodes[0].software_manager.software + assert "ExtendedApplication" in network.extended_hostnodes[0].software_manager.software + assert "ExtendedService" in network.extended_hostnodes[0].software_manager.software diff --git a/tests/integration_tests/game_layer/observations/test_acl_observations.py b/tests/integration_tests/game_layer/observations/test_acl_observations.py index 398c43a9..28f9ac5a 100644 --- a/tests/integration_tests/game_layer/observations/test_acl_observations.py +++ b/tests/integration_tests/game_layer/observations/test_acl_observations.py @@ -38,8 +38,8 @@ def test_acl_observations(simulation): acl_obs = ACLObservation( where=["network", "nodes", router.hostname, "acl", "acl"], ip_list=[], - port_list=["NTP", "HTTP", "POSTGRES_SERVER"], - protocol_list=["TCP", "UDP", "ICMP"], + port_list=[123, 80, 5432], + protocol_list=["tcp", "udp", "icmp"], num_rules=10, wildcard_list=[], ) diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 68506d59..21fe4bed 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -31,8 +31,8 @@ def test_firewall_observation(): num_rules=7, ip_list=["10.0.0.1", "10.0.0.2"], wildcard_list=["0.0.0.255", "0.0.0.1"], - port_list=["HTTP", "DNS"], - protocol_list=["TCP"], + port_list=[80, 53], + protocol_list=["tcp"], include_users=False, ) diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index bd8dfc4e..8254dad2 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -152,7 +152,12 @@ def test_config_nic_categories(simulation): def test_nic_monitored_traffic(simulation): - monitored_traffic = {"icmp": ["NONE"], "tcp": [53,]} + monitored_traffic = { + "icmp": ["NONE"], + "tcp": [ + 53, + ], + } pc: Computer = simulation.network.get_node_by_hostname("client_1") pc2: Computer = simulation.network.get_node_by_hostname("client_2") diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index 937bb061..c28e1bb8 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -24,8 +24,8 @@ def test_router_observation(): num_rules=7, ip_list=["10.0.0.1", "10.0.0.2"], wildcard_list=["0.0.0.255", "0.0.0.1"], - port_list=["HTTP", "DNS"], - protocol_list=["TCP"], + port_list=[80, 53], + protocol_list=["tcp"], ) router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl, include_users=False) diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index d872c2b0..570c4ad6 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -65,7 +65,9 @@ def test_uc2_rewards(game_and_agent): db_client.run() router: Router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=2) + router.acl.add_rule( + ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=2 + ) comp = GreenAdminDatabaseUnreachablePenalty("client_1") diff --git a/tests/integration_tests/network/test_airspace_config.py b/tests/integration_tests/network/test_airspace_config.py index 1794c4bc..e000f6ae 100644 --- a/tests/integration_tests/network/test_airspace_config.py +++ b/tests/integration_tests/network/test_airspace_config.py @@ -2,7 +2,6 @@ import yaml from primaite.game.game import PrimaiteGame -from primaite.simulator.network.airspace import AirSpaceFrequency from tests import TEST_ASSETS_ROOT @@ -13,8 +12,8 @@ def test_override_freq_max_capacity_mbps(): config_dict = yaml.safe_load(f) network = PrimaiteGame.from_config(cfg=config_dict).simulation.network - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_2_4"]) == 123.45 - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_5"]) == 0.0 + assert network.airspace.get_frequency_max_capacity_mbps("WIFI_2_4") == 123.45 + assert network.airspace.get_frequency_max_capacity_mbps("WIFI_5") == 0.0 pc_a = network.get_node_by_hostname("pc_a") pc_b = network.get_node_by_hostname("pc_b") @@ -32,8 +31,8 @@ def test_override_freq_max_capacity_mbps_blocked(): config_dict = yaml.safe_load(f) network = PrimaiteGame.from_config(cfg=config_dict).simulation.network - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_2_4"]) == 0.0 - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_5"]) == 0.0 + assert network.airspace.get_frequency_max_capacity_mbps("WIFI_2_4") == 0.0 + assert network.airspace.get_frequency_max_capacity_mbps("WIFI_5") == 0.0 pc_a = network.get_node_by_hostname("pc_a") pc_b = network.get_node_by_hostname("pc_b") diff --git a/tests/integration_tests/network/test_firewall.py b/tests/integration_tests/network/test_firewall.py index 8e06ccfb..44b660cf 100644 --- a/tests/integration_tests/network/test_firewall.py +++ b/tests/integration_tests/network/test_firewall.py @@ -73,8 +73,12 @@ def dmz_external_internal_network() -> Network: firewall_node.external_outbound_acl.add_rule( action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 ) - firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + firewall_node.dmz_inbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 + ) + firewall_node.dmz_outbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 + ) # external node external_node = Computer( @@ -262,8 +266,12 @@ def test_service_allowed_with_rule(dmz_external_internal_network): assert not internal_ntp_client.time - firewall.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1) - firewall.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1) + firewall.internal_outbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1 + ) + firewall.internal_inbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1 + ) internal_ntp_client.request_time() diff --git a/tests/integration_tests/system/test_nmap.py b/tests/integration_tests/system/test_nmap.py index 1064ed1b..9d92b660 100644 --- a/tests/integration_tests/system/test_nmap.py +++ b/tests/integration_tests/system/test_nmap.py @@ -73,7 +73,9 @@ def test_port_scan_one_node_one_port(example_network): client_2 = network.get_node_by_hostname("client_2") actual_result = client_1_nmap.port_scan( - target_ip_address=client_2.network_interface[1].ip_address, target_port=Port["DNS"], target_protocol=IPProtocol["TCP"] + target_ip_address=client_2.network_interface[1].ip_address, + target_port=Port["DNS"], + target_protocol=IPProtocol["TCP"], ) expected_result = {IPv4Address("192.168.10.22"): {IPProtocol["TCP"]: [Port["DNS"]]}} diff --git a/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py b/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py index 4e40bbd8..8becc6ae 100644 --- a/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py +++ b/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py @@ -66,7 +66,9 @@ def test_nested_dicts(): The expected output should have string values of enums as keys at all levels. """ original_dict = { - IPProtocol["UDP"]: {Port["ARP"]: {"inbound": 0, "outbound": 1016.0, "details": {IPProtocol["TCP"]: {"latency": "low"}}}} + IPProtocol["UDP"]: { + Port["ARP"]: {"inbound": 0, "outbound": 1016.0, "details": {IPProtocol["TCP"]: {"latency": "low"}}} + } } expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0, "details": {"tcp": {"latency": "low"}}}}} assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict @@ -79,6 +81,9 @@ def test_non_dict_values(): The original dictionary contains lists and tuples as values. The expected output should preserve these non-dictionary values while converting enum keys to string values. """ - original_dict = {IPProtocol["UDP"]: [Port["ARP"], Port["HTTP"]], "protocols": (IPProtocol["TCP"], IPProtocol["UDP"])} + original_dict = { + IPProtocol["UDP"]: [Port["ARP"], Port["HTTP"]], + "protocols": (IPProtocol["TCP"], IPProtocol["UDP"]), + } expected_dict = {"udp": [Port["ARP"], Port["HTTP"]], "protocols": (IPProtocol["TCP"], IPProtocol["UDP"])} assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict