Files
PrimAITE/src/primaite/simulator/network/container.py

382 lines
15 KiB
Python

# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional
import matplotlib.pyplot as plt
import networkx as nx
from networkx import MultiGraph
from prettytable import MARKDOWN, PrettyTable
from pydantic import Field
from primaite import getLogger
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.network.airspace import AirSpace
from primaite.simulator.network.hardware.base import Link, Node, WiredNetworkInterface
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.network.hardware.nodes.host.server import Printer
from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.services.service import Service
_LOGGER = getLogger(__name__)
class Network(SimComponent):
"""
Top level container object representing the physical network.
This class manages nodes, links, and other network components. It also
offers methods for rendering the network topology and gathering states.
:ivar Dict[str, Node] nodes: Dictionary mapping node UUIDs to Node instances.
:ivar Dict[str, Link] links: Dictionary mapping link UUIDs to Link instances.
"""
nodes: Dict[str, Node] = {}
links: Dict[str, Link] = {}
airspace: AirSpace = Field(default_factory=lambda: AirSpace())
_node_id_map: Dict[int, Node] = {}
_link_id_map: Dict[int, Node] = {}
def __init__(self, **kwargs):
"""
Initialise the network.
Constructs the network and sets up its initial state including
the request manager and an empty MultiGraph for topology representation.
"""
super().__init__(**kwargs)
self._nx_graph = MultiGraph()
def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
for node in self.nodes.values():
node.setup_for_episode(episode=episode)
for link in self.links.values():
link.setup_for_episode(episode=episode)
for node in self.nodes.values():
node.power_on()
for network_interface in node.network_interfaces.values():
network_interface.enable()
# Reset software
for software in node.software_manager.software.values():
if isinstance(software, Service):
software.start()
elif isinstance(software, Application):
software.run()
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
self._node_request_manager = RequestManager()
rm.add_request(
"node",
RequestType(func=self._node_request_manager),
)
return rm
def apply_timestep(self, timestep: int) -> None:
"""Apply a timestep evolution to this the network and its nodes and links."""
super().apply_timestep(timestep=timestep)
# apply timestep to nodes
for node_id in self.nodes:
self.nodes[node_id].apply_timestep(timestep=timestep)
# apply timestep to links
for link_id in self.links:
self.links[link_id].apply_timestep(timestep=timestep)
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
self.airspace.reset_bandwidth_load()
for node in self.nodes.values():
node.pre_timestep(timestep)
for link in self.links.values():
link.pre_timestep(timestep)
@property
def router_nodes(self) -> List[Node]:
"""The Routers in the Network."""
return [node for node in self.nodes.values() if node.__class__.__name__ == "Router"]
@property
def switch_nodes(self) -> List[Node]:
"""The Switches in the Network."""
return [node for node in self.nodes.values() if node.__class__.__name__ == "Switch"]
@property
def computer_nodes(self) -> List[Node]:
"""The Computers in the Network."""
return [node for node in self.nodes.values() if node.__class__.__name__ == "Computer"]
@property
def server_nodes(self) -> List[Node]:
"""The Servers in the Network."""
return [node for node in self.nodes.values() if node.__class__.__name__ == "Server"]
@property
def firewall_nodes(self) -> List[Node]:
"""The Firewalls in the Network."""
return [node for node in self.nodes.values() if node.__class__.__name__ == "Firewall"]
@property
def extended_hostnodes(self) -> List[Node]:
"""Extended nodes that inherited HostNode in the network."""
return [node for node in self.nodes.values() if node.__class__.__name__.lower() in HostNode._registry]
@property
def extended_networknodes(self) -> List[Node]:
"""Extended nodes that inherited NetworkNode in the network."""
return [node for node in self.nodes.values() if node.__class__.__name__.lower() in NetworkNode._registry]
@property
def printer_nodes(self) -> List[Node]:
"""The printers on the network."""
return [node for node in self.nodes.values() if isinstance(node, Printer)]
@property
def wireless_router_nodes(self) -> List[Node]:
"""The Routers in the Network."""
return [node for node in self.nodes.values() if node.__class__.__name__ == "WirelessRouter"]
def show(self, nodes: bool = True, ip_addresses: bool = True, links: bool = True, markdown: bool = False):
"""
Print tables describing the Network.
Generate and print PrettyTable instances that show details about nodes,
IP addresses, and links in the network. Output can be in Markdown format.
:param nodes: Include node details in the output. Defaults to True.
:param ip_addresses: Include IP address details in the output. Defaults to True.
:param links: Include link details in the output. Defaults to True.
:param markdown: Use Markdown style in table output. Defaults to False.
"""
nodes_type_map = {
"Router": self.router_nodes,
"Firewall": self.firewall_nodes,
"Switch": self.switch_nodes,
"Server": self.server_nodes,
"Computer": self.computer_nodes,
"Printer": self.printer_nodes,
"Wireless Router": self.wireless_router_nodes,
}
if nodes:
table = PrettyTable(["Node", "Type", "Operating State"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = "Nodes"
for node in self.nodes.values():
table.add_row((node.hostname, type(node)._identifier, node.operating_state.name))
print(table)
if ip_addresses:
table = PrettyTable(["Node", "Port", "IP Address", "Subnet Mask", "Default Gateway"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = "IP Addresses"
for nodes in nodes_type_map.values():
for node in nodes:
for i, port in node.network_interface.items():
if hasattr(port, "ip_address"):
if port.ip_address != IPv4Address("127.0.0.1"):
port_str = port.port_name if port.port_name else port.port_num
table.add_row(
[node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway]
)
print(table)
if links:
table = PrettyTable(
["Endpoint A", "A Port", "Endpoint B", "B Port", "is Up", "Bandwidth (MBits)", "Current Load"]
)
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = "Links"
links = list(self.links.values())
for nodes in nodes_type_map.values():
for node in nodes:
for link in links[::-1]:
if node in [link.endpoint_a.parent, link.endpoint_b.parent]:
table.add_row(
[
link.endpoint_a.parent.hostname,
str(link.endpoint_a),
link.endpoint_b.parent.hostname,
str(link.endpoint_b),
link.is_up,
link.bandwidth,
link.current_load_percent,
]
)
links.remove(link)
print(table)
def clear_links(self):
"""Clear all the links in the network by resetting their component state for the episode."""
for link in self.links.values():
link.setup_for_episode(episode=0) # TODO: shouldn't be using this method here.
def draw(self, seed: int = 123):
"""
Draw the Network using NetworkX and matplotlib.pyplot.
:param seed: An integer seed for reproducible layouts. Default is 123.
"""
pos = nx.spring_layout(self._nx_graph, seed=seed)
nx.draw(self._nx_graph, pos, with_labels=True)
plt.show()
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of the Network.
:return: A dictionary capturing the current state of the Network and its child objects.
"""
state = super().describe_state()
state.update(
{
"nodes": {node.hostname: node.describe_state() for node in self.nodes.values()},
"links": {},
}
)
# Update the links one-by-one. The key is a 4-tuple of `hostname_a, port_a, hostname_b, port_b`
for _, link in self.links.items():
node_a = link.endpoint_a._connected_node
node_b = link.endpoint_b._connected_node
hostname_a = node_a.hostname if node_a else None
hostname_b = node_b.hostname if node_b else None
port_a = link.endpoint_a.port_num
port_b = link.endpoint_b.port_num
link_key = f"{hostname_a}:eth-{port_a}<->{hostname_b}:eth-{port_b}"
state["links"][link_key] = link.describe_state()
state["links"][link_key]["hostname_a"] = hostname_a
state["links"][link_key]["hostname_b"] = hostname_b
state["links"][link_key]["port_a"] = port_a
state["links"][link_key]["port_b"] = port_b
return state
def add_node(self, node: Node) -> None:
"""
Add an existing node to the network.
.. note:: If the node is already present in the network, a warning is logged.
:param node: Node instance that should be kept track of by the network.
"""
if node in self:
_LOGGER.warning(f"Can't add node {node.uuid}. It is already in the network.")
return
self.nodes[node.uuid] = node
self._node_id_map[len(self.nodes)] = node
node.parent = self
self._nx_graph.add_node(node.hostname)
_LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}")
self._node_request_manager.add_request(name=node.hostname, request_type=RequestType(func=node._request_manager))
def get_node_by_hostname(self, hostname: str) -> Optional[Node]:
"""
Get a Node from the Network by its hostname.
.. note:: Assumes hostnames on the network are unique.
:param hostname: The Node hostname.
:return: The Node if it exists in the network.
"""
for node in self.nodes.values():
if node.hostname == hostname:
return node
def remove_node(self, node: Node) -> None:
"""
Remove a node from the network.
.. note:: If the node is not found in the network, a warning is logged.
:param node: Node instance that is currently part of the network that should be removed.
:type node: Node
"""
if node not in self:
_LOGGER.warning(f"Can't remove node {node.hostname}. It's not in the network.")
return
self.nodes.pop(node.uuid)
for i, _node in self._node_id_map.items():
if node == _node:
self._node_id_map.pop(i)
break
node.parent = None
self._node_request_manager.remove_request(name=node.hostname)
_LOGGER.info(f"Removed node {node.hostname} from network {self.uuid}")
def connect(
self, endpoint_a: WiredNetworkInterface, endpoint_b: WiredNetworkInterface, bandwidth: int = 100, **kwargs
) -> Optional[Link]:
"""
Connect two endpoints on the network by creating a link between their NICs/SwitchPorts.
.. note:: If the nodes owning the endpoints are not already in the network, they are automatically added.
:param endpoint_a: The first endpoint to connect.
:type endpoint_a: WiredNetworkInterface
:param endpoint_b: The second endpoint to connect.
:type endpoint_b: WiredNetworkInterface
:param bandwidth: bandwidth of new link, default of 100mbps
:type bandwidth: int
:raises RuntimeError: If any validation or runtime checks fail.
"""
node_a: Node = endpoint_a.parent
node_b: Node = endpoint_b.parent
if node_a not in self:
self.add_node(node_a)
if node_b not in self:
self.add_node(node_b)
if node_a is node_b:
_LOGGER.warning(f"Cannot link endpoint {endpoint_a} to {endpoint_b} because they belong to the same node.")
return
link = Link(endpoint_a=endpoint_a, endpoint_b=endpoint_b, bandwidth=bandwidth, **kwargs)
self.links[link.uuid] = link
self._link_id_map[len(self.links)] = link
self._nx_graph.add_edge(endpoint_a.parent.hostname, endpoint_b.parent.hostname)
link.parent = self
_LOGGER.debug(f"Added link {link.uuid} to connect {endpoint_a} and {endpoint_b}")
return link
def remove_link(self, link: Link) -> None:
"""Disconnect a link from the network.
:param link: The link to be removed
:type link: Link
"""
link.endpoint_a.disconnect_link()
link.endpoint_b.disconnect_link()
self.links.pop(link.uuid)
for i, _link in self._link_id_map.items():
if link == _link:
self._link_id_map.pop(i)
break
link.parent = None
_LOGGER.info(f"Removed link {link.uuid} from network {self.uuid}.")
def __contains__(self, item: Any) -> bool:
if isinstance(item, Node):
return item.uuid in self.nodes
elif isinstance(item, Link):
return item.uuid in self.links
return False