Merge remote-tracking branch 'origin/dev' into feature/2137-refactor-request-api
This commit is contained in:
@@ -7,11 +7,13 @@ from primaite import _PRIMAITE_ROOT
|
||||
__all__ = ["SIM_OUTPUT"]
|
||||
|
||||
|
||||
class __SimOutput:
|
||||
class _SimOutput:
|
||||
def __init__(self):
|
||||
self._path: Path = (
|
||||
_PRIMAITE_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
)
|
||||
self.save_pcap_logs: bool = False
|
||||
self.save_sys_logs: bool = False
|
||||
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
@@ -23,4 +25,4 @@ class __SimOutput:
|
||||
self._path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
SIM_OUTPUT = __SimOutput()
|
||||
SIM_OUTPUT = _SimOutput()
|
||||
|
||||
@@ -23,7 +23,6 @@ class FileSystem(SimComponent):
|
||||
"List containing all the folders in the file system."
|
||||
deleted_folders: Dict[str, Folder] = {}
|
||||
"List containing all the folders that have been deleted."
|
||||
_folders_by_name: Dict[str, Folder] = {}
|
||||
sys_log: SysLog
|
||||
"Instance of SysLog used to create system logs."
|
||||
sim_root: Path
|
||||
@@ -56,7 +55,6 @@ class FileSystem(SimComponent):
|
||||
folder = self.deleted_folders[uuid]
|
||||
self.deleted_folders.pop(uuid)
|
||||
self.folders[uuid] = folder
|
||||
self._folders_by_name[folder.name] = folder
|
||||
|
||||
# Clear any other deleted folders that aren't original (have been created by agent)
|
||||
self.deleted_folders.clear()
|
||||
@@ -67,7 +65,6 @@ class FileSystem(SimComponent):
|
||||
if uuid not in original_folder_uuids:
|
||||
folder = self.folders[uuid]
|
||||
self.folders.pop(uuid)
|
||||
self._folders_by_name.pop(folder.name)
|
||||
|
||||
# Now reset all remaining folders
|
||||
for folder in self.folders.values():
|
||||
@@ -173,7 +170,6 @@ class FileSystem(SimComponent):
|
||||
folder = Folder(name=folder_name, sys_log=self.sys_log)
|
||||
|
||||
self.folders[folder.uuid] = folder
|
||||
self._folders_by_name[folder.name] = folder
|
||||
self._folder_request_manager.add_request(
|
||||
name=folder.name, request_type=RequestType(func=folder._request_manager)
|
||||
)
|
||||
@@ -188,14 +184,13 @@ class FileSystem(SimComponent):
|
||||
if folder_name == "root":
|
||||
self.sys_log.warning("Cannot delete the root folder.")
|
||||
return
|
||||
folder = self._folders_by_name.get(folder_name)
|
||||
folder = self.get_folder(folder_name)
|
||||
if folder:
|
||||
# set folder to deleted state
|
||||
folder.delete()
|
||||
|
||||
# remove from folder list
|
||||
self.folders.pop(folder.uuid)
|
||||
self._folders_by_name.pop(folder.name)
|
||||
|
||||
# add to deleted list
|
||||
folder.remove_all_files()
|
||||
@@ -221,7 +216,10 @@ class FileSystem(SimComponent):
|
||||
:param folder_name: The folder name.
|
||||
:return: The matching Folder.
|
||||
"""
|
||||
return self._folders_by_name.get(folder_name)
|
||||
for folder in self.folders.values():
|
||||
if folder.name == folder_name:
|
||||
return folder
|
||||
return None
|
||||
|
||||
def get_folder_by_id(self, folder_uuid: str, include_deleted: bool = False) -> Optional[Folder]:
|
||||
"""
|
||||
@@ -261,13 +259,13 @@ class FileSystem(SimComponent):
|
||||
"""
|
||||
if folder_name:
|
||||
# check if file with name already exists
|
||||
folder = self._folders_by_name.get(folder_name)
|
||||
folder = self.get_folder(folder_name)
|
||||
# If not then create it
|
||||
if not folder:
|
||||
folder = self.create_folder(folder_name)
|
||||
else:
|
||||
# Use root folder if folder_name not supplied
|
||||
folder = self._folders_by_name["root"]
|
||||
folder = self.get_folder("root")
|
||||
|
||||
# Create the file and add it to the folder
|
||||
file = File(
|
||||
@@ -474,7 +472,6 @@ class FileSystem(SimComponent):
|
||||
|
||||
folder.restore()
|
||||
self.folders[folder.uuid] = folder
|
||||
self._folders_by_name[folder.name] = folder
|
||||
|
||||
if folder.deleted:
|
||||
self.deleted_folders.pop(folder.uuid)
|
||||
|
||||
@@ -87,7 +87,7 @@ class FileSystemItemABC(SimComponent):
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red"}
|
||||
vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red", "deleted"}
|
||||
self._original_state = self.model_dump(include=vals_to_keep)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
|
||||
@@ -17,8 +17,6 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
files: Dict[str, File] = {}
|
||||
"Files stored in the folder."
|
||||
_files_by_name: Dict[str, File] = {}
|
||||
"Files by their name as <file name>.<file type>."
|
||||
deleted_files: Dict[str, File] = {}
|
||||
"Files that have been deleted."
|
||||
|
||||
@@ -78,7 +76,6 @@ class Folder(FileSystemItemABC):
|
||||
file = self.deleted_files[uuid]
|
||||
self.deleted_files.pop(uuid)
|
||||
self.files[uuid] = file
|
||||
self._files_by_name[file.name] = file
|
||||
|
||||
# Clear any other deleted files that aren't original (have been created by agent)
|
||||
self.deleted_files.clear()
|
||||
@@ -89,7 +86,6 @@ class Folder(FileSystemItemABC):
|
||||
if uuid not in original_file_uuids:
|
||||
file = self.files[uuid]
|
||||
self.files.pop(uuid)
|
||||
self._files_by_name.pop(file.name)
|
||||
|
||||
# Now reset all remaining files
|
||||
for file in self.files.values():
|
||||
@@ -105,7 +101,7 @@ class Folder(FileSystemItemABC):
|
||||
self._file_request_manager = RequestManager()
|
||||
rm.add_request(
|
||||
name="file",
|
||||
request_type=RequestType(func=lambda request, context: self._file_request_manager),
|
||||
request_type=RequestType(func=self._file_request_manager),
|
||||
)
|
||||
return rm
|
||||
|
||||
@@ -219,7 +215,10 @@ class Folder(FileSystemItemABC):
|
||||
:return: The matching File.
|
||||
"""
|
||||
# TODO: Increment read count?
|
||||
return self._files_by_name.get(file_name)
|
||||
for file in self.files.values():
|
||||
if file.name == file_name:
|
||||
return file
|
||||
return None
|
||||
|
||||
def get_file_by_id(self, file_uuid: str, include_deleted: Optional[bool] = False) -> File:
|
||||
"""
|
||||
@@ -250,15 +249,14 @@ class Folder(FileSystemItemABC):
|
||||
raise Exception(f"Invalid file: {file}")
|
||||
|
||||
# check if file with id or name already exists in folder
|
||||
if (force is not True) and file.name in self._files_by_name:
|
||||
if self.get_file(file.name) is not None and not force:
|
||||
raise Exception(f"File with name {file.name} already exists in folder")
|
||||
|
||||
if (force is not True) and file.uuid in self.files:
|
||||
if (file.uuid in self.files) and not force:
|
||||
raise Exception(f"File with uuid {file.uuid} already exists in folder")
|
||||
|
||||
# add to list
|
||||
self.files[file.uuid] = file
|
||||
self._files_by_name[file.name] = file
|
||||
self._file_request_manager.add_request(file.uuid, RequestType(func=file._request_manager))
|
||||
file.folder = self
|
||||
|
||||
@@ -275,11 +273,9 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
if self.files.get(file.uuid):
|
||||
self.files.pop(file.uuid)
|
||||
self._files_by_name.pop(file.name)
|
||||
self.deleted_files[file.uuid] = file
|
||||
file.delete()
|
||||
self.sys_log.info(f"Removed file {file.name} (id: {file.uuid})")
|
||||
self._file_request_manager.remove_request(file.uuid)
|
||||
else:
|
||||
_LOGGER.debug(f"File with UUID {file.uuid} was not found.")
|
||||
|
||||
@@ -300,7 +296,6 @@ class Folder(FileSystemItemABC):
|
||||
self.deleted_files[file_id] = file
|
||||
|
||||
self.files = {}
|
||||
self._files_by_name = {}
|
||||
|
||||
def restore_file(self, file_uuid: str):
|
||||
"""
|
||||
@@ -316,7 +311,6 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
file.restore()
|
||||
self.files[file.uuid] = file
|
||||
self._files_by_name[file.name] = file
|
||||
|
||||
if file.deleted:
|
||||
self.deleted_files.pop(file_uuid)
|
||||
|
||||
148
src/primaite/simulator/network/creation.py
Normal file
148
src/primaite/simulator/network/creation.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
|
||||
|
||||
def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
|
||||
"""
|
||||
Calculate the minimum number of network switches required to connect a given number of nodes.
|
||||
|
||||
Each switch is assumed to have one port reserved for connecting to a router, reducing the effective
|
||||
number of ports available for PCs. The function calculates the total number of switches needed
|
||||
to accommodate all nodes under this constraint.
|
||||
|
||||
:param num_nodes: The total number of nodes that need to be connected in the network.
|
||||
:param max_switch_ports: The maximum number of ports available on each switch. Defaults to 24.
|
||||
|
||||
:return: The minimum number of switches required to connect all PCs.
|
||||
|
||||
Example:
|
||||
>>> num_of_switches_required(5)
|
||||
1
|
||||
>>> num_of_switches_required(24,24)
|
||||
2
|
||||
>>> num_of_switches_required(48,24)
|
||||
3
|
||||
>>> num_of_switches_required(25,10)
|
||||
3
|
||||
"""
|
||||
# Reduce the effective number of switch ports by 1 to leave space for the router
|
||||
effective_switch_ports = max_switch_ports - 1
|
||||
|
||||
# Calculate the number of fully utilised switches and any additional switch for remaining PCs
|
||||
full_switches = num_nodes // effective_switch_ports
|
||||
extra_pcs = num_nodes % effective_switch_ports
|
||||
|
||||
# Return the total number of switches required
|
||||
return full_switches + (1 if extra_pcs > 0 else 0)
|
||||
|
||||
|
||||
def create_office_lan(
|
||||
lan_name: str,
|
||||
subnet_base: int,
|
||||
pcs_ip_block_start: int,
|
||||
num_pcs: int,
|
||||
network: Optional[Network] = None,
|
||||
include_router: bool = True,
|
||||
) -> Network:
|
||||
"""
|
||||
Creates a 2-Tier or 3-Tier office local area network (LAN).
|
||||
|
||||
The LAN is configured with a specified number of personal computers (PCs), optionally including a router,
|
||||
and multiple edge switches to connect them. A core switch is added only if more than one edge switch is required.
|
||||
The network topology involves edge switches connected either directly to the router in a 2-Tier setup or
|
||||
to a core switch in a 3-Tier setup. If a router is included, it is connected to the core switch (if present)
|
||||
and configured with basic access control list (ACL) rules. PCs are distributed across the edge switches.
|
||||
|
||||
|
||||
:param str lan_name: The name to be assigned to the LAN.
|
||||
:param int subnet_base: The subnet base number to be used in the IP addresses.
|
||||
:param int pcs_ip_block_start: The starting block for assigning IP addresses to PCs.
|
||||
:param int num_pcs: The number of PCs to be added to the LAN.
|
||||
:param Optional[Network] network: The network to which the LAN components will be added. If None, a new network is
|
||||
created.
|
||||
:param bool include_router: Flag to determine if a router should be included in the LAN. Defaults to True.
|
||||
:return: The network object with the LAN components added.
|
||||
:raises ValueError: If pcs_ip_block_start is less than or equal to the number of required switches.
|
||||
"""
|
||||
# Initialise the network if not provided
|
||||
if not network:
|
||||
network = Network()
|
||||
|
||||
# Calculate the required number of switches
|
||||
num_of_switches = num_of_switches_required(num_nodes=num_pcs)
|
||||
effective_switch_ports = 23 # One port less for router connection
|
||||
if pcs_ip_block_start <= num_of_switches:
|
||||
raise ValueError(f"pcs_ip_block_start must be greater than the number of required switches {num_of_switches}")
|
||||
|
||||
# Create a core switch if more than one edge switch is needed
|
||||
if num_of_switches > 1:
|
||||
core_switch = Switch(hostname=f"switch_core_{lan_name}", start_up_duration=0)
|
||||
core_switch.power_on()
|
||||
network.add_node(core_switch)
|
||||
core_switch_port = 1
|
||||
|
||||
# Initialise the default gateway to None
|
||||
default_gateway = None
|
||||
|
||||
# Optionally include a router in the LAN
|
||||
if include_router:
|
||||
default_gateway = IPv4Address(f"192.168.{subnet_base}.1")
|
||||
router = Router(hostname=f"router_{lan_name}", start_up_duration=0)
|
||||
router.power_on()
|
||||
router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
|
||||
router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
|
||||
network.add_node(router)
|
||||
router.configure_port(port=1, ip_address=default_gateway, subnet_mask="255.255.255.0")
|
||||
router.enable_port(1)
|
||||
|
||||
# Initialise the first edge switch and connect to the router or core switch
|
||||
switch_port = 0
|
||||
switch_n = 1
|
||||
switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0)
|
||||
switch.power_on()
|
||||
network.add_node(switch)
|
||||
if num_of_switches > 1:
|
||||
network.connect(core_switch.switch_ports[core_switch_port], switch.switch_ports[24])
|
||||
else:
|
||||
network.connect(router.ethernet_ports[1], switch.switch_ports[24])
|
||||
|
||||
# Add PCs to the LAN and connect them to switches
|
||||
for i in range(1, num_pcs + 1):
|
||||
# Add a new edge switch if the current one is full
|
||||
if switch_port == effective_switch_ports:
|
||||
switch_n += 1
|
||||
switch_port = 0
|
||||
switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0)
|
||||
switch.power_on()
|
||||
network.add_node(switch)
|
||||
# Connect the new switch to the router or core switch
|
||||
if num_of_switches > 1:
|
||||
core_switch_port += 1
|
||||
network.connect(core_switch.switch_ports[core_switch_port], switch.switch_ports[24])
|
||||
else:
|
||||
network.connect(router.ethernet_ports[1], switch.switch_ports[24])
|
||||
|
||||
# Create and add a PC to the network
|
||||
pc = Computer(
|
||||
hostname=f"pc_{i}_{lan_name}",
|
||||
ip_address=f"192.168.{subnet_base}.{i+pcs_ip_block_start-1}",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway=default_gateway,
|
||||
start_up_duration=0,
|
||||
)
|
||||
pc.power_on()
|
||||
network.add_node(pc)
|
||||
|
||||
# Connect the PC to the switch
|
||||
switch_port += 1
|
||||
network.connect(switch.switch_ports[switch_port], pc.ethernet_port[1])
|
||||
switch.switch_ports[switch_port].enable()
|
||||
|
||||
return network
|
||||
@@ -4,7 +4,7 @@ import re
|
||||
import secrets
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
@@ -274,18 +274,40 @@ class NIC(SimComponent):
|
||||
|
||||
def receive_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
Receive a network frame from the connected link if the NIC is enabled.
|
||||
Receive a network frame from the connected link, processing it if the NIC is enabled.
|
||||
|
||||
The Frame is passed to the Node.
|
||||
This method decrements the Time To Live (TTL) of the frame, captures it using PCAP (Packet Capture), and checks
|
||||
if the frame is either a broadcast or destined for this NIC. If the frame is acceptable, it is passed to the
|
||||
connected node. The method also handles the discarding of frames with TTL expired and logs this event.
|
||||
|
||||
:param frame: The network frame being received.
|
||||
The frame's reception is based on various conditions:
|
||||
- If the NIC is disabled, the frame is not processed.
|
||||
- If the TTL of the frame reaches zero after decrement, it is discarded and logged.
|
||||
- If the frame is a broadcast or its destination MAC/IP address matches this NIC's, it is accepted.
|
||||
- All other frames are dropped and logged or printed to the console.
|
||||
|
||||
:param frame: The network frame being received. This should be an instance of the Frame class.
|
||||
:return: Returns True if the frame is processed and passed to the node, False otherwise.
|
||||
"""
|
||||
if self.enabled:
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self._connected_node.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return False
|
||||
frame.set_received_timestamp()
|
||||
self.pcap.capture(frame)
|
||||
# If this destination or is broadcast
|
||||
if frame.ethernet.dst_mac_addr == self.mac_address or frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
|
||||
accept_frame = False
|
||||
|
||||
# Check if it's a broadcast:
|
||||
if frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
|
||||
if frame.ip.dst_ip_address in {self.ip_address, self.ip_network.broadcast_address}:
|
||||
accept_frame = True
|
||||
else:
|
||||
if frame.ethernet.dst_mac_addr == self.mac_address:
|
||||
accept_frame = True
|
||||
|
||||
if accept_frame:
|
||||
self._connected_node.receive_frame(frame=frame, from_nic=self)
|
||||
return True
|
||||
return False
|
||||
@@ -436,6 +458,9 @@ class SwitchPort(SimComponent):
|
||||
"""
|
||||
if self.enabled:
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self._connected_node.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return False
|
||||
self.pcap.capture(frame)
|
||||
connected_node: Node = self._connected_node
|
||||
connected_node.forward_frame(frame=frame, incoming_port=self)
|
||||
@@ -671,17 +696,30 @@ class ARPCache:
|
||||
"""Clear the entire ARP cache, removing all stored entries."""
|
||||
self.arp.clear()
|
||||
|
||||
def send_arp_request(self, target_ip_address: Union[IPv4Address, str]):
|
||||
def send_arp_request(
|
||||
self, target_ip_address: Union[IPv4Address, str], ignore_networks: Optional[List[IPv4Address]] = None
|
||||
):
|
||||
"""
|
||||
Perform a standard ARP request for a given target IP address.
|
||||
|
||||
Broadcasts the request through all enabled NICs to determine the MAC address corresponding to the target IP
|
||||
address.
|
||||
address. This method can be configured to ignore specific networks when sending out ARP requests,
|
||||
which is useful in environments where certain addresses should not be queried.
|
||||
|
||||
:param target_ip_address: The target IP address to send an ARP request for.
|
||||
:param ignore_networks: An optional list of IPv4 addresses representing networks to be excluded from the ARP
|
||||
request broadcast. Each address in this list indicates a network which will not be queried during the ARP
|
||||
request process. This is particularly useful in complex network environments where traffic should be
|
||||
minimized or controlled to specific subnets. It is mainly used by the router to prevent ARP requests being
|
||||
sent back to their source.
|
||||
"""
|
||||
for nic in self.nics.values():
|
||||
if nic.enabled:
|
||||
use_nic = True
|
||||
if ignore_networks:
|
||||
for ipv4 in ignore_networks:
|
||||
if ipv4 in nic.ip_network:
|
||||
use_nic = False
|
||||
if nic.enabled and use_nic:
|
||||
self.sys_log.info(f"Sending ARP request from NIC {nic} for ip {target_ip_address}")
|
||||
tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP)
|
||||
|
||||
@@ -806,7 +844,6 @@ class ICMP:
|
||||
self.arp.send_arp_request(frame.ip.src_ip_address)
|
||||
self.process_icmp(frame=frame, from_nic=from_nic, is_reattempt=True)
|
||||
return
|
||||
tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP)
|
||||
|
||||
# Network Layer
|
||||
ip_packet = IPPacket(
|
||||
@@ -821,9 +858,7 @@ class ICMP:
|
||||
sequence=frame.icmp.sequence + 1,
|
||||
)
|
||||
payload = secrets.token_urlsafe(int(32 / 1.3)) # Standard ICMP 32 bytes size
|
||||
frame = Frame(
|
||||
ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, icmp=icmp_reply_packet, payload=payload
|
||||
)
|
||||
frame = Frame(ethernet=ethernet_header, ip=ip_packet, icmp=icmp_reply_packet, payload=payload)
|
||||
self.sys_log.info(f"Sending echo reply to {frame.ip.dst_ip_address}")
|
||||
|
||||
src_nic.send_frame(frame)
|
||||
@@ -1275,8 +1310,8 @@ class Node(SimComponent):
|
||||
self.start_up_countdown = self.start_up_duration
|
||||
|
||||
if self.start_up_duration <= 0:
|
||||
self._start_up_actions()
|
||||
self.operating_state = NodeOperatingState.ON
|
||||
self._start_up_actions()
|
||||
self.sys_log.info("Turned on")
|
||||
for nic in self.nics.values():
|
||||
if nic._connected_link:
|
||||
@@ -1450,7 +1485,7 @@ class Node(SimComponent):
|
||||
service.parent = self
|
||||
service.install() # Perform any additional setup, such as creating files for this service on the node.
|
||||
self.sys_log.info(f"Installed service {service.name}")
|
||||
_LOGGER.info(f"Added service {service.name} to node {self.hostname}")
|
||||
_LOGGER.debug(f"Added service {service.name} to node {self.hostname}")
|
||||
self._service_request_manager.add_request(service.name, RequestType(func=service._request_manager))
|
||||
|
||||
def uninstall_service(self, service: Service) -> None:
|
||||
@@ -1485,7 +1520,7 @@ class Node(SimComponent):
|
||||
self.applications[application.uuid] = application
|
||||
application.parent = self
|
||||
self.sys_log.info(f"Installed application {application.name}")
|
||||
_LOGGER.info(f"Added application {application.name} to node {self.hostname}")
|
||||
_LOGGER.debug(f"Added application {application.name} to node {self.hostname}")
|
||||
self._application_request_manager.add_request(application.name, RequestType(func=application._request_manager))
|
||||
|
||||
def uninstall_application(self, application: Application) -> None:
|
||||
|
||||
@@ -9,6 +9,7 @@ from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
|
||||
from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
|
||||
@@ -18,8 +19,8 @@ from primaite.simulator.system.core.sys_log import SysLog
|
||||
class ACLAction(Enum):
|
||||
"""Enum for defining the ACL action types."""
|
||||
|
||||
DENY = 0
|
||||
PERMIT = 1
|
||||
DENY = 2
|
||||
|
||||
|
||||
class ACLRule(SimComponent):
|
||||
@@ -65,11 +66,11 @@ class ACLRule(SimComponent):
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["action"] = self.action.value
|
||||
state["protocol"] = self.protocol.value if self.protocol else None
|
||||
state["protocol"] = self.protocol.name if self.protocol else None
|
||||
state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None
|
||||
state["src_port"] = self.src_port.value if self.src_port else None
|
||||
state["src_port"] = self.src_port.name if self.src_port else None
|
||||
state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None
|
||||
state["dst_port"] = self.dst_port.value if self.dst_port else None
|
||||
state["dst_port"] = self.dst_port.name if self.dst_port else None
|
||||
return state
|
||||
|
||||
|
||||
@@ -89,6 +90,8 @@ class AccessControlList(SimComponent):
|
||||
implicit_rule: ACLRule
|
||||
max_acl_rules: int = 25
|
||||
_acl: List[Optional[ACLRule]] = [None] * 24
|
||||
_default_config: Dict[int, dict] = {}
|
||||
"""Config dict describing how the ACL list should look at episode start"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
if not kwargs.get("implicit_action"):
|
||||
@@ -106,10 +109,40 @@ class AccessControlList(SimComponent):
|
||||
vals_to_keep = {"implicit_action", "max_acl_rules", "acl"}
|
||||
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
|
||||
|
||||
for i, rule in enumerate(self._acl):
|
||||
if not rule:
|
||||
continue
|
||||
self._default_config[i] = {"action": rule.action.name}
|
||||
if rule.src_ip_address:
|
||||
self._default_config[i]["src_ip"] = str(rule.src_ip_address)
|
||||
if rule.dst_ip_address:
|
||||
self._default_config[i]["dst_ip"] = str(rule.dst_ip_address)
|
||||
if rule.src_port:
|
||||
self._default_config[i]["src_port"] = rule.src_port.name
|
||||
if rule.dst_port:
|
||||
self._default_config[i]["dst_port"] = rule.dst_port.name
|
||||
if rule.protocol:
|
||||
self._default_config[i]["protocol"] = rule.protocol.name
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.implicit_rule.reset_component_for_episode(episode)
|
||||
super().reset_component_for_episode(episode)
|
||||
self._reset_rules_to_default()
|
||||
|
||||
def _reset_rules_to_default(self) -> None:
|
||||
"""Clear all ACL rules and set them to the default rules config."""
|
||||
self._acl = [None] * (self.max_acl_rules - 1)
|
||||
for r_num, r_cfg in self._default_config.items():
|
||||
self.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
@@ -129,9 +162,9 @@ class AccessControlList(SimComponent):
|
||||
func=lambda request, context: self.add_rule(
|
||||
ACLAction[request[0]],
|
||||
None if request[1] == "ALL" else IPProtocol[request[1]],
|
||||
IPv4Address(request[2]),
|
||||
None if request[2] == "ALL" else IPv4Address(request[2]),
|
||||
None if request[3] == "ALL" else Port[request[3]],
|
||||
IPv4Address(request[4]),
|
||||
None if request[4] == "ALL" else IPv4Address(request[4]),
|
||||
None if request[5] == "ALL" else Port[request[5]],
|
||||
int(request[6]),
|
||||
)
|
||||
@@ -333,11 +366,10 @@ class RouteEntry(SimComponent):
|
||||
"""
|
||||
Represents a single entry in a routing table.
|
||||
|
||||
Attributes:
|
||||
address (IPv4Address): The destination IP address or network address.
|
||||
subnet_mask (IPv4Address): The subnet mask for the network.
|
||||
next_hop_ip_address (IPv4Address): The next hop IP address to which packets should be forwarded.
|
||||
metric (int): The cost metric for this route. Default is 0.0.
|
||||
:ivar address: The destination IP address or network address.
|
||||
:ivar subnet_mask: The subnet mask for the network.
|
||||
:ivar next_hop_ip_address: The next hop IP address to which packets should be forwarded.
|
||||
:ivar metric: The cost metric for this route. Default is 0.0.
|
||||
|
||||
Example:
|
||||
>>> entry = RouteEntry(
|
||||
@@ -357,12 +389,6 @@ class RouteEntry(SimComponent):
|
||||
metric: float = 0.0
|
||||
"The cost metric for this route. Default is 0.0."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for key in {"address", "subnet_mask", "next_hop_ip_address"}:
|
||||
if not isinstance(kwargs[key], IPv4Address):
|
||||
kwargs[key] = IPv4Address(kwargs[key])
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"}
|
||||
@@ -397,10 +423,10 @@ class RouteTable(SimComponent):
|
||||
"""
|
||||
|
||||
routes: List[RouteEntry] = []
|
||||
default_route: Optional[RouteEntry] = None
|
||||
sys_log: SysLog
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
self._original_state["routes_orig"] = self.routes
|
||||
@@ -442,12 +468,35 @@ class RouteTable(SimComponent):
|
||||
)
|
||||
self.routes.append(route)
|
||||
|
||||
def set_default_route_next_hop_ip_address(self, ip_address: IPv4Address):
|
||||
"""
|
||||
Sets the next-hop IP address for the default route in a routing table.
|
||||
|
||||
This method checks if a default route (0.0.0.0/0) exists in the routing table. If it does not exist,
|
||||
the method creates a new default route with the specified next-hop IP address. If a default route already
|
||||
exists, it updates the next-hop IP address of the existing default route. After setting the next-hop
|
||||
IP address, the method logs this action.
|
||||
|
||||
:param ip_address: The next-hop IP address to be set for the default route.
|
||||
"""
|
||||
if not self.default_route:
|
||||
self.default_route = RouteEntry(
|
||||
ip_address=IPv4Address("0.0.0.0"),
|
||||
subnet_mask=IPv4Address("0.0.0.0"),
|
||||
next_hop_ip_address=ip_address,
|
||||
)
|
||||
else:
|
||||
self.default_route.next_hop_ip_address = ip_address
|
||||
self.sys_log.info(f"Default configured to use {ip_address} as the next-hop")
|
||||
|
||||
def find_best_route(self, destination_ip: Union[str, IPv4Address]) -> Optional[RouteEntry]:
|
||||
"""
|
||||
Find the best route for a given destination IP.
|
||||
|
||||
This method uses the Longest Prefix Match algorithm and considers metrics to find the best route.
|
||||
|
||||
If no dedicated route exists but a default route does, then the default route is returned as a last resort.
|
||||
|
||||
:param destination_ip: The destination IP to find the route for.
|
||||
:return: The best matching RouteEntry, or None if no route matches.
|
||||
"""
|
||||
@@ -467,6 +516,9 @@ class RouteTable(SimComponent):
|
||||
longest_prefix = prefix_len
|
||||
lowest_metric = route.metric
|
||||
|
||||
if not best_route and self.default_route:
|
||||
best_route = self.default_route
|
||||
|
||||
return best_route
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
@@ -498,12 +550,26 @@ class RouterARPCache(ARPCache):
|
||||
super().__init__(sys_log)
|
||||
self.router: Router = router
|
||||
|
||||
def process_arp_packet(self, from_nic: NIC, frame: Frame):
|
||||
def process_arp_packet(
|
||||
self, from_nic: NIC, frame: Frame, route_table: RouteTable, is_reattempt: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Overridden method to process a received ARP packet in a router-specific way.
|
||||
Processes a received ARP (Address Resolution Protocol) packet in a router-specific way.
|
||||
|
||||
This method is responsible for handling both ARP requests and responses. It processes ARP packets received on a
|
||||
Network Interface Card (NIC) and performs actions based on whether the packet is a request or a reply. This
|
||||
includes updating the ARP cache, forwarding ARP replies, sending ARP requests for unknown destinations, and
|
||||
handling packet TTL (Time To Live).
|
||||
|
||||
The method first checks if the ARP packet is a request or a reply. For ARP replies, it updates the ARP cache
|
||||
and forwards the reply if necessary. For ARP requests, it checks if the target IP matches one of the router's
|
||||
NICs and sends an ARP reply if so. If the destination is not directly connected, it consults the routing table
|
||||
to find the best route and reattempts ARP request processing if needed.
|
||||
|
||||
:param from_nic: The NIC that received the ARP packet.
|
||||
:param frame: The original ARP frame.
|
||||
:param frame: The frame containing the ARP packet.
|
||||
:param route_table: The routing table of the router.
|
||||
:param is_reattempt: Flag to indicate if this is a reattempt of processing the ARP packet, defaults to False.
|
||||
"""
|
||||
arp_packet = frame.arp
|
||||
|
||||
@@ -531,7 +597,11 @@ class RouterARPCache(ARPCache):
|
||||
)
|
||||
arp_packet.sender_mac_addr = nic.mac_address
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return
|
||||
nic.send_frame(frame)
|
||||
return
|
||||
|
||||
# ARP Request
|
||||
self.sys_log.info(
|
||||
@@ -542,16 +612,32 @@ class RouterARPCache(ARPCache):
|
||||
self.add_arp_cache_entry(
|
||||
ip_address=arp_packet.sender_ip_address, mac_address=arp_packet.sender_mac_addr, nic=from_nic
|
||||
)
|
||||
arp_packet = arp_packet.generate_reply(from_nic.mac_address)
|
||||
self.send_arp_reply(arp_packet, from_nic)
|
||||
|
||||
# If the target IP matches one of the router's NICs
|
||||
for nic in self.nics.values():
|
||||
if nic.enabled and nic.ip_address == arp_packet.target_ip_address:
|
||||
if arp_packet.target_ip_address in nic.ip_network:
|
||||
# if nic.enabled and nic.ip_address == arp_packet.target_ip_address:
|
||||
arp_reply = arp_packet.generate_reply(from_nic.mac_address)
|
||||
self.send_arp_reply(arp_reply, from_nic)
|
||||
return
|
||||
|
||||
# Check Route Table
|
||||
route = route_table.find_best_route(arp_packet.target_ip_address)
|
||||
if route:
|
||||
nic = self.get_arp_cache_nic(route.next_hop_ip_address)
|
||||
|
||||
if not nic:
|
||||
if not is_reattempt:
|
||||
self.send_arp_request(route.next_hop_ip_address, ignore_networks=[frame.ip.src_ip_address])
|
||||
return self.process_arp_packet(from_nic, frame, route_table, is_reattempt=True)
|
||||
else:
|
||||
self.sys_log.info("Ignoring ARP request as destination unavailable/No ARP entry found")
|
||||
return
|
||||
else:
|
||||
arp_reply = arp_packet.generate_reply(from_nic.mac_address)
|
||||
self.send_arp_reply(arp_reply, from_nic)
|
||||
return
|
||||
|
||||
|
||||
class RouterICMP(ICMP):
|
||||
"""
|
||||
@@ -622,7 +708,7 @@ class RouterICMP(ICMP):
|
||||
return
|
||||
|
||||
# Route the frame
|
||||
self.router.route_frame(frame, from_nic)
|
||||
self.router.process_frame(frame, from_nic)
|
||||
|
||||
elif frame.icmp.icmp_type == ICMPType.ECHO_REPLY:
|
||||
for nic in self.router.nics.values():
|
||||
@@ -642,7 +728,48 @@ class RouterICMP(ICMP):
|
||||
|
||||
return
|
||||
# Route the frame
|
||||
self.router.route_frame(frame, from_nic)
|
||||
self.router.process_frame(frame, from_nic)
|
||||
|
||||
|
||||
class RouterNIC(NIC):
|
||||
"""
|
||||
A Router-specific Network Interface Card (NIC) that extends the standard NIC functionality.
|
||||
|
||||
This class overrides the standard Node NIC's Layer 3 (L3) broadcast/unicast checks. It is designed
|
||||
to handle network frames in a manner specific to routers, allowing them to efficiently process
|
||||
and route network traffic.
|
||||
"""
|
||||
|
||||
def receive_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
Receive and process a network frame from the connected link, provided the NIC is enabled.
|
||||
|
||||
This method is tailored for router behavior. It decrements the frame's Time To Live (TTL), checks for TTL
|
||||
expiration, and captures the frame using PCAP (Packet Capture). The frame is accepted if it is destined for
|
||||
this NIC's MAC address or is a broadcast frame.
|
||||
|
||||
Key Differences from Standard NIC:
|
||||
- Does not perform Layer 3 (IP-based) broadcast checks.
|
||||
- Only checks for Layer 2 (Ethernet) destination MAC address and broadcast frames.
|
||||
|
||||
:param frame: The network frame being received. This should be an instance of the Frame class.
|
||||
:return: Returns True if the frame is processed and passed to the connected node, False otherwise.
|
||||
"""
|
||||
if self.enabled:
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self._connected_node.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return False
|
||||
frame.set_received_timestamp()
|
||||
self.pcap.capture(frame)
|
||||
# If this destination or is broadcast
|
||||
if frame.ethernet.dst_mac_addr == self.mac_address or frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
|
||||
self._connected_node.receive_frame(frame=frame, from_nic=self)
|
||||
return True
|
||||
return False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.mac_address}/{self.ip_address}"
|
||||
|
||||
|
||||
class Router(Node):
|
||||
@@ -655,7 +782,7 @@ class Router(Node):
|
||||
"""
|
||||
|
||||
num_ports: int
|
||||
ethernet_ports: Dict[int, NIC] = {}
|
||||
ethernet_ports: Dict[int, RouterNIC] = {}
|
||||
acl: AccessControlList
|
||||
route_table: RouteTable
|
||||
arp: RouterARPCache
|
||||
@@ -674,7 +801,7 @@ class Router(Node):
|
||||
kwargs["icmp"] = RouterICMP(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"), router=self)
|
||||
super().__init__(hostname=hostname, num_ports=num_ports, **kwargs)
|
||||
for i in range(1, self.num_ports + 1):
|
||||
nic = NIC(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
|
||||
nic = RouterNIC(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
|
||||
self.connect_nic(nic)
|
||||
self.ethernet_ports[i] = nic
|
||||
|
||||
@@ -725,13 +852,13 @@ class Router(Node):
|
||||
:return: A dictionary representing the current state.
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["num_ports"] = (self.num_ports,)
|
||||
state["acl"] = (self.acl.describe_state(),)
|
||||
state["num_ports"] = self.num_ports
|
||||
state["acl"] = self.acl.describe_state()
|
||||
return state
|
||||
|
||||
def route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
|
||||
def process_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
|
||||
"""
|
||||
Route a given frame from a source NIC to its destination.
|
||||
Process a Frame.
|
||||
|
||||
:param frame: The frame to be routed.
|
||||
:param from_nic: The source network interface.
|
||||
@@ -746,25 +873,57 @@ class Router(Node):
|
||||
return
|
||||
|
||||
if not nic:
|
||||
self.arp.send_arp_request(frame.ip.dst_ip_address)
|
||||
return self.route_frame(frame=frame, from_nic=from_nic, re_attempt=True)
|
||||
self.arp.send_arp_request(
|
||||
frame.ip.dst_ip_address, ignore_networks=[frame.ip.src_ip_address, from_nic.ip_address]
|
||||
)
|
||||
return self.process_frame(frame=frame, from_nic=from_nic, re_attempt=True)
|
||||
|
||||
if not nic.enabled:
|
||||
# TODO: Add sys_log here
|
||||
self.sys_log.info(f"Frame dropped as NIC {nic} is not enabled")
|
||||
return
|
||||
|
||||
if frame.ip.dst_ip_address in nic.ip_network:
|
||||
from_port = self._get_port_of_nic(from_nic)
|
||||
to_port = self._get_port_of_nic(nic)
|
||||
self.sys_log.info(f"Routing frame to internally from port {from_port} to port {to_port}")
|
||||
self.sys_log.info(f"Forwarding frame to internally from port {from_port} to port {to_port}")
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return
|
||||
frame.ethernet.src_mac_addr = nic.mac_address
|
||||
frame.ethernet.dst_mac_addr = target_mac
|
||||
nic.send_frame(frame)
|
||||
return
|
||||
else:
|
||||
pass
|
||||
# TODO: Deal with routing from route tables
|
||||
self._route_frame(frame, from_nic)
|
||||
|
||||
def _route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
|
||||
route = self.route_table.find_best_route(frame.ip.dst_ip_address)
|
||||
if route:
|
||||
nic = self.arp.get_arp_cache_nic(route.next_hop_ip_address)
|
||||
target_mac = self.arp.get_arp_cache_mac_address(route.next_hop_ip_address)
|
||||
if re_attempt and not nic:
|
||||
self.sys_log.info(f"Destination {frame.ip.dst_ip_address} is unreachable")
|
||||
return
|
||||
|
||||
if not nic:
|
||||
self.arp.send_arp_request(frame.ip.dst_ip_address, ignore_networks=[frame.ip.src_ip_address])
|
||||
return self.process_frame(frame=frame, from_nic=from_nic, re_attempt=True)
|
||||
|
||||
if not nic.enabled:
|
||||
self.sys_log.info(f"Frame dropped as NIC {nic} is not enabled")
|
||||
return
|
||||
|
||||
from_port = self._get_port_of_nic(from_nic)
|
||||
to_port = self._get_port_of_nic(nic)
|
||||
self.sys_log.info(f"Routing frame to internally from port {from_port} to port {to_port}")
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self.sys_log.info("Frame discarded as TTL limit reached")
|
||||
return
|
||||
frame.ethernet.src_mac_addr = nic.mac_address
|
||||
frame.ethernet.dst_mac_addr = target_mac
|
||||
nic.send_frame(frame)
|
||||
|
||||
def receive_frame(self, frame: Frame, from_nic: NIC):
|
||||
"""
|
||||
@@ -773,7 +932,7 @@ class Router(Node):
|
||||
:param frame: The incoming frame.
|
||||
:param from_nic: The network interface where the frame is coming from.
|
||||
"""
|
||||
route_frame = False
|
||||
process_frame = False
|
||||
protocol = frame.ip.protocol
|
||||
src_ip_address = frame.ip.src_ip_address
|
||||
dst_ip_address = frame.ip.dst_ip_address
|
||||
@@ -805,12 +964,12 @@ class Router(Node):
|
||||
self.icmp.process_icmp(frame=frame, from_nic=from_nic)
|
||||
else:
|
||||
if src_port == Port.ARP:
|
||||
self.arp.process_arp_packet(from_nic=from_nic, frame=frame)
|
||||
self.arp.process_arp_packet(from_nic=from_nic, frame=frame, route_table=self.route_table)
|
||||
else:
|
||||
# All other traffic
|
||||
route_frame = True
|
||||
if route_frame:
|
||||
self.route_frame(frame, from_nic)
|
||||
process_frame = True
|
||||
if process_frame:
|
||||
self.process_frame(frame, from_nic)
|
||||
|
||||
def configure_port(self, port: int, ip_address: Union[IPv4Address, str], subnet_mask: Union[IPv4Address, str]):
|
||||
"""
|
||||
@@ -873,3 +1032,63 @@ class Router(Node):
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: dict) -> "Router":
|
||||
"""Create a router based on a config dict.
|
||||
|
||||
Schema:
|
||||
- hostname (str): unique name for this router.
|
||||
- num_ports (int, optional): Number of network ports on the router. 8 by default
|
||||
- ports (dict): Dict with integers from 1 - num_ports as keys. The values should be another dict specifying
|
||||
ip_address and subnet_mask assigned to that ports (as strings)
|
||||
- acl (dict): Dict with integers from 1 - max_acl_rules as keys. The key defines the position within the ACL
|
||||
where the rule will be added (lower number is resolved first). The values should describe valid ACL
|
||||
Rules as:
|
||||
- action (str): either PERMIT or DENY
|
||||
- src_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
|
||||
- dst_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
|
||||
- protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP
|
||||
- src_ip_address (str, optional): IP address octet written in base 10
|
||||
- dst_ip_address (str, optional): IP address octet written in base 10
|
||||
|
||||
Example config:
|
||||
```
|
||||
{
|
||||
'hostname': 'router_1',
|
||||
'num_ports': 5,
|
||||
'ports': {
|
||||
1: {
|
||||
'ip_address' : '192.168.1.1',
|
||||
'subnet_mask' : '255.255.255.0',
|
||||
}
|
||||
},
|
||||
'acl' : {
|
||||
21: {'action': 'PERMIT', 'src_port': 'HTTP', dst_port: 'HTTP'},
|
||||
22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'},
|
||||
23: {'action': 'PERMIT', 'protocol': 'ICMP'},
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
:param cfg: Router config adhering to schema described in main docstring body
|
||||
:type cfg: dict
|
||||
:return: Configured router.
|
||||
:rtype: Router
|
||||
"""
|
||||
new = Router(
|
||||
hostname=cfg["hostname"],
|
||||
num_ports=cfg.get("num_ports"),
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
if "ports" in cfg:
|
||||
for port_num, port_cfg in cfg["ports"].items():
|
||||
new.configure_port(
|
||||
port=port_num,
|
||||
ip_address=port_cfg["ip_address"],
|
||||
subnet_mask=port_cfg["subnet_mask"],
|
||||
)
|
||||
if "acl" in cfg:
|
||||
new.acl._default_config = cfg["acl"] # save the config to allow resetting
|
||||
new.acl._reset_rules_to_default() # read the config and apply rules
|
||||
return new
|
||||
|
||||
@@ -90,12 +90,12 @@ class Switch(Node):
|
||||
self._add_mac_table_entry(src_mac, incoming_port)
|
||||
|
||||
outgoing_port = self.mac_address_table.get(dst_mac)
|
||||
if outgoing_port or dst_mac != "ff:ff:ff:ff:ff:ff":
|
||||
if outgoing_port and dst_mac.lower() != "ff:ff:ff:ff:ff:ff":
|
||||
outgoing_port.send_frame(frame)
|
||||
else:
|
||||
# If the destination MAC is not in the table, flood to all ports except incoming
|
||||
for port in self.switch_ports.values():
|
||||
if port != incoming_port:
|
||||
if port.enabled and port != incoming_port:
|
||||
port.send_frame(frame)
|
||||
|
||||
def disconnect_link_from_port(self, link: Link, port_number: int):
|
||||
|
||||
@@ -38,9 +38,6 @@ class Application(IOSoftware):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.health_state_visible = SoftwareHealthState.UNUSED
|
||||
self.health_state_actual = SoftwareHealthState.UNUSED
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
@@ -95,6 +92,9 @@ class Application(IOSoftware):
|
||||
if self.operating_state == ApplicationOperatingState.CLOSED:
|
||||
self.sys_log.info(f"Running Application {self.name}")
|
||||
self.operating_state = ApplicationOperatingState.RUNNING
|
||||
# set software health state to GOOD if initially set to UNUSED
|
||||
if self.health_state_actual == SoftwareHealthState.UNUSED:
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
|
||||
def _application_loop(self):
|
||||
"""The main application loop."""
|
||||
|
||||
@@ -72,7 +72,7 @@ class DataManipulationBot(DatabaseClient):
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run()))
|
||||
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.attack()))
|
||||
|
||||
return rm
|
||||
|
||||
@@ -83,7 +83,7 @@ class DataManipulationBot(DatabaseClient):
|
||||
payload: Optional[str] = None,
|
||||
port_scan_p_of_success: float = 0.1,
|
||||
data_manipulation_p_of_success: float = 0.1,
|
||||
repeat: bool = False,
|
||||
repeat: bool = True,
|
||||
):
|
||||
"""
|
||||
Configure the DataManipulatorBot to communicate with a DatabaseService.
|
||||
@@ -168,6 +168,12 @@ class DataManipulationBot(DatabaseClient):
|
||||
Calls the parent classes execute method before starting the application loop.
|
||||
"""
|
||||
super().run()
|
||||
|
||||
def attack(self):
|
||||
"""Perform the attack steps after opening the application."""
|
||||
if not self._can_perform_action():
|
||||
_LOGGER.debug("Data manipulation application attempted to execute but it cannot perform actions right now.")
|
||||
self.run()
|
||||
self._application_loop()
|
||||
|
||||
def _application_loop(self):
|
||||
@@ -198,4 +204,4 @@ class DataManipulationBot(DatabaseClient):
|
||||
|
||||
:param timestep: The timestep value to update the bot's state.
|
||||
"""
|
||||
self._application_loop()
|
||||
pass
|
||||
|
||||
@@ -41,6 +41,9 @@ class PacketCapture:
|
||||
|
||||
def setup_logger(self):
|
||||
"""Set up the logger configuration."""
|
||||
if not SIM_OUTPUT.save_pcap_logs:
|
||||
return
|
||||
|
||||
log_path = self._get_log_path()
|
||||
|
||||
file_handler = logging.FileHandler(filename=log_path)
|
||||
@@ -88,5 +91,6 @@ class PacketCapture:
|
||||
|
||||
:param frame: The PCAP frame to capture.
|
||||
"""
|
||||
msg = frame.model_dump_json()
|
||||
self.logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL
|
||||
if SIM_OUTPUT.save_pcap_logs:
|
||||
msg = frame.model_dump_json()
|
||||
self.logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
@@ -141,41 +141,76 @@ class SessionManager:
|
||||
def receive_payload_from_software_manager(
|
||||
self,
|
||||
payload: Any,
|
||||
dst_ip_address: Optional[IPv4Address] = None,
|
||||
dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
|
||||
dst_port: Optional[Port] = None,
|
||||
session_id: Optional[str] = None,
|
||||
is_reattempt: bool = False,
|
||||
) -> Union[Any, None]:
|
||||
"""
|
||||
Receive a payload from the SoftwareManager.
|
||||
Receive a payload from the SoftwareManager and send it to the appropriate NIC for transmission.
|
||||
|
||||
If no session_id, a Session is established. Once established, the payload is sent to ``send_payload_to_nic``.
|
||||
This method supports both unicast and Layer 3 broadcast transmissions. If `dst_ip_address` is an
|
||||
IPv4Network, a broadcast is initiated. For unicast, the destination MAC address is resolved via ARP.
|
||||
A new session is established if `session_id` is not provided, and an existing session is used otherwise.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param session_id: The Session ID the payload is to originate from. Optional. If None, one will be created.
|
||||
:param dst_ip_address: The destination IP address or network for broadcast. Optional.
|
||||
:param dst_port: The destination port for the TCP packet. Optional.
|
||||
:param session_id: The Session ID from which the payload originates. Optional.
|
||||
:param is_reattempt: Flag to indicate if this is a reattempt after an ARP request. Default is False.
|
||||
:return: The outcome of sending the frame, or None if sending was unsuccessful.
|
||||
"""
|
||||
is_broadcast = False
|
||||
outbound_nic = None
|
||||
dst_mac_address = None
|
||||
|
||||
# Use session details if session_id is provided
|
||||
if session_id:
|
||||
session = self.sessions_by_uuid[session_id]
|
||||
dst_ip_address = self.sessions_by_uuid[session_id].with_ip_address
|
||||
dst_port = self.sessions_by_uuid[session_id].dst_port
|
||||
dst_ip_address = session.with_ip_address
|
||||
dst_port = session.dst_port
|
||||
|
||||
dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address)
|
||||
# Determine if the payload is for broadcast or unicast
|
||||
|
||||
if dst_mac_address:
|
||||
outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address)
|
||||
# Handle broadcast transmission
|
||||
if isinstance(dst_ip_address, IPv4Network):
|
||||
is_broadcast = True
|
||||
dst_ip_address = dst_ip_address.broadcast_address
|
||||
if dst_ip_address:
|
||||
# Find a suitable NIC for the broadcast
|
||||
for nic in self.arp_cache.nics.values():
|
||||
if dst_ip_address in nic.ip_network and nic.enabled:
|
||||
dst_mac_address = "ff:ff:ff:ff:ff:ff"
|
||||
outbound_nic = nic
|
||||
else:
|
||||
if not is_reattempt:
|
||||
self.arp_cache.send_arp_request(dst_ip_address)
|
||||
return self.receive_payload_from_software_manager(
|
||||
payload=payload,
|
||||
dst_ip_address=dst_ip_address,
|
||||
dst_port=dst_port,
|
||||
session_id=session_id,
|
||||
is_reattempt=True,
|
||||
)
|
||||
else:
|
||||
return
|
||||
# Resolve MAC address for unicast transmission
|
||||
dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address)
|
||||
|
||||
# Resolve outbound NIC for unicast transmission
|
||||
if dst_mac_address:
|
||||
outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address)
|
||||
|
||||
# If MAC address not found, initiate ARP request
|
||||
else:
|
||||
if not is_reattempt:
|
||||
self.arp_cache.send_arp_request(dst_ip_address)
|
||||
# Reattempt payload transmission after ARP request
|
||||
return self.receive_payload_from_software_manager(
|
||||
payload=payload,
|
||||
dst_ip_address=dst_ip_address,
|
||||
dst_port=dst_port,
|
||||
session_id=session_id,
|
||||
is_reattempt=True,
|
||||
)
|
||||
else:
|
||||
# Return None if reattempt fails
|
||||
return
|
||||
|
||||
# Check if outbound NIC and destination MAC address are resolved
|
||||
if not outbound_nic or not dst_mac_address:
|
||||
return False
|
||||
|
||||
# Construct the frame for transmission
|
||||
frame = Frame(
|
||||
ethernet=EthernetHeader(src_mac_addr=outbound_nic.mac_address, dst_mac_addr=dst_mac_address),
|
||||
ip=IPPacket(
|
||||
@@ -189,15 +224,17 @@ class SessionManager:
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
if not session_id:
|
||||
# Manage session for unicast transmission
|
||||
if not (is_broadcast and session_id):
|
||||
session_key = self._get_session_key(frame, inbound_frame=False)
|
||||
session = self.sessions_by_key.get(session_key)
|
||||
if not session:
|
||||
# Create new session
|
||||
# Create a new session if it doesn't exist
|
||||
session = Session.from_session_key(session_key)
|
||||
self.sessions_by_key[session_key] = session
|
||||
self.sessions_by_uuid[session.uuid] = session
|
||||
|
||||
# Send the frame through the NIC
|
||||
return outbound_nic.send_frame(frame)
|
||||
|
||||
def receive_frame(self, frame: Frame):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from ipaddress import IPv4Address
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
@@ -130,20 +130,28 @@ class SoftwareManager:
|
||||
def send_payload_to_session_manager(
|
||||
self,
|
||||
payload: Any,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Send a payload to the SessionManager.
|
||||
Sends a payload to the SessionManager for network transmission.
|
||||
|
||||
This method is responsible for initiating the process of sending network payloads. It supports both
|
||||
unicast and Layer 3 broadcast transmissions. For broadcasts, the destination IP should be specified
|
||||
as an IPv4Network.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param dest_ip_address: The ip address of the payload destination.
|
||||
:param dest_port: The port of the payload destination.
|
||||
:param session_id: The Session ID the payload is to originate from. Optional.
|
||||
:param dest_ip_address: The IP address or network (for broadcasts) of the payload destination.
|
||||
:param dest_port: The destination port for the payload. Optional.
|
||||
:param session_id: The Session ID from which the payload originates. Optional.
|
||||
:return: True if the payload was successfully sent, False otherwise.
|
||||
"""
|
||||
return self.session_manager.receive_payload_from_software_manager(
|
||||
payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id
|
||||
payload=payload,
|
||||
dst_ip_address=dest_ip_address,
|
||||
dst_port=dest_port,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def receive_payload_from_session_manager(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str):
|
||||
|
||||
@@ -41,6 +41,9 @@ class SysLog:
|
||||
The logger is set to the DEBUG level, and is equipped with a handler that writes to a file and filters out
|
||||
JSON-like messages.
|
||||
"""
|
||||
if not SIM_OUTPUT.save_sys_logs:
|
||||
return
|
||||
|
||||
log_path = self._get_log_path()
|
||||
file_handler = logging.FileHandler(filename=log_path)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
@@ -91,7 +94,8 @@ class SysLog:
|
||||
|
||||
:param msg: The message to be logged.
|
||||
"""
|
||||
self.logger.debug(msg)
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.debug(msg)
|
||||
|
||||
def info(self, msg: str):
|
||||
"""
|
||||
@@ -99,7 +103,8 @@ class SysLog:
|
||||
|
||||
:param msg: The message to be logged.
|
||||
"""
|
||||
self.logger.info(msg)
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.info(msg)
|
||||
|
||||
def warning(self, msg: str):
|
||||
"""
|
||||
@@ -107,7 +112,8 @@ class SysLog:
|
||||
|
||||
:param msg: The message to be logged.
|
||||
"""
|
||||
self.logger.warning(msg)
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.warning(msg)
|
||||
|
||||
def error(self, msg: str):
|
||||
"""
|
||||
@@ -115,7 +121,8 @@ class SysLog:
|
||||
|
||||
:param msg: The message to be logged.
|
||||
"""
|
||||
self.logger.error(msg)
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.error(msg)
|
||||
|
||||
def critical(self, msg: str):
|
||||
"""
|
||||
@@ -123,4 +130,5 @@ class SysLog:
|
||||
|
||||
:param msg: The message to be logged.
|
||||
"""
|
||||
self.logger.critical(msg)
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.critical(msg)
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
|
||||
from primaite.simulator.file_system.folder import Folder
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
@@ -22,7 +24,7 @@ class DatabaseService(Service):
|
||||
|
||||
password: Optional[str] = None
|
||||
|
||||
backup_server: IPv4Address = None
|
||||
backup_server_ip: IPv4Address = None
|
||||
"""IP address of the backup server."""
|
||||
|
||||
latest_backup_directory: str = None
|
||||
@@ -36,7 +38,6 @@ class DatabaseService(Service):
|
||||
kwargs["port"] = Port.POSTGRES_SERVER
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
self._db_file: File
|
||||
self._create_db_file()
|
||||
|
||||
def set_original_state(self):
|
||||
@@ -45,8 +46,8 @@ class DatabaseService(Service):
|
||||
super().set_original_state()
|
||||
vals_to_include = {
|
||||
"password",
|
||||
"_connections",
|
||||
"backup_server",
|
||||
"connections",
|
||||
"backup_server_ip",
|
||||
"latest_backup_directory",
|
||||
"latest_backup_file_name",
|
||||
}
|
||||
@@ -64,7 +65,7 @@ class DatabaseService(Service):
|
||||
|
||||
:param: backup_server_ip: The IP address of the backup server
|
||||
"""
|
||||
self.backup_server = backup_server
|
||||
self.backup_server_ip = backup_server
|
||||
|
||||
def backup_database(self) -> bool:
|
||||
"""Create a backup of the database to the configured backup server."""
|
||||
@@ -73,7 +74,7 @@ class DatabaseService(Service):
|
||||
return False
|
||||
|
||||
# check if the backup server was configured
|
||||
if self.backup_server is None:
|
||||
if self.backup_server_ip is None:
|
||||
self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.")
|
||||
return False
|
||||
|
||||
@@ -81,10 +82,14 @@ class DatabaseService(Service):
|
||||
ftp_client_service: FTPClient = software_manager.software.get("FTPClient")
|
||||
|
||||
# send backup copy of database file to FTP server
|
||||
if not self.db_file:
|
||||
self.sys_log.error("Attempted to backup database file but it doesn't exist.")
|
||||
return False
|
||||
|
||||
response = ftp_client_service.send_file(
|
||||
dest_ip_address=self.backup_server,
|
||||
src_file_name=self._db_file.name,
|
||||
src_folder_name=self.folder.name,
|
||||
dest_ip_address=self.backup_server_ip,
|
||||
src_file_name=self.db_file.name,
|
||||
src_folder_name="database",
|
||||
dest_folder_name=str(self.uuid),
|
||||
dest_file_name="database.db",
|
||||
)
|
||||
@@ -110,7 +115,7 @@ class DatabaseService(Service):
|
||||
src_file_name="database.db",
|
||||
dest_folder_name="downloads",
|
||||
dest_file_name="database.db",
|
||||
dest_ip_address=self.backup_server,
|
||||
dest_ip_address=self.backup_server_ip,
|
||||
)
|
||||
|
||||
if not response:
|
||||
@@ -118,13 +123,10 @@ class DatabaseService(Service):
|
||||
return False
|
||||
|
||||
# replace db file
|
||||
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
|
||||
self.file_system.copy_file(
|
||||
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
|
||||
)
|
||||
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
|
||||
self.file_system.delete_file(folder_name="database", file_name="database.db")
|
||||
self.file_system.copy_file(src_folder_name="downloads", src_file_name="database.db", dst_folder_name="database")
|
||||
|
||||
if self._db_file is None:
|
||||
if self.db_file is None:
|
||||
self.sys_log.error("Copying database backup failed.")
|
||||
return False
|
||||
|
||||
@@ -134,12 +136,30 @@ class DatabaseService(Service):
|
||||
|
||||
def _create_db_file(self):
|
||||
"""Creates the Simulation File and sqlite file in the file system."""
|
||||
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db")
|
||||
self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id)
|
||||
self.file_system.create_file(folder_name="database", file_name="database.db")
|
||||
|
||||
@property
|
||||
def db_file(self) -> File:
|
||||
"""Returns the database file."""
|
||||
return self.file_system.get_file(folder_name="database", file_name="database.db")
|
||||
|
||||
@property
|
||||
def folder(self) -> Folder:
|
||||
"""Returns the database folder."""
|
||||
return self.file_system.get_folder_by_id(self.db_file.folder_id)
|
||||
|
||||
def _process_connect(
|
||||
self, connection_id: str, password: Optional[str] = None
|
||||
) -> Dict[str, Union[int, Dict[str, bool]]]:
|
||||
"""Process an incoming connection request.
|
||||
|
||||
:param connection_id: A unique identifier for the connection
|
||||
:type connection_id: str
|
||||
:param password: Supplied password. It must match self.password for connection success, defaults to None
|
||||
:type password: Optional[str], optional
|
||||
:return: Response to connection request containing success info.
|
||||
:rtype: Dict[str, Union[int, Dict[str, bool]]]
|
||||
"""
|
||||
status_code = 500 # Default internal server error
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
status_code = 503 # service unavailable
|
||||
@@ -184,7 +204,7 @@ class DatabaseService(Service):
|
||||
self.sys_log.info(f"{self.name}: Running {query}")
|
||||
|
||||
if query == "SELECT":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
if self.db_file.health_status == FileSystemItemHealthStatus.GOOD:
|
||||
return {
|
||||
"status_code": 200,
|
||||
"type": "sql",
|
||||
@@ -195,17 +215,8 @@ class DatabaseService(Service):
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
elif query == "DELETE":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
self.health_state_actual = SoftwareHealthState.COMPROMISED
|
||||
return {
|
||||
"status_code": 200,
|
||||
"type": "sql",
|
||||
"data": False,
|
||||
"uuid": query_id,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
self.db_file.health_status = FileSystemItemHealthStatus.COMPROMISED
|
||||
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id, "connection_id": connection_id}
|
||||
else:
|
||||
# Invalid query
|
||||
return {"status_code": 500, "data": False}
|
||||
@@ -265,3 +276,19 @@ class DatabaseService(Service):
|
||||
software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
|
||||
|
||||
return payload["status_code"] == 200
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
Apply a single timestep of simulation dynamics to this service.
|
||||
|
||||
Here at the first step, the database backup is created, in addition to normal service update logic.
|
||||
"""
|
||||
if timestep == 1:
|
||||
self.backup_database()
|
||||
return super().apply_timestep(timestep)
|
||||
|
||||
def _update_patch_status(self) -> None:
|
||||
"""Perform a database restore when the patching countdown is finished."""
|
||||
super()._update_patch_status()
|
||||
if self._patching_countdown is None:
|
||||
self.restore_backup()
|
||||
|
||||
@@ -89,6 +89,7 @@ class FTPClient(FTPServiceABC):
|
||||
f"{self.name}: Successfully connected to FTP Server "
|
||||
f"{dest_ip_address} via port {payload.ftp_command_args.value}"
|
||||
)
|
||||
self.add_connection(connection_id="server_connection", session_id=session_id)
|
||||
return True
|
||||
else:
|
||||
if is_reattempt:
|
||||
|
||||
@@ -99,5 +99,5 @@ class FTPServer(FTPServiceABC):
|
||||
if payload.status_code is not None:
|
||||
return False
|
||||
|
||||
self.send(self._process_ftp_command(payload=payload, session_id=session_id), session_id)
|
||||
self._process_ftp_command(payload=payload, session_id=session_id)
|
||||
return True
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import shutil
|
||||
from abc import ABC
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
@@ -16,6 +16,10 @@ class FTPServiceABC(Service, ABC):
|
||||
Contains shared methods between both classes.
|
||||
"""
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""Returns a Dict of the FTPService state."""
|
||||
return super().describe_state()
|
||||
|
||||
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
|
||||
"""
|
||||
Process the command in the FTP Packet.
|
||||
@@ -52,10 +56,12 @@ class FTPServiceABC(Service, ABC):
|
||||
folder_name = payload.ftp_command_args["dest_folder_name"]
|
||||
file_size = payload.ftp_command_args["file_size"]
|
||||
real_file_path = payload.ftp_command_args.get("real_file_path")
|
||||
health_status = payload.ftp_command_args["health_status"]
|
||||
is_real = real_file_path is not None
|
||||
file = self.file_system.create_file(
|
||||
file_name=file_name, folder_name=folder_name, size=file_size, real=is_real
|
||||
)
|
||||
file.health_status = health_status
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Created item in {self.sys_log.hostname}: {payload.ftp_command_args['dest_folder_name']}/"
|
||||
f"{payload.ftp_command_args['dest_file_name']}"
|
||||
@@ -110,6 +116,7 @@ class FTPServiceABC(Service, ABC):
|
||||
"dest_file_name": dest_file_name,
|
||||
"file_size": file.sim_size,
|
||||
"real_file_path": file.sim_path if file.real else None,
|
||||
"health_status": file.health_status,
|
||||
},
|
||||
packet_payload_size=file.sim_size,
|
||||
status_code=FTPStatusCode.OK if is_response else None,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -43,9 +44,6 @@ class Service(IOSoftware):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.health_state_visible = SoftwareHealthState.UNUSED
|
||||
self.health_state_actual = SoftwareHealthState.UNUSED
|
||||
|
||||
def _can_perform_action(self) -> bool:
|
||||
"""
|
||||
Checks if the service can perform actions.
|
||||
@@ -98,6 +96,7 @@ class Service(IOSoftware):
|
||||
rm.add_request("enable", RequestType(func=lambda request, context: self.enable()))
|
||||
return rm
|
||||
|
||||
@abstractmethod
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
@@ -118,7 +117,6 @@ class Service(IOSoftware):
|
||||
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
|
||||
self.sys_log.info(f"Stopping service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.STOPPED
|
||||
self.health_state_actual = SoftwareHealthState.UNUSED
|
||||
|
||||
def start(self, **kwargs) -> None:
|
||||
"""Start the service."""
|
||||
@@ -129,42 +127,39 @@ class Service(IOSoftware):
|
||||
if self.operating_state == ServiceOperatingState.STOPPED:
|
||||
self.sys_log.info(f"Starting service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RUNNING
|
||||
self.health_state_actual = SoftwareHealthState.GOOD
|
||||
# set software health state to GOOD if initially set to UNUSED
|
||||
if self.health_state_actual == SoftwareHealthState.UNUSED:
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Pause the service."""
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
self.sys_log.info(f"Pausing service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.PAUSED
|
||||
self.health_state_actual = SoftwareHealthState.OVERWHELMED
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Resume paused service."""
|
||||
if self.operating_state == ServiceOperatingState.PAUSED:
|
||||
self.sys_log.info(f"Resuming service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RUNNING
|
||||
self.health_state_actual = SoftwareHealthState.GOOD
|
||||
|
||||
def restart(self) -> None:
|
||||
"""Restart running service."""
|
||||
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
|
||||
self.sys_log.info(f"Pausing service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RESTARTING
|
||||
self.health_state_actual = SoftwareHealthState.OVERWHELMED
|
||||
self.restart_countdown = self.restart_duration
|
||||
|
||||
def disable(self) -> None:
|
||||
"""Disable the service."""
|
||||
self.sys_log.info(f"Disabling Application {self.name}")
|
||||
self.operating_state = ServiceOperatingState.DISABLED
|
||||
self.health_state_actual = SoftwareHealthState.OVERWHELMED
|
||||
|
||||
def enable(self) -> None:
|
||||
"""Enable the disabled service."""
|
||||
if self.operating_state == ServiceOperatingState.DISABLED:
|
||||
self.sys_log.info(f"Enabling Application {self.name}")
|
||||
self.operating_state = ServiceOperatingState.STOPPED
|
||||
self.health_state_actual = SoftwareHealthState.OVERWHELMED
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
@@ -181,5 +176,4 @@ class Service(IOSoftware):
|
||||
if self.restart_countdown <= 0:
|
||||
_LOGGER.debug(f"Restarting finished for service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RUNNING
|
||||
self.health_state_actual = SoftwareHealthState.GOOD
|
||||
self.restart_countdown -= 1
|
||||
|
||||
@@ -13,6 +13,7 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -123,7 +124,10 @@ class WebServer(Service):
|
||||
# get all users
|
||||
if db_client.query("SELECT"):
|
||||
# query succeeded
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
response.status_code = HttpStatusCode.OK
|
||||
else:
|
||||
self.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
|
||||
return response
|
||||
except Exception:
|
||||
|
||||
@@ -2,8 +2,8 @@ import copy
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.file_system.file_system import FileSystem, Folder
|
||||
@@ -38,12 +38,12 @@ class SoftwareHealthState(Enum):
|
||||
"Unused state."
|
||||
GOOD = 1
|
||||
"The software is in a good and healthy condition."
|
||||
COMPROMISED = 2
|
||||
"The software's security has been compromised."
|
||||
OVERWHELMED = 3
|
||||
"he software is overwhelmed and not functioning properly."
|
||||
PATCHING = 4
|
||||
PATCHING = 2
|
||||
"The software is undergoing patching or updates."
|
||||
COMPROMISED = 3
|
||||
"The software's security has been compromised."
|
||||
OVERWHELMED = 4
|
||||
"he software is overwhelmed and not functioning properly."
|
||||
|
||||
|
||||
class SoftwareCriticality(Enum):
|
||||
@@ -71,9 +71,9 @@ class Software(SimComponent):
|
||||
|
||||
name: str
|
||||
"The name of the software."
|
||||
health_state_actual: SoftwareHealthState = SoftwareHealthState.GOOD
|
||||
health_state_actual: SoftwareHealthState = SoftwareHealthState.UNUSED
|
||||
"The actual health state of the software."
|
||||
health_state_visible: SoftwareHealthState = SoftwareHealthState.GOOD
|
||||
health_state_visible: SoftwareHealthState = SoftwareHealthState.UNUSED
|
||||
"The health state of the software visible to the red agent."
|
||||
criticality: SoftwareCriticality = SoftwareCriticality.LOWEST
|
||||
"The criticality level of the software."
|
||||
@@ -195,8 +195,9 @@ class Software(SimComponent):
|
||||
|
||||
def patch(self) -> None:
|
||||
"""Perform a patch on the software."""
|
||||
self._patching_countdown = self.patching_duration
|
||||
self.set_health_state(SoftwareHealthState.PATCHING)
|
||||
if self.health_state_actual in (SoftwareHealthState.COMPROMISED, SoftwareHealthState.GOOD):
|
||||
self._patching_countdown = self.patching_duration
|
||||
self.set_health_state(SoftwareHealthState.PATCHING)
|
||||
|
||||
def _update_patch_status(self) -> None:
|
||||
"""Update the patch status of the software."""
|
||||
@@ -282,7 +283,7 @@ class IOSoftware(Software):
|
||||
|
||||
Returns true if the software can perform actions.
|
||||
"""
|
||||
if self.software_manager and self.software_manager.node.operating_state is NodeOperatingState.OFF:
|
||||
if self.software_manager and self.software_manager.node.operating_state != NodeOperatingState.ON:
|
||||
_LOGGER.debug(f"{self.name} Error: {self.software_manager.node.hostname} is not online.")
|
||||
return False
|
||||
return True
|
||||
@@ -303,13 +304,13 @@ class IOSoftware(Software):
|
||||
"""
|
||||
# if over or at capacity, set to overwhelmed
|
||||
if len(self._connections) >= self.max_sessions:
|
||||
self.health_state_actual = SoftwareHealthState.OVERWHELMED
|
||||
self.set_health_state(SoftwareHealthState.OVERWHELMED)
|
||||
self.sys_log.error(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.")
|
||||
return False
|
||||
else:
|
||||
# if service was previously overwhelmed, set to good because there is enough space for connections
|
||||
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
|
||||
self.health_state_actual = SoftwareHealthState.GOOD
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
|
||||
# check that connection already doesn't exist
|
||||
if not self._connections.get(connection_id):
|
||||
@@ -350,19 +351,22 @@ class IOSoftware(Software):
|
||||
self,
|
||||
payload: Any,
|
||||
session_id: Optional[str] = None,
|
||||
dest_ip_address: Optional[IPv4Address] = None,
|
||||
dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
|
||||
dest_port: Optional[Port] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""
|
||||
Sends a payload to the SessionManager.
|
||||
Sends a payload to the SessionManager for network transmission.
|
||||
|
||||
This method is responsible for initiating the process of sending network payloads. It supports both
|
||||
unicast and Layer 3 broadcast transmissions. For broadcasts, the destination IP should be specified
|
||||
as an IPv4Network. It delegates the actual sending process to the SoftwareManager.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param dest_ip_address: The ip address of the payload destination.
|
||||
:param dest_port: The port of the payload destination.
|
||||
:param session_id: The Session ID the payload is to originate from. Optional.
|
||||
|
||||
:return: True if successful, False otherwise.
|
||||
:param dest_ip_address: The IP address or network (for broadcasts) of the payload destination.
|
||||
:param dest_port: The destination port for the payload. Optional.
|
||||
:param session_id: The Session ID from which the payload originates. Optional.
|
||||
:return: True if the payload was successfully sent, False otherwise.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user