diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d08974c..f51fd648 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - File and folder observations can now be configured to always show the true health status, or require scanning like before. - It's now possible to disable stickiness on reward components, meaning their value returns to 0 during timesteps where agent don't issue the corresponding action. Affects `GreenAdminDatabaseUnreachablePenalty`, `WebpageUnavailablePenalty`, `WebServer404Penalty` - Node observations can now be configured to show the number of active local and remote logins. +- Ports, IP Protocols, and airspace frequencies no longer use enums. They are defined in dictionary lookups and are handled by custom validation to enable extendability with plugins. ### Fixed - Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config) diff --git a/docs/source/simulation_components/network/network.rst b/docs/source/simulation_components/network/network.rst index 636ffbcc..4cc121a3 100644 --- a/docs/source/simulation_components/network/network.rst +++ b/docs/source/simulation_components/network/network.rst @@ -103,13 +103,13 @@ we'll use the following Network that has a client, server, two switches, and a r router_1.acl.add_rule( action=ACLAction.PERMIT, - src_port=Port.ARP, - dst_port=Port.ARP, + src_port=Port["ARP"], + dst_port=Port["ARP"], position=22 ) router_1.acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.ICMP, + protocol=IPProtocol["ICMP"], position=23 ) diff --git a/docs/source/simulation_components/network/nodes/firewall.rst b/docs/source/simulation_components/network/nodes/firewall.rst index 149d3e67..1ef16d63 100644 --- a/docs/source/simulation_components/network/nodes/firewall.rst +++ b/docs/source/simulation_components/network/nodes/firewall.rst @@ -156,8 +156,8 @@ To prevent all external traffic from accessing the internal network, with except # Exception rule to allow HTTP traffic from external to internal network firewall.internal_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTP, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTP"], dst_ip_address="192.168.1.0", dst_wildcard_mask="0.0.0.255", position=2 @@ -172,16 +172,16 @@ To enable external traffic to access specific services hosted within the DMZ: # Allow HTTP and HTTPS traffic to the DMZ firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTP, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTP"], dst_ip_address="172.16.0.0", dst_wildcard_mask="0.0.0.255", position=3 ) firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTPS, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.0", dst_wildcard_mask="0.0.0.255", position=4 @@ -196,9 +196,9 @@ To permit SSH access from a designated external IP to a specific server within t # Allow SSH from a specific external IP to an internal server firewall.internal_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="10.0.0.2", - dst_port=Port.SSH, + dst_port=Port["SSH"], dst_ip_address="192.168.1.10", position=5 ) @@ -212,9 +212,9 @@ To limit database server access to selected external IP addresses: # Allow PostgreSQL traffic from an authorized external IP to the internal DB server firewall.internal_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="10.0.0.3", - dst_port=Port.POSTGRES_SERVER, + dst_port=Port["POSTGRES_SERVER"], dst_ip_address="192.168.1.20", position=6 ) @@ -222,8 +222,8 @@ To limit database server access to selected external IP addresses: # Deny all other PostgreSQL traffic from external sources firewall.internal_inbound_acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, - dst_port=Port.POSTGRES_SERVER, + protocol=IPProtocol["TCP"], + dst_port=Port["POSTGRES_SERVER"], dst_ip_address="192.168.1.0", dst_wildcard_mask="0.0.0.255", position=7 @@ -247,15 +247,15 @@ To authorize HTTP/HTTPS access to a DMZ-hosted web server, excluding known malic # Allow HTTP/HTTPS traffic to the DMZ web server firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTP, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTP"], dst_ip_address="172.16.0.2", position=9 ) firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTPS, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.2", position=10 ) @@ -269,9 +269,9 @@ To facilitate restricted access from the internal network to DMZ-hosted services # Permit specific internal application server HTTPS access to a DMZ-hosted API firewall.internal_outbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="192.168.1.30", # Internal application server IP - dst_port=Port.HTTPS, + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.3", # DMZ API server IP position=11 ) @@ -289,9 +289,9 @@ To facilitate restricted access from the internal network to DMZ-hosted services # Corresponding rule in DMZ inbound ACL to allow the traffic from the specific internal server firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="192.168.1.30", # Ensuring this specific source is allowed - dst_port=Port.HTTPS, + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.3", # DMZ API server IP position=13 ) @@ -301,7 +301,7 @@ To facilitate restricted access from the internal network to DMZ-hosted services action=ACLAction.DENY, src_ip_address="192.168.1.0", src_wildcard_mask="0.0.0.255", - dst_port=Port.HTTPS, + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.3", # DMZ API server IP position=14 ) @@ -315,8 +315,8 @@ To block all SSH access attempts from the external network: # Deny all SSH traffic from any external source firewall.external_inbound_acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, - dst_port=Port.SSH, + protocol=IPProtocol["TCP"], + dst_port=Port["SSH"], position=1 ) @@ -329,8 +329,8 @@ To allow the internal network to initiate HTTP connections to the external netwo # Permit outgoing HTTP traffic from the internal network to any external destination firewall.external_outbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTP, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTP"], position=2 ) diff --git a/docs/source/simulation_components/network/nodes/wireless_router.rst b/docs/source/simulation_components/network/nodes/wireless_router.rst index c78c8419..436852ea 100644 --- a/docs/source/simulation_components/network/nodes/wireless_router.rst +++ b/docs/source/simulation_components/network/nodes/wireless_router.rst @@ -49,7 +49,7 @@ additional steps to configure wireless settings: wireless_router.configure_wireless_access_point( port=1, ip_address="192.168.2.1", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency.WIFI_2_4, + frequency="WIFI_2_4", ) @@ -102,8 +102,8 @@ ICMP traffic, ensuring basic network connectivity and ping functionality. network.connect(pc_a.network_interface[1], router_1.router_interface) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) # Configure PC B pc_b = Computer( @@ -130,13 +130,13 @@ ICMP traffic, ensuring basic network connectivity and ping functionality. port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency.WIFI_2_4, + frequency="WIFI_2_4", ) router_2.configure_wireless_access_point( port=1, ip_address="192.168.1.2", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency.WIFI_2_4, + frequency="WIFI_2_4", ) # Configure routes for inter-router communication diff --git a/docs/source/simulation_components/network/transport_to_data_link_layer.rst b/docs/source/simulation_components/network/transport_to_data_link_layer.rst index cc546021..02bfdcdc 100644 --- a/docs/source/simulation_components/network/transport_to_data_link_layer.rst +++ b/docs/source/simulation_components/network/transport_to_data_link_layer.rst @@ -104,7 +104,7 @@ address of 'aa:bb:cc:dd:ee:ff' to port 8080 on the host 10.0.0.10 which has a NI ip_packet = IPPacket( src_ip_address="192.168.0.100", dst_ip_address="10.0.0.10", - protocol=IPProtocol.TCP + protocol=IPProtocol["TCP"] ) # Data Link Layer ethernet_header = EthernetHeader( diff --git a/docs/source/simulation_components/system/applications/nmap.rst b/docs/source/simulation_components/system/applications/nmap.rst index 1e7f5ea4..e2cd474e 100644 --- a/docs/source/simulation_components/system/applications/nmap.rst +++ b/docs/source/simulation_components/system/applications/nmap.rst @@ -165,8 +165,8 @@ Perform a horizontal port scan on port 5432 across multiple IP addresses: { IPv4Address('192.168.1.12'): { - : [ - + : [ + ] } } @@ -192,7 +192,7 @@ Perform a vertical port scan on multiple ports on a single IP address: vertical_scan_results = pc_1_nmap.port_scan( target_ip_address=[IPv4Address("192.168.1.12")], - target_port=[Port(21), Port(22), Port(80), Port(443)] + target_port=[21, 22, 80, 443] ) .. code-block:: python @@ -200,9 +200,9 @@ Perform a vertical port scan on multiple ports on a single IP address: { IPv4Address('192.168.1.12'): { - : [ - , - + : [ + , + ] } } @@ -233,7 +233,7 @@ Perform a box scan on multiple ports across multiple IP addresses: box_scan_results = pc_1_nmap.port_scan( target_ip_address=[IPv4Address("192.168.1.12"), IPv4Address("192.168.1.13")], - target_port=[Port(21), Port(22), Port(80), Port(443)] + target_port=[21, 22, 80, 443] ) .. code-block:: python @@ -241,15 +241,15 @@ Perform a box scan on multiple ports across multiple IP addresses: { IPv4Address('192.168.1.13'): { - : [ - , - + : [ + , + ] }, IPv4Address('192.168.1.12'): { - : [ - , - + : [ + , + ] } } @@ -289,36 +289,36 @@ Perform a full box scan on all ports, over both TCP and UDP, on a whole subnet: { IPv4Address('192.168.1.11'): { - : [ - + : [ + ] }, IPv4Address('192.168.1.1'): { - : [ - + : [ + ] }, IPv4Address('192.168.1.12'): { - : [ - , - , - , - + : [ + , + , + , + ], - : [ - , - + : [ + , + ] }, IPv4Address('192.168.1.13'): { - : [ - , - , - + : [ + , + , + ], - : [ - , - + : [ + , + ] } } diff --git a/docs/source/simulation_components/system/services/ftp_client.rst b/docs/source/simulation_components/system/services/ftp_client.rst index fdf9cfcf..0c9a781c 100644 --- a/docs/source/simulation_components/system/services/ftp_client.rst +++ b/docs/source/simulation_components/system/services/ftp_client.rst @@ -15,7 +15,7 @@ Key features - Connects to the :ref:`FTPServer` via the ``SoftwareManager``. - Simulates FTP requests and FTPPacket transfer across a network - Allows the emulation of FTP commands between an FTP client and server: - - PORT: specifies the port that server should connect to on the client (currently only uses ``Port.FTP``) + - PORT: specifies the port that server should connect to on the client (currently only uses ``Port["FTP"]``) - STOR: stores a file from client to server - RETR: retrieves a file from the FTP server - QUIT: disconnect from server diff --git a/src/primaite/config/load.py b/src/primaite/config/load.py index b00c26f6..39040d76 100644 --- a/src/primaite/config/load.py +++ b/src/primaite/config/load.py @@ -60,9 +60,10 @@ def data_manipulation_marl_config_path() -> Path: raise FileNotFoundError(msg) return path + def get_extended_config_path() -> Path: """ - Get the path to an 'extended' example config that contains nodes using the extension framework + Get the path to an 'extended' example config that contains nodes using the extension framework. :return: Path to the extended example config :rtype: Path @@ -72,4 +73,4 @@ def get_extended_config_path() -> Path: msg = f"Example config does not exist: {path}. Have you run `primaite setup`?" _LOGGER.error(msg) raise FileNotFoundError(msg) - return path \ No newline at end of file + return path diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 4419ccc7..617e8eee 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -12,6 +12,8 @@ from primaite.game.agent.observations.nic_observations import NICObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.port import Port _LOGGER = getLogger(__name__) @@ -44,7 +46,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): """Number of spaces for network interface observations on this host.""" include_nmne: Optional[bool] = None """Whether network interface observations should include number of malicious network events.""" - monitored_traffic: Optional[Dict] = None + monitored_traffic: Optional[Dict[IPProtocol, List[Port]]] = None """A dict containing which traffic types are to be included in the observation.""" include_num_access: Optional[bool] = None """Whether to include the number of accesses to files observations on this host.""" diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 002ee4da..d180b641 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -1,14 +1,15 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import Dict, Optional +from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.port import Port class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): @@ -21,7 +22,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): """Number of the network interface.""" include_nmne: Optional[bool] = None """Whether to include number of malicious network events (NMNE) in the observation.""" - monitored_traffic: Optional[Dict] = None + monitored_traffic: Optional[Dict[IPProtocol, List[Port]]] = None """A dict containing which traffic types are to be included in the observation.""" def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None: @@ -58,7 +59,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): def _default_monitored_traffic_observation(self, monitored_traffic_config: Dict) -> Dict: default_traffic_obs = {"TRAFFIC": {}} - for protocol in monitored_traffic_config: + for protocol in self.monitored_traffic: protocol = str(protocol).lower() default_traffic_obs["TRAFFIC"][protocol] = {} @@ -66,8 +67,8 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): default_traffic_obs["TRAFFIC"]["icmp"] = {"inbound": 0, "outbound": 0} else: default_traffic_obs["TRAFFIC"][protocol] = {} - for port in monitored_traffic_config[protocol]: - default_traffic_obs["TRAFFIC"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0} + for port in self.monitored_traffic[protocol]: + default_traffic_obs["TRAFFIC"][protocol][port] = {"inbound": 0, "outbound": 0} return default_traffic_obs @@ -142,17 +143,16 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): } else: for port in self.monitored_traffic[protocol]: - port_enum = Port[port] - obs["TRAFFIC"][protocol][port_enum.value] = {} + obs["TRAFFIC"][protocol][port] = {} traffic = {"inbound": 0, "outbound": 0} - if nic_state["traffic"][protocol].get(port_enum.value) is not None: - traffic = nic_state["traffic"][protocol][port_enum.value] + if nic_state["traffic"][protocol].get(port) is not None: + traffic = nic_state["traffic"][protocol][port] - obs["TRAFFIC"][protocol][port_enum.value]["inbound"] = self._categorise_traffic( + obs["TRAFFIC"][protocol][port]["inbound"] = self._categorise_traffic( traffic_value=traffic["inbound"], nic_state=nic_state ) - obs["TRAFFIC"][protocol][port_enum.value]["outbound"] = self._categorise_traffic( + obs["TRAFFIC"][protocol][port]["outbound"] = self._categorise_traffic( traffic_value=traffic["outbound"], nic_state=nic_state ) @@ -162,7 +162,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): obs["TRAFFIC"]["icmp"] = {"inbound": 0, "outbound": 0} else: for port in self.monitored_traffic[protocol]: - obs["TRAFFIC"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0} + obs["TRAFFIC"][protocol][port] = {"inbound": 0, "outbound": 0} if self.include_nmne: obs.update({"NMNE": {}}) @@ -201,7 +201,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): else: space["TRAFFIC"][protocol] = spaces.Dict({}) for port in self.monitored_traffic[protocol]: - space["TRAFFIC"][protocol][Port[port].value] = spaces.Dict( + space["TRAFFIC"][protocol][port] = spaces.Dict( {"inbound": spaces.Discrete(11), "outbound": spaces.Discrete(11)} ) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index e263cadb..e11521b6 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -12,6 +12,8 @@ from primaite.game.agent.observations.firewall_observation import FirewallObserv from primaite.game.agent.observations.host_observations import HostObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.observations.router_observation import RouterObservation +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.port import Port _LOGGER = getLogger(__name__) @@ -40,7 +42,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"): """Number of network interface cards (NICs).""" include_nmne: Optional[bool] = None """Flag to include nmne.""" - monitored_traffic: Optional[Dict] = None + monitored_traffic: Optional[Dict[IPProtocol, List[Port]]] = None """A dict containing which traffic types are to be included in the observation.""" include_num_access: Optional[bool] = None """Flag to include the number of accesses.""" diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 11c968af..6d1c0920 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -17,10 +17,9 @@ from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT -from primaite.simulator.network.airspace import AirSpaceFrequency from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState, UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.hardware.nodes.host.host_node import NIC, HostNode +from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC from primaite.simulator.network.hardware.nodes.host.server import Printer, Server from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode @@ -28,8 +27,6 @@ from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from primaite.simulator.network.nmne import NMNEConfig -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient # noqa: F401 @@ -52,6 +49,8 @@ from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.terminal.terminal import Terminal from primaite.simulator.system.services.web_server.web_server import WebServer from primaite.simulator.system.software import Software +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -82,9 +81,9 @@ class PrimaiteGameOptions(BaseModel): """Random number seed for RNGs.""" max_episode_length: int = 256 """Maximum number of episodes for the PrimAITE game.""" - ports: List[str] + ports: List[Port] """A whitelist of available ports in the simulation.""" - protocols: List[str] + protocols: List[IPProtocol] """A whitelist of available protocols in the simulation.""" thresholds: Optional[Dict] = {} """A dict containing the thresholds used for determining what is acceptable during observations.""" @@ -267,10 +266,7 @@ class PrimaiteGame: network_config = simulation_config.get("network", {}) airspace_cfg = network_config.get("airspace", {}) frequency_max_capacity_mbps_cfg = airspace_cfg.get("frequency_max_capacity_mbps", {}) - - frequency_max_capacity_mbps_cfg = {AirSpaceFrequency[k]: v for k, v in frequency_max_capacity_mbps_cfg.items()} - - net.airspace.frequency_max_capacity_mbps_ = frequency_max_capacity_mbps_cfg + net.airspace.set_frequency_max_capacity_mbps(frequency_max_capacity_mbps_cfg) nodes_cfg = network_config.get("nodes", []) links_cfg = network_config.get("links", []) @@ -291,11 +287,10 @@ class PrimaiteGame: dns_server=node_cfg.get("dns_server", None), operating_state=NodeOperatingState.ON if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()]) - elif n_type in NetworkNode._registry: - new_node = NetworkNode._registry[n_type]( - **node_cfg + else NodeOperatingState[p.upper()], ) + elif n_type in NetworkNode._registry: + new_node = NetworkNode._registry[n_type](**node_cfg) # Default PrimAITE nodes elif n_type == "computer": new_node = Computer( @@ -358,9 +353,9 @@ class PrimaiteGame: for port_id in set(software_cfg.get("options", {}).get("listen_on_ports", [])): port = None if isinstance(port_id, int): - port = Port(port_id) + port = port_id elif isinstance(port_id, str): - port = Port[port_id] + port = PORT_LOOKUP[port_id] if port: listen_on_ports.append(port) software.listen_on_ports = set(listen_on_ports) @@ -475,7 +470,7 @@ class PrimaiteGame: opt = application_cfg["options"] new_application.configure( target_ip_address=IPv4Address(opt.get("target_ip_address")), - target_port=Port(opt.get("target_port", Port.POSTGRES_SERVER.value)), + target_port=PORT_LOOKUP[opt.get("target_port", "POSTGRES_SERVER")], payload=opt.get("payload"), repeat=bool(opt.get("repeat")), port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), @@ -488,8 +483,10 @@ class PrimaiteGame: new_application.configure( c2_server_ip_address=IPv4Address(opt.get("c2_server_ip_address")), keep_alive_frequency=(opt.get("keep_alive_frequency", 5)), - masquerade_protocol=IPProtocol[(opt.get("masquerade_protocol", IPProtocol.TCP))], - masquerade_port=Port[(opt.get("masquerade_port", Port.HTTP))], + masquerade_protocol=PROTOCOL_LOOKUP[ + (opt.get("masquerade_protocol", PROTOCOL_LOOKUP["TCP"])) + ], + masquerade_port=PORT_LOOKUP[(opt.get("masquerade_port", PORT_LOOKUP["HTTP"]))], ) if "network_interfaces" in node_cfg: for nic_num, nic_cfg in node_cfg["network_interfaces"].items(): diff --git a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb b/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb index b6b13f28..a5cc385b 100644 --- a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb @@ -1783,7 +1783,7 @@ "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", "from primaite.simulator.network.transmission.transport_layer import Port\n", "# As we're configuring via the PrimAITE API we need to pass the actual IPProtocol/Port (Agents leverage the simulation via the game layer and thus can pass strings).\n", - "c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=IPProtocol.UDP, masquerade_port=Port.DNS)\n", + "c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=IPProtocol["UDP"], masquerade_port=Port["DNS"])\n", "c2_beacon.establish()\n", "c2_beacon.show()" ] diff --git a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb index 77ac4842..f573f251 100644 --- a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb +++ b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb @@ -182,7 +182,7 @@ "metadata": {}, "outputs": [], "source": [ - "mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port.HTTP, protocol = IPProtocol.NONE,operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)" + "mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port["HTTP"], protocol = IPProtocol["NONE"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)" ] }, { diff --git a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb b/src/primaite/simulator/_package_data/network_simulator_demo.ipynb index 17a0f796..2d5b4772 100644 --- a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb +++ b/src/primaite/simulator/_package_data/network_simulator_demo.ipynb @@ -537,7 +537,7 @@ "from primaite.simulator.network.hardware.nodes.network.router import ACLAction\n", "network.get_node_by_hostname(\"router_1\").acl.add_rule(\n", " action=ACLAction.DENY,\n", - " protocol=IPProtocol.ICMP,\n", + " protocol=IPProtocol["ICMP"],\n", " src_ip_address=\"192.168.10.22\",\n", " position=1\n", ")" diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index cdb01514..03d43130 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -1,12 +1,12 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations +import copy from abc import ABC, abstractmethod -from enum import Enum from typing import Any, Dict, List from prettytable import MARKDOWN, PrettyTable -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validate_call from primaite import getLogger from primaite.simulator.network.hardware.base import Layer3Interface, NetworkInterface, WiredNetworkInterface @@ -41,50 +41,30 @@ def format_hertz(hertz: float, format_terahertz: bool = False, decimals: int = 3 return format_str.format(hertz) + " Hz" -class AirSpaceFrequency(Enum): - """Enumeration representing the operating frequencies for wireless communications.""" +_default_frequency_set: Dict[str, Dict] = { + "WIFI_2_4": {"frequency": 2.4e9, "data_rate_bps": 100_000_000.0}, + "WIFI_5": {"frequency": 5e9, "data_rate_bps": 500_000_000.0}, +} +"""Frequency configuration that is automatically used for any new airspace.""" - WIFI_2_4 = 2.4e9 - """WiFi 2.4 GHz. Known for its extensive range and ability to penetrate solid objects effectively.""" - WIFI_5 = 5e9 - """WiFi 5 GHz. Known for its higher data transmission speeds and reduced interference from other devices.""" - def __str__(self) -> str: - hertz_str = format_hertz(hertz=self.value) - if self == AirSpaceFrequency.WIFI_2_4: - return f"WiFi {hertz_str}" - if self == AirSpaceFrequency.WIFI_5: - return f"WiFi {hertz_str}" - return "Unknown Frequency" +def register_default_frequency(freq_name: str, freq_hz: float, data_rate_bps: float) -> None: + """Add to the default frequency configuration. This is intended as a plugin hook. - @property - def maximum_data_rate_bps(self) -> float: - """ - Retrieves the maximum data transmission rate in bits per second (bps). + If your plugin makes use of bespoke frequencies for wireless communication, you should make a call to this method + wherever you define components that rely on the bespoke frequencies. That way, as soon as your components are + imported, this function automatically updates the default frequency set. - The maximum rates are predefined for frequencies.: - - WIFI 2.4 supports 100,000,000 bps - - WIFI 5 supports 500,000,000 bps + This should also be run before instances of AirSpace are created. - :return: The maximum data rate in bits per second. - """ - if self == AirSpaceFrequency.WIFI_2_4: - return 100_000_000.0 # 100 Megabits per second - if self == AirSpaceFrequency.WIFI_5: - return 500_000_000.0 # 500 Megabits per second - return 0.0 - - @property - def maximum_data_rate_mbps(self) -> float: - """ - Retrieves the maximum data transmission rate in megabits per second (Mbps). - - This is derived by converting the maximum data rate from bits per second, as defined - in `maximum_data_rate_bps`, to megabits per second. - - :return: The maximum data rate in megabits per second. - """ - return self.maximum_data_rate_bps / 1_000_000.0 + :param freq_name: The frequency name. If this clashes with an existing frequency name, it will be overwritten. + :type freq_name: str + :param freq_hz: The frequency itself, measured in Hertz. + :type freq_hz: float + :param data_rate_bps: The transmission capacity over this frequency, in bits per second. + :type data_rate_bps: float + """ + _default_frequency_set.update({freq_name: {"frequency": freq_hz, "data_rate_bps": data_rate_bps}}) class AirSpace(BaseModel): @@ -97,38 +77,51 @@ class AirSpace(BaseModel): """ wireless_interfaces: Dict[str, WirelessNetworkInterface] = Field(default_factory=lambda: {}) - wireless_interfaces_by_frequency: Dict[AirSpaceFrequency, List[WirelessNetworkInterface]] = Field( - default_factory=lambda: {} - ) - bandwidth_load: Dict[AirSpaceFrequency, float] = Field(default_factory=lambda: {}) - frequency_max_capacity_mbps_: Dict[AirSpaceFrequency, float] = Field(default_factory=lambda: {}) + wireless_interfaces_by_frequency: Dict[int, List[WirelessNetworkInterface]] = Field(default_factory=lambda: {}) + bandwidth_load: Dict[int, float] = Field(default_factory=lambda: {}) + frequencies: Dict[str, Dict] = Field(default_factory=lambda: copy.deepcopy(_default_frequency_set)) - def get_frequency_max_capacity_mbps(self, frequency: AirSpaceFrequency) -> float: + @validate_call + def get_frequency_max_capacity_mbps(self, freq_name: str) -> float: """ Retrieves the maximum data transmission capacity for a specified frequency. - This method checks a dictionary holding custom maximum capacities. If the frequency is found, it returns the - custom set maximum capacity. If the frequency is not found in the dictionary, it defaults to the standard - maximum data rate associated with that frequency. - - :param frequency: The frequency for which the maximum capacity is queried. - + :param freq_name: The frequency for which the maximum capacity is queried. :return: The maximum capacity in Mbps for the specified frequency. """ - if frequency in self.frequency_max_capacity_mbps_: - return self.frequency_max_capacity_mbps_[frequency] - return frequency.maximum_data_rate_mbps + if freq_name in self.frequencies: + return self.frequencies[freq_name]["data_rate_bps"] / (1024.0 * 1024.0) + return 0.0 - def set_frequency_max_capacity_mbps(self, cfg: Dict[AirSpaceFrequency, float]): + def set_frequency_max_capacity_mbps(self, cfg: Dict[int, float]) -> None: """ Sets custom maximum data transmission capacities for multiple frequencies. :param cfg: A dictionary mapping frequencies to their new maximum capacities in Mbps. """ - self.frequency_max_capacity_mbps_ = cfg for freq, mbps in cfg.items(): + self.frequencies[freq]["data_rate_bps"] = mbps * 1024 * 1024 print(f"Overriding {freq} max capacity as {mbps:.3f} mbps") + def register_frequency(self, freq_name: str, freq_hz: float, data_rate_bps: float) -> None: + """ + Define a new frequency for this airspace. + + :param freq_name: The frequency name. If this clashes with an existing frequency name, it will be overwritten. + :type freq_name: str + :param freq_hz: The frequency itself, measured in Hertz. + :type freq_hz: float + :param data_rate_bps: The transmission capacity over this frequency, in bits per second. + :type data_rate_bps: float + """ + if freq_name in self.frequencies: + _LOGGER.info( + f"Overwriting Air space frequency {freq_name}. " + f"Previous data rate: {self.frequencies[freq_name]['data_rate_bps']}. " + f"Current data rate: {data_rate_bps}." + ) + self.frequencies.update({freq_name: {"frequency": freq_hz, "data_rate_bps": data_rate_bps}}) + def show_bandwidth_load(self, markdown: bool = False): """ Prints a table of the current bandwidth load for each frequency on the airspace. @@ -150,7 +143,13 @@ class AirSpace(BaseModel): load_percent = load / maximum_capacity if maximum_capacity > 0 else 0.0 if load_percent > 1.0: load_percent = 1.0 - table.add_row([format_hertz(frequency.value), f"{load_percent:.0%}", f"{maximum_capacity:.3f}"]) + table.add_row( + [ + format_hertz(self.frequencies[frequency]["frequency"]), + f"{load_percent:.0%}", + f"{maximum_capacity:.3f}", + ] + ) print(table) def show_wireless_interfaces(self, markdown: bool = False): @@ -182,7 +181,7 @@ class AirSpace(BaseModel): interface.mac_address, interface.ip_address if hasattr(interface, "ip_address") else None, interface.subnet_mask if hasattr(interface, "subnet_mask") else None, - format_hertz(interface.frequency.value), + format_hertz(self.frequencies[interface.frequency]["frequency"]), f"{interface.speed:.3f}", status, ] @@ -298,7 +297,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC): """ airspace: AirSpace - frequency: AirSpaceFrequency = AirSpaceFrequency.WIFI_2_4 + frequency: str = "WIFI_2_4" def enable(self): """Attempt to enable the network interface.""" @@ -430,7 +429,7 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC) # Update the state with information from Layer3Interface state.update(Layer3Interface.describe_state(self)) - state["frequency"] = self.frequency.value + state["frequency"] = self.frequency return state diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index 39fbe783..6e019f32 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -130,15 +130,15 @@ class Network(SimComponent): def firewall_nodes(self) -> List[Node]: """The Firewalls in the Network.""" return [node for node in self.nodes.values() if node.__class__.__name__ == "Firewall"] - + @property def extended_hostnodes(self) -> List[Node]: - """Extended nodes that inherited HostNode in the network""" + """Extended nodes that inherited HostNode in the network.""" return [node for node in self.nodes.values() if node.__class__.__name__.lower() in HostNode._registry] - + @property def extended_networknodes(self) -> List[Node]: - """Extended nodes that inherited NetworkNode in the network""" + """Extended nodes that inherited NetworkNode in the network.""" return [node for node in self.nodes.values() if node.__class__.__name__.lower() in NetworkNode._registry] @property diff --git a/src/primaite/simulator/network/creation.py b/src/primaite/simulator/network/creation.py index 61a37a90..891c445e 100644 --- a/src/primaite/simulator/network/creation.py +++ b/src/primaite/simulator/network/creation.py @@ -6,8 +6,8 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP def num_of_switches_required(num_nodes: int, max_network_interface: int = 24) -> int: @@ -98,8 +98,10 @@ def create_office_lan( default_gateway = IPv4Address(f"192.168.{subnet_base}.1") router = Router(hostname=f"router_{lan_name}", start_up_duration=0) router.power_on() - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + router.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) network.add_node(router) router.configure_port(port=1, ip_address=default_gateway, subnet_mask="255.255.255.0") router.enable_port(1) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index bf230e07..050f4667 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -21,8 +21,6 @@ from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.packet_capture import PacketCapture from primaite.simulator.system.core.session_manager import SessionManager @@ -33,7 +31,9 @@ from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.terminal.terminal import Terminal from primaite.simulator.system.software import IOSoftware, Software from primaite.utils.converters import convert_dict_enum_keys_to_enum_values -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import PORT_LOOKUP IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware) @@ -203,16 +203,16 @@ class NetworkInterface(SimComponent, ABC): # Initialise basic frame data variables direction = "inbound" if inbound else "outbound" # Direction of the traffic ip_address = str(frame.ip.src_ip_address if inbound else frame.ip.dst_ip_address) # Source or destination IP - protocol = frame.ip.protocol.name # Network protocol used in the frame + protocol = frame.ip.protocol # Network protocol used in the frame # Initialise port variable; will be determined based on protocol type port = None # Determine the source or destination port based on the protocol (TCP/UDP) if frame.tcp: - port = frame.tcp.src_port.value if inbound else frame.tcp.dst_port.value + port = frame.tcp.src_port if inbound else frame.tcp.dst_port elif frame.udp: - port = frame.udp.src_port.value if inbound else frame.udp.dst_port.value + port = frame.udp.src_port if inbound else frame.udp.dst_port # Convert frame payload to string for keyword checking frame_str = str(frame.payload) @@ -274,20 +274,20 @@ class NetworkInterface(SimComponent, ABC): # Identify the protocol and port from the frame if frame.tcp: - protocol = IPProtocol.TCP + protocol = PROTOCOL_LOOKUP["TCP"] port = frame.tcp.dst_port elif frame.udp: - protocol = IPProtocol.UDP + protocol = PROTOCOL_LOOKUP["UDP"] port = frame.udp.dst_port elif frame.icmp: - protocol = IPProtocol.ICMP + protocol = PROTOCOL_LOOKUP["ICMP"] # Ensure the protocol is in the capture dict if protocol not in self.traffic: self.traffic[protocol] = {} # Handle non-ICMP protocols that use ports - if protocol != IPProtocol.ICMP: + if protocol != PROTOCOL_LOOKUP["ICMP"]: if port not in self.traffic[protocol]: self.traffic[protocol][port] = {"inbound": 0, "outbound": 0} self.traffic[protocol][port][direction] += frame.size_Mbits @@ -843,8 +843,8 @@ class UserManager(Service): :param password: The password for the default admin user """ kwargs["name"] = "UserManager" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) self.start() @@ -1166,8 +1166,8 @@ class UserSessionManager(Service): :param password: The password for the default admin user """ kwargs["name"] = "UserSessionManager" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) self.start() @@ -1312,7 +1312,7 @@ class UserSessionManager(Service): software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( payload={"type": "user_timeout", "connection_id": session.uuid}, - dest_port=Port.SSH, + dest_port=PORT_LOOKUP["SSH"], dest_ip_address=session.remote_ip_address, ) @@ -1839,14 +1839,14 @@ class Node(SimComponent): def show_open_ports(self, markdown: bool = False): """Prints a table of the open ports on the Node.""" - table = PrettyTable(["Port", "Name"]) + table = PrettyTable(["Port"]) if markdown: table.set_style(MARKDOWN) table.align = "l" table.title = f"{self.hostname} Open Ports" for port in self.software_manager.get_open_ports(): - if port.value > 0: - table.add_row([port.value, port.name]) + if port > 0: + table.add_row([port]) print(table.get_string(sortby="Port")) @property diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index ea162e88..5699721b 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -22,7 +22,7 @@ from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.icmp.icmp import ICMP from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.terminal.terminal import Terminal -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ipv4_address import IPV4Address _LOGGER = getLogger(__name__) @@ -332,7 +332,7 @@ class HostNode(Node): super().__init__(**kwargs) self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask)) - def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: """ Register a hostnode type. @@ -340,7 +340,7 @@ class HostNode(Node): :type identifier: str :raises ValueError: When attempting to register an hostnode with a name that is already allocated. """ - if identifier == 'default': + if identifier == "default": return # Enforce lowercase registry entries because it makes comparisons everywhere else much easier. identifier = identifier.lower() diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 4510eac0..47cfae57 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -14,10 +14,10 @@ from primaite.simulator.network.hardware.nodes.network.router import ( RouterInterface, ) from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.sys_log import SysLog -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import PORT_LOOKUP EXTERNAL_PORT_ID: Final[int] = 1 """The Firewall port ID of the external port.""" @@ -58,8 +58,8 @@ class Firewall(Router): >>> # Permit HTTP traffic to the DMZ >>> firewall.dmz_inbound_acl.add_rule( ... action=ACLAction.PERMIT, - ... protocol=IPProtocol.TCP, - ... dst_port=Port.HTTP, + ... protocol=IPProtocol["TCP"], + ... dst_port=Port["HTTP"], ... src_ip_address="0.0.0.0", ... src_wildcard_mask="0.0.0.0", ... dst_ip_address="172.16.0.0", @@ -596,9 +596,9 @@ class Firewall(Router): for r_num, r_cfg in cfg["acl"]["internal_inbound_acl"].items(): firewall.internal_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -611,9 +611,9 @@ class Firewall(Router): for r_num, r_cfg in cfg["acl"]["internal_outbound_acl"].items(): firewall.internal_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -626,9 +626,9 @@ class Firewall(Router): for r_num, r_cfg in cfg["acl"]["dmz_inbound_acl"].items(): firewall.dmz_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -641,9 +641,9 @@ class Firewall(Router): for r_num, r_cfg in cfg["acl"]["dmz_outbound_acl"].items(): firewall.dmz_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -656,9 +656,9 @@ class Firewall(Router): for r_num, r_cfg in cfg["acl"]["external_inbound_acl"].items(): firewall.external_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -671,9 +671,9 @@ class Firewall(Router): for r_num, r_cfg in cfg["acl"]["external_outbound_acl"].items(): firewall.external_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), diff --git a/src/primaite/simulator/network/hardware/nodes/network/network_node.py b/src/primaite/simulator/network/hardware/nodes/network/network_node.py index 6515bb02..a0cb63e1 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/network_node.py +++ b/src/primaite/simulator/network/hardware/nodes/network/network_node.py @@ -19,7 +19,7 @@ class NetworkNode(Node): _registry: ClassVar[Dict[str, Type["NetworkNode"]]] = {} """Registry of application types. Automatically populated when subclasses are defined.""" - def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: """ Register a networknode type. @@ -27,7 +27,7 @@ class NetworkNode(Node): :type identifier: str :raises ValueError: When attempting to register an networknode with a name that is already allocated. """ - if identifier == 'default': + if identifier == "default": return identifier = identifier.lower() super().__init_subclass__(**kwargs) diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index ceb91695..1080dca8 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -17,15 +17,15 @@ from primaite.simulator.network.hardware.nodes.network.network_node import Netwo from primaite.simulator.network.protocols.arp import ARPPacket from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.nmap import NMAP from primaite.simulator.system.core.session_manager import SessionManager from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.arp.arp import ARP from primaite.simulator.system.services.icmp.icmp import ICMP from primaite.simulator.system.services.terminal.terminal import Terminal -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import Port, PORT_LOOKUP @validate_call() @@ -106,7 +106,7 @@ class ACLRule(SimComponent): :ivar ACLAction action: Specifies whether to `PERMIT` or `DENY` the traffic that matches the rule conditions. The default action is `DENY`. - :ivar Optional[IPProtocol] protocol: The network protocol (e.g., TCP, UDP, ICMP) to match. If `None`, the rule + :ivar Optional[str] protocol: The network protocol (e.g., TCP, UDP, ICMP) to match. If `None`, the rule applies to all protocols. :ivar Optional[IPV4Address] src_ip_address: The source IP address to match. If combined with `src_wildcard_mask`, it specifies the start of an IP range. @@ -116,8 +116,8 @@ class ACLRule(SimComponent): `dst_wildcard_mask`, it specifies the start of an IP range. :ivar Optional[IPv4Address] dst_wildcard_mask: The wildcard mask for the destination IP address, defining the range of addresses to match. - :ivar Optional[Port] src_port: The source port number to match. Relevant for TCP/UDP protocols. - :ivar Optional[Port] dst_port: The destination port number to match. Relevant for TCP/UDP protocols. + :ivar Optional[int] src_port: The source port number to match. Relevant for TCP/UDP protocols. + :ivar Optional[int] dst_port: The destination port number to match. Relevant for TCP/UDP protocols. """ action: ACLAction = ACLAction.DENY @@ -149,13 +149,13 @@ class ACLRule(SimComponent): """ state = super().describe_state() state["action"] = self.action.value - state["protocol"] = self.protocol.name if self.protocol else None + state["protocol"] = self.protocol if self.protocol else None state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None state["src_wildcard_mask"] = str(self.src_wildcard_mask) if self.src_wildcard_mask else None - state["src_port"] = self.src_port.name if self.src_port else None + state["src_port"] = self.src_port if self.src_port else None state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None state["dst_wildcard_mask"] = str(self.dst_wildcard_mask) if self.dst_wildcard_mask else None - state["dst_port"] = self.dst_port.name if self.dst_port else None + state["dst_port"] = self.dst_port if self.dst_port else None state["match_count"] = self.match_count return state @@ -265,7 +265,7 @@ class AccessControlList(SimComponent): >>> acl = AccessControlList() >>> acl.add_rule( ... action=ACLAction.PERMIT, - ... protocol=IPProtocol.TCP, + ... protocol=IPProtocol["TCP"], ... src_ip_address="192.168.1.0", ... src_wildcard_mask="0.0.0.255", ... dst_ip_address="192.168.2.0", @@ -323,13 +323,13 @@ class AccessControlList(SimComponent): func=lambda request, context: RequestResponse.from_bool( self.add_rule( action=ACLAction[request[0]], - protocol=None if request[1] == "ALL" else IPProtocol[request[1]], + protocol=None if request[1] == "ALL" else request[1], src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]), src_wildcard_mask=None if request[3] == "NONE" else IPv4Address(request[3]), - src_port=None if request[4] == "ALL" else Port[request[4]], + src_port=None if request[4] == "ALL" else request[4], dst_ip_address=None if request[5] == "ALL" else IPv4Address(request[5]), dst_wildcard_mask=None if request[6] == "NONE" else IPv4Address(request[6]), - dst_port=None if request[7] == "ALL" else Port[request[7]], + dst_port=None if request[7] == "ALL" else request[7], position=int(request[8]), ) ) @@ -399,11 +399,11 @@ class AccessControlList(SimComponent): >>> router = Router("router") >>> router.add_rule( ... action=ACLAction.DENY, - ... protocol=IPProtocol.TCP, + ... protocol=IPProtocol["TCP"], ... src_ip_address="192.168.1.0", ... src_wildcard_mask="0.0.0.255", ... dst_ip_address="10.10.10.5", - ... dst_port=Port.SSH, + ... dst_port=Port["SSH"], ... position=5 ... ) >>> # This permits SSH traffic from the 192.168.1.0/24 subnet to the 10.10.10.5 server. @@ -411,10 +411,10 @@ class AccessControlList(SimComponent): >>> # Then if we want to allow a specific IP address from this subnet to SSH into the server >>> router.add_rule( ... action=ACLAction.PERMIT, - ... protocol=IPProtocol.TCP, + ... protocol=IPProtocol["TCP"], ... src_ip_address="192.168.1.25", ... dst_ip_address="10.10.10.5", - ... dst_port=Port.SSH, + ... dst_port=Port["SSH"], ... position=4 ... ) @@ -552,13 +552,13 @@ class AccessControlList(SimComponent): [ index, rule.action.name if rule.action else "ANY", - rule.protocol.name if rule.protocol else "ANY", + rule.protocol if rule.protocol else "ANY", rule.src_ip_address if rule.src_ip_address else "ANY", rule.src_wildcard_mask if rule.src_wildcard_mask else "ANY", - f"{rule.src_port.value} ({rule.src_port.name})" if rule.src_port else "ANY", + f"{rule.src_port}" if rule.src_port else "ANY", rule.dst_ip_address if rule.dst_ip_address else "ANY", rule.dst_wildcard_mask if rule.dst_wildcard_mask else "ANY", - f"{rule.dst_port.value} ({rule.dst_port.name})" if rule.dst_port else "ANY", + f"{rule.dst_port}" if rule.dst_port else "ANY", rule.match_count, ] ) @@ -1257,8 +1257,10 @@ class Router(NetworkNode): Initializes the router's ACL (Access Control List) with default rules, permitting essential protocols like ARP and ICMP, which are necessary for basic network operations and diagnostics. """ - self.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - self.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + self.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + self.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) def setup_for_episode(self, episode: int): """ @@ -1357,9 +1359,9 @@ class Router(NetworkNode): """ dst_ip_address = frame.ip.dst_ip_address dst_port = None - if frame.ip.protocol == IPProtocol.TCP: + if frame.ip.protocol == PROTOCOL_LOOKUP["TCP"]: dst_port = frame.tcp.dst_port - elif frame.ip.protocol == IPProtocol.UDP: + elif frame.ip.protocol == PROTOCOL_LOOKUP["UDP"]: dst_port = frame.udp.dst_port if self.ip_is_router_interface(dst_ip_address) and ( @@ -1632,9 +1634,9 @@ class Router(NetworkNode): for r_num, r_cfg in cfg["acl"].items(): router.acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 3cb4c515..27a13154 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -4,13 +4,13 @@ from typing import Any, Dict, Optional, Union from pydantic import validate_call -from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, IPWirelessNetworkInterface +from primaite.simulator.network.airspace import AirSpace, IPWirelessNetworkInterface from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterInterface from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import PORT_LOOKUP class WirelessAccessPoint(IPWirelessNetworkInterface): @@ -116,7 +116,7 @@ class WirelessRouter(Router): >>> wireless_router.configure_wireless_access_point( ... ip_address="10.10.10.1", ... subnet_mask="255.255.255.0" - ... frequency=AirSpaceFrequency.WIFI_2_4 + ... frequency="WIFI_2_4" ... ) """ @@ -153,7 +153,7 @@ class WirelessRouter(Router): self, ip_address: IPV4Address, subnet_mask: IPV4Address, - frequency: Optional[AirSpaceFrequency] = AirSpaceFrequency.WIFI_2_4, + frequency: Optional[str] = "WIFI_2_4", ): """ Configures a wireless access point (WAP). @@ -166,12 +166,12 @@ class WirelessRouter(Router): :param ip_address: The IP address to be assigned to the wireless access point. :param subnet_mask: The subnet mask associated with the IP address - :param frequency: The operating frequency of the wireless access point, defined by the AirSpaceFrequency + :param frequency: The operating frequency of the wireless access point, defined by the air space frequency enum. This determines the frequency band (e.g., 2.4 GHz or 5 GHz) the access point will use for wireless - communication. Default is AirSpaceFrequency.WIFI_2_4. + communication. Default is "WIFI_2_4". """ if not frequency: - frequency = AirSpaceFrequency.WIFI_2_4 + frequency = "WIFI_2_4" self.sys_log.info("Configuring wireless access point") self.wireless_access_point.disable() # Temporarily disable the WAP for reconfiguration @@ -264,16 +264,16 @@ class WirelessRouter(Router): if "wireless_access_point" in cfg: ip_address = cfg["wireless_access_point"]["ip_address"] subnet_mask = cfg["wireless_access_point"]["subnet_mask"] - frequency = AirSpaceFrequency[cfg["wireless_access_point"]["frequency"]] + frequency = cfg["wireless_access_point"]["frequency"] router.configure_wireless_access_point(ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency) if "acl" in cfg: for r_num, r_cfg in cfg["acl"].items(): router.acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), dst_ip_address=r_cfg.get("dst_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index cb0965eb..2c3c15b4 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -12,14 +12,14 @@ from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -79,9 +79,11 @@ def client_server_routed() -> Network: server_1.power_on() network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1]) - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) return network @@ -271,23 +273,30 @@ def arcd_uc2_network() -> Network: security_suite.connect_nic(NIC(ip_address="192.168.10.110", subnet_mask="255.255.255.0")) network.connect(endpoint_b=security_suite.network_interface[2], endpoint_a=switch_2.network_interface[7]) - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Allow PostgreSQL requests router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) # Allow DNS requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["DNS"], dst_port=PORT_LOOKUP["DNS"], position=1) # Allow FTP requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=2) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=2) # Open port 80 for web server - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=3 + ) return network diff --git a/src/primaite/simulator/network/protocols/masquerade.py b/src/primaite/simulator/network/protocols/masquerade.py index e2a7b6a0..5c5f03b2 100644 --- a/src/primaite/simulator/network/protocols/masquerade.py +++ b/src/primaite/simulator/network/protocols/masquerade.py @@ -3,14 +3,16 @@ from enum import Enum from typing import Optional from primaite.simulator.network.protocols.packet import DataPacket +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.port import Port class MasqueradePacket(DataPacket): """Represents an generic malicious packet that is masquerading as another protocol.""" - masquerade_protocol: Enum # The 'Masquerade' protocol that is currently in use + masquerade_protocol: IPProtocol # The 'Masquerade' protocol that is currently in use - masquerade_port: Enum # The 'Masquerade' port that is currently in use + masquerade_port: Port # The 'Masquerade' port that is currently in use class C2Packet(MasqueradePacket): diff --git a/src/primaite/simulator/network/transmission/data_link_layer.py b/src/primaite/simulator/network/transmission/data_link_layer.py index 159eca7f..259d62e3 100644 --- a/src/primaite/simulator/network/transmission/data_link_layer.py +++ b/src/primaite/simulator/network/transmission/data_link_layer.py @@ -7,10 +7,12 @@ from pydantic import BaseModel from primaite import getLogger from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.protocols.packet import DataPacket -from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol +from primaite.simulator.network.transmission.network_layer import IPPacket from primaite.simulator.network.transmission.primaite_layer import PrimaiteHeader -from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader, UDPHeader +from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader from primaite.simulator.network.utils import convert_bytes_to_megabits +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -70,15 +72,15 @@ class Frame(BaseModel): msg = "Network Frame cannot have both a TCP header and a UDP header" _LOGGER.error(msg) raise ValueError(msg) - if kwargs["ip"].protocol == IPProtocol.TCP and not kwargs.get("tcp"): + if kwargs["ip"].protocol == PROTOCOL_LOOKUP["TCP"] and not kwargs.get("tcp"): msg = "Cannot build a Frame using the TCP IP Protocol without a TCPHeader" _LOGGER.error(msg) raise ValueError(msg) - if kwargs["ip"].protocol == IPProtocol.UDP and not kwargs.get("udp"): + if kwargs["ip"].protocol == PROTOCOL_LOOKUP["UDP"] and not kwargs.get("udp"): msg = "Cannot build a Frame using the UDP IP Protocol without a UDPHeader" _LOGGER.error(msg) raise ValueError(msg) - if kwargs["ip"].protocol == IPProtocol.ICMP and not kwargs.get("icmp"): + if kwargs["ip"].protocol == PROTOCOL_LOOKUP["ICMP"] and not kwargs.get("icmp"): msg = "Cannot build a Frame using the ICMP IP Protocol without a ICMPPacket" _LOGGER.error(msg) raise ValueError(msg) @@ -165,7 +167,7 @@ class Frame(BaseModel): :return: True if the Frame is an ARP packet, otherwise False. """ - return self.udp.dst_port == Port.ARP + return self.udp.dst_port == PORT_LOOKUP["ARP"] @property def is_icmp(self) -> bool: diff --git a/src/primaite/simulator/network/transmission/network_layer.py b/src/primaite/simulator/network/transmission/network_layer.py index d493cbdf..49dcd1f5 100644 --- a/src/primaite/simulator/network/transmission/network_layer.py +++ b/src/primaite/simulator/network/transmission/network_layer.py @@ -4,32 +4,12 @@ from enum import Enum from pydantic import BaseModel from primaite import getLogger -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address _LOGGER = getLogger(__name__) -class IPProtocol(Enum): - """ - Enum representing transport layer protocols in IP header. - - .. _List of IPProtocols: - """ - - NONE = "none" - """Placeholder for a non-protocol.""" - TCP = "tcp" - """Transmission Control Protocol.""" - UDP = "udp" - """User Datagram Protocol.""" - ICMP = "icmp" - """Internet Control Message Protocol.""" - - def model_dump(self) -> str: - """Return as JSON-serialisable string.""" - return self.name - - class Precedence(Enum): """ Enum representing the Precedence levels in Quality of Service (QoS) for IP packets. @@ -81,7 +61,7 @@ class IPPacket(BaseModel): >>> ip_packet = IPPacket( ... src_ip_address=IPv4Address('192.168.0.1'), ... dst_ip_address=IPv4Address('10.0.0.1'), - ... protocol=IPProtocol.TCP, + ... protocol=IPProtocol["TCP"], ... ttl=64, ... precedence=Precedence.CRITICAL ... ) @@ -91,7 +71,7 @@ class IPPacket(BaseModel): "Source IP address." dst_ip_address: IPV4Address "Destination IP address." - protocol: IPProtocol = IPProtocol.TCP + protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"] "IPProtocol." ttl: int = 64 "Time to Live (TTL) for the packet." diff --git a/src/primaite/simulator/network/transmission/transport_layer.py b/src/primaite/simulator/network/transmission/transport_layer.py index 7f0d2d7a..10cf802c 100644 --- a/src/primaite/simulator/network/transmission/transport_layer.py +++ b/src/primaite/simulator/network/transmission/transport_layer.py @@ -1,82 +1,10 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from enum import Enum -from typing import List, Union +from typing import List from pydantic import BaseModel -class Port(Enum): - """ - Enumeration of common known TCP/UDP ports used by protocols for operation of network applications. - - .. _List of Ports: - """ - - UNUSED = -1 - "An unused port stub." - - NONE = 0 - "Place holder for a non-port." - WOL = 9 - "Wake-on-Lan (WOL) - Used to turn or awaken a computer from sleep mode by a network message." - FTP_DATA = 20 - "File Transfer [Default Data]" - FTP = 21 - "File Transfer Protocol (FTP) - FTP control (command)" - SSH = 22 - "Secure Shell (SSH) - Used for secure remote access and command execution." - SMTP = 25 - "Simple Mail Transfer Protocol (SMTP) - Used for email delivery between servers." - DNS = 53 - "Domain Name System (DNS) - Used for translating domain names to IP addresses." - HTTP = 80 - "HyperText Transfer Protocol (HTTP) - Used for web traffic." - POP3 = 110 - "Post Office Protocol version 3 (POP3) - Used for retrieving emails from a mail server." - SFTP = 115 - "Secure File Transfer Protocol (SFTP) - Used for secure file transfer over SSH." - NTP = 123 - "Network Time Protocol (NTP) - Used for clock synchronization between computer systems." - IMAP = 143 - "Internet Message Access Protocol (IMAP) - Used for retrieving emails from a mail server." - SNMP = 161 - "Simple Network Management Protocol (SNMP) - Used for network device management." - SNMP_TRAP = 162 - "SNMP Trap - Used for sending SNMP notifications (traps) to a network management system." - ARP = 219 - "Address resolution Protocol - Used to connect a MAC address to an IP address." - LDAP = 389 - "Lightweight Directory Access Protocol (LDAP) - Used for accessing and modifying directory information." - HTTPS = 443 - "HyperText Transfer Protocol Secure (HTTPS) - Used for secure web traffic." - SMB = 445 - "Server Message Block (SMB) - Used for file sharing and printer sharing in Windows environments." - IPP = 631 - "Internet Printing Protocol (IPP) - Used for printing over the internet or an intranet." - SQL_SERVER = 1433 - "Microsoft SQL Server Database Engine - Used for communication with the SQL Server." - MYSQL = 3306 - "MySQL Database Server - Used for MySQL database communication." - RDP = 3389 - "Remote Desktop Protocol (RDP) - Used for remote desktop access to Windows machines." - RTP = 5004 - "Real-time Transport Protocol (RTP) - Used for transmitting real-time media, e.g., audio and video." - RTP_ALT = 5005 - "Alternative port for RTP (RTP_ALT) - Used in some configurations for transmitting real-time media." - DNS_ALT = 5353 - "Alternative port for DNS (DNS_ALT) - Used in some configurations for DNS service." - HTTP_ALT = 8080 - "Alternative port for HTTP (HTTP_ALT) - Often used as an alternative HTTP port for web applications." - HTTPS_ALT = 8443 - "Alternative port for HTTPS (HTTPS_ALT) - Used in some configurations for secure web traffic." - POSTGRES_SERVER = 5432 - "Postgres SQL Server." - - def model_dump(self) -> str: - """Return a json-serialisable string.""" - return self.name - - class UDPHeader(BaseModel): """ Represents a UDP header for the transport layer of a Network Frame. @@ -87,13 +15,13 @@ class UDPHeader(BaseModel): :Example: >>> udp_header = UDPHeader( - ... src_port=Port.HTTP_ALT, - ... dst_port=Port.HTTP, + ... src_port=Port["HTTP_ALT"], + ... dst_port=Port["HTTP"], ... ) """ - src_port: Union[Port, int] - dst_port: Union[Port, int] + src_port: int + dst_port: int class TCPFlags(Enum): @@ -126,12 +54,12 @@ class TCPHeader(BaseModel): :Example: >>> tcp_header = TCPHeader( - ... src_port=Port.HTTP_ALT, - ... dst_port=Port.HTTP, + ... src_port=Port["HTTP_ALT"], + ... dst_port=Port["HTTP"], ... flags=[TCPFlags.SYN, TCPFlags.ACK] ... ) """ - src_port: Port - dst_port: Port + src_port: int + dst_port: int flags: List[TCPFlags] = [TCPFlags.SYN] diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index b5284968..a7871315 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -44,7 +44,7 @@ class Application(IOSoftware): _registry: ClassVar[Dict[str, Type["Application"]]] = {} """Registry of application types. Automatically populated when subclasses are defined.""" - def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: """ Register an application type. @@ -52,7 +52,7 @@ class Application(IOSoftware): :type identifier: str :raises ValueError: When attempting to register an application with a name that is already allocated. """ - if identifier == 'default': + if identifier == "default": return super().__init_subclass__(**kwargs) if identifier in cls._registry: diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 3f80c745..cd4b2a03 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -11,11 +11,11 @@ from pydantic import BaseModel from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.hardware.nodes.host.host_node import HostNode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.software_manager import SoftwareManager -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import PORT_LOOKUP class DatabaseClientConnection(BaseModel): @@ -90,8 +90,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"): def __init__(self, **kwargs): kwargs["name"] = "DatabaseClient" - kwargs["port"] = Port.POSTGRES_SERVER - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def _init_request_manager(self) -> RequestManager: diff --git a/src/primaite/simulator/system/applications/nmap.py b/src/primaite/simulator/system/applications/nmap.py index c87eaaf5..e2b9117d 100644 --- a/src/primaite/simulator/system/applications/nmap.py +++ b/src/primaite/simulator/system/applications/nmap.py @@ -7,10 +7,10 @@ from pydantic import validate_call from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ip_protocol import IPProtocol, is_valid_protocol, PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import is_valid_port, Port, PORT_LOOKUP class PortScanPayload(SimComponent): @@ -37,8 +37,8 @@ class PortScanPayload(SimComponent): """ state = super().describe_state() state["ip_address"] = str(self.ip_address) - state["port"] = self.port.value - state["protocol"] = self.protocol.value + state["port"] = self.port + state["protocol"] = self.protocol state["request"] = self.request return state @@ -64,8 +64,8 @@ class NMAP(Application, identifier="NMAP"): def __init__(self, **kwargs): kwargs["name"] = "NMAP" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) def _can_perform_network_action(self) -> bool: @@ -272,8 +272,8 @@ class NMAP(Application, identifier="NMAP"): payload = PortScanPayload(ip_address=ip_address, port=port, protocol=protocol) self._active_port_scans[payload.uuid] = payload self.sys_log.info( - f"{self.name}: Sending port scan request over {payload.protocol.name} on port {payload.port.value} " - f"({payload.port.name}) to {payload.ip_address}" + f"{self.name}: Sending port scan request over {payload.protocol} on port {payload.port} " + f"({payload.port}) to {payload.ip_address}" ) self.software_manager.send_payload_to_session_manager( payload=payload, dest_ip_address=ip_address, src_port=port, dest_port=port, ip_protocol=protocol @@ -295,8 +295,8 @@ class NMAP(Application, identifier="NMAP"): self._active_port_scans.pop(payload.uuid) self._port_scan_responses[payload.uuid] = payload self.sys_log.info( - f"{self.name}: Received port scan response from {payload.ip_address} on port {payload.port.value} " - f"({payload.port.name}) over {payload.protocol.name}" + f"{self.name}: Received port scan response from {payload.ip_address} on port {payload.port} " + f"({payload.port}) over {payload.protocol}" ) def _process_port_scan_request(self, payload: PortScanPayload, session_id: str) -> None: @@ -311,8 +311,8 @@ class NMAP(Application, identifier="NMAP"): if self.software_manager.check_port_is_open(port=payload.port, protocol=payload.protocol): payload.request = False self.sys_log.info( - f"{self.name}: Responding to port scan request for port {payload.port.value} " - f"({payload.port.name}) over {payload.protocol.name}", + f"{self.name}: Responding to port scan request for port {payload.port} " + f"({payload.port}) over {payload.protocol}", ) self.software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id) @@ -345,20 +345,20 @@ class NMAP(Application, identifier="NMAP"): """ ip_addresses = self._explode_ip_address_network_array(target_ip_address) - if isinstance(target_port, Port): + if is_valid_port(target_port): target_port = [target_port] elif target_port is None: - target_port = [port for port in Port if port not in {Port.NONE, Port.UNUSED}] + target_port = [PORT_LOOKUP[port] for port in PORT_LOOKUP if port not in {"NONE", "UNUSED"}] - if isinstance(target_protocol, IPProtocol): + if is_valid_protocol(target_protocol): target_protocol = [target_protocol] elif target_protocol is None: - target_protocol = [IPProtocol.TCP, IPProtocol.UDP] + target_protocol = [PROTOCOL_LOOKUP["TCP"], PROTOCOL_LOOKUP["UDP"]] scan_type = self._determine_port_scan_type(list(ip_addresses), target_port) active_ports = {} if show: - table = PrettyTable(["IP Address", "Port", "Name", "Protocol"]) + table = PrettyTable(["IP Address", "Port", "Protocol"]) table.align = "l" table.title = f"{self.software_manager.node.hostname} NMAP Port Scan ({scan_type})" self.sys_log.info(f"{self.name}: Starting port scan") @@ -369,13 +369,12 @@ class NMAP(Application, identifier="NMAP"): for protocol in target_protocol: for port in set(target_port): port_open = self._check_port_open_on_ip_address(ip_address=ip_address, port=port, protocol=protocol) - if port_open: if show: - table.add_row([ip_address, port.value, port.name, protocol.name]) + table.add_row([ip_address, port, protocol]) _ip_address = ip_address if not json_serializable else str(ip_address) - _protocol = protocol if not json_serializable else protocol.value - _port = port if not json_serializable else port.value + _protocol = protocol + _port = port if _ip_address not in active_ports: active_ports[_ip_address] = dict() if _protocol not in active_ports[_ip_address]: diff --git a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py index 5d4cc8e0..f77bc33a 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py @@ -9,14 +9,14 @@ from pydantic import BaseModel, Field, validate_call from primaite.interface.request import RequestResponse from primaite.simulator.file_system.file_system import FileSystem, Folder from primaite.simulator.network.protocols.masquerade import C2Packet -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application, ApplicationOperatingState from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.ip_protocol import IPProtocol, is_valid_protocol, PROTOCOL_LOOKUP +from primaite.utils.validation.port import is_valid_port, Port, PORT_LOOKUP class C2Command(Enum): @@ -81,10 +81,10 @@ class AbstractC2(Application, identifier="AbstractC2"): keep_alive_frequency: int = Field(default=5, ge=1) """The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon.""" - masquerade_protocol: IPProtocol = Field(default=IPProtocol.TCP) + masquerade_protocol: IPProtocol = Field(default=PROTOCOL_LOOKUP["TCP"]) """The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP.""" - masquerade_port: Port = Field(default=Port.HTTP) + masquerade_port: Port = Field(default=PORT_LOOKUP["HTTP"]) """The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP.""" c2_config: _C2Opts = _C2Opts() @@ -142,9 +142,9 @@ class AbstractC2(Application, identifier="AbstractC2"): def __init__(self, **kwargs): """Initialise the C2 applications to by default listen for HTTP traffic.""" - kwargs["listen_on_ports"] = {Port.HTTP, Port.FTP, Port.DNS} - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.TCP + kwargs["listen_on_ports"] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]} + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) @property @@ -366,8 +366,8 @@ class AbstractC2(Application, identifier="AbstractC2"): :return: True on successful configuration, false otherwise. :rtype: bool """ - # Validating that they are valid Enums. - if not isinstance(payload.masquerade_port, Port) or not isinstance(payload.masquerade_protocol, IPProtocol): + # Validating that they are valid Ports and Protocols. + if not is_valid_port(payload.masquerade_port) or not is_valid_protocol(payload.masquerade_protocol): self.sys_log.warning( f"{self.name}: Received invalid Masquerade Values within Keep Alive." f"Port: {payload.masquerade_port} Protocol: {payload.masquerade_protocol}." @@ -410,8 +410,8 @@ class AbstractC2(Application, identifier="AbstractC2"): self.keep_alive_inactivity = 0 self.keep_alive_frequency = 5 self.c2_remote_connection = None - self.c2_config.masquerade_port = Port.HTTP - self.c2_config.masquerade_protocol = IPProtocol.TCP + self.c2_config.masquerade_port = PORT_LOOKUP["HTTP"] + self.c2_config.masquerade_protocol = PROTOCOL_LOOKUP["TCP"] @abstractmethod def _confirm_remote_connection(self, timestep: int) -> bool: diff --git a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py index fa0271e5..c0c3d872 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py @@ -1,5 +1,4 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from enum import Enum from ipaddress import IPv4Address from typing import Dict, Optional @@ -9,12 +8,12 @@ from pydantic import validate_call from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.protocols.masquerade import C2Packet -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.red_applications.c2 import ExfilOpts, RansomwareOpts, TerminalOpts from primaite.simulator.system.applications.red_applications.c2.abstract_c2 import AbstractC2, C2Command, C2Payload from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.terminal.terminal import Terminal, TerminalClientConnection +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP class C2Beacon(AbstractC2, identifier="C2Beacon"): @@ -112,8 +111,8 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): self.configure( c2_server_ip_address=c2_remote_ip, keep_alive_frequency=frequency, - masquerade_protocol=IPProtocol[protocol], - masquerade_port=Port[port], + masquerade_protocol=PROTOCOL_LOOKUP[protocol], + masquerade_port=PORT_LOOKUP[port], ) ) @@ -130,8 +129,8 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): self, c2_server_ip_address: IPv4Address = None, keep_alive_frequency: int = 5, - masquerade_protocol: Enum = IPProtocol.TCP, - masquerade_port: Enum = Port.HTTP, + masquerade_protocol: str = PROTOCOL_LOOKUP["TCP"], + masquerade_port: int = PORT_LOOKUP["HTTP"], ) -> bool: """ Configures the C2 beacon to communicate with the C2 server. diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index fefb22c3..9fdbae57 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -7,10 +7,10 @@ from primaite import getLogger from primaite.game.science import simulate_trial from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -50,8 +50,8 @@ class DataManipulationBot(Application, identifier="DataManipulationBot"): def __init__(self, **kwargs): kwargs["name"] = "DataManipulationBot" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) self._db_connection: Optional[DatabaseClientConnection] = None diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index fcad3b3e..fb2c8847 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -7,8 +7,8 @@ from primaite import getLogger from primaite.game.science import simulate_trial from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -85,7 +85,7 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): if "target_ip_address" in request[-1]: request[-1]["target_ip_address"] = IPv4Address(request[-1]["target_ip_address"]) if "target_port" in request[-1]: - request[-1]["target_port"] = Port[request[-1]["target_port"]] + request[-1]["target_port"] = PORT_LOOKUP[request[-1]["target_port"]] return RequestResponse.from_bool(self.configure(**request[-1])) rm.add_request("configure", request_type=RequestType(func=_configure)) @@ -94,7 +94,7 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): def configure( self, target_ip_address: IPv4Address, - target_port: Optional[Port] = Port.POSTGRES_SERVER, + target_port: Optional[int] = PORT_LOOKUP["POSTGRES_SERVER"], payload: Optional[str] = None, repeat: bool = False, port_scan_p_of_success: float = 0.1, @@ -105,7 +105,7 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): Configure the Denial of Service bot. :param: target_ip_address: The IP address of the Node containing the target service. - :param: target_port: The port of the target service. Optional - Default is `Port.HTTP` + :param: target_port: The port of the target service. Optional - Default is `Port["HTTP"]` :param: payload: The payload the DoS Bot will throw at the target service. Optional - Default is `None` :param: repeat: If True, the bot will maintain the attack. Optional - Default is `True` :param: port_scan_p_of_success: The chance of the port scan being successful. Optional - Default is 0.1 (10%) diff --git a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py index 2046affc..93b4c50d 100644 --- a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py +++ b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py @@ -6,10 +6,10 @@ from prettytable import MARKDOWN, PrettyTable from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP class RansomwareScript(Application, identifier="RansomwareScript"): @@ -27,8 +27,8 @@ class RansomwareScript(Application, identifier="RansomwareScript"): def __init__(self, **kwargs): kwargs["name"] = "RansomwareScript" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) self._db_connection: Optional[DatabaseClientConnection] = None diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 73791676..c57a9bd3 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -15,10 +15,10 @@ from primaite.simulator.network.protocols.http import ( HttpResponsePacket, HttpStatusCode, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application from primaite.simulator.system.services.dns.dns_client import DNSClient +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -43,10 +43,10 @@ class WebBrowser(Application, identifier="WebBrowser"): def __init__(self, **kwargs): kwargs["name"] = "WebBrowser" - kwargs["protocol"] = IPProtocol.TCP + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] # default for web is port 80 if kwargs.get("port") is None: - kwargs["port"] = Port.HTTP + kwargs["port"] = PORT_LOOKUP["HTTP"] super().__init__(**kwargs) self.run() @@ -126,7 +126,7 @@ class WebBrowser(Application, identifier="WebBrowser"): if self.send( payload=payload, dest_ip_address=self.domain_name_ip_address, - dest_port=parsed_url.port if parsed_url.port else Port.HTTP, + dest_port=parsed_url.port if parsed_url.port else PORT_LOOKUP["HTTP"], ): self.sys_log.info( f"{self.name}: Received HTTP {payload.request_method.name} " @@ -154,7 +154,7 @@ class WebBrowser(Application, identifier="WebBrowser"): self, payload: HttpRequestPacket, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = Port.HTTP, + dest_port: Optional[Port] = PORT_LOOKUP["HTTP"], session_id: Optional[str] = None, **kwargs, ) -> bool: diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index b7e2c021..75322e86 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -10,8 +10,10 @@ from primaite.simulator.core import SimComponent from primaite.simulator.network.protocols.arp import ARPPacket from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame -from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader, UDPHeader +from primaite.simulator.network.transmission.network_layer import IPPacket +from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP if TYPE_CHECKING: from primaite.simulator.network.hardware.base import NetworkInterface @@ -34,7 +36,7 @@ class Session(SimComponent): :param connected: A flag indicating whether the session is connected. """ - protocol: IPProtocol + protocol: str with_ip_address: IPv4Address src_port: Optional[Port] dst_port: Optional[Port] @@ -119,7 +121,7 @@ class SessionManager: """ protocol = frame.ip.protocol with_ip_address = frame.ip.src_ip_address - if protocol == IPProtocol.TCP: + if protocol == PROTOCOL_LOOKUP["TCP"]: if inbound_frame: src_port = frame.tcp.src_port dst_port = frame.tcp.dst_port @@ -127,7 +129,7 @@ class SessionManager: dst_port = frame.tcp.src_port src_port = frame.tcp.dst_port with_ip_address = frame.ip.dst_ip_address - elif protocol == IPProtocol.UDP: + elif protocol == PROTOCOL_LOOKUP["UDP"]: if inbound_frame: src_port = frame.udp.src_port dst_port = frame.udp.dst_port @@ -262,7 +264,7 @@ class SessionManager: src_port: Optional[Port] = None, dst_port: Optional[Port] = None, session_id: Optional[str] = None, - ip_protocol: IPProtocol = IPProtocol.TCP, + ip_protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"], icmp_packet: Optional[ICMPPacket] = None, ) -> Union[Any, None]: """ @@ -286,7 +288,7 @@ class SessionManager: dst_mac_address = payload.target_mac_addr outbound_network_interface = self.resolve_outbound_network_interface(payload.target_ip_address) is_broadcast = payload.request - ip_protocol = IPProtocol.UDP + ip_protocol = PROTOCOL_LOOKUP["UDP"] else: vals = self.resolve_outbound_transmission_details( dst_ip_address=dst_ip_address, @@ -311,26 +313,26 @@ class SessionManager: if not outbound_network_interface or not dst_mac_address: return False - if not (src_port or dst_port): + if src_port is None and dst_port is None: raise ValueError( "Failed to resolve src or dst port. Have you sent the port from the service or application?" ) tcp_header = None udp_header = None - if ip_protocol == IPProtocol.TCP: + if ip_protocol == PROTOCOL_LOOKUP["TCP"]: tcp_header = TCPHeader( src_port=dst_port, dst_port=dst_port, ) - elif ip_protocol == IPProtocol.UDP: + elif ip_protocol == PROTOCOL_LOOKUP["UDP"]: udp_header = UDPHeader( src_port=dst_port, dst_port=dst_port, ) # TODO: Only create IP packet if not ARP # ip_packet = None - # if dst_port != Port.ARP: + # if dst_port != Port["ARP"]: # IPPacket( # src_ip_address=outbound_network_interface.ip_address, # dst_ip_address=dst_ip_address, @@ -387,7 +389,7 @@ class SessionManager: elif frame.udp: dst_port = frame.udp.dst_port elif frame.icmp: - dst_port = Port.NONE + dst_port = PORT_LOOKUP["NONE"] self.software_manager.receive_payload_from_session_manager( payload=frame.payload, port=dst_port, @@ -413,5 +415,5 @@ class SessionManager: table.align = "l" table.title = f"{self.sys_log.hostname} Session Manager" for session in self.sessions_by_key.values(): - table.add_row([session.dst_ip_address, session.dst_port.value, session.protocol.name]) + table.add_row([session.dst_ip_address, session.dst_port, session.protocol]) print(table) diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index d45611ed..60621384 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -8,12 +8,12 @@ from prettytable import MARKDOWN, PrettyTable from primaite.simulator.core import RequestType from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application, ApplicationOperatingState from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.service import Service, ServiceOperatingState from primaite.simulator.system.software import IOSoftware +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP if TYPE_CHECKING: from primaite.simulator.system.core.session_manager import SessionManager @@ -191,7 +191,7 @@ class SoftwareManager: dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, src_port: Optional[Port] = None, dest_port: Optional[Port] = None, - ip_protocol: IPProtocol = IPProtocol.TCP, + ip_protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"], session_id: Optional[str] = None, ) -> bool: """ @@ -275,8 +275,8 @@ class SoftwareManager: software_type, software.operating_state.name, software.health_state_actual.name, - software.port.value if software.port != Port.NONE else None, - software.protocol.value, + software.port if software.port != PORT_LOOKUP["NONE"] else None, + software.protocol, ] ) print(table) diff --git a/src/primaite/simulator/system/services/arp/arp.py b/src/primaite/simulator/system/services/arp/arp.py index efadf189..816eb99e 100644 --- a/src/primaite/simulator/system/services/arp/arp.py +++ b/src/primaite/simulator/system/services/arp/arp.py @@ -8,10 +8,10 @@ from prettytable import MARKDOWN, PrettyTable from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.service import Service -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import PORT_LOOKUP class ARP(Service): @@ -26,8 +26,8 @@ class ARP(Service): def __init__(self, **kwargs): kwargs["name"] = "ARP" - kwargs["port"] = Port.ARP - kwargs["protocol"] = IPProtocol.UDP + kwargs["port"] = PORT_LOOKUP["ARP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["UDP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index b38e87b4..b7cd8886 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -7,12 +7,12 @@ from primaite import getLogger from primaite.simulator.file_system.file_system import File from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus from primaite.simulator.file_system.folder import Folder -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.service import Service, ServiceOperatingState from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -38,8 +38,8 @@ class DatabaseService(Service): def __init__(self, **kwargs): kwargs["name"] = "DatabaseService" - kwargs["port"] = Port.POSTGRES_SERVER - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self._create_db_file() diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index d7ba0cd4..78642fa6 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -4,10 +4,10 @@ from typing import Dict, Optional from primaite import getLogger from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -22,11 +22,11 @@ class DNSClient(Service): def __init__(self, **kwargs): kwargs["name"] = "DNSClient" - kwargs["port"] = Port.DNS + kwargs["port"] = PORT_LOOKUP["DNS"] # DNS uses UDP by default # it switches to TCP when the bytes exceed 512 (or 4096) bytes # TCP for now - kwargs["protocol"] = IPProtocol.TCP + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self.start() @@ -95,7 +95,7 @@ class DNSClient(Service): # send a request to check if domain name exists in the DNS Server software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( - payload=payload, dest_ip_address=self.dns_server, dest_port=Port.DNS + payload=payload, dest_ip_address=self.dns_server, dest_port=PORT_LOOKUP["DNS"] ) # recursively re-call the function passing is_reattempt=True diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index 8a4bbaed..5b380320 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -6,9 +6,9 @@ from prettytable import MARKDOWN, PrettyTable from primaite import getLogger from primaite.simulator.network.protocols.dns import DNSPacket -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.service import Service +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -21,11 +21,11 @@ class DNSServer(Service): def __init__(self, **kwargs): kwargs["name"] = "DNSServer" - kwargs["port"] = Port.DNS + kwargs["port"] = PORT_LOOKUP["DNS"] # DNS uses UDP by default # it switches to TCP when the bytes exceed 512 (or 4096) bytes # TCP for now - kwargs["protocol"] = IPProtocol.TCP + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self.start() diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index f823e42c..00b70332 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -7,10 +7,10 @@ from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.file_system.file_system import File from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -25,8 +25,8 @@ class FTPClient(FTPServiceABC): def __init__(self, **kwargs): kwargs["name"] = "FTPClient" - kwargs["port"] = Port.FTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = PORT_LOOKUP["FTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self.start() @@ -104,7 +104,7 @@ class FTPClient(FTPServiceABC): def _connect_to_server( self, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = Port.FTP, + dest_port: Optional[Port] = PORT_LOOKUP["FTP"], session_id: Optional[str] = None, is_reattempt: Optional[bool] = False, ) -> bool: @@ -124,13 +124,13 @@ class FTPClient(FTPServiceABC): # normally FTP will choose a random port for the transfer, but using the FTP command port will do for now # create FTP packet - payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.PORT, ftp_command_args=Port.FTP) + payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.PORT, ftp_command_args=PORT_LOOKUP["FTP"]) if self.send(payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id): if payload.status_code == FTPStatusCode.OK: self.sys_log.info( f"{self.name}: Successfully connected to FTP Server " - f"{dest_ip_address} via port {payload.ftp_command_args.value}" + f"{dest_ip_address} via port {payload.ftp_command_args}" ) self.add_connection(connection_id="server_connection", session_id=session_id) return True @@ -139,7 +139,7 @@ class FTPClient(FTPServiceABC): # reattempt failed self.sys_log.warning( f"{self.name}: Unable to connect to FTP Server " - f"{dest_ip_address} via port {payload.ftp_command_args.value}" + f"{dest_ip_address} via port {payload.ftp_command_args}" ) return False else: @@ -152,7 +152,7 @@ class FTPClient(FTPServiceABC): return False def _disconnect_from_server( - self, dest_ip_address: Optional[IPv4Address] = None, dest_port: Optional[Port] = Port.FTP + self, dest_ip_address: Optional[IPv4Address] = None, dest_port: Optional[Port] = PORT_LOOKUP["FTP"] ) -> bool: """ Connects the client from a given FTP server. @@ -179,7 +179,7 @@ class FTPClient(FTPServiceABC): src_file_name: str, dest_folder_name: str, dest_file_name: str, - dest_port: Optional[Port] = Port.FTP, + dest_port: Optional[Port] = PORT_LOOKUP["FTP"], session_id: Optional[str] = None, ) -> bool: """ @@ -203,7 +203,7 @@ class FTPClient(FTPServiceABC): :param: dest_file_name: The name of the file to be saved on the FTP Server. :type: dest_file_name: str - :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port.FTP. + :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"]. :type: dest_port: Optional[Port] :param: session_id: The id of the session @@ -241,7 +241,7 @@ class FTPClient(FTPServiceABC): src_file_name: str, dest_folder_name: str, dest_file_name: str, - dest_port: Optional[Port] = Port.FTP, + dest_port: Optional[Port] = PORT_LOOKUP["FTP"], ) -> bool: """ Request a file from a target IP address. @@ -263,8 +263,8 @@ class FTPClient(FTPServiceABC): :param: dest_file_name: The name of the file to be saved on the FTP Server. :type: dest_file_name: str - :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port.FTP. - :type: dest_port: Optional[Port] + :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"]. + :type: dest_port: Optional[int] """ # check if FTP is currently connected to IP self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port) diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index f02d01f4..671200f5 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -3,9 +3,9 @@ from typing import Any, Optional from primaite import getLogger from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import is_valid_port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -23,8 +23,8 @@ class FTPServer(FTPServiceABC): def __init__(self, **kwargs): kwargs["name"] = "FTPServer" - kwargs["port"] = Port.FTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = PORT_LOOKUP["FTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self.start() @@ -52,7 +52,7 @@ class FTPServer(FTPServiceABC): # process server specific commands, otherwise call super if payload.ftp_command == FTPCommand.PORT: # check that the port is valid - if isinstance(payload.ftp_command_args, Port) and payload.ftp_command_args.value in range(0, 65535): + if is_valid_port(payload.ftp_command_args): # return successful connection self.add_connection(connection_id=session_id, session_id=session_id) payload.status_code = FTPStatusCode.OK diff --git a/src/primaite/simulator/system/services/ftp/ftp_service.py b/src/primaite/simulator/system/services/ftp/ftp_service.py index 689a3da7..77d82997 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_service.py +++ b/src/primaite/simulator/system/services/ftp/ftp_service.py @@ -5,8 +5,8 @@ from typing import Dict, Optional from primaite.simulator.file_system.file_system import File from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.service import Service +from primaite.utils.validation.port import Port class FTPServiceABC(Service, ABC): @@ -97,7 +97,7 @@ class FTPServiceABC(Service, ABC): :param: dest_ip_address: The IP address of the machine that hosts the FTP Server. :type: dest_ip_address: Optional[IPv4Address] - :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port.FTP. + :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"]. :type: dest_port: Optional[Port] :param: session_id: session ID linked to the FTP Packet. Optional. diff --git a/src/primaite/simulator/system/services/icmp/icmp.py b/src/primaite/simulator/system/services/icmp/icmp.py index 6741d86a..84ad995d 100644 --- a/src/primaite/simulator/system/services/icmp/icmp.py +++ b/src/primaite/simulator/system/services/icmp/icmp.py @@ -7,9 +7,9 @@ from primaite import getLogger from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.service import Service +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -26,8 +26,8 @@ class ICMP(Service): def __init__(self, **kwargs): kwargs["name"] = "ICMP" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.ICMP + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["ICMP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/services/icmp/router_icmp.py b/src/primaite/simulator/system/services/icmp/router_icmp.py index 4fdc6baa..19c0ac2d 100644 --- a/src/primaite/simulator/system/services/icmp/router_icmp.py +++ b/src/primaite/simulator/system/services/icmp/router_icmp.py @@ -36,13 +36,13 @@ # self.sys_log.info(f"Received echo request from {frame.ip.src_ip_address}") # target_mac_address = self.arp.get_arp_cache_mac_address(frame.ip.src_ip_address) # src_nic = self.arp.get_arp_cache_network_interface(frame.ip.src_ip_address) -# tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP) +# tcp_header = TCPHeader(src_port=Port["ARP"], dst_port=Port["ARP"]) # # # Network Layer # ip_packet = IPPacket( # src_ip_address=network_interface.ip_address, # dst_ip_address=frame.ip.src_ip_address, -# protocol=IPProtocol.ICMP, +# protocol=IPProtocol["ICMP"], # ) # # Data Link Layer # ethernet_header = EthernetHeader( diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index 8924a821..ed89971f 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -5,9 +5,9 @@ from typing import Dict, Optional from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.service import Service, ServiceOperatingState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -21,8 +21,8 @@ class NTPClient(Service): def __init__(self, **kwargs): kwargs["name"] = "NTPClient" - kwargs["port"] = Port.NTP - kwargs["protocol"] = IPProtocol.UDP + kwargs["port"] = PORT_LOOKUP["NTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["UDP"] super().__init__(**kwargs) self.start() @@ -55,7 +55,7 @@ class NTPClient(Service): payload: NTPPacket, session_id: Optional[str] = None, dest_ip_address: IPv4Address = None, - dest_port: Port = Port.NTP, + dest_port: Port = PORT_LOOKUP["NTP"], **kwargs, ) -> bool: """Requests NTP data from NTP server. diff --git a/src/primaite/simulator/system/services/ntp/ntp_server.py b/src/primaite/simulator/system/services/ntp/ntp_server.py index 547bbc06..b674a296 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_server.py +++ b/src/primaite/simulator/system/services/ntp/ntp_server.py @@ -4,9 +4,9 @@ from typing import Dict, Optional from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.service import Service +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -16,8 +16,8 @@ class NTPServer(Service): def __init__(self, **kwargs): kwargs["name"] = "NTPServer" - kwargs["port"] = Port.NTP - kwargs["protocol"] = IPProtocol.UDP + kwargs["port"] = PORT_LOOKUP["NTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["UDP"] super().__init__(**kwargs) self.start() diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 74dcb506..4f0b879c 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -52,7 +52,7 @@ class Service(IOSoftware): def __init__(self, **kwargs): super().__init__(**kwargs) - def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: """ Register a hostnode type. @@ -60,7 +60,7 @@ class Service(IOSoftware): :type identifier: str :raises ValueError: When attempting to register an hostnode with a name that is already allocated. """ - if identifier == 'default': + if identifier == "default": return # Enforce lowercase registry entries because it makes comparisons everywhere else much easier. identifier = identifier.lower() diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index e98e8555..ae3557f7 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -17,10 +17,10 @@ from primaite.simulator.network.protocols.ssh import ( SSHTransportMessage, SSHUserCredentials, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service, ServiceOperatingState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP # TODO 2824: Since remote terminal connections and remote user sessions are the same thing, we could refactor @@ -137,8 +137,8 @@ class Terminal(Service): def __init__(self, **kwargs): kwargs["name"] = "Terminal" - kwargs["port"] = Port.SSH - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = PORT_LOOKUP["SSH"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 4fc64e1f..75d9c472 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -10,11 +10,11 @@ from primaite.simulator.network.protocols.http import ( HttpResponsePacket, HttpStatusCode, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClientConnection from primaite.simulator.system.services.service import Service from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -49,10 +49,10 @@ class WebServer(Service): def __init__(self, **kwargs): kwargs["name"] = "WebServer" - kwargs["protocol"] = IPProtocol.TCP + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] # default for web is port 80 if kwargs.get("port") is None: - kwargs["port"] = Port.HTTP + kwargs["port"] = PORT_LOOKUP["HTTP"] super().__init__(**kwargs) self._install_web_files() diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index f1d1b9a1..6fb09a16 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -13,10 +13,10 @@ from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.file_system.file_system import FileSystem, Folder from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.core.sys_log import SysLog +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port if TYPE_CHECKING: from primaite.simulator.system.core.software_manager import SoftwareManager @@ -277,7 +277,7 @@ class IOSoftware(Software): "max_sessions": self.max_sessions, "tcp": self.tcp, "udp": self.udp, - "port": self.port.value, + "port": self.port, } ) return state @@ -386,8 +386,8 @@ class IOSoftware(Software): payload: Any, session_id: Optional[str] = None, dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, - dest_port: Optional[Port] = None, - ip_protocol: IPProtocol = IPProtocol.TCP, + dest_port: Optional[int] = None, + ip_protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"], **kwargs, ) -> bool: """ diff --git a/src/primaite/utils/validation/__init__.py b/src/primaite/utils/validation/__init__.py new file mode 100644 index 00000000..be6c00e7 --- /dev/null +++ b/src/primaite/utils/validation/__init__.py @@ -0,0 +1 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK diff --git a/src/primaite/utils/validation/ip_protocol.py b/src/primaite/utils/validation/ip_protocol.py new file mode 100644 index 00000000..4e358305 --- /dev/null +++ b/src/primaite/utils/validation/ip_protocol.py @@ -0,0 +1,47 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# Define a custom IP protocol validator +from typing import Any + +from pydantic import BeforeValidator, TypeAdapter, ValidationError +from typing_extensions import Annotated, Final + +PROTOCOL_LOOKUP: dict[str, str] = dict( + NONE="none", + TCP="tcp", + UDP="udp", + ICMP="icmp", +) +""" +Lookup table used for compatibility with PrimAITE <= 3.3. Configs with the capitalised protocol names are converted +to lowercase at runtime. +""" +VALID_PROTOCOLS = ["none", "tcp", "udp", "icmp"] +"""Supported protocols.""" + + +def protocol_validator(v: Any) -> str: + """ + Validate that IP Protocols are chosen from the list of supported IP Protocols. + + The protocol list is dynamic because plugins are able to extend it, therefore it is necessary to use this custom + validator instead of being able to specify a union of string literals. + """ + if isinstance(v, str) and v in PROTOCOL_LOOKUP: + return PROTOCOL_LOOKUP[v] + if v in VALID_PROTOCOLS: + return v + raise ValueError(f"{v} is not a valid IP Protocol. It must be one of the following: {VALID_PROTOCOLS}") + + +IPProtocol: Final[Annotated] = Annotated[str, BeforeValidator(protocol_validator)] +"""Validates that IP Protocols used in the simulation belong to the list of supported protocols.""" +_IPProtocolTypeAdapter = TypeAdapter(IPProtocol) + + +def is_valid_protocol(v: Any) -> bool: + """Convenience method to return true if the value matches the schema, and false otherwise.""" + try: + _IPProtocolTypeAdapter.validate_python(v) + return True + except ValidationError: + return False diff --git a/src/primaite/utils/validators.py b/src/primaite/utils/validation/ipv4_address.py similarity index 99% rename from src/primaite/utils/validators.py rename to src/primaite/utils/validation/ipv4_address.py index 139d303c..eb0e2574 100644 --- a/src/primaite/utils/validators.py +++ b/src/primaite/utils/validation/ipv4_address.py @@ -1,4 +1,6 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + + from ipaddress import IPv4Address from typing import Any, Final @@ -6,6 +8,9 @@ from pydantic import BeforeValidator from typing_extensions import Annotated +# Define a custom type IPV4Address using the typing_extensions.Annotated. +# Annotated is used to attach metadata to type hints. In this case, it's used to associate the ipv4_validator +# with the IPv4Address type, ensuring that any usage of IPV4Address undergoes validation before assignment. def ipv4_validator(v: Any) -> IPv4Address: """ Validate the input and ensure it can be converted to an IPv4Address instance. @@ -24,9 +29,6 @@ def ipv4_validator(v: Any) -> IPv4Address: return IPv4Address(v) -# Define a custom type IPV4Address using the typing_extensions.Annotated. -# Annotated is used to attach metadata to type hints. In this case, it's used to associate the ipv4_validator -# with the IPv4Address type, ensuring that any usage of IPV4Address undergoes validation before assignment. IPV4Address: Final[Annotated] = Annotated[IPv4Address, BeforeValidator(ipv4_validator)] """ IPv4Address with with IPv4Address with with pre-validation and auto-conversion from str using ipv4_validator.. diff --git a/src/primaite/utils/validation/port.py b/src/primaite/utils/validation/port.py new file mode 100644 index 00000000..90c36add --- /dev/null +++ b/src/primaite/utils/validation/port.py @@ -0,0 +1,70 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# Define a custom port validator +from typing import Any + +from pydantic import BeforeValidator, TypeAdapter, ValidationError +from typing_extensions import Annotated, Final + +PORT_LOOKUP: dict[str, int] = dict( + UNUSED=-1, + NONE=0, + WOL=9, + FTP_DATA=20, + FTP=21, + SSH=22, + SMTP=25, + DNS=53, + HTTP=80, + POP3=110, + SFTP=115, + NTP=123, + IMAP=143, + SNMP=161, + SNMP_TRAP=162, + ARP=219, + LDAP=389, + HTTPS=443, + SMB=445, + IPP=631, + SQL_SERVER=1433, + MYSQL=3306, + RDP=3389, + RTP=5004, + RTP_ALT=5005, + DNS_ALT=5353, + HTTP_ALT=8080, + HTTPS_ALT=8443, + POSTGRES_SERVER=5432, +) +""" +Lookup table used for compatibility with PrimAITE <= 3.3. Configs with named ports names are converted +to port integers at runtime. +""" + + +def port_validator(v: Any) -> int: + """ + Validate that Ports are chosen from the list of supported Ports. + + The protocol list is dynamic because plugins are able to extend it, therefore it is necessary to use this custom + validator instead of being able to specify a union of string literals. + """ + if isinstance(v, str) and v in PORT_LOOKUP: + v = PORT_LOOKUP[v] + if isinstance(v, int) and (0 <= v <= 65535): + return v + raise ValueError(f"{v} is not a valid Port. It must be an integer in the range [0,65535] or ") + + +Port: Final[Annotated] = Annotated[int, BeforeValidator(port_validator)] +"""Validates that network ports lie in the appropriate range of [0,65535].""" +_PortTypeAdapter = TypeAdapter(Port) + + +def is_valid_port(v: Any) -> bool: + """Convenience method to return true if the value matches the schema, and false otherwise.""" + try: + _PortTypeAdapter.validate_python(v) + return True + except ValidationError: + return False diff --git a/tests/conftest.py b/tests/conftest.py index 1bbff8f2..64fe0699 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,8 +18,6 @@ from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.networks import arcd_uc2_network -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.web_browser import WebBrowser @@ -28,6 +26,8 @@ from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT rayinit() @@ -45,8 +45,8 @@ class DummyService(Service): def __init__(self, **kwargs): kwargs["name"] = "DummyService" - kwargs["port"] = Port.HTTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = PORT_LOOKUP["HTTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def receive(self, payload: Any, session_id: str, **kwargs) -> bool: @@ -58,8 +58,8 @@ class DummyApplication(Application, identifier="DummyApplication"): def __init__(self, **kwargs): kwargs["name"] = "DummyApplication" - kwargs["port"] = Port.HTTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = PORT_LOOKUP["HTTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: @@ -77,7 +77,7 @@ def uc2_network() -> Network: @pytest.fixture(scope="function") def service(file_system) -> DummyService: return DummyService( - name="DummyService", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_service") + name="DummyService", port=PORT_LOOKUP["ARP"], file_system=file_system, sys_log=SysLog(hostname="dummy_service") ) @@ -90,7 +90,7 @@ def service_class(): def application(file_system) -> DummyApplication: return DummyApplication( name="DummyApplication", - port=Port.ARP, + port=PORT_LOOKUP["ARP"], file_system=file_system, sys_log=SysLog(hostname="dummy_application"), ) @@ -350,10 +350,10 @@ def install_stuff_to_sim(sim: Simulation): network.connect(endpoint_a=server_2.network_interface[1], endpoint_b=switch_2.network_interface[2]) # 2: Configure base ACL - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22) + router.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["DNS"], dst_port=PORT_LOOKUP["DNS"], position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=3) # 3: Install server software server_1.software_manager.install(DNSServer) @@ -379,13 +379,13 @@ def install_stuff_to_sim(sim: Simulation): r = sim.network.router_nodes[0] for i, acl_rule in enumerate(r.acl.acl): if i == 1: - assert acl_rule.src_port == acl_rule.dst_port == Port.DNS + assert acl_rule.src_port == acl_rule.dst_port == PORT_LOOKUP["DNS"] elif i == 3: - assert acl_rule.src_port == acl_rule.dst_port == Port.HTTP + assert acl_rule.src_port == acl_rule.dst_port == PORT_LOOKUP["HTTP"] elif i == 22: - assert acl_rule.src_port == acl_rule.dst_port == Port.ARP + assert acl_rule.src_port == acl_rule.dst_port == PORT_LOOKUP["ARP"] elif i == 23: - assert acl_rule.protocol == IPProtocol.ICMP + assert acl_rule.protocol == PROTOCOL_LOOKUP["ICMP"] elif i == 24: ... else: diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py index 457fdb42..7f251613 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py @@ -9,8 +9,8 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests.integration_tests.configuration_file_parsing import BASIC_FIREWALL, DMZ_NETWORK, load_config @@ -68,44 +68,44 @@ def test_firewall_acl_rules_correctly_added(dmz_config): # ICMP and ARP should be allowed internal_inbound assert firewall.internal_inbound_acl.num_rules == 2 assert firewall.internal_inbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.internal_inbound_acl.acl[22].src_port == Port.ARP - assert firewall.internal_inbound_acl.acl[22].dst_port == Port.ARP + assert firewall.internal_inbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.internal_inbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert firewall.internal_inbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.internal_inbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.internal_inbound_acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert firewall.internal_inbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed internal_outbound assert firewall.internal_outbound_acl.num_rules == 2 assert firewall.internal_outbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.internal_outbound_acl.acl[22].src_port == Port.ARP - assert firewall.internal_outbound_acl.acl[22].dst_port == Port.ARP + assert firewall.internal_outbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.internal_outbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert firewall.internal_outbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.internal_outbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.internal_outbound_acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert firewall.internal_outbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed dmz_inbound assert firewall.dmz_inbound_acl.num_rules == 2 assert firewall.dmz_inbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.dmz_inbound_acl.acl[22].src_port == Port.ARP - assert firewall.dmz_inbound_acl.acl[22].dst_port == Port.ARP + assert firewall.dmz_inbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.dmz_inbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert firewall.dmz_inbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.dmz_inbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.dmz_inbound_acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert firewall.dmz_inbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed dmz_outbound assert firewall.dmz_outbound_acl.num_rules == 2 assert firewall.dmz_outbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.dmz_outbound_acl.acl[22].src_port == Port.ARP - assert firewall.dmz_outbound_acl.acl[22].dst_port == Port.ARP + assert firewall.dmz_outbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.dmz_outbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert firewall.dmz_outbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.dmz_outbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.dmz_outbound_acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert firewall.dmz_outbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed external_inbound assert firewall.external_inbound_acl.num_rules == 1 assert firewall.external_inbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.external_inbound_acl.acl[22].src_port == Port.ARP - assert firewall.external_inbound_acl.acl[22].dst_port == Port.ARP + assert firewall.external_inbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.external_inbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] # external_inbound should have implicit action PERMIT # ICMP does not have a provided ACL Rule but implicit action should allow anything assert firewall.external_inbound_acl.implicit_action == ACLAction.PERMIT @@ -113,8 +113,8 @@ def test_firewall_acl_rules_correctly_added(dmz_config): # ICMP and ARP should be allowed external_outbound assert firewall.external_outbound_acl.num_rules == 1 assert firewall.external_outbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.external_outbound_acl.acl[22].src_port == Port.ARP - assert firewall.external_outbound_acl.acl[22].dst_port == Port.ARP + assert firewall.external_outbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.external_outbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] # external_outbound should have implicit action PERMIT # ICMP does not have a provided ACL Rule but implicit action should allow anything assert firewall.external_outbound_acl.implicit_action == ACLAction.PERMIT diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py index ccde3a02..d10c7dbb 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py @@ -6,8 +6,8 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests.integration_tests.configuration_file_parsing import DMZ_NETWORK, load_config @@ -63,8 +63,8 @@ def test_router_acl_rules_correctly_added(dmz_config): # ICMP and ARP should be allowed assert router_1.acl.num_rules == 2 assert router_1.acl.acl[22].action == ACLAction.PERMIT - assert router_1.acl.acl[22].src_port == Port.ARP - assert router_1.acl.acl[22].dst_port == Port.ARP + assert router_1.acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert router_1.acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert router_1.acl.acl[23].action == ACLAction.PERMIT - assert router_1.acl.acl[23].protocol == IPProtocol.ICMP + assert router_1.acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert router_1.acl.implicit_action == ACLAction.DENY diff --git a/tests/integration_tests/extensions/applications/extended_application.py b/tests/integration_tests/extensions/applications/extended_application.py index c9b3006d..70dc7cba 100644 --- a/tests/integration_tests/extensions/applications/extended_application.py +++ b/tests/integration_tests/extensions/applications/extended_application.py @@ -15,11 +15,11 @@ from primaite.simulator.network.protocols.http import ( HttpResponsePacket, HttpStatusCode, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.dns.dns_client import DNSClient +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -44,10 +44,10 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): def __init__(self, **kwargs): kwargs["name"] = "ExtendedApplication" - kwargs["protocol"] = IPProtocol.TCP + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] # default for web is port 80 if kwargs.get("port") is None: - kwargs["port"] = Port.HTTP + kwargs["port"] = PORT_LOOKUP["HTTP"] super().__init__(**kwargs) self.run() @@ -127,7 +127,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): if self.send( payload=payload, dest_ip_address=self.domain_name_ip_address, - dest_port=parsed_url.port if parsed_url.port else Port.HTTP, + dest_port=parsed_url.port if parsed_url.port else PORT_LOOKUP["HTTP"], ): self.sys_log.info( f"{self.name}: Received HTTP {payload.request_method.name} " @@ -155,7 +155,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): self, payload: HttpRequestPacket, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = Port.HTTP, + dest_port: Optional[int] = PORT_LOOKUP["HTTP"], session_id: Optional[str] = None, **kwargs, ) -> bool: diff --git a/tests/integration_tests/extensions/nodes/giga_switch.py b/tests/integration_tests/extensions/nodes/giga_switch.py index b86bea7d..e4100741 100644 --- a/tests/integration_tests/extensions/nodes/giga_switch.py +++ b/tests/integration_tests/extensions/nodes/giga_switch.py @@ -1,3 +1,4 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from typing import Dict from prettytable import MARKDOWN, PrettyTable @@ -27,7 +28,7 @@ class GigaSwitch(NetworkNode, identifier="gigaswitch"): "A MAC address table mapping destination MAC addresses to corresponding SwitchPorts." def __init__(self, **kwargs): - print('--- Extended Component: GigaSwitch ---') + print("--- Extended Component: GigaSwitch ---") super().__init__(**kwargs) for i in range(1, self.num_ports + 1): self.connect_nic(SwitchPort()) diff --git a/tests/integration_tests/extensions/nodes/super_computer.py b/tests/integration_tests/extensions/nodes/super_computer.py index 8a1465e9..80f7e3c3 100644 --- a/tests/integration_tests/extensions/nodes/super_computer.py +++ b/tests/integration_tests/extensions/nodes/super_computer.py @@ -1,9 +1,9 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from typing import ClassVar, Dict -from primaite.simulator.network.hardware.nodes.host.host_node import NIC, HostNode +from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC from primaite.simulator.system.services.ftp.ftp_client import FTPClient -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ipv4_address import IPV4Address class SuperComputer(HostNode, identifier="supercomputer"): @@ -37,7 +37,7 @@ class SuperComputer(HostNode, identifier="supercomputer"): SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "FTPClient": FTPClient} def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs): - print('--- Extended Component: SuperComputer ---') + print("--- Extended Component: SuperComputer ---") super().__init__(ip_address=ip_address, subnet_mask=subnet_mask, **kwargs) pass diff --git a/tests/integration_tests/extensions/services/extended_service.py b/tests/integration_tests/extensions/services/extended_service.py index 3151571b..ddaf4a1e 100644 --- a/tests/integration_tests/extensions/services/extended_service.py +++ b/tests/integration_tests/extensions/services/extended_service.py @@ -7,17 +7,17 @@ from primaite import getLogger from primaite.simulator.file_system.file_system import File from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus from primaite.simulator.file_system.folder import Folder -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.service import Service, ServiceOperatingState from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) -class ExtendedService(Service, identifier='extendedservice'): +class ExtendedService(Service, identifier="extendedservice"): """ A copy of DatabaseService that uses the extension framework instead of being part of PrimAITE. @@ -38,11 +38,11 @@ class ExtendedService(Service, identifier='extendedservice'): def __init__(self, **kwargs): kwargs["name"] = "ExtendedService" - kwargs["port"] = Port.POSTGRES_SERVER - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self._create_db_file() - if kwargs.get('options'): + if kwargs.get("options"): opt = kwargs["options"] self.password = opt.get("db_password", None) if "backup_server_ip" in opt: @@ -139,7 +139,9 @@ class ExtendedService(Service, identifier='extendedservice'): old_visible_state = SoftwareHealthState.GOOD # get db file regardless of whether or not it was deleted - db_file = self.file_system.get_file(folder_name="database", file_name="extended_service_database.db", include_deleted=True) + db_file = self.file_system.get_file( + folder_name="database", file_name="extended_service_database.db", include_deleted=True + ) if db_file is None: self.sys_log.warning("Database file not initialised.") @@ -153,7 +155,9 @@ class ExtendedService(Service, identifier='extendedservice'): self.file_system.delete_file(folder_name="database", file_name="extended_service_database.db") # replace db file - self.file_system.copy_file(src_folder_name="downloads", src_file_name="extended_service_database.db", dst_folder_name="database") + self.file_system.copy_file( + src_folder_name="downloads", src_file_name="extended_service_database.db", dst_folder_name="database" + ) if self.db_file is None: self.sys_log.error("Copying database backup failed.") diff --git a/tests/integration_tests/extensions/test_extendable_config.py b/tests/integration_tests/extensions/test_extendable_config.py index 5d8af64d..8467151b 100644 --- a/tests/integration_tests/extensions/test_extendable_config.py +++ b/tests/integration_tests/extensions/test_extendable_config.py @@ -1,22 +1,22 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import os + from primaite.config.load import get_extended_config_path from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ_NETWORK, load_config -import os +from tests.integration_tests.extensions.applications.extended_application import ExtendedApplication +from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch # Import the extended components so that PrimAITE registers them from tests.integration_tests.extensions.nodes.super_computer import SuperComputer -from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch from tests.integration_tests.extensions.services.extended_service import ExtendedService -from tests.integration_tests.extensions.applications.extended_application import ExtendedApplication def test_extended_example_config(): - """Test that the example config can be parsed properly.""" - config_path = os.path.join( "tests", "assets", "configs", "extended_config.yaml") + config_path = os.path.join("tests", "assets", "configs", "extended_config.yaml") game = load_config(config_path) network: Network = game.simulation.network @@ -25,8 +25,8 @@ def test_extended_example_config(): assert len(network.router_nodes) == 1 # 1 router in network assert len(network.switch_nodes) == 1 # 1 switches in network assert len(network.server_nodes) == 5 # 5 servers in network - assert len(network.extended_hostnodes) == 1 # One extended node based on HostNode - assert len(network.extended_networknodes) == 1 # One extended node based on NetworkNode + assert len(network.extended_hostnodes) == 1 # One extended node based on HostNode + assert len(network.extended_networknodes) == 1 # One extended node based on NetworkNode - assert 'ExtendedApplication' in network.extended_hostnodes[0].software_manager.software - assert 'ExtendedService' in network.extended_hostnodes[0].software_manager.software + assert "ExtendedApplication" in network.extended_hostnodes[0].software_manager.software + assert "ExtendedService" in network.extended_hostnodes[0].software_manager.software diff --git a/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py b/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py index 806ce063..187fb1fe 100644 --- a/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py +++ b/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py @@ -11,13 +11,13 @@ from primaite.simulator.network.hardware.base import UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.red_applications.c2.c2_beacon import C2Beacon from primaite.simulator.system.applications.red_applications.c2.c2_server import C2Command, C2Server from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture @@ -26,9 +26,9 @@ def game_and_agent_fixture(game_and_agent): game, agent = game_and_agent router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=4) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=5) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=6) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=4) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["DNS"], dst_port=PORT_LOOKUP["DNS"], position=5) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=6) c2_server_host = game.simulation.network.get_node_by_hostname("client_1") c2_server_host.software_manager.install(software_class=C2Server) diff --git a/tests/integration_tests/game_layer/actions/test_configure_actions.py b/tests/integration_tests/game_layer/actions/test_configure_actions.py index 0c9ec6f0..508bd5a4 100644 --- a/tests/integration_tests/game_layer/actions/test_configure_actions.py +++ b/tests/integration_tests/game_layer/actions/test_configure_actions.py @@ -11,12 +11,12 @@ from primaite.game.agent.actions import ( ) from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.database.database_service import DatabaseService +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT from tests.conftest import ControlledAgent @@ -200,7 +200,7 @@ class TestConfigureDoSBot: game.step() assert dos_bot.target_ip_address == IPv4Address("192.168.1.99") - assert dos_bot.target_port == Port.POSTGRES_SERVER + assert dos_bot.target_port == PORT_LOOKUP["POSTGRES_SERVER"] assert dos_bot.payload == "HACC" assert not dos_bot.repeat assert dos_bot.port_scan_p_of_success == 0.875 diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py index d011c1e8..a70cea72 100644 --- a/tests/integration_tests/game_layer/actions/test_terminal_actions.py +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -9,9 +9,9 @@ from primaite.simulator.network.hardware.base import UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture @@ -20,7 +20,7 @@ def game_and_agent_fixture(game_and_agent): game, agent = game_and_agent router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=4) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=4) return (game, agent) diff --git a/tests/integration_tests/game_layer/observations/test_acl_observations.py b/tests/integration_tests/game_layer/observations/test_acl_observations.py index f1d9d416..e7212f3c 100644 --- a/tests/integration_tests/game_layer/observations/test_acl_observations.py +++ b/tests/integration_tests/game_layer/observations/test_acl_observations.py @@ -4,10 +4,10 @@ import pytest from primaite.game.agent.observations.acl_observation import ACLObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.ntp.ntp_server import NTPServer +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -33,13 +33,13 @@ def test_acl_observations(simulation): server.software_manager.install(NTPServer) # add router acl rule - router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port.NTP, src_port=Port.NTP, position=1) + router.acl.add_rule(action=ACLAction.PERMIT, dst_port=PORT_LOOKUP["NTP"], src_port=PORT_LOOKUP["NTP"], position=1) acl_obs = ACLObservation( where=["network", "nodes", router.hostname, "acl", "acl"], ip_list=[], - port_list=["NTP", "HTTP", "POSTGRES_SERVER"], - protocol_list=["TCP", "UDP", "ICMP"], + port_list=[123, 80, 5432], + protocol_list=["tcp", "udp", "icmp"], num_rules=10, wildcard_list=[], ) diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 34a37f5e..05cf910c 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -5,8 +5,8 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import ACLAction from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP def check_default_rules(acl_obs): @@ -31,8 +31,8 @@ def test_firewall_observation(): num_rules=7, ip_list=["10.0.0.1", "10.0.0.2"], wildcard_list=["0.0.0.255", "0.0.0.1"], - port_list=["HTTP", "DNS"], - protocol_list=["TCP"], + port_list=[80, 53], + protocol_list=["tcp"], include_users=False, ) @@ -62,13 +62,13 @@ def test_firewall_observation(): # add a rule to the internal inbound and check that the observation is correct firewall.internal_inbound_acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="10.0.0.1", src_wildcard_mask="0.0.0.1", dst_ip_address="10.0.0.2", dst_wildcard_mask="0.0.0.1", - src_port=Port.HTTP, - dst_port=Port.HTTP, + src_port=PORT_LOOKUP["HTTP"], + dst_port=PORT_LOOKUP["HTTP"], position=5, ) diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index ef789ba7..8254dad2 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -152,7 +152,12 @@ def test_config_nic_categories(simulation): def test_nic_monitored_traffic(simulation): - monitored_traffic = {"icmp": ["NONE"], "tcp": ["DNS"]} + monitored_traffic = { + "icmp": ["NONE"], + "tcp": [ + 53, + ], + } pc: Computer = simulation.network.get_node_by_hostname("client_1") pc2: Computer = simulation.network.get_node_by_hostname("client_2") diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index 48d29cfb..4ced02f5 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -8,9 +8,9 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP def test_router_observation(): @@ -24,8 +24,8 @@ def test_router_observation(): num_rules=7, ip_list=["10.0.0.1", "10.0.0.2"], wildcard_list=["0.0.0.255", "0.0.0.1"], - port_list=["HTTP", "DNS"], - protocol_list=["TCP"], + port_list=[80, 53], + protocol_list=["tcp"], ) router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl, include_users=False) @@ -39,13 +39,13 @@ def test_router_observation(): # Add an ACL rule to the router router.acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="10.0.0.1", src_wildcard_mask="0.0.0.1", dst_ip_address="10.0.0.2", dst_wildcard_mask="0.0.0.1", - src_port=Port.HTTP, - dst_port=Port.HTTP, + src_port=PORT_LOOKUP["HTTP"], + dst_port=PORT_LOOKUP["HTTP"], position=5, ) # Observe the state using the RouterObservation instance diff --git a/tests/integration_tests/game_layer/observations/test_user_observations.py b/tests/integration_tests/game_layer/observations/test_user_observations.py index ca5e2543..e7287eee 100644 --- a/tests/integration_tests/game_layer/observations/test_user_observations.py +++ b/tests/integration_tests/game_layer/observations/test_user_observations.py @@ -3,7 +3,7 @@ import pytest from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml" @@ -15,7 +15,7 @@ def env_with_ssh() -> PrimaiteGymEnv: env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG) env.agent.flatten_obs = False router: Router = env.game.simulation.network.get_node_by_hostname("router_1") - router.acl.add_rule(ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=3) + router.acl.add_rule(ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=3) return env diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index a1005f34..e03a7d26 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -21,11 +21,11 @@ from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT FIREWALL_ACTIONS_NETWORK = TEST_ASSETS_ROOT / "configs/firewall_actions_network.yaml" @@ -608,9 +608,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.internal_outbound_acl.acl[1].action.name == "DENY" assert firewall.internal_outbound_acl.acl[1].src_ip_address == IPv4Address("192.168.0.10") assert firewall.internal_outbound_acl.acl[1].dst_ip_address is None - assert firewall.internal_outbound_acl.acl[1].dst_port == Port.DNS - assert firewall.internal_outbound_acl.acl[1].src_port == Port.ARP - assert firewall.internal_outbound_acl.acl[1].protocol == IPProtocol.ICMP + assert firewall.internal_outbound_acl.acl[1].dst_port == PORT_LOOKUP["DNS"] + assert firewall.internal_outbound_acl.acl[1].src_port == PORT_LOOKUP["ARP"] + assert firewall.internal_outbound_acl.acl[1].protocol == PROTOCOL_LOOKUP["ICMP"] env.step(4) # Remove ACL rule from Internal Outbound assert firewall.internal_outbound_acl.num_rules == 2 @@ -620,9 +620,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.dmz_inbound_acl.acl[1].action.name == "DENY" assert firewall.dmz_inbound_acl.acl[1].src_ip_address == IPv4Address("192.168.10.10") assert firewall.dmz_inbound_acl.acl[1].dst_ip_address == IPv4Address("192.168.0.10") - assert firewall.dmz_inbound_acl.acl[1].dst_port == Port.HTTP - assert firewall.dmz_inbound_acl.acl[1].src_port == Port.HTTP - assert firewall.dmz_inbound_acl.acl[1].protocol == IPProtocol.UDP + assert firewall.dmz_inbound_acl.acl[1].dst_port == PORT_LOOKUP["HTTP"] + assert firewall.dmz_inbound_acl.acl[1].src_port == PORT_LOOKUP["HTTP"] + assert firewall.dmz_inbound_acl.acl[1].protocol == PROTOCOL_LOOKUP["UDP"] env.step(6) # Remove ACL rule from DMZ Inbound assert firewall.dmz_inbound_acl.num_rules == 2 @@ -632,9 +632,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.dmz_outbound_acl.acl[2].action.name == "DENY" assert firewall.dmz_outbound_acl.acl[2].src_ip_address == IPv4Address("192.168.10.10") assert firewall.dmz_outbound_acl.acl[2].dst_ip_address == IPv4Address("192.168.0.10") - assert firewall.dmz_outbound_acl.acl[2].dst_port == Port.HTTP - assert firewall.dmz_outbound_acl.acl[2].src_port == Port.HTTP - assert firewall.dmz_outbound_acl.acl[2].protocol == IPProtocol.TCP + assert firewall.dmz_outbound_acl.acl[2].dst_port == PORT_LOOKUP["HTTP"] + assert firewall.dmz_outbound_acl.acl[2].src_port == PORT_LOOKUP["HTTP"] + assert firewall.dmz_outbound_acl.acl[2].protocol == PROTOCOL_LOOKUP["TCP"] env.step(8) # Remove ACL rule from DMZ Outbound assert firewall.dmz_outbound_acl.num_rules == 2 @@ -644,9 +644,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.external_inbound_acl.acl[10].action.name == "DENY" assert firewall.external_inbound_acl.acl[10].src_ip_address == IPv4Address("192.168.20.10") assert firewall.external_inbound_acl.acl[10].dst_ip_address == IPv4Address("192.168.10.10") - assert firewall.external_inbound_acl.acl[10].dst_port == Port.POSTGRES_SERVER - assert firewall.external_inbound_acl.acl[10].src_port == Port.POSTGRES_SERVER - assert firewall.external_inbound_acl.acl[10].protocol == IPProtocol.ICMP + assert firewall.external_inbound_acl.acl[10].dst_port == PORT_LOOKUP["POSTGRES_SERVER"] + assert firewall.external_inbound_acl.acl[10].src_port == PORT_LOOKUP["POSTGRES_SERVER"] + assert firewall.external_inbound_acl.acl[10].protocol == PROTOCOL_LOOKUP["ICMP"] env.step(10) # Remove ACL rule from External Inbound assert firewall.external_inbound_acl.num_rules == 1 diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 58783d70..0005b508 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -9,11 +9,11 @@ from primaite.interface.request import RequestResponse from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT from tests.conftest import ControlledAgent @@ -42,7 +42,12 @@ def test_WebpageUnavailablePenalty(game_and_agent): # Block the web traffic, check that failing to fetch the webpage yields a reward of -0.7 router: Router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.DENY, protocol=IPProtocol.TCP, src_port=Port.HTTP, dst_port=Port.HTTP) + router.acl.add_rule( + action=ACLAction.DENY, + protocol=PROTOCOL_LOOKUP["TCP"], + src_port=PORT_LOOKUP["HTTP"], + dst_port=PORT_LOOKUP["HTTP"], + ) agent.store_action(("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0})) game.step() assert agent.reward_function.current_reward == -0.7 @@ -65,7 +70,9 @@ def test_uc2_rewards(game_and_agent): db_client.run() router: Router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=2) + router.acl.add_rule( + ACLAction.PERMIT, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"], position=2 + ) comp = GreenAdminDatabaseUnreachablePenalty("client_1") diff --git a/tests/integration_tests/network/test_airspace_config.py b/tests/integration_tests/network/test_airspace_config.py index 78d00b47..e000f6ae 100644 --- a/tests/integration_tests/network/test_airspace_config.py +++ b/tests/integration_tests/network/test_airspace_config.py @@ -2,7 +2,6 @@ import yaml from primaite.game.game import PrimaiteGame -from primaite.simulator.network.airspace import AirSpaceFrequency from tests import TEST_ASSETS_ROOT @@ -13,8 +12,8 @@ def test_override_freq_max_capacity_mbps(): config_dict = yaml.safe_load(f) network = PrimaiteGame.from_config(cfg=config_dict).simulation.network - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_2_4) == 123.45 - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_5) == 0.0 + assert network.airspace.get_frequency_max_capacity_mbps("WIFI_2_4") == 123.45 + assert network.airspace.get_frequency_max_capacity_mbps("WIFI_5") == 0.0 pc_a = network.get_node_by_hostname("pc_a") pc_b = network.get_node_by_hostname("pc_b") @@ -32,8 +31,8 @@ def test_override_freq_max_capacity_mbps_blocked(): config_dict = yaml.safe_load(f) network = PrimaiteGame.from_config(cfg=config_dict).simulation.network - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_2_4) == 0.0 - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_5) == 0.0 + assert network.airspace.get_frequency_max_capacity_mbps("WIFI_2_4") == 0.0 + assert network.airspace.get_frequency_max_capacity_mbps("WIFI_5") == 0.0 pc_a = network.get_node_by_hostname("pc_a") pc_b = network.get_node_by_hostname("pc_b") diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index 80007c46..f07f02e7 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -8,10 +8,10 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application from primaite.simulator.system.services.service import Service +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP class BroadcastTestService(Service): @@ -20,8 +20,8 @@ class BroadcastTestService(Service): def __init__(self, **kwargs): # Set default service properties for broadcasting kwargs["name"] = "BroadcastService" - kwargs["port"] = Port.HTTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = PORT_LOOKUP["HTTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: @@ -33,12 +33,14 @@ class BroadcastTestService(Service): super().send( payload="unicast", dest_ip_address=ip_address, - dest_port=Port.HTTP, + dest_port=PORT_LOOKUP["HTTP"], ) def broadcast(self, ip_network: IPv4Network): # Send a broadcast payload to an entire IP network - super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port.HTTP, ip_protocol=self.protocol) + super().send( + payload="broadcast", dest_ip_address=ip_network, dest_port=PORT_LOOKUP["HTTP"], ip_protocol=self.protocol + ) class BroadcastTestClient(Application, identifier="BroadcastTestClient"): @@ -49,8 +51,8 @@ class BroadcastTestClient(Application, identifier="BroadcastTestClient"): def __init__(self, **kwargs): # Set default client properties kwargs["name"] = "BroadcastTestClient" - kwargs["port"] = Port.HTTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = PORT_LOOKUP["HTTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/tests/integration_tests/network/test_firewall.py b/tests/integration_tests/network/test_firewall.py index b15ee51a..79452318 100644 --- a/tests/integration_tests/network/test_firewall.py +++ b/tests/integration_tests/network/test_firewall.py @@ -7,10 +7,10 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.ntp.ntp_server import NTPServer +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -53,28 +53,32 @@ def dmz_external_internal_network() -> Network: ) # Allow ICMP - firewall_node.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.external_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.external_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + firewall_node.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.external_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.external_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Allow ARP firewall_node.internal_inbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) firewall_node.internal_outbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) firewall_node.external_inbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) firewall_node.external_outbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + firewall_node.dmz_inbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + firewall_node.dmz_outbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) - firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) # external node external_node = Computer( @@ -262,8 +266,12 @@ def test_service_allowed_with_rule(dmz_external_internal_network): assert not internal_ntp_client.time - firewall.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=1) - firewall.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=1) + firewall.internal_outbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=1 + ) + firewall.internal_inbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=1 + ) internal_ntp_client.request_time() @@ -271,8 +279,12 @@ def test_service_allowed_with_rule(dmz_external_internal_network): assert not dmz_ntp_client.time - firewall.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=1) - firewall.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=1) + firewall.dmz_outbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=1 + ) + firewall.dmz_inbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=1 + ) dmz_ntp_client.request_time() diff --git a/tests/integration_tests/network/test_routing.py b/tests/integration_tests/network/test_routing.py index 62b58cbd..04cdbe78 100644 --- a/tests/integration_tests/network/test_routing.py +++ b/tests/integration_tests/network/test_routing.py @@ -6,10 +6,10 @@ import pytest from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.ntp.ntp_server import NTPServer +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -73,8 +73,10 @@ def multi_hop_network() -> Network: router_1.enable_port(2) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Configure PC B pc_b = Computer( @@ -197,8 +199,12 @@ def test_routing_services(multi_hop_network): router_1: Router = multi_hop_network.get_node_by_hostname("router_1") # noqa router_2: Router = multi_hop_network.get_node_by_hostname("router_2") # noqa - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=21) - router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=21) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=21 + ) + router_2.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=21 + ) assert ntp_client.time is None ntp_client.request_time() diff --git a/tests/integration_tests/network/test_wireless_router.py b/tests/integration_tests/network/test_wireless_router.py index 733de6f6..fb0035e9 100644 --- a/tests/integration_tests/network/test_wireless_router.py +++ b/tests/integration_tests/network/test_wireless_router.py @@ -7,8 +7,8 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT @@ -37,8 +37,10 @@ def wireless_wan_network(): network.connect(pc_a.network_interface[1], router_1.network_interface[2]) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Configure PC B pc_b = Computer( diff --git a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py index 9d12f2cf..2cbd4d11 100644 --- a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py +++ b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py @@ -13,8 +13,6 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import AccessControlList, ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.c2.c2_beacon import C2Beacon @@ -25,6 +23,8 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT @@ -227,7 +227,7 @@ def test_c2_suite_acl_block(basic_network): assert c2_beacon.c2_connection_active == True # Now we add a HTTP blocking acl (Thus preventing a keep alive) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0) c2_beacon.apply_timestep(2) c2_beacon.apply_timestep(3) @@ -322,8 +322,8 @@ def test_c2_suite_acl_bypass(basic_network): ################ Confirm Default Setup ######################### # Permitting all HTTP & FTP traffic - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=1) c2_beacon.apply_timestep(0) assert c2_beacon.keep_alive_inactivity == 1 @@ -337,7 +337,7 @@ def test_c2_suite_acl_bypass(basic_network): ################ Denying HTTP Traffic ######################### # Now we add a HTTP blocking acl (Thus preventing a keep alive) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0) blocking_acl: AccessControlList = router.acl.acl[0] # Asserts to show the C2 Suite is unable to maintain connection: @@ -359,8 +359,8 @@ def test_c2_suite_acl_bypass(basic_network): c2_beacon.configure( c2_server_ip_address="192.168.0.2", keep_alive_frequency=2, - masquerade_port=Port.FTP, - masquerade_protocol=IPProtocol.TCP, + masquerade_port=PORT_LOOKUP["FTP"], + masquerade_protocol=PROTOCOL_LOOKUP["TCP"], ) c2_beacon.establish() @@ -407,8 +407,8 @@ def test_c2_suite_acl_bypass(basic_network): ################ Denying FTP Traffic & Enable HTTP ######################### # Blocking FTP and re-permitting HTTP: - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.FTP, dst_port=Port.FTP, position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=1) blocking_acl: AccessControlList = router.acl.acl[1] # Asserts to show the C2 Suite is unable to maintain connection: @@ -430,8 +430,8 @@ def test_c2_suite_acl_bypass(basic_network): c2_beacon.configure( c2_server_ip_address="192.168.0.2", keep_alive_frequency=2, - masquerade_port=Port.HTTP, - masquerade_protocol=IPProtocol.TCP, + masquerade_port=PORT_LOOKUP["HTTP"], + masquerade_protocol=PROTOCOL_LOOKUP["TCP"], ) c2_beacon.establish() diff --git a/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py b/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py index 2e87578d..50b0ceac 100644 --- a/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py +++ b/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py @@ -9,7 +9,6 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection from primaite.simulator.system.applications.red_applications.data_manipulation_bot import ( @@ -19,6 +18,7 @@ from primaite.simulator.system.applications.red_applications.data_manipulation_b from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -52,7 +52,10 @@ def data_manipulation_db_server_green_client(example_network) -> Network: router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) client_1: Computer = network.get_node_by_hostname("client_1") diff --git a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py index 68c1fbfe..1a09e875 100644 --- a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py +++ b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py @@ -8,12 +8,12 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -26,7 +26,7 @@ def dos_bot_and_db_server(client_server) -> Tuple[DoSBot, Computer, DatabaseServ dos_bot: DoSBot = computer.software_manager.software.get("DoSBot") dos_bot.configure( target_ip_address=IPv4Address(server.network_interface[1].ip_address), - target_port=Port.POSTGRES_SERVER, + target_port=PORT_LOOKUP["POSTGRES_SERVER"], ) # Install DB Server service on server @@ -43,7 +43,10 @@ def dos_bot_db_server_green_client(example_network) -> Network: router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) client_1: Computer = network.get_node_by_hostname("client_1") @@ -56,7 +59,7 @@ def dos_bot_db_server_green_client(example_network) -> Network: dos_bot: DoSBot = client_1.software_manager.software.get("DoSBot") dos_bot.configure( target_ip_address=IPv4Address(server.network_interface[1].ip_address), - target_port=Port.POSTGRES_SERVER, + target_port=PORT_LOOKUP["POSTGRES_SERVER"], ) # install db server service on server diff --git a/tests/integration_tests/system/red_applications/test_ransomware_script.py b/tests/integration_tests/system/red_applications/test_ransomware_script.py index 97abafb5..a5adbb04 100644 --- a/tests/integration_tests/system/red_applications/test_ransomware_script.py +++ b/tests/integration_tests/system/red_applications/test_ransomware_script.py @@ -9,11 +9,11 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -47,7 +47,10 @@ def ransomware_script_db_server_green_client(example_network) -> Network: router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) client_1: Computer = network.get_node_by_hostname("client_1") diff --git a/tests/integration_tests/system/test_nmap.py b/tests/integration_tests/system/test_nmap.py index 2b8691cc..c52b5caa 100644 --- a/tests/integration_tests/system/test_nmap.py +++ b/tests/integration_tests/system/test_nmap.py @@ -5,9 +5,9 @@ from ipaddress import IPv4Address, IPv4Network import yaml from primaite.game.game import PrimaiteGame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.nmap import NMAP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT @@ -73,10 +73,12 @@ def test_port_scan_one_node_one_port(example_network): client_2 = network.get_node_by_hostname("client_2") actual_result = client_1_nmap.port_scan( - target_ip_address=client_2.network_interface[1].ip_address, target_port=Port.DNS, target_protocol=IPProtocol.TCP + target_ip_address=client_2.network_interface[1].ip_address, + target_port=PORT_LOOKUP["DNS"], + target_protocol=PROTOCOL_LOOKUP["TCP"], ) - expected_result = {IPv4Address("192.168.10.22"): {IPProtocol.TCP: [Port.DNS]}} + expected_result = {IPv4Address("192.168.10.22"): {PROTOCOL_LOOKUP["TCP"]: [PORT_LOOKUP["DNS"]]}} assert actual_result == expected_result @@ -101,14 +103,20 @@ def test_port_scan_full_subnet_all_ports_and_protocols(example_network): actual_result = client_1_nmap.port_scan( target_ip_address=IPv4Network("192.168.10.0/24"), - target_port=[Port.ARP, Port.HTTP, Port.FTP, Port.DNS, Port.NTP], + target_port=[ + PORT_LOOKUP["ARP"], + PORT_LOOKUP["HTTP"], + PORT_LOOKUP["FTP"], + PORT_LOOKUP["DNS"], + PORT_LOOKUP["NTP"], + ], ) expected_result = { - IPv4Address("192.168.10.1"): {IPProtocol.UDP: [Port.ARP]}, + IPv4Address("192.168.10.1"): {PROTOCOL_LOOKUP["UDP"]: [PORT_LOOKUP["ARP"]]}, IPv4Address("192.168.10.22"): { - IPProtocol.TCP: [Port.HTTP, Port.FTP, Port.DNS], - IPProtocol.UDP: [Port.ARP, Port.NTP], + PROTOCOL_LOOKUP["TCP"]: [PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]], + PROTOCOL_LOOKUP["UDP"]: [PORT_LOOKUP["ARP"], PORT_LOOKUP["NTP"]], }, } @@ -122,10 +130,12 @@ def test_network_service_recon_all_ports_and_protocols(example_network): client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa actual_result = client_1_nmap.network_service_recon( - target_ip_address=IPv4Network("192.168.10.0/24"), target_port=Port.HTTP, target_protocol=IPProtocol.TCP + target_ip_address=IPv4Network("192.168.10.0/24"), + target_port=PORT_LOOKUP["HTTP"], + target_protocol=PROTOCOL_LOOKUP["TCP"], ) - expected_result = {IPv4Address("192.168.10.22"): {IPProtocol.TCP: [Port.HTTP]}} + expected_result = {IPv4Address("192.168.10.22"): {PROTOCOL_LOOKUP["TCP"]: [PORT_LOOKUP["HTTP"]]}} assert sort_dict(actual_result) == sort_dict(expected_result) diff --git a/tests/integration_tests/system/test_service_listening_on_ports.py b/tests/integration_tests/system/test_service_listening_on_ports.py index fd502a70..7a085ee1 100644 --- a/tests/integration_tests/system/test_service_listening_on_ports.py +++ b/tests/integration_tests/system/test_service_listening_on_ports.py @@ -6,19 +6,19 @@ from pydantic import Field from primaite.game.game import PrimaiteGame from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.service import Service +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT class _DatabaseListener(Service): name: str = "DatabaseListener" - protocol: IPProtocol = IPProtocol.TCP - port: Port = Port.NONE - listen_on_ports: Set[Port] = {Port.POSTGRES_SERVER} + protocol: str = PROTOCOL_LOOKUP["TCP"] + port: int = PORT_LOOKUP["NONE"] + listen_on_ports: Set[int] = {PORT_LOOKUP["POSTGRES_SERVER"]} payloads_received: List[Any] = Field(default_factory=list) def receive(self, payload: Any, session_id: str, **kwargs) -> bool: @@ -51,8 +51,8 @@ def test_http_listener(client_server): computer.session_manager.receive_payload_from_software_manager( payload="masquerade as Database traffic", dst_ip_address=server.network_interface[1].ip_address, - dst_port=Port.POSTGRES_SERVER, - ip_protocol=IPProtocol.TCP, + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + ip_protocol=PROTOCOL_LOOKUP["TCP"], ) assert len(server_db_listener.payloads_received) == 1 @@ -76,9 +76,9 @@ def test_set_listen_on_ports_from_config(): network = PrimaiteGame.from_config(cfg=config_dict).simulation.network client: Computer = network.get_node_by_hostname("client") - assert Port.SMB in client.software_manager.get_open_ports() - assert Port.IPP in client.software_manager.get_open_ports() + assert PORT_LOOKUP["SMB"] in client.software_manager.get_open_ports() + assert PORT_LOOKUP["IPP"] in client.software_manager.get_open_ports() web_browser = client.software_manager.software["WebBrowser"] - assert not web_browser.listen_on_ports.difference({Port.SMB, Port.IPP}) + assert not web_browser.listen_on_ports.difference({PORT_LOOKUP["SMB"], PORT_LOOKUP["IPP"]}) diff --git a/tests/integration_tests/system/test_web_client_server_and_database.py b/tests/integration_tests/system/test_web_client_server_and_database.py index 5a765763..f2ac1183 100644 --- a/tests/integration_tests/system/test_web_client_server_and_database.py +++ b/tests/integration_tests/system/test_web_client_server_and_database.py @@ -9,7 +9,6 @@ from primaite.simulator.network.hardware.base import Link from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService @@ -17,6 +16,7 @@ from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.web_server.web_server import WebServer from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -24,17 +24,22 @@ def web_client_web_server_database(example_network) -> Tuple[Network, Computer, # add rules to network router router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) # Allow DNS requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["DNS"], dst_port=PORT_LOOKUP["DNS"], position=1) # Allow FTP requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=2) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=2) # Open port 80 for web server - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=3 + ) # Create Computer computer: Computer = example_network.get_node_by_hostname("client_1") @@ -148,7 +153,9 @@ class TestWebBrowserHistory: assert web_browser.history[-1].response_code == 200 router = network.get_node_by_hostname("router_1") - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) + router.acl.add_rule( + action=ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0 + ) assert not web_browser.get_webpage() assert len(web_browser.history) == 3 # with current NIC behaviour, even if you block communication, you won't get SERVER_UNREACHABLE because @@ -166,7 +173,9 @@ class TestWebBrowserHistory: web_browser.get_webpage() router = network.get_node_by_hostname("router_1") - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) + router.acl.add_rule( + action=ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0 + ) web_browser.get_webpage() state = computer.describe_state() diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py index 95634cf1..a767f365 100644 --- a/tests/integration_tests/test_simulation/test_request_response.py +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -12,7 +12,7 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validation.port import PORT_LOOKUP from tests.conftest import DummyApplication, DummyService @@ -171,7 +171,7 @@ class TestDataManipulationGreenRequests: assert client_1_browser_execute.status == "success" assert client_2_browser_execute.status == "success" - router.acl.add_rule(ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + router.acl.add_rule(ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=3) client_1_browser_execute = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"]) client_2_browser_execute = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"]) assert client_1_browser_execute.status == "failure" @@ -182,7 +182,9 @@ class TestDataManipulationGreenRequests: assert client_1_db_client_execute.status == "success" assert client_2_db_client_execute.status == "success" - router.acl.add_rule(ACLAction.DENY, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER) + router.acl.add_rule( + ACLAction.DENY, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"] + ) client_1_db_client_execute = net.apply_request(["node", "client_1", "application", "DatabaseClient", "execute"]) client_2_db_client_execute = net.apply_request(["node", "client_2", "application", "DatabaseClient", "execute"]) assert client_1_db_client_execute.status == "failure" diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py index 9bc1abfd..6eca0c44 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py @@ -7,8 +7,10 @@ from primaite.simulator.network.hardware.base import generate_mac_address from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame -from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader, UDPHeader +from primaite.simulator.network.transmission.network_layer import IPPacket +from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -28,20 +30,20 @@ def router_with_acl_rules(): # Add rules here as needed acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="192.168.1.1", - src_port=Port.HTTPS, + src_port=PORT_LOOKUP["HTTPS"], dst_ip_address="192.168.1.2", - dst_port=Port.HTTP, + dst_port=PORT_LOOKUP["HTTP"], position=1, ) acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="192.168.1.3", - src_port=Port(8080), + src_port=8080, dst_ip_address="192.168.1.4", - dst_port=Port(80), + dst_port=80, position=2, ) return router @@ -65,21 +67,21 @@ def router_with_wildcard_acl(): # Rule to permit traffic from a specific source IP and port to a specific destination IP and port acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="192.168.1.1", - src_port=Port(8080), + src_port=8080, dst_ip_address="10.1.1.2", - dst_port=Port(80), + dst_port=80, position=1, ) # Rule to deny traffic from an IP range to a specific destination IP and port acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="192.168.1.0", src_wildcard_mask="0.0.0.255", dst_ip_address="10.1.1.3", - dst_port=Port(443), + dst_port=443, position=2, ) # Rule to permit any traffic to a range of destination IPs @@ -109,11 +111,11 @@ def test_add_rule(router_with_acl_rules): acl = router_with_acl_rules.acl assert acl.acl[1].action == ACLAction.PERMIT - assert acl.acl[1].protocol == IPProtocol.TCP + assert acl.acl[1].protocol == PROTOCOL_LOOKUP["TCP"] assert acl.acl[1].src_ip_address == IPv4Address("192.168.1.1") - assert acl.acl[1].src_port == Port.HTTPS + assert acl.acl[1].src_port == PORT_LOOKUP["HTTPS"] assert acl.acl[1].dst_ip_address == IPv4Address("192.168.1.2") - assert acl.acl[1].dst_port == Port.HTTP + assert acl.acl[1].dst_port == PORT_LOOKUP["HTTP"] def test_remove_rule(router_with_acl_rules): @@ -136,8 +138,8 @@ def test_traffic_permitted_by_specific_rule(router_with_acl_rules): acl = router_with_acl_rules.acl permitted_frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="192.168.1.2", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port.HTTPS, dst_port=Port.HTTP), + ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="192.168.1.2", protocol=PROTOCOL_LOOKUP["TCP"]), + tcp=TCPHeader(src_port=PORT_LOOKUP["HTTPS"], dst_port=PORT_LOOKUP["HTTP"]), ) is_permitted, _ = acl.is_permitted(permitted_frame) assert is_permitted @@ -153,8 +155,8 @@ def test_traffic_denied_by_specific_rule(router_with_acl_rules): acl = router_with_acl_rules.acl not_permitted_frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.3", dst_ip_address="192.168.1.4", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port(8080), dst_port=Port(80)), + ip=IPPacket(src_ip_address="192.168.1.3", dst_ip_address="192.168.1.4", protocol=PROTOCOL_LOOKUP["TCP"]), + tcp=TCPHeader(src_port=8080, dst_port=80), ) is_permitted, _ = acl.is_permitted(not_permitted_frame) assert not is_permitted @@ -173,8 +175,8 @@ def test_default_rule(router_with_acl_rules): acl = router_with_acl_rules.acl not_permitted_frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.5", dst_ip_address="192.168.1.12", protocol=IPProtocol.UDP), - udp=UDPHeader(src_port=Port.HTTPS, dst_port=Port.HTTP), + ip=IPPacket(src_ip_address="192.168.1.5", dst_ip_address="192.168.1.12", protocol=PROTOCOL_LOOKUP["UDP"]), + udp=UDPHeader(src_port=PORT_LOOKUP["HTTPS"], dst_port=PORT_LOOKUP["HTTP"]), ) is_permitted, rule = acl.is_permitted(not_permitted_frame) assert not is_permitted @@ -189,8 +191,8 @@ def test_direct_ip_match_with_acl(router_with_wildcard_acl): acl = router_with_wildcard_acl.acl frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="10.1.1.2", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port(8080), dst_port=Port(80)), + ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="10.1.1.2", protocol=PROTOCOL_LOOKUP["TCP"]), + tcp=TCPHeader(src_port=8080, dst_port=80), ) assert acl.is_permitted(frame)[0], "Direct IP match should be permitted." @@ -204,8 +206,8 @@ def test_ip_range_match_denied_with_acl(router_with_wildcard_acl): acl = router_with_wildcard_acl.acl frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.100", dst_ip_address="10.1.1.3", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port(8080), dst_port=Port(443)), + ip=IPPacket(src_ip_address="192.168.1.100", dst_ip_address="10.1.1.3", protocol=PROTOCOL_LOOKUP["TCP"]), + tcp=TCPHeader(src_port=8080, dst_port=443), ) assert not acl.is_permitted(frame)[0], "IP range match with wildcard mask should be denied." @@ -219,8 +221,8 @@ def test_traffic_permitted_to_destination_range_with_acl(router_with_wildcard_ac acl = router_with_wildcard_acl.acl frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=IPProtocol.UDP), - udp=UDPHeader(src_port=Port(1433), dst_port=Port(1433)), + ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=PROTOCOL_LOOKUP["UDP"]), + udp=UDPHeader(src_port=1433, dst_port=1433), ) assert acl.is_permitted(frame)[0], "Traffic to destination IP range should be permitted." @@ -253,23 +255,23 @@ def test_ip_traffic_from_specific_subnet(): permitted_frame_1 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER), + ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=PROTOCOL_LOOKUP["TCP"]), + tcp=TCPHeader(src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"]), ) assert acl.is_permitted(permitted_frame_1)[0] permitted_frame_2 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.10", dst_ip_address="85.199.214.101", protocol=IPProtocol.UDP), - udp=UDPHeader(src_port=Port.NTP, dst_port=Port.NTP), + ip=IPPacket(src_ip_address="192.168.1.10", dst_ip_address="85.199.214.101", protocol=PROTOCOL_LOOKUP["UDP"]), + udp=UDPHeader(src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"]), ) assert acl.is_permitted(permitted_frame_2)[0] permitted_frame_3 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.200", dst_ip_address="192.168.1.1", protocol=IPProtocol.ICMP), + ip=IPPacket(src_ip_address="192.168.1.200", dst_ip_address="192.168.1.1", protocol=PROTOCOL_LOOKUP["ICMP"]), icmp=ICMPPacket(identifier=1), ) @@ -277,16 +279,16 @@ def test_ip_traffic_from_specific_subnet(): not_permitted_frame_1 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.0.50", dst_ip_address="10.2.200.200", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER), + ip=IPPacket(src_ip_address="192.168.0.50", dst_ip_address="10.2.200.200", protocol=PROTOCOL_LOOKUP["TCP"]), + tcp=TCPHeader(src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"]), ) assert not acl.is_permitted(not_permitted_frame_1)[0] not_permitted_frame_2 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.2.10", dst_ip_address="85.199.214.101", protocol=IPProtocol.UDP), - udp=UDPHeader(src_port=Port.NTP, dst_port=Port.NTP), + ip=IPPacket(src_ip_address="192.168.2.10", dst_ip_address="85.199.214.101", protocol=PROTOCOL_LOOKUP["UDP"]), + udp=UDPHeader(src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"]), ) assert not acl.is_permitted(not_permitted_frame_2)[0] diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py index d4e38ded..fe9387de 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py @@ -2,8 +2,8 @@ from ipaddress import IPv4Address from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP def test_wireless_router_from_config(): @@ -67,12 +67,12 @@ def test_wireless_router_from_config(): r0 = rt.acl.acl[0] assert r0.action == ACLAction.PERMIT - assert r0.src_port == r0.dst_port == Port.POSTGRES_SERVER + assert r0.src_port == r0.dst_port == PORT_LOOKUP["POSTGRES_SERVER"] assert r0.src_ip_address == r0.dst_ip_address == r0.dst_wildcard_mask == r0.src_wildcard_mask == r0.protocol == None r1 = rt.acl.acl[1] assert r1.action == ACLAction.PERMIT - assert r1.protocol == IPProtocol.ICMP + assert r1.protocol == PROTOCOL_LOOKUP["ICMP"] assert ( r1.src_ip_address == r1.dst_ip_address diff --git a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py index 92618baa..e7e425b1 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py @@ -3,9 +3,11 @@ import pytest from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame -from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol, Precedence +from primaite.simulator.network.transmission.network_layer import IPPacket, Precedence from primaite.simulator.network.transmission.primaite_layer import AgentSource, DataStatus -from primaite.simulator.network.transmission.transport_layer import Port, TCPFlags, TCPHeader, UDPHeader +from primaite.simulator.network.transmission.transport_layer import TCPFlags, TCPHeader, UDPHeader +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP def test_frame_minimal_instantiation(): @@ -20,7 +22,7 @@ def test_frame_minimal_instantiation(): ) # Check network layer default values - assert frame.ip.protocol == IPProtocol.TCP + assert frame.ip.protocol == PROTOCOL_LOOKUP["TCP"] assert frame.ip.ttl == 64 assert frame.ip.precedence == Precedence.ROUTINE @@ -40,7 +42,7 @@ def test_frame_creation_fails_tcp_without_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.TCP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["TCP"]), ) @@ -49,7 +51,7 @@ def test_frame_creation_fails_udp_without_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.UDP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["UDP"]), ) @@ -58,7 +60,7 @@ def test_frame_creation_fails_tcp_with_udp_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.TCP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["TCP"]), udp=UDPHeader(src_port=8080, dst_port=80), ) @@ -68,7 +70,7 @@ def test_frame_creation_fails_udp_with_tcp_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.UDP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["UDP"]), udp=TCPHeader(src_port=8080, dst_port=80), ) @@ -77,7 +79,7 @@ def test_icmp_frame_creation(): """Tests Frame creation for ICMP.""" frame = Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.ICMP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["ICMP"]), icmp=ICMPPacket(), ) assert frame @@ -88,5 +90,5 @@ def test_icmp_frame_creation_fails_without_icmp_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.ICMP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["ICMP"]), ) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py index 885a3cb6..12dddf67 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py @@ -4,11 +4,11 @@ import pytest from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.red_applications.c2.c2_beacon import C2Beacon from primaite.simulator.system.applications.red_applications.c2.c2_server import C2Command, C2Server +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -129,19 +129,19 @@ def test_c2_handle_switching_port(basic_c2_network): # Assert to confirm that both the C2 server and the C2 beacon are configured correctly. assert c2_beacon.c2_config.keep_alive_frequency is 2 - assert c2_beacon.c2_config.masquerade_port is Port.HTTP - assert c2_beacon.c2_config.masquerade_protocol is IPProtocol.TCP + assert c2_beacon.c2_config.masquerade_port is PORT_LOOKUP["HTTP"] + assert c2_beacon.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] assert c2_server.c2_config.keep_alive_frequency is 2 - assert c2_server.c2_config.masquerade_port is Port.HTTP - assert c2_server.c2_config.masquerade_protocol is IPProtocol.TCP + assert c2_server.c2_config.masquerade_port is PORT_LOOKUP["HTTP"] + assert c2_server.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] # Configuring the C2 Beacon. c2_beacon.configure( c2_server_ip_address="192.168.0.1", keep_alive_frequency=2, - masquerade_port=Port.FTP, - masquerade_protocol=IPProtocol.TCP, + masquerade_port=PORT_LOOKUP["FTP"], + masquerade_protocol=PROTOCOL_LOOKUP["TCP"], ) # Asserting that the c2 applications have established a c2 connection @@ -150,11 +150,11 @@ def test_c2_handle_switching_port(basic_c2_network): # Assert to confirm that both the C2 server and the C2 beacon # Have reconfigured their C2 settings. - assert c2_beacon.c2_config.masquerade_port is Port.FTP - assert c2_beacon.c2_config.masquerade_protocol is IPProtocol.TCP + assert c2_beacon.c2_config.masquerade_port is PORT_LOOKUP["FTP"] + assert c2_beacon.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] - assert c2_server.c2_config.masquerade_port is Port.FTP - assert c2_server.c2_config.masquerade_protocol is IPProtocol.TCP + assert c2_server.c2_config.masquerade_port is PORT_LOOKUP["FTP"] + assert c2_server.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] def test_c2_handle_switching_frequency(basic_c2_network): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py index 0811d2a0..34a29cd0 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py @@ -3,13 +3,13 @@ import pytest from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.networks import arcd_uc2_network -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.red_applications.data_manipulation_bot import ( DataManipulationAttackStage, DataManipulationBot, ) +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -27,8 +27,8 @@ def test_create_dm_bot(dm_client): data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software.get("DataManipulationBot") assert data_manipulation_bot.name == "DataManipulationBot" - assert data_manipulation_bot.port == Port.NONE - assert data_manipulation_bot.protocol == IPProtocol.NONE + assert data_manipulation_bot.port == PORT_LOOKUP["NONE"] + assert data_manipulation_bot.protocol == PROTOCOL_LOOKUP["NONE"] assert data_manipulation_bot.payload == "DELETE" diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py index 2acd991a..e9762476 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py @@ -5,9 +5,9 @@ import pytest from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py index ce98d164..f1be475a 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py @@ -4,10 +4,10 @@ import pytest from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.protocols.http import HttpResponsePacket, HttpStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -39,8 +39,8 @@ def test_create_web_client(): # Web Browser should be pre-installed in computer web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") assert web_browser.name is "WebBrowser" - assert web_browser.port is Port.HTTP - assert web_browser.protocol is IPProtocol.TCP + assert web_browser.port is PORT_LOOKUP["HTTP"] + assert web_browser.protocol is PROTOCOL_LOOKUP["TCP"] def test_receive_invalid_payload(web_browser): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py index e9ce4884..db7e8d58 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py @@ -6,10 +6,10 @@ import pytest from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.protocols.dns import DNSPacket, DNSReply, DNSRequest -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -28,8 +28,8 @@ def test_create_dns_client(dns_client): assert dns_client is not None dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient") assert dns_client_service.name is "DNSClient" - assert dns_client_service.port is Port.DNS - assert dns_client_service.protocol is IPProtocol.TCP + assert dns_client_service.port is PORT_LOOKUP["DNS"] + assert dns_client_service.protocol is PROTOCOL_LOOKUP["TCP"] def test_dns_client_add_domain_to_cache_when_not_running(dns_client): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py index 4658fe76..c64602c0 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py @@ -8,10 +8,10 @@ from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -32,8 +32,8 @@ def test_create_dns_server(dns_server): assert dns_server is not None dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer") assert dns_server_service.name is "DNSServer" - assert dns_server_service.port is Port.DNS - assert dns_server_service.protocol is IPProtocol.TCP + assert dns_server_service.port is PORT_LOOKUP["DNS"] + assert dns_server_service.protocol is PROTOCOL_LOOKUP["TCP"] def test_dns_server_domain_name_registration(dns_server): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py index 3ce4d8ee..95788834 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py @@ -8,10 +8,10 @@ from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -31,8 +31,8 @@ def test_create_ftp_client(ftp_client): assert ftp_client is not None ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") assert ftp_client_service.name is "FTPClient" - assert ftp_client_service.port is Port.FTP - assert ftp_client_service.protocol is IPProtocol.TCP + assert ftp_client_service.port is PORT_LOOKUP["FTP"] + assert ftp_client_service.protocol is PROTOCOL_LOOKUP["TCP"] def test_ftp_client_store_file(ftp_client): @@ -61,7 +61,7 @@ def test_ftp_should_not_process_commands_if_service_not_running(ftp_client): """Method _process_ftp_command should return false if service is not running.""" payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.PORT, - ftp_command_args=Port.FTP, + ftp_command_args=PORT_LOOKUP["FTP"], status_code=FTPStatusCode.OK, ) @@ -102,7 +102,7 @@ def test_offline_ftp_client_receives_request(ftp_client): payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.PORT, - ftp_command_args=Port.FTP, + ftp_command_args=PORT_LOOKUP["FTP"], status_code=FTPStatusCode.OK, ) @@ -119,7 +119,7 @@ def test_receive_should_ignore_payload_with_none_status_code(ftp_client): """Receive should ignore payload with no set status code to prevent infinite send/receive loops.""" payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.PORT, - ftp_command_args=Port.FTP, + ftp_command_args=PORT_LOOKUP["FTP"], status_code=None, ) ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py index a1c2ba59..291cdede 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py @@ -6,10 +6,10 @@ from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -30,8 +30,8 @@ def test_create_ftp_server(ftp_server): assert ftp_server is not None ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer") assert ftp_server_service.name is "FTPServer" - assert ftp_server_service.port is Port.FTP - assert ftp_server_service.protocol is IPProtocol.TCP + assert ftp_server_service.port is PORT_LOOKUP["FTP"] + assert ftp_server_service.protocol is PROTOCOL_LOOKUP["TCP"] def test_ftp_server_store_file(ftp_server): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 41858b90..9b6a4bf3 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -18,13 +18,13 @@ from primaite.simulator.network.protocols.ssh import ( SSHTransportMessage, SSHUserCredentials, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection, Terminal from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -77,11 +77,15 @@ def wireless_wan_network(): network.connect(pc_a.network_interface[1], router_1.network_interface[2]) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # add ACL rule to allow SSH traffic - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=21) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=21 + ) # Configure PC B pc_b = Computer( @@ -329,7 +333,9 @@ def test_SSH_across_network(wireless_wan_network): terminal_a: Terminal = pc_a.software_manager.software.get("Terminal") terminal_b: Terminal = pc_b.software_manager.software.get("Terminal") - router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=21) + router_2.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=21 + ) assert len(terminal_a._connections) == 0 diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py index 9af176be..54f86ec8 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py @@ -9,9 +9,9 @@ from primaite.simulator.network.protocols.http import ( HttpResponsePacket, HttpStatusCode, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") @@ -33,8 +33,8 @@ def test_create_web_server(web_server): assert web_server is not None web_server_service: WebServer = web_server.software_manager.software.get("WebServer") assert web_server_service.name is "WebServer" - assert web_server_service.port is Port.HTTP - assert web_server_service.protocol is IPProtocol.TCP + assert web_server_service.port is PORT_LOOKUP["HTTP"] + assert web_server_service.protocol is PROTOCOL_LOOKUP["TCP"] def test_handling_get_request_not_found_path(web_server): diff --git a/tests/unit_tests/_primaite/_simulator/_system/test_software.py b/tests/unit_tests/_primaite/_simulator/_system/test_software.py index 4cf83370..300f8d9d 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/test_software.py +++ b/tests/unit_tests/_primaite/_simulator/_system/test_software.py @@ -3,11 +3,11 @@ from typing import Dict import pytest -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.service import Service from primaite.simulator.system.software import IOSoftware, SoftwareHealthState +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP class TestSoftware(Service): @@ -19,10 +19,10 @@ class TestSoftware(Service): def software(file_system): return TestSoftware( name="TestSoftware", - port=Port.ARP, + port=PORT_LOOKUP["ARP"], file_system=file_system, sys_log=SysLog(hostname="test_service"), - protocol=IPProtocol.TCP, + protocol=PROTOCOL_LOOKUP["TCP"], ) diff --git a/tests/unit_tests/_primaite/_utils/_validation/__init__.py b/tests/unit_tests/_primaite/_utils/_validation/__init__.py new file mode 100644 index 00000000..be6c00e7 --- /dev/null +++ b/tests/unit_tests/_primaite/_utils/_validation/__init__.py @@ -0,0 +1 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_utils/_validation/test_ip_protocol.py b/tests/unit_tests/_primaite/_utils/_validation/test_ip_protocol.py new file mode 100644 index 00000000..27829570 --- /dev/null +++ b/tests/unit_tests/_primaite/_utils/_validation/test_ip_protocol.py @@ -0,0 +1,23 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest + +from primaite.utils.validation.ip_protocol import IPProtocol, is_valid_protocol, PROTOCOL_LOOKUP, protocol_validator + + +def test_port_conversion(): + for proto_name, proto_val in PROTOCOL_LOOKUP.items(): + assert protocol_validator(proto_name) == proto_val + assert is_valid_protocol(proto_name) + + +def test_port_passthrough(): + for proto_val in PROTOCOL_LOOKUP.values(): + assert protocol_validator(proto_val) == proto_val + assert is_valid_protocol(proto_val) + + +def test_invalid_ports(): + for port in (123, "abcdefg", "NONEXISTENT_PROTO"): + with pytest.raises(ValueError): + protocol_validator(port) + assert not is_valid_protocol(port) diff --git a/tests/unit_tests/_primaite/_utils/_validation/test_port.py b/tests/unit_tests/_primaite/_utils/_validation/test_port.py new file mode 100644 index 00000000..6a8a2429 --- /dev/null +++ b/tests/unit_tests/_primaite/_utils/_validation/test_port.py @@ -0,0 +1,25 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest + +from primaite.utils.validation.port import is_valid_port, Port, PORT_LOOKUP, port_validator + + +def test_port_conversion(): + valid_port_lookup = {k: v for k, v in PORT_LOOKUP.items() if k != "UNUSED"} + for port_name, port_val in valid_port_lookup.items(): + assert port_validator(port_name) == port_val + assert is_valid_port(port_name) + + +def test_port_passthrough(): + valid_port_lookup = {k: v for k, v in PORT_LOOKUP.items() if k != "UNUSED"} + for port_val in valid_port_lookup.values(): + assert port_validator(port_val) == port_val + assert is_valid_port(port_val) + + +def test_invalid_ports(): + for port in (999999, -20, 3.214, "NONEXISTENT_PORT"): + with pytest.raises(ValueError): + port_validator(port) + assert not is_valid_port(port) diff --git a/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py b/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py index a8fb0a3a..1a1848ac 100644 --- a/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py +++ b/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.utils.converters import convert_dict_enum_keys_to_enum_values +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP def test_simple_conversion(): @@ -11,7 +11,7 @@ def test_simple_conversion(): The original dictionary contains one level of nested dictionary with enums as keys. The expected output should have string values of enums as keys. """ - original_dict = {IPProtocol.UDP: {Port.ARP: {"inbound": 0, "outbound": 1016.0}}} + original_dict = {PROTOCOL_LOOKUP["UDP"]: {PORT_LOOKUP["ARP"]: {"inbound": 0, "outbound": 1016.0}}} expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0}}} assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict @@ -36,8 +36,8 @@ def test_mixed_keys(): The expected output should have string values of enums and original string keys. """ original_dict = { - IPProtocol.TCP: {"port": {"inbound": 0, "outbound": 1016.0}}, - "protocol": {Port.HTTP: {"inbound": 10, "outbound": 2020.0}}, + PROTOCOL_LOOKUP["TCP"]: {"port": {"inbound": 0, "outbound": 1016.0}}, + "protocol": {PORT_LOOKUP["HTTP"]: {"inbound": 10, "outbound": 2020.0}}, } expected_dict = { "tcp": {"port": {"inbound": 0, "outbound": 1016.0}}, @@ -66,7 +66,13 @@ def test_nested_dicts(): The expected output should have string values of enums as keys at all levels. """ original_dict = { - IPProtocol.UDP: {Port.ARP: {"inbound": 0, "outbound": 1016.0, "details": {IPProtocol.TCP: {"latency": "low"}}}} + PROTOCOL_LOOKUP["UDP"]: { + PORT_LOOKUP["ARP"]: { + "inbound": 0, + "outbound": 1016.0, + "details": {PROTOCOL_LOOKUP["TCP"]: {"latency": "low"}}, + } + } } expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0, "details": {"tcp": {"latency": "low"}}}}} assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict @@ -79,6 +85,12 @@ def test_non_dict_values(): The original dictionary contains lists and tuples as values. The expected output should preserve these non-dictionary values while converting enum keys to string values. """ - original_dict = {IPProtocol.UDP: [Port.ARP, Port.HTTP], "protocols": (IPProtocol.TCP, IPProtocol.UDP)} - expected_dict = {"udp": [Port.ARP, Port.HTTP], "protocols": (IPProtocol.TCP, IPProtocol.UDP)} + original_dict = { + PROTOCOL_LOOKUP["UDP"]: [PORT_LOOKUP["ARP"], PORT_LOOKUP["HTTP"]], + "protocols": (PROTOCOL_LOOKUP["TCP"], PROTOCOL_LOOKUP["UDP"]), + } + expected_dict = { + "udp": [PORT_LOOKUP["ARP"], PORT_LOOKUP["HTTP"]], + "protocols": (PROTOCOL_LOOKUP["TCP"], PROTOCOL_LOOKUP["UDP"]), + } assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict