Merge branch 'feature/1789-add-convenience-methods-to-network-object' into feature/1801-Database

This commit is contained in:
Marek Wolan
2023-08-24 14:41:09 +01:00
8 changed files with 309 additions and 50 deletions

1
.gitignore vendored
View File

@@ -150,3 +150,4 @@ src/primaite/outputs/
# benchmark session outputs
benchmark/output
src/primaite/notebooks/scratch.ipynb

File diff suppressed because one or more lines are too long

View File

@@ -137,6 +137,7 @@ class SimComponent(BaseModel):
kwargs["uuid"] = str(uuid4())
super().__init__(**kwargs)
self.action_manager: Optional[ActionManager] = None
self._parent: Optional["SimComponent"] = None
@abstractmethod
def describe_state(self) -> Dict:
@@ -187,3 +188,24 @@ class SimComponent(BaseModel):
Override this method with anything that needs to happen within the component for it to be reset.
"""
pass
@property
def parent(self) -> "SimComponent":
"""Reference to the parent object which manages this object.
:return: Parent object.
:rtype: SimComponent
"""
return self._parent
@parent.setter
def parent(self, new_parent: "SimComponent") -> None:
if self._parent:
msg = f"Overwriting parent of {self}, {self._parent} with {new_parent}"
_LOGGER.warn(msg)
raise RuntimeWarning(msg)
self._parent = new_parent
@parent.deleter
def parent(self) -> None:
self._parent = None

View File

@@ -1,10 +1,13 @@
from typing import Dict
from typing import Any, Dict, Union
from primaite import getLogger
from primaite.simulator.core import Action, ActionManager, AllowAllValidator, SimComponent
from primaite.simulator.network.hardware.base import Link, Node
from primaite.simulator.network.hardware.base import Link, NIC, Node, SwitchPort
_LOGGER = getLogger(__name__)
class NetworkContainer(SimComponent):
class Network(SimComponent):
"""Top level container object representing the physical network."""
nodes: Dict[str, Node] = {}
@@ -40,3 +43,75 @@ class NetworkContainer(SimComponent):
}
)
return state
def add_node(self, node: Node) -> None:
"""
Add an existing node to the network.
:param node: Node instance that the network should keep track of.
:type node: Node
"""
if node in self:
msg = f"Can't add node {node}. It is already in the network."
_LOGGER.warning(msg)
raise RuntimeWarning(msg)
self.nodes[node.uuid] = node
node.parent = self
def remove_node(self, node: Node) -> None:
"""
Remove a node from the network.
:param node: Node instance that is currently part of the network that should be removed.
:type node: Node
"""
if node not in self:
msg = f"Can't remove node {node}. It's not in the network."
_LOGGER.warning(msg)
raise RuntimeWarning(msg)
del self.nodes[node.uuid]
del node.parent # misleading?
def connect(self, endpoint_a: Union[NIC, SwitchPort], endpoint_b: Union[NIC, SwitchPort], **kwargs) -> None:
"""Connect two nodes on the network by creating a link between an NIC/SwitchPort of each one.
:param endpoint_a: The endpoint to which to connect the link on the first node
:type endpoint_a: Union[NIC, SwitchPort]
:param endpoint_b: The endpoint to which to connct the link on the second node
:type endpoint_b: Union[NIC, SwitchPort]
:raises RuntimeError: _description_
"""
node_a = endpoint_a.parent
node_b = endpoint_b.parent
msg = ""
if node_a not in self:
msg = f"Cannot create a link to {endpoint_a} because the node is not in the network."
if node_b not in self:
msg = f"Cannot create a link to {endpoint_b} because the node is not in the network."
if node_a is node_b:
msg = f"Cannot link {endpoint_a} to {endpoint_b} because they belong to the same node."
if msg:
_LOGGER.error(msg)
raise RuntimeError(msg)
link = Link(endpoint_a=endpoint_a, endpoint_b=endpoint_b, **kwargs)
self.links[link.uuid] = link
link.parent = self
def remove_link(self, link: Link) -> None:
"""Disconnect a link from the network.
:param link: The link to be removed
:type link: Link
"""
link.endpoint_a.disconnect_link()
link.endpoint_b.disconnect_link()
del self.links[link.uuid]
del link.parent
def __contains__(self, item: Any) -> bool:
if isinstance(item, Node):
return item.uuid in self.nodes
elif isinstance(item, Link):
return item.uuid in self.links
raise TypeError("")

View File

@@ -918,6 +918,7 @@ class Node(SimComponent):
if nic.uuid not in self.nics:
self.nics[nic.uuid] = nic
nic.connected_node = self
nic.parent = self
self.sys_log.info(f"Connected NIC {nic}")
if self.operating_state == NodeOperatingState.ON:
nic.enable()
@@ -938,6 +939,7 @@ class Node(SimComponent):
nic = self.nics.get(nic)
if nic or nic.uuid in self.nics:
self.nics.pop(nic.uuid)
del nic.parent
nic.disable()
self.sys_log.info(f"Disconnected NIC {nic}")
else:
@@ -1009,6 +1011,7 @@ class Switch(Node):
self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
for port_num, port in self.switch_ports.items():
port.connected_node = self
port.parent = self
port.port_num = port_num
def show(self):

View File

@@ -2,19 +2,19 @@ from typing import Dict
from primaite.simulator.core import Action, ActionManager, AllowAllValidator, SimComponent
from primaite.simulator.domain.controller import DomainController
from primaite.simulator.network.container import NetworkContainer
from primaite.simulator.network.container import Network
class Simulation(SimComponent):
"""Top-level simulation object which holds a reference to all other parts of the simulation."""
network: NetworkContainer
network: Network
domain: DomainController
def __init__(self, **kwargs):
"""Initialise the Simulation."""
if not kwargs.get("network"):
kwargs["network"] = NetworkContainer()
kwargs["network"] = Network()
if not kwargs.get("domain"):
kwargs["domain"] = DomainController()

View File

@@ -0,0 +1,82 @@
import pytest
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.base import NIC, Node
def test_adding_removing_nodes():
"""Check that we can create and add a node to a network."""
net = Network()
n1 = Node(hostname="computer")
net.add_node(n1)
assert n1.parent is net
assert n1 in net
net.remove_node(n1)
assert n1.parent is None
assert n1 not in net
def test_readding_node():
"""Check that warning is raised when readding a node."""
net = Network()
n1 = Node(hostname="computer")
net.add_node(n1)
with pytest.raises(RuntimeWarning):
net.add_node(n1)
assert n1.parent is net
assert n1 in net
def test_removing_nonexistent_node():
"""Check that warning is raised when trying to remove a node that is not in the network."""
net = Network()
n1 = Node(hostname="computer")
with pytest.raises(RuntimeWarning):
net.remove_node(n1)
assert n1.parent is None
assert n1 not in net
def test_connecting_nodes():
"""Check that two nodes on the network can be connected."""
net = Network()
n1 = Node(hostname="computer")
n1_nic = NIC(ip_address="120.30.0.1", gateway="192.168.0.1", subnet_mask="255.255.255.0")
n1.connect_nic(n1_nic)
n2 = Node(hostname="server")
n2_nic = NIC(ip_address="120.30.0.2", gateway="192.168.0.1", subnet_mask="255.255.255.0")
n2.connect_nic(n2_nic)
net.add_node(n1)
net.add_node(n2)
net.connect(n1.nics[n1_nic.uuid], n2.nics[n2_nic.uuid], bandwidth=30)
assert len(net.links) == 1
link = list(net.links.values())[0]
assert link in net
assert link.parent is net
def test_connecting_node_to_itself():
net = Network()
node = Node(hostname="computer")
nic1 = NIC(ip_address="120.30.0.1", gateway="192.168.0.1", subnet_mask="255.255.255.0")
node.connect_nic(nic1)
nic2 = NIC(ip_address="120.30.0.2", gateway="192.168.0.1", subnet_mask="255.255.255.0")
node.connect_nic(nic2)
net.add_node(node)
with pytest.raises(RuntimeError):
net.connect(node.nics[nic1.uuid], node.nics[nic2.uuid], bandwidth=30)
assert node in net
assert nic1.connected_link is None
assert nic2.connected_link is None
assert len(net.links) == 0
def test_disconnecting_nodes():
...

View File

@@ -0,0 +1,17 @@
import json
from primaite.simulator.network.container import Network
def test_creating_container():
"""Check that we can create a network container"""
net = Network()
assert net.nodes == {}
assert net.links == {}
def test_describe_state():
"""Check that we can describe network state without raising errors, and that the result is JSON serialisable."""
net = Network()
state = net.describe_state()
json.dumps(state) # if this function call raises an error, the test fails, state was not JSON-serialisable