#2453 - Merging in updates on dev branch

This commit is contained in:
Charlie Crane
2024-04-16 16:51:29 +01:00
57 changed files with 1447 additions and 504 deletions

View File

@@ -1,3 +1,4 @@
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional
import matplotlib.pyplot as plt
@@ -86,6 +87,16 @@ class Network(SimComponent):
for link_id in self.links:
self.links[link_id].apply_timestep(timestep=timestep)
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
for node in self.nodes.values():
node.pre_timestep(timestep)
for link in self.links.values():
link.pre_timestep(timestep)
@property
def router_nodes(self) -> List[Node]:
"""The Routers in the Network."""
@@ -163,10 +174,11 @@ class Network(SimComponent):
for node in nodes:
for i, port in node.network_interface.items():
if hasattr(port, "ip_address"):
port_str = port.port_name if port.port_name else port.port_num
table.add_row(
[node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway]
)
if port.ip_address != IPv4Address("127.0.0.1"):
port_str = port.port_name if port.port_name else port.port_num
table.add_row(
[node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway]
)
print(table)
if links:

View File

@@ -9,7 +9,7 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
def num_of_switches_required(num_nodes: int, max_network_interface: int = 24) -> int:
"""
Calculate the minimum number of network switches required to connect a given number of nodes.
@@ -18,7 +18,7 @@ def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
to accommodate all nodes under this constraint.
:param num_nodes: The total number of nodes that need to be connected in the network.
:param max_switch_ports: The maximum number of ports available on each switch. Defaults to 24.
:param max_network_interface: The maximum number of ports available on each switch. Defaults to 24.
:return: The minimum number of switches required to connect all PCs.
@@ -33,11 +33,11 @@ def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
3
"""
# Reduce the effective number of switch ports by 1 to leave space for the router
effective_switch_ports = max_switch_ports - 1
effective_network_interface = max_network_interface - 1
# Calculate the number of fully utilised switches and any additional switch for remaining PCs
full_switches = num_nodes // effective_switch_ports
extra_pcs = num_nodes % effective_switch_ports
full_switches = num_nodes // effective_network_interface
extra_pcs = num_nodes % effective_network_interface
# Return the total number of switches required
return full_switches + (1 if extra_pcs > 0 else 0)
@@ -77,7 +77,7 @@ def create_office_lan(
# Calculate the required number of switches
num_of_switches = num_of_switches_required(num_nodes=num_pcs)
effective_switch_ports = 23 # One port less for router connection
effective_network_interface = 23 # One port less for router connection
if pcs_ip_block_start <= num_of_switches:
raise ValueError(f"pcs_ip_block_start must be greater than the number of required switches {num_of_switches}")
@@ -116,7 +116,7 @@ def create_office_lan(
# Add PCs to the LAN and connect them to switches
for i in range(1, num_pcs + 1):
# Add a new edge switch if the current one is full
if switch_port == effective_switch_ports:
if switch_port == effective_network_interface:
switch_n += 1
switch_port = 0
switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0)

View File

@@ -264,6 +264,9 @@ class NetworkInterface(SimComponent, ABC):
"""
return f"Port {self.port_name if self.port_name else self.port_num}: {self.mac_address}"
def __hash__(self) -> int:
return hash(self.uuid)
def apply_timestep(self, timestep: int) -> None:
"""
Apply a timestep evolution to this component.
@@ -661,6 +664,10 @@ class Link(SimComponent):
def apply_timestep(self, timestep: int) -> None:
"""Apply a timestep to the simulation."""
super().apply_timestep(timestep)
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
self.current_load = 0.0
@@ -895,6 +902,10 @@ class Node(SimComponent):
from primaite.simulator.system.applications.web_browser import WebBrowser
return WebBrowser
elif application_class_str == "RansomwareScript":
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
return RansomwareScript
else:
return 0
@@ -965,12 +976,15 @@ class Node(SimComponent):
table.align = "l"
table.title = f"{self.hostname} Network Interface Cards"
for port, network_interface in self.network_interface.items():
ip_address = ""
if hasattr(network_interface, "ip_address"):
ip_address = f"{network_interface.ip_address}/{network_interface.ip_network.prefixlen}"
table.add_row(
[
port,
network_interface.__class__.__name__,
network_interface.mac_address,
f"{network_interface.ip_address}/{network_interface.ip_network.prefixlen}",
ip_address,
network_interface.speed,
"Enabled" if network_interface.enabled else "Disabled",
]
@@ -1071,6 +1085,23 @@ class Node(SimComponent):
self.file_system.apply_timestep(timestep=timestep)
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
for network_interface in self.network_interfaces.values():
network_interface.pre_timestep(timestep=timestep)
for process_id in self.processes:
self.processes[process_id].pre_timestep(timestep=timestep)
for service_id in self.services:
self.services[service_id].pre_timestep(timestep=timestep)
for application_id in self.applications:
self.applications[application_id].pre_timestep(timestep=timestep)
self.file_system.pre_timestep(timestep=timestep)
def scan(self) -> bool:
"""
Scan the node and all the items within it.
@@ -1341,6 +1372,8 @@ class Node(SimComponent):
application_instance.configure(target_ip_address=IPv4Address(ip_address))
elif application_instance.name == "DataManipulationBot":
application_instance.configure(server_ip_address=IPv4Address(ip_address))
elif application_instance.name == "RansomwareScript":
application_instance.configure(server_ip_address=IPv4Address(ip_address))
else:
pass

View File

@@ -599,7 +599,9 @@ class Firewall(Router):
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
@@ -612,7 +614,9 @@ class Firewall(Router):
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
@@ -625,7 +629,9 @@ class Firewall(Router):
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
@@ -638,7 +644,9 @@ class Firewall(Router):
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
@@ -651,7 +659,9 @@ class Firewall(Router):
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
@@ -664,7 +674,9 @@ class Firewall(Router):
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)

View File

@@ -322,10 +322,12 @@ class AccessControlList(SimComponent):
action=ACLAction[request[0]],
protocol=None if request[1] == "ALL" else IPProtocol[request[1]],
src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]),
src_port=None if request[3] == "ALL" else Port[request[3]],
dst_ip_address=None if request[4] == "ALL" else IPv4Address(request[4]),
dst_port=None if request[5] == "ALL" else Port[request[5]],
position=int(request[6]),
src_wildcard_mask=None if request[3] == "NONE" else IPv4Address(request[3]),
src_port=None if request[4] == "ALL" else Port[request[4]],
dst_ip_address=None if request[5] == "ALL" else IPv4Address(request[5]),
dst_wildcard_mask=None if request[6] == "NONE" else IPv4Address(request[6]),
dst_port=None if request[7] == "ALL" else Port[request[7]],
position=int(request[8]),
)
)
),
@@ -772,6 +774,13 @@ class RouterARP(ARP):
is_reattempt=True,
is_default_route_attempt=is_default_route_attempt,
)
elif route and route == self.router.route_table.default_route:
self.send_arp_request(self.router.route_table.default_route.next_hop_ip_address)
return self._get_arp_cache_mac_address(
ip_address=self.router.route_table.default_route.next_hop_ip_address,
is_reattempt=True,
is_default_route_attempt=True,
)
else:
if self.router.route_table.default_route:
if not is_default_route_attempt:
@@ -822,6 +831,12 @@ class RouterARP(ARP):
return network_interface
if not is_reattempt:
if self.router.ip_is_in_router_interface_subnet(ip_address):
self.send_arp_request(ip_address)
return self._get_arp_cache_network_interface(
ip_address=ip_address, is_reattempt=True, is_default_route_attempt=is_default_route_attempt
)
route = self.router.route_table.find_best_route(ip_address)
if route and route != self.router.route_table.default_route:
self.send_arp_request(route.next_hop_ip_address)
@@ -830,6 +845,13 @@ class RouterARP(ARP):
is_reattempt=True,
is_default_route_attempt=is_default_route_attempt,
)
elif route and route == self.router.route_table.default_route:
self.send_arp_request(self.router.route_table.default_route.next_hop_ip_address)
return self._get_arp_cache_network_interface(
ip_address=self.router.route_table.default_route.next_hop_ip_address,
is_reattempt=True,
is_default_route_attempt=True,
)
else:
if self.router.route_table.default_route:
if not is_default_route_attempt:
@@ -1459,6 +1481,8 @@ class Router(NetworkNode):
frame.ethernet.src_mac_addr = network_interface.mac_address
frame.ethernet.dst_mac_addr = target_mac
network_interface.send_frame(frame)
else:
self.sys_log.error(f"Frame dropped as there is no route to {frame.ip.dst_ip_address}")
def configure_port(self, port: int, ip_address: Union[IPv4Address, str], subnet_mask: Union[IPv4Address, str]):
"""
@@ -1539,6 +1563,13 @@ class Router(NetworkNode):
- protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP
- src_ip_address (str, optional): IP address octet written in base 10
- dst_ip_address (str, optional): IP address octet written in base 10
- routes (list[dict]): List of route dicts with values:
- address (str): The destination address of the route.
- subnet_mask (str): The subnet mask of the route.
- next_hop_ip_address (str): The next hop IP for the route.
- metric (int): The metric of the route. Optional.
- default_route:
- next_hop_ip_address (str): The next hop IP for the route.
Example config:
```
@@ -1549,6 +1580,10 @@ class Router(NetworkNode):
1: {
'ip_address' : '192.168.1.1',
'subnet_mask' : '255.255.255.0',
},
2: {
'ip_address' : '192.168.0.1',
'subnet_mask' : '255.255.255.252',
}
},
'acl' : {
@@ -1556,6 +1591,10 @@ class Router(NetworkNode):
22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'},
23: {'action': 'PERMIT', 'protocol': 'ICMP'},
},
'routes' : [
{'address': '192.168.0.0', 'subnet_mask': '255.255.255.0', 'next_hop_ip_address': '192.168.1.2'}
],
'default_route': {'next_hop_ip_address': '192.168.0.2'}
}
```
@@ -1599,4 +1638,8 @@ class Router(NetworkNode):
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
metric=float(route.get("metric", 0)),
)
if "default_route" in cfg:
next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None)
if next_hop_ip_address:
router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
return router

View File

@@ -100,13 +100,8 @@ class Switch(NetworkNode):
def __init__(self, **kwargs):
super().__init__(**kwargs)
if not self.network_interface:
self.network_interface = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
for port_num, port in self.network_interface.items():
port._connected_node = self
port.port_num = port_num
port.parent = self
port.port_num = port_num
for i in range(1, self.num_ports + 1):
self.connect_nic(SwitchPort())
def show(self, markdown: bool = False):
"""

View File

@@ -8,7 +8,7 @@ from primaite.simulator.network.protocols.icmp import ICMPPacket
from primaite.simulator.network.protocols.packet import DataPacket
from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol
from primaite.simulator.network.transmission.primaite_layer import PrimaiteHeader
from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader, UDPHeader
from primaite.simulator.network.utils import convert_bytes_to_megabits
_LOGGER = getLogger(__name__)
@@ -141,3 +141,37 @@ class Frame(BaseModel):
def size_Mbits(self) -> float: # noqa - Keep it as MBits as this is how they're expressed
"""The daa transfer size of the Frame in Mbits."""
return convert_bytes_to_megabits(self.size)
@property
def is_broadcast(self) -> bool:
"""
Determines if the Frame is a broadcast frame.
A Frame is considered a broadcast frame if the destination MAC address is set to the broadcast address
"ff:ff:ff:ff:ff:ff".
:return: True if the destination MAC address is a broadcast address, otherwise False.
"""
return self.ethernet.dst_mac_addr.lower() == "ff:ff:ff:ff:ff:ff"
@property
def is_arp(self) -> bool:
"""
Checks if the Frame is an ARP (Address Resolution Protocol) packet.
This is determined by checking if the destination port of the TCP header is equal to the ARP port.
:return: True if the Frame is an ARP packet, otherwise False.
"""
return self.udp.dst_port == Port.ARP
@property
def is_icmp(self) -> bool:
"""
Determines if the Frame is an ICMP (Internet Control Message Protocol) packet.
This check is performed by verifying if the 'icmp' attribute of the Frame instance is present (not None).
:return: True if the Frame is an ICMP packet (i.e., has an ICMP header), otherwise False.
"""
return self.icmp is not None

View File

@@ -11,6 +11,9 @@ class Port(Enum):
.. _List of Ports:
"""
UNUSED = -1
"An unused port stub."
NONE = 0
"Place holder for a non-port."
WOL = 9