#2248 - All tests (bar the one config file test) now working. Still need to tidy up docstrings and some docs. Almost there

This commit is contained in:
Chris McCarthy
2024-02-07 23:05:34 +00:00
parent 5e25fefa14
commit 0c96fef3ec
29 changed files with 270 additions and 235 deletions

View File

@@ -319,11 +319,11 @@ class PrimaiteGame:
node_a = net.nodes[game.ref_map_nodes[link_cfg["endpoint_a_ref"]]]
node_b = net.nodes[game.ref_map_nodes[link_cfg["endpoint_b_ref"]]]
if isinstance(node_a, Switch):
endpoint_a = node_a.switch_ports[link_cfg["endpoint_a_port"]]
endpoint_a = node_a.network_interface[link_cfg["endpoint_a_port"]]
else:
endpoint_a = node_a.network_interface[link_cfg["endpoint_a_port"]]
if isinstance(node_b, Switch):
endpoint_b = node_b.switch_ports[link_cfg["endpoint_b_port"]]
endpoint_b = node_b.network_interface[link_cfg["endpoint_b_port"]]
else:
endpoint_b = node_b.network_interface[link_cfg["endpoint_b_port"]]
new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)

View File

@@ -149,7 +149,8 @@ class Network(SimComponent):
for nodes in nodes_type_map.values():
for node in nodes:
for i, port in node.network_interface.items():
table.add_row([node.hostname, i, port.ip_address, port.subnet_mask, node.default_gateway])
if hasattr(port, "ip_address"):
table.add_row([node.hostname, i, port.ip_address, port.subnet_mask, node.default_gateway])
print(table)
if links:

View File

@@ -109,9 +109,9 @@ def create_office_lan(
switch.power_on()
network.add_node(switch)
if num_of_switches > 1:
network.connect(core_switch.switch_ports[core_switch_port], switch.switch_ports[24])
network.connect(core_switch.network_interface[core_switch_port], switch.network_interface[24])
else:
network.connect(router.network_interface[1], switch.switch_ports[24])
network.connect(router.network_interface[1], switch.network_interface[24])
# Add PCs to the LAN and connect them to switches
for i in range(1, num_pcs + 1):
@@ -125,9 +125,9 @@ def create_office_lan(
# Connect the new switch to the router or core switch
if num_of_switches > 1:
core_switch_port += 1
network.connect(core_switch.switch_ports[core_switch_port], switch.switch_ports[24])
network.connect(core_switch.network_interface[core_switch_port], switch.network_interface[24])
else:
network.connect(router.network_interface[1], switch.switch_ports[24])
network.connect(router.network_interface[1], switch.network_interface[24])
# Create and add a PC to the network
pc = Computer(
@@ -142,7 +142,7 @@ def create_office_lan(
# Connect the PC to the switch
switch_port += 1
network.connect(switch.switch_ports[switch_port], pc.network_interface[1])
switch.switch_ports[switch_port].enable()
network.connect(switch.network_interface[switch_port], pc.network_interface[1])
switch.network_interface[switch_port].enable()
return network

View File

@@ -197,6 +197,12 @@ class WiredNetworkInterface(NetworkInterface, ABC):
)
return
if not self._connected_link:
self._connected_node.sys_log.info(
f"Interface {self} cannot be enabled as there is no Link connected."
)
return
self.enabled = True
self._connected_node.sys_log.info(f"Network Interface {self} enabled")
self.pcap = PacketCapture(hostname=self._connected_node.hostname, interface_num=self.port_num)
@@ -351,6 +357,12 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
Derived classes should define specific behaviors and properties of an IP-capable wired network interface,
customizing it for their specific use cases.
"""
_connected_link: Optional[Link] = None
"The network link to which the network interface is connected."
def model_post_init(self, __context: Any) -> None:
if self.ip_network.network_address == self.ip_address:
raise ValueError(f"{self.ip_address}/{self.subnet_mask} must not be a network address")
def describe_state(self) -> Dict:
"""
@@ -375,7 +387,7 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
except AttributeError:
pass
@abstractmethod
# @abstractmethod
def receive_frame(self, frame: Frame) -> bool:
"""
Receives a network frame on the network interface.
@@ -819,6 +831,13 @@ class Node(SimComponent):
table.add_row([port.value, port.name])
print(table)
@property
def has_enabled_network_interface(self) -> bool:
for network_interface in self.network_interfaces.values():
if network_interface.enabled:
return True
return False
def show_nic(self, markdown: bool = False):
"""Prints a table of the NICs on the Node."""
table = PrettyTable(["Port", "Type", "MAC Address", "Address", "Speed", "Status"])
@@ -830,7 +849,7 @@ class Node(SimComponent):
table.add_row(
[
port,
network_interface.__name__,
type(network_interface),
network_interface.mac_address,
f"{network_interface.ip_address}/{network_interface.ip_network.prefixlen}",
network_interface.speed,

View File

@@ -1,10 +1,10 @@
from __future__ import annotations
from typing import Dict
from typing import Dict, Any
from typing import Optional
from primaite import getLogger
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface, Link
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.transmission.data_link_layer import Frame
@@ -45,7 +45,7 @@ class HostARP(ARP):
:return: The NIC associated with the default gateway if it exists in the ARP cache, otherwise None.
"""
if self.software_manager.node.default_gateway:
if self.software_manager.node.default_gateway and self.software_manager.node.has_enabled_network_interface:
return self.get_arp_cache_network_interface(self.software_manager.node.default_gateway)
def _get_arp_cache_mac_address(
@@ -175,12 +175,14 @@ class NIC(IPWiredNetworkInterface):
and disconnect from network links and to manage the enabled/disabled state of the interface.
- Layer3Interface: Provides properties for Layer 3 network configuration, such as IP address and subnet mask.
"""
_connected_link: Optional[Link] = None
"The network link to which the network interface is connected."
wake_on_lan: bool = False
"Indicates if the NIC supports Wake-on-LAN functionality."
def __init__(self, **kwargs):
super().__init__(**kwargs)
def model_post_init(self, __context: Any) -> None:
if self.ip_network.network_address == self.ip_address:
raise ValueError(f"{self.ip_address}/{self.subnet_mask} must not be a network address")
def describe_state(self) -> Dict:
"""
@@ -353,7 +355,6 @@ class HostNode(Node):
if accept_frame:
self.session_manager.receive_frame(frame, from_network_interface)
else:
# denied as port closed
self.sys_log.info(f"Ignoring frame for port {frame.tcp.dst_port.value} from {frame.ip.src_ip_address}")
self.sys_log.info(f"Ignoring frame from {frame.ip.src_ip_address}")
# TODO: do we need to do anything more here?
pass

View File

@@ -555,6 +555,14 @@ class RouterARP(ARP):
return arp_entry.mac_address
if not is_reattempt:
if self.router.ip_is_in_router_interface_subnet(ip_address):
self.send_arp_request(ip_address)
return self._get_arp_cache_mac_address(
ip_address=ip_address,
is_reattempt=True,
is_default_route_attempt=is_default_route_attempt
)
route = self.router.route_table.find_best_route(ip_address)
if route and route != self.router.route_table.default_route:
self.send_arp_request(route.next_hop_ip_address)
@@ -818,7 +826,7 @@ class Router(NetworkNode):
network_interfaces: Dict[str, RouterInterface] = {}
"The Router Interfaces on the node."
network_interface: Dict[int, RouterInterface] = {}
"The Router Interfaceson the node by port id."
"The Router Interfaces on the node by port id."
acl: AccessControlList
route_table: RouteTable
@@ -885,6 +893,15 @@ class Router(NetworkNode):
return True
return False
def ip_is_in_router_interface_subnet(self, ip_address: IPV4Address, enabled_only: bool = False) -> bool:
for router_interface in self.network_interface.values():
if ip_address in router_interface.ip_network:
if enabled_only:
return router_interface.enabled
else:
return True
return False
def _get_port_of_nic(self, target_nic: RouterInterface) -> Optional[int]:
"""
Retrieve the port number for a given NIC.

View File

@@ -96,16 +96,18 @@ class Switch(NetworkNode):
num_ports: int = 24
"The number of ports on the switch."
switch_ports: Dict[int, SwitchPort] = {}
"The SwitchPorts on the switch."
network_interfaces: Dict[str, SwitchPort] = {}
"The SwitchPorts on the Switch."
network_interface: Dict[int, SwitchPort] = {}
"The SwitchPorts on the Switch by port id."
mac_address_table: Dict[str, SwitchPort] = {}
"A MAC address table mapping destination MAC addresses to corresponding SwitchPorts."
def __init__(self, **kwargs):
super().__init__(**kwargs)
if not self.switch_ports:
self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
for port_num, port in self.switch_ports.items():
if not self.network_interface:
self.network_interface = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
for port_num, port in self.network_interface.items():
port._connected_node = self
port.port_num = port_num
port.parent = self
@@ -122,7 +124,7 @@ class Switch(NetworkNode):
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Switch Ports"
for port_num, port in self.switch_ports.items():
for port_num, port in self.network_interface.items():
table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"])
print(table)
@@ -133,7 +135,7 @@ class Switch(NetworkNode):
:return: Current state of this object and child objects.
"""
state = super().describe_state()
state["ports"] = {port_num: port.describe_state() for port_num, port in self.switch_ports.items()}
state["ports"] = {port_num: port.describe_state() for port_num, port in self.network_interface.items()}
state["num_ports"] = self.num_ports # redundant?
state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()}
return state
@@ -171,7 +173,7 @@ class Switch(NetworkNode):
outgoing_port.send_frame(frame)
else:
# If the destination MAC is not in the table, flood to all ports except incoming
for port in self.switch_ports.values():
for port in self.network_interface.values():
if port.enabled and port != from_network_interface:
port.send_frame(frame)
@@ -183,7 +185,7 @@ class Switch(NetworkNode):
:param port_number: The port number on the switch from where the link should be disconnected.
:raise NetworkError: When an invalid port number is provided or the link does not match the connection.
"""
port = self.switch_ports.get(port_number)
port = self.network_interface.get(port_number)
if port is None:
msg = f"Invalid port number {port_number} on the switch"
_LOGGER.error(msg)

View File

@@ -41,13 +41,13 @@ def client_server_routed() -> Network:
# Switch 1
switch_1 = Switch(hostname="switch_1", num_ports=6)
switch_1.power_on()
network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.switch_ports[6])
network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[6])
router_1.enable_port(1)
# Switch 2
switch_2 = Switch(hostname="switch_2", num_ports=6)
switch_2.power_on()
network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.switch_ports[6])
network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[6])
router_1.enable_port(2)
# Client 1
@@ -56,10 +56,10 @@ def client_server_routed() -> Network:
ip_address="192.168.2.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.2.1",
operating_state=NodeOperatingState.ON,
start_up_duration=0
)
client_1.power_on()
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.switch_ports[1])
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
# Server 1
server_1 = Server(
@@ -67,10 +67,10 @@ def client_server_routed() -> Network:
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
start_up_duration=0
)
server_1.power_on()
network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.switch_ports[1])
network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1])
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
@@ -119,21 +119,21 @@ def arcd_uc2_network() -> Network:
network = Network()
# Router 1
router_1 = Router(hostname="router_1", num_ports=5, operating_state=NodeOperatingState.ON)
router_1 = Router(hostname="router_1", num_ports=5, start_up_duration=0)
router_1.power_on()
router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0")
# Switch 1
switch_1 = Switch(hostname="switch_1", num_ports=8, operating_state=NodeOperatingState.ON)
switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0)
switch_1.power_on()
network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.switch_ports[8])
network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[8])
router_1.enable_port(1)
# Switch 2
switch_2 = Switch(hostname="switch_2", num_ports=8, operating_state=NodeOperatingState.ON)
switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0)
switch_2.power_on()
network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.switch_ports[8])
network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8])
router_1.enable_port(2)
# Client 1
@@ -143,10 +143,10 @@ def arcd_uc2_network() -> Network:
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
dns_server=IPv4Address("192.168.1.10"),
operating_state=NodeOperatingState.ON,
start_up_duration=0
)
client_1.power_on()
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.switch_ports[1])
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
client_1.software_manager.install(DataManipulationBot)
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
db_manipulation_bot.configure(
@@ -163,12 +163,12 @@ def arcd_uc2_network() -> Network:
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
dns_server=IPv4Address("192.168.1.10"),
operating_state=NodeOperatingState.ON,
start_up_duration=0
)
client_2.power_on()
web_browser = client_2.software_manager.software.get("WebBrowser")
web_browser.target_url = "http://arcd.com/users/"
network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.switch_ports[2])
network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.network_interface[2])
# Domain Controller
domain_controller = Server(
@@ -176,12 +176,12 @@ def arcd_uc2_network() -> Network:
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
start_up_duration=0
)
domain_controller.power_on()
domain_controller.software_manager.install(DNSServer)
network.connect(endpoint_b=domain_controller.network_interface[1], endpoint_a=switch_1.switch_ports[1])
network.connect(endpoint_b=domain_controller.network_interface[1], endpoint_a=switch_1.network_interface[1])
# Database Server
database_server = Server(
@@ -190,10 +190,10 @@ def arcd_uc2_network() -> Network:
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
operating_state=NodeOperatingState.ON,
start_up_duration=0
)
database_server.power_on()
network.connect(endpoint_b=database_server.network_interface[1], endpoint_a=switch_1.switch_ports[3])
network.connect(endpoint_b=database_server.network_interface[1], endpoint_a=switch_1.network_interface[3])
ddl = """
CREATE TABLE IF NOT EXISTS user (
@@ -264,14 +264,14 @@ def arcd_uc2_network() -> Network:
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
operating_state=NodeOperatingState.ON,
start_up_duration=0
)
web_server.power_on()
web_server.software_manager.install(DatabaseClient)
database_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
database_client.configure(server_ip_address=IPv4Address("192.168.1.14"))
network.connect(endpoint_b=web_server.network_interface[1], endpoint_a=switch_1.switch_ports[2])
network.connect(endpoint_b=web_server.network_interface[1], endpoint_a=switch_1.network_interface[2])
database_client.run()
database_client.connect()
@@ -279,7 +279,7 @@ def arcd_uc2_network() -> Network:
# register the web_server to a domain
dns_server_service: DNSServer = domain_controller.software_manager.software.get("DNSServer") # noqa
dns_server_service.dns_register("arcd.com", web_server.ip_address)
dns_server_service.dns_register("arcd.com", web_server.network_interface[1].ip_address)
# Backup Server
backup_server = Server(
@@ -288,11 +288,11 @@ def arcd_uc2_network() -> Network:
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
operating_state=NodeOperatingState.ON,
start_up_duration=0
)
backup_server.power_on()
backup_server.software_manager.install(FTPServer)
network.connect(endpoint_b=backup_server.network_interface[1], endpoint_a=switch_1.switch_ports[4])
network.connect(endpoint_b=backup_server.network_interface[1], endpoint_a=switch_1.network_interface[4])
# Security Suite
security_suite = Server(
@@ -301,12 +301,12 @@ def arcd_uc2_network() -> Network:
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
operating_state=NodeOperatingState.ON,
start_up_duration=0
)
security_suite.power_on()
network.connect(endpoint_b=security_suite.network_interface[1], endpoint_a=switch_1.switch_ports[7])
network.connect(endpoint_b=security_suite.network_interface[1], endpoint_a=switch_1.network_interface[7])
security_suite.connect_nic(NIC(ip_address="192.168.10.110", subnet_mask="255.255.255.0"))
network.connect(endpoint_b=security_suite.network_interface[2], endpoint_a=switch_2.switch_ports[7])
network.connect(endpoint_b=security_suite.network_interface[2], endpoint_a=switch_2.network_interface[7])
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)

View File

@@ -23,6 +23,7 @@ class DatabaseClient(Application):
server_ip_address: Optional[IPv4Address] = None
server_password: Optional[str] = None
connected: bool = False
_query_success_tracker: Dict[str, bool] = {}
def __init__(self, **kwargs):
@@ -73,9 +74,10 @@ class DatabaseClient(Application):
if not connection_id:
connection_id = str(uuid4())
return self._connect(
self.connected = self._connect(
server_ip_address=self.server_ip_address, password=self.server_password, connection_id=connection_id
)
return self.connected
def _connect(
self,
@@ -147,6 +149,7 @@ class DatabaseClient(Application):
self.sys_log.info(
f"{self.name}: DatabaseClient disconnected connection {connection_id} from {self.server_ip_address}"
)
self.connected = False
def _query(self, sql: str, query_id: str, connection_id: str, is_reattempt: bool = False) -> bool:
"""

View File

@@ -108,13 +108,14 @@ class NTPClient(Service):
def request_time(self) -> None:
"""Send request to ntp_server."""
self.software_manager.session_manager.receive_payload_from_software_manager(
payload=NTPPacket(),
dst_ip_address=self.ntp_server,
src_port=self.port,
dst_port=self.port,
ip_protocol=self.protocol,
)
if self.ntp_server:
self.software_manager.session_manager.receive_payload_from_software_manager(
payload=NTPPacket(),
dst_ip_address=self.ntp_server,
src_port=self.port,
dst_port=self.port,
ip_protocol=self.protocol,
)
def apply_timestep(self, timestep: int) -> None:
"""