diff --git a/src/primaite/exceptions.py b/src/primaite/exceptions.py index 3b4058ac..025f6d41 100644 --- a/src/primaite/exceptions.py +++ b/src/primaite/exceptions.py @@ -9,3 +9,9 @@ class RLlibAgentError(PrimaiteError): """Raised when there is a generic error with a RLlib agent that is specific to PRimAITE.""" pass + + +class NetworkError(PrimaiteError): + """Raised when an error occurs at the network level.""" + + pass diff --git a/src/primaite/simulator/network/physical_layer.py b/src/primaite/simulator/network/physical_layer.py index bb4b120e..6d268b59 100644 --- a/src/primaite/simulator/network/physical_layer.py +++ b/src/primaite/simulator/network/physical_layer.py @@ -2,15 +2,19 @@ from __future__ import annotations import re import secrets -from ipaddress import IPv4Address -from typing import Dict, List, Optional +from ipaddress import IPv4Address, IPv4Network +from typing import Any, Dict, List, Optional, Union +from primaite import getLogger +from primaite.exceptions import NetworkError from primaite.simulator.core import SimComponent +_LOGGER = getLogger(__name__) + def generate_mac_address(oui: Optional[str] = None) -> str: """ - Generate a random MAC Address.. + Generate a random MAC Address. :Example: @@ -29,9 +33,8 @@ def generate_mac_address(oui: Optional[str] = None) -> str: if oui: oui_pattern = re.compile(r"^([0-9A-Fa-f]{2}[:-]){2}[0-9A-Fa-f]{2}$") if not oui_pattern.match(oui): - raise ValueError( - f"Invalid oui. The oui should be in the format 'xx:xx:xx', where x is a hexadecimal digit, got '{oui}'." - ) + msg = f"Invalid oui. The oui should be in the format xx:xx:xx, where x is a hexadecimal digit, got '{oui}'" + raise ValueError(msg) oui_bytes = [int(chunk, 16) for chunk in oui.split(":")] mac = oui_bytes + random_bytes[len(oui_bytes) :] else: @@ -54,26 +57,21 @@ class Link(SimComponent): endpoint_a: NIC endpoint_b: NIC - bandwidth: int + bandwidth: int = 100 current_load: int = 0 - def __init__(self, endpoint_a: NIC, endpoint_b: NIC, bandwidth: int = 100): + def model_post_init(self, __context: Any) -> None: """ - Initialize the Link instance. + Ensure that endpoint_a and endpoint_b are not the same :class:`~primaite.simulator.network.physical_layer.NIC`. - When a Link is created, it automatically connects the endpoints to itself. - - :param endpoint_a: The first NIC connected to the link. - :type endpoint_a: NIC - :param endpoint_b: The second NIC connected to the link. - :type endpoint_b: NIC - :param bandwidth: The bandwidth of the link in Mbps (default is 100 Mbps). - :type bandwidth: int - :raise ValueError: If endpoint_a equals endpoint_b. + :raises ValueError: If endpoint_a and endpoint_b are the same NIC. """ - super().__init__(endpoint_a=endpoint_a, endpoint_b=endpoint_b, bandwidth=bandwidth) if self.endpoint_a == self.endpoint_b: - raise ValueError("endpoint_a and endpoint_b cannot be the same NIC") + msg = "endpoint_a and endpoint_b cannot be the same NIC" + _LOGGER.error(msg) + raise ValueError(msg) + self.endpoint_a.connect_link(self) + self.endpoint_b.connect_link(self) def send_frame(self, sender_nic: NIC, frame): """ @@ -132,11 +130,11 @@ class NIC(SimComponent): :param enabled: Indicates whether the NIC is enabled. """ - ip_address: IPv4Address + ip_address: Union[str, IPv4Address] "The IP address assigned to the NIC for communication on an IP-based network." subnet_mask: str "The subnet mask assigned to the NIC." - gateway: IPv4Address + gateway: Union[str, IPv4Address] "The default gateway IP address for forwarding network traffic to other networks. Randomly generated upon creation." mac_address: str = generate_mac_address() "The MAC address of the NIC. Defaults to a randomly set MAC address." @@ -153,14 +151,56 @@ class NIC(SimComponent): enabled: bool = False "Indicates whether the NIC is enabled." + def model_post_init(self, __context: Any) -> None: + """ + Post init function converts string IPs to IPv$Address and checks for proper IP address and gateway config. + + :raises ValueError: When the ip_address and gateway are the same. And when the ip_address/subnet mask are a + network address. + """ + if not isinstance(self.ip_address, IPv4Address): + self.ip_address: IPv4Address = IPv4Address(self.ip_address) + if not isinstance(self.gateway, IPv4Address): + self.gateway: IPv4Address = IPv4Address(self.gateway) + if self.ip_address == self.gateway: + msg = f"NIC ip address {self.ip_address} cannot be the same as the gateway {self.gateway}" + _LOGGER.error(msg) + raise ValueError(msg) + if self.ip_network.network_address == self.ip_address: + msg = ( + f"Failed to set IP address {self.ip_address} and subnet mask {self.subnet_mask} as it is a " + f"network address {self.ip_network.network_address}" + ) + _LOGGER.error(msg) + raise ValueError(msg) + + @property + def ip_network(self) -> IPv4Network: + """ + Return the IPv4Network of the NIC. + + :return: The IPv4Network from the ip_address/subnet mask. + """ + return IPv4Network(f"{self.ip_address}/{self.subnet_mask}", strict=False) + def connect_link(self, link: Link): """ Connect the NIC to a link. :param link: The link to which the NIC is connected. :type link: :class:`~primaite.simulator.network.physical_layer.Link` + :raise NetworkError: When an attempt to connect a Link is made while the NIC has a connected Link. """ - pass + if not self.connected_link: + if self.connected_link != link: + # TODO: Inform the Node that a link has been connected + self.connected_link = link + else: + _LOGGER.warning(f"Cannot connect link to NIC ({self.mac_address}) as it is already connected") + else: + msg = f"Cannot connect link to NIC ({self.mac_address}) as it already has a connection" + _LOGGER.error(msg) + raise NetworkError(msg) def disconnect_link(self): """Disconnect the NIC from the connected :class:`~primaite.simulator.network.physical_layer.Link`.""" diff --git a/tests/integration_tests/__init__.py b/tests/integration_tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/network/__init__.py b/tests/integration_tests/network/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/network/test_nic_link_connection.py b/tests/integration_tests/network/test_nic_link_connection.py new file mode 100644 index 00000000..1a191200 --- /dev/null +++ b/tests/integration_tests/network/test_nic_link_connection.py @@ -0,0 +1,14 @@ +import pytest + +from primaite.simulator.network.physical_layer import Link, NIC + + +def test_link_fails_with_same_nic(): + """Tests Link creation fails with endpoint_a and endpoint_b are the same NIC.""" + with pytest.raises(ValueError): + nic_a = NIC( + ip_address="192.168.1.2", + subnet_mask="255.255.255.0", + gateway="192.168.0.1", + ) + Link(endpoint_a=nic_a, endpoint_b=nic_a) diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_physical_layer.py b/tests/unit_tests/_primaite/_simulator/_network/test_physical_layer.py index 137e2cd6..ad1226a6 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_physical_layer.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_physical_layer.py @@ -1,8 +1,9 @@ import re +from ipaddress import IPv4Address import pytest -from primaite.simulator.network.physical_layer import generate_mac_address +from primaite.simulator.network.physical_layer import generate_mac_address, NIC def test_mac_address_generation(): @@ -24,3 +25,47 @@ def test_invalid_oui_mac_address(): invalid_oui = "aa-bb-cc" with pytest.raises(ValueError): generate_mac_address(oui=invalid_oui) + + +def test_nic_ip_address_type_conversion(): + """Tests NIC IP and gateway address is converted to IPv4Address is originally a string.""" + nic = NIC( + ip_address="192.168.1.2", + subnet_mask="255.255.255.0", + gateway="192.168.0.1", + ) + assert isinstance(nic.ip_address, IPv4Address) + assert isinstance(nic.gateway, IPv4Address) + + +def test_nic_deserialize(): + """Tests NIC serialization and deserialization.""" + nic = NIC( + ip_address="192.168.1.2", + subnet_mask="255.255.255.0", + gateway="192.168.0.1", + ) + + nic_json = nic.model_dump_json() + deserialized_nic = NIC.model_validate_json(nic_json) + assert nic == deserialized_nic + + +def test_nic_ip_address_as_gateway_fails(): + """Tests NIC creation fails if ip address is the same as the gateway.""" + with pytest.raises(ValueError): + NIC( + ip_address="192.168.0.1", + subnet_mask="255.255.255.0", + gateway="192.168.0.1", + ) + + +def test_nic_ip_address_as_network_address_fails(): + """Tests NIC creation fails if ip address and subnet mask are a network address.""" + with pytest.raises(ValueError): + NIC( + ip_address="192.168.0.0", + subnet_mask="255.255.255.0", + gateway="192.168.0.1", + )