diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 30ccd195..0984f008 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -5,7 +5,6 @@ from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType -from pydantic import field_validator from primaite import getLogger from primaite.game.agent.observations.file_system_observations import FolderObservation @@ -13,8 +12,7 @@ from primaite.game.agent.observations.nic_observations import NICObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validators import IPProtocol, Port _LOGGER = getLogger(__name__) @@ -47,7 +45,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): """Number of spaces for network interface observations on this host.""" include_nmne: Optional[bool] = None """Whether network interface observations should include number of malicious network events.""" - monitored_traffic: Optional[Dict] = None + monitored_traffic: Optional[Dict[IPProtocol, List[Port]]] = None """A dict containing which traffic types are to be included in the observation.""" include_num_access: Optional[bool] = None """Whether to include the number of accesses to files observations on this host.""" @@ -58,26 +56,26 @@ class HostObservation(AbstractObservation, identifier="HOST"): include_users: Optional[bool] = True """If True, report user session information.""" - @field_validator("monitored_traffic", mode="before") - def traffic_lookup(cls, val: Optional[Dict]) -> Optional[Dict]: - """ - Convert monitored_traffic by lookup against Port and Protocol dicts. + # @field_validator("monitored_traffic", mode="before") + # def traffic_lookup(cls, val: Optional[Dict]) -> Optional[Dict]: + # """ + # Convert monitored_traffic by lookup against Port and Protocol dicts. - This is necessary for retaining compatiblility with configs written for PrimAITE <=3.3. - This method will be removed in PrimAITE >= 4.0 - """ - if val is None: - return val - new_val = {} - for proto, port_list in val.items(): - # convert protocol, for instance ICMP becomes "icmp" - proto = IPProtocol[proto] if proto in IPProtocol else proto - new_val[proto] = [] - for port in port_list: - # convert ports, for instance "HTTP" becomes 80 - port = Port[port] if port in Port else port - new_val[proto].append(port) - return new_val + # This is necessary for retaining compatiblility with configs written for PrimAITE <=3.3. + # This method will be removed in PrimAITE >= 4.0 + # """ + # if val is None: + # return val + # new_val = {} + # for proto, port_list in val.items(): + # # convert protocol, for instance ICMP becomes "icmp" + # proto = PROTOCOL_LOOKUP[proto] if proto in PROTOCOL_LOOKUP else proto + # new_val[proto] = [] + # for port in port_list: + # # convert ports, for instance "HTTP" becomes 80 + # port = PORT_LOOKUP[port] if port in PORT_LOOKUP else port + # new_val[proto].append(port) + # return new_val def __init__( self, diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 296ce04c..c51cb427 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -9,8 +9,8 @@ from pydantic import field_validator from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validators import PROTOCOL_LOOKUP class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): @@ -39,11 +39,11 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): new_val = {} for proto, port_list in val.items(): # convert protocol, for instance ICMP becomes "icmp" - proto = IPProtocol[proto] if proto in IPProtocol else proto + proto = PROTOCOL_LOOKUP[proto] if proto in PROTOCOL_LOOKUP else proto new_val[proto] = [] for port in port_list: # convert ports, for instance "HTTP" becomes 80 - port = Port[port] if port in Port else port + port = PORT_LOOKUP[port] if port in PORT_LOOKUP else port new_val[proto].append(port) return new_val diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 054ffcdb..0bb8ea0f 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -12,8 +12,8 @@ from primaite.game.agent.observations.firewall_observation import FirewallObserv from primaite.game.agent.observations.host_observations import HostObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.observations.router_observation import RouterObservation -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -76,11 +76,11 @@ class NodesObservation(AbstractObservation, identifier="NODES"): new_val = {} for proto, port_list in val.items(): # convert protocol, for instance ICMP becomes "icmp" - proto = IPProtocol[proto] if proto in IPProtocol else proto + proto = PROTOCOL_LOOKUP[proto] if proto in PROTOCOL_LOOKUP else proto new_val[proto] = [] for port in port_list: # convert ports, for instance "HTTP" becomes 80 - port = Port[port] if port in Port else port + port = PORT_LOOKUP[port] if port in PORT_LOOKUP else port new_val[proto].append(port) return new_val diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 8e0abb1e..a0d2ceb4 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -27,8 +27,7 @@ from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from primaite.simulator.network.nmne import NMNEConfig -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient # noqa: F401 @@ -51,6 +50,7 @@ from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.terminal.terminal import Terminal from primaite.simulator.system.services.web_server.web_server import WebServer from primaite.simulator.system.software import Software +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -97,8 +97,8 @@ class PrimaiteGameOptions(BaseModel): :warning: This will be deprecated in PrimAITE 4.0 and configs will need to be converted. """ for i, port_val in enumerate(vals): - if port_val in Port: - vals[i] = Port[port_val] + if port_val in PORT_LOOKUP: + vals[i] = PORT_LOOKUP[port_val] return vals @field_validator("protocols", mode="before") @@ -110,8 +110,8 @@ class PrimaiteGameOptions(BaseModel): :warning: This will be deprecated in PrimAITE 4.0 and configs will need to be converted. """ for i, proto_val in enumerate(vals): - if proto_val in IPProtocol: - vals[i] = IPProtocol[proto_val] + if proto_val in PROTOCOL_LOOKUP: + vals[i] = PROTOCOL_LOOKUP[proto_val] return vals @@ -381,7 +381,7 @@ class PrimaiteGame: if isinstance(port_id, int): port = port_id elif isinstance(port_id, str): - port = Port[port_id] + port = PORT_LOOKUP[port_id] if port: listen_on_ports.append(port) software.listen_on_ports = set(listen_on_ports) @@ -496,7 +496,7 @@ class PrimaiteGame: opt = application_cfg["options"] new_application.configure( target_ip_address=IPv4Address(opt.get("target_ip_address")), - target_port=Port[opt.get("target_port", "POSTGRES_SERVER")], + target_port=PORT_LOOKUP[opt.get("target_port", "POSTGRES_SERVER")], payload=opt.get("payload"), repeat=bool(opt.get("repeat")), port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), @@ -509,8 +509,10 @@ class PrimaiteGame: new_application.configure( c2_server_ip_address=IPv4Address(opt.get("c2_server_ip_address")), keep_alive_frequency=(opt.get("keep_alive_frequency", 5)), - masquerade_protocol=IPProtocol[(opt.get("masquerade_protocol", IPProtocol["TCP"]))], - masquerade_port=Port[(opt.get("masquerade_port", Port["HTTP"]))], + masquerade_protocol=PROTOCOL_LOOKUP[ + (opt.get("masquerade_protocol", PROTOCOL_LOOKUP["TCP"])) + ], + masquerade_port=PORT_LOOKUP[(opt.get("masquerade_port", PORT_LOOKUP["HTTP"]))], ) if "network_interfaces" in node_cfg: for nic_num, nic_cfg in node_cfg["network_interfaces"].items(): diff --git a/src/primaite/simulator/network/creation.py b/src/primaite/simulator/network/creation.py index c2524b4b..9e2e5502 100644 --- a/src/primaite/simulator/network/creation.py +++ b/src/primaite/simulator/network/creation.py @@ -6,8 +6,8 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validators import PROTOCOL_LOOKUP def num_of_switches_required(num_nodes: int, max_network_interface: int = 24) -> int: @@ -98,8 +98,10 @@ def create_office_lan( default_gateway = IPv4Address(f"192.168.{subnet_base}.1") router = Router(hostname=f"router_{lan_name}", start_up_duration=0) router.power_on() - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + router.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + router.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) network.add_node(router) router.configure_port(port=1, ip_address=default_gateway, subnet_mask="255.255.255.0") router.enable_port(1) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 4154cc08..affaf3cc 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -21,8 +21,7 @@ from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.packet_capture import PacketCapture from primaite.simulator.system.core.session_manager import SessionManager @@ -33,7 +32,7 @@ from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.terminal.terminal import Terminal from primaite.simulator.system.software import IOSoftware, Software from primaite.utils.converters import convert_dict_enum_keys_to_enum_values -from primaite.utils.validators import IPV4Address +from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware) @@ -274,20 +273,20 @@ class NetworkInterface(SimComponent, ABC): # Identify the protocol and port from the frame if frame.tcp: - protocol = IPProtocol["TCP"] + protocol = PROTOCOL_LOOKUP["TCP"] port = frame.tcp.dst_port elif frame.udp: - protocol = IPProtocol["UDP"] + protocol = PROTOCOL_LOOKUP["UDP"] port = frame.udp.dst_port elif frame.icmp: - protocol = IPProtocol["ICMP"] + protocol = PROTOCOL_LOOKUP["ICMP"] # Ensure the protocol is in the capture dict if protocol not in self.traffic: self.traffic[protocol] = {} # Handle non-ICMP protocols that use ports - if protocol != IPProtocol["ICMP"]: + if protocol != PROTOCOL_LOOKUP["ICMP"]: if port not in self.traffic[protocol]: self.traffic[protocol][port] = {"inbound": 0, "outbound": 0} self.traffic[protocol][port][direction] += frame.size_Mbits @@ -843,8 +842,8 @@ class UserManager(Service): :param password: The password for the default admin user """ kwargs["name"] = "UserManager" - kwargs["port"] = Port["NONE"] - kwargs["protocol"] = IPProtocol["NONE"] + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) self.start() @@ -1166,8 +1165,8 @@ class UserSessionManager(Service): :param password: The password for the default admin user """ kwargs["name"] = "UserSessionManager" - kwargs["port"] = Port["NONE"] - kwargs["protocol"] = IPProtocol["NONE"] + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) self.start() @@ -1312,7 +1311,7 @@ class UserSessionManager(Service): software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( payload={"type": "user_timeout", "connection_id": session.uuid}, - dest_port=Port["SSH"], + dest_port=PORT_LOOKUP["SSH"], dest_ip_address=session.remote_ip_address, ) diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 6d8e084d..eed1132b 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -14,10 +14,9 @@ from primaite.simulator.network.hardware.nodes.network.router import ( RouterInterface, ) from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.sys_log import SysLog -from primaite.utils.validators import IPV4Address +from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP EXTERNAL_PORT_ID: Final[int] = 1 """The Firewall port ID of the external port.""" @@ -596,9 +595,9 @@ class Firewall(Router): for r_num, r_cfg in cfg["acl"]["internal_inbound_acl"].items(): firewall.internal_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -611,9 +610,9 @@ class Firewall(Router): for r_num, r_cfg in cfg["acl"]["internal_outbound_acl"].items(): firewall.internal_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -626,9 +625,9 @@ class Firewall(Router): for r_num, r_cfg in cfg["acl"]["dmz_inbound_acl"].items(): firewall.dmz_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -641,9 +640,9 @@ class Firewall(Router): for r_num, r_cfg in cfg["acl"]["dmz_outbound_acl"].items(): firewall.dmz_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -656,9 +655,9 @@ class Firewall(Router): for r_num, r_cfg in cfg["acl"]["external_inbound_acl"].items(): firewall.external_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -671,9 +670,9 @@ class Firewall(Router): for r_num, r_cfg in cfg["acl"]["external_outbound_acl"].items(): firewall.external_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index fded23f9..46efe668 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -17,15 +17,14 @@ from primaite.simulator.network.hardware.nodes.network.network_node import Netwo from primaite.simulator.network.protocols.arp import ARPPacket from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.nmap import NMAP from primaite.simulator.system.core.session_manager import SessionManager from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.arp.arp import ARP from primaite.simulator.system.services.icmp.icmp import ICMP from primaite.simulator.system.services.terminal.terminal import Terminal -from primaite.utils.validators import IPV4Address +from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP @validate_call() @@ -134,14 +133,14 @@ class ACLRule(SimComponent): def protocol_valid(cls, val: Optional[str]) -> Optional[str]: """Assert that the protocol for the rule is predefined in the IPProtocol lookup.""" if val is not None: - assert val in IPProtocol.values(), f"Cannot create ACL rule with invalid protocol {val}" + assert val in PROTOCOL_LOOKUP.values(), f"Cannot create ACL rule with invalid protocol {val}" return val @field_validator("src_port", "dst_port", mode="before") def ports_valid(cls, val: Optional[int]) -> Optional[int]: """Assert that the port for the rule is predefined in the Port lookup.""" if val is not None: - assert val in Port.values(), f"Cannot create ACL rule with invalid port {val}" + assert val in PORT_LOOKUP.values(), f"Cannot create ACL rule with invalid port {val}" return val def __str__(self) -> str: @@ -1271,8 +1270,10 @@ class Router(NetworkNode): Initializes the router's ACL (Access Control List) with default rules, permitting essential protocols like ARP and ICMP, which are necessary for basic network operations and diagnostics. """ - self.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - self.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + self.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + self.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) def setup_for_episode(self, episode: int): """ @@ -1371,9 +1372,9 @@ class Router(NetworkNode): """ dst_ip_address = frame.ip.dst_ip_address dst_port = None - if frame.ip.protocol == IPProtocol["TCP"]: + if frame.ip.protocol == PROTOCOL_LOOKUP["TCP"]: dst_port = frame.tcp.dst_port - elif frame.ip.protocol == IPProtocol["UDP"]: + elif frame.ip.protocol == PROTOCOL_LOOKUP["UDP"]: dst_port = frame.udp.dst_port if self.ip_is_router_interface(dst_ip_address) and ( @@ -1646,9 +1647,9 @@ class Router(NetworkNode): for r_num, r_cfg in cfg["acl"].items(): router.acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 1969a121..3615ef54 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -8,9 +8,8 @@ from primaite.simulator.network.airspace import AirSpace, IPWirelessNetworkInter from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterInterface from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port -from primaite.utils.validators import IPV4Address +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP class WirelessAccessPoint(IPWirelessNetworkInterface): @@ -271,9 +270,9 @@ class WirelessRouter(Router): for r_num, r_cfg in cfg["acl"].items(): router.acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), dst_ip_address=r_cfg.get("dst_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index a73f3b12..c3b4a341 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -12,14 +12,14 @@ from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -79,9 +79,11 @@ def client_server_routed() -> Network: server_1.power_on() network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1]) - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) return network @@ -271,23 +273,30 @@ def arcd_uc2_network() -> Network: security_suite.connect_nic(NIC(ip_address="192.168.10.110", subnet_mask="255.255.255.0")) network.connect(endpoint_b=security_suite.network_interface[2], endpoint_a=switch_2.network_interface[7]) - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Allow PostgreSQL requests router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) # Allow DNS requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["DNS"], dst_port=Port["DNS"], position=1) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["DNS"], dst_port=PORT_LOOKUP["DNS"], position=1) # Allow FTP requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["FTP"], dst_port=Port["FTP"], position=2) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=2) # Open port 80 for web server - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=3) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=3 + ) return network diff --git a/src/primaite/simulator/network/transmission/data_link_layer.py b/src/primaite/simulator/network/transmission/data_link_layer.py index b9bc48d9..ca212c58 100644 --- a/src/primaite/simulator/network/transmission/data_link_layer.py +++ b/src/primaite/simulator/network/transmission/data_link_layer.py @@ -7,10 +7,11 @@ from pydantic import BaseModel from primaite import getLogger from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.protocols.packet import DataPacket -from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol +from primaite.simulator.network.transmission.network_layer import IPPacket from primaite.simulator.network.transmission.primaite_layer import PrimaiteHeader -from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader, UDPHeader +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP, TCPHeader, UDPHeader from primaite.simulator.network.utils import convert_bytes_to_megabits +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -70,15 +71,15 @@ class Frame(BaseModel): msg = "Network Frame cannot have both a TCP header and a UDP header" _LOGGER.error(msg) raise ValueError(msg) - if kwargs["ip"].protocol == IPProtocol["TCP"] and not kwargs.get("tcp"): + if kwargs["ip"].protocol == PROTOCOL_LOOKUP["TCP"] and not kwargs.get("tcp"): msg = "Cannot build a Frame using the TCP IP Protocol without a TCPHeader" _LOGGER.error(msg) raise ValueError(msg) - if kwargs["ip"].protocol == IPProtocol["UDP"] and not kwargs.get("udp"): + if kwargs["ip"].protocol == PROTOCOL_LOOKUP["UDP"] and not kwargs.get("udp"): msg = "Cannot build a Frame using the UDP IP Protocol without a UDPHeader" _LOGGER.error(msg) raise ValueError(msg) - if kwargs["ip"].protocol == IPProtocol["ICMP"] and not kwargs.get("icmp"): + if kwargs["ip"].protocol == PROTOCOL_LOOKUP["ICMP"] and not kwargs.get("icmp"): msg = "Cannot build a Frame using the ICMP IP Protocol without a ICMPPacket" _LOGGER.error(msg) raise ValueError(msg) @@ -165,7 +166,7 @@ class Frame(BaseModel): :return: True if the Frame is an ARP packet, otherwise False. """ - return self.udp.dst_port == Port["ARP"] + return self.udp.dst_port == PORT_LOOKUP["ARP"] @property def is_icmp(self) -> bool: diff --git a/src/primaite/simulator/network/transmission/network_layer.py b/src/primaite/simulator/network/transmission/network_layer.py index a01b7f42..47e8a032 100644 --- a/src/primaite/simulator/network/transmission/network_layer.py +++ b/src/primaite/simulator/network/transmission/network_layer.py @@ -4,39 +4,11 @@ from enum import Enum from pydantic import BaseModel from primaite import getLogger -from primaite.utils.validators import IPV4Address +from primaite.utils.validators import IPProtocol, IPV4Address, PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) -IPProtocol: dict[str, str] = dict( - NONE="none", - TCP="tcp", - UDP="udp", - ICMP="icmp", -) - -# class IPProtocol(Enum): -# """ -# Enum representing transport layer protocols in IP header. - -# .. _List of IPProtocols: -# """ - -# NONE = "none" -# """Placeholder for a non-protocol.""" -# TCP = "tcp" -# """Transmission Control Protocol.""" -# UDP = "udp" -# """User Datagram Protocol.""" -# ICMP = "icmp" -# """Internet Control Message Protocol.""" - -# def model_dump(self) -> str: -# """Return as JSON-serialisable string.""" -# return self.name - - class Precedence(Enum): """ Enum representing the Precedence levels in Quality of Service (QoS) for IP packets. @@ -98,7 +70,7 @@ class IPPacket(BaseModel): "Source IP address." dst_ip_address: IPV4Address "Destination IP address." - protocol: str = IPProtocol["TCP"] + protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"] "IPProtocol." ttl: int = 64 "Time to Live (TTL) for the packet." diff --git a/src/primaite/simulator/network/transmission/transport_layer.py b/src/primaite/simulator/network/transmission/transport_layer.py index 60f2f070..fbc4b5ad 100644 --- a/src/primaite/simulator/network/transmission/transport_layer.py +++ b/src/primaite/simulator/network/transmission/transport_layer.py @@ -4,7 +4,7 @@ from typing import List from pydantic import BaseModel -Port: dict[str, int] = dict( +PORT_LOOKUP: dict[str, int] = dict( UNUSED=-1, NONE=0, WOL=9, @@ -36,81 +36,6 @@ Port: dict[str, int] = dict( POSTGRES_SERVER=5432, ) -# class Port(): -# def __getattr__() - - -# class Port(Enum): -# """ -# Enumeration of common known TCP/UDP ports used by protocols for operation of network applications. - -# .. _List of Ports: -# """ - -# UNUSED = -1 -# "An unused port stub." - -# NONE = 0 -# "Place holder for a non-port." -# WOL = 9 -# "Wake-on-Lan (WOL) - Used to turn or awaken a computer from sleep mode by a network message." -# FTP_DATA = 20 -# "File Transfer [Default Data]" -# FTP = 21 -# "File Transfer Protocol (FTP) - FTP control (command)" -# SSH = 22 -# "Secure Shell (SSH) - Used for secure remote access and command execution." -# SMTP = 25 -# "Simple Mail Transfer Protocol (SMTP) - Used for email delivery between servers." -# DNS = 53 -# "Domain Name System (DNS) - Used for translating domain names to IP addresses." -# HTTP = 80 -# "HyperText Transfer Protocol (HTTP) - Used for web traffic." -# POP3 = 110 -# "Post Office Protocol version 3 (POP3) - Used for retrieving emails from a mail server." -# SFTP = 115 -# "Secure File Transfer Protocol (SFTP) - Used for secure file transfer over SSH." -# NTP = 123 -# "Network Time Protocol (NTP) - Used for clock synchronization between computer systems." -# IMAP = 143 -# "Internet Message Access Protocol (IMAP) - Used for retrieving emails from a mail server." -# SNMP = 161 -# "Simple Network Management Protocol (SNMP) - Used for network device management." -# SNMP_TRAP = 162 -# "SNMP Trap - Used for sending SNMP notifications (traps) to a network management system." -# ARP = 219 -# "Address resolution Protocol - Used to connect a MAC address to an IP address." -# LDAP = 389 -# "Lightweight Directory Access Protocol (LDAP) - Used for accessing and modifying directory information." -# HTTPS = 443 -# "HyperText Transfer Protocol Secure (HTTPS) - Used for secure web traffic." -# SMB = 445 -# "Server Message Block (SMB) - Used for file sharing and printer sharing in Windows environments." -# IPP = 631 -# "Internet Printing Protocol (IPP) - Used for printing over the internet or an intranet." -# SQL_SERVER = 1433 -# "Microsoft SQL Server Database Engine - Used for communication with the SQL Server." -# MYSQL = 3306 -# "MySQL Database Server - Used for MySQL database communication." -# RDP = 3389 -# "Remote Desktop Protocol (RDP) - Used for remote desktop access to Windows machines." -# RTP = 5004 -# "Real-time Transport Protocol (RTP) - Used for transmitting real-time media, e.g., audio and video." -# RTP_ALT = 5005 -# "Alternative port for RTP (RTP_ALT) - Used in some configurations for transmitting real-time media." -# DNS_ALT = 5353 -# "Alternative port for DNS (DNS_ALT) - Used in some configurations for DNS service." -# HTTP_ALT = 8080 -# "Alternative port for HTTP (HTTP_ALT) - Often used as an alternative HTTP port for web applications." -# HTTPS_ALT = 8443 -# "Alternative port for HTTPS (HTTPS_ALT) - Used in some configurations for secure web traffic." -# POSTGRES_SERVER = 5432 -# "Postgres SQL Server." - -# def model_dump(self) -> str: -# """Return a json-serialisable string.""" -# return self.name - class UDPHeader(BaseModel): """ diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 170e2647..4967f519 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -11,11 +11,10 @@ from pydantic import BaseModel from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.hardware.nodes.host.host_node import HostNode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.software_manager import SoftwareManager -from primaite.utils.validators import IPV4Address +from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP class DatabaseClientConnection(BaseModel): @@ -90,8 +89,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"): def __init__(self, **kwargs): kwargs["name"] = "DatabaseClient" - kwargs["port"] = Port["POSTGRES_SERVER"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def _init_request_manager(self) -> RequestManager: diff --git a/src/primaite/simulator/system/applications/nmap.py b/src/primaite/simulator/system/applications/nmap.py index 74bce85d..34433e65 100644 --- a/src/primaite/simulator/system/applications/nmap.py +++ b/src/primaite/simulator/system/applications/nmap.py @@ -7,10 +7,9 @@ from pydantic import validate_call from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application -from primaite.utils.validators import IPV4Address +from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP class PortScanPayload(SimComponent): @@ -64,8 +63,8 @@ class NMAP(Application, identifier="NMAP"): def __init__(self, **kwargs): kwargs["name"] = "NMAP" - kwargs["port"] = Port["NONE"] - kwargs["protocol"] = IPProtocol["NONE"] + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) def _can_perform_network_action(self) -> bool: @@ -348,12 +347,12 @@ class NMAP(Application, identifier="NMAP"): if isinstance(target_port, int): target_port = [target_port] elif target_port is None: - target_port = [port for port in Port if port not in {Port["NONE"], Port["UNUSED"]}] + target_port = [port for port in PORT_LOOKUP if port not in {PORT_LOOKUP["NONE"], PORT_LOOKUP["UNUSED"]}] if isinstance(target_protocol, str): target_protocol = [target_protocol] elif target_protocol is None: - target_protocol = [IPProtocol["TCP"], IPProtocol["UDP"]] + target_protocol = [PROTOCOL_LOOKUP["TCP"], PROTOCOL_LOOKUP["UDP"]] scan_type = self._determine_port_scan_type(list(ip_addresses), target_port) active_ports = {} diff --git a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py index d442d968..b0cdefba 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py @@ -9,14 +9,14 @@ from pydantic import BaseModel, Field, validate_call from primaite.interface.request import RequestResponse from primaite.simulator.file_system.file_system import FileSystem, Folder from primaite.simulator.network.protocols.masquerade import C2Packet -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application, ApplicationOperatingState from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validators import PROTOCOL_LOOKUP class C2Command(Enum): @@ -81,10 +81,10 @@ class AbstractC2(Application, identifier="AbstractC2"): keep_alive_frequency: int = Field(default=5, ge=1) """The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon.""" - masquerade_protocol: str = Field(default=IPProtocol["TCP"]) + masquerade_protocol: str = Field(default=PROTOCOL_LOOKUP["TCP"]) """The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP.""" - masquerade_port: int = Field(default=Port["HTTP"]) + masquerade_port: int = Field(default=PORT_LOOKUP["HTTP"]) """The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP.""" c2_config: _C2Opts = _C2Opts() @@ -142,9 +142,9 @@ class AbstractC2(Application, identifier="AbstractC2"): def __init__(self, **kwargs): """Initialise the C2 applications to by default listen for HTTP traffic.""" - kwargs["listen_on_ports"] = {Port["HTTP"], Port["FTP"], Port["DNS"]} - kwargs["port"] = Port["NONE"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["listen_on_ports"] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]} + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) @property @@ -410,8 +410,8 @@ class AbstractC2(Application, identifier="AbstractC2"): self.keep_alive_inactivity = 0 self.keep_alive_frequency = 5 self.c2_remote_connection = None - self.c2_config.masquerade_port = Port["HTTP"] - self.c2_config.masquerade_protocol = IPProtocol["TCP"] + self.c2_config.masquerade_port = PORT_LOOKUP["HTTP"] + self.c2_config.masquerade_protocol = PROTOCOL_LOOKUP["TCP"] @abstractmethod def _confirm_remote_connection(self, timestep: int) -> bool: diff --git a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py index 9178e68a..450c60ad 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py @@ -8,12 +8,12 @@ from pydantic import validate_call from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.protocols.masquerade import C2Packet -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.red_applications.c2 import ExfilOpts, RansomwareOpts, TerminalOpts from primaite.simulator.system.applications.red_applications.c2.abstract_c2 import AbstractC2, C2Command, C2Payload from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.terminal.terminal import Terminal, TerminalClientConnection +from primaite.utils.validators import PROTOCOL_LOOKUP class C2Beacon(AbstractC2, identifier="C2Beacon"): @@ -111,8 +111,8 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): self.configure( c2_server_ip_address=c2_remote_ip, keep_alive_frequency=frequency, - masquerade_protocol=IPProtocol[protocol], - masquerade_port=Port[port], + masquerade_protocol=PROTOCOL_LOOKUP[protocol], + masquerade_port=PORT_LOOKUP[port], ) ) @@ -129,8 +129,8 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): self, c2_server_ip_address: IPv4Address = None, keep_alive_frequency: int = 5, - masquerade_protocol: str = IPProtocol["TCP"], - masquerade_port: int = Port["HTTP"], + masquerade_protocol: str = PROTOCOL_LOOKUP["TCP"], + masquerade_port: int = PORT_LOOKUP["HTTP"], ) -> bool: """ Configures the C2 beacon to communicate with the C2 server. diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index d74ae384..c2d19160 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -7,10 +7,10 @@ from primaite import getLogger from primaite.game.science import simulate_trial from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -50,8 +50,8 @@ class DataManipulationBot(Application, identifier="DataManipulationBot"): def __init__(self, **kwargs): kwargs["name"] = "DataManipulationBot" - kwargs["port"] = Port["NONE"] - kwargs["protocol"] = IPProtocol["NONE"] + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) self._db_connection: Optional[DatabaseClientConnection] = None diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index 2cc99c4a..7e199b48 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -7,7 +7,7 @@ from primaite import getLogger from primaite.game.science import simulate_trial from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClient _LOGGER = getLogger(__name__) @@ -85,7 +85,7 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): if "target_ip_address" in request[-1]: request[-1]["target_ip_address"] = IPv4Address(request[-1]["target_ip_address"]) if "target_port" in request[-1]: - request[-1]["target_port"] = Port[request[-1]["target_port"]] + request[-1]["target_port"] = PORT_LOOKUP[request[-1]["target_port"]] return RequestResponse.from_bool(self.configure(**request[-1])) rm.add_request("configure", request_type=RequestType(func=_configure)) @@ -94,7 +94,7 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): def configure( self, target_ip_address: IPv4Address, - target_port: Optional[int] = Port["POSTGRES_SERVER"], + target_port: Optional[int] = PORT_LOOKUP["POSTGRES_SERVER"], payload: Optional[str] = None, repeat: bool = False, port_scan_p_of_success: float = 0.1, diff --git a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py index a819190c..56f885f4 100644 --- a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py +++ b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py @@ -6,10 +6,10 @@ from prettytable import MARKDOWN, PrettyTable from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection +from primaite.utils.validators import PROTOCOL_LOOKUP class RansomwareScript(Application, identifier="RansomwareScript"): @@ -27,8 +27,8 @@ class RansomwareScript(Application, identifier="RansomwareScript"): def __init__(self, **kwargs): kwargs["name"] = "RansomwareScript" - kwargs["port"] = Port["NONE"] - kwargs["protocol"] = IPProtocol["NONE"] + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) self._db_connection: Optional[DatabaseClientConnection] = None diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 6707fa52..faa7b5ec 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -15,10 +15,10 @@ from primaite.simulator.network.protocols.http import ( HttpResponsePacket, HttpStatusCode, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.services.dns.dns_client import DNSClient +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -43,10 +43,10 @@ class WebBrowser(Application, identifier="WebBrowser"): def __init__(self, **kwargs): kwargs["name"] = "WebBrowser" - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] # default for web is port 80 if kwargs.get("port") is None: - kwargs["port"] = Port["HTTP"] + kwargs["port"] = PORT_LOOKUP["HTTP"] super().__init__(**kwargs) self.run() @@ -126,7 +126,7 @@ class WebBrowser(Application, identifier="WebBrowser"): if self.send( payload=payload, dest_ip_address=self.domain_name_ip_address, - dest_port=parsed_url.port if parsed_url.port else Port["HTTP"], + dest_port=parsed_url.port if parsed_url.port else PORT_LOOKUP["HTTP"], ): self.sys_log.info( f"{self.name}: Received HTTP {payload.request_method.name} " @@ -154,7 +154,7 @@ class WebBrowser(Application, identifier="WebBrowser"): self, payload: HttpRequestPacket, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[int] = Port["HTTP"], + dest_port: Optional[int] = PORT_LOOKUP["HTTP"], session_id: Optional[str] = None, **kwargs, ) -> bool: diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index 33de3443..fcf07d9f 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -10,8 +10,9 @@ from primaite.simulator.core import SimComponent from primaite.simulator.network.protocols.arp import ARPPacket from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame -from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader, UDPHeader +from primaite.simulator.network.transmission.network_layer import IPPacket +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP, TCPHeader, UDPHeader +from primaite.utils.validators import PROTOCOL_LOOKUP if TYPE_CHECKING: from primaite.simulator.network.hardware.base import NetworkInterface @@ -117,7 +118,7 @@ class SessionManager: """ protocol = frame.ip.protocol with_ip_address = frame.ip.src_ip_address - if protocol == IPProtocol["TCP"]: + if protocol == PROTOCOL_LOOKUP["TCP"]: if inbound_frame: src_port = frame.tcp.src_port dst_port = frame.tcp.dst_port @@ -125,7 +126,7 @@ class SessionManager: dst_port = frame.tcp.src_port src_port = frame.tcp.dst_port with_ip_address = frame.ip.dst_ip_address - elif protocol == IPProtocol["UDP"]: + elif protocol == PROTOCOL_LOOKUP["UDP"]: if inbound_frame: src_port = frame.udp.src_port dst_port = frame.udp.dst_port @@ -260,7 +261,7 @@ class SessionManager: src_port: Optional[int] = None, dst_port: Optional[int] = None, session_id: Optional[str] = None, - ip_protocol: str = IPProtocol["TCP"], + ip_protocol: str = PROTOCOL_LOOKUP["TCP"], icmp_packet: Optional[ICMPPacket] = None, ) -> Union[Any, None]: """ @@ -284,7 +285,7 @@ class SessionManager: dst_mac_address = payload.target_mac_addr outbound_network_interface = self.resolve_outbound_network_interface(payload.target_ip_address) is_broadcast = payload.request - ip_protocol = IPProtocol["UDP"] + ip_protocol = PROTOCOL_LOOKUP["UDP"] else: vals = self.resolve_outbound_transmission_details( dst_ip_address=dst_ip_address, @@ -316,12 +317,12 @@ class SessionManager: tcp_header = None udp_header = None - if ip_protocol == IPProtocol["TCP"]: + if ip_protocol == PROTOCOL_LOOKUP["TCP"]: tcp_header = TCPHeader( src_port=dst_port, dst_port=dst_port, ) - elif ip_protocol == IPProtocol["UDP"]: + elif ip_protocol == PROTOCOL_LOOKUP["UDP"]: udp_header = UDPHeader( src_port=dst_port, dst_port=dst_port, @@ -385,7 +386,7 @@ class SessionManager: elif frame.udp: dst_port = frame.udp.dst_port elif frame.icmp: - dst_port = Port["NONE"] + dst_port = PORT_LOOKUP["NONE"] self.software_manager.receive_payload_from_session_manager( payload=frame.payload, port=dst_port, diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index 8eac33fa..abf2ca3a 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -8,12 +8,12 @@ from prettytable import MARKDOWN, PrettyTable from primaite.simulator.core import RequestType from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application, ApplicationOperatingState from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.service import Service, ServiceOperatingState from primaite.simulator.system.software import IOSoftware +from primaite.utils.validators import PROTOCOL_LOOKUP if TYPE_CHECKING: from primaite.simulator.system.core.session_manager import SessionManager @@ -191,7 +191,7 @@ class SoftwareManager: dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, src_port: Optional[int] = None, dest_port: Optional[int] = None, - ip_protocol: str = IPProtocol["TCP"], + ip_protocol: str = PROTOCOL_LOOKUP["TCP"], session_id: Optional[str] = None, ) -> bool: """ @@ -275,7 +275,7 @@ class SoftwareManager: software_type, software.operating_state.name, software.health_state_actual.name, - software.port if software.port != Port["NONE"] else None, + software.port if software.port != PORT_LOOKUP["NONE"] else None, software.protocol, ] ) diff --git a/src/primaite/simulator/system/services/arp/arp.py b/src/primaite/simulator/system/services/arp/arp.py index b8dd5f89..2641f1c8 100644 --- a/src/primaite/simulator/system/services/arp/arp.py +++ b/src/primaite/simulator/system/services/arp/arp.py @@ -8,10 +8,9 @@ from prettytable import MARKDOWN, PrettyTable from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.service import Service -from primaite.utils.validators import IPV4Address +from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP class ARP(Service): @@ -26,8 +25,8 @@ class ARP(Service): def __init__(self, **kwargs): kwargs["name"] = "ARP" - kwargs["port"] = Port["ARP"] - kwargs["protocol"] = IPProtocol["UDP"] + kwargs["port"] = PORT_LOOKUP["ARP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["UDP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 11ca9eb2..f9a5d087 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -7,12 +7,12 @@ from primaite import getLogger from primaite.simulator.file_system.file_system import File from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus from primaite.simulator.file_system.folder import Folder -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.service import Service, ServiceOperatingState from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -38,8 +38,8 @@ class DatabaseService(Service): def __init__(self, **kwargs): kwargs["name"] = "DatabaseService" - kwargs["port"] = Port["POSTGRES_SERVER"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self._create_db_file() diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 62f14366..316189a7 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -4,10 +4,10 @@ from typing import Dict, Optional from primaite import getLogger from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -22,11 +22,11 @@ class DNSClient(Service): def __init__(self, **kwargs): kwargs["name"] = "DNSClient" - kwargs["port"] = Port["DNS"] + kwargs["port"] = PORT_LOOKUP["DNS"] # DNS uses UDP by default # it switches to TCP when the bytes exceed 512 (or 4096) bytes # TCP for now - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self.start() @@ -95,7 +95,7 @@ class DNSClient(Service): # send a request to check if domain name exists in the DNS Server software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( - payload=payload, dest_ip_address=self.dns_server, dest_port=Port["DNS"] + payload=payload, dest_ip_address=self.dns_server, dest_port=PORT_LOOKUP["DNS"] ) # recursively re-call the function passing is_reattempt=True diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index 93895825..e0786124 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -6,9 +6,9 @@ from prettytable import MARKDOWN, PrettyTable from primaite import getLogger from primaite.simulator.network.protocols.dns import DNSPacket -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.service import Service +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -21,11 +21,11 @@ class DNSServer(Service): def __init__(self, **kwargs): kwargs["name"] = "DNSServer" - kwargs["port"] = Port["DNS"] + kwargs["port"] = PORT_LOOKUP["DNS"] # DNS uses UDP by default # it switches to TCP when the bytes exceed 512 (or 4096) bytes # TCP for now - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self.start() diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index 1fce4133..11a926cf 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -7,10 +7,10 @@ from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.file_system.file_system import File from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -25,8 +25,8 @@ class FTPClient(FTPServiceABC): def __init__(self, **kwargs): kwargs["name"] = "FTPClient" - kwargs["port"] = Port["FTP"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["FTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self.start() @@ -104,7 +104,7 @@ class FTPClient(FTPServiceABC): def _connect_to_server( self, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[int] = Port["FTP"], + dest_port: Optional[int] = PORT_LOOKUP["FTP"], session_id: Optional[str] = None, is_reattempt: Optional[bool] = False, ) -> bool: @@ -124,7 +124,7 @@ class FTPClient(FTPServiceABC): # normally FTP will choose a random port for the transfer, but using the FTP command port will do for now # create FTP packet - payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.PORT, ftp_command_args=Port["FTP"]) + payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.PORT, ftp_command_args=PORT_LOOKUP["FTP"]) if self.send(payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id): if payload.status_code == FTPStatusCode.OK: @@ -152,7 +152,7 @@ class FTPClient(FTPServiceABC): return False def _disconnect_from_server( - self, dest_ip_address: Optional[IPv4Address] = None, dest_port: Optional[int] = Port["FTP"] + self, dest_ip_address: Optional[IPv4Address] = None, dest_port: Optional[int] = PORT_LOOKUP["FTP"] ) -> bool: """ Connects the client from a given FTP server. @@ -179,7 +179,7 @@ class FTPClient(FTPServiceABC): src_file_name: str, dest_folder_name: str, dest_file_name: str, - dest_port: Optional[int] = Port["FTP"], + dest_port: Optional[int] = PORT_LOOKUP["FTP"], session_id: Optional[str] = None, ) -> bool: """ @@ -241,7 +241,7 @@ class FTPClient(FTPServiceABC): src_file_name: str, dest_folder_name: str, dest_file_name: str, - dest_port: Optional[int] = Port["FTP"], + dest_port: Optional[int] = PORT_LOOKUP["FTP"], ) -> bool: """ Request a file from a target IP address. diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index 701bff79..38a253be 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -3,9 +3,9 @@ from typing import Any, Optional from primaite import getLogger from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -23,8 +23,8 @@ class FTPServer(FTPServiceABC): def __init__(self, **kwargs): kwargs["name"] = "FTPServer" - kwargs["port"] = Port["FTP"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["FTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self.start() diff --git a/src/primaite/simulator/system/services/icmp/icmp.py b/src/primaite/simulator/system/services/icmp/icmp.py index a2dfac0d..486ba2b0 100644 --- a/src/primaite/simulator/system/services/icmp/icmp.py +++ b/src/primaite/simulator/system/services/icmp/icmp.py @@ -7,9 +7,9 @@ from primaite import getLogger from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.service import Service +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -26,8 +26,8 @@ class ICMP(Service): def __init__(self, **kwargs): kwargs["name"] = "ICMP" - kwargs["port"] = Port["NONE"] - kwargs["protocol"] = IPProtocol["ICMP"] + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["ICMP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index 40b8d273..184833e1 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -5,9 +5,9 @@ from typing import Dict, Optional from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.service import Service, ServiceOperatingState +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -21,8 +21,8 @@ class NTPClient(Service): def __init__(self, **kwargs): kwargs["name"] = "NTPClient" - kwargs["port"] = Port["NTP"] - kwargs["protocol"] = IPProtocol["UDP"] + kwargs["port"] = PORT_LOOKUP["NTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["UDP"] super().__init__(**kwargs) self.start() @@ -55,7 +55,7 @@ class NTPClient(Service): payload: NTPPacket, session_id: Optional[str] = None, dest_ip_address: IPv4Address = None, - dest_port: int = Port["NTP"], + dest_port: int = PORT_LOOKUP["NTP"], **kwargs, ) -> bool: """Requests NTP data from NTP server. diff --git a/src/primaite/simulator/system/services/ntp/ntp_server.py b/src/primaite/simulator/system/services/ntp/ntp_server.py index d9de40c6..4764bffb 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_server.py +++ b/src/primaite/simulator/system/services/ntp/ntp_server.py @@ -4,9 +4,9 @@ from typing import Dict, Optional from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.service import Service +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -16,8 +16,8 @@ class NTPServer(Service): def __init__(self, **kwargs): kwargs["name"] = "NTPServer" - kwargs["port"] = Port["NTP"] - kwargs["protocol"] = IPProtocol["UDP"] + kwargs["port"] = PORT_LOOKUP["NTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["UDP"] super().__init__(**kwargs) self.start() diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 41987aff..2b0bc02b 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -17,10 +17,10 @@ from primaite.simulator.network.protocols.ssh import ( SSHTransportMessage, SSHUserCredentials, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service, ServiceOperatingState +from primaite.utils.validators import PROTOCOL_LOOKUP # TODO 2824: Since remote terminal connections and remote user sessions are the same thing, we could refactor @@ -137,8 +137,8 @@ class Terminal(Service): def __init__(self, **kwargs): kwargs["name"] = "Terminal" - kwargs["port"] = Port["SSH"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["SSH"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index c021a86e..2805b1b2 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -10,11 +10,11 @@ from primaite.simulator.network.protocols.http import ( HttpResponsePacket, HttpStatusCode, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClientConnection from primaite.simulator.system.services.service import Service from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -49,10 +49,10 @@ class WebServer(Service): def __init__(self, **kwargs): kwargs["name"] = "WebServer" - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] # default for web is port 80 if kwargs.get("port") is None: - kwargs["port"] = Port["HTTP"] + kwargs["port"] = PORT_LOOKUP["HTTP"] super().__init__(**kwargs) self._install_web_files() diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 084bdaf6..d34678b9 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -13,9 +13,9 @@ from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.file_system.file_system import FileSystem, Folder from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState -from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.core.sys_log import SysLog +from primaite.utils.validators import PROTOCOL_LOOKUP if TYPE_CHECKING: from primaite.simulator.system.core.software_manager import SoftwareManager @@ -386,7 +386,7 @@ class IOSoftware(Software): session_id: Optional[str] = None, dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, dest_port: Optional[int] = None, - ip_protocol: str = IPProtocol["TCP"], + ip_protocol: str = PROTOCOL_LOOKUP["TCP"], **kwargs, ) -> bool: """ diff --git a/src/primaite/utils/validators.py b/src/primaite/utils/validators.py index 139d303c..f07b475d 100644 --- a/src/primaite/utils/validators.py +++ b/src/primaite/utils/validators.py @@ -6,6 +6,9 @@ from pydantic import BeforeValidator from typing_extensions import Annotated +# Define a custom type IPV4Address using the typing_extensions.Annotated. +# Annotated is used to attach metadata to type hints. In this case, it's used to associate the ipv4_validator +# with the IPv4Address type, ensuring that any usage of IPV4Address undergoes validation before assignment. def ipv4_validator(v: Any) -> IPv4Address: """ Validate the input and ensure it can be converted to an IPv4Address instance. @@ -24,9 +27,6 @@ def ipv4_validator(v: Any) -> IPv4Address: return IPv4Address(v) -# Define a custom type IPV4Address using the typing_extensions.Annotated. -# Annotated is used to attach metadata to type hints. In this case, it's used to associate the ipv4_validator -# with the IPv4Address type, ensuring that any usage of IPV4Address undergoes validation before assignment. IPV4Address: Final[Annotated] = Annotated[IPv4Address, BeforeValidator(ipv4_validator)] """ IPv4Address with with IPv4Address with with pre-validation and auto-conversion from str using ipv4_validator.. @@ -37,3 +37,39 @@ will automatically check and convert the input value to an instance of IPv4Addre any Pydantic model uses it. This ensures that any field marked with this type is not just an IPv4Address in form, but also valid according to the rules defined in ipv4_validator. """ + +# Define a custom port validator +Port: Final[Annotated] = Annotated[int, BeforeValidator(lambda n: 0 <= n <= 65535)] +"""Validates that network ports lie in the appropriate range of [0,65535].""" + +# Define a custom IP protocol validator +PROTOCOL_LOOKUP: dict[str, str] = dict( + NONE="none", + TCP="tcp", + UDP="udp", + ICMP="icmp", +) +""" +Lookup table used for compatibility with PrimAITE <= 3.3. Configs with the capitalised protocol names are converted +to lowercase at runtime. +""" +VALID_PROTOCOLS = ["none", "tcp", "udp", "icmp"] +"""Supported protocols.""" + + +def protocol_validator(v: Any) -> str: + """ + Validate that IP Protocols are chosen from the list of supported IP Protocols. + + The protocol list is dynamic because plugins are able to extend it, therefore it is necessary to use this custom + validator instead of being able to specify a union of string literals. + """ + if v in PROTOCOL_LOOKUP: + return PROTOCOL_LOOKUP(v) + if v in VALID_PROTOCOLS: + return v + raise ValueError(f"{v} is not a valid IP Protocol. It must be one of the following: {VALID_PROTOCOLS}") + + +IPProtocol: Final[Annotated] = Annotated[str, BeforeValidator(protocol_validator)] +"""Validates that IP Protocols used in the simulation belong to the list of supported protocols.""" diff --git a/tests/conftest.py b/tests/conftest.py index 1ffa2146..687bec92 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,8 +18,7 @@ from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.networks import arcd_uc2_network -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.web_browser import WebBrowser @@ -28,6 +27,7 @@ from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validators import PROTOCOL_LOOKUP from tests import TEST_ASSETS_ROOT rayinit() @@ -45,8 +45,8 @@ class DummyService(Service): def __init__(self, **kwargs): kwargs["name"] = "DummyService" - kwargs["port"] = Port["HTTP"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["HTTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def receive(self, payload: Any, session_id: str, **kwargs) -> bool: @@ -58,8 +58,8 @@ class DummyApplication(Application, identifier="DummyApplication"): def __init__(self, **kwargs): kwargs["name"] = "DummyApplication" - kwargs["port"] = Port["HTTP"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["HTTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: @@ -77,7 +77,7 @@ def uc2_network() -> Network: @pytest.fixture(scope="function") def service(file_system) -> DummyService: return DummyService( - name="DummyService", port=Port["ARP"], file_system=file_system, sys_log=SysLog(hostname="dummy_service") + name="DummyService", port=PORT_LOOKUP["ARP"], file_system=file_system, sys_log=SysLog(hostname="dummy_service") ) @@ -90,7 +90,7 @@ def service_class(): def application(file_system) -> DummyApplication: return DummyApplication( name="DummyApplication", - port=Port["ARP"], + port=PORT_LOOKUP["ARP"], file_system=file_system, sys_log=SysLog(hostname="dummy_application"), ) @@ -350,10 +350,10 @@ def install_stuff_to_sim(sim: Simulation): network.connect(endpoint_a=server_2.network_interface[1], endpoint_b=switch_2.network_interface[2]) # 2: Configure base ACL - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["DNS"], dst_port=Port["DNS"], position=1) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=3) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22) + router.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["DNS"], dst_port=PORT_LOOKUP["DNS"], position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=3) # 3: Install server software server_1.software_manager.install(DNSServer) @@ -379,13 +379,13 @@ def install_stuff_to_sim(sim: Simulation): r = sim.network.router_nodes[0] for i, acl_rule in enumerate(r.acl.acl): if i == 1: - assert acl_rule.src_port == acl_rule.dst_port == Port["DNS"] + assert acl_rule.src_port == acl_rule.dst_port == PORT_LOOKUP["DNS"] elif i == 3: - assert acl_rule.src_port == acl_rule.dst_port == Port["HTTP"] + assert acl_rule.src_port == acl_rule.dst_port == PORT_LOOKUP["HTTP"] elif i == 22: - assert acl_rule.src_port == acl_rule.dst_port == Port["ARP"] + assert acl_rule.src_port == acl_rule.dst_port == PORT_LOOKUP["ARP"] elif i == 23: - assert acl_rule.protocol == IPProtocol["ICMP"] + assert acl_rule.protocol == PROTOCOL_LOOKUP["ICMP"] elif i == 24: ... else: diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py index d35e2ebb..6d0ef7b0 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py @@ -9,8 +9,8 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validators import PROTOCOL_LOOKUP from tests.integration_tests.configuration_file_parsing import BASIC_FIREWALL, DMZ_NETWORK, load_config @@ -68,44 +68,44 @@ def test_firewall_acl_rules_correctly_added(dmz_config): # ICMP and ARP should be allowed internal_inbound assert firewall.internal_inbound_acl.num_rules == 2 assert firewall.internal_inbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.internal_inbound_acl.acl[22].src_port == Port["ARP"] - assert firewall.internal_inbound_acl.acl[22].dst_port == Port["ARP"] + assert firewall.internal_inbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.internal_inbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert firewall.internal_inbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.internal_inbound_acl.acl[23].protocol == IPProtocol["ICMP"] + assert firewall.internal_inbound_acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert firewall.internal_inbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed internal_outbound assert firewall.internal_outbound_acl.num_rules == 2 assert firewall.internal_outbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.internal_outbound_acl.acl[22].src_port == Port["ARP"] - assert firewall.internal_outbound_acl.acl[22].dst_port == Port["ARP"] + assert firewall.internal_outbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.internal_outbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert firewall.internal_outbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.internal_outbound_acl.acl[23].protocol == IPProtocol["ICMP"] + assert firewall.internal_outbound_acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert firewall.internal_outbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed dmz_inbound assert firewall.dmz_inbound_acl.num_rules == 2 assert firewall.dmz_inbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.dmz_inbound_acl.acl[22].src_port == Port["ARP"] - assert firewall.dmz_inbound_acl.acl[22].dst_port == Port["ARP"] + assert firewall.dmz_inbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.dmz_inbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert firewall.dmz_inbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.dmz_inbound_acl.acl[23].protocol == IPProtocol["ICMP"] + assert firewall.dmz_inbound_acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert firewall.dmz_inbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed dmz_outbound assert firewall.dmz_outbound_acl.num_rules == 2 assert firewall.dmz_outbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.dmz_outbound_acl.acl[22].src_port == Port["ARP"] - assert firewall.dmz_outbound_acl.acl[22].dst_port == Port["ARP"] + assert firewall.dmz_outbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.dmz_outbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert firewall.dmz_outbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.dmz_outbound_acl.acl[23].protocol == IPProtocol["ICMP"] + assert firewall.dmz_outbound_acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert firewall.dmz_outbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed external_inbound assert firewall.external_inbound_acl.num_rules == 1 assert firewall.external_inbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.external_inbound_acl.acl[22].src_port == Port["ARP"] - assert firewall.external_inbound_acl.acl[22].dst_port == Port["ARP"] + assert firewall.external_inbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.external_inbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] # external_inbound should have implicit action PERMIT # ICMP does not have a provided ACL Rule but implicit action should allow anything assert firewall.external_inbound_acl.implicit_action == ACLAction.PERMIT @@ -113,8 +113,8 @@ def test_firewall_acl_rules_correctly_added(dmz_config): # ICMP and ARP should be allowed external_outbound assert firewall.external_outbound_acl.num_rules == 1 assert firewall.external_outbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.external_outbound_acl.acl[22].src_port == Port["ARP"] - assert firewall.external_outbound_acl.acl[22].dst_port == Port["ARP"] + assert firewall.external_outbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.external_outbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] # external_outbound should have implicit action PERMIT # ICMP does not have a provided ACL Rule but implicit action should allow anything assert firewall.external_outbound_acl.implicit_action == ACLAction.PERMIT diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py index 16543565..c348ee81 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py @@ -6,8 +6,8 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validators import PROTOCOL_LOOKUP from tests.integration_tests.configuration_file_parsing import DMZ_NETWORK, load_config @@ -63,8 +63,8 @@ def test_router_acl_rules_correctly_added(dmz_config): # ICMP and ARP should be allowed assert router_1.acl.num_rules == 2 assert router_1.acl.acl[22].action == ACLAction.PERMIT - assert router_1.acl.acl[22].src_port == Port["ARP"] - assert router_1.acl.acl[22].dst_port == Port["ARP"] + assert router_1.acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert router_1.acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert router_1.acl.acl[23].action == ACLAction.PERMIT - assert router_1.acl.acl[23].protocol == IPProtocol["ICMP"] + assert router_1.acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert router_1.acl.implicit_action == ACLAction.DENY diff --git a/tests/integration_tests/extensions/applications/extended_application.py b/tests/integration_tests/extensions/applications/extended_application.py index 8e3d33e1..28029b32 100644 --- a/tests/integration_tests/extensions/applications/extended_application.py +++ b/tests/integration_tests/extensions/applications/extended_application.py @@ -15,11 +15,11 @@ from primaite.simulator.network.protocols.http import ( HttpResponsePacket, HttpStatusCode, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.dns.dns_client import DNSClient +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -44,10 +44,10 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): def __init__(self, **kwargs): kwargs["name"] = "ExtendedApplication" - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] # default for web is port 80 if kwargs.get("port") is None: - kwargs["port"] = Port["HTTP"] + kwargs["port"] = PORT_LOOKUP["HTTP"] super().__init__(**kwargs) self.run() @@ -127,7 +127,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): if self.send( payload=payload, dest_ip_address=self.domain_name_ip_address, - dest_port=parsed_url.port if parsed_url.port else Port["HTTP"], + dest_port=parsed_url.port if parsed_url.port else PORT_LOOKUP["HTTP"], ): self.sys_log.info( f"{self.name}: Received HTTP {payload.request_method.name} " @@ -155,7 +155,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): self, payload: HttpRequestPacket, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[int] = Port["HTTP"], + dest_port: Optional[int] = PORT_LOOKUP["HTTP"], session_id: Optional[str] = None, **kwargs, ) -> bool: diff --git a/tests/integration_tests/extensions/services/extended_service.py b/tests/integration_tests/extensions/services/extended_service.py index b745b774..70d47aaa 100644 --- a/tests/integration_tests/extensions/services/extended_service.py +++ b/tests/integration_tests/extensions/services/extended_service.py @@ -7,12 +7,12 @@ from primaite import getLogger from primaite.simulator.file_system.file_system import File from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus from primaite.simulator.file_system.folder import Folder -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.service import Service, ServiceOperatingState from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -38,8 +38,8 @@ class ExtendedService(Service, identifier="extendedservice"): def __init__(self, **kwargs): kwargs["name"] = "ExtendedService" - kwargs["port"] = Port["POSTGRES_SERVER"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self._create_db_file() if kwargs.get("options"): diff --git a/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py b/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py index 17b0ba8c..2c750621 100644 --- a/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py +++ b/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py @@ -11,7 +11,7 @@ from primaite.simulator.network.hardware.base import UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.red_applications.c2.c2_beacon import C2Beacon from primaite.simulator.system.applications.red_applications.c2.c2_server import C2Command, C2Server from primaite.simulator.system.services.database.database_service import DatabaseService @@ -26,9 +26,9 @@ def game_and_agent_fixture(game_and_agent): game, agent = game_and_agent router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=4) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["DNS"], dst_port=Port["DNS"], position=5) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["FTP"], dst_port=Port["FTP"], position=6) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=4) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["DNS"], dst_port=PORT_LOOKUP["DNS"], position=5) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=6) c2_server_host = game.simulation.network.get_node_by_hostname("client_1") c2_server_host.software_manager.install(software_class=C2Server) diff --git a/tests/integration_tests/game_layer/actions/test_configure_actions.py b/tests/integration_tests/game_layer/actions/test_configure_actions.py index 34ee25d6..b56a4b99 100644 --- a/tests/integration_tests/game_layer/actions/test_configure_actions.py +++ b/tests/integration_tests/game_layer/actions/test_configure_actions.py @@ -11,7 +11,7 @@ from primaite.game.agent.actions import ( ) from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot @@ -200,7 +200,7 @@ class TestConfigureDoSBot: game.step() assert dos_bot.target_ip_address == IPv4Address("192.168.1.99") - assert dos_bot.target_port == Port["POSTGRES_SERVER"] + assert dos_bot.target_port == PORT_LOOKUP["POSTGRES_SERVER"] assert dos_bot.payload == "HACC" assert not dos_bot.repeat assert dos_bot.port_scan_p_of_success == 0.875 diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py index 857edd26..bc168c3c 100644 --- a/tests/integration_tests/game_layer/actions/test_terminal_actions.py +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -9,7 +9,7 @@ from primaite.simulator.network.hardware.base import UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection @@ -20,7 +20,7 @@ def game_and_agent_fixture(game_and_agent): game, agent = game_and_agent router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["SSH"], dst_port=Port["SSH"], position=4) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=4) return (game, agent) diff --git a/tests/integration_tests/game_layer/observations/test_acl_observations.py b/tests/integration_tests/game_layer/observations/test_acl_observations.py index 28f9ac5a..2bf0486c 100644 --- a/tests/integration_tests/game_layer/observations/test_acl_observations.py +++ b/tests/integration_tests/game_layer/observations/test_acl_observations.py @@ -4,7 +4,7 @@ import pytest from primaite.game.agent.observations.acl_observation import ACLObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.sim_container import Simulation from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.ntp.ntp_server import NTPServer @@ -33,7 +33,7 @@ def test_acl_observations(simulation): server.software_manager.install(NTPServer) # add router acl rule - router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port["NTP"], src_port=Port["NTP"], position=1) + router.acl.add_rule(action=ACLAction.PERMIT, dst_port=PORT_LOOKUP["NTP"], src_port=PORT_LOOKUP["NTP"], position=1) acl_obs = ACLObservation( where=["network", "nodes", router.hostname, "acl", "acl"], diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 21fe4bed..af8c4669 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -5,8 +5,8 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import ACLAction from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validators import PROTOCOL_LOOKUP def check_default_rules(acl_obs): @@ -62,13 +62,13 @@ def test_firewall_observation(): # add a rule to the internal inbound and check that the observation is correct firewall.internal_inbound_acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol["TCP"], + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="10.0.0.1", src_wildcard_mask="0.0.0.1", dst_ip_address="10.0.0.2", dst_wildcard_mask="0.0.0.1", - src_port=Port["HTTP"], - dst_port=Port["HTTP"], + src_port=PORT_LOOKUP["HTTP"], + dst_port=PORT_LOOKUP["HTTP"], position=5, ) diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index c28e1bb8..cdd428b0 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -8,9 +8,9 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.sim_container import Simulation +from primaite.utils.validators import PROTOCOL_LOOKUP def test_router_observation(): @@ -39,13 +39,13 @@ def test_router_observation(): # Add an ACL rule to the router router.acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol["TCP"], + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="10.0.0.1", src_wildcard_mask="0.0.0.1", dst_ip_address="10.0.0.2", dst_wildcard_mask="0.0.0.1", - src_port=Port["HTTP"], - dst_port=Port["HTTP"], + src_port=PORT_LOOKUP["HTTP"], + dst_port=PORT_LOOKUP["HTTP"], position=5, ) # Observe the state using the RouterObservation instance diff --git a/tests/integration_tests/game_layer/observations/test_user_observations.py b/tests/integration_tests/game_layer/observations/test_user_observations.py index 6ca4bc9e..70637b0d 100644 --- a/tests/integration_tests/game_layer/observations/test_user_observations.py +++ b/tests/integration_tests/game_layer/observations/test_user_observations.py @@ -3,7 +3,7 @@ import pytest from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from tests import TEST_ASSETS_ROOT DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml" @@ -15,7 +15,7 @@ def env_with_ssh() -> PrimaiteGymEnv: env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG) env.agent.flatten_obs = False router: Router = env.game.simulation.network.get_node_by_hostname("router_1") - router.acl.add_rule(ACLAction.PERMIT, src_port=Port["SSH"], dst_port=Port["SSH"], position=3) + router.acl.add_rule(ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=3) return env diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index c3e86263..2675b615 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -21,11 +21,11 @@ from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validators import PROTOCOL_LOOKUP from tests import TEST_ASSETS_ROOT FIREWALL_ACTIONS_NETWORK = TEST_ASSETS_ROOT / "configs/firewall_actions_network.yaml" @@ -608,9 +608,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.internal_outbound_acl.acl[1].action.name == "DENY" assert firewall.internal_outbound_acl.acl[1].src_ip_address == IPv4Address("192.168.0.10") assert firewall.internal_outbound_acl.acl[1].dst_ip_address is None - assert firewall.internal_outbound_acl.acl[1].dst_port == Port["DNS"] - assert firewall.internal_outbound_acl.acl[1].src_port == Port["ARP"] - assert firewall.internal_outbound_acl.acl[1].protocol == IPProtocol["ICMP"] + assert firewall.internal_outbound_acl.acl[1].dst_port == PORT_LOOKUP["DNS"] + assert firewall.internal_outbound_acl.acl[1].src_port == PORT_LOOKUP["ARP"] + assert firewall.internal_outbound_acl.acl[1].protocol == PROTOCOL_LOOKUP["ICMP"] env.step(4) # Remove ACL rule from Internal Outbound assert firewall.internal_outbound_acl.num_rules == 2 @@ -620,9 +620,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.dmz_inbound_acl.acl[1].action.name == "DENY" assert firewall.dmz_inbound_acl.acl[1].src_ip_address == IPv4Address("192.168.10.10") assert firewall.dmz_inbound_acl.acl[1].dst_ip_address == IPv4Address("192.168.0.10") - assert firewall.dmz_inbound_acl.acl[1].dst_port == Port["HTTP"] - assert firewall.dmz_inbound_acl.acl[1].src_port == Port["HTTP"] - assert firewall.dmz_inbound_acl.acl[1].protocol == IPProtocol["UDP"] + assert firewall.dmz_inbound_acl.acl[1].dst_port == PORT_LOOKUP["HTTP"] + assert firewall.dmz_inbound_acl.acl[1].src_port == PORT_LOOKUP["HTTP"] + assert firewall.dmz_inbound_acl.acl[1].protocol == PROTOCOL_LOOKUP["UDP"] env.step(6) # Remove ACL rule from DMZ Inbound assert firewall.dmz_inbound_acl.num_rules == 2 @@ -632,9 +632,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.dmz_outbound_acl.acl[2].action.name == "DENY" assert firewall.dmz_outbound_acl.acl[2].src_ip_address == IPv4Address("192.168.10.10") assert firewall.dmz_outbound_acl.acl[2].dst_ip_address == IPv4Address("192.168.0.10") - assert firewall.dmz_outbound_acl.acl[2].dst_port == Port["HTTP"] - assert firewall.dmz_outbound_acl.acl[2].src_port == Port["HTTP"] - assert firewall.dmz_outbound_acl.acl[2].protocol == IPProtocol["TCP"] + assert firewall.dmz_outbound_acl.acl[2].dst_port == PORT_LOOKUP["HTTP"] + assert firewall.dmz_outbound_acl.acl[2].src_port == PORT_LOOKUP["HTTP"] + assert firewall.dmz_outbound_acl.acl[2].protocol == PROTOCOL_LOOKUP["TCP"] env.step(8) # Remove ACL rule from DMZ Outbound assert firewall.dmz_outbound_acl.num_rules == 2 @@ -644,9 +644,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.external_inbound_acl.acl[10].action.name == "DENY" assert firewall.external_inbound_acl.acl[10].src_ip_address == IPv4Address("192.168.20.10") assert firewall.external_inbound_acl.acl[10].dst_ip_address == IPv4Address("192.168.10.10") - assert firewall.external_inbound_acl.acl[10].dst_port == Port["POSTGRES_SERVER"] - assert firewall.external_inbound_acl.acl[10].src_port == Port["POSTGRES_SERVER"] - assert firewall.external_inbound_acl.acl[10].protocol == IPProtocol["ICMP"] + assert firewall.external_inbound_acl.acl[10].dst_port == PORT_LOOKUP["POSTGRES_SERVER"] + assert firewall.external_inbound_acl.acl[10].src_port == PORT_LOOKUP["POSTGRES_SERVER"] + assert firewall.external_inbound_acl.acl[10].protocol == PROTOCOL_LOOKUP["ICMP"] env.step(10) # Remove ACL rule from External Inbound assert firewall.external_inbound_acl.num_rules == 1 diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 570c4ad6..0afe666c 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -9,11 +9,11 @@ from primaite.interface.request import RequestResponse from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService +from primaite.utils.validators import PROTOCOL_LOOKUP from tests import TEST_ASSETS_ROOT from tests.conftest import ControlledAgent @@ -42,7 +42,12 @@ def test_WebpageUnavailablePenalty(game_and_agent): # Block the web traffic, check that failing to fetch the webpage yields a reward of -0.7 router: Router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.DENY, protocol=IPProtocol["TCP"], src_port=Port["HTTP"], dst_port=Port["HTTP"]) + router.acl.add_rule( + action=ACLAction.DENY, + protocol=PROTOCOL_LOOKUP["TCP"], + src_port=PORT_LOOKUP["HTTP"], + dst_port=PORT_LOOKUP["HTTP"], + ) agent.store_action(("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0})) game.step() assert agent.reward_function.current_reward == -0.7 @@ -66,7 +71,7 @@ def test_uc2_rewards(game_and_agent): router: Router = game.simulation.network.get_node_by_hostname("router") router.acl.add_rule( - ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=2 + ACLAction.PERMIT, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"], position=2 ) comp = GreenAdminDatabaseUnreachablePenalty("client_1") diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index da0af89d..b5b2acbc 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -8,10 +8,10 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.services.service import Service +from primaite.utils.validators import PROTOCOL_LOOKUP class BroadcastTestService(Service): @@ -20,8 +20,8 @@ class BroadcastTestService(Service): def __init__(self, **kwargs): # Set default service properties for broadcasting kwargs["name"] = "BroadcastService" - kwargs["port"] = Port["HTTP"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["HTTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: @@ -33,12 +33,14 @@ class BroadcastTestService(Service): super().send( payload="unicast", dest_ip_address=ip_address, - dest_port=Port["HTTP"], + dest_port=PORT_LOOKUP["HTTP"], ) def broadcast(self, ip_network: IPv4Network): # Send a broadcast payload to an entire IP network - super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port["HTTP"], ip_protocol=self.protocol) + super().send( + payload="broadcast", dest_ip_address=ip_network, dest_port=PORT_LOOKUP["HTTP"], ip_protocol=self.protocol + ) class BroadcastTestClient(Application, identifier="BroadcastTestClient"): @@ -49,8 +51,8 @@ class BroadcastTestClient(Application, identifier="BroadcastTestClient"): def __init__(self, **kwargs): # Set default client properties kwargs["name"] = "BroadcastTestClient" - kwargs["port"] = Port["HTTP"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["HTTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/tests/integration_tests/network/test_firewall.py b/tests/integration_tests/network/test_firewall.py index 44b660cf..58763c3e 100644 --- a/tests/integration_tests/network/test_firewall.py +++ b/tests/integration_tests/network/test_firewall.py @@ -7,10 +7,10 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.ntp.ntp_server import NTPServer +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -53,31 +53,31 @@ def dmz_external_internal_network() -> Network: ) # Allow ICMP - firewall_node.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) - firewall_node.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) - firewall_node.external_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) - firewall_node.external_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) - firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) - firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + firewall_node.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.external_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.external_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Allow ARP firewall_node.internal_inbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) firewall_node.internal_outbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) firewall_node.external_inbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) firewall_node.external_outbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) firewall_node.dmz_inbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) firewall_node.dmz_outbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) # external node @@ -267,10 +267,10 @@ def test_service_allowed_with_rule(dmz_external_internal_network): assert not internal_ntp_client.time firewall.internal_outbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=1 ) firewall.internal_inbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=1 ) internal_ntp_client.request_time() @@ -279,8 +279,12 @@ def test_service_allowed_with_rule(dmz_external_internal_network): assert not dmz_ntp_client.time - firewall.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1) - firewall.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1) + firewall.dmz_outbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=1 + ) + firewall.dmz_inbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=1 + ) dmz_ntp_client.request_time() diff --git a/tests/integration_tests/network/test_routing.py b/tests/integration_tests/network/test_routing.py index 641342e2..dde66a43 100644 --- a/tests/integration_tests/network/test_routing.py +++ b/tests/integration_tests/network/test_routing.py @@ -6,10 +6,10 @@ import pytest from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.ntp.ntp_server import NTPServer +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -73,8 +73,10 @@ def multi_hop_network() -> Network: router_1.enable_port(2) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Configure PC B pc_b = Computer( @@ -197,8 +199,12 @@ def test_routing_services(multi_hop_network): router_1: Router = multi_hop_network.get_node_by_hostname("router_1") # noqa router_2: Router = multi_hop_network.get_node_by_hostname("router_2") # noqa - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=21) - router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=21) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=21 + ) + router_2.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=21 + ) assert ntp_client.time is None ntp_client.request_time() diff --git a/tests/integration_tests/network/test_wireless_router.py b/tests/integration_tests/network/test_wireless_router.py index 2f1be930..520ec21a 100644 --- a/tests/integration_tests/network/test_wireless_router.py +++ b/tests/integration_tests/network/test_wireless_router.py @@ -7,8 +7,8 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validators import PROTOCOL_LOOKUP from tests import TEST_ASSETS_ROOT @@ -37,8 +37,10 @@ def wireless_wan_network(): network.connect(pc_a.network_interface[1], router_1.network_interface[2]) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Configure PC B pc_b = Computer( diff --git a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py index d819b511..b1979154 100644 --- a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py +++ b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py @@ -13,8 +13,7 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import AccessControlList, ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.c2.c2_beacon import C2Beacon @@ -25,6 +24,7 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validators import PROTOCOL_LOOKUP from tests import TEST_ASSETS_ROOT @@ -227,7 +227,7 @@ def test_c2_suite_acl_block(basic_network): assert c2_beacon.c2_connection_active == True # Now we add a HTTP blocking acl (Thus preventing a keep alive) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0) c2_beacon.apply_timestep(2) c2_beacon.apply_timestep(3) @@ -322,8 +322,8 @@ def test_c2_suite_acl_bypass(basic_network): ################ Confirm Default Setup ######################### # Permitting all HTTP & FTP traffic - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["FTP"], dst_port=Port["FTP"], position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=1) c2_beacon.apply_timestep(0) assert c2_beacon.keep_alive_inactivity == 1 @@ -337,7 +337,7 @@ def test_c2_suite_acl_bypass(basic_network): ################ Denying HTTP Traffic ######################### # Now we add a HTTP blocking acl (Thus preventing a keep alive) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0) blocking_acl: AccessControlList = router.acl.acl[0] # Asserts to show the C2 Suite is unable to maintain connection: @@ -359,8 +359,8 @@ def test_c2_suite_acl_bypass(basic_network): c2_beacon.configure( c2_server_ip_address="192.168.0.2", keep_alive_frequency=2, - masquerade_port=Port["FTP"], - masquerade_protocol=IPProtocol["TCP"], + masquerade_port=PORT_LOOKUP["FTP"], + masquerade_protocol=PROTOCOL_LOOKUP["TCP"], ) c2_beacon.establish() @@ -407,8 +407,8 @@ def test_c2_suite_acl_bypass(basic_network): ################ Denying FTP Traffic & Enable HTTP ######################### # Blocking FTP and re-permitting HTTP: - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port["FTP"], dst_port=Port["FTP"], position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=1) blocking_acl: AccessControlList = router.acl.acl[1] # Asserts to show the C2 Suite is unable to maintain connection: @@ -430,8 +430,8 @@ def test_c2_suite_acl_bypass(basic_network): c2_beacon.configure( c2_server_ip_address="192.168.0.2", keep_alive_frequency=2, - masquerade_port=Port["HTTP"], - masquerade_protocol=IPProtocol["TCP"], + masquerade_port=PORT_LOOKUP["HTTP"], + masquerade_protocol=PROTOCOL_LOOKUP["TCP"], ) c2_beacon.establish() diff --git a/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py b/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py index 9c0760b7..54c372e4 100644 --- a/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py +++ b/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py @@ -9,7 +9,7 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection from primaite.simulator.system.applications.red_applications.data_manipulation_bot import ( @@ -52,7 +52,10 @@ def data_manipulation_db_server_green_client(example_network) -> Network: router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) client_1: Computer = network.get_node_by_hostname("client_1") diff --git a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py index 709a417f..ad0a519b 100644 --- a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py +++ b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py @@ -8,7 +8,7 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot @@ -26,7 +26,7 @@ def dos_bot_and_db_server(client_server) -> Tuple[DoSBot, Computer, DatabaseServ dos_bot: DoSBot = computer.software_manager.software.get("DoSBot") dos_bot.configure( target_ip_address=IPv4Address(server.network_interface[1].ip_address), - target_port=Port["POSTGRES_SERVER"], + target_port=PORT_LOOKUP["POSTGRES_SERVER"], ) # Install DB Server service on server @@ -43,7 +43,10 @@ def dos_bot_db_server_green_client(example_network) -> Network: router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) client_1: Computer = network.get_node_by_hostname("client_1") @@ -56,7 +59,7 @@ def dos_bot_db_server_green_client(example_network) -> Network: dos_bot: DoSBot = client_1.software_manager.software.get("DoSBot") dos_bot.configure( target_ip_address=IPv4Address(server.network_interface[1].ip_address), - target_port=Port["POSTGRES_SERVER"], + target_port=PORT_LOOKUP["POSTGRES_SERVER"], ) # install db server service on server diff --git a/tests/integration_tests/system/red_applications/test_ransomware_script.py b/tests/integration_tests/system/red_applications/test_ransomware_script.py index b34e9b30..09cbcf85 100644 --- a/tests/integration_tests/system/red_applications/test_ransomware_script.py +++ b/tests/integration_tests/system/red_applications/test_ransomware_script.py @@ -9,7 +9,7 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.database.database_service import DatabaseService @@ -47,7 +47,10 @@ def ransomware_script_db_server_green_client(example_network) -> Network: router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) client_1: Computer = network.get_node_by_hostname("client_1") diff --git a/tests/integration_tests/system/test_nmap.py b/tests/integration_tests/system/test_nmap.py index 9d92b660..c1c4df82 100644 --- a/tests/integration_tests/system/test_nmap.py +++ b/tests/integration_tests/system/test_nmap.py @@ -5,9 +5,9 @@ from ipaddress import IPv4Address, IPv4Network import yaml from primaite.game.game import PrimaiteGame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.nmap import NMAP +from primaite.utils.validators import PROTOCOL_LOOKUP from tests import TEST_ASSETS_ROOT @@ -74,11 +74,11 @@ def test_port_scan_one_node_one_port(example_network): actual_result = client_1_nmap.port_scan( target_ip_address=client_2.network_interface[1].ip_address, - target_port=Port["DNS"], - target_protocol=IPProtocol["TCP"], + target_port=PORT_LOOKUP["DNS"], + target_protocol=PROTOCOL_LOOKUP["TCP"], ) - expected_result = {IPv4Address("192.168.10.22"): {IPProtocol["TCP"]: [Port["DNS"]]}} + expected_result = {IPv4Address("192.168.10.22"): {PROTOCOL_LOOKUP["TCP"]: [PORT_LOOKUP["DNS"]]}} assert actual_result == expected_result @@ -103,14 +103,20 @@ def test_port_scan_full_subnet_all_ports_and_protocols(example_network): actual_result = client_1_nmap.port_scan( target_ip_address=IPv4Network("192.168.10.0/24"), - target_port=[Port["ARP"], Port["HTTP"], Port["FTP"], Port["DNS"], Port["NTP"]], + target_port=[ + PORT_LOOKUP["ARP"], + PORT_LOOKUP["HTTP"], + PORT_LOOKUP["FTP"], + PORT_LOOKUP["DNS"], + PORT_LOOKUP["NTP"], + ], ) expected_result = { - IPv4Address("192.168.10.1"): {IPProtocol["UDP"]: [Port["ARP"]]}, + IPv4Address("192.168.10.1"): {PROTOCOL_LOOKUP["UDP"]: [PORT_LOOKUP["ARP"]]}, IPv4Address("192.168.10.22"): { - IPProtocol["TCP"]: [Port["HTTP"], Port["FTP"], Port["DNS"]], - IPProtocol["UDP"]: [Port["ARP"], Port["NTP"]], + PROTOCOL_LOOKUP["TCP"]: [PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]], + PROTOCOL_LOOKUP["UDP"]: [PORT_LOOKUP["ARP"], PORT_LOOKUP["NTP"]], }, } @@ -124,10 +130,12 @@ def test_network_service_recon_all_ports_and_protocols(example_network): client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa actual_result = client_1_nmap.network_service_recon( - target_ip_address=IPv4Network("192.168.10.0/24"), target_port=Port["HTTP"], target_protocol=IPProtocol["TCP"] + target_ip_address=IPv4Network("192.168.10.0/24"), + target_port=PORT_LOOKUP["HTTP"], + target_protocol=PROTOCOL_LOOKUP["TCP"], ) - expected_result = {IPv4Address("192.168.10.22"): {IPProtocol["TCP"]: [Port["HTTP"]]}} + expected_result = {IPv4Address("192.168.10.22"): {PROTOCOL_LOOKUP["TCP"]: [PORT_LOOKUP["HTTP"]]}} assert sort_dict(actual_result) == sort_dict(expected_result) diff --git a/tests/integration_tests/system/test_service_listening_on_ports.py b/tests/integration_tests/system/test_service_listening_on_ports.py index 5226ab4a..4108041d 100644 --- a/tests/integration_tests/system/test_service_listening_on_ports.py +++ b/tests/integration_tests/system/test_service_listening_on_ports.py @@ -6,19 +6,19 @@ from pydantic import Field from primaite.game.game import PrimaiteGame from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.service import Service +from primaite.utils.validators import PROTOCOL_LOOKUP from tests import TEST_ASSETS_ROOT class _DatabaseListener(Service): name: str = "DatabaseListener" - protocol: str = IPProtocol["TCP"] - port: int = Port["NONE"] - listen_on_ports: Set[int] = {Port["POSTGRES_SERVER"]} + protocol: str = PROTOCOL_LOOKUP["TCP"] + port: int = PORT_LOOKUP["NONE"] + listen_on_ports: Set[int] = {PORT_LOOKUP["POSTGRES_SERVER"]} payloads_received: List[Any] = Field(default_factory=list) def receive(self, payload: Any, session_id: str, **kwargs) -> bool: @@ -51,8 +51,8 @@ def test_http_listener(client_server): computer.session_manager.receive_payload_from_software_manager( payload="masquerade as Database traffic", dst_ip_address=server.network_interface[1].ip_address, - dst_port=Port["POSTGRES_SERVER"], - ip_protocol=IPProtocol["TCP"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + ip_protocol=PROTOCOL_LOOKUP["TCP"], ) assert len(server_db_listener.payloads_received) == 1 @@ -76,9 +76,9 @@ def test_set_listen_on_ports_from_config(): network = PrimaiteGame.from_config(cfg=config_dict).simulation.network client: Computer = network.get_node_by_hostname("client") - assert Port["SMB"] in client.software_manager.get_open_ports() - assert Port["IPP"] in client.software_manager.get_open_ports() + assert PORT_LOOKUP["SMB"] in client.software_manager.get_open_ports() + assert PORT_LOOKUP["IPP"] in client.software_manager.get_open_ports() web_browser = client.software_manager.software["WebBrowser"] - assert not web_browser.listen_on_ports.difference({Port["SMB"], Port["IPP"]}) + assert not web_browser.listen_on_ports.difference({PORT_LOOKUP["SMB"], PORT_LOOKUP["IPP"]}) diff --git a/tests/integration_tests/system/test_web_client_server_and_database.py b/tests/integration_tests/system/test_web_client_server_and_database.py index 6c37360f..854ef41b 100644 --- a/tests/integration_tests/system/test_web_client_server_and_database.py +++ b/tests/integration_tests/system/test_web_client_server_and_database.py @@ -9,7 +9,7 @@ from primaite.simulator.network.hardware.base import Link from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService @@ -24,17 +24,22 @@ def web_client_web_server_database(example_network) -> Tuple[Network, Computer, # add rules to network router router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) # Allow DNS requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["DNS"], dst_port=Port["DNS"], position=1) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["DNS"], dst_port=PORT_LOOKUP["DNS"], position=1) # Allow FTP requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["FTP"], dst_port=Port["FTP"], position=2) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=2) # Open port 80 for web server - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=3) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=3 + ) # Create Computer computer: Computer = example_network.get_node_by_hostname("client_1") @@ -148,7 +153,9 @@ class TestWebBrowserHistory: assert web_browser.history[-1].response_code == 200 router = network.get_node_by_hostname("router_1") - router.acl.add_rule(action=ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) + router.acl.add_rule( + action=ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0 + ) assert not web_browser.get_webpage() assert len(web_browser.history) == 3 # with current NIC behaviour, even if you block communication, you won't get SERVER_UNREACHABLE because @@ -166,7 +173,9 @@ class TestWebBrowserHistory: web_browser.get_webpage() router = network.get_node_by_hostname("router_1") - router.acl.add_rule(action=ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) + router.acl.add_rule( + action=ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0 + ) web_browser.get_webpage() state = computer.describe_state() diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py index ff73e621..7813628c 100644 --- a/tests/integration_tests/test_simulation/test_request_response.py +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -12,7 +12,7 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from tests.conftest import DummyApplication, DummyService @@ -171,7 +171,7 @@ class TestDataManipulationGreenRequests: assert client_1_browser_execute.status == "success" assert client_2_browser_execute.status == "success" - router.acl.add_rule(ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=3) + router.acl.add_rule(ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=3) client_1_browser_execute = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"]) client_2_browser_execute = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"]) assert client_1_browser_execute.status == "failure" @@ -182,7 +182,9 @@ class TestDataManipulationGreenRequests: assert client_1_db_client_execute.status == "success" assert client_2_db_client_execute.status == "success" - router.acl.add_rule(ACLAction.DENY, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"]) + router.acl.add_rule( + ACLAction.DENY, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"] + ) client_1_db_client_execute = net.apply_request(["node", "client_1", "application", "DatabaseClient", "execute"]) client_2_db_client_execute = net.apply_request(["node", "client_2", "application", "DatabaseClient", "execute"]) assert client_1_db_client_execute.status == "failure" diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py index 4c471faa..ba7628c2 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py @@ -7,8 +7,9 @@ from primaite.simulator.network.hardware.base import generate_mac_address from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame -from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader, UDPHeader +from primaite.simulator.network.transmission.network_layer import IPPacket +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP, TCPHeader, UDPHeader +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -28,16 +29,16 @@ def router_with_acl_rules(): # Add rules here as needed acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol["TCP"], + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="192.168.1.1", - src_port=Port["HTTPS"], + src_port=PORT_LOOKUP["HTTPS"], dst_ip_address="192.168.1.2", - dst_port=Port["HTTP"], + dst_port=PORT_LOOKUP["HTTP"], position=1, ) acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol["TCP"], + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="192.168.1.3", src_port=8080, dst_ip_address="192.168.1.4", @@ -65,7 +66,7 @@ def router_with_wildcard_acl(): # Rule to permit traffic from a specific source IP and port to a specific destination IP and port acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol["TCP"], + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="192.168.1.1", src_port=8080, dst_ip_address="10.1.1.2", @@ -75,7 +76,7 @@ def router_with_wildcard_acl(): # Rule to deny traffic from an IP range to a specific destination IP and port acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol["TCP"], + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="192.168.1.0", src_wildcard_mask="0.0.0.255", dst_ip_address="10.1.1.3", @@ -109,11 +110,11 @@ def test_add_rule(router_with_acl_rules): acl = router_with_acl_rules.acl assert acl.acl[1].action == ACLAction.PERMIT - assert acl.acl[1].protocol == IPProtocol["TCP"] + assert acl.acl[1].protocol == PROTOCOL_LOOKUP["TCP"] assert acl.acl[1].src_ip_address == IPv4Address("192.168.1.1") - assert acl.acl[1].src_port == Port["HTTPS"] + assert acl.acl[1].src_port == PORT_LOOKUP["HTTPS"] assert acl.acl[1].dst_ip_address == IPv4Address("192.168.1.2") - assert acl.acl[1].dst_port == Port["HTTP"] + assert acl.acl[1].dst_port == PORT_LOOKUP["HTTP"] def test_remove_rule(router_with_acl_rules): @@ -136,8 +137,8 @@ def test_traffic_permitted_by_specific_rule(router_with_acl_rules): acl = router_with_acl_rules.acl permitted_frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="192.168.1.2", protocol=IPProtocol["TCP"]), - tcp=TCPHeader(src_port=Port["HTTPS"], dst_port=Port["HTTP"]), + ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="192.168.1.2", protocol=PROTOCOL_LOOKUP["TCP"]), + tcp=TCPHeader(src_port=PORT_LOOKUP["HTTPS"], dst_port=PORT_LOOKUP["HTTP"]), ) is_permitted, _ = acl.is_permitted(permitted_frame) assert is_permitted @@ -153,7 +154,7 @@ def test_traffic_denied_by_specific_rule(router_with_acl_rules): acl = router_with_acl_rules.acl not_permitted_frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.3", dst_ip_address="192.168.1.4", protocol=IPProtocol["TCP"]), + ip=IPPacket(src_ip_address="192.168.1.3", dst_ip_address="192.168.1.4", protocol=PROTOCOL_LOOKUP["TCP"]), tcp=TCPHeader(src_port=8080, dst_port=80), ) is_permitted, _ = acl.is_permitted(not_permitted_frame) @@ -173,8 +174,8 @@ def test_default_rule(router_with_acl_rules): acl = router_with_acl_rules.acl not_permitted_frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.5", dst_ip_address="192.168.1.12", protocol=IPProtocol["UDP"]), - udp=UDPHeader(src_port=Port["HTTPS"], dst_port=Port["HTTP"]), + ip=IPPacket(src_ip_address="192.168.1.5", dst_ip_address="192.168.1.12", protocol=PROTOCOL_LOOKUP["UDP"]), + udp=UDPHeader(src_port=PORT_LOOKUP["HTTPS"], dst_port=PORT_LOOKUP["HTTP"]), ) is_permitted, rule = acl.is_permitted(not_permitted_frame) assert not is_permitted @@ -189,7 +190,7 @@ def test_direct_ip_match_with_acl(router_with_wildcard_acl): acl = router_with_wildcard_acl.acl frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="10.1.1.2", protocol=IPProtocol["TCP"]), + ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="10.1.1.2", protocol=PROTOCOL_LOOKUP["TCP"]), tcp=TCPHeader(src_port=8080, dst_port=80), ) assert acl.is_permitted(frame)[0], "Direct IP match should be permitted." @@ -204,7 +205,7 @@ def test_ip_range_match_denied_with_acl(router_with_wildcard_acl): acl = router_with_wildcard_acl.acl frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.100", dst_ip_address="10.1.1.3", protocol=IPProtocol["TCP"]), + ip=IPPacket(src_ip_address="192.168.1.100", dst_ip_address="10.1.1.3", protocol=PROTOCOL_LOOKUP["TCP"]), tcp=TCPHeader(src_port=8080, dst_port=443), ) assert not acl.is_permitted(frame)[0], "IP range match with wildcard mask should be denied." @@ -219,7 +220,7 @@ def test_traffic_permitted_to_destination_range_with_acl(router_with_wildcard_ac acl = router_with_wildcard_acl.acl frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=IPProtocol["UDP"]), + ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=PROTOCOL_LOOKUP["UDP"]), udp=UDPHeader(src_port=1433, dst_port=1433), ) assert acl.is_permitted(frame)[0], "Traffic to destination IP range should be permitted." @@ -253,23 +254,23 @@ def test_ip_traffic_from_specific_subnet(): permitted_frame_1 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=IPProtocol["TCP"]), - tcp=TCPHeader(src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"]), + ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=PROTOCOL_LOOKUP["TCP"]), + tcp=TCPHeader(src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"]), ) assert acl.is_permitted(permitted_frame_1)[0] permitted_frame_2 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.10", dst_ip_address="85.199.214.101", protocol=IPProtocol["UDP"]), - udp=UDPHeader(src_port=Port["NTP"], dst_port=Port["NTP"]), + ip=IPPacket(src_ip_address="192.168.1.10", dst_ip_address="85.199.214.101", protocol=PROTOCOL_LOOKUP["UDP"]), + udp=UDPHeader(src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"]), ) assert acl.is_permitted(permitted_frame_2)[0] permitted_frame_3 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.200", dst_ip_address="192.168.1.1", protocol=IPProtocol["ICMP"]), + ip=IPPacket(src_ip_address="192.168.1.200", dst_ip_address="192.168.1.1", protocol=PROTOCOL_LOOKUP["ICMP"]), icmp=ICMPPacket(identifier=1), ) @@ -277,16 +278,16 @@ def test_ip_traffic_from_specific_subnet(): not_permitted_frame_1 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.0.50", dst_ip_address="10.2.200.200", protocol=IPProtocol["TCP"]), - tcp=TCPHeader(src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"]), + ip=IPPacket(src_ip_address="192.168.0.50", dst_ip_address="10.2.200.200", protocol=PROTOCOL_LOOKUP["TCP"]), + tcp=TCPHeader(src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"]), ) assert not acl.is_permitted(not_permitted_frame_1)[0] not_permitted_frame_2 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.2.10", dst_ip_address="85.199.214.101", protocol=IPProtocol["UDP"]), - udp=UDPHeader(src_port=Port["NTP"], dst_port=Port["NTP"]), + ip=IPPacket(src_ip_address="192.168.2.10", dst_ip_address="85.199.214.101", protocol=PROTOCOL_LOOKUP["UDP"]), + udp=UDPHeader(src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"]), ) assert not acl.is_permitted(not_permitted_frame_2)[0] diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py index 3551ce38..0e1844c4 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py @@ -2,8 +2,8 @@ from ipaddress import IPv4Address from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validators import PROTOCOL_LOOKUP def test_wireless_router_from_config(): @@ -67,12 +67,12 @@ def test_wireless_router_from_config(): r0 = rt.acl.acl[0] assert r0.action == ACLAction.PERMIT - assert r0.src_port == r0.dst_port == Port["POSTGRES_SERVER"] + assert r0.src_port == r0.dst_port == PORT_LOOKUP["POSTGRES_SERVER"] assert r0.src_ip_address == r0.dst_ip_address == r0.dst_wildcard_mask == r0.src_wildcard_mask == r0.protocol == None r1 = rt.acl.acl[1] assert r1.action == ACLAction.PERMIT - assert r1.protocol == IPProtocol["ICMP"] + assert r1.protocol == PROTOCOL_LOOKUP["ICMP"] assert ( r1.src_ip_address == r1.dst_ip_address diff --git a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py index 9fd39dfc..9e9a1f72 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py @@ -3,9 +3,10 @@ import pytest from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame -from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol, Precedence +from primaite.simulator.network.transmission.network_layer import IPPacket, Precedence from primaite.simulator.network.transmission.primaite_layer import AgentSource, DataStatus -from primaite.simulator.network.transmission.transport_layer import Port, TCPFlags, TCPHeader, UDPHeader +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP, TCPFlags, TCPHeader, UDPHeader +from primaite.utils.validators import PROTOCOL_LOOKUP def test_frame_minimal_instantiation(): @@ -20,7 +21,7 @@ def test_frame_minimal_instantiation(): ) # Check network layer default values - assert frame.ip.protocol == IPProtocol["TCP"] + assert frame.ip.protocol == PROTOCOL_LOOKUP["TCP"] assert frame.ip.ttl == 64 assert frame.ip.precedence == Precedence.ROUTINE @@ -40,7 +41,7 @@ def test_frame_creation_fails_tcp_without_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["TCP"]), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["TCP"]), ) @@ -49,7 +50,7 @@ def test_frame_creation_fails_udp_without_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["UDP"]), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["UDP"]), ) @@ -58,7 +59,7 @@ def test_frame_creation_fails_tcp_with_udp_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["TCP"]), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["TCP"]), udp=UDPHeader(src_port=8080, dst_port=80), ) @@ -68,7 +69,7 @@ def test_frame_creation_fails_udp_with_tcp_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["UDP"]), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["UDP"]), udp=TCPHeader(src_port=8080, dst_port=80), ) @@ -77,7 +78,7 @@ def test_icmp_frame_creation(): """Tests Frame creation for ICMP.""" frame = Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["ICMP"]), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["ICMP"]), icmp=ICMPPacket(), ) assert frame @@ -88,5 +89,5 @@ def test_icmp_frame_creation_fails_without_icmp_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["ICMP"]), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["ICMP"]), ) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py index 6e53aebc..fde70616 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py @@ -4,11 +4,11 @@ import pytest from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.red_applications.c2.c2_beacon import C2Beacon from primaite.simulator.system.applications.red_applications.c2.c2_server import C2Command, C2Server +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -129,19 +129,19 @@ def test_c2_handle_switching_port(basic_c2_network): # Assert to confirm that both the C2 server and the C2 beacon are configured correctly. assert c2_beacon.c2_config.keep_alive_frequency is 2 - assert c2_beacon.c2_config.masquerade_port is Port["HTTP"] - assert c2_beacon.c2_config.masquerade_protocol is IPProtocol["TCP"] + assert c2_beacon.c2_config.masquerade_port is PORT_LOOKUP["HTTP"] + assert c2_beacon.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] assert c2_server.c2_config.keep_alive_frequency is 2 - assert c2_server.c2_config.masquerade_port is Port["HTTP"] - assert c2_server.c2_config.masquerade_protocol is IPProtocol["TCP"] + assert c2_server.c2_config.masquerade_port is PORT_LOOKUP["HTTP"] + assert c2_server.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] # Configuring the C2 Beacon. c2_beacon.configure( c2_server_ip_address="192.168.0.1", keep_alive_frequency=2, - masquerade_port=Port["FTP"], - masquerade_protocol=IPProtocol["TCP"], + masquerade_port=PORT_LOOKUP["FTP"], + masquerade_protocol=PROTOCOL_LOOKUP["TCP"], ) # Asserting that the c2 applications have established a c2 connection @@ -150,11 +150,11 @@ def test_c2_handle_switching_port(basic_c2_network): # Assert to confirm that both the C2 server and the C2 beacon # Have reconfigured their C2 settings. - assert c2_beacon.c2_config.masquerade_port is Port["FTP"] - assert c2_beacon.c2_config.masquerade_protocol is IPProtocol["TCP"] + assert c2_beacon.c2_config.masquerade_port is PORT_LOOKUP["FTP"] + assert c2_beacon.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] - assert c2_server.c2_config.masquerade_port is Port["FTP"] - assert c2_server.c2_config.masquerade_protocol is IPProtocol["TCP"] + assert c2_server.c2_config.masquerade_port is PORT_LOOKUP["FTP"] + assert c2_server.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] def test_c2_handle_switching_frequency(basic_c2_network): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py index 229f98fe..f4750158 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py @@ -3,13 +3,13 @@ import pytest from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.networks import arcd_uc2_network -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.red_applications.data_manipulation_bot import ( DataManipulationAttackStage, DataManipulationBot, ) +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -27,8 +27,8 @@ def test_create_dm_bot(dm_client): data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software.get("DataManipulationBot") assert data_manipulation_bot.name == "DataManipulationBot" - assert data_manipulation_bot.port == Port["NONE"] - assert data_manipulation_bot.protocol == IPProtocol["NONE"] + assert data_manipulation_bot.port == PORT_LOOKUP["NONE"] + assert data_manipulation_bot.protocol == PROTOCOL_LOOKUP["NONE"] assert data_manipulation_bot.payload == "DELETE" diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py index 2acd991a..d0c65266 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py @@ -5,7 +5,7 @@ import pytest from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py index c274c18e..f5781485 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py @@ -4,10 +4,10 @@ import pytest from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.protocols.http import HttpResponsePacket, HttpStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -39,8 +39,8 @@ def test_create_web_client(): # Web Browser should be pre-installed in computer web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") assert web_browser.name is "WebBrowser" - assert web_browser.port is Port["HTTP"] - assert web_browser.protocol is IPProtocol["TCP"] + assert web_browser.port is PORT_LOOKUP["HTTP"] + assert web_browser.protocol is PROTOCOL_LOOKUP["TCP"] def test_receive_invalid_payload(web_browser): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py index 1a51708d..09099c5c 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py @@ -6,10 +6,10 @@ import pytest from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.protocols.dns import DNSPacket, DNSReply, DNSRequest -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -28,8 +28,8 @@ def test_create_dns_client(dns_client): assert dns_client is not None dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient") assert dns_client_service.name is "DNSClient" - assert dns_client_service.port is Port["DNS"] - assert dns_client_service.protocol is IPProtocol["TCP"] + assert dns_client_service.port is PORT_LOOKUP["DNS"] + assert dns_client_service.protocol is PROTOCOL_LOOKUP["TCP"] def test_dns_client_add_domain_to_cache_when_not_running(dns_client): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py index 8cdb1b84..688bfd7d 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py @@ -8,10 +8,10 @@ from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -32,8 +32,8 @@ def test_create_dns_server(dns_server): assert dns_server is not None dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer") assert dns_server_service.name is "DNSServer" - assert dns_server_service.port is Port["DNS"] - assert dns_server_service.protocol is IPProtocol["TCP"] + assert dns_server_service.port is PORT_LOOKUP["DNS"] + assert dns_server_service.protocol is PROTOCOL_LOOKUP["TCP"] def test_dns_server_domain_name_registration(dns_server): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py index 3c1afb28..b4fe8633 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py @@ -8,10 +8,10 @@ from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -31,8 +31,8 @@ def test_create_ftp_client(ftp_client): assert ftp_client is not None ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") assert ftp_client_service.name is "FTPClient" - assert ftp_client_service.port is Port["FTP"] - assert ftp_client_service.protocol is IPProtocol["TCP"] + assert ftp_client_service.port is PORT_LOOKUP["FTP"] + assert ftp_client_service.protocol is PROTOCOL_LOOKUP["TCP"] def test_ftp_client_store_file(ftp_client): @@ -61,7 +61,7 @@ def test_ftp_should_not_process_commands_if_service_not_running(ftp_client): """Method _process_ftp_command should return false if service is not running.""" payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.PORT, - ftp_command_args=Port["FTP"], + ftp_command_args=PORT_LOOKUP["FTP"], status_code=FTPStatusCode.OK, ) @@ -102,7 +102,7 @@ def test_offline_ftp_client_receives_request(ftp_client): payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.PORT, - ftp_command_args=Port["FTP"], + ftp_command_args=PORT_LOOKUP["FTP"], status_code=FTPStatusCode.OK, ) @@ -119,7 +119,7 @@ def test_receive_should_ignore_payload_with_none_status_code(ftp_client): """Receive should ignore payload with no set status code to prevent infinite send/receive loops.""" payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.PORT, - ftp_command_args=Port["FTP"], + ftp_command_args=PORT_LOOKUP["FTP"], status_code=None, ) ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py index aa13ec5e..3f10db4d 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py @@ -6,10 +6,10 @@ from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -30,8 +30,8 @@ def test_create_ftp_server(ftp_server): assert ftp_server is not None ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer") assert ftp_server_service.name is "FTPServer" - assert ftp_server_service.port is Port["FTP"] - assert ftp_server_service.protocol is IPProtocol["TCP"] + assert ftp_server_service.port is PORT_LOOKUP["FTP"] + assert ftp_server_service.protocol is PROTOCOL_LOOKUP["TCP"] def test_ftp_server_store_file(ftp_server): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 21ed839b..f2895091 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -18,13 +18,13 @@ from primaite.simulator.network.protocols.ssh import ( SSHTransportMessage, SSHUserCredentials, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection, Terminal from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -77,11 +77,15 @@ def wireless_wan_network(): network.connect(pc_a.network_interface[1], router_1.network_interface[2]) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # add ACL rule to allow SSH traffic - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["SSH"], dst_port=Port["SSH"], position=21) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=21 + ) # Configure PC B pc_b = Computer( @@ -329,7 +333,9 @@ def test_SSH_across_network(wireless_wan_network): terminal_a: Terminal = pc_a.software_manager.software.get("Terminal") terminal_b: Terminal = pc_b.software_manager.software.get("Terminal") - router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["SSH"], dst_port=Port["SSH"], position=21) + router_2.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=21 + ) assert len(terminal_a._connections) == 0 diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py index c1df3857..c78a381e 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py @@ -9,9 +9,9 @@ from primaite.simulator.network.protocols.http import ( HttpResponsePacket, HttpStatusCode, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -33,8 +33,8 @@ def test_create_web_server(web_server): assert web_server is not None web_server_service: WebServer = web_server.software_manager.software.get("WebServer") assert web_server_service.name is "WebServer" - assert web_server_service.port is Port["HTTP"] - assert web_server_service.protocol is IPProtocol["TCP"] + assert web_server_service.port is PORT_LOOKUP["HTTP"] + assert web_server_service.protocol is PROTOCOL_LOOKUP["TCP"] def test_handling_get_request_not_found_path(web_server): diff --git a/tests/unit_tests/_primaite/_simulator/_system/test_software.py b/tests/unit_tests/_primaite/_simulator/_system/test_software.py index b7a663af..1baaf88e 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/test_software.py +++ b/tests/unit_tests/_primaite/_simulator/_system/test_software.py @@ -3,11 +3,11 @@ from typing import Dict import pytest -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.service import Service from primaite.simulator.system.software import IOSoftware, SoftwareHealthState +from primaite.utils.validators import PROTOCOL_LOOKUP class TestSoftware(Service): @@ -19,10 +19,10 @@ class TestSoftware(Service): def software(file_system): return TestSoftware( name="TestSoftware", - port=Port["ARP"], + port=PORT_LOOKUP["ARP"], file_system=file_system, sys_log=SysLog(hostname="test_service"), - protocol=IPProtocol["TCP"], + protocol=PROTOCOL_LOOKUP["TCP"], ) diff --git a/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py b/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py index 8becc6ae..10ed36e0 100644 --- a/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py +++ b/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.utils.converters import convert_dict_enum_keys_to_enum_values +from primaite.utils.validators import PROTOCOL_LOOKUP def test_simple_conversion(): @@ -11,7 +11,7 @@ def test_simple_conversion(): The original dictionary contains one level of nested dictionary with enums as keys. The expected output should have string values of enums as keys. """ - original_dict = {IPProtocol["UDP"]: {Port["ARP"]: {"inbound": 0, "outbound": 1016.0}}} + original_dict = {PROTOCOL_LOOKUP["UDP"]: {PORT_LOOKUP["ARP"]: {"inbound": 0, "outbound": 1016.0}}} expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0}}} assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict @@ -36,8 +36,8 @@ def test_mixed_keys(): The expected output should have string values of enums and original string keys. """ original_dict = { - IPProtocol["TCP"]: {"port": {"inbound": 0, "outbound": 1016.0}}, - "protocol": {Port["HTTP"]: {"inbound": 10, "outbound": 2020.0}}, + PROTOCOL_LOOKUP["TCP"]: {"port": {"inbound": 0, "outbound": 1016.0}}, + "protocol": {PORT_LOOKUP["HTTP"]: {"inbound": 10, "outbound": 2020.0}}, } expected_dict = { "tcp": {"port": {"inbound": 0, "outbound": 1016.0}}, @@ -66,8 +66,12 @@ def test_nested_dicts(): The expected output should have string values of enums as keys at all levels. """ original_dict = { - IPProtocol["UDP"]: { - Port["ARP"]: {"inbound": 0, "outbound": 1016.0, "details": {IPProtocol["TCP"]: {"latency": "low"}}} + PROTOCOL_LOOKUP["UDP"]: { + PORT_LOOKUP["ARP"]: { + "inbound": 0, + "outbound": 1016.0, + "details": {PROTOCOL_LOOKUP["TCP"]: {"latency": "low"}}, + } } } expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0, "details": {"tcp": {"latency": "low"}}}}} @@ -82,8 +86,11 @@ def test_non_dict_values(): The expected output should preserve these non-dictionary values while converting enum keys to string values. """ original_dict = { - IPProtocol["UDP"]: [Port["ARP"], Port["HTTP"]], - "protocols": (IPProtocol["TCP"], IPProtocol["UDP"]), + PROTOCOL_LOOKUP["UDP"]: [PORT_LOOKUP["ARP"], PORT_LOOKUP["HTTP"]], + "protocols": (PROTOCOL_LOOKUP["TCP"], PROTOCOL_LOOKUP["UDP"]), + } + expected_dict = { + "udp": [PORT_LOOKUP["ARP"], PORT_LOOKUP["HTTP"]], + "protocols": (PROTOCOL_LOOKUP["TCP"], PROTOCOL_LOOKUP["UDP"]), } - expected_dict = {"udp": [Port["ARP"], Port["HTTP"]], "protocols": (IPProtocol["TCP"], IPProtocol["UDP"])} assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict