diff --git a/CHANGELOG.md b/CHANGELOG.md index f51fd648..e54f32e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,12 +19,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added reward calculation details to AgentHistoryItem. - Added a new Privilege-Escalation-and Data-Loss-Example.ipynb notebook with a realistic cyber scenario focusing on internal privilege escalation and data loss through the manipulation of SSH access and Access Control Lists (ACLs). +- Added a new extensible `NetworkNodeAdder` class for convenient addition of sets of nodes based on a simplified config. ### Changed - File and folder observations can now be configured to always show the true health status, or require scanning like before. - It's now possible to disable stickiness on reward components, meaning their value returns to 0 during timesteps where agent don't issue the corresponding action. Affects `GreenAdminDatabaseUnreachablePenalty`, `WebpageUnavailablePenalty`, `WebServer404Penalty` - Node observations can now be configured to show the number of active local and remote logins. -- Ports, IP Protocols, and airspace frequencies no longer use enums. They are defined in dictionary lookups and are handled by custom validation to enable extendability with plugins. +- Ports and IP Protocols no longer use enums. They are defined in dictionary lookups and are handled by custom validation to enable extensibility with plugins. +- Changed AirSpaceFrequency to a data transfer object with a registry to allow extensibility +- Changed the Office LAN creation convenience function to follow the new `NetworkNodeAdder` pattern. Office LANs can now also be defined in YAML config. ### Fixed - Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config) @@ -32,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 and `uninstall` methods in the `Node` class. - Updated the `receive_payload_from_session_manager` method in `SoftwareManager` so that it now sends a copy of the payload to any software listening on the destination port of the `Frame`. +- Made the `show` method of `Network` show all node types, including ones registered at runtime ### Removed - Removed the `install` and `uninstall` methods in the `Node` class. diff --git a/docs/index.rst b/docs/index.rst index ff97f60d..1da15b8c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -30,6 +30,7 @@ What is PrimAITE? source/varying_config_files source/environment source/action_masking + source/node_sets .. toctree:: :caption: Notebooks: diff --git a/docs/source/node_sets.rst b/docs/source/node_sets.rst new file mode 100644 index 00000000..866f0139 --- /dev/null +++ b/docs/source/node_sets.rst @@ -0,0 +1,115 @@ +.. only:: comment + + © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + +.. _network_node_adder: + +Network Node Adder Module +######################### + +This module provides a framework for adding nodes to a network in a standardised way. It defines a base class ``NetworkNodeAdder``, which can be extended to create specific node adders, and utility functions to calculate network infrastructure requirements. + +The module allows you to use the pre-defined node adders, ``OfficeLANAdder``, or create custom ones by extending the base class. + +How It Works +============ + +The main class in the module is ``NetworkNodeAdder``, which defines the interface for adding nodes to a network. Child classes are expected to: + +1. Define a ``ConfigSchema`` nested class to define configuration options. +2. Implement the ``add_nodes_to_net(config, network)`` method, which adds the nodes to the network according to the configuration object. + +The ``NetworkNodeAdder`` base class handles node adders defined in the primAITE config YAML file as well. It does this by keeping a registry of node adder classes, and uses the ``type`` field of the config to select the appropriate class to which to pass the configuration. + +Example Usage +============= + +Via Python API +-------------- + +Adding nodes to a network can be done using the python API by constructing the relevant ``ConfigSchema`` object like this: + +.. code-block:: python + + net = Network() + + office_lan_config = OfficeLANAdder.ConfigSchema( + lan_name="CORP-LAN", + subnet_base=2, + pcs_ip_block_start=10, + num_pcs=8, + include_router=False, + bandwidth=150, + ) + OfficeLANAdder.add_nodes_to_net(config=office_lan_config, network=net) + +In this example, a network with 8 computers connected by a switch will be added to the network object. + + +Via YAML Config +--------------- + +.. code-block:: yaml + simulation: + network: + nodes: + # ... nodes go here + node_sets: + - type: office_lan + lan_name: CORP_LAN + subnet_base: 2 + pcs_ip_block_start: 10 + num_pcs: 8 + include_router: False + bandwidth: 150 + # ... additional node sets can be added below + +``NetworkNodeAdder`` reads the ``type`` property of the config, then constructs and passes the configuration to ``OfficeLANAdder.add_nodes_to_net()``. + +In this example, a network with 8 computers connected by a switch will be added to the network object. Equivalent to the above. + + +Creating Custom Node Adders +=========================== +To create a custom node adder, subclass NetworkNodeAdder and define: + +* A ConfigSchema class that defines the configuration schema for the node adder. +* The add_nodes_to_net method that implements how nodes should be added to the network. + +Example: DataCenterAdder +------------------------ +Here is an example of creating a custom node adder, DataCenterAdder: + +.. code-block:: python + + class DataCenterAdder(NetworkNodeAdder, identifier="data_center"): + class ConfigSchema(NetworkNodeAdder.ConfigSchema): + type: Literal["data_center"] = "data_center" + num_servers: int + data_center_name: str + + @classmethod + def add_nodes_to_net(cls, config: ConfigSchema, network: Network) -> None: + for i in range(config.num_servers): + server = Computer( + hostname=f"server_{i}_{config.data_center_name}", + ip_address=f"192.168.100.{i + 8}", + subnet_mask="255.255.255.0", + default_gateway="192.168.100.1", + start_up_duration=0 + ) + server.power_on() + network.add_node(server) + +**Using the Custom Node Adder:** + +.. code-block:: python + + config = { + "type": "data_center", + "num_servers": 5, + "data_center_name": "dc1" + } + + network = Network() + DataCenterAdder.from_config(config, network) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 6d1c0920..691ac2a1 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -17,7 +17,8 @@ from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT -from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState, UserManager +from primaite.simulator.network.creation import NetworkNodeAdder +from primaite.simulator.network.hardware.base import NetworkInterface, Node, NodeOperatingState, UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC from primaite.simulator.network.hardware.nodes.host.server import Printer, Server @@ -270,6 +271,7 @@ class PrimaiteGame: nodes_cfg = network_config.get("nodes", []) links_cfg = network_config.get("links", []) + node_sets_cfg = network_config.get("node_sets", []) # Set the NMNE capture config NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {})) @@ -277,22 +279,8 @@ class PrimaiteGame: n_type = node_cfg["type"] new_node = None - # Handle extended nodes - if n_type.lower() in HostNode._registry: - new_node = HostNode._registry[n_type]( - hostname=node_cfg["hostname"], - ip_address=node_cfg["ip_address"], - subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), - default_gateway=node_cfg.get("default_gateway"), - dns_server=node_cfg.get("dns_server", None), - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - elif n_type in NetworkNode._registry: - new_node = NetworkNode._registry[n_type](**node_cfg) # Default PrimAITE nodes - elif n_type == "computer": + if n_type == "computer": new_node = Computer( hostname=node_cfg["hostname"], ip_address=node_cfg["ip_address"], @@ -337,6 +325,20 @@ class PrimaiteGame: if not (p := node_cfg.get("operating_state")) else NodeOperatingState[p.upper()], ) + # Handle extended nodes + elif n_type.lower() in Node._registry: + new_node = HostNode._registry[n_type]( + hostname=node_cfg["hostname"], + ip_address=node_cfg["ip_address"], + subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), + default_gateway=node_cfg.get("default_gateway"), + dns_server=node_cfg.get("dns_server", None), + operating_state=NodeOperatingState.ON + if not (p := node_cfg.get("operating_state")) + else NodeOperatingState[p.upper()], + ) + elif n_type in NetworkNode._registry: + new_node = NetworkNode._registry[n_type](**node_cfg) else: msg = f"invalid node type {n_type} in config" _LOGGER.error(msg) @@ -505,6 +507,10 @@ class PrimaiteGame: new_node.start_up_duration = int(node_cfg.get("start_up_duration", 3)) new_node.shut_down_duration = int(node_cfg.get("shut_down_duration", 3)) + # 1.1 Create Node Sets + for node_set_cfg in node_sets_cfg: + NetworkNodeAdder.from_config(node_set_cfg, network=net) + # 2. create links between nodes for link_cfg in links_cfg: node_a = net.get_node_by_hostname(link_cfg["endpoint_a_hostname"]) diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index 03d43130..2705c108 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -1,12 +1,11 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations -import copy from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any, ClassVar, Dict, List from prettytable import MARKDOWN, PrettyTable -from pydantic import BaseModel, Field, validate_call +from pydantic import BaseModel, ConfigDict, Field, validate_call from primaite import getLogger from primaite.simulator.network.hardware.base import Layer3Interface, NetworkInterface, WiredNetworkInterface @@ -41,30 +40,28 @@ def format_hertz(hertz: float, format_terahertz: bool = False, decimals: int = 3 return format_str.format(hertz) + " Hz" -_default_frequency_set: Dict[str, Dict] = { - "WIFI_2_4": {"frequency": 2.4e9, "data_rate_bps": 100_000_000.0}, - "WIFI_5": {"frequency": 5e9, "data_rate_bps": 500_000_000.0}, -} -"""Frequency configuration that is automatically used for any new airspace.""" +class AirSpaceFrequency(BaseModel): + """Data transfer object for defining properties of an airspace frequency.""" + + model_config = ConfigDict(extra="forbid") + name: str + """Alias for frequency.""" + frequency_hz: int + """This acts as the primary key. If two names are mapped to the same frequency, they will share a bandwidth.""" + data_rate_bps: float + """How much data can be transmitted on this frequency per second.""" + + _registry: ClassVar[Dict[str, AirSpaceFrequency]] = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.name in self._registry: + raise RuntimeError(f"Frequency {self.name} is already registered. Cannot register it again.") + self._registry[self.name] = self -def register_default_frequency(freq_name: str, freq_hz: float, data_rate_bps: float) -> None: - """Add to the default frequency configuration. This is intended as a plugin hook. - - If your plugin makes use of bespoke frequencies for wireless communication, you should make a call to this method - wherever you define components that rely on the bespoke frequencies. That way, as soon as your components are - imported, this function automatically updates the default frequency set. - - This should also be run before instances of AirSpace are created. - - :param freq_name: The frequency name. If this clashes with an existing frequency name, it will be overwritten. - :type freq_name: str - :param freq_hz: The frequency itself, measured in Hertz. - :type freq_hz: float - :param data_rate_bps: The transmission capacity over this frequency, in bits per second. - :type data_rate_bps: float - """ - _default_frequency_set.update({freq_name: {"frequency": freq_hz, "data_rate_bps": data_rate_bps}}) +FREQ_WIFI_2_4 = AirSpaceFrequency(name="WIFI_2_4", frequency_hz=2.4e9, data_rate_bps=100_000_000.0) +FREQ_WIFI_5 = AirSpaceFrequency(name="WIFI_5", frequency_hz=5e9, data_rate_bps=500_000_000.0) class AirSpace(BaseModel): @@ -79,7 +76,7 @@ class AirSpace(BaseModel): wireless_interfaces: Dict[str, WirelessNetworkInterface] = Field(default_factory=lambda: {}) wireless_interfaces_by_frequency: Dict[int, List[WirelessNetworkInterface]] = Field(default_factory=lambda: {}) bandwidth_load: Dict[int, float] = Field(default_factory=lambda: {}) - frequencies: Dict[str, Dict] = Field(default_factory=lambda: copy.deepcopy(_default_frequency_set)) + frequencies: Dict[str, AirSpaceFrequency] = AirSpaceFrequency._registry @validate_call def get_frequency_max_capacity_mbps(self, freq_name: str) -> float: @@ -90,7 +87,7 @@ class AirSpace(BaseModel): :return: The maximum capacity in Mbps for the specified frequency. """ if freq_name in self.frequencies: - return self.frequencies[freq_name]["data_rate_bps"] / (1024.0 * 1024.0) + return self.frequencies[freq_name].data_rate_bps / (1024.0 * 1024.0) return 0.0 def set_frequency_max_capacity_mbps(self, cfg: Dict[int, float]) -> None: @@ -100,7 +97,7 @@ class AirSpace(BaseModel): :param cfg: A dictionary mapping frequencies to their new maximum capacities in Mbps. """ for freq, mbps in cfg.items(): - self.frequencies[freq]["data_rate_bps"] = mbps * 1024 * 1024 + self.frequencies[freq].data_rate_bps = mbps * 1024 * 1024 print(f"Overriding {freq} max capacity as {mbps:.3f} mbps") def register_frequency(self, freq_name: str, freq_hz: float, data_rate_bps: float) -> None: @@ -117,10 +114,12 @@ class AirSpace(BaseModel): if freq_name in self.frequencies: _LOGGER.info( f"Overwriting Air space frequency {freq_name}. " - f"Previous data rate: {self.frequencies[freq_name]['data_rate_bps']}. " + f"Previous data rate: {self.frequencies[freq_name].data_rate_bps}. " f"Current data rate: {data_rate_bps}." ) - self.frequencies.update({freq_name: {"frequency": freq_hz, "data_rate_bps": data_rate_bps}}) + self.frequencies.update( + {freq_name: AirSpaceFrequency(name=freq_name, frequency_hz=freq_hz, data_rate_bps=data_rate_bps)} + ) def show_bandwidth_load(self, markdown: bool = False): """ @@ -145,7 +144,7 @@ class AirSpace(BaseModel): load_percent = 1.0 table.add_row( [ - format_hertz(self.frequencies[frequency]["frequency"]), + format_hertz(self.frequencies[frequency].frequency_hz), f"{load_percent:.0%}", f"{maximum_capacity:.3f}", ] @@ -181,7 +180,7 @@ class AirSpace(BaseModel): interface.mac_address, interface.ip_address if hasattr(interface, "ip_address") else None, interface.subnet_mask if hasattr(interface, "subnet_mask") else None, - format_hertz(self.frequencies[interface.frequency]["frequency"]), + format_hertz(self.frequencies[interface.frequency].frequency_hz), f"{interface.speed:.3f}", status, ] @@ -209,9 +208,9 @@ class AirSpace(BaseModel): """ if wireless_interface.mac_address not in self.wireless_interfaces: self.wireless_interfaces[wireless_interface.mac_address] = wireless_interface - if wireless_interface.frequency not in self.wireless_interfaces_by_frequency: - self.wireless_interfaces_by_frequency[wireless_interface.frequency] = [] - self.wireless_interfaces_by_frequency[wireless_interface.frequency].append(wireless_interface) + if wireless_interface.frequency.frequency_hz not in self.wireless_interfaces_by_frequency: + self.wireless_interfaces_by_frequency[wireless_interface.frequency.frequency_hz] = [] + self.wireless_interfaces_by_frequency[wireless_interface.frequency.frequency_hz].append(wireless_interface) def remove_wireless_interface(self, wireless_interface: WirelessNetworkInterface): """ @@ -221,7 +220,7 @@ class AirSpace(BaseModel): """ if wireless_interface.mac_address in self.wireless_interfaces: self.wireless_interfaces.pop(wireless_interface.mac_address) - self.wireless_interfaces_by_frequency[wireless_interface.frequency].remove(wireless_interface) + self.wireless_interfaces_by_frequency[wireless_interface.frequency.frequency_hz].remove(wireless_interface) def clear(self): """ @@ -297,7 +296,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC): """ airspace: AirSpace - frequency: str = "WIFI_2_4" + frequency: AirSpaceFrequency = FREQ_WIFI_2_4 def enable(self): """Attempt to enable the network interface.""" diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index 6e019f32..1082e172 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -179,9 +179,8 @@ class Network(SimComponent): table.set_style(MARKDOWN) table.align = "l" table.title = "Nodes" - for node_type, nodes in nodes_type_map.items(): - for node in nodes: - table.add_row([node.hostname, node_type, node.operating_state.name]) + for node in self.nodes.values(): + table.add_row((node.hostname, type(node)._identifier, node.operating_state.name)) print(table) if ip_addresses: diff --git a/src/primaite/simulator/network/creation.py b/src/primaite/simulator/network/creation.py index 891c445e..8cc9a493 100644 --- a/src/primaite/simulator/network/creation.py +++ b/src/primaite/simulator/network/creation.py @@ -1,6 +1,9 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from abc import ABC, abstractmethod from ipaddress import IPv4Address -from typing import Optional +from typing import Any, ClassVar, Dict, Literal, Self, Type + +from pydantic import BaseModel, model_validator from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer @@ -10,6 +13,219 @@ from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP from primaite.utils.validation.port import PORT_LOOKUP +class NetworkNodeAdder(BaseModel): + """ + Base class for adding a set of related nodes to a network in a standardised way. + + Child classes should define a ConfigSchema nested class that subclasses NetworkNodeAdder.ConfigSchema and a __call__ + method which performs the node addition to the network. + + Here is a template that users can use to define custom node adders: + ``` + class YourNodeAdder(NetworkNodeAdder, identifier="your_name"): + class ConfigSchema(NetworkNodeAdder.ConfigSchema): + property_1 : str + property_2 : int + + @classmethod + def add_nodes_to_net(cls, config: ConfigSchema, network: Network) -> None: + node_1 = Node(property_1, ...) + node_2 = Node(...) + network.connect(node_1.network_interface[1], node_2.network_interface[1]) + ... + ``` + """ + + class ConfigSchema(BaseModel, ABC): + """ + Base schema for node adders. + + Child classes of NetworkNodeAdder must define a schema which inherits from this schema. The identifier is used + by the from_config method to select the correct node adder at runtime. + """ + + type: str + """Uniquely identifies the node adder class to use for adding nodes to network.""" + + _registry: ClassVar[Dict[str, Type["NetworkNodeAdder"]]] = {} + + def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: + """ + Register a network node adder class. + + :param identifier: Unique name for the node adder to use for matching against primaite config entries. + :type identifier: str + :raises ValueError: When attempting to register a name that is already reserved. + """ + super().__init_subclass__(**kwargs) + if identifier in cls._registry: + raise ValueError(f"Duplicate node adder {identifier}") + cls._registry[identifier] = cls + + @classmethod + @abstractmethod + def add_nodes_to_net(cls, config: ConfigSchema, network: Network) -> None: + """ + Add nodes to the network. + + Abstract method that must be overwritten by child classes. Use the config definition to create nodes and add + them to the network that is passed in. + + :param config: Config object that defines how to create and add nodes to the network + :type config: ConfigSchema + :param network: PrimAITE network object to which to add nodes. + :type network: Network + """ + pass + + @classmethod + def from_config(cls, config: Dict, network: Network) -> None: + """ + Accept a config, find the relevant node adder class, and call it to add nodes to the network. + + Child classes do not need to define this method. + + :param config: Configuration object for the child adder class + :type config: Dict + :param network: The Network object to which to add nodes + :type network: Network + """ + if config["type"] not in cls._registry: + raise ValueError(f"Invalid node adder type {config['type']}") + adder_class = cls._registry[config["type"]] + adder_class.add_nodes_to_net(config=adder_class.ConfigSchema(**config), network=network) + + +class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"): + """Creates an office LAN.""" + + class ConfigSchema(NetworkNodeAdder.ConfigSchema): + """Configuration schema for OfficeLANAdder.""" + + type: Literal["office_lan"] = "office_lan" + lan_name: str + """Name of lan used for generating hostnames for new nodes.""" + subnet_base: int + """Used as the third octet of IP addresses for nodes in the network.""" + pcs_ip_block_start: int + """Starting point for the fourth octet of IP addresses of nodes in the network.""" + num_pcs: int + """The number of hosts to generate.""" + include_router: bool = True + """Whether to include a router in the new office LAN.""" + bandwidth: int = 100 + """Data bandwidth to the LAN measured in Mbps.""" + + @model_validator(mode="after") + def check_ip_range(self) -> Self: + """Make sure the ip addresses of hosts don't exceed the maximum possible ip address.""" + if self.pcs_ip_block_start + self.num_pcs >= 254: + raise ValueError( + f"Cannot create {self.num_pcs} pcs starting at ip block {self.pcs_ip_block_start} " + f"because ip address octets cannot exceed 254." + ) + return self + + @classmethod + def add_nodes_to_net(cls, config: ConfigSchema, network: Network) -> None: + """ + Add an office lan to the network according to the config definition. + + This method creates a number of hosts and enough switches such that all hosts can be connected to a switch. + Optionally, a router is added to connect the switches together. All the nodes and networking devices are added + to the provided network. + + :param config: Configuration object specifying office LAN parameters + :type config: OfficeLANAdder.ConfigSchema + :param network: The PrimAITE network to which to add the office LAN. + :type network: Network + :raises ValueError: upon invalid configuration + """ + # Calculate the required number of switches + num_of_switches = num_of_switches_required(num_nodes=config.num_pcs) + effective_network_interface = 23 # One port less for router connection + if config.pcs_ip_block_start <= num_of_switches: + raise ValueError( + f"pcs_ip_block_start must be greater than the number of required switches {num_of_switches}" + ) + + # Create a core switch if more than one edge switch is needed + if num_of_switches > 1: + core_switch = Switch(hostname=f"switch_core_{config.lan_name}", start_up_duration=0) + core_switch.power_on() + network.add_node(core_switch) + core_switch_port = 1 + + # Initialise the default gateway to None + default_gateway = None + + # Optionally include a router in the LAN + if config.include_router: + default_gateway = IPv4Address(f"192.168.{config.subnet_base}.1") + router = Router(hostname=f"router_{config.lan_name}", start_up_duration=0) + router.power_on() + router.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + router.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + network.add_node(router) + router.configure_port(port=1, ip_address=default_gateway, subnet_mask="255.255.255.0") + router.enable_port(1) + + # Initialise the first edge switch and connect to the router or core switch + switch_port = 0 + switch_n = 1 + switch = Switch(hostname=f"switch_edge_{switch_n}_{config.lan_name}", start_up_duration=0) + switch.power_on() + network.add_node(switch) + if num_of_switches > 1: + network.connect( + core_switch.network_interface[core_switch_port], + switch.network_interface[24], + bandwidth=config.bandwidth, + ) + else: + network.connect(router.network_interface[1], switch.network_interface[24], bandwidth=config.bandwidth) + + # Add PCs to the LAN and connect them to switches + for i in range(1, config.num_pcs + 1): + # Add a new edge switch if the current one is full + if switch_port == effective_network_interface: + switch_n += 1 + switch_port = 0 + switch = Switch(hostname=f"switch_edge_{switch_n}_{config.lan_name}", start_up_duration=0) + switch.power_on() + network.add_node(switch) + # Connect the new switch to the router or core switch + if num_of_switches > 1: + core_switch_port += 1 + network.connect( + core_switch.network_interface[core_switch_port], + switch.network_interface[24], + bandwidth=config.bandwidth, + ) + else: + network.connect( + router.network_interface[1], switch.network_interface[24], bandwidth=config.bandwidth + ) + + # Create and add a PC to the network + pc = Computer( + hostname=f"pc_{i}_{config.lan_name}", + ip_address=f"192.168.{config.subnet_base}.{i+config.pcs_ip_block_start-1}", + subnet_mask="255.255.255.0", + default_gateway=default_gateway, + start_up_duration=0, + ) + pc.power_on() + network.add_node(pc) + + # Connect the PC to the switch + switch_port += 1 + network.connect(switch.network_interface[switch_port], pc.network_interface[1], bandwidth=config.bandwidth) + switch.network_interface[switch_port].enable() + + def num_of_switches_required(num_nodes: int, max_network_interface: int = 24) -> int: """ Calculate the minimum number of network switches required to connect a given number of nodes. @@ -42,115 +258,3 @@ def num_of_switches_required(num_nodes: int, max_network_interface: int = 24) -> # Return the total number of switches required return full_switches + (1 if extra_pcs > 0 else 0) - - -def create_office_lan( - lan_name: str, - subnet_base: int, - pcs_ip_block_start: int, - num_pcs: int, - network: Optional[Network] = None, - include_router: bool = True, - bandwidth: int = 100, -) -> Network: - """ - Creates a 2-Tier or 3-Tier office local area network (LAN). - - The LAN is configured with a specified number of personal computers (PCs), optionally including a router, - and multiple edge switches to connect them. A core switch is added only if more than one edge switch is required. - The network topology involves edge switches connected either directly to the router in a 2-Tier setup or - to a core switch in a 3-Tier setup. If a router is included, it is connected to the core switch (if present) - and configured with basic access control list (ACL) rules. PCs are distributed across the edge switches. - - - :param str lan_name: The name to be assigned to the LAN. - :param int subnet_base: The subnet base number to be used in the IP addresses. - :param int pcs_ip_block_start: The starting block for assigning IP addresses to PCs. - :param int num_pcs: The number of PCs to be added to the LAN. - :param Optional[Network] network: The network to which the LAN components will be added. If None, a new network is - created. - :param bool include_router: Flag to determine if a router should be included in the LAN. Defaults to True. - :return: The network object with the LAN components added. - :raises ValueError: If pcs_ip_block_start is less than or equal to the number of required switches. - """ - # Initialise the network if not provided - if not network: - network = Network() - - # Calculate the required number of switches - num_of_switches = num_of_switches_required(num_nodes=num_pcs) - effective_network_interface = 23 # One port less for router connection - if pcs_ip_block_start <= num_of_switches: - raise ValueError(f"pcs_ip_block_start must be greater than the number of required switches {num_of_switches}") - - # Create a core switch if more than one edge switch is needed - if num_of_switches > 1: - core_switch = Switch(hostname=f"switch_core_{lan_name}", start_up_duration=0) - core_switch.power_on() - network.add_node(core_switch) - core_switch_port = 1 - - # Initialise the default gateway to None - default_gateway = None - - # Optionally include a router in the LAN - if include_router: - default_gateway = IPv4Address(f"192.168.{subnet_base}.1") - router = Router(hostname=f"router_{lan_name}", start_up_duration=0) - router.power_on() - router.acl.add_rule( - action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 - ) - router.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) - network.add_node(router) - router.configure_port(port=1, ip_address=default_gateway, subnet_mask="255.255.255.0") - router.enable_port(1) - - # Initialise the first edge switch and connect to the router or core switch - switch_port = 0 - switch_n = 1 - switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0) - switch.power_on() - network.add_node(switch) - if num_of_switches > 1: - network.connect( - core_switch.network_interface[core_switch_port], switch.network_interface[24], bandwidth=bandwidth - ) - else: - network.connect(router.network_interface[1], switch.network_interface[24], bandwidth=bandwidth) - - # Add PCs to the LAN and connect them to switches - for i in range(1, num_pcs + 1): - # Add a new edge switch if the current one is full - if switch_port == effective_network_interface: - switch_n += 1 - switch_port = 0 - switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0) - switch.power_on() - network.add_node(switch) - # Connect the new switch to the router or core switch - if num_of_switches > 1: - core_switch_port += 1 - network.connect( - core_switch.network_interface[core_switch_port], switch.network_interface[24], bandwidth=bandwidth - ) - else: - network.connect(router.network_interface[1], switch.network_interface[24], bandwidth=bandwidth) - - # Create and add a PC to the network - pc = Computer( - hostname=f"pc_{i}_{lan_name}", - ip_address=f"192.168.{subnet_base}.{i+pcs_ip_block_start-1}", - subnet_mask="255.255.255.0", - default_gateway=default_gateway, - start_up_duration=0, - ) - pc.power_on() - network.add_node(pc) - - # Connect the PC to the switch - switch_port += 1 - network.connect(switch.network_interface[switch_port], pc.network_interface[1], bandwidth=bandwidth) - switch.network_interface[switch_port].enable() - - return network diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 050f4667..51e200e7 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1539,6 +1539,29 @@ class Node(SimComponent): SYSTEM_SOFTWARE: ClassVar[Dict[str, Type[Software]]] = {} "Base system software that must be preinstalled." + _registry: ClassVar[Dict[str, Type["Node"]]] = {} + """Registry of application types. Automatically populated when subclasses are defined.""" + + _identifier: ClassVar[str] = "unknown" + """Identifier for this particular class, used for printing and logging. Each subclass redefines this.""" + + def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: + """ + Register a node type. + + :param identifier: Uniquely specifies an node class by name. Used for finding items by config. + :type identifier: str + :raises ValueError: When attempting to register an node with a name that is already allocated. + """ + if identifier == "default": + return + identifier = identifier.lower() + super().__init_subclass__(**kwargs) + if identifier in cls._registry: + raise ValueError(f"Tried to define new node {identifier}, but this name is already reserved.") + cls._registry[identifier] = cls + cls._identifier = identifier + def __init__(self, **kwargs): """ Initialize the Node with various components and managers. diff --git a/src/primaite/simulator/network/hardware/nodes/host/computer.py b/src/primaite/simulator/network/hardware/nodes/host/computer.py index 68c72554..4253d15c 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/computer.py +++ b/src/primaite/simulator/network/hardware/nodes/host/computer.py @@ -5,7 +5,7 @@ from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.system.services.ftp.ftp_client import FTPClient -class Computer(HostNode): +class Computer(HostNode, identifier="computer"): """ A basic Computer class. diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 5699721b..0c309136 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -2,7 +2,7 @@ from __future__ import annotations from ipaddress import IPv4Address -from typing import Any, ClassVar, Dict, Optional, Type +from typing import Any, ClassVar, Dict, Optional from primaite import getLogger from primaite.simulator.network.hardware.base import ( @@ -262,7 +262,7 @@ class NIC(IPWiredNetworkInterface): return f"Port {self.port_name if self.port_name else self.port_num}: {self.mac_address}/{self.ip_address}" -class HostNode(Node): +class HostNode(Node, identifier="HostNode"): """ Represents a host node in the network. @@ -325,30 +325,10 @@ class HostNode(Node): network_interface: Dict[int, NIC] = {} "The NICs on the node by port id." - _registry: ClassVar[Dict[str, Type["HostNode"]]] = {} - """Registry of application types. Automatically populated when subclasses are defined.""" - def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs): super().__init__(**kwargs) self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask)) - def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: - """ - Register a hostnode type. - - :param identifier: Uniquely specifies an hostnode class by name. Used for finding items by config. - :type identifier: str - :raises ValueError: When attempting to register an hostnode with a name that is already allocated. - """ - if identifier == "default": - return - # Enforce lowercase registry entries because it makes comparisons everywhere else much easier. - identifier = identifier.lower() - super().__init_subclass__(**kwargs) - if identifier in cls._registry: - raise ValueError(f"Tried to define new hostnode {identifier}, but this name is already reserved.") - cls._registry[identifier] = cls - @property def nmap(self) -> Optional[NMAP]: """ diff --git a/src/primaite/simulator/network/hardware/nodes/host/server.py b/src/primaite/simulator/network/hardware/nodes/host/server.py index 379c9927..bf1ef39b 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/server.py +++ b/src/primaite/simulator/network/hardware/nodes/host/server.py @@ -2,7 +2,7 @@ from primaite.simulator.network.hardware.nodes.host.host_node import HostNode -class Server(HostNode): +class Server(HostNode, identifier="server"): """ A basic Server class. @@ -31,7 +31,7 @@ class Server(HostNode): """ -class Printer(HostNode): +class Printer(HostNode, identifier="printer"): """Printer? I don't even know her!.""" # TODO: Implement printer-specific behaviour diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 47cfae57..84cf8530 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -27,7 +27,7 @@ DMZ_PORT_ID: Final[int] = 3 """The Firewall port ID of the DMZ port.""" -class Firewall(Router): +class Firewall(Router, identifier="firewall"): """ A Firewall class that extends the functionality of a Router. diff --git a/src/primaite/simulator/network/hardware/nodes/network/network_node.py b/src/primaite/simulator/network/hardware/nodes/network/network_node.py index a0cb63e1..a5b8544f 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/network_node.py +++ b/src/primaite/simulator/network/hardware/nodes/network/network_node.py @@ -1,13 +1,13 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from abc import abstractmethod -from typing import Any, ClassVar, Dict, Optional, Type +from typing import Optional from primaite.simulator.network.hardware.base import NetworkInterface, Node from primaite.simulator.network.transmission.data_link_layer import Frame from primaite.simulator.system.services.arp.arp import ARP -class NetworkNode(Node): +class NetworkNode(Node, identifier="NetworkNode"): """ Represents an abstract base class for a network node that can receive and process network frames. @@ -16,25 +16,6 @@ class NetworkNode(Node): provide functionality for receiving and processing frames received on their network interfaces. """ - _registry: ClassVar[Dict[str, Type["NetworkNode"]]] = {} - """Registry of application types. Automatically populated when subclasses are defined.""" - - def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: - """ - Register a networknode type. - - :param identifier: Uniquely specifies an networknode class by name. Used for finding items by config. - :type identifier: str - :raises ValueError: When attempting to register an networknode with a name that is already allocated. - """ - if identifier == "default": - return - identifier = identifier.lower() - super().__init_subclass__(**kwargs) - if identifier in cls._registry: - raise ValueError(f"Tried to define new networknode {identifier}, but this name is already reserved.") - cls._registry[identifier] = cls - @abstractmethod def receive_frame(self, frame: Frame, from_network_interface: NetworkInterface): """ diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 1080dca8..e921faff 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1184,7 +1184,7 @@ class RouterSessionManager(SessionManager): return outbound_network_interface, dst_mac_address, dst_ip_address, src_port, dst_port, protocol, is_broadcast -class Router(NetworkNode): +class Router(NetworkNode, identifier="router"): """ Represents a network router, managing routing and forwarding of IP packets across network interfaces. diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index 4324ac94..d29152a4 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -87,7 +87,7 @@ class SwitchPort(WiredNetworkInterface): return False -class Switch(NetworkNode): +class Switch(NetworkNode, identifier="switch"): """ A class representing a Layer 2 network switch. diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 27a13154..aed314d2 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Optional, Union from pydantic import validate_call -from primaite.simulator.network.airspace import AirSpace, IPWirelessNetworkInterface +from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, FREQ_WIFI_2_4, IPWirelessNetworkInterface from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterInterface from primaite.simulator.network.transmission.data_link_layer import Frame @@ -91,7 +91,7 @@ class WirelessAccessPoint(IPWirelessNetworkInterface): ) -class WirelessRouter(Router): +class WirelessRouter(Router, identifier="wireless_router"): """ A WirelessRouter class that extends the functionality of a standard Router to include wireless capabilities. @@ -153,7 +153,7 @@ class WirelessRouter(Router): self, ip_address: IPV4Address, subnet_mask: IPV4Address, - frequency: Optional[str] = "WIFI_2_4", + frequency: Optional[AirSpaceFrequency] = FREQ_WIFI_2_4, ): """ Configures a wireless access point (WAP). @@ -171,7 +171,7 @@ class WirelessRouter(Router): communication. Default is "WIFI_2_4". """ if not frequency: - frequency = "WIFI_2_4" + frequency = FREQ_WIFI_2_4 self.sys_log.info("Configuring wireless access point") self.wireless_access_point.disable() # Temporarily disable the WAP for reconfiguration @@ -264,7 +264,7 @@ class WirelessRouter(Router): if "wireless_access_point" in cfg: ip_address = cfg["wireless_access_point"]["ip_address"] subnet_mask = cfg["wireless_access_point"]["subnet_mask"] - frequency = cfg["wireless_access_point"]["frequency"] + frequency = AirSpaceFrequency._registry[cfg["wireless_access_point"]["frequency"]] router.configure_wireless_access_point(ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency) if "acl" in cfg: diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_creation.py b/tests/unit_tests/_primaite/_simulator/_network/test_creation.py new file mode 100644 index 00000000..2e86ebbc --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_network/test_creation.py @@ -0,0 +1,69 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest + +from primaite.simulator.network.container import Network +from primaite.simulator.network.creation import NetworkNodeAdder, OfficeLANAdder + +param_names = ("lan_name", "subnet_base", "pcs_ip_block_start", "num_pcs", "include_router", "bandwidth") +param_vals = ( + ("CORP-NETWORK", 3, 10, 6, True, 45), + ("OTHER-NETWORK", 10, 25, 26, True, 100), + ("OTHER-NETWORK", 10, 25, 55, False, 100), +) +param_dicts = [dict(zip(param_names, vals)) for vals in param_vals] + + +def _assert_valid_creation(net: Network, lan_name, subnet_base, pcs_ip_block_start, num_pcs, include_router, bandwidth): + """Assert that the network contains the correct nodes as described by config items""" + num_switches = 1 if num_pcs <= 23 else num_pcs // 23 + 2 + num_routers = 1 if include_router else 0 + total_nodes = num_pcs + num_switches + num_routers + + assert all((n.hostname.endswith(lan_name) for n in net.nodes.values())) + assert len(net.computer_nodes) == num_pcs + assert len(net.switch_nodes) == num_switches + assert len(net.router_nodes) == num_routers + assert len(net.nodes) == total_nodes + assert all( + [str(n.network_interface[1].ip_address).startswith(f"192.168.{subnet_base}") for n in net.computer_nodes] + ) + # check that computers occupy address range 192.168.3.10 - 192.168.3.16 + computer_ip_last_octets = {str(n.network_interface[1].ip_address).split(".")[-1] for n in net.computer_nodes} + assert computer_ip_last_octets == {str(i) for i in range(pcs_ip_block_start, pcs_ip_block_start + num_pcs)} + + +@pytest.mark.parametrize("kwargs", param_dicts) +def test_office_lan_adder(kwargs): + """Assert that adding an office lan via the python API works correctly.""" + net = Network() + + office_lan_config = OfficeLANAdder.ConfigSchema( + lan_name=kwargs["lan_name"], + subnet_base=kwargs["subnet_base"], + pcs_ip_block_start=kwargs["pcs_ip_block_start"], + num_pcs=kwargs["num_pcs"], + include_router=kwargs["include_router"], + bandwidth=kwargs["bandwidth"], + ) + OfficeLANAdder.add_nodes_to_net(config=office_lan_config, network=net) + + _assert_valid_creation(net=net, **kwargs) + + +@pytest.mark.parametrize("kwargs", param_dicts) +def test_office_lan_from_config(kwargs): + """Assert that the base class can add an office lan given a config dict.""" + net = Network() + + config = dict( + type="office_lan", + lan_name=kwargs["lan_name"], + subnet_base=kwargs["subnet_base"], + pcs_ip_block_start=kwargs["pcs_ip_block_start"], + num_pcs=kwargs["num_pcs"], + include_router=kwargs["include_router"], + bandwidth=kwargs["bandwidth"], + ) + + NetworkNodeAdder.from_config(config=config, network=net) + _assert_valid_creation(net=net, **kwargs)