diff --git a/docs/source/simulation_components/network/network.rst b/docs/source/simulation_components/network/network.rst index 636ffbcc..4cc121a3 100644 --- a/docs/source/simulation_components/network/network.rst +++ b/docs/source/simulation_components/network/network.rst @@ -103,13 +103,13 @@ we'll use the following Network that has a client, server, two switches, and a r router_1.acl.add_rule( action=ACLAction.PERMIT, - src_port=Port.ARP, - dst_port=Port.ARP, + src_port=Port["ARP"], + dst_port=Port["ARP"], position=22 ) router_1.acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.ICMP, + protocol=IPProtocol["ICMP"], position=23 ) diff --git a/docs/source/simulation_components/network/nodes/firewall.rst b/docs/source/simulation_components/network/nodes/firewall.rst index 149d3e67..1ef16d63 100644 --- a/docs/source/simulation_components/network/nodes/firewall.rst +++ b/docs/source/simulation_components/network/nodes/firewall.rst @@ -156,8 +156,8 @@ To prevent all external traffic from accessing the internal network, with except # Exception rule to allow HTTP traffic from external to internal network firewall.internal_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTP, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTP"], dst_ip_address="192.168.1.0", dst_wildcard_mask="0.0.0.255", position=2 @@ -172,16 +172,16 @@ To enable external traffic to access specific services hosted within the DMZ: # Allow HTTP and HTTPS traffic to the DMZ firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTP, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTP"], dst_ip_address="172.16.0.0", dst_wildcard_mask="0.0.0.255", position=3 ) firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTPS, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.0", dst_wildcard_mask="0.0.0.255", position=4 @@ -196,9 +196,9 @@ To permit SSH access from a designated external IP to a specific server within t # Allow SSH from a specific external IP to an internal server firewall.internal_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="10.0.0.2", - dst_port=Port.SSH, + dst_port=Port["SSH"], dst_ip_address="192.168.1.10", position=5 ) @@ -212,9 +212,9 @@ To limit database server access to selected external IP addresses: # Allow PostgreSQL traffic from an authorized external IP to the internal DB server firewall.internal_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="10.0.0.3", - dst_port=Port.POSTGRES_SERVER, + dst_port=Port["POSTGRES_SERVER"], dst_ip_address="192.168.1.20", position=6 ) @@ -222,8 +222,8 @@ To limit database server access to selected external IP addresses: # Deny all other PostgreSQL traffic from external sources firewall.internal_inbound_acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, - dst_port=Port.POSTGRES_SERVER, + protocol=IPProtocol["TCP"], + dst_port=Port["POSTGRES_SERVER"], dst_ip_address="192.168.1.0", dst_wildcard_mask="0.0.0.255", position=7 @@ -247,15 +247,15 @@ To authorize HTTP/HTTPS access to a DMZ-hosted web server, excluding known malic # Allow HTTP/HTTPS traffic to the DMZ web server firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTP, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTP"], dst_ip_address="172.16.0.2", position=9 ) firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTPS, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.2", position=10 ) @@ -269,9 +269,9 @@ To facilitate restricted access from the internal network to DMZ-hosted services # Permit specific internal application server HTTPS access to a DMZ-hosted API firewall.internal_outbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="192.168.1.30", # Internal application server IP - dst_port=Port.HTTPS, + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.3", # DMZ API server IP position=11 ) @@ -289,9 +289,9 @@ To facilitate restricted access from the internal network to DMZ-hosted services # Corresponding rule in DMZ inbound ACL to allow the traffic from the specific internal server firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="192.168.1.30", # Ensuring this specific source is allowed - dst_port=Port.HTTPS, + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.3", # DMZ API server IP position=13 ) @@ -301,7 +301,7 @@ To facilitate restricted access from the internal network to DMZ-hosted services action=ACLAction.DENY, src_ip_address="192.168.1.0", src_wildcard_mask="0.0.0.255", - dst_port=Port.HTTPS, + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.3", # DMZ API server IP position=14 ) @@ -315,8 +315,8 @@ To block all SSH access attempts from the external network: # Deny all SSH traffic from any external source firewall.external_inbound_acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, - dst_port=Port.SSH, + protocol=IPProtocol["TCP"], + dst_port=Port["SSH"], position=1 ) @@ -329,8 +329,8 @@ To allow the internal network to initiate HTTP connections to the external netwo # Permit outgoing HTTP traffic from the internal network to any external destination firewall.external_outbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTP, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTP"], position=2 ) diff --git a/docs/source/simulation_components/network/nodes/wireless_router.rst b/docs/source/simulation_components/network/nodes/wireless_router.rst index c78c8419..bd361afa 100644 --- a/docs/source/simulation_components/network/nodes/wireless_router.rst +++ b/docs/source/simulation_components/network/nodes/wireless_router.rst @@ -49,7 +49,7 @@ additional steps to configure wireless settings: wireless_router.configure_wireless_access_point( port=1, ip_address="192.168.2.1", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency.WIFI_2_4, + frequency=AirSpaceFrequency["WIFI_2_4"], ) @@ -102,8 +102,8 @@ ICMP traffic, ensuring basic network connectivity and ping functionality. network.connect(pc_a.network_interface[1], router_1.router_interface) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) # Configure PC B pc_b = Computer( @@ -130,13 +130,13 @@ ICMP traffic, ensuring basic network connectivity and ping functionality. port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency.WIFI_2_4, + frequency=AirSpaceFrequency["WIFI_2_4"], ) router_2.configure_wireless_access_point( port=1, ip_address="192.168.1.2", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency.WIFI_2_4, + frequency=AirSpaceFrequency["WIFI_2_4"], ) # Configure routes for inter-router communication diff --git a/docs/source/simulation_components/network/transport_to_data_link_layer.rst b/docs/source/simulation_components/network/transport_to_data_link_layer.rst index cc546021..02bfdcdc 100644 --- a/docs/source/simulation_components/network/transport_to_data_link_layer.rst +++ b/docs/source/simulation_components/network/transport_to_data_link_layer.rst @@ -104,7 +104,7 @@ address of 'aa:bb:cc:dd:ee:ff' to port 8080 on the host 10.0.0.10 which has a NI ip_packet = IPPacket( src_ip_address="192.168.0.100", dst_ip_address="10.0.0.10", - protocol=IPProtocol.TCP + protocol=IPProtocol["TCP"] ) # Data Link Layer ethernet_header = EthernetHeader( diff --git a/docs/source/simulation_components/system/applications/nmap.rst b/docs/source/simulation_components/system/applications/nmap.rst index 1e7f5ea4..e2cd474e 100644 --- a/docs/source/simulation_components/system/applications/nmap.rst +++ b/docs/source/simulation_components/system/applications/nmap.rst @@ -165,8 +165,8 @@ Perform a horizontal port scan on port 5432 across multiple IP addresses: { IPv4Address('192.168.1.12'): { - : [ - + : [ + ] } } @@ -192,7 +192,7 @@ Perform a vertical port scan on multiple ports on a single IP address: vertical_scan_results = pc_1_nmap.port_scan( target_ip_address=[IPv4Address("192.168.1.12")], - target_port=[Port(21), Port(22), Port(80), Port(443)] + target_port=[21, 22, 80, 443] ) .. code-block:: python @@ -200,9 +200,9 @@ Perform a vertical port scan on multiple ports on a single IP address: { IPv4Address('192.168.1.12'): { - : [ - , - + : [ + , + ] } } @@ -233,7 +233,7 @@ Perform a box scan on multiple ports across multiple IP addresses: box_scan_results = pc_1_nmap.port_scan( target_ip_address=[IPv4Address("192.168.1.12"), IPv4Address("192.168.1.13")], - target_port=[Port(21), Port(22), Port(80), Port(443)] + target_port=[21, 22, 80, 443] ) .. code-block:: python @@ -241,15 +241,15 @@ Perform a box scan on multiple ports across multiple IP addresses: { IPv4Address('192.168.1.13'): { - : [ - , - + : [ + , + ] }, IPv4Address('192.168.1.12'): { - : [ - , - + : [ + , + ] } } @@ -289,36 +289,36 @@ Perform a full box scan on all ports, over both TCP and UDP, on a whole subnet: { IPv4Address('192.168.1.11'): { - : [ - + : [ + ] }, IPv4Address('192.168.1.1'): { - : [ - + : [ + ] }, IPv4Address('192.168.1.12'): { - : [ - , - , - , - + : [ + , + , + , + ], - : [ - , - + : [ + , + ] }, IPv4Address('192.168.1.13'): { - : [ - , - , - + : [ + , + , + ], - : [ - , - + : [ + , + ] } } diff --git a/docs/source/simulation_components/system/services/ftp_client.rst b/docs/source/simulation_components/system/services/ftp_client.rst index fdf9cfcf..0c9a781c 100644 --- a/docs/source/simulation_components/system/services/ftp_client.rst +++ b/docs/source/simulation_components/system/services/ftp_client.rst @@ -15,7 +15,7 @@ Key features - Connects to the :ref:`FTPServer` via the ``SoftwareManager``. - Simulates FTP requests and FTPPacket transfer across a network - Allows the emulation of FTP commands between an FTP client and server: - - PORT: specifies the port that server should connect to on the client (currently only uses ``Port.FTP``) + - PORT: specifies the port that server should connect to on the client (currently only uses ``Port["FTP"]``) - STOR: stores a file from client to server - RETR: retrieves a file from the FTP server - QUIT: disconnect from server diff --git a/notebooks/test.ipynb b/notebooks/test.ipynb new file mode 100644 index 00000000..5afe04b0 --- /dev/null +++ b/notebooks/test.ipynb @@ -0,0 +1,157 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import yaml\n", + "\n", + "from primaite.game.game import PrimaiteGame\n", + "from primaite.session.environment import PrimaiteGymEnv\n", + "from primaite.simulator.network.hardware.nodes.host.computer import Computer\n", + "from primaite.simulator.network.hardware.nodes.host.server import Server\n", + "from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection\n", + "from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot\n", + "from primaite.simulator.system.services.database.database_service import DatabaseService\n", + "from primaite import getLogger, PRIMAITE_PATHS" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "with open(PRIMAITE_PATHS.user_config_path / \"example_config\" / \"data_manipulation.yaml\") as f:\n", + " cfg = yaml.safe_load(f)\n", + "game = PrimaiteGame.from_config(cfg)\n", + "uc2_network = game.simulation.network" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "client_1: Computer = uc2_network.get_node_by_hostname(\"client_1\")\n", + "db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get(\"DataManipulationBot\")\n", + "\n", + "database_server: Server = uc2_network.get_node_by_hostname(\"database_server\")\n", + "db_service: DatabaseService = database_server.software_manager.software.get(\"DatabaseService\")\n", + "\n", + "web_server: Server = uc2_network.get_node_by_hostname(\"web_server\")\n", + "db_client: DatabaseClient = web_server.software_manager.software.get(\"DatabaseClient\")\n", + "db_connection: DatabaseClientConnection = db_client.get_new_connection()\n", + "db_service.backup_database()\n", + "\n", + "# First check that the DB client on the web_server can successfully query the users table on the database\n", + "assert db_connection.query(\"SELECT\")\n", + "\n", + "db_manipulation_bot.data_manipulation_p_of_success = 1.0\n", + "db_manipulation_bot.port_scan_p_of_success = 1.0\n", + "\n", + "# Now we run the DataManipulationBot\n", + "db_manipulation_bot.attack()\n", + "\n", + "# Now check that the DB client on the web_server cannot query the users table on the database\n", + "assert not db_connection.query(\"SELECT\")\n", + "\n", + "# Now restore the database\n", + "db_service.restore_backup()\n", + "\n", + "# Now check that the DB client on the web_server can successfully query the users table on the database\n", + "assert db_connection.query(\"SELECT\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "router = uc2_network.get_node_by_hostname(\"router_1\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----------------------------------------------------------------------------------------------------------------+\n", + "| router_1 Access Control List |\n", + "+-------+--------+----------+--------+--------------+-------------+--------+--------------+-------------+---------+\n", + "| Index | Action | Protocol | Src IP | Src Wildcard | Src Port | Dst IP | Dst Wildcard | Dst Port | Matched |\n", + "+-------+--------+----------+--------+--------------+-------------+--------+--------------+-------------+---------+\n", + "| 18 | PERMIT | ANY | ANY | ANY | 5432 (5432) | ANY | ANY | 5432 (5432) | 4 |\n", + "| 19 | PERMIT | ANY | ANY | ANY | 53 (53) | ANY | ANY | 53 (53) | 0 |\n", + "| 20 | PERMIT | ANY | ANY | ANY | 21 (21) | ANY | ANY | 21 (21) | 0 |\n", + "| 21 | PERMIT | ANY | ANY | ANY | 80 (80) | ANY | ANY | 80 (80) | 0 |\n", + "| 22 | PERMIT | ANY | ANY | ANY | 219 (219) | ANY | ANY | 219 (219) | 9 |\n", + "| 23 | PERMIT | icmp | ANY | ANY | ANY | ANY | ANY | ANY | 0 |\n", + "| 24 | DENY | ANY | ANY | ANY | ANY | ANY | ANY | ANY | 0 |\n", + "+-------+--------+----------+--------+--------------+-------------+--------+--------------+-------------+---------+\n" + ] + } + ], + "source": [ + "router.acl.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "AccessControlList.is_permitted() missing 1 required positional argument: 'frame'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mrouter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43macl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_permitted\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mTypeError\u001b[0m: AccessControlList.is_permitted() missing 1 required positional argument: 'frame'" + ] + } + ], + "source": [ + "router.acl.is_permitted()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py index 41af5a8f..abb6f1f8 100644 --- a/src/primaite/game/agent/observations/acl_observation.py +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -10,6 +10,8 @@ from gymnasium.core import ObsType from primaite import getLogger from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port _LOGGER = getLogger(__name__) @@ -61,7 +63,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)} self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)} self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)} - self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)} + self.protocol_to_id: Dict[str, int] = {IPProtocol[p]: i + 2 for i, p in enumerate(protocol_list)} self.default_observation: Dict = { i + 1: { diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 4419ccc7..05b25952 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -5,6 +5,7 @@ from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType +from pydantic import field_validator from primaite import getLogger from primaite.game.agent.observations.file_system_observations import FolderObservation @@ -12,6 +13,8 @@ from primaite.game.agent.observations.nic_observations import NICObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port _LOGGER = getLogger(__name__) @@ -55,6 +58,21 @@ class HostObservation(AbstractObservation, identifier="HOST"): include_users: Optional[bool] = True """If True, report user session information.""" + @field_validator('monitored_traffic', mode='before') + def traffic_lookup(cls, val:Optional[Dict]) -> Optional[Dict]: + if val is None: + return val + new_val = {} + for proto, port_list in val.items(): + # convert protocol, for instance ICMP becomes "icmp" + proto = IPProtocol[proto] if proto in IPProtocol else proto + new_val[proto] = [] + for port in port_list: + # convert ports, for instance "HTTP" becomes 80 + port = Port[port] if port in Port else port + new_val[proto].append(port) + return new_val + def __init__( self, where: WhereType, diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 002ee4da..200187f5 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -5,9 +5,11 @@ from typing import Dict, Optional from gymnasium import spaces from gymnasium.core import ObsType +from pydantic import field_validator from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port @@ -24,6 +26,22 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): monitored_traffic: Optional[Dict] = None """A dict containing which traffic types are to be included in the observation.""" + @field_validator('monitored_traffic', mode='before') + def traffic_lookup(cls, val:Optional[Dict]) -> Optional[Dict]: + if val is None: + return val + new_val = {} + for proto, port_list in val.items(): + # convert protocol, for instance ICMP becomes "icmp" + proto = IPProtocol[proto] if proto in IPProtocol else proto + new_val[proto] = [] + for port in port_list: + # convert ports, for instance "HTTP" becomes 80 + port = Port[port] if port in Port else port + new_val[proto].append(port) + return new_val + + def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None: """ Initialise a network interface observation instance. @@ -67,7 +85,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): else: default_traffic_obs["TRAFFIC"][protocol] = {} for port in monitored_traffic_config[protocol]: - default_traffic_obs["TRAFFIC"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0} + default_traffic_obs["TRAFFIC"][protocol] = {"inbound": 0, "outbound": 0} return default_traffic_obs @@ -142,17 +160,16 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): } else: for port in self.monitored_traffic[protocol]: - port_enum = Port[port] - obs["TRAFFIC"][protocol][port_enum.value] = {} + obs["TRAFFIC"][protocol][port] = {} traffic = {"inbound": 0, "outbound": 0} - if nic_state["traffic"][protocol].get(port_enum.value) is not None: - traffic = nic_state["traffic"][protocol][port_enum.value] + if nic_state["traffic"][protocol].get(port) is not None: + traffic = nic_state["traffic"][protocol][port] - obs["TRAFFIC"][protocol][port_enum.value]["inbound"] = self._categorise_traffic( + obs["TRAFFIC"][protocol][port]["inbound"] = self._categorise_traffic( traffic_value=traffic["inbound"], nic_state=nic_state ) - obs["TRAFFIC"][protocol][port_enum.value]["outbound"] = self._categorise_traffic( + obs["TRAFFIC"][protocol][port]["outbound"] = self._categorise_traffic( traffic_value=traffic["outbound"], nic_state=nic_state ) @@ -162,7 +179,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): obs["TRAFFIC"]["icmp"] = {"inbound": 0, "outbound": 0} else: for port in self.monitored_traffic[protocol]: - obs["TRAFFIC"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0} + obs["TRAFFIC"][protocol][port] = {"inbound": 0, "outbound": 0} if self.include_nmne: obs.update({"NMNE": {}}) @@ -201,7 +218,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): else: space["TRAFFIC"][protocol] = spaces.Dict({}) for port in self.monitored_traffic[protocol]: - space["TRAFFIC"][protocol][Port[port].value] = spaces.Dict( + space["TRAFFIC"][protocol][port] = spaces.Dict( {"inbound": spaces.Discrete(11), "outbound": spaces.Discrete(11)} ) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index e263cadb..3e51c3b3 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -5,13 +5,15 @@ from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType -from pydantic import model_validator +from pydantic import field_validator, model_validator from primaite import getLogger from primaite.game.agent.observations.firewall_observation import FirewallObservation from primaite.game.agent.observations.host_observations import HostObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.observations.router_observation import RouterObservation +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port _LOGGER = getLogger(__name__) @@ -61,6 +63,21 @@ class NodesObservation(AbstractObservation, identifier="NODES"): num_rules: Optional[int] = None """Number of rules ACL rules to show.""" + @field_validator('monitored_traffic', mode='before') + def traffic_lookup(cls, val:Optional[Dict]) -> Optional[Dict]: + if val is None: + return val + new_val = {} + for proto, port_list in val.items(): + # convert protocol, for instance ICMP becomes "icmp" + proto = IPProtocol[proto] if proto in IPProtocol else proto + new_val[proto] = [] + for port in port_list: + # convert ports, for instance "HTTP" becomes 80 + port = Port[port] if port in Port else port + new_val[proto].append(port) + return new_val + @model_validator(mode="after") def force_optional_fields(self) -> NodesObservation.ConfigSchema: """Check that options are specified only if they are needed for the nodes that are part of the config.""" diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 11c968af..e8329c63 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -4,7 +4,7 @@ from ipaddress import IPv4Address from typing import Dict, List, Optional, Union import numpy as np -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, field_validator from primaite import DEFAULT_BANDWIDTH, getLogger from primaite.game.agent.actions import ActionManager @@ -82,13 +82,40 @@ class PrimaiteGameOptions(BaseModel): """Random number seed for RNGs.""" max_episode_length: int = 256 """Maximum number of episodes for the PrimAITE game.""" - ports: List[str] + ports: List[int] """A whitelist of available ports in the simulation.""" protocols: List[str] """A whitelist of available protocols in the simulation.""" thresholds: Optional[Dict] = {} """A dict containing the thresholds used for determining what is acceptable during observations.""" + @field_validator('ports', mode='before') + def ports_str2int(cls, vals:Union[List[str],List[int]]) -> List[int]: + """ + Convert named port strings to port integer values. Integer ports remain unaffected. + + This is necessary to retain backwards compatibility with configs written for PrimAITE<=3.3. + :warning: This will be deprecated in PrimAITE 4.0 and configs will need to be converted. + """ + for i, port_val in enumerate(vals): + if port_val in Port: + vals[i] = Port[port_val] + return vals + + @field_validator('protocols', mode='before') + def protocols_str2int(cls, vals:List[str]) -> List[str]: + """ + Convert old-style named protocols to their proper values. + + This is necessary to retain backwards compatibility with configs written for PrimAITE<=3.3. + :warning: This will be deprecated in PrimAITE 4.0 and configs will need to be converted. + """ + for i, proto_val in enumerate(vals): + if proto_val in IPProtocol: + vals[i] = IPProtocol[proto_val] + return vals + + class PrimaiteGame: """ @@ -358,7 +385,7 @@ class PrimaiteGame: for port_id in set(software_cfg.get("options", {}).get("listen_on_ports", [])): port = None if isinstance(port_id, int): - port = Port(port_id) + port = port_id elif isinstance(port_id, str): port = Port[port_id] if port: @@ -475,7 +502,7 @@ class PrimaiteGame: opt = application_cfg["options"] new_application.configure( target_ip_address=IPv4Address(opt.get("target_ip_address")), - target_port=Port(opt.get("target_port", Port.POSTGRES_SERVER.value)), + target_port = Port[opt.get("target_port", "POSTGRES_SERVER")], payload=opt.get("payload"), repeat=bool(opt.get("repeat")), port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), @@ -488,8 +515,8 @@ class PrimaiteGame: new_application.configure( c2_server_ip_address=IPv4Address(opt.get("c2_server_ip_address")), keep_alive_frequency=(opt.get("keep_alive_frequency", 5)), - masquerade_protocol=IPProtocol[(opt.get("masquerade_protocol", IPProtocol.TCP))], - masquerade_port=Port[(opt.get("masquerade_port", Port.HTTP))], + masquerade_protocol=IPProtocol[(opt.get("masquerade_protocol", IPProtocol["TCP"]))], + masquerade_port=Port[(opt.get("masquerade_port", Port["HTTP"]))], ) if "network_interfaces" in node_cfg: for nic_num, nic_cfg in node_cfg["network_interfaces"].items(): diff --git a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb b/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb index b6b13f28..a5cc385b 100644 --- a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb @@ -1783,7 +1783,7 @@ "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", "from primaite.simulator.network.transmission.transport_layer import Port\n", "# As we're configuring via the PrimAITE API we need to pass the actual IPProtocol/Port (Agents leverage the simulation via the game layer and thus can pass strings).\n", - "c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=IPProtocol.UDP, masquerade_port=Port.DNS)\n", + "c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=IPProtocol["UDP"], masquerade_port=Port["DNS"])\n", "c2_beacon.establish()\n", "c2_beacon.show()" ] diff --git a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb index 77ac4842..f573f251 100644 --- a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb +++ b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb @@ -182,7 +182,7 @@ "metadata": {}, "outputs": [], "source": [ - "mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port.HTTP, protocol = IPProtocol.NONE,operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)" + "mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port["HTTP"], protocol = IPProtocol["NONE"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)" ] }, { diff --git a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb b/src/primaite/simulator/_package_data/network_simulator_demo.ipynb index 17a0f796..2d5b4772 100644 --- a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb +++ b/src/primaite/simulator/_package_data/network_simulator_demo.ipynb @@ -537,7 +537,7 @@ "from primaite.simulator.network.hardware.nodes.network.router import ACLAction\n", "network.get_node_by_hostname(\"router_1\").acl.add_rule(\n", " action=ACLAction.DENY,\n", - " protocol=IPProtocol.ICMP,\n", + " protocol=IPProtocol["ICMP"],\n", " src_ip_address=\"192.168.10.22\",\n", " position=1\n", ")" diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index cdb01514..29326df8 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -3,10 +3,12 @@ from __future__ import annotations from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Dict, List +from typing import Any, ClassVar, Dict, List, Type +from pydantic._internal._generics import PydanticGenericMetadata +from typing_extensions import Unpack from prettytable import MARKDOWN, PrettyTable -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from primaite import getLogger from primaite.simulator.network.hardware.base import Layer3Interface, NetworkInterface, WiredNetworkInterface @@ -40,51 +42,29 @@ def format_hertz(hertz: float, format_terahertz: bool = False, decimals: int = 3 else: # Hertz return format_str.format(hertz) + " Hz" +AirSpaceFrequencyRegistry: Dict[str,Dict] = { + "WIFI_2_4" : {'frequency': 2.4e9, 'data_rate_bps':100_000_000.0}, + "WIFI_5" : {'frequency': 5e9, 'data_rate_bps':500_000_000.0}, +} -class AirSpaceFrequency(Enum): - """Enumeration representing the operating frequencies for wireless communications.""" +def register_frequency(freq_name: str, freq_hz: int, data_rate_bps: int) -> None: + if freq_name in AirSpaceFrequencyRegistry: + raise RuntimeError(f"Cannot register new frequency {freq_name} because it's already registered.") + AirSpaceFrequencyRegistry.update({freq_name:{'frequency': freq_hz, 'data_rate_bps':data_rate_bps}}) - WIFI_2_4 = 2.4e9 - """WiFi 2.4 GHz. Known for its extensive range and ability to penetrate solid objects effectively.""" - WIFI_5 = 5e9 - """WiFi 5 GHz. Known for its higher data transmission speeds and reduced interference from other devices.""" +def maximum_data_rate_mbps(frequency_name:str) -> float: + """ + Retrieves the maximum data transmission rate in megabits per second (Mbps). - def __str__(self) -> str: - hertz_str = format_hertz(hertz=self.value) - if self == AirSpaceFrequency.WIFI_2_4: - return f"WiFi {hertz_str}" - if self == AirSpaceFrequency.WIFI_5: - return f"WiFi {hertz_str}" - return "Unknown Frequency" + This is derived by converting the maximum data rate from bits per second, as defined + in `maximum_data_rate_bps`, to megabits per second. - @property - def maximum_data_rate_bps(self) -> float: - """ - Retrieves the maximum data transmission rate in bits per second (bps). + :return: The maximum data rate in megabits per second. + """ + return AirSpaceFrequencyRegistry[frequency_name]['data_rate_bps'] + return data_rate / 1_000_000.0 - The maximum rates are predefined for frequencies.: - - WIFI 2.4 supports 100,000,000 bps - - WIFI 5 supports 500,000,000 bps - :return: The maximum data rate in bits per second. - """ - if self == AirSpaceFrequency.WIFI_2_4: - return 100_000_000.0 # 100 Megabits per second - if self == AirSpaceFrequency.WIFI_5: - return 500_000_000.0 # 500 Megabits per second - return 0.0 - - @property - def maximum_data_rate_mbps(self) -> float: - """ - Retrieves the maximum data transmission rate in megabits per second (Mbps). - - This is derived by converting the maximum data rate from bits per second, as defined - in `maximum_data_rate_bps`, to megabits per second. - - :return: The maximum data rate in megabits per second. - """ - return self.maximum_data_rate_bps / 1_000_000.0 class AirSpace(BaseModel): @@ -97,13 +77,13 @@ class AirSpace(BaseModel): """ wireless_interfaces: Dict[str, WirelessNetworkInterface] = Field(default_factory=lambda: {}) - wireless_interfaces_by_frequency: Dict[AirSpaceFrequency, List[WirelessNetworkInterface]] = Field( + wireless_interfaces_by_frequency: Dict[int, List[WirelessNetworkInterface]] = Field( default_factory=lambda: {} ) - bandwidth_load: Dict[AirSpaceFrequency, float] = Field(default_factory=lambda: {}) - frequency_max_capacity_mbps_: Dict[AirSpaceFrequency, float] = Field(default_factory=lambda: {}) + bandwidth_load: Dict[int, float] = Field(default_factory=lambda: {}) + frequency_max_capacity_mbps_: Dict[int, float] = Field(default_factory=lambda: {}) - def get_frequency_max_capacity_mbps(self, frequency: AirSpaceFrequency) -> float: + def get_frequency_max_capacity_mbps(self, frequency: str) -> float: """ Retrieves the maximum data transmission capacity for a specified frequency. @@ -117,9 +97,9 @@ class AirSpace(BaseModel): """ if frequency in self.frequency_max_capacity_mbps_: return self.frequency_max_capacity_mbps_[frequency] - return frequency.maximum_data_rate_mbps + return maximum_data_rate_mbps(frequency) - def set_frequency_max_capacity_mbps(self, cfg: Dict[AirSpaceFrequency, float]): + def set_frequency_max_capacity_mbps(self, cfg: Dict[int, float]): """ Sets custom maximum data transmission capacities for multiple frequencies. @@ -150,7 +130,7 @@ class AirSpace(BaseModel): load_percent = load / maximum_capacity if maximum_capacity > 0 else 0.0 if load_percent > 1.0: load_percent = 1.0 - table.add_row([format_hertz(frequency.value), f"{load_percent:.0%}", f"{maximum_capacity:.3f}"]) + table.add_row([format_hertz(frequency), f"{load_percent:.0%}", f"{maximum_capacity:.3f}"]) print(table) def show_wireless_interfaces(self, markdown: bool = False): @@ -182,7 +162,7 @@ class AirSpace(BaseModel): interface.mac_address, interface.ip_address if hasattr(interface, "ip_address") else None, interface.subnet_mask if hasattr(interface, "subnet_mask") else None, - format_hertz(interface.frequency.value), + format_hertz(interface.frequency), f"{interface.speed:.3f}", status, ] @@ -298,7 +278,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC): """ airspace: AirSpace - frequency: AirSpaceFrequency = AirSpaceFrequency.WIFI_2_4 + frequency: str = "WIFI_2_4" def enable(self): """Attempt to enable the network interface.""" @@ -430,7 +410,7 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC) # Update the state with information from Layer3Interface state.update(Layer3Interface.describe_state(self)) - state["frequency"] = self.frequency.value + state["frequency"] = self.frequency return state diff --git a/src/primaite/simulator/network/creation.py b/src/primaite/simulator/network/creation.py index 61a37a90..c2524b4b 100644 --- a/src/primaite/simulator/network/creation.py +++ b/src/primaite/simulator/network/creation.py @@ -98,8 +98,8 @@ def create_office_lan( default_gateway = IPv4Address(f"192.168.{subnet_base}.1") router = Router(hostname=f"router_{lan_name}", start_up_duration=0) router.power_on() - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) network.add_node(router) router.configure_port(port=1, ip_address=default_gateway, subnet_mask="255.255.255.0") router.enable_port(1) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index bf230e07..4154cc08 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -203,16 +203,16 @@ class NetworkInterface(SimComponent, ABC): # Initialise basic frame data variables direction = "inbound" if inbound else "outbound" # Direction of the traffic ip_address = str(frame.ip.src_ip_address if inbound else frame.ip.dst_ip_address) # Source or destination IP - protocol = frame.ip.protocol.name # Network protocol used in the frame + protocol = frame.ip.protocol # Network protocol used in the frame # Initialise port variable; will be determined based on protocol type port = None # Determine the source or destination port based on the protocol (TCP/UDP) if frame.tcp: - port = frame.tcp.src_port.value if inbound else frame.tcp.dst_port.value + port = frame.tcp.src_port if inbound else frame.tcp.dst_port elif frame.udp: - port = frame.udp.src_port.value if inbound else frame.udp.dst_port.value + port = frame.udp.src_port if inbound else frame.udp.dst_port # Convert frame payload to string for keyword checking frame_str = str(frame.payload) @@ -274,20 +274,20 @@ class NetworkInterface(SimComponent, ABC): # Identify the protocol and port from the frame if frame.tcp: - protocol = IPProtocol.TCP + protocol = IPProtocol["TCP"] port = frame.tcp.dst_port elif frame.udp: - protocol = IPProtocol.UDP + protocol = IPProtocol["UDP"] port = frame.udp.dst_port elif frame.icmp: - protocol = IPProtocol.ICMP + protocol = IPProtocol["ICMP"] # Ensure the protocol is in the capture dict if protocol not in self.traffic: self.traffic[protocol] = {} # Handle non-ICMP protocols that use ports - if protocol != IPProtocol.ICMP: + if protocol != IPProtocol["ICMP"]: if port not in self.traffic[protocol]: self.traffic[protocol][port] = {"inbound": 0, "outbound": 0} self.traffic[protocol][port][direction] += frame.size_Mbits @@ -843,8 +843,8 @@ class UserManager(Service): :param password: The password for the default admin user """ kwargs["name"] = "UserManager" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["port"] = Port["NONE"] + kwargs["protocol"] = IPProtocol["NONE"] super().__init__(**kwargs) self.start() @@ -1166,8 +1166,8 @@ class UserSessionManager(Service): :param password: The password for the default admin user """ kwargs["name"] = "UserSessionManager" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["port"] = Port["NONE"] + kwargs["protocol"] = IPProtocol["NONE"] super().__init__(**kwargs) self.start() @@ -1312,7 +1312,7 @@ class UserSessionManager(Service): software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( payload={"type": "user_timeout", "connection_id": session.uuid}, - dest_port=Port.SSH, + dest_port=Port["SSH"], dest_ip_address=session.remote_ip_address, ) @@ -1845,8 +1845,9 @@ class Node(SimComponent): table.align = "l" table.title = f"{self.hostname} Open Ports" for port in self.software_manager.get_open_ports(): - if port.value > 0: - table.add_row([port.value, port.name]) + if port > 0: + # TODO: do a reverse lookup for port name, or change this to only show port int + table.add_row([port, port]) print(table.get_string(sortby="Port")) @property diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 4510eac0..6d8e084d 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -58,8 +58,8 @@ class Firewall(Router): >>> # Permit HTTP traffic to the DMZ >>> firewall.dmz_inbound_acl.add_rule( ... action=ACLAction.PERMIT, - ... protocol=IPProtocol.TCP, - ... dst_port=Port.HTTP, + ... protocol=IPProtocol["TCP"], + ... dst_port=Port["HTTP"], ... src_ip_address="0.0.0.0", ... src_wildcard_mask="0.0.0.0", ... dst_ip_address="172.16.0.0", diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index ceb91695..013c473e 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -7,7 +7,7 @@ from ipaddress import IPv4Address, IPv4Network from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union from prettytable import MARKDOWN, PrettyTable -from pydantic import validate_call +from pydantic import field_validator, validate_call from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent @@ -106,7 +106,7 @@ class ACLRule(SimComponent): :ivar ACLAction action: Specifies whether to `PERMIT` or `DENY` the traffic that matches the rule conditions. The default action is `DENY`. - :ivar Optional[IPProtocol] protocol: The network protocol (e.g., TCP, UDP, ICMP) to match. If `None`, the rule + :ivar Optional[str] protocol: The network protocol (e.g., TCP, UDP, ICMP) to match. If `None`, the rule applies to all protocols. :ivar Optional[IPV4Address] src_ip_address: The source IP address to match. If combined with `src_wildcard_mask`, it specifies the start of an IP range. @@ -116,20 +116,33 @@ class ACLRule(SimComponent): `dst_wildcard_mask`, it specifies the start of an IP range. :ivar Optional[IPv4Address] dst_wildcard_mask: The wildcard mask for the destination IP address, defining the range of addresses to match. - :ivar Optional[Port] src_port: The source port number to match. Relevant for TCP/UDP protocols. - :ivar Optional[Port] dst_port: The destination port number to match. Relevant for TCP/UDP protocols. + :ivar Optional[int] src_port: The source port number to match. Relevant for TCP/UDP protocols. + :ivar Optional[int] dst_port: The destination port number to match. Relevant for TCP/UDP protocols. """ action: ACLAction = ACLAction.DENY - protocol: Optional[IPProtocol] = None + protocol: Optional[str] = None src_ip_address: Optional[IPV4Address] = None src_wildcard_mask: Optional[IPV4Address] = None dst_ip_address: Optional[IPV4Address] = None dst_wildcard_mask: Optional[IPV4Address] = None - src_port: Optional[Port] = None - dst_port: Optional[Port] = None + src_port: Optional[int] = None + dst_port: Optional[int] = None match_count: int = 0 + @field_validator('protocol', mode='before') + def protocol_valid(cls, val:Optional[str]) -> Optional[str]: + if val is not None: + assert val in IPProtocol.values(), f"Cannot create ACL rule with invalid protocol {val}" + return val + + @field_validator('src_port', 'dst_port', mode='before') + def ports_valid(cls, val:Optional[int]) -> Optional[int]: + if val is not None: + assert val in Port.values(), f"Cannot create ACL rule with invalid port {val}" + return val + + def __str__(self) -> str: rule_strings = [] for key, value in self.model_dump(exclude={"uuid", "request_manager"}).items(): @@ -149,13 +162,13 @@ class ACLRule(SimComponent): """ state = super().describe_state() state["action"] = self.action.value - state["protocol"] = self.protocol.name if self.protocol else None + state["protocol"] = self.protocol if self.protocol else None state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None state["src_wildcard_mask"] = str(self.src_wildcard_mask) if self.src_wildcard_mask else None - state["src_port"] = self.src_port.name if self.src_port else None + state["src_port"] = self.src_port if self.src_port else None state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None state["dst_wildcard_mask"] = str(self.dst_wildcard_mask) if self.dst_wildcard_mask else None - state["dst_port"] = self.dst_port.name if self.dst_port else None + state["dst_port"] = self.dst_port if self.dst_port else None state["match_count"] = self.match_count return state @@ -265,7 +278,7 @@ class AccessControlList(SimComponent): >>> acl = AccessControlList() >>> acl.add_rule( ... action=ACLAction.PERMIT, - ... protocol=IPProtocol.TCP, + ... protocol=IPProtocol["TCP"], ... src_ip_address="192.168.1.0", ... src_wildcard_mask="0.0.0.255", ... dst_ip_address="192.168.2.0", @@ -323,13 +336,13 @@ class AccessControlList(SimComponent): func=lambda request, context: RequestResponse.from_bool( self.add_rule( action=ACLAction[request[0]], - protocol=None if request[1] == "ALL" else IPProtocol[request[1]], + protocol=None if request[1] == "ALL" else request[1], src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]), src_wildcard_mask=None if request[3] == "NONE" else IPv4Address(request[3]), - src_port=None if request[4] == "ALL" else Port[request[4]], + src_port=None if request[4] == "ALL" else request[4], dst_ip_address=None if request[5] == "ALL" else IPv4Address(request[5]), dst_wildcard_mask=None if request[6] == "NONE" else IPv4Address(request[6]), - dst_port=None if request[7] == "ALL" else Port[request[7]], + dst_port=None if request[7] == "ALL" else request[7], position=int(request[8]), ) ) @@ -377,13 +390,13 @@ class AccessControlList(SimComponent): def add_rule( self, action: ACLAction = ACLAction.DENY, - protocol: Optional[IPProtocol] = None, + protocol: Optional[str] = None, src_ip_address: Optional[IPV4Address] = None, src_wildcard_mask: Optional[IPV4Address] = None, dst_ip_address: Optional[IPV4Address] = None, dst_wildcard_mask: Optional[IPV4Address] = None, - src_port: Optional[Port] = None, - dst_port: Optional[Port] = None, + src_port: Optional[int] = None, + dst_port: Optional[int] = None, position: int = 0, ) -> bool: """ @@ -399,11 +412,11 @@ class AccessControlList(SimComponent): >>> router = Router("router") >>> router.add_rule( ... action=ACLAction.DENY, - ... protocol=IPProtocol.TCP, + ... protocol=IPProtocol["TCP"], ... src_ip_address="192.168.1.0", ... src_wildcard_mask="0.0.0.255", ... dst_ip_address="10.10.10.5", - ... dst_port=Port.SSH, + ... dst_port=Port["SSH"], ... position=5 ... ) >>> # This permits SSH traffic from the 192.168.1.0/24 subnet to the 10.10.10.5 server. @@ -411,10 +424,10 @@ class AccessControlList(SimComponent): >>> # Then if we want to allow a specific IP address from this subnet to SSH into the server >>> router.add_rule( ... action=ACLAction.PERMIT, - ... protocol=IPProtocol.TCP, + ... protocol=IPProtocol["TCP"], ... src_ip_address="192.168.1.25", ... dst_ip_address="10.10.10.5", - ... dst_port=Port.SSH, + ... dst_port=Port["SSH"], ... position=4 ... ) @@ -485,11 +498,11 @@ class AccessControlList(SimComponent): def get_relevant_rules( self, - protocol: IPProtocol, + protocol: str, src_ip_address: Union[str, IPv4Address], - src_port: Port, + src_port: int, dst_ip_address: Union[str, IPv4Address], - dst_port: Port, + dst_port: int, ) -> List[ACLRule]: """ Get the list of relevant rules for a packet with given properties. @@ -552,13 +565,13 @@ class AccessControlList(SimComponent): [ index, rule.action.name if rule.action else "ANY", - rule.protocol.name if rule.protocol else "ANY", + rule.protocol if rule.protocol else "ANY", rule.src_ip_address if rule.src_ip_address else "ANY", rule.src_wildcard_mask if rule.src_wildcard_mask else "ANY", - f"{rule.src_port.value} ({rule.src_port.name})" if rule.src_port else "ANY", + f"{rule.src_port} ({rule.src_port})" if rule.src_port else "ANY", rule.dst_ip_address if rule.dst_ip_address else "ANY", rule.dst_wildcard_mask if rule.dst_wildcard_mask else "ANY", - f"{rule.dst_port.value} ({rule.dst_port.name})" if rule.dst_port else "ANY", + f"{rule.dst_port} ({rule.dst_port})" if rule.dst_port else "ANY", rule.match_count, ] ) @@ -1088,17 +1101,17 @@ class RouterSessionManager(SessionManager): def resolve_outbound_transmission_details( self, dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, - src_port: Optional[Port] = None, - dst_port: Optional[Port] = None, - protocol: Optional[IPProtocol] = None, + src_port: Optional[int] = None, + dst_port: Optional[int] = None, + protocol: Optional[str] = None, session_id: Optional[str] = None, ) -> Tuple[ Optional[RouterInterface], Optional[str], IPv4Address, - Optional[Port], - Optional[Port], - Optional[IPProtocol], + Optional[int], + Optional[int], + Optional[str], bool, ]: """ @@ -1118,19 +1131,19 @@ class RouterSessionManager(SessionManager): treats the transmission as a broadcast to that network. Optional. :type dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] :param src_port: The source port number for the transmission. Optional. - :type src_port: Optional[Port] + :type src_port: Optional[int] :param dst_port: The destination port number for the transmission. Optional. - :type dst_port: Optional[Port] + :type dst_port: Optional[int] :param protocol: The IP protocol to be used for the transmission. Optional. - :type protocol: Optional[IPProtocol] + :type protocol: Optional[str] :param session_id: The session ID associated with the transmission. If provided, the session details override other parameters. Optional. :type session_id: Optional[str] :return: A tuple containing the resolved outbound network interface, destination MAC address, destination IP address, source port, destination port, protocol, and a boolean indicating whether the transmission is a broadcast. - :rtype: Tuple[Optional[RouterInterface], Optional[str], IPv4Address, Optional[Port], Optional[Port], - Optional[IPProtocol], bool] + :rtype: Tuple[Optional[RouterInterface], Optional[str], IPv4Address, Optional[int], Optional[int], + Optional[str], bool] """ if dst_ip_address and not isinstance(dst_ip_address, (IPv4Address, IPv4Network)): dst_ip_address = IPv4Address(dst_ip_address) @@ -1257,8 +1270,8 @@ class Router(NetworkNode): Initializes the router's ACL (Access Control List) with default rules, permitting essential protocols like ARP and ICMP, which are necessary for basic network operations and diagnostics. """ - self.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - self.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + self.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + self.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) def setup_for_episode(self, episode: int): """ @@ -1357,9 +1370,9 @@ class Router(NetworkNode): """ dst_ip_address = frame.ip.dst_ip_address dst_port = None - if frame.ip.protocol == IPProtocol.TCP: + if frame.ip.protocol == IPProtocol["TCP"]: dst_port = frame.tcp.dst_port - elif frame.ip.protocol == IPProtocol.UDP: + elif frame.ip.protocol == IPProtocol["UDP"]: dst_port = frame.udp.dst_port if self.ip_is_router_interface(dst_ip_address) and ( diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 3cb4c515..d73bc756 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -116,7 +116,7 @@ class WirelessRouter(Router): >>> wireless_router.configure_wireless_access_point( ... ip_address="10.10.10.1", ... subnet_mask="255.255.255.0" - ... frequency=AirSpaceFrequency.WIFI_2_4 + ... frequency=AirSpaceFrequency["WIFI_2_4"] ... ) """ @@ -153,7 +153,7 @@ class WirelessRouter(Router): self, ip_address: IPV4Address, subnet_mask: IPV4Address, - frequency: Optional[AirSpaceFrequency] = AirSpaceFrequency.WIFI_2_4, + frequency: Optional[int] = AirSpaceFrequency["WIFI_2_4"], ): """ Configures a wireless access point (WAP). @@ -168,10 +168,10 @@ class WirelessRouter(Router): :param subnet_mask: The subnet mask associated with the IP address :param frequency: The operating frequency of the wireless access point, defined by the AirSpaceFrequency enum. This determines the frequency band (e.g., 2.4 GHz or 5 GHz) the access point will use for wireless - communication. Default is AirSpaceFrequency.WIFI_2_4. + communication. Default is AirSpaceFrequency["WIFI_2_4"]. """ if not frequency: - frequency = AirSpaceFrequency.WIFI_2_4 + frequency = AirSpaceFrequency["WIFI_2_4"] self.sys_log.info("Configuring wireless access point") self.wireless_access_point.disable() # Temporarily disable the WAP for reconfiguration diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index cb0965eb..a73f3b12 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -79,9 +79,9 @@ def client_server_routed() -> Network: server_1.power_on() network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1]) - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) return network @@ -271,23 +271,23 @@ def arcd_uc2_network() -> Network: security_suite.connect_nic(NIC(ip_address="192.168.10.110", subnet_mask="255.255.255.0")) network.connect(endpoint_b=security_suite.network_interface[2], endpoint_a=switch_2.network_interface[7]) - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) # Allow PostgreSQL requests router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 ) # Allow DNS requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["DNS"], dst_port=Port["DNS"], position=1) # Allow FTP requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=2) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["FTP"], dst_port=Port["FTP"], position=2) # Open port 80 for web server - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=3) return network diff --git a/src/primaite/simulator/network/protocols/masquerade.py b/src/primaite/simulator/network/protocols/masquerade.py index e2a7b6a0..ef060bc7 100644 --- a/src/primaite/simulator/network/protocols/masquerade.py +++ b/src/primaite/simulator/network/protocols/masquerade.py @@ -8,9 +8,9 @@ from primaite.simulator.network.protocols.packet import DataPacket class MasqueradePacket(DataPacket): """Represents an generic malicious packet that is masquerading as another protocol.""" - masquerade_protocol: Enum # The 'Masquerade' protocol that is currently in use + masquerade_protocol: str # The 'Masquerade' protocol that is currently in use - masquerade_port: Enum # The 'Masquerade' port that is currently in use + masquerade_port: int # The 'Masquerade' port that is currently in use class C2Packet(MasqueradePacket): diff --git a/src/primaite/simulator/network/transmission/data_link_layer.py b/src/primaite/simulator/network/transmission/data_link_layer.py index 159eca7f..b9bc48d9 100644 --- a/src/primaite/simulator/network/transmission/data_link_layer.py +++ b/src/primaite/simulator/network/transmission/data_link_layer.py @@ -70,15 +70,15 @@ class Frame(BaseModel): msg = "Network Frame cannot have both a TCP header and a UDP header" _LOGGER.error(msg) raise ValueError(msg) - if kwargs["ip"].protocol == IPProtocol.TCP and not kwargs.get("tcp"): + if kwargs["ip"].protocol == IPProtocol["TCP"] and not kwargs.get("tcp"): msg = "Cannot build a Frame using the TCP IP Protocol without a TCPHeader" _LOGGER.error(msg) raise ValueError(msg) - if kwargs["ip"].protocol == IPProtocol.UDP and not kwargs.get("udp"): + if kwargs["ip"].protocol == IPProtocol["UDP"] and not kwargs.get("udp"): msg = "Cannot build a Frame using the UDP IP Protocol without a UDPHeader" _LOGGER.error(msg) raise ValueError(msg) - if kwargs["ip"].protocol == IPProtocol.ICMP and not kwargs.get("icmp"): + if kwargs["ip"].protocol == IPProtocol["ICMP"] and not kwargs.get("icmp"): msg = "Cannot build a Frame using the ICMP IP Protocol without a ICMPPacket" _LOGGER.error(msg) raise ValueError(msg) @@ -165,7 +165,7 @@ class Frame(BaseModel): :return: True if the Frame is an ARP packet, otherwise False. """ - return self.udp.dst_port == Port.ARP + return self.udp.dst_port == Port["ARP"] @property def is_icmp(self) -> bool: diff --git a/src/primaite/simulator/network/transmission/network_layer.py b/src/primaite/simulator/network/transmission/network_layer.py index d493cbdf..36ff2751 100644 --- a/src/primaite/simulator/network/transmission/network_layer.py +++ b/src/primaite/simulator/network/transmission/network_layer.py @@ -9,25 +9,32 @@ from primaite.utils.validators import IPV4Address _LOGGER = getLogger(__name__) -class IPProtocol(Enum): - """ - Enum representing transport layer protocols in IP header. +IPProtocol : dict[str, str] = dict( + NONE = "none", + TCP = "tcp", + UDP = "udp", + ICMP = "icmp", +) - .. _List of IPProtocols: - """ +# class IPProtocol(Enum): +# """ +# Enum representing transport layer protocols in IP header. - NONE = "none" - """Placeholder for a non-protocol.""" - TCP = "tcp" - """Transmission Control Protocol.""" - UDP = "udp" - """User Datagram Protocol.""" - ICMP = "icmp" - """Internet Control Message Protocol.""" +# .. _List of IPProtocols: +# """ - def model_dump(self) -> str: - """Return as JSON-serialisable string.""" - return self.name +# NONE = "none" +# """Placeholder for a non-protocol.""" +# TCP = "tcp" +# """Transmission Control Protocol.""" +# UDP = "udp" +# """User Datagram Protocol.""" +# ICMP = "icmp" +# """Internet Control Message Protocol.""" + +# def model_dump(self) -> str: +# """Return as JSON-serialisable string.""" +# return self.name class Precedence(Enum): @@ -81,7 +88,7 @@ class IPPacket(BaseModel): >>> ip_packet = IPPacket( ... src_ip_address=IPv4Address('192.168.0.1'), ... dst_ip_address=IPv4Address('10.0.0.1'), - ... protocol=IPProtocol.TCP, + ... protocol=IPProtocol["TCP"], ... ttl=64, ... precedence=Precedence.CRITICAL ... ) @@ -91,7 +98,7 @@ class IPPacket(BaseModel): "Source IP address." dst_ip_address: IPV4Address "Destination IP address." - protocol: IPProtocol = IPProtocol.TCP + protocol: str = IPProtocol["TCP"] "IPProtocol." ttl: int = 64 "Time to Live (TTL) for the packet." diff --git a/src/primaite/simulator/network/transmission/transport_layer.py b/src/primaite/simulator/network/transmission/transport_layer.py index 7f0d2d7a..c77ef532 100644 --- a/src/primaite/simulator/network/transmission/transport_layer.py +++ b/src/primaite/simulator/network/transmission/transport_layer.py @@ -5,76 +5,112 @@ from typing import List, Union from pydantic import BaseModel -class Port(Enum): - """ - Enumeration of common known TCP/UDP ports used by protocols for operation of network applications. +Port: dict[str, int] = dict( + UNUSED = -1, + NONE = 0, + WOL = 9, + FTP_DATA = 20, + FTP = 21, + SSH = 22, + SMTP = 25, + DNS = 53, + HTTP = 80, + POP3 = 110, + SFTP = 115, + NTP = 123, + IMAP = 143, + SNMP = 161, + SNMP_TRAP = 162, + ARP = 219, + LDAP = 389, + HTTPS = 443, + SMB = 445, + IPP = 631, + SQL_SERVER = 1433, + MYSQL = 3306, + RDP = 3389, + RTP = 5004, + RTP_ALT = 5005, + DNS_ALT = 5353, + HTTP_ALT = 8080, + HTTPS_ALT = 8443, + POSTGRES_SERVER = 5432, +) - .. _List of Ports: - """ +# class Port(): +# def __getattr__() - UNUSED = -1 - "An unused port stub." - NONE = 0 - "Place holder for a non-port." - WOL = 9 - "Wake-on-Lan (WOL) - Used to turn or awaken a computer from sleep mode by a network message." - FTP_DATA = 20 - "File Transfer [Default Data]" - FTP = 21 - "File Transfer Protocol (FTP) - FTP control (command)" - SSH = 22 - "Secure Shell (SSH) - Used for secure remote access and command execution." - SMTP = 25 - "Simple Mail Transfer Protocol (SMTP) - Used for email delivery between servers." - DNS = 53 - "Domain Name System (DNS) - Used for translating domain names to IP addresses." - HTTP = 80 - "HyperText Transfer Protocol (HTTP) - Used for web traffic." - POP3 = 110 - "Post Office Protocol version 3 (POP3) - Used for retrieving emails from a mail server." - SFTP = 115 - "Secure File Transfer Protocol (SFTP) - Used for secure file transfer over SSH." - NTP = 123 - "Network Time Protocol (NTP) - Used for clock synchronization between computer systems." - IMAP = 143 - "Internet Message Access Protocol (IMAP) - Used for retrieving emails from a mail server." - SNMP = 161 - "Simple Network Management Protocol (SNMP) - Used for network device management." - SNMP_TRAP = 162 - "SNMP Trap - Used for sending SNMP notifications (traps) to a network management system." - ARP = 219 - "Address resolution Protocol - Used to connect a MAC address to an IP address." - LDAP = 389 - "Lightweight Directory Access Protocol (LDAP) - Used for accessing and modifying directory information." - HTTPS = 443 - "HyperText Transfer Protocol Secure (HTTPS) - Used for secure web traffic." - SMB = 445 - "Server Message Block (SMB) - Used for file sharing and printer sharing in Windows environments." - IPP = 631 - "Internet Printing Protocol (IPP) - Used for printing over the internet or an intranet." - SQL_SERVER = 1433 - "Microsoft SQL Server Database Engine - Used for communication with the SQL Server." - MYSQL = 3306 - "MySQL Database Server - Used for MySQL database communication." - RDP = 3389 - "Remote Desktop Protocol (RDP) - Used for remote desktop access to Windows machines." - RTP = 5004 - "Real-time Transport Protocol (RTP) - Used for transmitting real-time media, e.g., audio and video." - RTP_ALT = 5005 - "Alternative port for RTP (RTP_ALT) - Used in some configurations for transmitting real-time media." - DNS_ALT = 5353 - "Alternative port for DNS (DNS_ALT) - Used in some configurations for DNS service." - HTTP_ALT = 8080 - "Alternative port for HTTP (HTTP_ALT) - Often used as an alternative HTTP port for web applications." - HTTPS_ALT = 8443 - "Alternative port for HTTPS (HTTPS_ALT) - Used in some configurations for secure web traffic." - POSTGRES_SERVER = 5432 - "Postgres SQL Server." +# class Port(Enum): +# """ +# Enumeration of common known TCP/UDP ports used by protocols for operation of network applications. - def model_dump(self) -> str: - """Return a json-serialisable string.""" - return self.name +# .. _List of Ports: +# """ + +# UNUSED = -1 +# "An unused port stub." + +# NONE = 0 +# "Place holder for a non-port." +# WOL = 9 +# "Wake-on-Lan (WOL) - Used to turn or awaken a computer from sleep mode by a network message." +# FTP_DATA = 20 +# "File Transfer [Default Data]" +# FTP = 21 +# "File Transfer Protocol (FTP) - FTP control (command)" +# SSH = 22 +# "Secure Shell (SSH) - Used for secure remote access and command execution." +# SMTP = 25 +# "Simple Mail Transfer Protocol (SMTP) - Used for email delivery between servers." +# DNS = 53 +# "Domain Name System (DNS) - Used for translating domain names to IP addresses." +# HTTP = 80 +# "HyperText Transfer Protocol (HTTP) - Used for web traffic." +# POP3 = 110 +# "Post Office Protocol version 3 (POP3) - Used for retrieving emails from a mail server." +# SFTP = 115 +# "Secure File Transfer Protocol (SFTP) - Used for secure file transfer over SSH." +# NTP = 123 +# "Network Time Protocol (NTP) - Used for clock synchronization between computer systems." +# IMAP = 143 +# "Internet Message Access Protocol (IMAP) - Used for retrieving emails from a mail server." +# SNMP = 161 +# "Simple Network Management Protocol (SNMP) - Used for network device management." +# SNMP_TRAP = 162 +# "SNMP Trap - Used for sending SNMP notifications (traps) to a network management system." +# ARP = 219 +# "Address resolution Protocol - Used to connect a MAC address to an IP address." +# LDAP = 389 +# "Lightweight Directory Access Protocol (LDAP) - Used for accessing and modifying directory information." +# HTTPS = 443 +# "HyperText Transfer Protocol Secure (HTTPS) - Used for secure web traffic." +# SMB = 445 +# "Server Message Block (SMB) - Used for file sharing and printer sharing in Windows environments." +# IPP = 631 +# "Internet Printing Protocol (IPP) - Used for printing over the internet or an intranet." +# SQL_SERVER = 1433 +# "Microsoft SQL Server Database Engine - Used for communication with the SQL Server." +# MYSQL = 3306 +# "MySQL Database Server - Used for MySQL database communication." +# RDP = 3389 +# "Remote Desktop Protocol (RDP) - Used for remote desktop access to Windows machines." +# RTP = 5004 +# "Real-time Transport Protocol (RTP) - Used for transmitting real-time media, e.g., audio and video." +# RTP_ALT = 5005 +# "Alternative port for RTP (RTP_ALT) - Used in some configurations for transmitting real-time media." +# DNS_ALT = 5353 +# "Alternative port for DNS (DNS_ALT) - Used in some configurations for DNS service." +# HTTP_ALT = 8080 +# "Alternative port for HTTP (HTTP_ALT) - Often used as an alternative HTTP port for web applications." +# HTTPS_ALT = 8443 +# "Alternative port for HTTPS (HTTPS_ALT) - Used in some configurations for secure web traffic." +# POSTGRES_SERVER = 5432 +# "Postgres SQL Server." + +# def model_dump(self) -> str: +# """Return a json-serialisable string.""" +# return self.name class UDPHeader(BaseModel): @@ -87,13 +123,13 @@ class UDPHeader(BaseModel): :Example: >>> udp_header = UDPHeader( - ... src_port=Port.HTTP_ALT, - ... dst_port=Port.HTTP, + ... src_port=Port["HTTP_ALT"], + ... dst_port=Port["HTTP"], ... ) """ - src_port: Union[Port, int] - dst_port: Union[Port, int] + src_port: int + dst_port: int class TCPFlags(Enum): @@ -126,12 +162,12 @@ class TCPHeader(BaseModel): :Example: >>> tcp_header = TCPHeader( - ... src_port=Port.HTTP_ALT, - ... dst_port=Port.HTTP, + ... src_port=Port["HTTP_ALT"], + ... dst_port=Port["HTTP"], ... flags=[TCPFlags.SYN, TCPFlags.ACK] ... ) """ - src_port: Port - dst_port: Port + src_port: int + dst_port: int flags: List[TCPFlags] = [TCPFlags.SYN] diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 3f80c745..170e2647 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -90,8 +90,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"): def __init__(self, **kwargs): kwargs["name"] = "DatabaseClient" - kwargs["port"] = Port.POSTGRES_SERVER - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["POSTGRES_SERVER"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) def _init_request_manager(self) -> RequestManager: diff --git a/src/primaite/simulator/system/applications/nmap.py b/src/primaite/simulator/system/applications/nmap.py index c87eaaf5..74bce85d 100644 --- a/src/primaite/simulator/system/applications/nmap.py +++ b/src/primaite/simulator/system/applications/nmap.py @@ -24,8 +24,8 @@ class PortScanPayload(SimComponent): """ ip_address: IPV4Address - port: Port - protocol: IPProtocol + port: int + protocol: str request: bool = True def describe_state(self) -> Dict: @@ -37,8 +37,8 @@ class PortScanPayload(SimComponent): """ state = super().describe_state() state["ip_address"] = str(self.ip_address) - state["port"] = self.port.value - state["protocol"] = self.protocol.value + state["port"] = self.port + state["protocol"] = self.protocol state["request"] = self.request return state @@ -64,8 +64,8 @@ class NMAP(Application, identifier="NMAP"): def __init__(self, **kwargs): kwargs["name"] = "NMAP" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["port"] = Port["NONE"] + kwargs["protocol"] = IPProtocol["NONE"] super().__init__(**kwargs) def _can_perform_network_action(self) -> bool: @@ -218,14 +218,14 @@ class NMAP(Application, identifier="NMAP"): print(table.get_string(sortby="IP Address")) return active_nodes - def _determine_port_scan_type(self, target_ip_addresses: List[IPV4Address], target_ports: List[Port]) -> str: + def _determine_port_scan_type(self, target_ip_addresses: List[IPV4Address], target_ports: List[int]) -> str: """ Determine the type of port scan based on the number of target IP addresses and ports. :param target_ip_addresses: The list of target IP addresses. :type target_ip_addresses: List[IPV4Address] :param target_ports: The list of target ports. - :type target_ports: List[Port] + :type target_ports: List[int] :return: The type of port scan. :rtype: str @@ -238,8 +238,8 @@ class NMAP(Application, identifier="NMAP"): def _check_port_open_on_ip_address( self, ip_address: IPv4Address, - port: Port, - protocol: IPProtocol, + port: int, + protocol: str, is_re_attempt: bool = False, port_scan_uuid: Optional[str] = None, ) -> bool: @@ -251,7 +251,7 @@ class NMAP(Application, identifier="NMAP"): :param port: The target port. :type port: Port :param protocol: The protocol used for the port scan. - :type protocol: IPProtocol + :type protocol: str :param is_re_attempt: Flag indicating if this is a reattempt. Defaults to False. :type is_re_attempt: bool :param port_scan_uuid: The UUID of the port scan payload. Defaults to None. @@ -272,8 +272,8 @@ class NMAP(Application, identifier="NMAP"): payload = PortScanPayload(ip_address=ip_address, port=port, protocol=protocol) self._active_port_scans[payload.uuid] = payload self.sys_log.info( - f"{self.name}: Sending port scan request over {payload.protocol.name} on port {payload.port.value} " - f"({payload.port.name}) to {payload.ip_address}" + f"{self.name}: Sending port scan request over {payload.protocol} on port {payload.port} " + f"({payload.port}) to {payload.ip_address}" ) self.software_manager.send_payload_to_session_manager( payload=payload, dest_ip_address=ip_address, src_port=port, dest_port=port, ip_protocol=protocol @@ -295,8 +295,8 @@ class NMAP(Application, identifier="NMAP"): self._active_port_scans.pop(payload.uuid) self._port_scan_responses[payload.uuid] = payload self.sys_log.info( - f"{self.name}: Received port scan response from {payload.ip_address} on port {payload.port.value} " - f"({payload.port.name}) over {payload.protocol.name}" + f"{self.name}: Received port scan response from {payload.ip_address} on port {payload.port} " + f"({payload.port}) over {payload.protocol}" ) def _process_port_scan_request(self, payload: PortScanPayload, session_id: str) -> None: @@ -311,8 +311,8 @@ class NMAP(Application, identifier="NMAP"): if self.software_manager.check_port_is_open(port=payload.port, protocol=payload.protocol): payload.request = False self.sys_log.info( - f"{self.name}: Responding to port scan request for port {payload.port.value} " - f"({payload.port.name}) over {payload.protocol.name}", + f"{self.name}: Responding to port scan request for port {payload.port} " + f"({payload.port}) over {payload.protocol}", ) self.software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id) @@ -320,20 +320,20 @@ class NMAP(Application, identifier="NMAP"): def port_scan( self, target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]], - target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None, - target_port: Optional[Union[Port, List[Port]]] = None, + target_protocol: Optional[Union[str, List[str]]] = None, + target_port: Optional[Union[int, List[int]]] = None, show: bool = True, json_serializable: bool = False, - ) -> Dict[IPv4Address, Dict[IPProtocol, List[Port]]]: + ) -> Dict[IPv4Address, Dict[str, List[int]]]: """ Perform a port scan on the target IP address(es). :param target_ip_address: The target IP address(es) or network(s) for the port scan. :type target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]] :param target_protocol: The protocol(s) to use for the port scan. Defaults to None, which includes TCP and UDP. - :type target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] + :type target_protocol: Optional[Union[str, List[str]]] :param target_port: The port(s) to scan. Defaults to None, which includes all valid ports. - :type target_port: Optional[Union[Port, List[Port]]] + :type target_port: Optional[Union[int, List[int]]] :param show: Flag indicating whether to display the scan results. Defaults to True. :type show: bool :param json_serializable: Flag indicating whether the return value should be JSON serializable. Defaults to @@ -341,19 +341,19 @@ class NMAP(Application, identifier="NMAP"): :type json_serializable: bool :return: A dictionary mapping IP addresses to protocols and lists of open ports. - :rtype: Dict[IPv4Address, Dict[IPProtocol, List[Port]]] + :rtype: Dict[IPv4Address, Dict[str, List[int]]] """ ip_addresses = self._explode_ip_address_network_array(target_ip_address) - if isinstance(target_port, Port): + if isinstance(target_port, int): target_port = [target_port] elif target_port is None: - target_port = [port for port in Port if port not in {Port.NONE, Port.UNUSED}] + target_port = [port for port in Port if port not in {Port["NONE"], Port["UNUSED"]}] - if isinstance(target_protocol, IPProtocol): + if isinstance(target_protocol, str): target_protocol = [target_protocol] elif target_protocol is None: - target_protocol = [IPProtocol.TCP, IPProtocol.UDP] + target_protocol = [IPProtocol["TCP"], IPProtocol["UDP"]] scan_type = self._determine_port_scan_type(list(ip_addresses), target_port) active_ports = {} @@ -372,10 +372,10 @@ class NMAP(Application, identifier="NMAP"): if port_open: if show: - table.add_row([ip_address, port.value, port.name, protocol.name]) + table.add_row([ip_address, port, port, protocol]) _ip_address = ip_address if not json_serializable else str(ip_address) - _protocol = protocol if not json_serializable else protocol.value - _port = port if not json_serializable else port.value + _protocol = protocol if not json_serializable else protocol + _port = port if not json_serializable else port if _ip_address not in active_ports: active_ports[_ip_address] = dict() if _protocol not in active_ports[_ip_address]: @@ -390,12 +390,12 @@ class NMAP(Application, identifier="NMAP"): def network_service_recon( self, target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]], - target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None, - target_port: Optional[Union[Port, List[Port]]] = None, + target_protocol: Optional[Union[str, List[str]]] = None, + target_port: Optional[Union[int, List[int]]] = None, show: bool = True, show_online_only: bool = True, json_serializable: bool = False, - ) -> Dict[IPv4Address, Dict[IPProtocol, List[Port]]]: + ) -> Dict[IPv4Address, Dict[str, List[int]]]: """ Perform a network service reconnaissance which includes a ping scan followed by a port scan. @@ -408,9 +408,9 @@ class NMAP(Application, identifier="NMAP"): :param target_ip_address: The target IP address(es) or network(s) for the port scan. :type target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]] :param target_protocol: The protocol(s) to use for the port scan. Defaults to None, which includes TCP and UDP. - :type target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] + :type target_protocol: Optional[Union[str, List[str]]] :param target_port: The port(s) to scan. Defaults to None, which includes all valid ports. - :type target_port: Optional[Union[Port, List[Port]]] + :type target_port: Optional[Union[int, List[int]]] :param show: Flag indicating whether to display the scan results. Defaults to True. :type show: bool :param show_online_only: Flag indicating whether to show only the online hosts. Defaults to True. @@ -420,7 +420,7 @@ class NMAP(Application, identifier="NMAP"): :type json_serializable: bool :return: A dictionary mapping IP addresses to protocols and lists of open ports. - :rtype: Dict[IPv4Address, Dict[IPProtocol, List[Port]]] + :rtype: Dict[IPv4Address, Dict[str, List[int]]] """ ping_scan_results = self.ping_scan( target_ip_address=target_ip_address, show=show, show_online_only=show_online_only, json_serializable=False diff --git a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py index 5d4cc8e0..d442d968 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py @@ -81,10 +81,10 @@ class AbstractC2(Application, identifier="AbstractC2"): keep_alive_frequency: int = Field(default=5, ge=1) """The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon.""" - masquerade_protocol: IPProtocol = Field(default=IPProtocol.TCP) + masquerade_protocol: str = Field(default=IPProtocol["TCP"]) """The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP.""" - masquerade_port: Port = Field(default=Port.HTTP) + masquerade_port: int = Field(default=Port["HTTP"]) """The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP.""" c2_config: _C2Opts = _C2Opts() @@ -142,9 +142,9 @@ class AbstractC2(Application, identifier="AbstractC2"): def __init__(self, **kwargs): """Initialise the C2 applications to by default listen for HTTP traffic.""" - kwargs["listen_on_ports"] = {Port.HTTP, Port.FTP, Port.DNS} - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.TCP + kwargs["listen_on_ports"] = {Port["HTTP"], Port["FTP"], Port["DNS"]} + kwargs["port"] = Port["NONE"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) @property @@ -367,7 +367,7 @@ class AbstractC2(Application, identifier="AbstractC2"): :rtype: bool """ # Validating that they are valid Enums. - if not isinstance(payload.masquerade_port, Port) or not isinstance(payload.masquerade_protocol, IPProtocol): + if not isinstance(payload.masquerade_port, int) or not isinstance(payload.masquerade_protocol, str): self.sys_log.warning( f"{self.name}: Received invalid Masquerade Values within Keep Alive." f"Port: {payload.masquerade_port} Protocol: {payload.masquerade_protocol}." @@ -410,8 +410,8 @@ class AbstractC2(Application, identifier="AbstractC2"): self.keep_alive_inactivity = 0 self.keep_alive_frequency = 5 self.c2_remote_connection = None - self.c2_config.masquerade_port = Port.HTTP - self.c2_config.masquerade_protocol = IPProtocol.TCP + self.c2_config.masquerade_port = Port["HTTP"] + self.c2_config.masquerade_protocol = IPProtocol["TCP"] @abstractmethod def _confirm_remote_connection(self, timestep: int) -> bool: diff --git a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py index fa0271e5..06453330 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py @@ -130,8 +130,8 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): self, c2_server_ip_address: IPv4Address = None, keep_alive_frequency: int = 5, - masquerade_protocol: Enum = IPProtocol.TCP, - masquerade_port: Enum = Port.HTTP, + masquerade_protocol: str = IPProtocol["TCP"], + masquerade_port: int = Port["HTTP"], ) -> bool: """ Configures the C2 beacon to communicate with the C2 server. diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index fefb22c3..d74ae384 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -50,8 +50,8 @@ class DataManipulationBot(Application, identifier="DataManipulationBot"): def __init__(self, **kwargs): kwargs["name"] = "DataManipulationBot" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["port"] = Port["NONE"] + kwargs["protocol"] = IPProtocol["NONE"] super().__init__(**kwargs) self._db_connection: Optional[DatabaseClientConnection] = None diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index fcad3b3e..2cc99c4a 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -35,7 +35,7 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): target_ip_address: Optional[IPv4Address] = None """IP address of the target service.""" - target_port: Optional[Port] = None + target_port: Optional[int] = None """Port of the target service.""" payload: Optional[str] = None @@ -94,7 +94,7 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): def configure( self, target_ip_address: IPv4Address, - target_port: Optional[Port] = Port.POSTGRES_SERVER, + target_port: Optional[int] = Port["POSTGRES_SERVER"], payload: Optional[str] = None, repeat: bool = False, port_scan_p_of_success: float = 0.1, @@ -105,7 +105,7 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): Configure the Denial of Service bot. :param: target_ip_address: The IP address of the Node containing the target service. - :param: target_port: The port of the target service. Optional - Default is `Port.HTTP` + :param: target_port: The port of the target service. Optional - Default is `Port["HTTP"]` :param: payload: The payload the DoS Bot will throw at the target service. Optional - Default is `None` :param: repeat: If True, the bot will maintain the attack. Optional - Default is `True` :param: port_scan_p_of_success: The chance of the port scan being successful. Optional - Default is 0.1 (10%) diff --git a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py index 2046affc..a819190c 100644 --- a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py +++ b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py @@ -27,8 +27,8 @@ class RansomwareScript(Application, identifier="RansomwareScript"): def __init__(self, **kwargs): kwargs["name"] = "RansomwareScript" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["port"] = Port["NONE"] + kwargs["protocol"] = IPProtocol["NONE"] super().__init__(**kwargs) self._db_connection: Optional[DatabaseClientConnection] = None diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 73791676..6707fa52 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -43,10 +43,10 @@ class WebBrowser(Application, identifier="WebBrowser"): def __init__(self, **kwargs): kwargs["name"] = "WebBrowser" - kwargs["protocol"] = IPProtocol.TCP + kwargs["protocol"] = IPProtocol["TCP"] # default for web is port 80 if kwargs.get("port") is None: - kwargs["port"] = Port.HTTP + kwargs["port"] = Port["HTTP"] super().__init__(**kwargs) self.run() @@ -126,7 +126,7 @@ class WebBrowser(Application, identifier="WebBrowser"): if self.send( payload=payload, dest_ip_address=self.domain_name_ip_address, - dest_port=parsed_url.port if parsed_url.port else Port.HTTP, + dest_port=parsed_url.port if parsed_url.port else Port["HTTP"], ): self.sys_log.info( f"{self.name}: Received HTTP {payload.request_method.name} " @@ -154,7 +154,7 @@ class WebBrowser(Application, identifier="WebBrowser"): self, payload: HttpRequestPacket, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = Port.HTTP, + dest_port: Optional[int] = Port["HTTP"], session_id: Optional[str] = None, **kwargs, ) -> bool: diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index b7e2c021..172be453 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -34,14 +34,14 @@ class Session(SimComponent): :param connected: A flag indicating whether the session is connected. """ - protocol: IPProtocol + protocol: str with_ip_address: IPv4Address - src_port: Optional[Port] - dst_port: Optional[Port] + src_port: Optional[int] + dst_port: Optional[int] connected: bool = False @classmethod - def from_session_key(cls, session_key: Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]) -> Session: + def from_session_key(cls, session_key: Tuple[str, IPv4Address, Optional[int], Optional[int]]) -> Session: """ Create a Session instance from a session key tuple. @@ -77,7 +77,7 @@ class SessionManager: def __init__(self, sys_log: SysLog): self.sessions_by_key: Dict[ - Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]], Session + Tuple[str, IPv4Address, IPv4Address, Optional[int], Optional[int]], Session ] = {} self.sessions_by_uuid: Dict[str, Session] = {} self.sys_log: SysLog = sys_log @@ -103,7 +103,7 @@ class SessionManager: @staticmethod def _get_session_key( frame: Frame, inbound_frame: bool = True - ) -> Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]: + ) -> Tuple[str, IPv4Address, Optional[int], Optional[int]]: """ Extracts the session key from the given frame. @@ -111,15 +111,15 @@ class SessionManager: - IPProtocol: The transport protocol (e.g. TCP, UDP, ICMP). - IPv4Address: The source IP address. - IPv4Address: The destination IP address. - - Optional[Port]: The source port number (if applicable). - - Optional[Port]: The destination port number (if applicable). + - Optional[int]: The source port number (if applicable). + - Optional[int]: The destination port number (if applicable). :param frame: The network frame from which to extract the session key. :return: A tuple containing the session key. """ protocol = frame.ip.protocol with_ip_address = frame.ip.src_ip_address - if protocol == IPProtocol.TCP: + if protocol == IPProtocol["TCP"]: if inbound_frame: src_port = frame.tcp.src_port dst_port = frame.tcp.dst_port @@ -127,7 +127,7 @@ class SessionManager: dst_port = frame.tcp.src_port src_port = frame.tcp.dst_port with_ip_address = frame.ip.dst_ip_address - elif protocol == IPProtocol.UDP: + elif protocol == IPProtocol["UDP"]: if inbound_frame: src_port = frame.udp.src_port dst_port = frame.udp.dst_port @@ -167,17 +167,17 @@ class SessionManager: def resolve_outbound_transmission_details( self, dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, - src_port: Optional[Port] = None, - dst_port: Optional[Port] = None, - protocol: Optional[IPProtocol] = None, + src_port: Optional[int] = None, + dst_port: Optional[int] = None, + protocol: Optional[str] = None, session_id: Optional[str] = None, ) -> Tuple[ Optional["NetworkInterface"], Optional[str], IPv4Address, - Optional[Port], - Optional[Port], - Optional[IPProtocol], + Optional[int], + Optional[int], + Optional[str], bool, ]: """ @@ -196,19 +196,19 @@ class SessionManager: treats the transmission as a broadcast to that network. Optional. :type dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] :param src_port: The source port number for the transmission. Optional. - :type src_port: Optional[Port] + :type src_port: Optional[int] :param dst_port: The destination port number for the transmission. Optional. - :type dst_port: Optional[Port] + :type dst_port: Optional[int] :param protocol: The IP protocol to be used for the transmission. Optional. - :type protocol: Optional[IPProtocol] + :type protocol: Optional[str] :param session_id: The session ID associated with the transmission. If provided, the session details override other parameters. Optional. :type session_id: Optional[str] :return: A tuple containing the resolved outbound network interface, destination MAC address, destination IP address, source port, destination port, protocol, and a boolean indicating whether the transmission is a broadcast. - :rtype: Tuple[Optional["NetworkInterface"], Optional[str], IPv4Address, Optional[Port], Optional[Port], - Optional[IPProtocol], bool] + :rtype: Tuple[Optional["NetworkInterface"], Optional[str], IPv4Address, Optional[int], Optional[int], + Optional[str], bool] """ if dst_ip_address and not isinstance(dst_ip_address, (IPv4Address, IPv4Network)): dst_ip_address = IPv4Address(dst_ip_address) @@ -259,10 +259,10 @@ class SessionManager: self, payload: Any, dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, - src_port: Optional[Port] = None, - dst_port: Optional[Port] = None, + src_port: Optional[int] = None, + dst_port: Optional[int] = None, session_id: Optional[str] = None, - ip_protocol: IPProtocol = IPProtocol.TCP, + ip_protocol: str = IPProtocol["TCP"], icmp_packet: Optional[ICMPPacket] = None, ) -> Union[Any, None]: """ @@ -286,7 +286,7 @@ class SessionManager: dst_mac_address = payload.target_mac_addr outbound_network_interface = self.resolve_outbound_network_interface(payload.target_ip_address) is_broadcast = payload.request - ip_protocol = IPProtocol.UDP + ip_protocol = IPProtocol["UDP"] else: vals = self.resolve_outbound_transmission_details( dst_ip_address=dst_ip_address, @@ -311,26 +311,26 @@ class SessionManager: if not outbound_network_interface or not dst_mac_address: return False - if not (src_port or dst_port): + if src_port is None and dst_port is None: raise ValueError( "Failed to resolve src or dst port. Have you sent the port from the service or application?" ) tcp_header = None udp_header = None - if ip_protocol == IPProtocol.TCP: + if ip_protocol == IPProtocol["TCP"]: tcp_header = TCPHeader( src_port=dst_port, dst_port=dst_port, ) - elif ip_protocol == IPProtocol.UDP: + elif ip_protocol == IPProtocol["UDP"]: udp_header = UDPHeader( src_port=dst_port, dst_port=dst_port, ) # TODO: Only create IP packet if not ARP # ip_packet = None - # if dst_port != Port.ARP: + # if dst_port != Port["ARP"]: # IPPacket( # src_ip_address=outbound_network_interface.ip_address, # dst_ip_address=dst_ip_address, @@ -387,7 +387,7 @@ class SessionManager: elif frame.udp: dst_port = frame.udp.dst_port elif frame.icmp: - dst_port = Port.NONE + dst_port = Port["NONE"] self.software_manager.receive_payload_from_session_manager( payload=frame.payload, port=dst_port, @@ -413,5 +413,5 @@ class SessionManager: table.align = "l" table.title = f"{self.sys_log.hostname} Session Manager" for session in self.sessions_by_key.values(): - table.add_row([session.dst_ip_address, session.dst_port.value, session.protocol.name]) + table.add_row([session.dst_ip_address, session.dst_port, session.protocol]) print(table) diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index d45611ed..8eac33fa 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -52,7 +52,7 @@ class SoftwareManager: self.session_manager = session_manager self.software: Dict[str, Union[Service, Application]] = {} self._software_class_to_name_map: Dict[Type[IOSoftware], str] = {} - self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {} + self.port_protocol_mapping: Dict[Tuple[int, str], Union[Service, Application]] = {} self.sys_log: SysLog = sys_log self.file_system: FileSystem = file_system self.dns_server: Optional[IPv4Address] = dns_server @@ -67,7 +67,7 @@ class SoftwareManager: """Provides access to the ICMP service instance, if installed.""" return self.software.get("ICMP") # noqa - def get_open_ports(self) -> List[Port]: + def get_open_ports(self) -> List[int]: """ Get a list of open ports. @@ -81,7 +81,7 @@ class SoftwareManager: open_ports += list(software.listen_on_ports) return open_ports - def check_port_is_open(self, port: Port, protocol: IPProtocol) -> bool: + def check_port_is_open(self, port: int, protocol: str) -> bool: """ Check if a specific port is open and running a service using the specified protocol. @@ -93,7 +93,7 @@ class SoftwareManager: :param port: The port to check. :type port: Port :param protocol: The protocol to check (e.g., TCP, UDP). - :type protocol: IPProtocol + :type protocol: str :return: True if the port is open and a service is running on it using the specified protocol, False otherwise. :rtype: bool """ @@ -189,9 +189,9 @@ class SoftwareManager: self, payload: Any, dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, - src_port: Optional[Port] = None, - dest_port: Optional[Port] = None, - ip_protocol: IPProtocol = IPProtocol.TCP, + src_port: Optional[int] = None, + dest_port: Optional[int] = None, + ip_protocol: str = IPProtocol["TCP"], session_id: Optional[str] = None, ) -> bool: """ @@ -219,8 +219,8 @@ class SoftwareManager: def receive_payload_from_session_manager( self, payload: Any, - port: Port, - protocol: IPProtocol, + port: int, + protocol: str, session_id: str, from_network_interface: "NIC", frame: Frame, @@ -275,8 +275,8 @@ class SoftwareManager: software_type, software.operating_state.name, software.health_state_actual.name, - software.port.value if software.port != Port.NONE else None, - software.protocol.value, + software.port if software.port != Port["NONE"] else None, + software.protocol, ] ) print(table) diff --git a/src/primaite/simulator/system/services/arp/arp.py b/src/primaite/simulator/system/services/arp/arp.py index efadf189..b8dd5f89 100644 --- a/src/primaite/simulator/system/services/arp/arp.py +++ b/src/primaite/simulator/system/services/arp/arp.py @@ -26,8 +26,8 @@ class ARP(Service): def __init__(self, **kwargs): kwargs["name"] = "ARP" - kwargs["port"] = Port.ARP - kwargs["protocol"] = IPProtocol.UDP + kwargs["port"] = Port["ARP"] + kwargs["protocol"] = IPProtocol["UDP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index b38e87b4..11ca9eb2 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -38,8 +38,8 @@ class DatabaseService(Service): def __init__(self, **kwargs): kwargs["name"] = "DatabaseService" - kwargs["port"] = Port.POSTGRES_SERVER - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["POSTGRES_SERVER"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) self._create_db_file() diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index d7ba0cd4..62f14366 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -22,11 +22,11 @@ class DNSClient(Service): def __init__(self, **kwargs): kwargs["name"] = "DNSClient" - kwargs["port"] = Port.DNS + kwargs["port"] = Port["DNS"] # DNS uses UDP by default # it switches to TCP when the bytes exceed 512 (or 4096) bytes # TCP for now - kwargs["protocol"] = IPProtocol.TCP + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) self.start() @@ -95,7 +95,7 @@ class DNSClient(Service): # send a request to check if domain name exists in the DNS Server software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( - payload=payload, dest_ip_address=self.dns_server, dest_port=Port.DNS + payload=payload, dest_ip_address=self.dns_server, dest_port=Port["DNS"] ) # recursively re-call the function passing is_reattempt=True @@ -110,7 +110,7 @@ class DNSClient(Service): payload: DNSPacket, session_id: Optional[str] = None, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = None, + dest_port: Optional[int] = None, **kwargs, ) -> bool: """ diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index 8a4bbaed..93895825 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -21,11 +21,11 @@ class DNSServer(Service): def __init__(self, **kwargs): kwargs["name"] = "DNSServer" - kwargs["port"] = Port.DNS + kwargs["port"] = Port["DNS"] # DNS uses UDP by default # it switches to TCP when the bytes exceed 512 (or 4096) bytes # TCP for now - kwargs["protocol"] = IPProtocol.TCP + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) self.start() diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index f823e42c..1fce4133 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -25,8 +25,8 @@ class FTPClient(FTPServiceABC): def __init__(self, **kwargs): kwargs["name"] = "FTPClient" - kwargs["port"] = Port.FTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["FTP"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) self.start() @@ -104,7 +104,7 @@ class FTPClient(FTPServiceABC): def _connect_to_server( self, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = Port.FTP, + dest_port: Optional[int] = Port["FTP"], session_id: Optional[str] = None, is_reattempt: Optional[bool] = False, ) -> bool: @@ -114,7 +114,7 @@ class FTPClient(FTPServiceABC): :param: dest_ip_address: IP address of the FTP server the client needs to connect to. Optional. :type: dest_ip_address: Optional[IPv4Address] :param: dest_port: Port of the FTP server the client needs to connect to. Optional. - :type: dest_port: Optional[Port] + :type: dest_port: Optional[int] :param: is_reattempt: Set to True if attempt to connect to FTP Server has been attempted. Default False. :type: is_reattempt: Optional[bool] """ @@ -124,13 +124,13 @@ class FTPClient(FTPServiceABC): # normally FTP will choose a random port for the transfer, but using the FTP command port will do for now # create FTP packet - payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.PORT, ftp_command_args=Port.FTP) + payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.PORT, ftp_command_args=Port["FTP"]) if self.send(payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id): if payload.status_code == FTPStatusCode.OK: self.sys_log.info( f"{self.name}: Successfully connected to FTP Server " - f"{dest_ip_address} via port {payload.ftp_command_args.value}" + f"{dest_ip_address} via port {payload.ftp_command_args}" ) self.add_connection(connection_id="server_connection", session_id=session_id) return True @@ -139,7 +139,7 @@ class FTPClient(FTPServiceABC): # reattempt failed self.sys_log.warning( f"{self.name}: Unable to connect to FTP Server " - f"{dest_ip_address} via port {payload.ftp_command_args.value}" + f"{dest_ip_address} via port {payload.ftp_command_args}" ) return False else: @@ -152,7 +152,7 @@ class FTPClient(FTPServiceABC): return False def _disconnect_from_server( - self, dest_ip_address: Optional[IPv4Address] = None, dest_port: Optional[Port] = Port.FTP + self, dest_ip_address: Optional[IPv4Address] = None, dest_port: Optional[int] = Port["FTP"] ) -> bool: """ Connects the client from a given FTP server. @@ -160,7 +160,7 @@ class FTPClient(FTPServiceABC): :param: dest_ip_address: IP address of the FTP server the client needs to disconnect from. Optional. :type: dest_ip_address: Optional[IPv4Address] :param: dest_port: Port of the FTP server the client needs to disconnect from. Optional. - :type: dest_port: Optional[Port] + :type: dest_port: Optional[int] :param: is_reattempt: Set to True if attempt to disconnect from FTP Server has been attempted. Default False. :type: is_reattempt: Optional[bool] """ @@ -179,7 +179,7 @@ class FTPClient(FTPServiceABC): src_file_name: str, dest_folder_name: str, dest_file_name: str, - dest_port: Optional[Port] = Port.FTP, + dest_port: Optional[int] = Port["FTP"], session_id: Optional[str] = None, ) -> bool: """ @@ -203,8 +203,8 @@ class FTPClient(FTPServiceABC): :param: dest_file_name: The name of the file to be saved on the FTP Server. :type: dest_file_name: str - :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port.FTP. - :type: dest_port: Optional[Port] + :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"]. + :type: dest_port: Optional[int] :param: session_id: The id of the session :type: session_id: Optional[str] @@ -241,7 +241,7 @@ class FTPClient(FTPServiceABC): src_file_name: str, dest_folder_name: str, dest_file_name: str, - dest_port: Optional[Port] = Port.FTP, + dest_port: Optional[int] = Port["FTP"], ) -> bool: """ Request a file from a target IP address. @@ -263,8 +263,8 @@ class FTPClient(FTPServiceABC): :param: dest_file_name: The name of the file to be saved on the FTP Server. :type: dest_file_name: str - :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port.FTP. - :type: dest_port: Optional[Port] + :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"]. + :type: dest_port: Optional[int] """ # check if FTP is currently connected to IP self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port) diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index f02d01f4..701bff79 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -23,8 +23,8 @@ class FTPServer(FTPServiceABC): def __init__(self, **kwargs): kwargs["name"] = "FTPServer" - kwargs["port"] = Port.FTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["FTP"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) self.start() @@ -52,7 +52,7 @@ class FTPServer(FTPServiceABC): # process server specific commands, otherwise call super if payload.ftp_command == FTPCommand.PORT: # check that the port is valid - if isinstance(payload.ftp_command_args, Port) and payload.ftp_command_args.value in range(0, 65535): + if isinstance(payload.ftp_command_args, int) and (0 <= payload.ftp_command_args < 65535): # return successful connection self.add_connection(connection_id=session_id, session_id=session_id) payload.status_code = FTPStatusCode.OK diff --git a/src/primaite/simulator/system/services/ftp/ftp_service.py b/src/primaite/simulator/system/services/ftp/ftp_service.py index 689a3da7..36245e0f 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_service.py +++ b/src/primaite/simulator/system/services/ftp/ftp_service.py @@ -78,7 +78,7 @@ class FTPServiceABC(Service, ABC): dest_folder_name: str, dest_file_name: str, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = None, + dest_port: Optional[int] = None, session_id: Optional[str] = None, is_response: bool = False, ) -> bool: @@ -97,8 +97,8 @@ class FTPServiceABC(Service, ABC): :param: dest_ip_address: The IP address of the machine that hosts the FTP Server. :type: dest_ip_address: Optional[IPv4Address] - :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port.FTP. - :type: dest_port: Optional[Port] + :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"]. + :type: dest_port: Optional[int] :param: session_id: session ID linked to the FTP Packet. Optional. :type: session_id: Optional[str] @@ -168,7 +168,7 @@ class FTPServiceABC(Service, ABC): payload: FTPPacket, session_id: Optional[str] = None, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = None, + dest_port: Optional[int] = None, **kwargs, ) -> bool: """ diff --git a/src/primaite/simulator/system/services/icmp/icmp.py b/src/primaite/simulator/system/services/icmp/icmp.py index 6741d86a..a2dfac0d 100644 --- a/src/primaite/simulator/system/services/icmp/icmp.py +++ b/src/primaite/simulator/system/services/icmp/icmp.py @@ -26,8 +26,8 @@ class ICMP(Service): def __init__(self, **kwargs): kwargs["name"] = "ICMP" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.ICMP + kwargs["port"] = Port["NONE"] + kwargs["protocol"] = IPProtocol["ICMP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/services/icmp/router_icmp.py b/src/primaite/simulator/system/services/icmp/router_icmp.py index 4fdc6baa..19c0ac2d 100644 --- a/src/primaite/simulator/system/services/icmp/router_icmp.py +++ b/src/primaite/simulator/system/services/icmp/router_icmp.py @@ -36,13 +36,13 @@ # self.sys_log.info(f"Received echo request from {frame.ip.src_ip_address}") # target_mac_address = self.arp.get_arp_cache_mac_address(frame.ip.src_ip_address) # src_nic = self.arp.get_arp_cache_network_interface(frame.ip.src_ip_address) -# tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP) +# tcp_header = TCPHeader(src_port=Port["ARP"], dst_port=Port["ARP"]) # # # Network Layer # ip_packet = IPPacket( # src_ip_address=network_interface.ip_address, # dst_ip_address=frame.ip.src_ip_address, -# protocol=IPProtocol.ICMP, +# protocol=IPProtocol["ICMP"], # ) # # Data Link Layer # ethernet_header = EthernetHeader( diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index 8924a821..40b8d273 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -21,8 +21,8 @@ class NTPClient(Service): def __init__(self, **kwargs): kwargs["name"] = "NTPClient" - kwargs["port"] = Port.NTP - kwargs["protocol"] = IPProtocol.UDP + kwargs["port"] = Port["NTP"] + kwargs["protocol"] = IPProtocol["UDP"] super().__init__(**kwargs) self.start() @@ -55,7 +55,7 @@ class NTPClient(Service): payload: NTPPacket, session_id: Optional[str] = None, dest_ip_address: IPv4Address = None, - dest_port: Port = Port.NTP, + dest_port: int = Port["NTP"], **kwargs, ) -> bool: """Requests NTP data from NTP server. diff --git a/src/primaite/simulator/system/services/ntp/ntp_server.py b/src/primaite/simulator/system/services/ntp/ntp_server.py index 547bbc06..d9de40c6 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_server.py +++ b/src/primaite/simulator/system/services/ntp/ntp_server.py @@ -16,8 +16,8 @@ class NTPServer(Service): def __init__(self, **kwargs): kwargs["name"] = "NTPServer" - kwargs["port"] = Port.NTP - kwargs["protocol"] = IPProtocol.UDP + kwargs["port"] = Port["NTP"] + kwargs["protocol"] = IPProtocol["UDP"] super().__init__(**kwargs) self.start() diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index e98e8555..41987aff 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -137,8 +137,8 @@ class Terminal(Service): def __init__(self, **kwargs): kwargs["name"] = "Terminal" - kwargs["port"] = Port.SSH - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["SSH"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 4fc64e1f..c021a86e 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -49,10 +49,10 @@ class WebServer(Service): def __init__(self, **kwargs): kwargs["name"] = "WebServer" - kwargs["protocol"] = IPProtocol.TCP + kwargs["protocol"] = IPProtocol["TCP"] # default for web is port 80 if kwargs.get("port") is None: - kwargs["port"] = Port.HTTP + kwargs["port"] = Port["HTTP"] super().__init__(**kwargs) self._install_web_files() @@ -145,7 +145,7 @@ class WebServer(Service): payload: HttpResponsePacket, session_id: Optional[str] = None, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = None, + dest_port: Optional[int] = None, **kwargs, ) -> bool: """ diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index f1d1b9a1..1880d244 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -251,11 +251,11 @@ class IOSoftware(Software): "Indicates if the software uses TCP protocol for communication. Default is True." udp: bool = True "Indicates if the software uses UDP protocol for communication. Default is True." - port: Port + port: int "The port to which the software is connected." - listen_on_ports: Set[Port] = Field(default_factory=set) + listen_on_ports: Set[int] = Field(default_factory=set) "The set of ports to listen on." - protocol: IPProtocol + protocol: str "The IP Protocol the Software operates on." _connections: Dict[str, Dict] = {} "Active connections." @@ -277,7 +277,7 @@ class IOSoftware(Software): "max_sessions": self.max_sessions, "tcp": self.tcp, "udp": self.udp, - "port": self.port.value, + "port": self.port, } ) return state @@ -386,8 +386,8 @@ class IOSoftware(Software): payload: Any, session_id: Optional[str] = None, dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, - dest_port: Optional[Port] = None, - ip_protocol: IPProtocol = IPProtocol.TCP, + dest_port: Optional[int] = None, + ip_protocol: str = IPProtocol["TCP"], **kwargs, ) -> bool: """ diff --git a/tests/conftest.py b/tests/conftest.py index 1bbff8f2..1ffa2146 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,8 +45,8 @@ class DummyService(Service): def __init__(self, **kwargs): kwargs["name"] = "DummyService" - kwargs["port"] = Port.HTTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["HTTP"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) def receive(self, payload: Any, session_id: str, **kwargs) -> bool: @@ -58,8 +58,8 @@ class DummyApplication(Application, identifier="DummyApplication"): def __init__(self, **kwargs): kwargs["name"] = "DummyApplication" - kwargs["port"] = Port.HTTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["HTTP"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: @@ -77,7 +77,7 @@ def uc2_network() -> Network: @pytest.fixture(scope="function") def service(file_system) -> DummyService: return DummyService( - name="DummyService", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_service") + name="DummyService", port=Port["ARP"], file_system=file_system, sys_log=SysLog(hostname="dummy_service") ) @@ -90,7 +90,7 @@ def service_class(): def application(file_system) -> DummyApplication: return DummyApplication( name="DummyApplication", - port=Port.ARP, + port=Port["ARP"], file_system=file_system, sys_log=SysLog(hostname="dummy_application"), ) @@ -350,10 +350,10 @@ def install_stuff_to_sim(sim: Simulation): network.connect(endpoint_a=server_2.network_interface[1], endpoint_b=switch_2.network_interface[2]) # 2: Configure base ACL - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["DNS"], dst_port=Port["DNS"], position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=3) # 3: Install server software server_1.software_manager.install(DNSServer) @@ -379,13 +379,13 @@ def install_stuff_to_sim(sim: Simulation): r = sim.network.router_nodes[0] for i, acl_rule in enumerate(r.acl.acl): if i == 1: - assert acl_rule.src_port == acl_rule.dst_port == Port.DNS + assert acl_rule.src_port == acl_rule.dst_port == Port["DNS"] elif i == 3: - assert acl_rule.src_port == acl_rule.dst_port == Port.HTTP + assert acl_rule.src_port == acl_rule.dst_port == Port["HTTP"] elif i == 22: - assert acl_rule.src_port == acl_rule.dst_port == Port.ARP + assert acl_rule.src_port == acl_rule.dst_port == Port["ARP"] elif i == 23: - assert acl_rule.protocol == IPProtocol.ICMP + assert acl_rule.protocol == IPProtocol["ICMP"] elif i == 24: ... else: diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py index 457fdb42..d35e2ebb 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py @@ -68,44 +68,44 @@ def test_firewall_acl_rules_correctly_added(dmz_config): # ICMP and ARP should be allowed internal_inbound assert firewall.internal_inbound_acl.num_rules == 2 assert firewall.internal_inbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.internal_inbound_acl.acl[22].src_port == Port.ARP - assert firewall.internal_inbound_acl.acl[22].dst_port == Port.ARP + assert firewall.internal_inbound_acl.acl[22].src_port == Port["ARP"] + assert firewall.internal_inbound_acl.acl[22].dst_port == Port["ARP"] assert firewall.internal_inbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.internal_inbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.internal_inbound_acl.acl[23].protocol == IPProtocol["ICMP"] assert firewall.internal_inbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed internal_outbound assert firewall.internal_outbound_acl.num_rules == 2 assert firewall.internal_outbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.internal_outbound_acl.acl[22].src_port == Port.ARP - assert firewall.internal_outbound_acl.acl[22].dst_port == Port.ARP + assert firewall.internal_outbound_acl.acl[22].src_port == Port["ARP"] + assert firewall.internal_outbound_acl.acl[22].dst_port == Port["ARP"] assert firewall.internal_outbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.internal_outbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.internal_outbound_acl.acl[23].protocol == IPProtocol["ICMP"] assert firewall.internal_outbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed dmz_inbound assert firewall.dmz_inbound_acl.num_rules == 2 assert firewall.dmz_inbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.dmz_inbound_acl.acl[22].src_port == Port.ARP - assert firewall.dmz_inbound_acl.acl[22].dst_port == Port.ARP + assert firewall.dmz_inbound_acl.acl[22].src_port == Port["ARP"] + assert firewall.dmz_inbound_acl.acl[22].dst_port == Port["ARP"] assert firewall.dmz_inbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.dmz_inbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.dmz_inbound_acl.acl[23].protocol == IPProtocol["ICMP"] assert firewall.dmz_inbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed dmz_outbound assert firewall.dmz_outbound_acl.num_rules == 2 assert firewall.dmz_outbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.dmz_outbound_acl.acl[22].src_port == Port.ARP - assert firewall.dmz_outbound_acl.acl[22].dst_port == Port.ARP + assert firewall.dmz_outbound_acl.acl[22].src_port == Port["ARP"] + assert firewall.dmz_outbound_acl.acl[22].dst_port == Port["ARP"] assert firewall.dmz_outbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.dmz_outbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.dmz_outbound_acl.acl[23].protocol == IPProtocol["ICMP"] assert firewall.dmz_outbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed external_inbound assert firewall.external_inbound_acl.num_rules == 1 assert firewall.external_inbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.external_inbound_acl.acl[22].src_port == Port.ARP - assert firewall.external_inbound_acl.acl[22].dst_port == Port.ARP + assert firewall.external_inbound_acl.acl[22].src_port == Port["ARP"] + assert firewall.external_inbound_acl.acl[22].dst_port == Port["ARP"] # external_inbound should have implicit action PERMIT # ICMP does not have a provided ACL Rule but implicit action should allow anything assert firewall.external_inbound_acl.implicit_action == ACLAction.PERMIT @@ -113,8 +113,8 @@ def test_firewall_acl_rules_correctly_added(dmz_config): # ICMP and ARP should be allowed external_outbound assert firewall.external_outbound_acl.num_rules == 1 assert firewall.external_outbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.external_outbound_acl.acl[22].src_port == Port.ARP - assert firewall.external_outbound_acl.acl[22].dst_port == Port.ARP + assert firewall.external_outbound_acl.acl[22].src_port == Port["ARP"] + assert firewall.external_outbound_acl.acl[22].dst_port == Port["ARP"] # external_outbound should have implicit action PERMIT # ICMP does not have a provided ACL Rule but implicit action should allow anything assert firewall.external_outbound_acl.implicit_action == ACLAction.PERMIT diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py index ccde3a02..16543565 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py @@ -63,8 +63,8 @@ def test_router_acl_rules_correctly_added(dmz_config): # ICMP and ARP should be allowed assert router_1.acl.num_rules == 2 assert router_1.acl.acl[22].action == ACLAction.PERMIT - assert router_1.acl.acl[22].src_port == Port.ARP - assert router_1.acl.acl[22].dst_port == Port.ARP + assert router_1.acl.acl[22].src_port == Port["ARP"] + assert router_1.acl.acl[22].dst_port == Port["ARP"] assert router_1.acl.acl[23].action == ACLAction.PERMIT - assert router_1.acl.acl[23].protocol == IPProtocol.ICMP + assert router_1.acl.acl[23].protocol == IPProtocol["ICMP"] assert router_1.acl.implicit_action == ACLAction.DENY diff --git a/tests/integration_tests/extensions/applications/extended_application.py b/tests/integration_tests/extensions/applications/extended_application.py index c9b3006d..8e3d33e1 100644 --- a/tests/integration_tests/extensions/applications/extended_application.py +++ b/tests/integration_tests/extensions/applications/extended_application.py @@ -44,10 +44,10 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): def __init__(self, **kwargs): kwargs["name"] = "ExtendedApplication" - kwargs["protocol"] = IPProtocol.TCP + kwargs["protocol"] = IPProtocol["TCP"] # default for web is port 80 if kwargs.get("port") is None: - kwargs["port"] = Port.HTTP + kwargs["port"] = Port["HTTP"] super().__init__(**kwargs) self.run() @@ -127,7 +127,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): if self.send( payload=payload, dest_ip_address=self.domain_name_ip_address, - dest_port=parsed_url.port if parsed_url.port else Port.HTTP, + dest_port=parsed_url.port if parsed_url.port else Port["HTTP"], ): self.sys_log.info( f"{self.name}: Received HTTP {payload.request_method.name} " @@ -155,7 +155,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): self, payload: HttpRequestPacket, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = Port.HTTP, + dest_port: Optional[int] = Port["HTTP"], session_id: Optional[str] = None, **kwargs, ) -> bool: diff --git a/tests/integration_tests/extensions/services/extended_service.py b/tests/integration_tests/extensions/services/extended_service.py index 3151571b..d4af600f 100644 --- a/tests/integration_tests/extensions/services/extended_service.py +++ b/tests/integration_tests/extensions/services/extended_service.py @@ -38,8 +38,8 @@ class ExtendedService(Service, identifier='extendedservice'): def __init__(self, **kwargs): kwargs["name"] = "ExtendedService" - kwargs["port"] = Port.POSTGRES_SERVER - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["POSTGRES_SERVER"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) self._create_db_file() if kwargs.get('options'): diff --git a/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py b/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py index 806ce063..17b0ba8c 100644 --- a/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py +++ b/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py @@ -26,9 +26,9 @@ def game_and_agent_fixture(game_and_agent): game, agent = game_and_agent router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=4) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=5) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=6) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=4) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["DNS"], dst_port=Port["DNS"], position=5) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["FTP"], dst_port=Port["FTP"], position=6) c2_server_host = game.simulation.network.get_node_by_hostname("client_1") c2_server_host.software_manager.install(software_class=C2Server) diff --git a/tests/integration_tests/game_layer/actions/test_configure_actions.py b/tests/integration_tests/game_layer/actions/test_configure_actions.py index 0c9ec6f0..34ee25d6 100644 --- a/tests/integration_tests/game_layer/actions/test_configure_actions.py +++ b/tests/integration_tests/game_layer/actions/test_configure_actions.py @@ -200,7 +200,7 @@ class TestConfigureDoSBot: game.step() assert dos_bot.target_ip_address == IPv4Address("192.168.1.99") - assert dos_bot.target_port == Port.POSTGRES_SERVER + assert dos_bot.target_port == Port["POSTGRES_SERVER"] assert dos_bot.payload == "HACC" assert not dos_bot.repeat assert dos_bot.port_scan_p_of_success == 0.875 diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py index d011c1e8..857edd26 100644 --- a/tests/integration_tests/game_layer/actions/test_terminal_actions.py +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -20,7 +20,7 @@ def game_and_agent_fixture(game_and_agent): game, agent = game_and_agent router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=4) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["SSH"], dst_port=Port["SSH"], position=4) return (game, agent) diff --git a/tests/integration_tests/game_layer/observations/test_acl_observations.py b/tests/integration_tests/game_layer/observations/test_acl_observations.py index f1d9d416..398c43a9 100644 --- a/tests/integration_tests/game_layer/observations/test_acl_observations.py +++ b/tests/integration_tests/game_layer/observations/test_acl_observations.py @@ -33,7 +33,7 @@ def test_acl_observations(simulation): server.software_manager.install(NTPServer) # add router acl rule - router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port.NTP, src_port=Port.NTP, position=1) + router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port["NTP"], src_port=Port["NTP"], position=1) acl_obs = ACLObservation( where=["network", "nodes", router.hostname, "acl", "acl"], diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 34a37f5e..68506d59 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -62,13 +62,13 @@ def test_firewall_observation(): # add a rule to the internal inbound and check that the observation is correct firewall.internal_inbound_acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="10.0.0.1", src_wildcard_mask="0.0.0.1", dst_ip_address="10.0.0.2", dst_wildcard_mask="0.0.0.1", - src_port=Port.HTTP, - dst_port=Port.HTTP, + src_port=Port["HTTP"], + dst_port=Port["HTTP"], position=5, ) diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index ef789ba7..bd8dfc4e 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -152,7 +152,7 @@ def test_config_nic_categories(simulation): def test_nic_monitored_traffic(simulation): - monitored_traffic = {"icmp": ["NONE"], "tcp": ["DNS"]} + monitored_traffic = {"icmp": ["NONE"], "tcp": [53,]} pc: Computer = simulation.network.get_node_by_hostname("client_1") pc2: Computer = simulation.network.get_node_by_hostname("client_2") diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index 48d29cfb..937bb061 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -39,13 +39,13 @@ def test_router_observation(): # Add an ACL rule to the router router.acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="10.0.0.1", src_wildcard_mask="0.0.0.1", dst_ip_address="10.0.0.2", dst_wildcard_mask="0.0.0.1", - src_port=Port.HTTP, - dst_port=Port.HTTP, + src_port=Port["HTTP"], + dst_port=Port["HTTP"], position=5, ) # Observe the state using the RouterObservation instance diff --git a/tests/integration_tests/game_layer/observations/test_user_observations.py b/tests/integration_tests/game_layer/observations/test_user_observations.py index ca5e2543..6ca4bc9e 100644 --- a/tests/integration_tests/game_layer/observations/test_user_observations.py +++ b/tests/integration_tests/game_layer/observations/test_user_observations.py @@ -15,7 +15,7 @@ def env_with_ssh() -> PrimaiteGymEnv: env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG) env.agent.flatten_obs = False router: Router = env.game.simulation.network.get_node_by_hostname("router_1") - router.acl.add_rule(ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=3) + router.acl.add_rule(ACLAction.PERMIT, src_port=Port["SSH"], dst_port=Port["SSH"], position=3) return env diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index a1005f34..c3e86263 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -608,9 +608,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.internal_outbound_acl.acl[1].action.name == "DENY" assert firewall.internal_outbound_acl.acl[1].src_ip_address == IPv4Address("192.168.0.10") assert firewall.internal_outbound_acl.acl[1].dst_ip_address is None - assert firewall.internal_outbound_acl.acl[1].dst_port == Port.DNS - assert firewall.internal_outbound_acl.acl[1].src_port == Port.ARP - assert firewall.internal_outbound_acl.acl[1].protocol == IPProtocol.ICMP + assert firewall.internal_outbound_acl.acl[1].dst_port == Port["DNS"] + assert firewall.internal_outbound_acl.acl[1].src_port == Port["ARP"] + assert firewall.internal_outbound_acl.acl[1].protocol == IPProtocol["ICMP"] env.step(4) # Remove ACL rule from Internal Outbound assert firewall.internal_outbound_acl.num_rules == 2 @@ -620,9 +620,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.dmz_inbound_acl.acl[1].action.name == "DENY" assert firewall.dmz_inbound_acl.acl[1].src_ip_address == IPv4Address("192.168.10.10") assert firewall.dmz_inbound_acl.acl[1].dst_ip_address == IPv4Address("192.168.0.10") - assert firewall.dmz_inbound_acl.acl[1].dst_port == Port.HTTP - assert firewall.dmz_inbound_acl.acl[1].src_port == Port.HTTP - assert firewall.dmz_inbound_acl.acl[1].protocol == IPProtocol.UDP + assert firewall.dmz_inbound_acl.acl[1].dst_port == Port["HTTP"] + assert firewall.dmz_inbound_acl.acl[1].src_port == Port["HTTP"] + assert firewall.dmz_inbound_acl.acl[1].protocol == IPProtocol["UDP"] env.step(6) # Remove ACL rule from DMZ Inbound assert firewall.dmz_inbound_acl.num_rules == 2 @@ -632,9 +632,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.dmz_outbound_acl.acl[2].action.name == "DENY" assert firewall.dmz_outbound_acl.acl[2].src_ip_address == IPv4Address("192.168.10.10") assert firewall.dmz_outbound_acl.acl[2].dst_ip_address == IPv4Address("192.168.0.10") - assert firewall.dmz_outbound_acl.acl[2].dst_port == Port.HTTP - assert firewall.dmz_outbound_acl.acl[2].src_port == Port.HTTP - assert firewall.dmz_outbound_acl.acl[2].protocol == IPProtocol.TCP + assert firewall.dmz_outbound_acl.acl[2].dst_port == Port["HTTP"] + assert firewall.dmz_outbound_acl.acl[2].src_port == Port["HTTP"] + assert firewall.dmz_outbound_acl.acl[2].protocol == IPProtocol["TCP"] env.step(8) # Remove ACL rule from DMZ Outbound assert firewall.dmz_outbound_acl.num_rules == 2 @@ -644,9 +644,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.external_inbound_acl.acl[10].action.name == "DENY" assert firewall.external_inbound_acl.acl[10].src_ip_address == IPv4Address("192.168.20.10") assert firewall.external_inbound_acl.acl[10].dst_ip_address == IPv4Address("192.168.10.10") - assert firewall.external_inbound_acl.acl[10].dst_port == Port.POSTGRES_SERVER - assert firewall.external_inbound_acl.acl[10].src_port == Port.POSTGRES_SERVER - assert firewall.external_inbound_acl.acl[10].protocol == IPProtocol.ICMP + assert firewall.external_inbound_acl.acl[10].dst_port == Port["POSTGRES_SERVER"] + assert firewall.external_inbound_acl.acl[10].src_port == Port["POSTGRES_SERVER"] + assert firewall.external_inbound_acl.acl[10].protocol == IPProtocol["ICMP"] env.step(10) # Remove ACL rule from External Inbound assert firewall.external_inbound_acl.num_rules == 1 diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 58783d70..d872c2b0 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -42,7 +42,7 @@ def test_WebpageUnavailablePenalty(game_and_agent): # Block the web traffic, check that failing to fetch the webpage yields a reward of -0.7 router: Router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.DENY, protocol=IPProtocol.TCP, src_port=Port.HTTP, dst_port=Port.HTTP) + router.acl.add_rule(action=ACLAction.DENY, protocol=IPProtocol["TCP"], src_port=Port["HTTP"], dst_port=Port["HTTP"]) agent.store_action(("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0})) game.step() assert agent.reward_function.current_reward == -0.7 @@ -65,7 +65,7 @@ def test_uc2_rewards(game_and_agent): db_client.run() router: Router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=2) + router.acl.add_rule(ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=2) comp = GreenAdminDatabaseUnreachablePenalty("client_1") diff --git a/tests/integration_tests/network/test_airspace_config.py b/tests/integration_tests/network/test_airspace_config.py index 78d00b47..1794c4bc 100644 --- a/tests/integration_tests/network/test_airspace_config.py +++ b/tests/integration_tests/network/test_airspace_config.py @@ -13,8 +13,8 @@ def test_override_freq_max_capacity_mbps(): config_dict = yaml.safe_load(f) network = PrimaiteGame.from_config(cfg=config_dict).simulation.network - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_2_4) == 123.45 - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_5) == 0.0 + assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_2_4"]) == 123.45 + assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_5"]) == 0.0 pc_a = network.get_node_by_hostname("pc_a") pc_b = network.get_node_by_hostname("pc_b") @@ -32,8 +32,8 @@ def test_override_freq_max_capacity_mbps_blocked(): config_dict = yaml.safe_load(f) network = PrimaiteGame.from_config(cfg=config_dict).simulation.network - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_2_4) == 0.0 - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_5) == 0.0 + assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_2_4"]) == 0.0 + assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_5"]) == 0.0 pc_a = network.get_node_by_hostname("pc_a") pc_b = network.get_node_by_hostname("pc_b") diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index 80007c46..da0af89d 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -20,8 +20,8 @@ class BroadcastTestService(Service): def __init__(self, **kwargs): # Set default service properties for broadcasting kwargs["name"] = "BroadcastService" - kwargs["port"] = Port.HTTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["HTTP"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: @@ -33,12 +33,12 @@ class BroadcastTestService(Service): super().send( payload="unicast", dest_ip_address=ip_address, - dest_port=Port.HTTP, + dest_port=Port["HTTP"], ) def broadcast(self, ip_network: IPv4Network): # Send a broadcast payload to an entire IP network - super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port.HTTP, ip_protocol=self.protocol) + super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port["HTTP"], ip_protocol=self.protocol) class BroadcastTestClient(Application, identifier="BroadcastTestClient"): @@ -49,8 +49,8 @@ class BroadcastTestClient(Application, identifier="BroadcastTestClient"): def __init__(self, **kwargs): # Set default client properties kwargs["name"] = "BroadcastTestClient" - kwargs["port"] = Port.HTTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["HTTP"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/tests/integration_tests/network/test_firewall.py b/tests/integration_tests/network/test_firewall.py index b15ee51a..8e06ccfb 100644 --- a/tests/integration_tests/network/test_firewall.py +++ b/tests/integration_tests/network/test_firewall.py @@ -53,28 +53,28 @@ def dmz_external_internal_network() -> Network: ) # Allow ICMP - firewall_node.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.external_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.external_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + firewall_node.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + firewall_node.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + firewall_node.external_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + firewall_node.external_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) # Allow ARP firewall_node.internal_inbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22 + action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 ) firewall_node.internal_outbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22 + action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 ) firewall_node.external_inbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22 + action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 ) firewall_node.external_outbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22 + action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 ) - firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) + firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) # external node external_node = Computer( @@ -262,8 +262,8 @@ def test_service_allowed_with_rule(dmz_external_internal_network): assert not internal_ntp_client.time - firewall.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=1) - firewall.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=1) + firewall.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1) + firewall.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1) internal_ntp_client.request_time() @@ -271,8 +271,8 @@ def test_service_allowed_with_rule(dmz_external_internal_network): assert not dmz_ntp_client.time - firewall.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=1) - firewall.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=1) + firewall.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1) + firewall.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1) dmz_ntp_client.request_time() diff --git a/tests/integration_tests/network/test_routing.py b/tests/integration_tests/network/test_routing.py index 62b58cbd..641342e2 100644 --- a/tests/integration_tests/network/test_routing.py +++ b/tests/integration_tests/network/test_routing.py @@ -73,8 +73,8 @@ def multi_hop_network() -> Network: router_1.enable_port(2) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) # Configure PC B pc_b = Computer( @@ -197,8 +197,8 @@ def test_routing_services(multi_hop_network): router_1: Router = multi_hop_network.get_node_by_hostname("router_1") # noqa router_2: Router = multi_hop_network.get_node_by_hostname("router_2") # noqa - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=21) - router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=21) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=21) + router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=21) assert ntp_client.time is None ntp_client.request_time() diff --git a/tests/integration_tests/network/test_wireless_router.py b/tests/integration_tests/network/test_wireless_router.py index 733de6f6..2f1be930 100644 --- a/tests/integration_tests/network/test_wireless_router.py +++ b/tests/integration_tests/network/test_wireless_router.py @@ -37,8 +37,8 @@ def wireless_wan_network(): network.connect(pc_a.network_interface[1], router_1.network_interface[2]) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) # Configure PC B pc_b = Computer( diff --git a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py index 9d12f2cf..d819b511 100644 --- a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py +++ b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py @@ -227,7 +227,7 @@ def test_c2_suite_acl_block(basic_network): assert c2_beacon.c2_connection_active == True # Now we add a HTTP blocking acl (Thus preventing a keep alive) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) c2_beacon.apply_timestep(2) c2_beacon.apply_timestep(3) @@ -322,8 +322,8 @@ def test_c2_suite_acl_bypass(basic_network): ################ Confirm Default Setup ######################### # Permitting all HTTP & FTP traffic - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["FTP"], dst_port=Port["FTP"], position=1) c2_beacon.apply_timestep(0) assert c2_beacon.keep_alive_inactivity == 1 @@ -337,7 +337,7 @@ def test_c2_suite_acl_bypass(basic_network): ################ Denying HTTP Traffic ######################### # Now we add a HTTP blocking acl (Thus preventing a keep alive) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) blocking_acl: AccessControlList = router.acl.acl[0] # Asserts to show the C2 Suite is unable to maintain connection: @@ -359,8 +359,8 @@ def test_c2_suite_acl_bypass(basic_network): c2_beacon.configure( c2_server_ip_address="192.168.0.2", keep_alive_frequency=2, - masquerade_port=Port.FTP, - masquerade_protocol=IPProtocol.TCP, + masquerade_port=Port["FTP"], + masquerade_protocol=IPProtocol["TCP"], ) c2_beacon.establish() @@ -407,8 +407,8 @@ def test_c2_suite_acl_bypass(basic_network): ################ Denying FTP Traffic & Enable HTTP ######################### # Blocking FTP and re-permitting HTTP: - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.FTP, dst_port=Port.FTP, position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=Port["FTP"], dst_port=Port["FTP"], position=1) blocking_acl: AccessControlList = router.acl.acl[1] # Asserts to show the C2 Suite is unable to maintain connection: @@ -430,8 +430,8 @@ def test_c2_suite_acl_bypass(basic_network): c2_beacon.configure( c2_server_ip_address="192.168.0.2", keep_alive_frequency=2, - masquerade_port=Port.HTTP, - masquerade_protocol=IPProtocol.TCP, + masquerade_port=Port["HTTP"], + masquerade_protocol=IPProtocol["TCP"], ) c2_beacon.establish() diff --git a/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py b/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py index 2e87578d..9c0760b7 100644 --- a/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py +++ b/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py @@ -52,7 +52,7 @@ def data_manipulation_db_server_green_client(example_network) -> Network: router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 ) client_1: Computer = network.get_node_by_hostname("client_1") diff --git a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py index 68c1fbfe..709a417f 100644 --- a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py +++ b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py @@ -26,7 +26,7 @@ def dos_bot_and_db_server(client_server) -> Tuple[DoSBot, Computer, DatabaseServ dos_bot: DoSBot = computer.software_manager.software.get("DoSBot") dos_bot.configure( target_ip_address=IPv4Address(server.network_interface[1].ip_address), - target_port=Port.POSTGRES_SERVER, + target_port=Port["POSTGRES_SERVER"], ) # Install DB Server service on server @@ -43,7 +43,7 @@ def dos_bot_db_server_green_client(example_network) -> Network: router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 ) client_1: Computer = network.get_node_by_hostname("client_1") @@ -56,7 +56,7 @@ def dos_bot_db_server_green_client(example_network) -> Network: dos_bot: DoSBot = client_1.software_manager.software.get("DoSBot") dos_bot.configure( target_ip_address=IPv4Address(server.network_interface[1].ip_address), - target_port=Port.POSTGRES_SERVER, + target_port=Port["POSTGRES_SERVER"], ) # install db server service on server diff --git a/tests/integration_tests/system/red_applications/test_ransomware_script.py b/tests/integration_tests/system/red_applications/test_ransomware_script.py index 97abafb5..b34e9b30 100644 --- a/tests/integration_tests/system/red_applications/test_ransomware_script.py +++ b/tests/integration_tests/system/red_applications/test_ransomware_script.py @@ -47,7 +47,7 @@ def ransomware_script_db_server_green_client(example_network) -> Network: router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 ) client_1: Computer = network.get_node_by_hostname("client_1") diff --git a/tests/integration_tests/system/test_nmap.py b/tests/integration_tests/system/test_nmap.py index 2b8691cc..1064ed1b 100644 --- a/tests/integration_tests/system/test_nmap.py +++ b/tests/integration_tests/system/test_nmap.py @@ -73,10 +73,10 @@ def test_port_scan_one_node_one_port(example_network): client_2 = network.get_node_by_hostname("client_2") actual_result = client_1_nmap.port_scan( - target_ip_address=client_2.network_interface[1].ip_address, target_port=Port.DNS, target_protocol=IPProtocol.TCP + target_ip_address=client_2.network_interface[1].ip_address, target_port=Port["DNS"], target_protocol=IPProtocol["TCP"] ) - expected_result = {IPv4Address("192.168.10.22"): {IPProtocol.TCP: [Port.DNS]}} + expected_result = {IPv4Address("192.168.10.22"): {IPProtocol["TCP"]: [Port["DNS"]]}} assert actual_result == expected_result @@ -101,14 +101,14 @@ def test_port_scan_full_subnet_all_ports_and_protocols(example_network): actual_result = client_1_nmap.port_scan( target_ip_address=IPv4Network("192.168.10.0/24"), - target_port=[Port.ARP, Port.HTTP, Port.FTP, Port.DNS, Port.NTP], + target_port=[Port["ARP"], Port["HTTP"], Port["FTP"], Port["DNS"], Port["NTP"]], ) expected_result = { - IPv4Address("192.168.10.1"): {IPProtocol.UDP: [Port.ARP]}, + IPv4Address("192.168.10.1"): {IPProtocol["UDP"]: [Port["ARP"]]}, IPv4Address("192.168.10.22"): { - IPProtocol.TCP: [Port.HTTP, Port.FTP, Port.DNS], - IPProtocol.UDP: [Port.ARP, Port.NTP], + IPProtocol["TCP"]: [Port["HTTP"], Port["FTP"], Port["DNS"]], + IPProtocol["UDP"]: [Port["ARP"], Port["NTP"]], }, } @@ -122,10 +122,10 @@ def test_network_service_recon_all_ports_and_protocols(example_network): client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa actual_result = client_1_nmap.network_service_recon( - target_ip_address=IPv4Network("192.168.10.0/24"), target_port=Port.HTTP, target_protocol=IPProtocol.TCP + target_ip_address=IPv4Network("192.168.10.0/24"), target_port=Port["HTTP"], target_protocol=IPProtocol["TCP"] ) - expected_result = {IPv4Address("192.168.10.22"): {IPProtocol.TCP: [Port.HTTP]}} + expected_result = {IPv4Address("192.168.10.22"): {IPProtocol["TCP"]: [Port["HTTP"]]}} assert sort_dict(actual_result) == sort_dict(expected_result) diff --git a/tests/integration_tests/system/test_service_listening_on_ports.py b/tests/integration_tests/system/test_service_listening_on_ports.py index fd502a70..5226ab4a 100644 --- a/tests/integration_tests/system/test_service_listening_on_ports.py +++ b/tests/integration_tests/system/test_service_listening_on_ports.py @@ -16,9 +16,9 @@ from tests import TEST_ASSETS_ROOT class _DatabaseListener(Service): name: str = "DatabaseListener" - protocol: IPProtocol = IPProtocol.TCP - port: Port = Port.NONE - listen_on_ports: Set[Port] = {Port.POSTGRES_SERVER} + protocol: str = IPProtocol["TCP"] + port: int = Port["NONE"] + listen_on_ports: Set[int] = {Port["POSTGRES_SERVER"]} payloads_received: List[Any] = Field(default_factory=list) def receive(self, payload: Any, session_id: str, **kwargs) -> bool: @@ -51,8 +51,8 @@ def test_http_listener(client_server): computer.session_manager.receive_payload_from_software_manager( payload="masquerade as Database traffic", dst_ip_address=server.network_interface[1].ip_address, - dst_port=Port.POSTGRES_SERVER, - ip_protocol=IPProtocol.TCP, + dst_port=Port["POSTGRES_SERVER"], + ip_protocol=IPProtocol["TCP"], ) assert len(server_db_listener.payloads_received) == 1 @@ -76,9 +76,9 @@ def test_set_listen_on_ports_from_config(): network = PrimaiteGame.from_config(cfg=config_dict).simulation.network client: Computer = network.get_node_by_hostname("client") - assert Port.SMB in client.software_manager.get_open_ports() - assert Port.IPP in client.software_manager.get_open_ports() + assert Port["SMB"] in client.software_manager.get_open_ports() + assert Port["IPP"] in client.software_manager.get_open_ports() web_browser = client.software_manager.software["WebBrowser"] - assert not web_browser.listen_on_ports.difference({Port.SMB, Port.IPP}) + assert not web_browser.listen_on_ports.difference({Port["SMB"], Port["IPP"]}) diff --git a/tests/integration_tests/system/test_web_client_server_and_database.py b/tests/integration_tests/system/test_web_client_server_and_database.py index 5a765763..6c37360f 100644 --- a/tests/integration_tests/system/test_web_client_server_and_database.py +++ b/tests/integration_tests/system/test_web_client_server_and_database.py @@ -24,17 +24,17 @@ def web_client_web_server_database(example_network) -> Tuple[Network, Computer, # add rules to network router router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 ) # Allow DNS requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["DNS"], dst_port=Port["DNS"], position=1) # Allow FTP requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=2) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["FTP"], dst_port=Port["FTP"], position=2) # Open port 80 for web server - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=3) # Create Computer computer: Computer = example_network.get_node_by_hostname("client_1") @@ -148,7 +148,7 @@ class TestWebBrowserHistory: assert web_browser.history[-1].response_code == 200 router = network.get_node_by_hostname("router_1") - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) assert not web_browser.get_webpage() assert len(web_browser.history) == 3 # with current NIC behaviour, even if you block communication, you won't get SERVER_UNREACHABLE because @@ -166,7 +166,7 @@ class TestWebBrowserHistory: web_browser.get_webpage() router = network.get_node_by_hostname("router_1") - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) web_browser.get_webpage() state = computer.describe_state() diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py index 95634cf1..ff73e621 100644 --- a/tests/integration_tests/test_simulation/test_request_response.py +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -171,7 +171,7 @@ class TestDataManipulationGreenRequests: assert client_1_browser_execute.status == "success" assert client_2_browser_execute.status == "success" - router.acl.add_rule(ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + router.acl.add_rule(ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=3) client_1_browser_execute = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"]) client_2_browser_execute = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"]) assert client_1_browser_execute.status == "failure" @@ -182,7 +182,7 @@ class TestDataManipulationGreenRequests: assert client_1_db_client_execute.status == "success" assert client_2_db_client_execute.status == "success" - router.acl.add_rule(ACLAction.DENY, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER) + router.acl.add_rule(ACLAction.DENY, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"]) client_1_db_client_execute = net.apply_request(["node", "client_1", "application", "DatabaseClient", "execute"]) client_2_db_client_execute = net.apply_request(["node", "client_2", "application", "DatabaseClient", "execute"]) assert client_1_db_client_execute.status == "failure" diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py index 9bc1abfd..4c471faa 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py @@ -28,20 +28,20 @@ def router_with_acl_rules(): # Add rules here as needed acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="192.168.1.1", - src_port=Port.HTTPS, + src_port=Port["HTTPS"], dst_ip_address="192.168.1.2", - dst_port=Port.HTTP, + dst_port=Port["HTTP"], position=1, ) acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="192.168.1.3", - src_port=Port(8080), + src_port=8080, dst_ip_address="192.168.1.4", - dst_port=Port(80), + dst_port=80, position=2, ) return router @@ -65,21 +65,21 @@ def router_with_wildcard_acl(): # Rule to permit traffic from a specific source IP and port to a specific destination IP and port acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="192.168.1.1", - src_port=Port(8080), + src_port=8080, dst_ip_address="10.1.1.2", - dst_port=Port(80), + dst_port=80, position=1, ) # Rule to deny traffic from an IP range to a specific destination IP and port acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="192.168.1.0", src_wildcard_mask="0.0.0.255", dst_ip_address="10.1.1.3", - dst_port=Port(443), + dst_port=443, position=2, ) # Rule to permit any traffic to a range of destination IPs @@ -109,11 +109,11 @@ def test_add_rule(router_with_acl_rules): acl = router_with_acl_rules.acl assert acl.acl[1].action == ACLAction.PERMIT - assert acl.acl[1].protocol == IPProtocol.TCP + assert acl.acl[1].protocol == IPProtocol["TCP"] assert acl.acl[1].src_ip_address == IPv4Address("192.168.1.1") - assert acl.acl[1].src_port == Port.HTTPS + assert acl.acl[1].src_port == Port["HTTPS"] assert acl.acl[1].dst_ip_address == IPv4Address("192.168.1.2") - assert acl.acl[1].dst_port == Port.HTTP + assert acl.acl[1].dst_port == Port["HTTP"] def test_remove_rule(router_with_acl_rules): @@ -136,8 +136,8 @@ def test_traffic_permitted_by_specific_rule(router_with_acl_rules): acl = router_with_acl_rules.acl permitted_frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="192.168.1.2", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port.HTTPS, dst_port=Port.HTTP), + ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="192.168.1.2", protocol=IPProtocol["TCP"]), + tcp=TCPHeader(src_port=Port["HTTPS"], dst_port=Port["HTTP"]), ) is_permitted, _ = acl.is_permitted(permitted_frame) assert is_permitted @@ -153,8 +153,8 @@ def test_traffic_denied_by_specific_rule(router_with_acl_rules): acl = router_with_acl_rules.acl not_permitted_frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.3", dst_ip_address="192.168.1.4", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port(8080), dst_port=Port(80)), + ip=IPPacket(src_ip_address="192.168.1.3", dst_ip_address="192.168.1.4", protocol=IPProtocol["TCP"]), + tcp=TCPHeader(src_port=8080, dst_port=80), ) is_permitted, _ = acl.is_permitted(not_permitted_frame) assert not is_permitted @@ -173,8 +173,8 @@ def test_default_rule(router_with_acl_rules): acl = router_with_acl_rules.acl not_permitted_frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.5", dst_ip_address="192.168.1.12", protocol=IPProtocol.UDP), - udp=UDPHeader(src_port=Port.HTTPS, dst_port=Port.HTTP), + ip=IPPacket(src_ip_address="192.168.1.5", dst_ip_address="192.168.1.12", protocol=IPProtocol["UDP"]), + udp=UDPHeader(src_port=Port["HTTPS"], dst_port=Port["HTTP"]), ) is_permitted, rule = acl.is_permitted(not_permitted_frame) assert not is_permitted @@ -189,8 +189,8 @@ def test_direct_ip_match_with_acl(router_with_wildcard_acl): acl = router_with_wildcard_acl.acl frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="10.1.1.2", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port(8080), dst_port=Port(80)), + ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="10.1.1.2", protocol=IPProtocol["TCP"]), + tcp=TCPHeader(src_port=8080, dst_port=80), ) assert acl.is_permitted(frame)[0], "Direct IP match should be permitted." @@ -204,8 +204,8 @@ def test_ip_range_match_denied_with_acl(router_with_wildcard_acl): acl = router_with_wildcard_acl.acl frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.100", dst_ip_address="10.1.1.3", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port(8080), dst_port=Port(443)), + ip=IPPacket(src_ip_address="192.168.1.100", dst_ip_address="10.1.1.3", protocol=IPProtocol["TCP"]), + tcp=TCPHeader(src_port=8080, dst_port=443), ) assert not acl.is_permitted(frame)[0], "IP range match with wildcard mask should be denied." @@ -219,8 +219,8 @@ def test_traffic_permitted_to_destination_range_with_acl(router_with_wildcard_ac acl = router_with_wildcard_acl.acl frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=IPProtocol.UDP), - udp=UDPHeader(src_port=Port(1433), dst_port=Port(1433)), + ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=IPProtocol["UDP"]), + udp=UDPHeader(src_port=1433, dst_port=1433), ) assert acl.is_permitted(frame)[0], "Traffic to destination IP range should be permitted." @@ -253,23 +253,23 @@ def test_ip_traffic_from_specific_subnet(): permitted_frame_1 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER), + ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=IPProtocol["TCP"]), + tcp=TCPHeader(src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"]), ) assert acl.is_permitted(permitted_frame_1)[0] permitted_frame_2 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.10", dst_ip_address="85.199.214.101", protocol=IPProtocol.UDP), - udp=UDPHeader(src_port=Port.NTP, dst_port=Port.NTP), + ip=IPPacket(src_ip_address="192.168.1.10", dst_ip_address="85.199.214.101", protocol=IPProtocol["UDP"]), + udp=UDPHeader(src_port=Port["NTP"], dst_port=Port["NTP"]), ) assert acl.is_permitted(permitted_frame_2)[0] permitted_frame_3 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.200", dst_ip_address="192.168.1.1", protocol=IPProtocol.ICMP), + ip=IPPacket(src_ip_address="192.168.1.200", dst_ip_address="192.168.1.1", protocol=IPProtocol["ICMP"]), icmp=ICMPPacket(identifier=1), ) @@ -277,16 +277,16 @@ def test_ip_traffic_from_specific_subnet(): not_permitted_frame_1 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.0.50", dst_ip_address="10.2.200.200", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER), + ip=IPPacket(src_ip_address="192.168.0.50", dst_ip_address="10.2.200.200", protocol=IPProtocol["TCP"]), + tcp=TCPHeader(src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"]), ) assert not acl.is_permitted(not_permitted_frame_1)[0] not_permitted_frame_2 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.2.10", dst_ip_address="85.199.214.101", protocol=IPProtocol.UDP), - udp=UDPHeader(src_port=Port.NTP, dst_port=Port.NTP), + ip=IPPacket(src_ip_address="192.168.2.10", dst_ip_address="85.199.214.101", protocol=IPProtocol["UDP"]), + udp=UDPHeader(src_port=Port["NTP"], dst_port=Port["NTP"]), ) assert not acl.is_permitted(not_permitted_frame_2)[0] diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py index d4e38ded..3551ce38 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py @@ -67,12 +67,12 @@ def test_wireless_router_from_config(): r0 = rt.acl.acl[0] assert r0.action == ACLAction.PERMIT - assert r0.src_port == r0.dst_port == Port.POSTGRES_SERVER + assert r0.src_port == r0.dst_port == Port["POSTGRES_SERVER"] assert r0.src_ip_address == r0.dst_ip_address == r0.dst_wildcard_mask == r0.src_wildcard_mask == r0.protocol == None r1 = rt.acl.acl[1] assert r1.action == ACLAction.PERMIT - assert r1.protocol == IPProtocol.ICMP + assert r1.protocol == IPProtocol["ICMP"] assert ( r1.src_ip_address == r1.dst_ip_address diff --git a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py index 92618baa..9fd39dfc 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py @@ -20,7 +20,7 @@ def test_frame_minimal_instantiation(): ) # Check network layer default values - assert frame.ip.protocol == IPProtocol.TCP + assert frame.ip.protocol == IPProtocol["TCP"] assert frame.ip.ttl == 64 assert frame.ip.precedence == Precedence.ROUTINE @@ -40,7 +40,7 @@ def test_frame_creation_fails_tcp_without_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.TCP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["TCP"]), ) @@ -49,7 +49,7 @@ def test_frame_creation_fails_udp_without_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.UDP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["UDP"]), ) @@ -58,7 +58,7 @@ def test_frame_creation_fails_tcp_with_udp_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.TCP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["TCP"]), udp=UDPHeader(src_port=8080, dst_port=80), ) @@ -68,7 +68,7 @@ def test_frame_creation_fails_udp_with_tcp_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.UDP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["UDP"]), udp=TCPHeader(src_port=8080, dst_port=80), ) @@ -77,7 +77,7 @@ def test_icmp_frame_creation(): """Tests Frame creation for ICMP.""" frame = Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.ICMP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["ICMP"]), icmp=ICMPPacket(), ) assert frame @@ -88,5 +88,5 @@ def test_icmp_frame_creation_fails_without_icmp_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.ICMP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["ICMP"]), ) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py index 885a3cb6..6e53aebc 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py @@ -129,19 +129,19 @@ def test_c2_handle_switching_port(basic_c2_network): # Assert to confirm that both the C2 server and the C2 beacon are configured correctly. assert c2_beacon.c2_config.keep_alive_frequency is 2 - assert c2_beacon.c2_config.masquerade_port is Port.HTTP - assert c2_beacon.c2_config.masquerade_protocol is IPProtocol.TCP + assert c2_beacon.c2_config.masquerade_port is Port["HTTP"] + assert c2_beacon.c2_config.masquerade_protocol is IPProtocol["TCP"] assert c2_server.c2_config.keep_alive_frequency is 2 - assert c2_server.c2_config.masquerade_port is Port.HTTP - assert c2_server.c2_config.masquerade_protocol is IPProtocol.TCP + assert c2_server.c2_config.masquerade_port is Port["HTTP"] + assert c2_server.c2_config.masquerade_protocol is IPProtocol["TCP"] # Configuring the C2 Beacon. c2_beacon.configure( c2_server_ip_address="192.168.0.1", keep_alive_frequency=2, - masquerade_port=Port.FTP, - masquerade_protocol=IPProtocol.TCP, + masquerade_port=Port["FTP"], + masquerade_protocol=IPProtocol["TCP"], ) # Asserting that the c2 applications have established a c2 connection @@ -150,11 +150,11 @@ def test_c2_handle_switching_port(basic_c2_network): # Assert to confirm that both the C2 server and the C2 beacon # Have reconfigured their C2 settings. - assert c2_beacon.c2_config.masquerade_port is Port.FTP - assert c2_beacon.c2_config.masquerade_protocol is IPProtocol.TCP + assert c2_beacon.c2_config.masquerade_port is Port["FTP"] + assert c2_beacon.c2_config.masquerade_protocol is IPProtocol["TCP"] - assert c2_server.c2_config.masquerade_port is Port.FTP - assert c2_server.c2_config.masquerade_protocol is IPProtocol.TCP + assert c2_server.c2_config.masquerade_port is Port["FTP"] + assert c2_server.c2_config.masquerade_protocol is IPProtocol["TCP"] def test_c2_handle_switching_frequency(basic_c2_network): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py index 0811d2a0..229f98fe 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py @@ -27,8 +27,8 @@ def test_create_dm_bot(dm_client): data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software.get("DataManipulationBot") assert data_manipulation_bot.name == "DataManipulationBot" - assert data_manipulation_bot.port == Port.NONE - assert data_manipulation_bot.protocol == IPProtocol.NONE + assert data_manipulation_bot.port == Port["NONE"] + assert data_manipulation_bot.protocol == IPProtocol["NONE"] assert data_manipulation_bot.payload == "DELETE" diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py index ce98d164..c274c18e 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py @@ -39,8 +39,8 @@ def test_create_web_client(): # Web Browser should be pre-installed in computer web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") assert web_browser.name is "WebBrowser" - assert web_browser.port is Port.HTTP - assert web_browser.protocol is IPProtocol.TCP + assert web_browser.port is Port["HTTP"] + assert web_browser.protocol is IPProtocol["TCP"] def test_receive_invalid_payload(web_browser): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py index e9ce4884..1a51708d 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py @@ -28,8 +28,8 @@ def test_create_dns_client(dns_client): assert dns_client is not None dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient") assert dns_client_service.name is "DNSClient" - assert dns_client_service.port is Port.DNS - assert dns_client_service.protocol is IPProtocol.TCP + assert dns_client_service.port is Port["DNS"] + assert dns_client_service.protocol is IPProtocol["TCP"] def test_dns_client_add_domain_to_cache_when_not_running(dns_client): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py index 4658fe76..8cdb1b84 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py @@ -32,8 +32,8 @@ def test_create_dns_server(dns_server): assert dns_server is not None dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer") assert dns_server_service.name is "DNSServer" - assert dns_server_service.port is Port.DNS - assert dns_server_service.protocol is IPProtocol.TCP + assert dns_server_service.port is Port["DNS"] + assert dns_server_service.protocol is IPProtocol["TCP"] def test_dns_server_domain_name_registration(dns_server): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py index 3ce4d8ee..3c1afb28 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py @@ -31,8 +31,8 @@ def test_create_ftp_client(ftp_client): assert ftp_client is not None ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") assert ftp_client_service.name is "FTPClient" - assert ftp_client_service.port is Port.FTP - assert ftp_client_service.protocol is IPProtocol.TCP + assert ftp_client_service.port is Port["FTP"] + assert ftp_client_service.protocol is IPProtocol["TCP"] def test_ftp_client_store_file(ftp_client): @@ -61,7 +61,7 @@ def test_ftp_should_not_process_commands_if_service_not_running(ftp_client): """Method _process_ftp_command should return false if service is not running.""" payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.PORT, - ftp_command_args=Port.FTP, + ftp_command_args=Port["FTP"], status_code=FTPStatusCode.OK, ) @@ -102,7 +102,7 @@ def test_offline_ftp_client_receives_request(ftp_client): payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.PORT, - ftp_command_args=Port.FTP, + ftp_command_args=Port["FTP"], status_code=FTPStatusCode.OK, ) @@ -119,7 +119,7 @@ def test_receive_should_ignore_payload_with_none_status_code(ftp_client): """Receive should ignore payload with no set status code to prevent infinite send/receive loops.""" payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.PORT, - ftp_command_args=Port.FTP, + ftp_command_args=Port["FTP"], status_code=None, ) ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py index a1c2ba59..aa13ec5e 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py @@ -30,8 +30,8 @@ def test_create_ftp_server(ftp_server): assert ftp_server is not None ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer") assert ftp_server_service.name is "FTPServer" - assert ftp_server_service.port is Port.FTP - assert ftp_server_service.protocol is IPProtocol.TCP + assert ftp_server_service.port is Port["FTP"] + assert ftp_server_service.protocol is IPProtocol["TCP"] def test_ftp_server_store_file(ftp_server): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 41858b90..21ed839b 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -77,11 +77,11 @@ def wireless_wan_network(): network.connect(pc_a.network_interface[1], router_1.network_interface[2]) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) # add ACL rule to allow SSH traffic - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=21) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["SSH"], dst_port=Port["SSH"], position=21) # Configure PC B pc_b = Computer( @@ -329,7 +329,7 @@ def test_SSH_across_network(wireless_wan_network): terminal_a: Terminal = pc_a.software_manager.software.get("Terminal") terminal_b: Terminal = pc_b.software_manager.software.get("Terminal") - router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=21) + router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["SSH"], dst_port=Port["SSH"], position=21) assert len(terminal_a._connections) == 0 diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py index 9af176be..c1df3857 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py @@ -33,8 +33,8 @@ def test_create_web_server(web_server): assert web_server is not None web_server_service: WebServer = web_server.software_manager.software.get("WebServer") assert web_server_service.name is "WebServer" - assert web_server_service.port is Port.HTTP - assert web_server_service.protocol is IPProtocol.TCP + assert web_server_service.port is Port["HTTP"] + assert web_server_service.protocol is IPProtocol["TCP"] def test_handling_get_request_not_found_path(web_server): diff --git a/tests/unit_tests/_primaite/_simulator/_system/test_software.py b/tests/unit_tests/_primaite/_simulator/_system/test_software.py index 4cf83370..b7a663af 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/test_software.py +++ b/tests/unit_tests/_primaite/_simulator/_system/test_software.py @@ -19,10 +19,10 @@ class TestSoftware(Service): def software(file_system): return TestSoftware( name="TestSoftware", - port=Port.ARP, + port=Port["ARP"], file_system=file_system, sys_log=SysLog(hostname="test_service"), - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], ) diff --git a/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py b/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py index a8fb0a3a..4e40bbd8 100644 --- a/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py +++ b/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py @@ -11,7 +11,7 @@ def test_simple_conversion(): The original dictionary contains one level of nested dictionary with enums as keys. The expected output should have string values of enums as keys. """ - original_dict = {IPProtocol.UDP: {Port.ARP: {"inbound": 0, "outbound": 1016.0}}} + original_dict = {IPProtocol["UDP"]: {Port["ARP"]: {"inbound": 0, "outbound": 1016.0}}} expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0}}} assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict @@ -36,8 +36,8 @@ def test_mixed_keys(): The expected output should have string values of enums and original string keys. """ original_dict = { - IPProtocol.TCP: {"port": {"inbound": 0, "outbound": 1016.0}}, - "protocol": {Port.HTTP: {"inbound": 10, "outbound": 2020.0}}, + IPProtocol["TCP"]: {"port": {"inbound": 0, "outbound": 1016.0}}, + "protocol": {Port["HTTP"]: {"inbound": 10, "outbound": 2020.0}}, } expected_dict = { "tcp": {"port": {"inbound": 0, "outbound": 1016.0}}, @@ -66,7 +66,7 @@ def test_nested_dicts(): The expected output should have string values of enums as keys at all levels. """ original_dict = { - IPProtocol.UDP: {Port.ARP: {"inbound": 0, "outbound": 1016.0, "details": {IPProtocol.TCP: {"latency": "low"}}}} + IPProtocol["UDP"]: {Port["ARP"]: {"inbound": 0, "outbound": 1016.0, "details": {IPProtocol["TCP"]: {"latency": "low"}}}} } expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0, "details": {"tcp": {"latency": "low"}}}}} assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict @@ -79,6 +79,6 @@ def test_non_dict_values(): The original dictionary contains lists and tuples as values. The expected output should preserve these non-dictionary values while converting enum keys to string values. """ - original_dict = {IPProtocol.UDP: [Port.ARP, Port.HTTP], "protocols": (IPProtocol.TCP, IPProtocol.UDP)} - expected_dict = {"udp": [Port.ARP, Port.HTTP], "protocols": (IPProtocol.TCP, IPProtocol.UDP)} + original_dict = {IPProtocol["UDP"]: [Port["ARP"], Port["HTTP"]], "protocols": (IPProtocol["TCP"], IPProtocol["UDP"])} + expected_dict = {"udp": [Port["ARP"], Port["HTTP"]], "protocols": (IPProtocol["TCP"], IPProtocol["UDP"])} assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict