diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 63120ecf..b7dfcf72 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -1,6 +1,6 @@ """Core of the PrimAITE Simulator.""" from abc import ABC, abstractmethod -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Union from uuid import uuid4 from pydantic import BaseModel, ConfigDict, Extra @@ -199,9 +199,9 @@ class SimComponent(BaseModel): return self._parent @parent.setter - def parent(self, new_parent: "SimComponent") -> None: - if self._parent: - msg = f"Overwriting parent of {self}, {self._parent} with {new_parent}" + def parent(self, new_parent: Union["SimComponent", None]) -> None: + if self._parent and new_parent: + msg = f"Overwriting parent of {self.uuid}. Old parent: {self._parent.uuid}, New parent: {new_parent.uuid}" _LOGGER.warn(msg) raise RuntimeWarning(msg) self._parent = new_parent diff --git a/tests/integration_tests/network/test_network_creation.py b/tests/integration_tests/network/test_network_creation.py index 418f5e5f..356eb1db 100644 --- a/tests/integration_tests/network/test_network_creation.py +++ b/tests/integration_tests/network/test_network_creation.py @@ -22,8 +22,7 @@ def test_readding_node(): net = Network() n1 = Node(hostname="computer") net.add_node(n1) - with pytest.raises(RuntimeWarning): - net.add_node(n1) + net.add_node(n1) assert n1.parent is net assert n1 in net @@ -32,8 +31,7 @@ def test_removing_nonexistent_node(): """Check that warning is raised when trying to remove a node that is not in the network.""" net = Network() n1 = Node(hostname="computer") - with pytest.raises(RuntimeWarning): - net.remove_node(n1) + net.remove_node(n1) assert n1.parent is None assert n1 not in net @@ -69,8 +67,7 @@ def test_connecting_node_to_itself(): net.add_node(node) - with pytest.raises(RuntimeError): - net.connect(node.nics[nic1.uuid], node.nics[nic2.uuid], bandwidth=30) + net.connect(node.nics[nic1.uuid], node.nics[nic2.uuid], bandwidth=30) assert node in net assert nic1.connected_link is None @@ -79,4 +76,22 @@ def test_connecting_node_to_itself(): def test_disconnecting_nodes(): - ... + net = Network() + + n1 = Node(hostname="computer") + n1_nic = NIC(ip_address="120.30.0.1", gateway="192.168.0.1", subnet_mask="255.255.255.0") + n1.connect_nic(n1_nic) + net.add_node(n1) + + n2 = Node(hostname="server") + n2_nic = NIC(ip_address="120.30.0.2", gateway="192.168.0.1", subnet_mask="255.255.255.0") + n2.connect_nic(n2_nic) + net.add_node(n2) + + net.connect(n1.nics[n1_nic.uuid], n2.nics[n2_nic.uuid], bandwidth=30) + assert len(net.links) == 1 + + link = list(net.links.values())[0] + net.remove_link(link) + assert link not in net + assert len(net.links) == 0