Merged PR 561: Make it possible to add sets of nodes to the network

## Summary
* Changed the office LAN convenience function to a class with a registry. Now, plugin can register custom node adders.
* Added ability to define `node_sets` in the config that map to `NetworkNodeAdder` subclasses
* Made airspacefrequency into a DTO class again to make management simpler.
* Moved the node registry out of `HostNode` and `NetworkNode` into `Node`
* Changed game.py to check the hardcoded node types before the node registry (this will change once I add ConfigSchema to all node subclasses)
* Made `show` method of the network container show all nodes, including ones registered at runtime.

## Test process
* Existing tests passed.
* Added unit tests for node adders

## Checklist
- [X] PR is linked to a **work item**
- [X] **acceptance criteria** of linked ticket are met
- [X] performed **self-review** of the code
- [X] written **tests** for any new functionality added with this PR
- [X] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [X] updated the **change log**
- [X] ran **pre-commit** checks for code style
- [ ] attended to any **TO-DOs** left in the code
This commit is contained in:
Marek Wolan
2024-10-09 14:56:57 +00:00
17 changed files with 505 additions and 224 deletions

View File

@@ -19,12 +19,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added reward calculation details to AgentHistoryItem.
- Added a new Privilege-Escalation-and Data-Loss-Example.ipynb notebook with a realistic cyber scenario focusing on
internal privilege escalation and data loss through the manipulation of SSH access and Access Control Lists (ACLs).
- Added a new extensible `NetworkNodeAdder` class for convenient addition of sets of nodes based on a simplified config.
### Changed
- File and folder observations can now be configured to always show the true health status, or require scanning like before.
- It's now possible to disable stickiness on reward components, meaning their value returns to 0 during timesteps where agent don't issue the corresponding action. Affects `GreenAdminDatabaseUnreachablePenalty`, `WebpageUnavailablePenalty`, `WebServer404Penalty`
- Node observations can now be configured to show the number of active local and remote logins.
- Ports, IP Protocols, and airspace frequencies no longer use enums. They are defined in dictionary lookups and are handled by custom validation to enable extendability with plugins.
- Ports and IP Protocols no longer use enums. They are defined in dictionary lookups and are handled by custom validation to enable extensibility with plugins.
- Changed AirSpaceFrequency to a data transfer object with a registry to allow extensibility
- Changed the Office LAN creation convenience function to follow the new `NetworkNodeAdder` pattern. Office LANs can now also be defined in YAML config.
### Fixed
- Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config)
@@ -32,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
and `uninstall` methods in the `Node` class.
- Updated the `receive_payload_from_session_manager` method in `SoftwareManager` so that it now sends a copy of the
payload to any software listening on the destination port of the `Frame`.
- Made the `show` method of `Network` show all node types, including ones registered at runtime
### Removed
- Removed the `install` and `uninstall` methods in the `Node` class.

View File

@@ -30,6 +30,7 @@ What is PrimAITE?
source/varying_config_files
source/environment
source/action_masking
source/node_sets
.. toctree::
:caption: Notebooks:

115
docs/source/node_sets.rst Normal file
View File

@@ -0,0 +1,115 @@
.. only:: comment
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
.. _network_node_adder:
Network Node Adder Module
#########################
This module provides a framework for adding nodes to a network in a standardised way. It defines a base class ``NetworkNodeAdder``, which can be extended to create specific node adders, and utility functions to calculate network infrastructure requirements.
The module allows you to use the pre-defined node adders, ``OfficeLANAdder``, or create custom ones by extending the base class.
How It Works
============
The main class in the module is ``NetworkNodeAdder``, which defines the interface for adding nodes to a network. Child classes are expected to:
1. Define a ``ConfigSchema`` nested class to define configuration options.
2. Implement the ``add_nodes_to_net(config, network)`` method, which adds the nodes to the network according to the configuration object.
The ``NetworkNodeAdder`` base class handles node adders defined in the primAITE config YAML file as well. It does this by keeping a registry of node adder classes, and uses the ``type`` field of the config to select the appropriate class to which to pass the configuration.
Example Usage
=============
Via Python API
--------------
Adding nodes to a network can be done using the python API by constructing the relevant ``ConfigSchema`` object like this:
.. code-block:: python
net = Network()
office_lan_config = OfficeLANAdder.ConfigSchema(
lan_name="CORP-LAN",
subnet_base=2,
pcs_ip_block_start=10,
num_pcs=8,
include_router=False,
bandwidth=150,
)
OfficeLANAdder.add_nodes_to_net(config=office_lan_config, network=net)
In this example, a network with 8 computers connected by a switch will be added to the network object.
Via YAML Config
---------------
.. code-block:: yaml
simulation:
network:
nodes:
# ... nodes go here
node_sets:
- type: office_lan
lan_name: CORP_LAN
subnet_base: 2
pcs_ip_block_start: 10
num_pcs: 8
include_router: False
bandwidth: 150
# ... additional node sets can be added below
``NetworkNodeAdder`` reads the ``type`` property of the config, then constructs and passes the configuration to ``OfficeLANAdder.add_nodes_to_net()``.
In this example, a network with 8 computers connected by a switch will be added to the network object. Equivalent to the above.
Creating Custom Node Adders
===========================
To create a custom node adder, subclass NetworkNodeAdder and define:
* A ConfigSchema class that defines the configuration schema for the node adder.
* The add_nodes_to_net method that implements how nodes should be added to the network.
Example: DataCenterAdder
------------------------
Here is an example of creating a custom node adder, DataCenterAdder:
.. code-block:: python
class DataCenterAdder(NetworkNodeAdder, identifier="data_center"):
class ConfigSchema(NetworkNodeAdder.ConfigSchema):
type: Literal["data_center"] = "data_center"
num_servers: int
data_center_name: str
@classmethod
def add_nodes_to_net(cls, config: ConfigSchema, network: Network) -> None:
for i in range(config.num_servers):
server = Computer(
hostname=f"server_{i}_{config.data_center_name}",
ip_address=f"192.168.100.{i + 8}",
subnet_mask="255.255.255.0",
default_gateway="192.168.100.1",
start_up_duration=0
)
server.power_on()
network.add_node(server)
**Using the Custom Node Adder:**
.. code-block:: python
config = {
"type": "data_center",
"num_servers": 5,
"data_center_name": "dc1"
}
network = Network()
DataCenterAdder.from_config(config, network)

View File

@@ -17,7 +17,8 @@ from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
from primaite.game.agent.scripted_agents.tap001 import TAP001
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState, UserManager
from primaite.simulator.network.creation import NetworkNodeAdder
from primaite.simulator.network.hardware.base import NetworkInterface, Node, NodeOperatingState, UserManager
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
@@ -270,6 +271,7 @@ class PrimaiteGame:
nodes_cfg = network_config.get("nodes", [])
links_cfg = network_config.get("links", [])
node_sets_cfg = network_config.get("node_sets", [])
# Set the NMNE capture config
NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {}))
@@ -277,22 +279,8 @@ class PrimaiteGame:
n_type = node_cfg["type"]
new_node = None
# Handle extended nodes
if n_type.lower() in HostNode._registry:
new_node = HostNode._registry[n_type](
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
default_gateway=node_cfg.get("default_gateway"),
dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
elif n_type in NetworkNode._registry:
new_node = NetworkNode._registry[n_type](**node_cfg)
# Default PrimAITE nodes
elif n_type == "computer":
if n_type == "computer":
new_node = Computer(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
@@ -337,6 +325,20 @@ class PrimaiteGame:
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
# Handle extended nodes
elif n_type.lower() in Node._registry:
new_node = HostNode._registry[n_type](
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
default_gateway=node_cfg.get("default_gateway"),
dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
elif n_type in NetworkNode._registry:
new_node = NetworkNode._registry[n_type](**node_cfg)
else:
msg = f"invalid node type {n_type} in config"
_LOGGER.error(msg)
@@ -505,6 +507,10 @@ class PrimaiteGame:
new_node.start_up_duration = int(node_cfg.get("start_up_duration", 3))
new_node.shut_down_duration = int(node_cfg.get("shut_down_duration", 3))
# 1.1 Create Node Sets
for node_set_cfg in node_sets_cfg:
NetworkNodeAdder.from_config(node_set_cfg, network=net)
# 2. create links between nodes
for link_cfg in links_cfg:
node_a = net.get_node_by_hostname(link_cfg["endpoint_a_hostname"])

View File

@@ -1,12 +1,11 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from __future__ import annotations
import copy
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any, ClassVar, Dict, List
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel, Field, validate_call
from pydantic import BaseModel, ConfigDict, Field, validate_call
from primaite import getLogger
from primaite.simulator.network.hardware.base import Layer3Interface, NetworkInterface, WiredNetworkInterface
@@ -41,30 +40,28 @@ def format_hertz(hertz: float, format_terahertz: bool = False, decimals: int = 3
return format_str.format(hertz) + " Hz"
_default_frequency_set: Dict[str, Dict] = {
"WIFI_2_4": {"frequency": 2.4e9, "data_rate_bps": 100_000_000.0},
"WIFI_5": {"frequency": 5e9, "data_rate_bps": 500_000_000.0},
}
"""Frequency configuration that is automatically used for any new airspace."""
class AirSpaceFrequency(BaseModel):
"""Data transfer object for defining properties of an airspace frequency."""
model_config = ConfigDict(extra="forbid")
name: str
"""Alias for frequency."""
frequency_hz: int
"""This acts as the primary key. If two names are mapped to the same frequency, they will share a bandwidth."""
data_rate_bps: float
"""How much data can be transmitted on this frequency per second."""
_registry: ClassVar[Dict[str, AirSpaceFrequency]] = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.name in self._registry:
raise RuntimeError(f"Frequency {self.name} is already registered. Cannot register it again.")
self._registry[self.name] = self
def register_default_frequency(freq_name: str, freq_hz: float, data_rate_bps: float) -> None:
"""Add to the default frequency configuration. This is intended as a plugin hook.
If your plugin makes use of bespoke frequencies for wireless communication, you should make a call to this method
wherever you define components that rely on the bespoke frequencies. That way, as soon as your components are
imported, this function automatically updates the default frequency set.
This should also be run before instances of AirSpace are created.
:param freq_name: The frequency name. If this clashes with an existing frequency name, it will be overwritten.
:type freq_name: str
:param freq_hz: The frequency itself, measured in Hertz.
:type freq_hz: float
:param data_rate_bps: The transmission capacity over this frequency, in bits per second.
:type data_rate_bps: float
"""
_default_frequency_set.update({freq_name: {"frequency": freq_hz, "data_rate_bps": data_rate_bps}})
FREQ_WIFI_2_4 = AirSpaceFrequency(name="WIFI_2_4", frequency_hz=2.4e9, data_rate_bps=100_000_000.0)
FREQ_WIFI_5 = AirSpaceFrequency(name="WIFI_5", frequency_hz=5e9, data_rate_bps=500_000_000.0)
class AirSpace(BaseModel):
@@ -79,7 +76,7 @@ class AirSpace(BaseModel):
wireless_interfaces: Dict[str, WirelessNetworkInterface] = Field(default_factory=lambda: {})
wireless_interfaces_by_frequency: Dict[int, List[WirelessNetworkInterface]] = Field(default_factory=lambda: {})
bandwidth_load: Dict[int, float] = Field(default_factory=lambda: {})
frequencies: Dict[str, Dict] = Field(default_factory=lambda: copy.deepcopy(_default_frequency_set))
frequencies: Dict[str, AirSpaceFrequency] = AirSpaceFrequency._registry
@validate_call
def get_frequency_max_capacity_mbps(self, freq_name: str) -> float:
@@ -90,7 +87,7 @@ class AirSpace(BaseModel):
:return: The maximum capacity in Mbps for the specified frequency.
"""
if freq_name in self.frequencies:
return self.frequencies[freq_name]["data_rate_bps"] / (1024.0 * 1024.0)
return self.frequencies[freq_name].data_rate_bps / (1024.0 * 1024.0)
return 0.0
def set_frequency_max_capacity_mbps(self, cfg: Dict[int, float]) -> None:
@@ -100,7 +97,7 @@ class AirSpace(BaseModel):
:param cfg: A dictionary mapping frequencies to their new maximum capacities in Mbps.
"""
for freq, mbps in cfg.items():
self.frequencies[freq]["data_rate_bps"] = mbps * 1024 * 1024
self.frequencies[freq].data_rate_bps = mbps * 1024 * 1024
print(f"Overriding {freq} max capacity as {mbps:.3f} mbps")
def register_frequency(self, freq_name: str, freq_hz: float, data_rate_bps: float) -> None:
@@ -117,10 +114,12 @@ class AirSpace(BaseModel):
if freq_name in self.frequencies:
_LOGGER.info(
f"Overwriting Air space frequency {freq_name}. "
f"Previous data rate: {self.frequencies[freq_name]['data_rate_bps']}. "
f"Previous data rate: {self.frequencies[freq_name].data_rate_bps}. "
f"Current data rate: {data_rate_bps}."
)
self.frequencies.update({freq_name: {"frequency": freq_hz, "data_rate_bps": data_rate_bps}})
self.frequencies.update(
{freq_name: AirSpaceFrequency(name=freq_name, frequency_hz=freq_hz, data_rate_bps=data_rate_bps)}
)
def show_bandwidth_load(self, markdown: bool = False):
"""
@@ -145,7 +144,7 @@ class AirSpace(BaseModel):
load_percent = 1.0
table.add_row(
[
format_hertz(self.frequencies[frequency]["frequency"]),
format_hertz(self.frequencies[frequency].frequency_hz),
f"{load_percent:.0%}",
f"{maximum_capacity:.3f}",
]
@@ -181,7 +180,7 @@ class AirSpace(BaseModel):
interface.mac_address,
interface.ip_address if hasattr(interface, "ip_address") else None,
interface.subnet_mask if hasattr(interface, "subnet_mask") else None,
format_hertz(self.frequencies[interface.frequency]["frequency"]),
format_hertz(self.frequencies[interface.frequency].frequency_hz),
f"{interface.speed:.3f}",
status,
]
@@ -209,9 +208,9 @@ class AirSpace(BaseModel):
"""
if wireless_interface.mac_address not in self.wireless_interfaces:
self.wireless_interfaces[wireless_interface.mac_address] = wireless_interface
if wireless_interface.frequency not in self.wireless_interfaces_by_frequency:
self.wireless_interfaces_by_frequency[wireless_interface.frequency] = []
self.wireless_interfaces_by_frequency[wireless_interface.frequency].append(wireless_interface)
if wireless_interface.frequency.frequency_hz not in self.wireless_interfaces_by_frequency:
self.wireless_interfaces_by_frequency[wireless_interface.frequency.frequency_hz] = []
self.wireless_interfaces_by_frequency[wireless_interface.frequency.frequency_hz].append(wireless_interface)
def remove_wireless_interface(self, wireless_interface: WirelessNetworkInterface):
"""
@@ -221,7 +220,7 @@ class AirSpace(BaseModel):
"""
if wireless_interface.mac_address in self.wireless_interfaces:
self.wireless_interfaces.pop(wireless_interface.mac_address)
self.wireless_interfaces_by_frequency[wireless_interface.frequency].remove(wireless_interface)
self.wireless_interfaces_by_frequency[wireless_interface.frequency.frequency_hz].remove(wireless_interface)
def clear(self):
"""
@@ -297,7 +296,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC):
"""
airspace: AirSpace
frequency: str = "WIFI_2_4"
frequency: AirSpaceFrequency = FREQ_WIFI_2_4
def enable(self):
"""Attempt to enable the network interface."""

View File

@@ -179,9 +179,8 @@ class Network(SimComponent):
table.set_style(MARKDOWN)
table.align = "l"
table.title = "Nodes"
for node_type, nodes in nodes_type_map.items():
for node in nodes:
table.add_row([node.hostname, node_type, node.operating_state.name])
for node in self.nodes.values():
table.add_row((node.hostname, type(node)._identifier, node.operating_state.name))
print(table)
if ip_addresses:

View File

@@ -1,6 +1,9 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from abc import ABC, abstractmethod
from ipaddress import IPv4Address
from typing import Optional
from typing import Any, ClassVar, Dict, Literal, Self, Type
from pydantic import BaseModel, model_validator
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
@@ -10,6 +13,219 @@ from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.port import PORT_LOOKUP
class NetworkNodeAdder(BaseModel):
"""
Base class for adding a set of related nodes to a network in a standardised way.
Child classes should define a ConfigSchema nested class that subclasses NetworkNodeAdder.ConfigSchema and a __call__
method which performs the node addition to the network.
Here is a template that users can use to define custom node adders:
```
class YourNodeAdder(NetworkNodeAdder, identifier="your_name"):
class ConfigSchema(NetworkNodeAdder.ConfigSchema):
property_1 : str
property_2 : int
@classmethod
def add_nodes_to_net(cls, config: ConfigSchema, network: Network) -> None:
node_1 = Node(property_1, ...)
node_2 = Node(...)
network.connect(node_1.network_interface[1], node_2.network_interface[1])
...
```
"""
class ConfigSchema(BaseModel, ABC):
"""
Base schema for node adders.
Child classes of NetworkNodeAdder must define a schema which inherits from this schema. The identifier is used
by the from_config method to select the correct node adder at runtime.
"""
type: str
"""Uniquely identifies the node adder class to use for adding nodes to network."""
_registry: ClassVar[Dict[str, Type["NetworkNodeAdder"]]] = {}
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
"""
Register a network node adder class.
:param identifier: Unique name for the node adder to use for matching against primaite config entries.
:type identifier: str
:raises ValueError: When attempting to register a name that is already reserved.
"""
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Duplicate node adder {identifier}")
cls._registry[identifier] = cls
@classmethod
@abstractmethod
def add_nodes_to_net(cls, config: ConfigSchema, network: Network) -> None:
"""
Add nodes to the network.
Abstract method that must be overwritten by child classes. Use the config definition to create nodes and add
them to the network that is passed in.
:param config: Config object that defines how to create and add nodes to the network
:type config: ConfigSchema
:param network: PrimAITE network object to which to add nodes.
:type network: Network
"""
pass
@classmethod
def from_config(cls, config: Dict, network: Network) -> None:
"""
Accept a config, find the relevant node adder class, and call it to add nodes to the network.
Child classes do not need to define this method.
:param config: Configuration object for the child adder class
:type config: Dict
:param network: The Network object to which to add nodes
:type network: Network
"""
if config["type"] not in cls._registry:
raise ValueError(f"Invalid node adder type {config['type']}")
adder_class = cls._registry[config["type"]]
adder_class.add_nodes_to_net(config=adder_class.ConfigSchema(**config), network=network)
class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"):
"""Creates an office LAN."""
class ConfigSchema(NetworkNodeAdder.ConfigSchema):
"""Configuration schema for OfficeLANAdder."""
type: Literal["office_lan"] = "office_lan"
lan_name: str
"""Name of lan used for generating hostnames for new nodes."""
subnet_base: int
"""Used as the third octet of IP addresses for nodes in the network."""
pcs_ip_block_start: int
"""Starting point for the fourth octet of IP addresses of nodes in the network."""
num_pcs: int
"""The number of hosts to generate."""
include_router: bool = True
"""Whether to include a router in the new office LAN."""
bandwidth: int = 100
"""Data bandwidth to the LAN measured in Mbps."""
@model_validator(mode="after")
def check_ip_range(self) -> Self:
"""Make sure the ip addresses of hosts don't exceed the maximum possible ip address."""
if self.pcs_ip_block_start + self.num_pcs >= 254:
raise ValueError(
f"Cannot create {self.num_pcs} pcs starting at ip block {self.pcs_ip_block_start} "
f"because ip address octets cannot exceed 254."
)
return self
@classmethod
def add_nodes_to_net(cls, config: ConfigSchema, network: Network) -> None:
"""
Add an office lan to the network according to the config definition.
This method creates a number of hosts and enough switches such that all hosts can be connected to a switch.
Optionally, a router is added to connect the switches together. All the nodes and networking devices are added
to the provided network.
:param config: Configuration object specifying office LAN parameters
:type config: OfficeLANAdder.ConfigSchema
:param network: The PrimAITE network to which to add the office LAN.
:type network: Network
:raises ValueError: upon invalid configuration
"""
# Calculate the required number of switches
num_of_switches = num_of_switches_required(num_nodes=config.num_pcs)
effective_network_interface = 23 # One port less for router connection
if config.pcs_ip_block_start <= num_of_switches:
raise ValueError(
f"pcs_ip_block_start must be greater than the number of required switches {num_of_switches}"
)
# Create a core switch if more than one edge switch is needed
if num_of_switches > 1:
core_switch = Switch(hostname=f"switch_core_{config.lan_name}", start_up_duration=0)
core_switch.power_on()
network.add_node(core_switch)
core_switch_port = 1
# Initialise the default gateway to None
default_gateway = None
# Optionally include a router in the LAN
if config.include_router:
default_gateway = IPv4Address(f"192.168.{config.subnet_base}.1")
router = Router(hostname=f"router_{config.lan_name}", start_up_duration=0)
router.power_on()
router.acl.add_rule(
action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22
)
router.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23)
network.add_node(router)
router.configure_port(port=1, ip_address=default_gateway, subnet_mask="255.255.255.0")
router.enable_port(1)
# Initialise the first edge switch and connect to the router or core switch
switch_port = 0
switch_n = 1
switch = Switch(hostname=f"switch_edge_{switch_n}_{config.lan_name}", start_up_duration=0)
switch.power_on()
network.add_node(switch)
if num_of_switches > 1:
network.connect(
core_switch.network_interface[core_switch_port],
switch.network_interface[24],
bandwidth=config.bandwidth,
)
else:
network.connect(router.network_interface[1], switch.network_interface[24], bandwidth=config.bandwidth)
# Add PCs to the LAN and connect them to switches
for i in range(1, config.num_pcs + 1):
# Add a new edge switch if the current one is full
if switch_port == effective_network_interface:
switch_n += 1
switch_port = 0
switch = Switch(hostname=f"switch_edge_{switch_n}_{config.lan_name}", start_up_duration=0)
switch.power_on()
network.add_node(switch)
# Connect the new switch to the router or core switch
if num_of_switches > 1:
core_switch_port += 1
network.connect(
core_switch.network_interface[core_switch_port],
switch.network_interface[24],
bandwidth=config.bandwidth,
)
else:
network.connect(
router.network_interface[1], switch.network_interface[24], bandwidth=config.bandwidth
)
# Create and add a PC to the network
pc = Computer(
hostname=f"pc_{i}_{config.lan_name}",
ip_address=f"192.168.{config.subnet_base}.{i+config.pcs_ip_block_start-1}",
subnet_mask="255.255.255.0",
default_gateway=default_gateway,
start_up_duration=0,
)
pc.power_on()
network.add_node(pc)
# Connect the PC to the switch
switch_port += 1
network.connect(switch.network_interface[switch_port], pc.network_interface[1], bandwidth=config.bandwidth)
switch.network_interface[switch_port].enable()
def num_of_switches_required(num_nodes: int, max_network_interface: int = 24) -> int:
"""
Calculate the minimum number of network switches required to connect a given number of nodes.
@@ -42,115 +258,3 @@ def num_of_switches_required(num_nodes: int, max_network_interface: int = 24) ->
# Return the total number of switches required
return full_switches + (1 if extra_pcs > 0 else 0)
def create_office_lan(
lan_name: str,
subnet_base: int,
pcs_ip_block_start: int,
num_pcs: int,
network: Optional[Network] = None,
include_router: bool = True,
bandwidth: int = 100,
) -> Network:
"""
Creates a 2-Tier or 3-Tier office local area network (LAN).
The LAN is configured with a specified number of personal computers (PCs), optionally including a router,
and multiple edge switches to connect them. A core switch is added only if more than one edge switch is required.
The network topology involves edge switches connected either directly to the router in a 2-Tier setup or
to a core switch in a 3-Tier setup. If a router is included, it is connected to the core switch (if present)
and configured with basic access control list (ACL) rules. PCs are distributed across the edge switches.
:param str lan_name: The name to be assigned to the LAN.
:param int subnet_base: The subnet base number to be used in the IP addresses.
:param int pcs_ip_block_start: The starting block for assigning IP addresses to PCs.
:param int num_pcs: The number of PCs to be added to the LAN.
:param Optional[Network] network: The network to which the LAN components will be added. If None, a new network is
created.
:param bool include_router: Flag to determine if a router should be included in the LAN. Defaults to True.
:return: The network object with the LAN components added.
:raises ValueError: If pcs_ip_block_start is less than or equal to the number of required switches.
"""
# Initialise the network if not provided
if not network:
network = Network()
# Calculate the required number of switches
num_of_switches = num_of_switches_required(num_nodes=num_pcs)
effective_network_interface = 23 # One port less for router connection
if pcs_ip_block_start <= num_of_switches:
raise ValueError(f"pcs_ip_block_start must be greater than the number of required switches {num_of_switches}")
# Create a core switch if more than one edge switch is needed
if num_of_switches > 1:
core_switch = Switch(hostname=f"switch_core_{lan_name}", start_up_duration=0)
core_switch.power_on()
network.add_node(core_switch)
core_switch_port = 1
# Initialise the default gateway to None
default_gateway = None
# Optionally include a router in the LAN
if include_router:
default_gateway = IPv4Address(f"192.168.{subnet_base}.1")
router = Router(hostname=f"router_{lan_name}", start_up_duration=0)
router.power_on()
router.acl.add_rule(
action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22
)
router.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23)
network.add_node(router)
router.configure_port(port=1, ip_address=default_gateway, subnet_mask="255.255.255.0")
router.enable_port(1)
# Initialise the first edge switch and connect to the router or core switch
switch_port = 0
switch_n = 1
switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0)
switch.power_on()
network.add_node(switch)
if num_of_switches > 1:
network.connect(
core_switch.network_interface[core_switch_port], switch.network_interface[24], bandwidth=bandwidth
)
else:
network.connect(router.network_interface[1], switch.network_interface[24], bandwidth=bandwidth)
# Add PCs to the LAN and connect them to switches
for i in range(1, num_pcs + 1):
# Add a new edge switch if the current one is full
if switch_port == effective_network_interface:
switch_n += 1
switch_port = 0
switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0)
switch.power_on()
network.add_node(switch)
# Connect the new switch to the router or core switch
if num_of_switches > 1:
core_switch_port += 1
network.connect(
core_switch.network_interface[core_switch_port], switch.network_interface[24], bandwidth=bandwidth
)
else:
network.connect(router.network_interface[1], switch.network_interface[24], bandwidth=bandwidth)
# Create and add a PC to the network
pc = Computer(
hostname=f"pc_{i}_{lan_name}",
ip_address=f"192.168.{subnet_base}.{i+pcs_ip_block_start-1}",
subnet_mask="255.255.255.0",
default_gateway=default_gateway,
start_up_duration=0,
)
pc.power_on()
network.add_node(pc)
# Connect the PC to the switch
switch_port += 1
network.connect(switch.network_interface[switch_port], pc.network_interface[1], bandwidth=bandwidth)
switch.network_interface[switch_port].enable()
return network

View File

@@ -1539,6 +1539,29 @@ class Node(SimComponent):
SYSTEM_SOFTWARE: ClassVar[Dict[str, Type[Software]]] = {}
"Base system software that must be preinstalled."
_registry: ClassVar[Dict[str, Type["Node"]]] = {}
"""Registry of application types. Automatically populated when subclasses are defined."""
_identifier: ClassVar[str] = "unknown"
"""Identifier for this particular class, used for printing and logging. Each subclass redefines this."""
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
"""
Register a node type.
:param identifier: Uniquely specifies an node class by name. Used for finding items by config.
:type identifier: str
:raises ValueError: When attempting to register an node with a name that is already allocated.
"""
if identifier == "default":
return
identifier = identifier.lower()
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Tried to define new node {identifier}, but this name is already reserved.")
cls._registry[identifier] = cls
cls._identifier = identifier
def __init__(self, **kwargs):
"""
Initialize the Node with various components and managers.

View File

@@ -5,7 +5,7 @@ from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
class Computer(HostNode):
class Computer(HostNode, identifier="computer"):
"""
A basic Computer class.

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Any, ClassVar, Dict, Optional, Type
from typing import Any, ClassVar, Dict, Optional
from primaite import getLogger
from primaite.simulator.network.hardware.base import (
@@ -262,7 +262,7 @@ class NIC(IPWiredNetworkInterface):
return f"Port {self.port_name if self.port_name else self.port_num}: {self.mac_address}/{self.ip_address}"
class HostNode(Node):
class HostNode(Node, identifier="HostNode"):
"""
Represents a host node in the network.
@@ -325,30 +325,10 @@ class HostNode(Node):
network_interface: Dict[int, NIC] = {}
"The NICs on the node by port id."
_registry: ClassVar[Dict[str, Type["HostNode"]]] = {}
"""Registry of application types. Automatically populated when subclasses are defined."""
def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs):
super().__init__(**kwargs)
self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask))
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
"""
Register a hostnode type.
:param identifier: Uniquely specifies an hostnode class by name. Used for finding items by config.
:type identifier: str
:raises ValueError: When attempting to register an hostnode with a name that is already allocated.
"""
if identifier == "default":
return
# Enforce lowercase registry entries because it makes comparisons everywhere else much easier.
identifier = identifier.lower()
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Tried to define new hostnode {identifier}, but this name is already reserved.")
cls._registry[identifier] = cls
@property
def nmap(self) -> Optional[NMAP]:
"""

View File

@@ -2,7 +2,7 @@
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
class Server(HostNode):
class Server(HostNode, identifier="server"):
"""
A basic Server class.
@@ -31,7 +31,7 @@ class Server(HostNode):
"""
class Printer(HostNode):
class Printer(HostNode, identifier="printer"):
"""Printer? I don't even know her!."""
# TODO: Implement printer-specific behaviour

View File

@@ -27,7 +27,7 @@ DMZ_PORT_ID: Final[int] = 3
"""The Firewall port ID of the DMZ port."""
class Firewall(Router):
class Firewall(Router, identifier="firewall"):
"""
A Firewall class that extends the functionality of a Router.

View File

@@ -1,13 +1,13 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from abc import abstractmethod
from typing import Any, ClassVar, Dict, Optional, Type
from typing import Optional
from primaite.simulator.network.hardware.base import NetworkInterface, Node
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.system.services.arp.arp import ARP
class NetworkNode(Node):
class NetworkNode(Node, identifier="NetworkNode"):
"""
Represents an abstract base class for a network node that can receive and process network frames.
@@ -16,25 +16,6 @@ class NetworkNode(Node):
provide functionality for receiving and processing frames received on their network interfaces.
"""
_registry: ClassVar[Dict[str, Type["NetworkNode"]]] = {}
"""Registry of application types. Automatically populated when subclasses are defined."""
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
"""
Register a networknode type.
:param identifier: Uniquely specifies an networknode class by name. Used for finding items by config.
:type identifier: str
:raises ValueError: When attempting to register an networknode with a name that is already allocated.
"""
if identifier == "default":
return
identifier = identifier.lower()
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Tried to define new networknode {identifier}, but this name is already reserved.")
cls._registry[identifier] = cls
@abstractmethod
def receive_frame(self, frame: Frame, from_network_interface: NetworkInterface):
"""

View File

@@ -1184,7 +1184,7 @@ class RouterSessionManager(SessionManager):
return outbound_network_interface, dst_mac_address, dst_ip_address, src_port, dst_port, protocol, is_broadcast
class Router(NetworkNode):
class Router(NetworkNode, identifier="router"):
"""
Represents a network router, managing routing and forwarding of IP packets across network interfaces.

View File

@@ -87,7 +87,7 @@ class SwitchPort(WiredNetworkInterface):
return False
class Switch(NetworkNode):
class Switch(NetworkNode, identifier="switch"):
"""
A class representing a Layer 2 network switch.

View File

@@ -4,7 +4,7 @@ from typing import Any, Dict, Optional, Union
from pydantic import validate_call
from primaite.simulator.network.airspace import AirSpace, IPWirelessNetworkInterface
from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, FREQ_WIFI_2_4, IPWirelessNetworkInterface
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterInterface
from primaite.simulator.network.transmission.data_link_layer import Frame
@@ -91,7 +91,7 @@ class WirelessAccessPoint(IPWirelessNetworkInterface):
)
class WirelessRouter(Router):
class WirelessRouter(Router, identifier="wireless_router"):
"""
A WirelessRouter class that extends the functionality of a standard Router to include wireless capabilities.
@@ -153,7 +153,7 @@ class WirelessRouter(Router):
self,
ip_address: IPV4Address,
subnet_mask: IPV4Address,
frequency: Optional[str] = "WIFI_2_4",
frequency: Optional[AirSpaceFrequency] = FREQ_WIFI_2_4,
):
"""
Configures a wireless access point (WAP).
@@ -171,7 +171,7 @@ class WirelessRouter(Router):
communication. Default is "WIFI_2_4".
"""
if not frequency:
frequency = "WIFI_2_4"
frequency = FREQ_WIFI_2_4
self.sys_log.info("Configuring wireless access point")
self.wireless_access_point.disable() # Temporarily disable the WAP for reconfiguration
@@ -264,7 +264,7 @@ class WirelessRouter(Router):
if "wireless_access_point" in cfg:
ip_address = cfg["wireless_access_point"]["ip_address"]
subnet_mask = cfg["wireless_access_point"]["subnet_mask"]
frequency = cfg["wireless_access_point"]["frequency"]
frequency = AirSpaceFrequency._registry[cfg["wireless_access_point"]["frequency"]]
router.configure_wireless_access_point(ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency)
if "acl" in cfg:

View File

@@ -0,0 +1,69 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import pytest
from primaite.simulator.network.container import Network
from primaite.simulator.network.creation import NetworkNodeAdder, OfficeLANAdder
param_names = ("lan_name", "subnet_base", "pcs_ip_block_start", "num_pcs", "include_router", "bandwidth")
param_vals = (
("CORP-NETWORK", 3, 10, 6, True, 45),
("OTHER-NETWORK", 10, 25, 26, True, 100),
("OTHER-NETWORK", 10, 25, 55, False, 100),
)
param_dicts = [dict(zip(param_names, vals)) for vals in param_vals]
def _assert_valid_creation(net: Network, lan_name, subnet_base, pcs_ip_block_start, num_pcs, include_router, bandwidth):
"""Assert that the network contains the correct nodes as described by config items"""
num_switches = 1 if num_pcs <= 23 else num_pcs // 23 + 2
num_routers = 1 if include_router else 0
total_nodes = num_pcs + num_switches + num_routers
assert all((n.hostname.endswith(lan_name) for n in net.nodes.values()))
assert len(net.computer_nodes) == num_pcs
assert len(net.switch_nodes) == num_switches
assert len(net.router_nodes) == num_routers
assert len(net.nodes) == total_nodes
assert all(
[str(n.network_interface[1].ip_address).startswith(f"192.168.{subnet_base}") for n in net.computer_nodes]
)
# check that computers occupy address range 192.168.3.10 - 192.168.3.16
computer_ip_last_octets = {str(n.network_interface[1].ip_address).split(".")[-1] for n in net.computer_nodes}
assert computer_ip_last_octets == {str(i) for i in range(pcs_ip_block_start, pcs_ip_block_start + num_pcs)}
@pytest.mark.parametrize("kwargs", param_dicts)
def test_office_lan_adder(kwargs):
"""Assert that adding an office lan via the python API works correctly."""
net = Network()
office_lan_config = OfficeLANAdder.ConfigSchema(
lan_name=kwargs["lan_name"],
subnet_base=kwargs["subnet_base"],
pcs_ip_block_start=kwargs["pcs_ip_block_start"],
num_pcs=kwargs["num_pcs"],
include_router=kwargs["include_router"],
bandwidth=kwargs["bandwidth"],
)
OfficeLANAdder.add_nodes_to_net(config=office_lan_config, network=net)
_assert_valid_creation(net=net, **kwargs)
@pytest.mark.parametrize("kwargs", param_dicts)
def test_office_lan_from_config(kwargs):
"""Assert that the base class can add an office lan given a config dict."""
net = Network()
config = dict(
type="office_lan",
lan_name=kwargs["lan_name"],
subnet_base=kwargs["subnet_base"],
pcs_ip_block_start=kwargs["pcs_ip_block_start"],
num_pcs=kwargs["num_pcs"],
include_router=kwargs["include_router"],
bandwidth=kwargs["bandwidth"],
)
NetworkNodeAdder.from_config(config=config, network=net)
_assert_valid_creation(net=net, **kwargs)