#1724 - Added the primaite/simulator/network/transmission sub-package with modules for each layer. They come together to build a minimal but fairly realistic network Frame. A custom PrimaiteHeader has been included to hold primaite specific metadata required in transmission for reward function and RL agent downstream. Added some basic tests that check the proper configuration of Frames with matching headers for protocols. Updated the frame typehints in NIC and Link classes.

This commit is contained in:
Chris McCarthy
2023-08-01 22:25:00 +01:00
parent f41fc241b7
commit 9d17a9b0d3
11 changed files with 577 additions and 9 deletions

View File

@@ -0,0 +1,100 @@
from typing import Any, Optional
from pydantic import BaseModel
from primaite import getLogger
from primaite.simulator.network.transmission.network_layer import ICMPHeader, IPPacket, IPProtocol
from primaite.simulator.network.transmission.primaite_layer import PrimaiteHeader
from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader
_LOGGER = getLogger(__name__)
class EthernetHeader(BaseModel):
"""
Represents the Ethernet layer of a network frame.
:param src_mac_addr: Source MAC address.
:param dst_mac_addr: Destination MAC address.
:Example:
>>> ethernet = EthernetHeader(
... src_mac_addr='AA:BB:CC:DD:EE:FF',
... dst_mac_addr='11:22:33:44:55:66'
... )
"""
src_mac_addr: str
"Source MAC address."
dst_mac_addr: str
"Destination MAC address."
class Frame(BaseModel):
"""
Represents a complete network frame with all layers.
:param ethernet: Ethernet layer.
:param ip: IP layer.
:param tcp: TCP layer.
:param payload: Payload data in the frame.
:Example:
>>> from ipaddress import IPv4Address
>>> frame=Frame(
... ethernet=EthernetHeader(
... src_mac_addr='AA:BB:CC:DD:EE:FF',
... dst_mac_addr='11:22:33:44:55:66'
... ),
... ip=IPPacket(
... src_ip=IPv4Address('192.168.0.1'),
... dst_ip=IPv4Address('10.0.0.1'),
... ),
... tcp=TCPHeader(
... src_port=8080,
... dst_port=80,
... ),
... payload=b"Hello, World!"
... )
"""
def __init__(self, **kwargs):
if kwargs.get("tcp") and kwargs.get("udp"):
msg = "Network Frame cannot have both a TCP header and a UDP header"
_LOGGER.error(msg)
raise ValueError(msg)
if kwargs["ip"].protocol == IPProtocol.TCP and not kwargs.get("tcp"):
msg = "Cannot build a Frame using the TCP IP Protocol without a TCPHeader"
_LOGGER.error(msg)
raise ValueError(msg)
if kwargs["ip"].protocol == IPProtocol.UDP and not kwargs.get("UDP"):
msg = "Cannot build a Frame using the UDP IP Protocol without a UDPHeader"
_LOGGER.error(msg)
raise ValueError(msg)
if kwargs["ip"].protocol == IPProtocol.ICMP and not kwargs.get("icmp"):
msg = "Cannot build a Frame using the ICMP IP Protocol without a ICMPHeader"
_LOGGER.error(msg)
raise ValueError(msg)
super().__init__(**kwargs)
ethernet: EthernetHeader
"Ethernet header."
ip: IPPacket
"IP packet."
tcp: Optional[TCPHeader] = None
"TCP header."
udp: Optional[UDPHeader] = None
"UDP header."
icmp: Optional[ICMPHeader] = None
"ICMP header."
primaite_header: PrimaiteHeader = PrimaiteHeader()
"PrimAITE header."
payload: Optional[Any] = None
"Raw data payload."
@property
def size(self) -> int:
"""The size of the Frame in Bytes."""
return len(self.model_dump_json().encode("utf-8"))

View File

@@ -0,0 +1,194 @@
import secrets
from enum import Enum
from ipaddress import IPv4Address
from typing import Union
from pydantic import BaseModel, field_validator, validate_call
from pydantic_core.core_schema import FieldValidationInfo
from primaite import getLogger
_LOGGER = getLogger(__name__)
class IPProtocol(Enum):
"""Enum representing transport layer protocols in IP header."""
TCP = "tcp"
UDP = "udp"
ICMP = "icmp"
class Precedence(Enum):
"""
Enum representing the Precedence levels in Quality of Service (QoS) for IP packets.
Precedence values range from 0 to 7, indicating different levels of priority.
Members:
- ROUTINE: 0 - Lowest priority level, used for ordinary data traffic that does not require special treatment.
- PRIORITY: 1 - Higher priority than ROUTINE, used for traffic that needs a bit more importance.
- IMMEDIATE: 2 - Used for more urgent traffic that requires immediate handling and minimal delay.
- FLASH: 3 - Used for highly urgent and important traffic that should be processed with high priority.
- FLASH_OVERRIDE: 4 - Higher priority than FLASH, used for critical traffic that takes precedence over most traffic.
- CRITICAL: 5 - Reserved for critical commands or control messages that are vital to the operation of the network.
- INTERNET: 6 - Used for network control messages, such as routing updates, for maintaining network operations.
- NETWORK: 7 - Highest priority for the most critical network control messages, such as routing protocol hellos.
"""
ROUTINE = 0
"Lowest priority level, used for ordinary data traffic that does not require special treatment."
PRIORITY = 1
"Higher priority than ROUTINE, used for traffic that needs a bit more importance."
IMMEDIATE = 2
"Used for more urgent traffic that requires immediate handling and minimal delay."
FLASH = 3
"Used for highly urgent and important traffic that should be processed with high priority."
FLASH_OVERRIDE = 4
"Has higher priority than FLASH, used for critical traffic that takes precedence over most other traffic."
CRITICAL = 5
"Reserved for critical commands or emergency control messages that are vital to the operation of the network."
INTERNET = 6
"Used for network control messages, such as routing updates, essential for maintaining network operations."
NETWORK = 7
"Highest priority level, used for the most critical network control messages, such as routing protocol hellos."
class ICMPType(Enum):
"""Enumeration of common ICMP (Internet Control Message Protocol) types."""
ECHO_REPLY = 0
"Echo Reply message."
DESTINATION_UNREACHABLE = 3
"Destination Unreachable."
REDIRECT = 5
"Redirect."
ECHO_REQUEST = 8
"Echo Request (ping)."
ROUTER_ADVERTISEMENT = 10
"Router Advertisement."
ROUTER_SOLICITATION = 11
"Router discovery/selection/solicitation."
TIME_EXCEEDED = 11
"Time Exceeded."
TIMESTAMP_REQUEST = 13
"Timestamp Request."
TIMESTAMP_REPLY = 14
"Timestamp Reply."
@validate_call
def get_icmp_type_code_description(icmp_type: ICMPType, icmp_code: int) -> Union[str, None]:
"""
Maps ICMPType and code pairings to their respective description.
:param icmp_type: An ICMPType.
:param icmp_code: An icmp code.
:return: The icmp type and code pairing description if it exists, otherwise returns None.
"""
icmp_code_descriptions = {
ICMPType.ECHO_REPLY: {0: "Echo reply"},
ICMPType.DESTINATION_UNREACHABLE: {
0: "Destination network unreachable",
1: "Destination host unreachable",
2: "Destination protocol unreachable",
3: "Destination port unreachable",
4: "Fragmentation required",
5: "Source route failed",
6: "Destination network unknown",
7: "Destination host unknown",
8: "Source host isolated",
9: "Network administratively prohibited",
10: "Host administratively prohibited",
11: "Network unreachable for ToS",
12: "Host unreachable for ToS",
13: "Communication administratively prohibited",
14: "Host Precedence Violation",
15: "Precedence cutoff in effect",
},
ICMPType.REDIRECT: {
0: "Redirect Datagram for the Network",
1: "Redirect Datagram for the Host",
},
ICMPType.ECHO_REQUEST: {0: "Echo request"},
ICMPType.ROUTER_ADVERTISEMENT: {0: "Router Advertisement"},
ICMPType.ROUTER_SOLICITATION: {0: "Router discovery/selection/solicitation"},
ICMPType.TIME_EXCEEDED: {0: "TTL expired in transit", 1: "Fragment reassembly time exceeded"},
ICMPType.TIMESTAMP_REQUEST: {0: "Timestamp Request"},
ICMPType.TIMESTAMP_REPLY: {0: "Timestamp reply"},
}
return icmp_code_descriptions[icmp_type].get(icmp_code)
class ICMPHeader(BaseModel):
"""Models an ICMP Header."""
icmp_type: ICMPType = ICMPType.ECHO_REQUEST
"ICMP Type."
icmp_code: int = 0
"ICMP Code."
identifier: str = secrets.randbits(16)
"ICMP identifier (16 bits randomly generated)."
sequence: int = 1
"ICMP message sequence number."
@field_validator("icmp_code") # noqa
@classmethod
def _icmp_type_must_have_icmp_code(cls, v: int, info: FieldValidationInfo) -> int:
"""Validates the icmp_type and icmp_code."""
icmp_type = info.data["icmp_type"]
if get_icmp_type_code_description(icmp_type, v):
return v
msg = f"No Matching ICMP code for type:{icmp_type.name}, code:{v}"
_LOGGER.error(msg)
raise ValueError(msg)
def code_description(self) -> str:
"""The icmp_code description."""
description = get_icmp_type_code_description(self.icmp_type, self.icmp_code)
if description:
return description
msg = f"No Matching ICMP code for type:{self.icmp_type.name}, code:{self.icmp_code}"
_LOGGER.error(msg)
raise ValueError(msg)
class IPPacket(BaseModel):
"""
Represents the IP layer of a network frame.
:param src_ip: Source IP address.
:param dst_ip: Destination IP address.
:param protocol: The IP protocol (default is TCP).
:param ttl: Time to Live (TTL) for the packet.
:param precedence: Precedence level for Quality of Service (QoS).
:Example:
>>> from ipaddress import IPv4Address
>>> ip_packet = IPPacket(
... src_ip=IPv4Address('192.168.0.1'),
... dst_ip=IPv4Address('10.0.0.1'),
... protocol=IPProtocol.TCP,
... ttl=64,
... precedence=Precedence.CRITICAL
... )
"""
src_ip: IPv4Address
"Source IP address."
dst_ip: IPv4Address
"Destination IP address."
protocol: IPProtocol = IPProtocol.TCP
"IPProtocol."
ttl: int = 64
"Time to Live (TTL) for the packet."
precedence: Precedence = Precedence.ROUTINE
"Precedence level for Quality of Service (default is Precedence.ROUTINE)."
def __init__(self, **kwargs):
if not isinstance(kwargs["src_ip"], IPv4Address):
kwargs["src_ip"] = IPv4Address(kwargs["src_ip"])
if not isinstance(kwargs["dst_ip"], IPv4Address):
kwargs["dst_ip"] = IPv4Address(kwargs["dst_ip"])
super().__init__(**kwargs)

View File

@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional
from primaite import getLogger
from primaite.exceptions import NetworkError
from primaite.simulator.core import SimComponent
from primaite.simulator.network.transmission.data_link_layer import Frame
_LOGGER = getLogger(__name__)
@@ -121,7 +122,7 @@ class NIC(SimComponent):
Connect the NIC to a link.
:param link: The link to which the NIC is connected.
:type link: :class:`~primaite.simulator.network.physical_layer.Link`
:type link: :class:`~primaite.simulator.network.transmission.physical_layer.Link`
:raise NetworkError: When an attempt to connect a Link is made while the NIC has a connected Link.
"""
if not self.connected_link:
@@ -136,7 +137,7 @@ class NIC(SimComponent):
raise NetworkError(msg)
def disconnect_link(self):
"""Disconnect the NIC from the connected :class:`~primaite.simulator.network.physical_layer.Link`."""
"""Disconnect the NIC from the connected Link."""
if self.connected_link.endpoint_a == self:
self.connected_link.endpoint_a = None
if self.connected_link.endpoint_b == self:
@@ -161,7 +162,7 @@ class NIC(SimComponent):
"""
pass
def send_frame(self, frame: Any):
def send_frame(self, frame: Frame):
"""
Send a network frame from the NIC to the connected link.
@@ -170,7 +171,7 @@ class NIC(SimComponent):
"""
pass
def receive_frame(self, frame: Any):
def receive_frame(self, frame: Frame):
"""
Receive a network frame from the connected link.
@@ -222,7 +223,7 @@ class Link(SimComponent):
def model_post_init(self, __context: Any) -> None:
"""
Ensure that endpoint_a and endpoint_b are not the same :class:`~primaite.simulator.network.physical_layer.NIC`.
Ensure that endpoint_a and endpoint_b are not the same NIC.
:raises ValueError: If endpoint_a and endpoint_b are the same NIC.
"""
@@ -233,7 +234,7 @@ class Link(SimComponent):
self.endpoint_a.connect_link(self)
self.endpoint_b.connect_link(self)
def send_frame(self, sender_nic: NIC, frame: Any):
def send_frame(self, sender_nic: NIC, frame: Frame):
"""
Send a network frame from one NIC to another connected NIC.
@@ -244,7 +245,7 @@ class Link(SimComponent):
"""
pass
def receive_frame(self, sender_nic: NIC, frame: Any):
def receive_frame(self, sender_nic: NIC, frame: Frame):
"""
Receive a network frame from a connected NIC.

View File

@@ -0,0 +1,40 @@
from enum import Enum
from pydantic import BaseModel
class DataStatus(Enum):
"""
The status of the data in transmission.
Members:
- GOOD: 1
- COMPROMISED: 2
- CORRUPT: 3
"""
GOOD = 1
COMPROMISED = 2
CORRUPT = 3
class AgentSource(Enum):
"""
The agent source of the transmission.
Members:
- RED: 1
- GREEN: 2
- BLUE: 3
"""
RED = 1
GREEN = 2
BLUE = 3
class PrimaiteHeader(BaseModel):
"""A custom header for carrying PrimAITE transmission metadata required for RL."""
agent_source: AgentSource = AgentSource.GREEN
data_status: DataStatus = DataStatus.GOOD

View File

@@ -0,0 +1,119 @@
from enum import Enum
from typing import List, Union
from pydantic import BaseModel
class Port(Enum):
"""Enumeration of common known TCP/UDP ports used by protocols for operation of network applications."""
WOL = 9
"Wake-on-Lan (WOL) - Used to turn or awaken a computer from sleep mode by a network message."
FTP_DATA = 20
"File Transfer [Default Data]"
FTP = 21
"File Transfer Protocol (FTP) - FTP control (command)"
SSH = 22
"Secure Shell (SSH) - Used for secure remote access and command execution."
SMTP = 25
"Simple Mail Transfer Protocol (SMTP) - Used for email delivery between servers."
DNS = 53
"Domain Name System (DNS) - Used for translating domain names to IP addresses."
HTTP = 80
"HyperText Transfer Protocol (HTTP) - Used for web traffic."
POP3 = 110
"Post Office Protocol version 3 (POP3) - Used for retrieving emails from a mail server."
SFTP = 115
"Secure File Transfer Protocol (SFTP) - Used for secure file transfer over SSH."
NTP = 123
"Network Time Protocol (NTP) - Used for clock synchronization between computer systems."
IMAP = 143
"Internet Message Access Protocol (IMAP) - Used for retrieving emails from a mail server."
SNMP = 161
"Simple Network Management Protocol (SNMP) - Used for network device management."
SNMP_TRAP = 162
"SNMP Trap - Used for sending SNMP notifications (traps) to a network management system."
LDAP = 389
"Lightweight Directory Access Protocol (LDAP) - Used for accessing and modifying directory information."
HTTPS = 443
"HyperText Transfer Protocol Secure (HTTPS) - Used for secure web traffic."
SMB = 445
"Server Message Block (SMB) - Used for file sharing and printer sharing in Windows environments."
IPP = 631
"Internet Printing Protocol (IPP) - Used for printing over the internet or an intranet."
SQL_SERVER = 1433
"Microsoft SQL Server Database Engine - Used for communication with the SQL Server."
MYSQL = 3306
"MySQL Database Server - Used for MySQL database communication."
RDP = 3389
"Remote Desktop Protocol (RDP) - Used for remote desktop access to Windows machines."
RTP = 5004
"Real-time Transport Protocol (RTP) - Used for transmitting real-time media, e.g., audio and video."
RTP_ALT = 5005
"Alternative port for RTP (RTP_ALT) - Used in some configurations for transmitting real-time media."
DNS_ALT = 5353
"Alternative port for DNS (DNS_ALT) - Used in some configurations for DNS service."
HTTP_ALT = 8080
"Alternative port for HTTP (HTTP_ALT) - Often used as an alternative HTTP port for web applications."
HTTPS_ALT = 8443
"Alternative port for HTTPS (HTTPS_ALT) - Used in some configurations for secure web traffic."
class UDPHeader(BaseModel):
"""
Represents a UDP header for the transport layer of a Network Frame.
:param src_port: Source port.
:param dst_port: Destination port.
:Example:
>>> udp_header = UDPHeader(
... src_port=Port.HTTP_ALT,
... dst_port=Port.HTTP,
... )
"""
src_port: Union[Port, int]
dst_port: Union[Port, int]
class TCPFlags(Enum):
"""
Enum representing TCP control flags used in a TCP connection.
Flags are used to indicate a particular state of the connection or provide additional information.
Members:
- SYN: (1) - Used in the first step of connection establishment phase or 3-way handshake process between two hosts.
- ACK: (2) - Used to acknowledge packets that are successfully received by the host.
- FIN: (4) - Used to request connection termination when there is no more data from the sender.
- RST: (8) - Used to terminate the connection if there is an issue with the TCP connection.
"""
SYN = 1
ACK = 2
FIN = 4
RST = 8
class TCPHeader(BaseModel):
"""
Represents a TCP header for the transport layer of a Network Frame.
:param src_port: Source port.
:param dst_port: Destination port.
:param flags: TCP flags (list of TCPFlags members).
:Example:
>>> tcp_header = TCPHeader(
... src_port=Port.HTTP_ALT,
... dst_port=Port.HTTP,
... flags=[TCPFlags.SYN, TCPFlags.ACK]
... )
"""
src_port: int
dst_port: int
flags: List[TCPFlags] = [TCPFlags.SYN]

View File

@@ -1,6 +1,6 @@
import pytest
from primaite.simulator.network.physical_layer import Link, NIC
from primaite.simulator.network.transmission.physical_layer import Link, NIC
def test_link_fails_with_same_nic():

View File

@@ -0,0 +1,90 @@
import pytest
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
from primaite.simulator.network.transmission.network_layer import ICMPHeader, IPPacket, IPProtocol, Precedence
from primaite.simulator.network.transmission.primaite_layer import AgentSource, DataStatus
from primaite.simulator.network.transmission.transport_layer import Port, TCPFlags, TCPHeader, UDPHeader
def test_frame_minimal_instantiation():
"""Tests that the minimum frame (TCP SYN) using default values."""
frame = Frame(
ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"),
ip=IPPacket(src_ip="192.168.0.10", dst_ip="192.168.0.20"),
tcp=TCPHeader(
src_port=8080,
dst_port=80,
),
)
# Check network layer default values
assert frame.ip.protocol == IPProtocol.TCP
assert frame.ip.ttl == 64
assert frame.ip.precedence == Precedence.ROUTINE
# Check transport layer default values
assert frame.tcp.flags == [TCPFlags.SYN]
# Check primaite custom header default values
assert frame.primaite_header.agent_source == AgentSource.GREEN
assert frame.primaite_header.data_status == DataStatus.GOOD
# Check that model can be dumped down to json and returned as size in Bytes
assert frame.size
def test_frame_creation_fails_tcp_without_header():
"""Tests Frame creation fails if the IPProtocol is TCP but there is no TCPHeader."""
with pytest.raises(ValueError):
Frame(
ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"),
ip=IPPacket(src_ip="192.168.0.10", dst_ip="192.168.0.20", protocol=IPProtocol.TCP),
)
def test_frame_creation_fails_udp_without_header():
"""Tests Frame creation fails if the IPProtocol is UDP but there is no UDPHeader."""
with pytest.raises(ValueError):
Frame(
ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"),
ip=IPPacket(src_ip="192.168.0.10", dst_ip="192.168.0.20", protocol=IPProtocol.UDP),
)
def test_frame_creation_fails_tcp_with_udp_header():
"""Tests Frame creation fails if the IPProtocol is TCP but there is a UDPHeader."""
with pytest.raises(ValueError):
Frame(
ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"),
ip=IPPacket(src_ip="192.168.0.10", dst_ip="192.168.0.20", protocol=IPProtocol.TCP),
udp=UDPHeader(src_port=8080, dst_port=80),
)
def test_frame_creation_fails_udp_with_tcp_header():
"""Tests Frame creation fails if the IPProtocol is UDP but there is a TCPHeader."""
with pytest.raises(ValueError):
Frame(
ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"),
ip=IPPacket(src_ip="192.168.0.10", dst_ip="192.168.0.20", protocol=IPProtocol.UDP),
udp=TCPHeader(src_port=8080, dst_port=80),
)
def test_icmp_frame_creation():
"""Tests Frame creation for ICMP."""
frame = Frame(
ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"),
ip=IPPacket(src_ip="192.168.0.10", dst_ip="192.168.0.20", protocol=IPProtocol.ICMP),
icmp=ICMPHeader(),
)
assert frame
def test_icmp_frame_creation_fails_without_icmp_header():
"""Tests Frame creation for ICMP."""
with pytest.raises(ValueError):
Frame(
ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"),
ip=IPPacket(src_ip="192.168.0.10", dst_ip="192.168.0.20", protocol=IPProtocol.ICMP),
)

View File

@@ -0,0 +1,24 @@
import pytest
from primaite.simulator.network.transmission.network_layer import ICMPHeader, ICMPType
def test_icmp_minimal_header_creation():
"""Checks the minimal ICMPHeader (ping 1 request) creation using default values."""
ping = ICMPHeader()
assert ping.icmp_type == ICMPType.ECHO_REQUEST
assert ping.icmp_code == 0
assert ping.identifier
assert ping.sequence == 1
def test_valid_icmp_type_code_pairing():
"""Tests ICMPHeader creation with valid type and code pairing."""
assert ICMPHeader(icmp_type=ICMPType.DESTINATION_UNREACHABLE, icmp_code=6)
def test_invalid_icmp_type_code_pairing():
"""Tests ICMPHeader creation fails with invalid type and code pairing."""
with pytest.raises(ValueError):
assert ICMPHeader(icmp_type=ICMPType.DESTINATION_UNREACHABLE, icmp_code=16)

View File

@@ -3,7 +3,7 @@ from ipaddress import IPv4Address
import pytest
from primaite.simulator.network.physical_layer import generate_mac_address, NIC
from primaite.simulator.network.transmission.physical_layer import generate_mac_address, NIC
def test_mac_address_generation():