From dd931d900b560f191d2063f1414dd49c18898941 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 18 Sep 2024 16:02:25 +0100 Subject: [PATCH 1/9] port refactor - eod commit, airspace still broken --- .../simulation_components/network/network.rst | 6 +- .../network/nodes/firewall.rst | 50 ++--- .../network/nodes/wireless_router.rst | 10 +- .../network/transport_to_data_link_layer.rst | 2 +- .../system/applications/nmap.rst | 64 +++--- .../system/services/ftp_client.rst | 2 +- notebooks/test.ipynb | 157 +++++++++++++++ .../agent/observations/acl_observation.py | 4 +- .../agent/observations/host_observations.py | 18 ++ .../agent/observations/nic_observations.py | 35 +++- .../agent/observations/node_observations.py | 19 +- src/primaite/game/game.py | 39 +++- .../Command-&-Control-E2E-Demonstration.ipynb | 2 +- .../create-simulation_demo.ipynb | 2 +- .../network_simulator_demo.ipynb | 2 +- src/primaite/simulator/network/airspace.py | 82 +++----- src/primaite/simulator/network/creation.py | 4 +- .../simulator/network/hardware/base.py | 29 +-- .../hardware/nodes/network/firewall.py | 4 +- .../network/hardware/nodes/network/router.py | 97 +++++---- .../hardware/nodes/network/wireless_router.py | 8 +- src/primaite/simulator/network/networks.py | 16 +- .../simulator/network/protocols/masquerade.py | 4 +- .../network/transmission/data_link_layer.py | 8 +- .../network/transmission/network_layer.py | 43 ++-- .../network/transmission/transport_layer.py | 184 +++++++++++------- .../system/applications/database_client.py | 4 +- .../simulator/system/applications/nmap.py | 72 +++---- .../red_applications/c2/abstract_c2.py | 16 +- .../red_applications/c2/c2_beacon.py | 4 +- .../red_applications/data_manipulation_bot.py | 4 +- .../applications/red_applications/dos_bot.py | 6 +- .../red_applications/ransomware_script.py | 4 +- .../system/applications/web_browser.py | 8 +- .../simulator/system/core/session_manager.py | 62 +++--- .../simulator/system/core/software_manager.py | 22 +-- .../simulator/system/services/arp/arp.py | 4 +- .../services/database/database_service.py | 4 +- .../system/services/dns/dns_client.py | 8 +- .../system/services/dns/dns_server.py | 4 +- .../system/services/ftp/ftp_client.py | 30 +-- .../system/services/ftp/ftp_server.py | 6 +- .../system/services/ftp/ftp_service.py | 8 +- .../simulator/system/services/icmp/icmp.py | 4 +- .../system/services/icmp/router_icmp.py | 4 +- .../system/services/ntp/ntp_client.py | 6 +- .../system/services/ntp/ntp_server.py | 4 +- .../system/services/terminal/terminal.py | 4 +- .../system/services/web_server/web_server.py | 6 +- src/primaite/simulator/system/software.py | 12 +- tests/conftest.py | 28 +-- .../nodes/network/test_firewall_config.py | 32 +-- .../nodes/network/test_router_config.py | 6 +- .../applications/extended_application.py | 8 +- .../extensions/services/extended_service.py | 4 +- .../actions/test_c2_suite_actions.py | 6 +- .../actions/test_configure_actions.py | 2 +- .../actions/test_terminal_actions.py | 2 +- .../observations/test_acl_observations.py | 2 +- .../observations/test_firewall_observation.py | 6 +- .../observations/test_nic_observations.py | 2 +- .../observations/test_router_observation.py | 6 +- .../observations/test_user_observations.py | 2 +- .../game_layer/test_actions.py | 24 +-- .../game_layer/test_rewards.py | 4 +- .../network/test_airspace_config.py | 8 +- .../network/test_broadcast.py | 12 +- .../network/test_firewall.py | 32 +-- .../integration_tests/network/test_routing.py | 8 +- .../network/test_wireless_router.py | 4 +- .../test_c2_suite_integration.py | 20 +- .../test_data_manipulation_bot_and_server.py | 2 +- .../test_dos_bot_and_server.py | 6 +- .../test_ransomware_script.py | 2 +- tests/integration_tests/system/test_nmap.py | 16 +- .../system/test_service_listening_on_ports.py | 16 +- .../test_web_client_server_and_database.py | 12 +- .../test_simulation/test_request_response.py | 4 +- .../_network/_hardware/nodes/test_acl.py | 70 +++---- .../_network/_hardware/nodes/test_router.py | 4 +- .../_transmission/test_data_link_layer.py | 14 +- .../_red_applications/test_c2_suite.py | 20 +- .../test_data_manipulation_bot.py | 4 +- .../_system/_applications/test_web_browser.py | 4 +- .../_system/_services/test_dns_client.py | 4 +- .../_system/_services/test_dns_server.py | 4 +- .../_system/_services/test_ftp_client.py | 10 +- .../_system/_services/test_ftp_server.py | 4 +- .../_system/_services/test_terminal.py | 8 +- .../_system/_services/test_web_server.py | 4 +- .../_simulator/_system/test_software.py | 4 +- .../_utils/test_dict_enum_keys_conversion.py | 12 +- 92 files changed, 957 insertions(+), 682 deletions(-) create mode 100644 notebooks/test.ipynb diff --git a/docs/source/simulation_components/network/network.rst b/docs/source/simulation_components/network/network.rst index 636ffbcc..4cc121a3 100644 --- a/docs/source/simulation_components/network/network.rst +++ b/docs/source/simulation_components/network/network.rst @@ -103,13 +103,13 @@ we'll use the following Network that has a client, server, two switches, and a r router_1.acl.add_rule( action=ACLAction.PERMIT, - src_port=Port.ARP, - dst_port=Port.ARP, + src_port=Port["ARP"], + dst_port=Port["ARP"], position=22 ) router_1.acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.ICMP, + protocol=IPProtocol["ICMP"], position=23 ) diff --git a/docs/source/simulation_components/network/nodes/firewall.rst b/docs/source/simulation_components/network/nodes/firewall.rst index 149d3e67..1ef16d63 100644 --- a/docs/source/simulation_components/network/nodes/firewall.rst +++ b/docs/source/simulation_components/network/nodes/firewall.rst @@ -156,8 +156,8 @@ To prevent all external traffic from accessing the internal network, with except # Exception rule to allow HTTP traffic from external to internal network firewall.internal_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTP, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTP"], dst_ip_address="192.168.1.0", dst_wildcard_mask="0.0.0.255", position=2 @@ -172,16 +172,16 @@ To enable external traffic to access specific services hosted within the DMZ: # Allow HTTP and HTTPS traffic to the DMZ firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTP, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTP"], dst_ip_address="172.16.0.0", dst_wildcard_mask="0.0.0.255", position=3 ) firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTPS, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.0", dst_wildcard_mask="0.0.0.255", position=4 @@ -196,9 +196,9 @@ To permit SSH access from a designated external IP to a specific server within t # Allow SSH from a specific external IP to an internal server firewall.internal_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="10.0.0.2", - dst_port=Port.SSH, + dst_port=Port["SSH"], dst_ip_address="192.168.1.10", position=5 ) @@ -212,9 +212,9 @@ To limit database server access to selected external IP addresses: # Allow PostgreSQL traffic from an authorized external IP to the internal DB server firewall.internal_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="10.0.0.3", - dst_port=Port.POSTGRES_SERVER, + dst_port=Port["POSTGRES_SERVER"], dst_ip_address="192.168.1.20", position=6 ) @@ -222,8 +222,8 @@ To limit database server access to selected external IP addresses: # Deny all other PostgreSQL traffic from external sources firewall.internal_inbound_acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, - dst_port=Port.POSTGRES_SERVER, + protocol=IPProtocol["TCP"], + dst_port=Port["POSTGRES_SERVER"], dst_ip_address="192.168.1.0", dst_wildcard_mask="0.0.0.255", position=7 @@ -247,15 +247,15 @@ To authorize HTTP/HTTPS access to a DMZ-hosted web server, excluding known malic # Allow HTTP/HTTPS traffic to the DMZ web server firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTP, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTP"], dst_ip_address="172.16.0.2", position=9 ) firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTPS, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.2", position=10 ) @@ -269,9 +269,9 @@ To facilitate restricted access from the internal network to DMZ-hosted services # Permit specific internal application server HTTPS access to a DMZ-hosted API firewall.internal_outbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="192.168.1.30", # Internal application server IP - dst_port=Port.HTTPS, + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.3", # DMZ API server IP position=11 ) @@ -289,9 +289,9 @@ To facilitate restricted access from the internal network to DMZ-hosted services # Corresponding rule in DMZ inbound ACL to allow the traffic from the specific internal server firewall.dmz_inbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="192.168.1.30", # Ensuring this specific source is allowed - dst_port=Port.HTTPS, + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.3", # DMZ API server IP position=13 ) @@ -301,7 +301,7 @@ To facilitate restricted access from the internal network to DMZ-hosted services action=ACLAction.DENY, src_ip_address="192.168.1.0", src_wildcard_mask="0.0.0.255", - dst_port=Port.HTTPS, + dst_port=Port["HTTPS"], dst_ip_address="172.16.0.3", # DMZ API server IP position=14 ) @@ -315,8 +315,8 @@ To block all SSH access attempts from the external network: # Deny all SSH traffic from any external source firewall.external_inbound_acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, - dst_port=Port.SSH, + protocol=IPProtocol["TCP"], + dst_port=Port["SSH"], position=1 ) @@ -329,8 +329,8 @@ To allow the internal network to initiate HTTP connections to the external netwo # Permit outgoing HTTP traffic from the internal network to any external destination firewall.external_outbound_acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, - dst_port=Port.HTTP, + protocol=IPProtocol["TCP"], + dst_port=Port["HTTP"], position=2 ) diff --git a/docs/source/simulation_components/network/nodes/wireless_router.rst b/docs/source/simulation_components/network/nodes/wireless_router.rst index c78c8419..bd361afa 100644 --- a/docs/source/simulation_components/network/nodes/wireless_router.rst +++ b/docs/source/simulation_components/network/nodes/wireless_router.rst @@ -49,7 +49,7 @@ additional steps to configure wireless settings: wireless_router.configure_wireless_access_point( port=1, ip_address="192.168.2.1", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency.WIFI_2_4, + frequency=AirSpaceFrequency["WIFI_2_4"], ) @@ -102,8 +102,8 @@ ICMP traffic, ensuring basic network connectivity and ping functionality. network.connect(pc_a.network_interface[1], router_1.router_interface) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) # Configure PC B pc_b = Computer( @@ -130,13 +130,13 @@ ICMP traffic, ensuring basic network connectivity and ping functionality. port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency.WIFI_2_4, + frequency=AirSpaceFrequency["WIFI_2_4"], ) router_2.configure_wireless_access_point( port=1, ip_address="192.168.1.2", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency.WIFI_2_4, + frequency=AirSpaceFrequency["WIFI_2_4"], ) # Configure routes for inter-router communication diff --git a/docs/source/simulation_components/network/transport_to_data_link_layer.rst b/docs/source/simulation_components/network/transport_to_data_link_layer.rst index cc546021..02bfdcdc 100644 --- a/docs/source/simulation_components/network/transport_to_data_link_layer.rst +++ b/docs/source/simulation_components/network/transport_to_data_link_layer.rst @@ -104,7 +104,7 @@ address of 'aa:bb:cc:dd:ee:ff' to port 8080 on the host 10.0.0.10 which has a NI ip_packet = IPPacket( src_ip_address="192.168.0.100", dst_ip_address="10.0.0.10", - protocol=IPProtocol.TCP + protocol=IPProtocol["TCP"] ) # Data Link Layer ethernet_header = EthernetHeader( diff --git a/docs/source/simulation_components/system/applications/nmap.rst b/docs/source/simulation_components/system/applications/nmap.rst index 1e7f5ea4..e2cd474e 100644 --- a/docs/source/simulation_components/system/applications/nmap.rst +++ b/docs/source/simulation_components/system/applications/nmap.rst @@ -165,8 +165,8 @@ Perform a horizontal port scan on port 5432 across multiple IP addresses: { IPv4Address('192.168.1.12'): { - : [ - + : [ + ] } } @@ -192,7 +192,7 @@ Perform a vertical port scan on multiple ports on a single IP address: vertical_scan_results = pc_1_nmap.port_scan( target_ip_address=[IPv4Address("192.168.1.12")], - target_port=[Port(21), Port(22), Port(80), Port(443)] + target_port=[21, 22, 80, 443] ) .. code-block:: python @@ -200,9 +200,9 @@ Perform a vertical port scan on multiple ports on a single IP address: { IPv4Address('192.168.1.12'): { - : [ - , - + : [ + , + ] } } @@ -233,7 +233,7 @@ Perform a box scan on multiple ports across multiple IP addresses: box_scan_results = pc_1_nmap.port_scan( target_ip_address=[IPv4Address("192.168.1.12"), IPv4Address("192.168.1.13")], - target_port=[Port(21), Port(22), Port(80), Port(443)] + target_port=[21, 22, 80, 443] ) .. code-block:: python @@ -241,15 +241,15 @@ Perform a box scan on multiple ports across multiple IP addresses: { IPv4Address('192.168.1.13'): { - : [ - , - + : [ + , + ] }, IPv4Address('192.168.1.12'): { - : [ - , - + : [ + , + ] } } @@ -289,36 +289,36 @@ Perform a full box scan on all ports, over both TCP and UDP, on a whole subnet: { IPv4Address('192.168.1.11'): { - : [ - + : [ + ] }, IPv4Address('192.168.1.1'): { - : [ - + : [ + ] }, IPv4Address('192.168.1.12'): { - : [ - , - , - , - + : [ + , + , + , + ], - : [ - , - + : [ + , + ] }, IPv4Address('192.168.1.13'): { - : [ - , - , - + : [ + , + , + ], - : [ - , - + : [ + , + ] } } diff --git a/docs/source/simulation_components/system/services/ftp_client.rst b/docs/source/simulation_components/system/services/ftp_client.rst index fdf9cfcf..0c9a781c 100644 --- a/docs/source/simulation_components/system/services/ftp_client.rst +++ b/docs/source/simulation_components/system/services/ftp_client.rst @@ -15,7 +15,7 @@ Key features - Connects to the :ref:`FTPServer` via the ``SoftwareManager``. - Simulates FTP requests and FTPPacket transfer across a network - Allows the emulation of FTP commands between an FTP client and server: - - PORT: specifies the port that server should connect to on the client (currently only uses ``Port.FTP``) + - PORT: specifies the port that server should connect to on the client (currently only uses ``Port["FTP"]``) - STOR: stores a file from client to server - RETR: retrieves a file from the FTP server - QUIT: disconnect from server diff --git a/notebooks/test.ipynb b/notebooks/test.ipynb new file mode 100644 index 00000000..5afe04b0 --- /dev/null +++ b/notebooks/test.ipynb @@ -0,0 +1,157 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import yaml\n", + "\n", + "from primaite.game.game import PrimaiteGame\n", + "from primaite.session.environment import PrimaiteGymEnv\n", + "from primaite.simulator.network.hardware.nodes.host.computer import Computer\n", + "from primaite.simulator.network.hardware.nodes.host.server import Server\n", + "from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection\n", + "from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot\n", + "from primaite.simulator.system.services.database.database_service import DatabaseService\n", + "from primaite import getLogger, PRIMAITE_PATHS" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "with open(PRIMAITE_PATHS.user_config_path / \"example_config\" / \"data_manipulation.yaml\") as f:\n", + " cfg = yaml.safe_load(f)\n", + "game = PrimaiteGame.from_config(cfg)\n", + "uc2_network = game.simulation.network" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "client_1: Computer = uc2_network.get_node_by_hostname(\"client_1\")\n", + "db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get(\"DataManipulationBot\")\n", + "\n", + "database_server: Server = uc2_network.get_node_by_hostname(\"database_server\")\n", + "db_service: DatabaseService = database_server.software_manager.software.get(\"DatabaseService\")\n", + "\n", + "web_server: Server = uc2_network.get_node_by_hostname(\"web_server\")\n", + "db_client: DatabaseClient = web_server.software_manager.software.get(\"DatabaseClient\")\n", + "db_connection: DatabaseClientConnection = db_client.get_new_connection()\n", + "db_service.backup_database()\n", + "\n", + "# First check that the DB client on the web_server can successfully query the users table on the database\n", + "assert db_connection.query(\"SELECT\")\n", + "\n", + "db_manipulation_bot.data_manipulation_p_of_success = 1.0\n", + "db_manipulation_bot.port_scan_p_of_success = 1.0\n", + "\n", + "# Now we run the DataManipulationBot\n", + "db_manipulation_bot.attack()\n", + "\n", + "# Now check that the DB client on the web_server cannot query the users table on the database\n", + "assert not db_connection.query(\"SELECT\")\n", + "\n", + "# Now restore the database\n", + "db_service.restore_backup()\n", + "\n", + "# Now check that the DB client on the web_server can successfully query the users table on the database\n", + "assert db_connection.query(\"SELECT\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "router = uc2_network.get_node_by_hostname(\"router_1\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----------------------------------------------------------------------------------------------------------------+\n", + "| router_1 Access Control List |\n", + "+-------+--------+----------+--------+--------------+-------------+--------+--------------+-------------+---------+\n", + "| Index | Action | Protocol | Src IP | Src Wildcard | Src Port | Dst IP | Dst Wildcard | Dst Port | Matched |\n", + "+-------+--------+----------+--------+--------------+-------------+--------+--------------+-------------+---------+\n", + "| 18 | PERMIT | ANY | ANY | ANY | 5432 (5432) | ANY | ANY | 5432 (5432) | 4 |\n", + "| 19 | PERMIT | ANY | ANY | ANY | 53 (53) | ANY | ANY | 53 (53) | 0 |\n", + "| 20 | PERMIT | ANY | ANY | ANY | 21 (21) | ANY | ANY | 21 (21) | 0 |\n", + "| 21 | PERMIT | ANY | ANY | ANY | 80 (80) | ANY | ANY | 80 (80) | 0 |\n", + "| 22 | PERMIT | ANY | ANY | ANY | 219 (219) | ANY | ANY | 219 (219) | 9 |\n", + "| 23 | PERMIT | icmp | ANY | ANY | ANY | ANY | ANY | ANY | 0 |\n", + "| 24 | DENY | ANY | ANY | ANY | ANY | ANY | ANY | ANY | 0 |\n", + "+-------+--------+----------+--------+--------------+-------------+--------+--------------+-------------+---------+\n" + ] + } + ], + "source": [ + "router.acl.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "AccessControlList.is_permitted() missing 1 required positional argument: 'frame'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mrouter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43macl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_permitted\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mTypeError\u001b[0m: AccessControlList.is_permitted() missing 1 required positional argument: 'frame'" + ] + } + ], + "source": [ + "router.acl.is_permitted()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py index 41af5a8f..abb6f1f8 100644 --- a/src/primaite/game/agent/observations/acl_observation.py +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -10,6 +10,8 @@ from gymnasium.core import ObsType from primaite import getLogger from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port _LOGGER = getLogger(__name__) @@ -61,7 +63,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)} self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)} self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)} - self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)} + self.protocol_to_id: Dict[str, int] = {IPProtocol[p]: i + 2 for i, p in enumerate(protocol_list)} self.default_observation: Dict = { i + 1: { diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 4419ccc7..05b25952 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -5,6 +5,7 @@ from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType +from pydantic import field_validator from primaite import getLogger from primaite.game.agent.observations.file_system_observations import FolderObservation @@ -12,6 +13,8 @@ from primaite.game.agent.observations.nic_observations import NICObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port _LOGGER = getLogger(__name__) @@ -55,6 +58,21 @@ class HostObservation(AbstractObservation, identifier="HOST"): include_users: Optional[bool] = True """If True, report user session information.""" + @field_validator('monitored_traffic', mode='before') + def traffic_lookup(cls, val:Optional[Dict]) -> Optional[Dict]: + if val is None: + return val + new_val = {} + for proto, port_list in val.items(): + # convert protocol, for instance ICMP becomes "icmp" + proto = IPProtocol[proto] if proto in IPProtocol else proto + new_val[proto] = [] + for port in port_list: + # convert ports, for instance "HTTP" becomes 80 + port = Port[port] if port in Port else port + new_val[proto].append(port) + return new_val + def __init__( self, where: WhereType, diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 002ee4da..200187f5 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -5,9 +5,11 @@ from typing import Dict, Optional from gymnasium import spaces from gymnasium.core import ObsType +from pydantic import field_validator from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port @@ -24,6 +26,22 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): monitored_traffic: Optional[Dict] = None """A dict containing which traffic types are to be included in the observation.""" + @field_validator('monitored_traffic', mode='before') + def traffic_lookup(cls, val:Optional[Dict]) -> Optional[Dict]: + if val is None: + return val + new_val = {} + for proto, port_list in val.items(): + # convert protocol, for instance ICMP becomes "icmp" + proto = IPProtocol[proto] if proto in IPProtocol else proto + new_val[proto] = [] + for port in port_list: + # convert ports, for instance "HTTP" becomes 80 + port = Port[port] if port in Port else port + new_val[proto].append(port) + return new_val + + def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None: """ Initialise a network interface observation instance. @@ -67,7 +85,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): else: default_traffic_obs["TRAFFIC"][protocol] = {} for port in monitored_traffic_config[protocol]: - default_traffic_obs["TRAFFIC"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0} + default_traffic_obs["TRAFFIC"][protocol] = {"inbound": 0, "outbound": 0} return default_traffic_obs @@ -142,17 +160,16 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): } else: for port in self.monitored_traffic[protocol]: - port_enum = Port[port] - obs["TRAFFIC"][protocol][port_enum.value] = {} + obs["TRAFFIC"][protocol][port] = {} traffic = {"inbound": 0, "outbound": 0} - if nic_state["traffic"][protocol].get(port_enum.value) is not None: - traffic = nic_state["traffic"][protocol][port_enum.value] + if nic_state["traffic"][protocol].get(port) is not None: + traffic = nic_state["traffic"][protocol][port] - obs["TRAFFIC"][protocol][port_enum.value]["inbound"] = self._categorise_traffic( + obs["TRAFFIC"][protocol][port]["inbound"] = self._categorise_traffic( traffic_value=traffic["inbound"], nic_state=nic_state ) - obs["TRAFFIC"][protocol][port_enum.value]["outbound"] = self._categorise_traffic( + obs["TRAFFIC"][protocol][port]["outbound"] = self._categorise_traffic( traffic_value=traffic["outbound"], nic_state=nic_state ) @@ -162,7 +179,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): obs["TRAFFIC"]["icmp"] = {"inbound": 0, "outbound": 0} else: for port in self.monitored_traffic[protocol]: - obs["TRAFFIC"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0} + obs["TRAFFIC"][protocol][port] = {"inbound": 0, "outbound": 0} if self.include_nmne: obs.update({"NMNE": {}}) @@ -201,7 +218,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): else: space["TRAFFIC"][protocol] = spaces.Dict({}) for port in self.monitored_traffic[protocol]: - space["TRAFFIC"][protocol][Port[port].value] = spaces.Dict( + space["TRAFFIC"][protocol][port] = spaces.Dict( {"inbound": spaces.Discrete(11), "outbound": spaces.Discrete(11)} ) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index e263cadb..3e51c3b3 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -5,13 +5,15 @@ from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType -from pydantic import model_validator +from pydantic import field_validator, model_validator from primaite import getLogger from primaite.game.agent.observations.firewall_observation import FirewallObservation from primaite.game.agent.observations.host_observations import HostObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.observations.router_observation import RouterObservation +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port _LOGGER = getLogger(__name__) @@ -61,6 +63,21 @@ class NodesObservation(AbstractObservation, identifier="NODES"): num_rules: Optional[int] = None """Number of rules ACL rules to show.""" + @field_validator('monitored_traffic', mode='before') + def traffic_lookup(cls, val:Optional[Dict]) -> Optional[Dict]: + if val is None: + return val + new_val = {} + for proto, port_list in val.items(): + # convert protocol, for instance ICMP becomes "icmp" + proto = IPProtocol[proto] if proto in IPProtocol else proto + new_val[proto] = [] + for port in port_list: + # convert ports, for instance "HTTP" becomes 80 + port = Port[port] if port in Port else port + new_val[proto].append(port) + return new_val + @model_validator(mode="after") def force_optional_fields(self) -> NodesObservation.ConfigSchema: """Check that options are specified only if they are needed for the nodes that are part of the config.""" diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 11c968af..e8329c63 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -4,7 +4,7 @@ from ipaddress import IPv4Address from typing import Dict, List, Optional, Union import numpy as np -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, field_validator from primaite import DEFAULT_BANDWIDTH, getLogger from primaite.game.agent.actions import ActionManager @@ -82,13 +82,40 @@ class PrimaiteGameOptions(BaseModel): """Random number seed for RNGs.""" max_episode_length: int = 256 """Maximum number of episodes for the PrimAITE game.""" - ports: List[str] + ports: List[int] """A whitelist of available ports in the simulation.""" protocols: List[str] """A whitelist of available protocols in the simulation.""" thresholds: Optional[Dict] = {} """A dict containing the thresholds used for determining what is acceptable during observations.""" + @field_validator('ports', mode='before') + def ports_str2int(cls, vals:Union[List[str],List[int]]) -> List[int]: + """ + Convert named port strings to port integer values. Integer ports remain unaffected. + + This is necessary to retain backwards compatibility with configs written for PrimAITE<=3.3. + :warning: This will be deprecated in PrimAITE 4.0 and configs will need to be converted. + """ + for i, port_val in enumerate(vals): + if port_val in Port: + vals[i] = Port[port_val] + return vals + + @field_validator('protocols', mode='before') + def protocols_str2int(cls, vals:List[str]) -> List[str]: + """ + Convert old-style named protocols to their proper values. + + This is necessary to retain backwards compatibility with configs written for PrimAITE<=3.3. + :warning: This will be deprecated in PrimAITE 4.0 and configs will need to be converted. + """ + for i, proto_val in enumerate(vals): + if proto_val in IPProtocol: + vals[i] = IPProtocol[proto_val] + return vals + + class PrimaiteGame: """ @@ -358,7 +385,7 @@ class PrimaiteGame: for port_id in set(software_cfg.get("options", {}).get("listen_on_ports", [])): port = None if isinstance(port_id, int): - port = Port(port_id) + port = port_id elif isinstance(port_id, str): port = Port[port_id] if port: @@ -475,7 +502,7 @@ class PrimaiteGame: opt = application_cfg["options"] new_application.configure( target_ip_address=IPv4Address(opt.get("target_ip_address")), - target_port=Port(opt.get("target_port", Port.POSTGRES_SERVER.value)), + target_port = Port[opt.get("target_port", "POSTGRES_SERVER")], payload=opt.get("payload"), repeat=bool(opt.get("repeat")), port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), @@ -488,8 +515,8 @@ class PrimaiteGame: new_application.configure( c2_server_ip_address=IPv4Address(opt.get("c2_server_ip_address")), keep_alive_frequency=(opt.get("keep_alive_frequency", 5)), - masquerade_protocol=IPProtocol[(opt.get("masquerade_protocol", IPProtocol.TCP))], - masquerade_port=Port[(opt.get("masquerade_port", Port.HTTP))], + masquerade_protocol=IPProtocol[(opt.get("masquerade_protocol", IPProtocol["TCP"]))], + masquerade_port=Port[(opt.get("masquerade_port", Port["HTTP"]))], ) if "network_interfaces" in node_cfg: for nic_num, nic_cfg in node_cfg["network_interfaces"].items(): diff --git a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb b/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb index b6b13f28..a5cc385b 100644 --- a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb @@ -1783,7 +1783,7 @@ "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", "from primaite.simulator.network.transmission.transport_layer import Port\n", "# As we're configuring via the PrimAITE API we need to pass the actual IPProtocol/Port (Agents leverage the simulation via the game layer and thus can pass strings).\n", - "c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=IPProtocol.UDP, masquerade_port=Port.DNS)\n", + "c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=IPProtocol["UDP"], masquerade_port=Port["DNS"])\n", "c2_beacon.establish()\n", "c2_beacon.show()" ] diff --git a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb index 77ac4842..f573f251 100644 --- a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb +++ b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb @@ -182,7 +182,7 @@ "metadata": {}, "outputs": [], "source": [ - "mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port.HTTP, protocol = IPProtocol.NONE,operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)" + "mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port["HTTP"], protocol = IPProtocol["NONE"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)" ] }, { diff --git a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb b/src/primaite/simulator/_package_data/network_simulator_demo.ipynb index 17a0f796..2d5b4772 100644 --- a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb +++ b/src/primaite/simulator/_package_data/network_simulator_demo.ipynb @@ -537,7 +537,7 @@ "from primaite.simulator.network.hardware.nodes.network.router import ACLAction\n", "network.get_node_by_hostname(\"router_1\").acl.add_rule(\n", " action=ACLAction.DENY,\n", - " protocol=IPProtocol.ICMP,\n", + " protocol=IPProtocol["ICMP"],\n", " src_ip_address=\"192.168.10.22\",\n", " position=1\n", ")" diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index cdb01514..29326df8 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -3,10 +3,12 @@ from __future__ import annotations from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Dict, List +from typing import Any, ClassVar, Dict, List, Type +from pydantic._internal._generics import PydanticGenericMetadata +from typing_extensions import Unpack from prettytable import MARKDOWN, PrettyTable -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from primaite import getLogger from primaite.simulator.network.hardware.base import Layer3Interface, NetworkInterface, WiredNetworkInterface @@ -40,51 +42,29 @@ def format_hertz(hertz: float, format_terahertz: bool = False, decimals: int = 3 else: # Hertz return format_str.format(hertz) + " Hz" +AirSpaceFrequencyRegistry: Dict[str,Dict] = { + "WIFI_2_4" : {'frequency': 2.4e9, 'data_rate_bps':100_000_000.0}, + "WIFI_5" : {'frequency': 5e9, 'data_rate_bps':500_000_000.0}, +} -class AirSpaceFrequency(Enum): - """Enumeration representing the operating frequencies for wireless communications.""" +def register_frequency(freq_name: str, freq_hz: int, data_rate_bps: int) -> None: + if freq_name in AirSpaceFrequencyRegistry: + raise RuntimeError(f"Cannot register new frequency {freq_name} because it's already registered.") + AirSpaceFrequencyRegistry.update({freq_name:{'frequency': freq_hz, 'data_rate_bps':data_rate_bps}}) - WIFI_2_4 = 2.4e9 - """WiFi 2.4 GHz. Known for its extensive range and ability to penetrate solid objects effectively.""" - WIFI_5 = 5e9 - """WiFi 5 GHz. Known for its higher data transmission speeds and reduced interference from other devices.""" +def maximum_data_rate_mbps(frequency_name:str) -> float: + """ + Retrieves the maximum data transmission rate in megabits per second (Mbps). - def __str__(self) -> str: - hertz_str = format_hertz(hertz=self.value) - if self == AirSpaceFrequency.WIFI_2_4: - return f"WiFi {hertz_str}" - if self == AirSpaceFrequency.WIFI_5: - return f"WiFi {hertz_str}" - return "Unknown Frequency" + This is derived by converting the maximum data rate from bits per second, as defined + in `maximum_data_rate_bps`, to megabits per second. - @property - def maximum_data_rate_bps(self) -> float: - """ - Retrieves the maximum data transmission rate in bits per second (bps). + :return: The maximum data rate in megabits per second. + """ + return AirSpaceFrequencyRegistry[frequency_name]['data_rate_bps'] + return data_rate / 1_000_000.0 - The maximum rates are predefined for frequencies.: - - WIFI 2.4 supports 100,000,000 bps - - WIFI 5 supports 500,000,000 bps - :return: The maximum data rate in bits per second. - """ - if self == AirSpaceFrequency.WIFI_2_4: - return 100_000_000.0 # 100 Megabits per second - if self == AirSpaceFrequency.WIFI_5: - return 500_000_000.0 # 500 Megabits per second - return 0.0 - - @property - def maximum_data_rate_mbps(self) -> float: - """ - Retrieves the maximum data transmission rate in megabits per second (Mbps). - - This is derived by converting the maximum data rate from bits per second, as defined - in `maximum_data_rate_bps`, to megabits per second. - - :return: The maximum data rate in megabits per second. - """ - return self.maximum_data_rate_bps / 1_000_000.0 class AirSpace(BaseModel): @@ -97,13 +77,13 @@ class AirSpace(BaseModel): """ wireless_interfaces: Dict[str, WirelessNetworkInterface] = Field(default_factory=lambda: {}) - wireless_interfaces_by_frequency: Dict[AirSpaceFrequency, List[WirelessNetworkInterface]] = Field( + wireless_interfaces_by_frequency: Dict[int, List[WirelessNetworkInterface]] = Field( default_factory=lambda: {} ) - bandwidth_load: Dict[AirSpaceFrequency, float] = Field(default_factory=lambda: {}) - frequency_max_capacity_mbps_: Dict[AirSpaceFrequency, float] = Field(default_factory=lambda: {}) + bandwidth_load: Dict[int, float] = Field(default_factory=lambda: {}) + frequency_max_capacity_mbps_: Dict[int, float] = Field(default_factory=lambda: {}) - def get_frequency_max_capacity_mbps(self, frequency: AirSpaceFrequency) -> float: + def get_frequency_max_capacity_mbps(self, frequency: str) -> float: """ Retrieves the maximum data transmission capacity for a specified frequency. @@ -117,9 +97,9 @@ class AirSpace(BaseModel): """ if frequency in self.frequency_max_capacity_mbps_: return self.frequency_max_capacity_mbps_[frequency] - return frequency.maximum_data_rate_mbps + return maximum_data_rate_mbps(frequency) - def set_frequency_max_capacity_mbps(self, cfg: Dict[AirSpaceFrequency, float]): + def set_frequency_max_capacity_mbps(self, cfg: Dict[int, float]): """ Sets custom maximum data transmission capacities for multiple frequencies. @@ -150,7 +130,7 @@ class AirSpace(BaseModel): load_percent = load / maximum_capacity if maximum_capacity > 0 else 0.0 if load_percent > 1.0: load_percent = 1.0 - table.add_row([format_hertz(frequency.value), f"{load_percent:.0%}", f"{maximum_capacity:.3f}"]) + table.add_row([format_hertz(frequency), f"{load_percent:.0%}", f"{maximum_capacity:.3f}"]) print(table) def show_wireless_interfaces(self, markdown: bool = False): @@ -182,7 +162,7 @@ class AirSpace(BaseModel): interface.mac_address, interface.ip_address if hasattr(interface, "ip_address") else None, interface.subnet_mask if hasattr(interface, "subnet_mask") else None, - format_hertz(interface.frequency.value), + format_hertz(interface.frequency), f"{interface.speed:.3f}", status, ] @@ -298,7 +278,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC): """ airspace: AirSpace - frequency: AirSpaceFrequency = AirSpaceFrequency.WIFI_2_4 + frequency: str = "WIFI_2_4" def enable(self): """Attempt to enable the network interface.""" @@ -430,7 +410,7 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC) # Update the state with information from Layer3Interface state.update(Layer3Interface.describe_state(self)) - state["frequency"] = self.frequency.value + state["frequency"] = self.frequency return state diff --git a/src/primaite/simulator/network/creation.py b/src/primaite/simulator/network/creation.py index 61a37a90..c2524b4b 100644 --- a/src/primaite/simulator/network/creation.py +++ b/src/primaite/simulator/network/creation.py @@ -98,8 +98,8 @@ def create_office_lan( default_gateway = IPv4Address(f"192.168.{subnet_base}.1") router = Router(hostname=f"router_{lan_name}", start_up_duration=0) router.power_on() - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) network.add_node(router) router.configure_port(port=1, ip_address=default_gateway, subnet_mask="255.255.255.0") router.enable_port(1) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index bf230e07..4154cc08 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -203,16 +203,16 @@ class NetworkInterface(SimComponent, ABC): # Initialise basic frame data variables direction = "inbound" if inbound else "outbound" # Direction of the traffic ip_address = str(frame.ip.src_ip_address if inbound else frame.ip.dst_ip_address) # Source or destination IP - protocol = frame.ip.protocol.name # Network protocol used in the frame + protocol = frame.ip.protocol # Network protocol used in the frame # Initialise port variable; will be determined based on protocol type port = None # Determine the source or destination port based on the protocol (TCP/UDP) if frame.tcp: - port = frame.tcp.src_port.value if inbound else frame.tcp.dst_port.value + port = frame.tcp.src_port if inbound else frame.tcp.dst_port elif frame.udp: - port = frame.udp.src_port.value if inbound else frame.udp.dst_port.value + port = frame.udp.src_port if inbound else frame.udp.dst_port # Convert frame payload to string for keyword checking frame_str = str(frame.payload) @@ -274,20 +274,20 @@ class NetworkInterface(SimComponent, ABC): # Identify the protocol and port from the frame if frame.tcp: - protocol = IPProtocol.TCP + protocol = IPProtocol["TCP"] port = frame.tcp.dst_port elif frame.udp: - protocol = IPProtocol.UDP + protocol = IPProtocol["UDP"] port = frame.udp.dst_port elif frame.icmp: - protocol = IPProtocol.ICMP + protocol = IPProtocol["ICMP"] # Ensure the protocol is in the capture dict if protocol not in self.traffic: self.traffic[protocol] = {} # Handle non-ICMP protocols that use ports - if protocol != IPProtocol.ICMP: + if protocol != IPProtocol["ICMP"]: if port not in self.traffic[protocol]: self.traffic[protocol][port] = {"inbound": 0, "outbound": 0} self.traffic[protocol][port][direction] += frame.size_Mbits @@ -843,8 +843,8 @@ class UserManager(Service): :param password: The password for the default admin user """ kwargs["name"] = "UserManager" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["port"] = Port["NONE"] + kwargs["protocol"] = IPProtocol["NONE"] super().__init__(**kwargs) self.start() @@ -1166,8 +1166,8 @@ class UserSessionManager(Service): :param password: The password for the default admin user """ kwargs["name"] = "UserSessionManager" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["port"] = Port["NONE"] + kwargs["protocol"] = IPProtocol["NONE"] super().__init__(**kwargs) self.start() @@ -1312,7 +1312,7 @@ class UserSessionManager(Service): software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( payload={"type": "user_timeout", "connection_id": session.uuid}, - dest_port=Port.SSH, + dest_port=Port["SSH"], dest_ip_address=session.remote_ip_address, ) @@ -1845,8 +1845,9 @@ class Node(SimComponent): table.align = "l" table.title = f"{self.hostname} Open Ports" for port in self.software_manager.get_open_ports(): - if port.value > 0: - table.add_row([port.value, port.name]) + if port > 0: + # TODO: do a reverse lookup for port name, or change this to only show port int + table.add_row([port, port]) print(table.get_string(sortby="Port")) @property diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 4510eac0..6d8e084d 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -58,8 +58,8 @@ class Firewall(Router): >>> # Permit HTTP traffic to the DMZ >>> firewall.dmz_inbound_acl.add_rule( ... action=ACLAction.PERMIT, - ... protocol=IPProtocol.TCP, - ... dst_port=Port.HTTP, + ... protocol=IPProtocol["TCP"], + ... dst_port=Port["HTTP"], ... src_ip_address="0.0.0.0", ... src_wildcard_mask="0.0.0.0", ... dst_ip_address="172.16.0.0", diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index ceb91695..013c473e 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -7,7 +7,7 @@ from ipaddress import IPv4Address, IPv4Network from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union from prettytable import MARKDOWN, PrettyTable -from pydantic import validate_call +from pydantic import field_validator, validate_call from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent @@ -106,7 +106,7 @@ class ACLRule(SimComponent): :ivar ACLAction action: Specifies whether to `PERMIT` or `DENY` the traffic that matches the rule conditions. The default action is `DENY`. - :ivar Optional[IPProtocol] protocol: The network protocol (e.g., TCP, UDP, ICMP) to match. If `None`, the rule + :ivar Optional[str] protocol: The network protocol (e.g., TCP, UDP, ICMP) to match. If `None`, the rule applies to all protocols. :ivar Optional[IPV4Address] src_ip_address: The source IP address to match. If combined with `src_wildcard_mask`, it specifies the start of an IP range. @@ -116,20 +116,33 @@ class ACLRule(SimComponent): `dst_wildcard_mask`, it specifies the start of an IP range. :ivar Optional[IPv4Address] dst_wildcard_mask: The wildcard mask for the destination IP address, defining the range of addresses to match. - :ivar Optional[Port] src_port: The source port number to match. Relevant for TCP/UDP protocols. - :ivar Optional[Port] dst_port: The destination port number to match. Relevant for TCP/UDP protocols. + :ivar Optional[int] src_port: The source port number to match. Relevant for TCP/UDP protocols. + :ivar Optional[int] dst_port: The destination port number to match. Relevant for TCP/UDP protocols. """ action: ACLAction = ACLAction.DENY - protocol: Optional[IPProtocol] = None + protocol: Optional[str] = None src_ip_address: Optional[IPV4Address] = None src_wildcard_mask: Optional[IPV4Address] = None dst_ip_address: Optional[IPV4Address] = None dst_wildcard_mask: Optional[IPV4Address] = None - src_port: Optional[Port] = None - dst_port: Optional[Port] = None + src_port: Optional[int] = None + dst_port: Optional[int] = None match_count: int = 0 + @field_validator('protocol', mode='before') + def protocol_valid(cls, val:Optional[str]) -> Optional[str]: + if val is not None: + assert val in IPProtocol.values(), f"Cannot create ACL rule with invalid protocol {val}" + return val + + @field_validator('src_port', 'dst_port', mode='before') + def ports_valid(cls, val:Optional[int]) -> Optional[int]: + if val is not None: + assert val in Port.values(), f"Cannot create ACL rule with invalid port {val}" + return val + + def __str__(self) -> str: rule_strings = [] for key, value in self.model_dump(exclude={"uuid", "request_manager"}).items(): @@ -149,13 +162,13 @@ class ACLRule(SimComponent): """ state = super().describe_state() state["action"] = self.action.value - state["protocol"] = self.protocol.name if self.protocol else None + state["protocol"] = self.protocol if self.protocol else None state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None state["src_wildcard_mask"] = str(self.src_wildcard_mask) if self.src_wildcard_mask else None - state["src_port"] = self.src_port.name if self.src_port else None + state["src_port"] = self.src_port if self.src_port else None state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None state["dst_wildcard_mask"] = str(self.dst_wildcard_mask) if self.dst_wildcard_mask else None - state["dst_port"] = self.dst_port.name if self.dst_port else None + state["dst_port"] = self.dst_port if self.dst_port else None state["match_count"] = self.match_count return state @@ -265,7 +278,7 @@ class AccessControlList(SimComponent): >>> acl = AccessControlList() >>> acl.add_rule( ... action=ACLAction.PERMIT, - ... protocol=IPProtocol.TCP, + ... protocol=IPProtocol["TCP"], ... src_ip_address="192.168.1.0", ... src_wildcard_mask="0.0.0.255", ... dst_ip_address="192.168.2.0", @@ -323,13 +336,13 @@ class AccessControlList(SimComponent): func=lambda request, context: RequestResponse.from_bool( self.add_rule( action=ACLAction[request[0]], - protocol=None if request[1] == "ALL" else IPProtocol[request[1]], + protocol=None if request[1] == "ALL" else request[1], src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]), src_wildcard_mask=None if request[3] == "NONE" else IPv4Address(request[3]), - src_port=None if request[4] == "ALL" else Port[request[4]], + src_port=None if request[4] == "ALL" else request[4], dst_ip_address=None if request[5] == "ALL" else IPv4Address(request[5]), dst_wildcard_mask=None if request[6] == "NONE" else IPv4Address(request[6]), - dst_port=None if request[7] == "ALL" else Port[request[7]], + dst_port=None if request[7] == "ALL" else request[7], position=int(request[8]), ) ) @@ -377,13 +390,13 @@ class AccessControlList(SimComponent): def add_rule( self, action: ACLAction = ACLAction.DENY, - protocol: Optional[IPProtocol] = None, + protocol: Optional[str] = None, src_ip_address: Optional[IPV4Address] = None, src_wildcard_mask: Optional[IPV4Address] = None, dst_ip_address: Optional[IPV4Address] = None, dst_wildcard_mask: Optional[IPV4Address] = None, - src_port: Optional[Port] = None, - dst_port: Optional[Port] = None, + src_port: Optional[int] = None, + dst_port: Optional[int] = None, position: int = 0, ) -> bool: """ @@ -399,11 +412,11 @@ class AccessControlList(SimComponent): >>> router = Router("router") >>> router.add_rule( ... action=ACLAction.DENY, - ... protocol=IPProtocol.TCP, + ... protocol=IPProtocol["TCP"], ... src_ip_address="192.168.1.0", ... src_wildcard_mask="0.0.0.255", ... dst_ip_address="10.10.10.5", - ... dst_port=Port.SSH, + ... dst_port=Port["SSH"], ... position=5 ... ) >>> # This permits SSH traffic from the 192.168.1.0/24 subnet to the 10.10.10.5 server. @@ -411,10 +424,10 @@ class AccessControlList(SimComponent): >>> # Then if we want to allow a specific IP address from this subnet to SSH into the server >>> router.add_rule( ... action=ACLAction.PERMIT, - ... protocol=IPProtocol.TCP, + ... protocol=IPProtocol["TCP"], ... src_ip_address="192.168.1.25", ... dst_ip_address="10.10.10.5", - ... dst_port=Port.SSH, + ... dst_port=Port["SSH"], ... position=4 ... ) @@ -485,11 +498,11 @@ class AccessControlList(SimComponent): def get_relevant_rules( self, - protocol: IPProtocol, + protocol: str, src_ip_address: Union[str, IPv4Address], - src_port: Port, + src_port: int, dst_ip_address: Union[str, IPv4Address], - dst_port: Port, + dst_port: int, ) -> List[ACLRule]: """ Get the list of relevant rules for a packet with given properties. @@ -552,13 +565,13 @@ class AccessControlList(SimComponent): [ index, rule.action.name if rule.action else "ANY", - rule.protocol.name if rule.protocol else "ANY", + rule.protocol if rule.protocol else "ANY", rule.src_ip_address if rule.src_ip_address else "ANY", rule.src_wildcard_mask if rule.src_wildcard_mask else "ANY", - f"{rule.src_port.value} ({rule.src_port.name})" if rule.src_port else "ANY", + f"{rule.src_port} ({rule.src_port})" if rule.src_port else "ANY", rule.dst_ip_address if rule.dst_ip_address else "ANY", rule.dst_wildcard_mask if rule.dst_wildcard_mask else "ANY", - f"{rule.dst_port.value} ({rule.dst_port.name})" if rule.dst_port else "ANY", + f"{rule.dst_port} ({rule.dst_port})" if rule.dst_port else "ANY", rule.match_count, ] ) @@ -1088,17 +1101,17 @@ class RouterSessionManager(SessionManager): def resolve_outbound_transmission_details( self, dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, - src_port: Optional[Port] = None, - dst_port: Optional[Port] = None, - protocol: Optional[IPProtocol] = None, + src_port: Optional[int] = None, + dst_port: Optional[int] = None, + protocol: Optional[str] = None, session_id: Optional[str] = None, ) -> Tuple[ Optional[RouterInterface], Optional[str], IPv4Address, - Optional[Port], - Optional[Port], - Optional[IPProtocol], + Optional[int], + Optional[int], + Optional[str], bool, ]: """ @@ -1118,19 +1131,19 @@ class RouterSessionManager(SessionManager): treats the transmission as a broadcast to that network. Optional. :type dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] :param src_port: The source port number for the transmission. Optional. - :type src_port: Optional[Port] + :type src_port: Optional[int] :param dst_port: The destination port number for the transmission. Optional. - :type dst_port: Optional[Port] + :type dst_port: Optional[int] :param protocol: The IP protocol to be used for the transmission. Optional. - :type protocol: Optional[IPProtocol] + :type protocol: Optional[str] :param session_id: The session ID associated with the transmission. If provided, the session details override other parameters. Optional. :type session_id: Optional[str] :return: A tuple containing the resolved outbound network interface, destination MAC address, destination IP address, source port, destination port, protocol, and a boolean indicating whether the transmission is a broadcast. - :rtype: Tuple[Optional[RouterInterface], Optional[str], IPv4Address, Optional[Port], Optional[Port], - Optional[IPProtocol], bool] + :rtype: Tuple[Optional[RouterInterface], Optional[str], IPv4Address, Optional[int], Optional[int], + Optional[str], bool] """ if dst_ip_address and not isinstance(dst_ip_address, (IPv4Address, IPv4Network)): dst_ip_address = IPv4Address(dst_ip_address) @@ -1257,8 +1270,8 @@ class Router(NetworkNode): Initializes the router's ACL (Access Control List) with default rules, permitting essential protocols like ARP and ICMP, which are necessary for basic network operations and diagnostics. """ - self.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - self.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + self.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + self.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) def setup_for_episode(self, episode: int): """ @@ -1357,9 +1370,9 @@ class Router(NetworkNode): """ dst_ip_address = frame.ip.dst_ip_address dst_port = None - if frame.ip.protocol == IPProtocol.TCP: + if frame.ip.protocol == IPProtocol["TCP"]: dst_port = frame.tcp.dst_port - elif frame.ip.protocol == IPProtocol.UDP: + elif frame.ip.protocol == IPProtocol["UDP"]: dst_port = frame.udp.dst_port if self.ip_is_router_interface(dst_ip_address) and ( diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 3cb4c515..d73bc756 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -116,7 +116,7 @@ class WirelessRouter(Router): >>> wireless_router.configure_wireless_access_point( ... ip_address="10.10.10.1", ... subnet_mask="255.255.255.0" - ... frequency=AirSpaceFrequency.WIFI_2_4 + ... frequency=AirSpaceFrequency["WIFI_2_4"] ... ) """ @@ -153,7 +153,7 @@ class WirelessRouter(Router): self, ip_address: IPV4Address, subnet_mask: IPV4Address, - frequency: Optional[AirSpaceFrequency] = AirSpaceFrequency.WIFI_2_4, + frequency: Optional[int] = AirSpaceFrequency["WIFI_2_4"], ): """ Configures a wireless access point (WAP). @@ -168,10 +168,10 @@ class WirelessRouter(Router): :param subnet_mask: The subnet mask associated with the IP address :param frequency: The operating frequency of the wireless access point, defined by the AirSpaceFrequency enum. This determines the frequency band (e.g., 2.4 GHz or 5 GHz) the access point will use for wireless - communication. Default is AirSpaceFrequency.WIFI_2_4. + communication. Default is AirSpaceFrequency["WIFI_2_4"]. """ if not frequency: - frequency = AirSpaceFrequency.WIFI_2_4 + frequency = AirSpaceFrequency["WIFI_2_4"] self.sys_log.info("Configuring wireless access point") self.wireless_access_point.disable() # Temporarily disable the WAP for reconfiguration diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index cb0965eb..a73f3b12 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -79,9 +79,9 @@ def client_server_routed() -> Network: server_1.power_on() network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1]) - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) return network @@ -271,23 +271,23 @@ def arcd_uc2_network() -> Network: security_suite.connect_nic(NIC(ip_address="192.168.10.110", subnet_mask="255.255.255.0")) network.connect(endpoint_b=security_suite.network_interface[2], endpoint_a=switch_2.network_interface[7]) - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) # Allow PostgreSQL requests router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 ) # Allow DNS requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["DNS"], dst_port=Port["DNS"], position=1) # Allow FTP requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=2) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["FTP"], dst_port=Port["FTP"], position=2) # Open port 80 for web server - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=3) return network diff --git a/src/primaite/simulator/network/protocols/masquerade.py b/src/primaite/simulator/network/protocols/masquerade.py index e2a7b6a0..ef060bc7 100644 --- a/src/primaite/simulator/network/protocols/masquerade.py +++ b/src/primaite/simulator/network/protocols/masquerade.py @@ -8,9 +8,9 @@ from primaite.simulator.network.protocols.packet import DataPacket class MasqueradePacket(DataPacket): """Represents an generic malicious packet that is masquerading as another protocol.""" - masquerade_protocol: Enum # The 'Masquerade' protocol that is currently in use + masquerade_protocol: str # The 'Masquerade' protocol that is currently in use - masquerade_port: Enum # The 'Masquerade' port that is currently in use + masquerade_port: int # The 'Masquerade' port that is currently in use class C2Packet(MasqueradePacket): diff --git a/src/primaite/simulator/network/transmission/data_link_layer.py b/src/primaite/simulator/network/transmission/data_link_layer.py index 159eca7f..b9bc48d9 100644 --- a/src/primaite/simulator/network/transmission/data_link_layer.py +++ b/src/primaite/simulator/network/transmission/data_link_layer.py @@ -70,15 +70,15 @@ class Frame(BaseModel): msg = "Network Frame cannot have both a TCP header and a UDP header" _LOGGER.error(msg) raise ValueError(msg) - if kwargs["ip"].protocol == IPProtocol.TCP and not kwargs.get("tcp"): + if kwargs["ip"].protocol == IPProtocol["TCP"] and not kwargs.get("tcp"): msg = "Cannot build a Frame using the TCP IP Protocol without a TCPHeader" _LOGGER.error(msg) raise ValueError(msg) - if kwargs["ip"].protocol == IPProtocol.UDP and not kwargs.get("udp"): + if kwargs["ip"].protocol == IPProtocol["UDP"] and not kwargs.get("udp"): msg = "Cannot build a Frame using the UDP IP Protocol without a UDPHeader" _LOGGER.error(msg) raise ValueError(msg) - if kwargs["ip"].protocol == IPProtocol.ICMP and not kwargs.get("icmp"): + if kwargs["ip"].protocol == IPProtocol["ICMP"] and not kwargs.get("icmp"): msg = "Cannot build a Frame using the ICMP IP Protocol without a ICMPPacket" _LOGGER.error(msg) raise ValueError(msg) @@ -165,7 +165,7 @@ class Frame(BaseModel): :return: True if the Frame is an ARP packet, otherwise False. """ - return self.udp.dst_port == Port.ARP + return self.udp.dst_port == Port["ARP"] @property def is_icmp(self) -> bool: diff --git a/src/primaite/simulator/network/transmission/network_layer.py b/src/primaite/simulator/network/transmission/network_layer.py index d493cbdf..36ff2751 100644 --- a/src/primaite/simulator/network/transmission/network_layer.py +++ b/src/primaite/simulator/network/transmission/network_layer.py @@ -9,25 +9,32 @@ from primaite.utils.validators import IPV4Address _LOGGER = getLogger(__name__) -class IPProtocol(Enum): - """ - Enum representing transport layer protocols in IP header. +IPProtocol : dict[str, str] = dict( + NONE = "none", + TCP = "tcp", + UDP = "udp", + ICMP = "icmp", +) - .. _List of IPProtocols: - """ +# class IPProtocol(Enum): +# """ +# Enum representing transport layer protocols in IP header. - NONE = "none" - """Placeholder for a non-protocol.""" - TCP = "tcp" - """Transmission Control Protocol.""" - UDP = "udp" - """User Datagram Protocol.""" - ICMP = "icmp" - """Internet Control Message Protocol.""" +# .. _List of IPProtocols: +# """ - def model_dump(self) -> str: - """Return as JSON-serialisable string.""" - return self.name +# NONE = "none" +# """Placeholder for a non-protocol.""" +# TCP = "tcp" +# """Transmission Control Protocol.""" +# UDP = "udp" +# """User Datagram Protocol.""" +# ICMP = "icmp" +# """Internet Control Message Protocol.""" + +# def model_dump(self) -> str: +# """Return as JSON-serialisable string.""" +# return self.name class Precedence(Enum): @@ -81,7 +88,7 @@ class IPPacket(BaseModel): >>> ip_packet = IPPacket( ... src_ip_address=IPv4Address('192.168.0.1'), ... dst_ip_address=IPv4Address('10.0.0.1'), - ... protocol=IPProtocol.TCP, + ... protocol=IPProtocol["TCP"], ... ttl=64, ... precedence=Precedence.CRITICAL ... ) @@ -91,7 +98,7 @@ class IPPacket(BaseModel): "Source IP address." dst_ip_address: IPV4Address "Destination IP address." - protocol: IPProtocol = IPProtocol.TCP + protocol: str = IPProtocol["TCP"] "IPProtocol." ttl: int = 64 "Time to Live (TTL) for the packet." diff --git a/src/primaite/simulator/network/transmission/transport_layer.py b/src/primaite/simulator/network/transmission/transport_layer.py index 7f0d2d7a..c77ef532 100644 --- a/src/primaite/simulator/network/transmission/transport_layer.py +++ b/src/primaite/simulator/network/transmission/transport_layer.py @@ -5,76 +5,112 @@ from typing import List, Union from pydantic import BaseModel -class Port(Enum): - """ - Enumeration of common known TCP/UDP ports used by protocols for operation of network applications. +Port: dict[str, int] = dict( + UNUSED = -1, + NONE = 0, + WOL = 9, + FTP_DATA = 20, + FTP = 21, + SSH = 22, + SMTP = 25, + DNS = 53, + HTTP = 80, + POP3 = 110, + SFTP = 115, + NTP = 123, + IMAP = 143, + SNMP = 161, + SNMP_TRAP = 162, + ARP = 219, + LDAP = 389, + HTTPS = 443, + SMB = 445, + IPP = 631, + SQL_SERVER = 1433, + MYSQL = 3306, + RDP = 3389, + RTP = 5004, + RTP_ALT = 5005, + DNS_ALT = 5353, + HTTP_ALT = 8080, + HTTPS_ALT = 8443, + POSTGRES_SERVER = 5432, +) - .. _List of Ports: - """ +# class Port(): +# def __getattr__() - UNUSED = -1 - "An unused port stub." - NONE = 0 - "Place holder for a non-port." - WOL = 9 - "Wake-on-Lan (WOL) - Used to turn or awaken a computer from sleep mode by a network message." - FTP_DATA = 20 - "File Transfer [Default Data]" - FTP = 21 - "File Transfer Protocol (FTP) - FTP control (command)" - SSH = 22 - "Secure Shell (SSH) - Used for secure remote access and command execution." - SMTP = 25 - "Simple Mail Transfer Protocol (SMTP) - Used for email delivery between servers." - DNS = 53 - "Domain Name System (DNS) - Used for translating domain names to IP addresses." - HTTP = 80 - "HyperText Transfer Protocol (HTTP) - Used for web traffic." - POP3 = 110 - "Post Office Protocol version 3 (POP3) - Used for retrieving emails from a mail server." - SFTP = 115 - "Secure File Transfer Protocol (SFTP) - Used for secure file transfer over SSH." - NTP = 123 - "Network Time Protocol (NTP) - Used for clock synchronization between computer systems." - IMAP = 143 - "Internet Message Access Protocol (IMAP) - Used for retrieving emails from a mail server." - SNMP = 161 - "Simple Network Management Protocol (SNMP) - Used for network device management." - SNMP_TRAP = 162 - "SNMP Trap - Used for sending SNMP notifications (traps) to a network management system." - ARP = 219 - "Address resolution Protocol - Used to connect a MAC address to an IP address." - LDAP = 389 - "Lightweight Directory Access Protocol (LDAP) - Used for accessing and modifying directory information." - HTTPS = 443 - "HyperText Transfer Protocol Secure (HTTPS) - Used for secure web traffic." - SMB = 445 - "Server Message Block (SMB) - Used for file sharing and printer sharing in Windows environments." - IPP = 631 - "Internet Printing Protocol (IPP) - Used for printing over the internet or an intranet." - SQL_SERVER = 1433 - "Microsoft SQL Server Database Engine - Used for communication with the SQL Server." - MYSQL = 3306 - "MySQL Database Server - Used for MySQL database communication." - RDP = 3389 - "Remote Desktop Protocol (RDP) - Used for remote desktop access to Windows machines." - RTP = 5004 - "Real-time Transport Protocol (RTP) - Used for transmitting real-time media, e.g., audio and video." - RTP_ALT = 5005 - "Alternative port for RTP (RTP_ALT) - Used in some configurations for transmitting real-time media." - DNS_ALT = 5353 - "Alternative port for DNS (DNS_ALT) - Used in some configurations for DNS service." - HTTP_ALT = 8080 - "Alternative port for HTTP (HTTP_ALT) - Often used as an alternative HTTP port for web applications." - HTTPS_ALT = 8443 - "Alternative port for HTTPS (HTTPS_ALT) - Used in some configurations for secure web traffic." - POSTGRES_SERVER = 5432 - "Postgres SQL Server." +# class Port(Enum): +# """ +# Enumeration of common known TCP/UDP ports used by protocols for operation of network applications. - def model_dump(self) -> str: - """Return a json-serialisable string.""" - return self.name +# .. _List of Ports: +# """ + +# UNUSED = -1 +# "An unused port stub." + +# NONE = 0 +# "Place holder for a non-port." +# WOL = 9 +# "Wake-on-Lan (WOL) - Used to turn or awaken a computer from sleep mode by a network message." +# FTP_DATA = 20 +# "File Transfer [Default Data]" +# FTP = 21 +# "File Transfer Protocol (FTP) - FTP control (command)" +# SSH = 22 +# "Secure Shell (SSH) - Used for secure remote access and command execution." +# SMTP = 25 +# "Simple Mail Transfer Protocol (SMTP) - Used for email delivery between servers." +# DNS = 53 +# "Domain Name System (DNS) - Used for translating domain names to IP addresses." +# HTTP = 80 +# "HyperText Transfer Protocol (HTTP) - Used for web traffic." +# POP3 = 110 +# "Post Office Protocol version 3 (POP3) - Used for retrieving emails from a mail server." +# SFTP = 115 +# "Secure File Transfer Protocol (SFTP) - Used for secure file transfer over SSH." +# NTP = 123 +# "Network Time Protocol (NTP) - Used for clock synchronization between computer systems." +# IMAP = 143 +# "Internet Message Access Protocol (IMAP) - Used for retrieving emails from a mail server." +# SNMP = 161 +# "Simple Network Management Protocol (SNMP) - Used for network device management." +# SNMP_TRAP = 162 +# "SNMP Trap - Used for sending SNMP notifications (traps) to a network management system." +# ARP = 219 +# "Address resolution Protocol - Used to connect a MAC address to an IP address." +# LDAP = 389 +# "Lightweight Directory Access Protocol (LDAP) - Used for accessing and modifying directory information." +# HTTPS = 443 +# "HyperText Transfer Protocol Secure (HTTPS) - Used for secure web traffic." +# SMB = 445 +# "Server Message Block (SMB) - Used for file sharing and printer sharing in Windows environments." +# IPP = 631 +# "Internet Printing Protocol (IPP) - Used for printing over the internet or an intranet." +# SQL_SERVER = 1433 +# "Microsoft SQL Server Database Engine - Used for communication with the SQL Server." +# MYSQL = 3306 +# "MySQL Database Server - Used for MySQL database communication." +# RDP = 3389 +# "Remote Desktop Protocol (RDP) - Used for remote desktop access to Windows machines." +# RTP = 5004 +# "Real-time Transport Protocol (RTP) - Used for transmitting real-time media, e.g., audio and video." +# RTP_ALT = 5005 +# "Alternative port for RTP (RTP_ALT) - Used in some configurations for transmitting real-time media." +# DNS_ALT = 5353 +# "Alternative port for DNS (DNS_ALT) - Used in some configurations for DNS service." +# HTTP_ALT = 8080 +# "Alternative port for HTTP (HTTP_ALT) - Often used as an alternative HTTP port for web applications." +# HTTPS_ALT = 8443 +# "Alternative port for HTTPS (HTTPS_ALT) - Used in some configurations for secure web traffic." +# POSTGRES_SERVER = 5432 +# "Postgres SQL Server." + +# def model_dump(self) -> str: +# """Return a json-serialisable string.""" +# return self.name class UDPHeader(BaseModel): @@ -87,13 +123,13 @@ class UDPHeader(BaseModel): :Example: >>> udp_header = UDPHeader( - ... src_port=Port.HTTP_ALT, - ... dst_port=Port.HTTP, + ... src_port=Port["HTTP_ALT"], + ... dst_port=Port["HTTP"], ... ) """ - src_port: Union[Port, int] - dst_port: Union[Port, int] + src_port: int + dst_port: int class TCPFlags(Enum): @@ -126,12 +162,12 @@ class TCPHeader(BaseModel): :Example: >>> tcp_header = TCPHeader( - ... src_port=Port.HTTP_ALT, - ... dst_port=Port.HTTP, + ... src_port=Port["HTTP_ALT"], + ... dst_port=Port["HTTP"], ... flags=[TCPFlags.SYN, TCPFlags.ACK] ... ) """ - src_port: Port - dst_port: Port + src_port: int + dst_port: int flags: List[TCPFlags] = [TCPFlags.SYN] diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 3f80c745..170e2647 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -90,8 +90,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"): def __init__(self, **kwargs): kwargs["name"] = "DatabaseClient" - kwargs["port"] = Port.POSTGRES_SERVER - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["POSTGRES_SERVER"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) def _init_request_manager(self) -> RequestManager: diff --git a/src/primaite/simulator/system/applications/nmap.py b/src/primaite/simulator/system/applications/nmap.py index c87eaaf5..74bce85d 100644 --- a/src/primaite/simulator/system/applications/nmap.py +++ b/src/primaite/simulator/system/applications/nmap.py @@ -24,8 +24,8 @@ class PortScanPayload(SimComponent): """ ip_address: IPV4Address - port: Port - protocol: IPProtocol + port: int + protocol: str request: bool = True def describe_state(self) -> Dict: @@ -37,8 +37,8 @@ class PortScanPayload(SimComponent): """ state = super().describe_state() state["ip_address"] = str(self.ip_address) - state["port"] = self.port.value - state["protocol"] = self.protocol.value + state["port"] = self.port + state["protocol"] = self.protocol state["request"] = self.request return state @@ -64,8 +64,8 @@ class NMAP(Application, identifier="NMAP"): def __init__(self, **kwargs): kwargs["name"] = "NMAP" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["port"] = Port["NONE"] + kwargs["protocol"] = IPProtocol["NONE"] super().__init__(**kwargs) def _can_perform_network_action(self) -> bool: @@ -218,14 +218,14 @@ class NMAP(Application, identifier="NMAP"): print(table.get_string(sortby="IP Address")) return active_nodes - def _determine_port_scan_type(self, target_ip_addresses: List[IPV4Address], target_ports: List[Port]) -> str: + def _determine_port_scan_type(self, target_ip_addresses: List[IPV4Address], target_ports: List[int]) -> str: """ Determine the type of port scan based on the number of target IP addresses and ports. :param target_ip_addresses: The list of target IP addresses. :type target_ip_addresses: List[IPV4Address] :param target_ports: The list of target ports. - :type target_ports: List[Port] + :type target_ports: List[int] :return: The type of port scan. :rtype: str @@ -238,8 +238,8 @@ class NMAP(Application, identifier="NMAP"): def _check_port_open_on_ip_address( self, ip_address: IPv4Address, - port: Port, - protocol: IPProtocol, + port: int, + protocol: str, is_re_attempt: bool = False, port_scan_uuid: Optional[str] = None, ) -> bool: @@ -251,7 +251,7 @@ class NMAP(Application, identifier="NMAP"): :param port: The target port. :type port: Port :param protocol: The protocol used for the port scan. - :type protocol: IPProtocol + :type protocol: str :param is_re_attempt: Flag indicating if this is a reattempt. Defaults to False. :type is_re_attempt: bool :param port_scan_uuid: The UUID of the port scan payload. Defaults to None. @@ -272,8 +272,8 @@ class NMAP(Application, identifier="NMAP"): payload = PortScanPayload(ip_address=ip_address, port=port, protocol=protocol) self._active_port_scans[payload.uuid] = payload self.sys_log.info( - f"{self.name}: Sending port scan request over {payload.protocol.name} on port {payload.port.value} " - f"({payload.port.name}) to {payload.ip_address}" + f"{self.name}: Sending port scan request over {payload.protocol} on port {payload.port} " + f"({payload.port}) to {payload.ip_address}" ) self.software_manager.send_payload_to_session_manager( payload=payload, dest_ip_address=ip_address, src_port=port, dest_port=port, ip_protocol=protocol @@ -295,8 +295,8 @@ class NMAP(Application, identifier="NMAP"): self._active_port_scans.pop(payload.uuid) self._port_scan_responses[payload.uuid] = payload self.sys_log.info( - f"{self.name}: Received port scan response from {payload.ip_address} on port {payload.port.value} " - f"({payload.port.name}) over {payload.protocol.name}" + f"{self.name}: Received port scan response from {payload.ip_address} on port {payload.port} " + f"({payload.port}) over {payload.protocol}" ) def _process_port_scan_request(self, payload: PortScanPayload, session_id: str) -> None: @@ -311,8 +311,8 @@ class NMAP(Application, identifier="NMAP"): if self.software_manager.check_port_is_open(port=payload.port, protocol=payload.protocol): payload.request = False self.sys_log.info( - f"{self.name}: Responding to port scan request for port {payload.port.value} " - f"({payload.port.name}) over {payload.protocol.name}", + f"{self.name}: Responding to port scan request for port {payload.port} " + f"({payload.port}) over {payload.protocol}", ) self.software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id) @@ -320,20 +320,20 @@ class NMAP(Application, identifier="NMAP"): def port_scan( self, target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]], - target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None, - target_port: Optional[Union[Port, List[Port]]] = None, + target_protocol: Optional[Union[str, List[str]]] = None, + target_port: Optional[Union[int, List[int]]] = None, show: bool = True, json_serializable: bool = False, - ) -> Dict[IPv4Address, Dict[IPProtocol, List[Port]]]: + ) -> Dict[IPv4Address, Dict[str, List[int]]]: """ Perform a port scan on the target IP address(es). :param target_ip_address: The target IP address(es) or network(s) for the port scan. :type target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]] :param target_protocol: The protocol(s) to use for the port scan. Defaults to None, which includes TCP and UDP. - :type target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] + :type target_protocol: Optional[Union[str, List[str]]] :param target_port: The port(s) to scan. Defaults to None, which includes all valid ports. - :type target_port: Optional[Union[Port, List[Port]]] + :type target_port: Optional[Union[int, List[int]]] :param show: Flag indicating whether to display the scan results. Defaults to True. :type show: bool :param json_serializable: Flag indicating whether the return value should be JSON serializable. Defaults to @@ -341,19 +341,19 @@ class NMAP(Application, identifier="NMAP"): :type json_serializable: bool :return: A dictionary mapping IP addresses to protocols and lists of open ports. - :rtype: Dict[IPv4Address, Dict[IPProtocol, List[Port]]] + :rtype: Dict[IPv4Address, Dict[str, List[int]]] """ ip_addresses = self._explode_ip_address_network_array(target_ip_address) - if isinstance(target_port, Port): + if isinstance(target_port, int): target_port = [target_port] elif target_port is None: - target_port = [port for port in Port if port not in {Port.NONE, Port.UNUSED}] + target_port = [port for port in Port if port not in {Port["NONE"], Port["UNUSED"]}] - if isinstance(target_protocol, IPProtocol): + if isinstance(target_protocol, str): target_protocol = [target_protocol] elif target_protocol is None: - target_protocol = [IPProtocol.TCP, IPProtocol.UDP] + target_protocol = [IPProtocol["TCP"], IPProtocol["UDP"]] scan_type = self._determine_port_scan_type(list(ip_addresses), target_port) active_ports = {} @@ -372,10 +372,10 @@ class NMAP(Application, identifier="NMAP"): if port_open: if show: - table.add_row([ip_address, port.value, port.name, protocol.name]) + table.add_row([ip_address, port, port, protocol]) _ip_address = ip_address if not json_serializable else str(ip_address) - _protocol = protocol if not json_serializable else protocol.value - _port = port if not json_serializable else port.value + _protocol = protocol if not json_serializable else protocol + _port = port if not json_serializable else port if _ip_address not in active_ports: active_ports[_ip_address] = dict() if _protocol not in active_ports[_ip_address]: @@ -390,12 +390,12 @@ class NMAP(Application, identifier="NMAP"): def network_service_recon( self, target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]], - target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None, - target_port: Optional[Union[Port, List[Port]]] = None, + target_protocol: Optional[Union[str, List[str]]] = None, + target_port: Optional[Union[int, List[int]]] = None, show: bool = True, show_online_only: bool = True, json_serializable: bool = False, - ) -> Dict[IPv4Address, Dict[IPProtocol, List[Port]]]: + ) -> Dict[IPv4Address, Dict[str, List[int]]]: """ Perform a network service reconnaissance which includes a ping scan followed by a port scan. @@ -408,9 +408,9 @@ class NMAP(Application, identifier="NMAP"): :param target_ip_address: The target IP address(es) or network(s) for the port scan. :type target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]] :param target_protocol: The protocol(s) to use for the port scan. Defaults to None, which includes TCP and UDP. - :type target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] + :type target_protocol: Optional[Union[str, List[str]]] :param target_port: The port(s) to scan. Defaults to None, which includes all valid ports. - :type target_port: Optional[Union[Port, List[Port]]] + :type target_port: Optional[Union[int, List[int]]] :param show: Flag indicating whether to display the scan results. Defaults to True. :type show: bool :param show_online_only: Flag indicating whether to show only the online hosts. Defaults to True. @@ -420,7 +420,7 @@ class NMAP(Application, identifier="NMAP"): :type json_serializable: bool :return: A dictionary mapping IP addresses to protocols and lists of open ports. - :rtype: Dict[IPv4Address, Dict[IPProtocol, List[Port]]] + :rtype: Dict[IPv4Address, Dict[str, List[int]]] """ ping_scan_results = self.ping_scan( target_ip_address=target_ip_address, show=show, show_online_only=show_online_only, json_serializable=False diff --git a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py index 5d4cc8e0..d442d968 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py @@ -81,10 +81,10 @@ class AbstractC2(Application, identifier="AbstractC2"): keep_alive_frequency: int = Field(default=5, ge=1) """The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon.""" - masquerade_protocol: IPProtocol = Field(default=IPProtocol.TCP) + masquerade_protocol: str = Field(default=IPProtocol["TCP"]) """The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP.""" - masquerade_port: Port = Field(default=Port.HTTP) + masquerade_port: int = Field(default=Port["HTTP"]) """The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP.""" c2_config: _C2Opts = _C2Opts() @@ -142,9 +142,9 @@ class AbstractC2(Application, identifier="AbstractC2"): def __init__(self, **kwargs): """Initialise the C2 applications to by default listen for HTTP traffic.""" - kwargs["listen_on_ports"] = {Port.HTTP, Port.FTP, Port.DNS} - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.TCP + kwargs["listen_on_ports"] = {Port["HTTP"], Port["FTP"], Port["DNS"]} + kwargs["port"] = Port["NONE"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) @property @@ -367,7 +367,7 @@ class AbstractC2(Application, identifier="AbstractC2"): :rtype: bool """ # Validating that they are valid Enums. - if not isinstance(payload.masquerade_port, Port) or not isinstance(payload.masquerade_protocol, IPProtocol): + if not isinstance(payload.masquerade_port, int) or not isinstance(payload.masquerade_protocol, str): self.sys_log.warning( f"{self.name}: Received invalid Masquerade Values within Keep Alive." f"Port: {payload.masquerade_port} Protocol: {payload.masquerade_protocol}." @@ -410,8 +410,8 @@ class AbstractC2(Application, identifier="AbstractC2"): self.keep_alive_inactivity = 0 self.keep_alive_frequency = 5 self.c2_remote_connection = None - self.c2_config.masquerade_port = Port.HTTP - self.c2_config.masquerade_protocol = IPProtocol.TCP + self.c2_config.masquerade_port = Port["HTTP"] + self.c2_config.masquerade_protocol = IPProtocol["TCP"] @abstractmethod def _confirm_remote_connection(self, timestep: int) -> bool: diff --git a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py index fa0271e5..06453330 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py @@ -130,8 +130,8 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): self, c2_server_ip_address: IPv4Address = None, keep_alive_frequency: int = 5, - masquerade_protocol: Enum = IPProtocol.TCP, - masquerade_port: Enum = Port.HTTP, + masquerade_protocol: str = IPProtocol["TCP"], + masquerade_port: int = Port["HTTP"], ) -> bool: """ Configures the C2 beacon to communicate with the C2 server. diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index fefb22c3..d74ae384 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -50,8 +50,8 @@ class DataManipulationBot(Application, identifier="DataManipulationBot"): def __init__(self, **kwargs): kwargs["name"] = "DataManipulationBot" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["port"] = Port["NONE"] + kwargs["protocol"] = IPProtocol["NONE"] super().__init__(**kwargs) self._db_connection: Optional[DatabaseClientConnection] = None diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index fcad3b3e..2cc99c4a 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -35,7 +35,7 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): target_ip_address: Optional[IPv4Address] = None """IP address of the target service.""" - target_port: Optional[Port] = None + target_port: Optional[int] = None """Port of the target service.""" payload: Optional[str] = None @@ -94,7 +94,7 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): def configure( self, target_ip_address: IPv4Address, - target_port: Optional[Port] = Port.POSTGRES_SERVER, + target_port: Optional[int] = Port["POSTGRES_SERVER"], payload: Optional[str] = None, repeat: bool = False, port_scan_p_of_success: float = 0.1, @@ -105,7 +105,7 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): Configure the Denial of Service bot. :param: target_ip_address: The IP address of the Node containing the target service. - :param: target_port: The port of the target service. Optional - Default is `Port.HTTP` + :param: target_port: The port of the target service. Optional - Default is `Port["HTTP"]` :param: payload: The payload the DoS Bot will throw at the target service. Optional - Default is `None` :param: repeat: If True, the bot will maintain the attack. Optional - Default is `True` :param: port_scan_p_of_success: The chance of the port scan being successful. Optional - Default is 0.1 (10%) diff --git a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py index 2046affc..a819190c 100644 --- a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py +++ b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py @@ -27,8 +27,8 @@ class RansomwareScript(Application, identifier="RansomwareScript"): def __init__(self, **kwargs): kwargs["name"] = "RansomwareScript" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE + kwargs["port"] = Port["NONE"] + kwargs["protocol"] = IPProtocol["NONE"] super().__init__(**kwargs) self._db_connection: Optional[DatabaseClientConnection] = None diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 73791676..6707fa52 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -43,10 +43,10 @@ class WebBrowser(Application, identifier="WebBrowser"): def __init__(self, **kwargs): kwargs["name"] = "WebBrowser" - kwargs["protocol"] = IPProtocol.TCP + kwargs["protocol"] = IPProtocol["TCP"] # default for web is port 80 if kwargs.get("port") is None: - kwargs["port"] = Port.HTTP + kwargs["port"] = Port["HTTP"] super().__init__(**kwargs) self.run() @@ -126,7 +126,7 @@ class WebBrowser(Application, identifier="WebBrowser"): if self.send( payload=payload, dest_ip_address=self.domain_name_ip_address, - dest_port=parsed_url.port if parsed_url.port else Port.HTTP, + dest_port=parsed_url.port if parsed_url.port else Port["HTTP"], ): self.sys_log.info( f"{self.name}: Received HTTP {payload.request_method.name} " @@ -154,7 +154,7 @@ class WebBrowser(Application, identifier="WebBrowser"): self, payload: HttpRequestPacket, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = Port.HTTP, + dest_port: Optional[int] = Port["HTTP"], session_id: Optional[str] = None, **kwargs, ) -> bool: diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index b7e2c021..172be453 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -34,14 +34,14 @@ class Session(SimComponent): :param connected: A flag indicating whether the session is connected. """ - protocol: IPProtocol + protocol: str with_ip_address: IPv4Address - src_port: Optional[Port] - dst_port: Optional[Port] + src_port: Optional[int] + dst_port: Optional[int] connected: bool = False @classmethod - def from_session_key(cls, session_key: Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]) -> Session: + def from_session_key(cls, session_key: Tuple[str, IPv4Address, Optional[int], Optional[int]]) -> Session: """ Create a Session instance from a session key tuple. @@ -77,7 +77,7 @@ class SessionManager: def __init__(self, sys_log: SysLog): self.sessions_by_key: Dict[ - Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]], Session + Tuple[str, IPv4Address, IPv4Address, Optional[int], Optional[int]], Session ] = {} self.sessions_by_uuid: Dict[str, Session] = {} self.sys_log: SysLog = sys_log @@ -103,7 +103,7 @@ class SessionManager: @staticmethod def _get_session_key( frame: Frame, inbound_frame: bool = True - ) -> Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]: + ) -> Tuple[str, IPv4Address, Optional[int], Optional[int]]: """ Extracts the session key from the given frame. @@ -111,15 +111,15 @@ class SessionManager: - IPProtocol: The transport protocol (e.g. TCP, UDP, ICMP). - IPv4Address: The source IP address. - IPv4Address: The destination IP address. - - Optional[Port]: The source port number (if applicable). - - Optional[Port]: The destination port number (if applicable). + - Optional[int]: The source port number (if applicable). + - Optional[int]: The destination port number (if applicable). :param frame: The network frame from which to extract the session key. :return: A tuple containing the session key. """ protocol = frame.ip.protocol with_ip_address = frame.ip.src_ip_address - if protocol == IPProtocol.TCP: + if protocol == IPProtocol["TCP"]: if inbound_frame: src_port = frame.tcp.src_port dst_port = frame.tcp.dst_port @@ -127,7 +127,7 @@ class SessionManager: dst_port = frame.tcp.src_port src_port = frame.tcp.dst_port with_ip_address = frame.ip.dst_ip_address - elif protocol == IPProtocol.UDP: + elif protocol == IPProtocol["UDP"]: if inbound_frame: src_port = frame.udp.src_port dst_port = frame.udp.dst_port @@ -167,17 +167,17 @@ class SessionManager: def resolve_outbound_transmission_details( self, dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, - src_port: Optional[Port] = None, - dst_port: Optional[Port] = None, - protocol: Optional[IPProtocol] = None, + src_port: Optional[int] = None, + dst_port: Optional[int] = None, + protocol: Optional[str] = None, session_id: Optional[str] = None, ) -> Tuple[ Optional["NetworkInterface"], Optional[str], IPv4Address, - Optional[Port], - Optional[Port], - Optional[IPProtocol], + Optional[int], + Optional[int], + Optional[str], bool, ]: """ @@ -196,19 +196,19 @@ class SessionManager: treats the transmission as a broadcast to that network. Optional. :type dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] :param src_port: The source port number for the transmission. Optional. - :type src_port: Optional[Port] + :type src_port: Optional[int] :param dst_port: The destination port number for the transmission. Optional. - :type dst_port: Optional[Port] + :type dst_port: Optional[int] :param protocol: The IP protocol to be used for the transmission. Optional. - :type protocol: Optional[IPProtocol] + :type protocol: Optional[str] :param session_id: The session ID associated with the transmission. If provided, the session details override other parameters. Optional. :type session_id: Optional[str] :return: A tuple containing the resolved outbound network interface, destination MAC address, destination IP address, source port, destination port, protocol, and a boolean indicating whether the transmission is a broadcast. - :rtype: Tuple[Optional["NetworkInterface"], Optional[str], IPv4Address, Optional[Port], Optional[Port], - Optional[IPProtocol], bool] + :rtype: Tuple[Optional["NetworkInterface"], Optional[str], IPv4Address, Optional[int], Optional[int], + Optional[str], bool] """ if dst_ip_address and not isinstance(dst_ip_address, (IPv4Address, IPv4Network)): dst_ip_address = IPv4Address(dst_ip_address) @@ -259,10 +259,10 @@ class SessionManager: self, payload: Any, dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, - src_port: Optional[Port] = None, - dst_port: Optional[Port] = None, + src_port: Optional[int] = None, + dst_port: Optional[int] = None, session_id: Optional[str] = None, - ip_protocol: IPProtocol = IPProtocol.TCP, + ip_protocol: str = IPProtocol["TCP"], icmp_packet: Optional[ICMPPacket] = None, ) -> Union[Any, None]: """ @@ -286,7 +286,7 @@ class SessionManager: dst_mac_address = payload.target_mac_addr outbound_network_interface = self.resolve_outbound_network_interface(payload.target_ip_address) is_broadcast = payload.request - ip_protocol = IPProtocol.UDP + ip_protocol = IPProtocol["UDP"] else: vals = self.resolve_outbound_transmission_details( dst_ip_address=dst_ip_address, @@ -311,26 +311,26 @@ class SessionManager: if not outbound_network_interface or not dst_mac_address: return False - if not (src_port or dst_port): + if src_port is None and dst_port is None: raise ValueError( "Failed to resolve src or dst port. Have you sent the port from the service or application?" ) tcp_header = None udp_header = None - if ip_protocol == IPProtocol.TCP: + if ip_protocol == IPProtocol["TCP"]: tcp_header = TCPHeader( src_port=dst_port, dst_port=dst_port, ) - elif ip_protocol == IPProtocol.UDP: + elif ip_protocol == IPProtocol["UDP"]: udp_header = UDPHeader( src_port=dst_port, dst_port=dst_port, ) # TODO: Only create IP packet if not ARP # ip_packet = None - # if dst_port != Port.ARP: + # if dst_port != Port["ARP"]: # IPPacket( # src_ip_address=outbound_network_interface.ip_address, # dst_ip_address=dst_ip_address, @@ -387,7 +387,7 @@ class SessionManager: elif frame.udp: dst_port = frame.udp.dst_port elif frame.icmp: - dst_port = Port.NONE + dst_port = Port["NONE"] self.software_manager.receive_payload_from_session_manager( payload=frame.payload, port=dst_port, @@ -413,5 +413,5 @@ class SessionManager: table.align = "l" table.title = f"{self.sys_log.hostname} Session Manager" for session in self.sessions_by_key.values(): - table.add_row([session.dst_ip_address, session.dst_port.value, session.protocol.name]) + table.add_row([session.dst_ip_address, session.dst_port, session.protocol]) print(table) diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index d45611ed..8eac33fa 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -52,7 +52,7 @@ class SoftwareManager: self.session_manager = session_manager self.software: Dict[str, Union[Service, Application]] = {} self._software_class_to_name_map: Dict[Type[IOSoftware], str] = {} - self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {} + self.port_protocol_mapping: Dict[Tuple[int, str], Union[Service, Application]] = {} self.sys_log: SysLog = sys_log self.file_system: FileSystem = file_system self.dns_server: Optional[IPv4Address] = dns_server @@ -67,7 +67,7 @@ class SoftwareManager: """Provides access to the ICMP service instance, if installed.""" return self.software.get("ICMP") # noqa - def get_open_ports(self) -> List[Port]: + def get_open_ports(self) -> List[int]: """ Get a list of open ports. @@ -81,7 +81,7 @@ class SoftwareManager: open_ports += list(software.listen_on_ports) return open_ports - def check_port_is_open(self, port: Port, protocol: IPProtocol) -> bool: + def check_port_is_open(self, port: int, protocol: str) -> bool: """ Check if a specific port is open and running a service using the specified protocol. @@ -93,7 +93,7 @@ class SoftwareManager: :param port: The port to check. :type port: Port :param protocol: The protocol to check (e.g., TCP, UDP). - :type protocol: IPProtocol + :type protocol: str :return: True if the port is open and a service is running on it using the specified protocol, False otherwise. :rtype: bool """ @@ -189,9 +189,9 @@ class SoftwareManager: self, payload: Any, dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, - src_port: Optional[Port] = None, - dest_port: Optional[Port] = None, - ip_protocol: IPProtocol = IPProtocol.TCP, + src_port: Optional[int] = None, + dest_port: Optional[int] = None, + ip_protocol: str = IPProtocol["TCP"], session_id: Optional[str] = None, ) -> bool: """ @@ -219,8 +219,8 @@ class SoftwareManager: def receive_payload_from_session_manager( self, payload: Any, - port: Port, - protocol: IPProtocol, + port: int, + protocol: str, session_id: str, from_network_interface: "NIC", frame: Frame, @@ -275,8 +275,8 @@ class SoftwareManager: software_type, software.operating_state.name, software.health_state_actual.name, - software.port.value if software.port != Port.NONE else None, - software.protocol.value, + software.port if software.port != Port["NONE"] else None, + software.protocol, ] ) print(table) diff --git a/src/primaite/simulator/system/services/arp/arp.py b/src/primaite/simulator/system/services/arp/arp.py index efadf189..b8dd5f89 100644 --- a/src/primaite/simulator/system/services/arp/arp.py +++ b/src/primaite/simulator/system/services/arp/arp.py @@ -26,8 +26,8 @@ class ARP(Service): def __init__(self, **kwargs): kwargs["name"] = "ARP" - kwargs["port"] = Port.ARP - kwargs["protocol"] = IPProtocol.UDP + kwargs["port"] = Port["ARP"] + kwargs["protocol"] = IPProtocol["UDP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index b38e87b4..11ca9eb2 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -38,8 +38,8 @@ class DatabaseService(Service): def __init__(self, **kwargs): kwargs["name"] = "DatabaseService" - kwargs["port"] = Port.POSTGRES_SERVER - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["POSTGRES_SERVER"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) self._create_db_file() diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index d7ba0cd4..62f14366 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -22,11 +22,11 @@ class DNSClient(Service): def __init__(self, **kwargs): kwargs["name"] = "DNSClient" - kwargs["port"] = Port.DNS + kwargs["port"] = Port["DNS"] # DNS uses UDP by default # it switches to TCP when the bytes exceed 512 (or 4096) bytes # TCP for now - kwargs["protocol"] = IPProtocol.TCP + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) self.start() @@ -95,7 +95,7 @@ class DNSClient(Service): # send a request to check if domain name exists in the DNS Server software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( - payload=payload, dest_ip_address=self.dns_server, dest_port=Port.DNS + payload=payload, dest_ip_address=self.dns_server, dest_port=Port["DNS"] ) # recursively re-call the function passing is_reattempt=True @@ -110,7 +110,7 @@ class DNSClient(Service): payload: DNSPacket, session_id: Optional[str] = None, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = None, + dest_port: Optional[int] = None, **kwargs, ) -> bool: """ diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index 8a4bbaed..93895825 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -21,11 +21,11 @@ class DNSServer(Service): def __init__(self, **kwargs): kwargs["name"] = "DNSServer" - kwargs["port"] = Port.DNS + kwargs["port"] = Port["DNS"] # DNS uses UDP by default # it switches to TCP when the bytes exceed 512 (or 4096) bytes # TCP for now - kwargs["protocol"] = IPProtocol.TCP + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) self.start() diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index f823e42c..1fce4133 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -25,8 +25,8 @@ class FTPClient(FTPServiceABC): def __init__(self, **kwargs): kwargs["name"] = "FTPClient" - kwargs["port"] = Port.FTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["FTP"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) self.start() @@ -104,7 +104,7 @@ class FTPClient(FTPServiceABC): def _connect_to_server( self, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = Port.FTP, + dest_port: Optional[int] = Port["FTP"], session_id: Optional[str] = None, is_reattempt: Optional[bool] = False, ) -> bool: @@ -114,7 +114,7 @@ class FTPClient(FTPServiceABC): :param: dest_ip_address: IP address of the FTP server the client needs to connect to. Optional. :type: dest_ip_address: Optional[IPv4Address] :param: dest_port: Port of the FTP server the client needs to connect to. Optional. - :type: dest_port: Optional[Port] + :type: dest_port: Optional[int] :param: is_reattempt: Set to True if attempt to connect to FTP Server has been attempted. Default False. :type: is_reattempt: Optional[bool] """ @@ -124,13 +124,13 @@ class FTPClient(FTPServiceABC): # normally FTP will choose a random port for the transfer, but using the FTP command port will do for now # create FTP packet - payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.PORT, ftp_command_args=Port.FTP) + payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.PORT, ftp_command_args=Port["FTP"]) if self.send(payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id): if payload.status_code == FTPStatusCode.OK: self.sys_log.info( f"{self.name}: Successfully connected to FTP Server " - f"{dest_ip_address} via port {payload.ftp_command_args.value}" + f"{dest_ip_address} via port {payload.ftp_command_args}" ) self.add_connection(connection_id="server_connection", session_id=session_id) return True @@ -139,7 +139,7 @@ class FTPClient(FTPServiceABC): # reattempt failed self.sys_log.warning( f"{self.name}: Unable to connect to FTP Server " - f"{dest_ip_address} via port {payload.ftp_command_args.value}" + f"{dest_ip_address} via port {payload.ftp_command_args}" ) return False else: @@ -152,7 +152,7 @@ class FTPClient(FTPServiceABC): return False def _disconnect_from_server( - self, dest_ip_address: Optional[IPv4Address] = None, dest_port: Optional[Port] = Port.FTP + self, dest_ip_address: Optional[IPv4Address] = None, dest_port: Optional[int] = Port["FTP"] ) -> bool: """ Connects the client from a given FTP server. @@ -160,7 +160,7 @@ class FTPClient(FTPServiceABC): :param: dest_ip_address: IP address of the FTP server the client needs to disconnect from. Optional. :type: dest_ip_address: Optional[IPv4Address] :param: dest_port: Port of the FTP server the client needs to disconnect from. Optional. - :type: dest_port: Optional[Port] + :type: dest_port: Optional[int] :param: is_reattempt: Set to True if attempt to disconnect from FTP Server has been attempted. Default False. :type: is_reattempt: Optional[bool] """ @@ -179,7 +179,7 @@ class FTPClient(FTPServiceABC): src_file_name: str, dest_folder_name: str, dest_file_name: str, - dest_port: Optional[Port] = Port.FTP, + dest_port: Optional[int] = Port["FTP"], session_id: Optional[str] = None, ) -> bool: """ @@ -203,8 +203,8 @@ class FTPClient(FTPServiceABC): :param: dest_file_name: The name of the file to be saved on the FTP Server. :type: dest_file_name: str - :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port.FTP. - :type: dest_port: Optional[Port] + :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"]. + :type: dest_port: Optional[int] :param: session_id: The id of the session :type: session_id: Optional[str] @@ -241,7 +241,7 @@ class FTPClient(FTPServiceABC): src_file_name: str, dest_folder_name: str, dest_file_name: str, - dest_port: Optional[Port] = Port.FTP, + dest_port: Optional[int] = Port["FTP"], ) -> bool: """ Request a file from a target IP address. @@ -263,8 +263,8 @@ class FTPClient(FTPServiceABC): :param: dest_file_name: The name of the file to be saved on the FTP Server. :type: dest_file_name: str - :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port.FTP. - :type: dest_port: Optional[Port] + :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"]. + :type: dest_port: Optional[int] """ # check if FTP is currently connected to IP self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port) diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index f02d01f4..701bff79 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -23,8 +23,8 @@ class FTPServer(FTPServiceABC): def __init__(self, **kwargs): kwargs["name"] = "FTPServer" - kwargs["port"] = Port.FTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["FTP"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) self.start() @@ -52,7 +52,7 @@ class FTPServer(FTPServiceABC): # process server specific commands, otherwise call super if payload.ftp_command == FTPCommand.PORT: # check that the port is valid - if isinstance(payload.ftp_command_args, Port) and payload.ftp_command_args.value in range(0, 65535): + if isinstance(payload.ftp_command_args, int) and (0 <= payload.ftp_command_args < 65535): # return successful connection self.add_connection(connection_id=session_id, session_id=session_id) payload.status_code = FTPStatusCode.OK diff --git a/src/primaite/simulator/system/services/ftp/ftp_service.py b/src/primaite/simulator/system/services/ftp/ftp_service.py index 689a3da7..36245e0f 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_service.py +++ b/src/primaite/simulator/system/services/ftp/ftp_service.py @@ -78,7 +78,7 @@ class FTPServiceABC(Service, ABC): dest_folder_name: str, dest_file_name: str, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = None, + dest_port: Optional[int] = None, session_id: Optional[str] = None, is_response: bool = False, ) -> bool: @@ -97,8 +97,8 @@ class FTPServiceABC(Service, ABC): :param: dest_ip_address: The IP address of the machine that hosts the FTP Server. :type: dest_ip_address: Optional[IPv4Address] - :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port.FTP. - :type: dest_port: Optional[Port] + :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"]. + :type: dest_port: Optional[int] :param: session_id: session ID linked to the FTP Packet. Optional. :type: session_id: Optional[str] @@ -168,7 +168,7 @@ class FTPServiceABC(Service, ABC): payload: FTPPacket, session_id: Optional[str] = None, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = None, + dest_port: Optional[int] = None, **kwargs, ) -> bool: """ diff --git a/src/primaite/simulator/system/services/icmp/icmp.py b/src/primaite/simulator/system/services/icmp/icmp.py index 6741d86a..a2dfac0d 100644 --- a/src/primaite/simulator/system/services/icmp/icmp.py +++ b/src/primaite/simulator/system/services/icmp/icmp.py @@ -26,8 +26,8 @@ class ICMP(Service): def __init__(self, **kwargs): kwargs["name"] = "ICMP" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.ICMP + kwargs["port"] = Port["NONE"] + kwargs["protocol"] = IPProtocol["ICMP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/services/icmp/router_icmp.py b/src/primaite/simulator/system/services/icmp/router_icmp.py index 4fdc6baa..19c0ac2d 100644 --- a/src/primaite/simulator/system/services/icmp/router_icmp.py +++ b/src/primaite/simulator/system/services/icmp/router_icmp.py @@ -36,13 +36,13 @@ # self.sys_log.info(f"Received echo request from {frame.ip.src_ip_address}") # target_mac_address = self.arp.get_arp_cache_mac_address(frame.ip.src_ip_address) # src_nic = self.arp.get_arp_cache_network_interface(frame.ip.src_ip_address) -# tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP) +# tcp_header = TCPHeader(src_port=Port["ARP"], dst_port=Port["ARP"]) # # # Network Layer # ip_packet = IPPacket( # src_ip_address=network_interface.ip_address, # dst_ip_address=frame.ip.src_ip_address, -# protocol=IPProtocol.ICMP, +# protocol=IPProtocol["ICMP"], # ) # # Data Link Layer # ethernet_header = EthernetHeader( diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index 8924a821..40b8d273 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -21,8 +21,8 @@ class NTPClient(Service): def __init__(self, **kwargs): kwargs["name"] = "NTPClient" - kwargs["port"] = Port.NTP - kwargs["protocol"] = IPProtocol.UDP + kwargs["port"] = Port["NTP"] + kwargs["protocol"] = IPProtocol["UDP"] super().__init__(**kwargs) self.start() @@ -55,7 +55,7 @@ class NTPClient(Service): payload: NTPPacket, session_id: Optional[str] = None, dest_ip_address: IPv4Address = None, - dest_port: Port = Port.NTP, + dest_port: int = Port["NTP"], **kwargs, ) -> bool: """Requests NTP data from NTP server. diff --git a/src/primaite/simulator/system/services/ntp/ntp_server.py b/src/primaite/simulator/system/services/ntp/ntp_server.py index 547bbc06..d9de40c6 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_server.py +++ b/src/primaite/simulator/system/services/ntp/ntp_server.py @@ -16,8 +16,8 @@ class NTPServer(Service): def __init__(self, **kwargs): kwargs["name"] = "NTPServer" - kwargs["port"] = Port.NTP - kwargs["protocol"] = IPProtocol.UDP + kwargs["port"] = Port["NTP"] + kwargs["protocol"] = IPProtocol["UDP"] super().__init__(**kwargs) self.start() diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index e98e8555..41987aff 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -137,8 +137,8 @@ class Terminal(Service): def __init__(self, **kwargs): kwargs["name"] = "Terminal" - kwargs["port"] = Port.SSH - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["SSH"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 4fc64e1f..c021a86e 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -49,10 +49,10 @@ class WebServer(Service): def __init__(self, **kwargs): kwargs["name"] = "WebServer" - kwargs["protocol"] = IPProtocol.TCP + kwargs["protocol"] = IPProtocol["TCP"] # default for web is port 80 if kwargs.get("port") is None: - kwargs["port"] = Port.HTTP + kwargs["port"] = Port["HTTP"] super().__init__(**kwargs) self._install_web_files() @@ -145,7 +145,7 @@ class WebServer(Service): payload: HttpResponsePacket, session_id: Optional[str] = None, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = None, + dest_port: Optional[int] = None, **kwargs, ) -> bool: """ diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index f1d1b9a1..1880d244 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -251,11 +251,11 @@ class IOSoftware(Software): "Indicates if the software uses TCP protocol for communication. Default is True." udp: bool = True "Indicates if the software uses UDP protocol for communication. Default is True." - port: Port + port: int "The port to which the software is connected." - listen_on_ports: Set[Port] = Field(default_factory=set) + listen_on_ports: Set[int] = Field(default_factory=set) "The set of ports to listen on." - protocol: IPProtocol + protocol: str "The IP Protocol the Software operates on." _connections: Dict[str, Dict] = {} "Active connections." @@ -277,7 +277,7 @@ class IOSoftware(Software): "max_sessions": self.max_sessions, "tcp": self.tcp, "udp": self.udp, - "port": self.port.value, + "port": self.port, } ) return state @@ -386,8 +386,8 @@ class IOSoftware(Software): payload: Any, session_id: Optional[str] = None, dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, - dest_port: Optional[Port] = None, - ip_protocol: IPProtocol = IPProtocol.TCP, + dest_port: Optional[int] = None, + ip_protocol: str = IPProtocol["TCP"], **kwargs, ) -> bool: """ diff --git a/tests/conftest.py b/tests/conftest.py index 1bbff8f2..1ffa2146 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,8 +45,8 @@ class DummyService(Service): def __init__(self, **kwargs): kwargs["name"] = "DummyService" - kwargs["port"] = Port.HTTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["HTTP"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) def receive(self, payload: Any, session_id: str, **kwargs) -> bool: @@ -58,8 +58,8 @@ class DummyApplication(Application, identifier="DummyApplication"): def __init__(self, **kwargs): kwargs["name"] = "DummyApplication" - kwargs["port"] = Port.HTTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["HTTP"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: @@ -77,7 +77,7 @@ def uc2_network() -> Network: @pytest.fixture(scope="function") def service(file_system) -> DummyService: return DummyService( - name="DummyService", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_service") + name="DummyService", port=Port["ARP"], file_system=file_system, sys_log=SysLog(hostname="dummy_service") ) @@ -90,7 +90,7 @@ def service_class(): def application(file_system) -> DummyApplication: return DummyApplication( name="DummyApplication", - port=Port.ARP, + port=Port["ARP"], file_system=file_system, sys_log=SysLog(hostname="dummy_application"), ) @@ -350,10 +350,10 @@ def install_stuff_to_sim(sim: Simulation): network.connect(endpoint_a=server_2.network_interface[1], endpoint_b=switch_2.network_interface[2]) # 2: Configure base ACL - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["DNS"], dst_port=Port["DNS"], position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=3) # 3: Install server software server_1.software_manager.install(DNSServer) @@ -379,13 +379,13 @@ def install_stuff_to_sim(sim: Simulation): r = sim.network.router_nodes[0] for i, acl_rule in enumerate(r.acl.acl): if i == 1: - assert acl_rule.src_port == acl_rule.dst_port == Port.DNS + assert acl_rule.src_port == acl_rule.dst_port == Port["DNS"] elif i == 3: - assert acl_rule.src_port == acl_rule.dst_port == Port.HTTP + assert acl_rule.src_port == acl_rule.dst_port == Port["HTTP"] elif i == 22: - assert acl_rule.src_port == acl_rule.dst_port == Port.ARP + assert acl_rule.src_port == acl_rule.dst_port == Port["ARP"] elif i == 23: - assert acl_rule.protocol == IPProtocol.ICMP + assert acl_rule.protocol == IPProtocol["ICMP"] elif i == 24: ... else: diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py index 457fdb42..d35e2ebb 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py @@ -68,44 +68,44 @@ def test_firewall_acl_rules_correctly_added(dmz_config): # ICMP and ARP should be allowed internal_inbound assert firewall.internal_inbound_acl.num_rules == 2 assert firewall.internal_inbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.internal_inbound_acl.acl[22].src_port == Port.ARP - assert firewall.internal_inbound_acl.acl[22].dst_port == Port.ARP + assert firewall.internal_inbound_acl.acl[22].src_port == Port["ARP"] + assert firewall.internal_inbound_acl.acl[22].dst_port == Port["ARP"] assert firewall.internal_inbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.internal_inbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.internal_inbound_acl.acl[23].protocol == IPProtocol["ICMP"] assert firewall.internal_inbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed internal_outbound assert firewall.internal_outbound_acl.num_rules == 2 assert firewall.internal_outbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.internal_outbound_acl.acl[22].src_port == Port.ARP - assert firewall.internal_outbound_acl.acl[22].dst_port == Port.ARP + assert firewall.internal_outbound_acl.acl[22].src_port == Port["ARP"] + assert firewall.internal_outbound_acl.acl[22].dst_port == Port["ARP"] assert firewall.internal_outbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.internal_outbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.internal_outbound_acl.acl[23].protocol == IPProtocol["ICMP"] assert firewall.internal_outbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed dmz_inbound assert firewall.dmz_inbound_acl.num_rules == 2 assert firewall.dmz_inbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.dmz_inbound_acl.acl[22].src_port == Port.ARP - assert firewall.dmz_inbound_acl.acl[22].dst_port == Port.ARP + assert firewall.dmz_inbound_acl.acl[22].src_port == Port["ARP"] + assert firewall.dmz_inbound_acl.acl[22].dst_port == Port["ARP"] assert firewall.dmz_inbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.dmz_inbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.dmz_inbound_acl.acl[23].protocol == IPProtocol["ICMP"] assert firewall.dmz_inbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed dmz_outbound assert firewall.dmz_outbound_acl.num_rules == 2 assert firewall.dmz_outbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.dmz_outbound_acl.acl[22].src_port == Port.ARP - assert firewall.dmz_outbound_acl.acl[22].dst_port == Port.ARP + assert firewall.dmz_outbound_acl.acl[22].src_port == Port["ARP"] + assert firewall.dmz_outbound_acl.acl[22].dst_port == Port["ARP"] assert firewall.dmz_outbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.dmz_outbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.dmz_outbound_acl.acl[23].protocol == IPProtocol["ICMP"] assert firewall.dmz_outbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed external_inbound assert firewall.external_inbound_acl.num_rules == 1 assert firewall.external_inbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.external_inbound_acl.acl[22].src_port == Port.ARP - assert firewall.external_inbound_acl.acl[22].dst_port == Port.ARP + assert firewall.external_inbound_acl.acl[22].src_port == Port["ARP"] + assert firewall.external_inbound_acl.acl[22].dst_port == Port["ARP"] # external_inbound should have implicit action PERMIT # ICMP does not have a provided ACL Rule but implicit action should allow anything assert firewall.external_inbound_acl.implicit_action == ACLAction.PERMIT @@ -113,8 +113,8 @@ def test_firewall_acl_rules_correctly_added(dmz_config): # ICMP and ARP should be allowed external_outbound assert firewall.external_outbound_acl.num_rules == 1 assert firewall.external_outbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.external_outbound_acl.acl[22].src_port == Port.ARP - assert firewall.external_outbound_acl.acl[22].dst_port == Port.ARP + assert firewall.external_outbound_acl.acl[22].src_port == Port["ARP"] + assert firewall.external_outbound_acl.acl[22].dst_port == Port["ARP"] # external_outbound should have implicit action PERMIT # ICMP does not have a provided ACL Rule but implicit action should allow anything assert firewall.external_outbound_acl.implicit_action == ACLAction.PERMIT diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py index ccde3a02..16543565 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py @@ -63,8 +63,8 @@ def test_router_acl_rules_correctly_added(dmz_config): # ICMP and ARP should be allowed assert router_1.acl.num_rules == 2 assert router_1.acl.acl[22].action == ACLAction.PERMIT - assert router_1.acl.acl[22].src_port == Port.ARP - assert router_1.acl.acl[22].dst_port == Port.ARP + assert router_1.acl.acl[22].src_port == Port["ARP"] + assert router_1.acl.acl[22].dst_port == Port["ARP"] assert router_1.acl.acl[23].action == ACLAction.PERMIT - assert router_1.acl.acl[23].protocol == IPProtocol.ICMP + assert router_1.acl.acl[23].protocol == IPProtocol["ICMP"] assert router_1.acl.implicit_action == ACLAction.DENY diff --git a/tests/integration_tests/extensions/applications/extended_application.py b/tests/integration_tests/extensions/applications/extended_application.py index c9b3006d..8e3d33e1 100644 --- a/tests/integration_tests/extensions/applications/extended_application.py +++ b/tests/integration_tests/extensions/applications/extended_application.py @@ -44,10 +44,10 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): def __init__(self, **kwargs): kwargs["name"] = "ExtendedApplication" - kwargs["protocol"] = IPProtocol.TCP + kwargs["protocol"] = IPProtocol["TCP"] # default for web is port 80 if kwargs.get("port") is None: - kwargs["port"] = Port.HTTP + kwargs["port"] = Port["HTTP"] super().__init__(**kwargs) self.run() @@ -127,7 +127,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): if self.send( payload=payload, dest_ip_address=self.domain_name_ip_address, - dest_port=parsed_url.port if parsed_url.port else Port.HTTP, + dest_port=parsed_url.port if parsed_url.port else Port["HTTP"], ): self.sys_log.info( f"{self.name}: Received HTTP {payload.request_method.name} " @@ -155,7 +155,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): self, payload: HttpRequestPacket, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[Port] = Port.HTTP, + dest_port: Optional[int] = Port["HTTP"], session_id: Optional[str] = None, **kwargs, ) -> bool: diff --git a/tests/integration_tests/extensions/services/extended_service.py b/tests/integration_tests/extensions/services/extended_service.py index 3151571b..d4af600f 100644 --- a/tests/integration_tests/extensions/services/extended_service.py +++ b/tests/integration_tests/extensions/services/extended_service.py @@ -38,8 +38,8 @@ class ExtendedService(Service, identifier='extendedservice'): def __init__(self, **kwargs): kwargs["name"] = "ExtendedService" - kwargs["port"] = Port.POSTGRES_SERVER - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["POSTGRES_SERVER"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) self._create_db_file() if kwargs.get('options'): diff --git a/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py b/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py index 806ce063..17b0ba8c 100644 --- a/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py +++ b/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py @@ -26,9 +26,9 @@ def game_and_agent_fixture(game_and_agent): game, agent = game_and_agent router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=4) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=5) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=6) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=4) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["DNS"], dst_port=Port["DNS"], position=5) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["FTP"], dst_port=Port["FTP"], position=6) c2_server_host = game.simulation.network.get_node_by_hostname("client_1") c2_server_host.software_manager.install(software_class=C2Server) diff --git a/tests/integration_tests/game_layer/actions/test_configure_actions.py b/tests/integration_tests/game_layer/actions/test_configure_actions.py index 0c9ec6f0..34ee25d6 100644 --- a/tests/integration_tests/game_layer/actions/test_configure_actions.py +++ b/tests/integration_tests/game_layer/actions/test_configure_actions.py @@ -200,7 +200,7 @@ class TestConfigureDoSBot: game.step() assert dos_bot.target_ip_address == IPv4Address("192.168.1.99") - assert dos_bot.target_port == Port.POSTGRES_SERVER + assert dos_bot.target_port == Port["POSTGRES_SERVER"] assert dos_bot.payload == "HACC" assert not dos_bot.repeat assert dos_bot.port_scan_p_of_success == 0.875 diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py index d011c1e8..857edd26 100644 --- a/tests/integration_tests/game_layer/actions/test_terminal_actions.py +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -20,7 +20,7 @@ def game_and_agent_fixture(game_and_agent): game, agent = game_and_agent router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=4) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["SSH"], dst_port=Port["SSH"], position=4) return (game, agent) diff --git a/tests/integration_tests/game_layer/observations/test_acl_observations.py b/tests/integration_tests/game_layer/observations/test_acl_observations.py index f1d9d416..398c43a9 100644 --- a/tests/integration_tests/game_layer/observations/test_acl_observations.py +++ b/tests/integration_tests/game_layer/observations/test_acl_observations.py @@ -33,7 +33,7 @@ def test_acl_observations(simulation): server.software_manager.install(NTPServer) # add router acl rule - router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port.NTP, src_port=Port.NTP, position=1) + router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port["NTP"], src_port=Port["NTP"], position=1) acl_obs = ACLObservation( where=["network", "nodes", router.hostname, "acl", "acl"], diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 34a37f5e..68506d59 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -62,13 +62,13 @@ def test_firewall_observation(): # add a rule to the internal inbound and check that the observation is correct firewall.internal_inbound_acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="10.0.0.1", src_wildcard_mask="0.0.0.1", dst_ip_address="10.0.0.2", dst_wildcard_mask="0.0.0.1", - src_port=Port.HTTP, - dst_port=Port.HTTP, + src_port=Port["HTTP"], + dst_port=Port["HTTP"], position=5, ) diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index ef789ba7..bd8dfc4e 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -152,7 +152,7 @@ def test_config_nic_categories(simulation): def test_nic_monitored_traffic(simulation): - monitored_traffic = {"icmp": ["NONE"], "tcp": ["DNS"]} + monitored_traffic = {"icmp": ["NONE"], "tcp": [53,]} pc: Computer = simulation.network.get_node_by_hostname("client_1") pc2: Computer = simulation.network.get_node_by_hostname("client_2") diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index 48d29cfb..937bb061 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -39,13 +39,13 @@ def test_router_observation(): # Add an ACL rule to the router router.acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="10.0.0.1", src_wildcard_mask="0.0.0.1", dst_ip_address="10.0.0.2", dst_wildcard_mask="0.0.0.1", - src_port=Port.HTTP, - dst_port=Port.HTTP, + src_port=Port["HTTP"], + dst_port=Port["HTTP"], position=5, ) # Observe the state using the RouterObservation instance diff --git a/tests/integration_tests/game_layer/observations/test_user_observations.py b/tests/integration_tests/game_layer/observations/test_user_observations.py index ca5e2543..6ca4bc9e 100644 --- a/tests/integration_tests/game_layer/observations/test_user_observations.py +++ b/tests/integration_tests/game_layer/observations/test_user_observations.py @@ -15,7 +15,7 @@ def env_with_ssh() -> PrimaiteGymEnv: env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG) env.agent.flatten_obs = False router: Router = env.game.simulation.network.get_node_by_hostname("router_1") - router.acl.add_rule(ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=3) + router.acl.add_rule(ACLAction.PERMIT, src_port=Port["SSH"], dst_port=Port["SSH"], position=3) return env diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index a1005f34..c3e86263 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -608,9 +608,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.internal_outbound_acl.acl[1].action.name == "DENY" assert firewall.internal_outbound_acl.acl[1].src_ip_address == IPv4Address("192.168.0.10") assert firewall.internal_outbound_acl.acl[1].dst_ip_address is None - assert firewall.internal_outbound_acl.acl[1].dst_port == Port.DNS - assert firewall.internal_outbound_acl.acl[1].src_port == Port.ARP - assert firewall.internal_outbound_acl.acl[1].protocol == IPProtocol.ICMP + assert firewall.internal_outbound_acl.acl[1].dst_port == Port["DNS"] + assert firewall.internal_outbound_acl.acl[1].src_port == Port["ARP"] + assert firewall.internal_outbound_acl.acl[1].protocol == IPProtocol["ICMP"] env.step(4) # Remove ACL rule from Internal Outbound assert firewall.internal_outbound_acl.num_rules == 2 @@ -620,9 +620,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.dmz_inbound_acl.acl[1].action.name == "DENY" assert firewall.dmz_inbound_acl.acl[1].src_ip_address == IPv4Address("192.168.10.10") assert firewall.dmz_inbound_acl.acl[1].dst_ip_address == IPv4Address("192.168.0.10") - assert firewall.dmz_inbound_acl.acl[1].dst_port == Port.HTTP - assert firewall.dmz_inbound_acl.acl[1].src_port == Port.HTTP - assert firewall.dmz_inbound_acl.acl[1].protocol == IPProtocol.UDP + assert firewall.dmz_inbound_acl.acl[1].dst_port == Port["HTTP"] + assert firewall.dmz_inbound_acl.acl[1].src_port == Port["HTTP"] + assert firewall.dmz_inbound_acl.acl[1].protocol == IPProtocol["UDP"] env.step(6) # Remove ACL rule from DMZ Inbound assert firewall.dmz_inbound_acl.num_rules == 2 @@ -632,9 +632,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.dmz_outbound_acl.acl[2].action.name == "DENY" assert firewall.dmz_outbound_acl.acl[2].src_ip_address == IPv4Address("192.168.10.10") assert firewall.dmz_outbound_acl.acl[2].dst_ip_address == IPv4Address("192.168.0.10") - assert firewall.dmz_outbound_acl.acl[2].dst_port == Port.HTTP - assert firewall.dmz_outbound_acl.acl[2].src_port == Port.HTTP - assert firewall.dmz_outbound_acl.acl[2].protocol == IPProtocol.TCP + assert firewall.dmz_outbound_acl.acl[2].dst_port == Port["HTTP"] + assert firewall.dmz_outbound_acl.acl[2].src_port == Port["HTTP"] + assert firewall.dmz_outbound_acl.acl[2].protocol == IPProtocol["TCP"] env.step(8) # Remove ACL rule from DMZ Outbound assert firewall.dmz_outbound_acl.num_rules == 2 @@ -644,9 +644,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.external_inbound_acl.acl[10].action.name == "DENY" assert firewall.external_inbound_acl.acl[10].src_ip_address == IPv4Address("192.168.20.10") assert firewall.external_inbound_acl.acl[10].dst_ip_address == IPv4Address("192.168.10.10") - assert firewall.external_inbound_acl.acl[10].dst_port == Port.POSTGRES_SERVER - assert firewall.external_inbound_acl.acl[10].src_port == Port.POSTGRES_SERVER - assert firewall.external_inbound_acl.acl[10].protocol == IPProtocol.ICMP + assert firewall.external_inbound_acl.acl[10].dst_port == Port["POSTGRES_SERVER"] + assert firewall.external_inbound_acl.acl[10].src_port == Port["POSTGRES_SERVER"] + assert firewall.external_inbound_acl.acl[10].protocol == IPProtocol["ICMP"] env.step(10) # Remove ACL rule from External Inbound assert firewall.external_inbound_acl.num_rules == 1 diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 58783d70..d872c2b0 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -42,7 +42,7 @@ def test_WebpageUnavailablePenalty(game_and_agent): # Block the web traffic, check that failing to fetch the webpage yields a reward of -0.7 router: Router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.DENY, protocol=IPProtocol.TCP, src_port=Port.HTTP, dst_port=Port.HTTP) + router.acl.add_rule(action=ACLAction.DENY, protocol=IPProtocol["TCP"], src_port=Port["HTTP"], dst_port=Port["HTTP"]) agent.store_action(("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0})) game.step() assert agent.reward_function.current_reward == -0.7 @@ -65,7 +65,7 @@ def test_uc2_rewards(game_and_agent): db_client.run() router: Router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=2) + router.acl.add_rule(ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=2) comp = GreenAdminDatabaseUnreachablePenalty("client_1") diff --git a/tests/integration_tests/network/test_airspace_config.py b/tests/integration_tests/network/test_airspace_config.py index 78d00b47..1794c4bc 100644 --- a/tests/integration_tests/network/test_airspace_config.py +++ b/tests/integration_tests/network/test_airspace_config.py @@ -13,8 +13,8 @@ def test_override_freq_max_capacity_mbps(): config_dict = yaml.safe_load(f) network = PrimaiteGame.from_config(cfg=config_dict).simulation.network - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_2_4) == 123.45 - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_5) == 0.0 + assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_2_4"]) == 123.45 + assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_5"]) == 0.0 pc_a = network.get_node_by_hostname("pc_a") pc_b = network.get_node_by_hostname("pc_b") @@ -32,8 +32,8 @@ def test_override_freq_max_capacity_mbps_blocked(): config_dict = yaml.safe_load(f) network = PrimaiteGame.from_config(cfg=config_dict).simulation.network - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_2_4) == 0.0 - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_5) == 0.0 + assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_2_4"]) == 0.0 + assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_5"]) == 0.0 pc_a = network.get_node_by_hostname("pc_a") pc_b = network.get_node_by_hostname("pc_b") diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index 80007c46..da0af89d 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -20,8 +20,8 @@ class BroadcastTestService(Service): def __init__(self, **kwargs): # Set default service properties for broadcasting kwargs["name"] = "BroadcastService" - kwargs["port"] = Port.HTTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["HTTP"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: @@ -33,12 +33,12 @@ class BroadcastTestService(Service): super().send( payload="unicast", dest_ip_address=ip_address, - dest_port=Port.HTTP, + dest_port=Port["HTTP"], ) def broadcast(self, ip_network: IPv4Network): # Send a broadcast payload to an entire IP network - super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port.HTTP, ip_protocol=self.protocol) + super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port["HTTP"], ip_protocol=self.protocol) class BroadcastTestClient(Application, identifier="BroadcastTestClient"): @@ -49,8 +49,8 @@ class BroadcastTestClient(Application, identifier="BroadcastTestClient"): def __init__(self, **kwargs): # Set default client properties kwargs["name"] = "BroadcastTestClient" - kwargs["port"] = Port.HTTP - kwargs["protocol"] = IPProtocol.TCP + kwargs["port"] = Port["HTTP"] + kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/tests/integration_tests/network/test_firewall.py b/tests/integration_tests/network/test_firewall.py index b15ee51a..8e06ccfb 100644 --- a/tests/integration_tests/network/test_firewall.py +++ b/tests/integration_tests/network/test_firewall.py @@ -53,28 +53,28 @@ def dmz_external_internal_network() -> Network: ) # Allow ICMP - firewall_node.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.external_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.external_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + firewall_node.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + firewall_node.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + firewall_node.external_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + firewall_node.external_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) # Allow ARP firewall_node.internal_inbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22 + action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 ) firewall_node.internal_outbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22 + action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 ) firewall_node.external_inbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22 + action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 ) firewall_node.external_outbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22 + action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 ) - firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) + firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) # external node external_node = Computer( @@ -262,8 +262,8 @@ def test_service_allowed_with_rule(dmz_external_internal_network): assert not internal_ntp_client.time - firewall.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=1) - firewall.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=1) + firewall.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1) + firewall.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1) internal_ntp_client.request_time() @@ -271,8 +271,8 @@ def test_service_allowed_with_rule(dmz_external_internal_network): assert not dmz_ntp_client.time - firewall.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=1) - firewall.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=1) + firewall.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1) + firewall.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1) dmz_ntp_client.request_time() diff --git a/tests/integration_tests/network/test_routing.py b/tests/integration_tests/network/test_routing.py index 62b58cbd..641342e2 100644 --- a/tests/integration_tests/network/test_routing.py +++ b/tests/integration_tests/network/test_routing.py @@ -73,8 +73,8 @@ def multi_hop_network() -> Network: router_1.enable_port(2) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) # Configure PC B pc_b = Computer( @@ -197,8 +197,8 @@ def test_routing_services(multi_hop_network): router_1: Router = multi_hop_network.get_node_by_hostname("router_1") # noqa router_2: Router = multi_hop_network.get_node_by_hostname("router_2") # noqa - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=21) - router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=21) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=21) + router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=21) assert ntp_client.time is None ntp_client.request_time() diff --git a/tests/integration_tests/network/test_wireless_router.py b/tests/integration_tests/network/test_wireless_router.py index 733de6f6..2f1be930 100644 --- a/tests/integration_tests/network/test_wireless_router.py +++ b/tests/integration_tests/network/test_wireless_router.py @@ -37,8 +37,8 @@ def wireless_wan_network(): network.connect(pc_a.network_interface[1], router_1.network_interface[2]) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) # Configure PC B pc_b = Computer( diff --git a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py index 9d12f2cf..d819b511 100644 --- a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py +++ b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py @@ -227,7 +227,7 @@ def test_c2_suite_acl_block(basic_network): assert c2_beacon.c2_connection_active == True # Now we add a HTTP blocking acl (Thus preventing a keep alive) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) c2_beacon.apply_timestep(2) c2_beacon.apply_timestep(3) @@ -322,8 +322,8 @@ def test_c2_suite_acl_bypass(basic_network): ################ Confirm Default Setup ######################### # Permitting all HTTP & FTP traffic - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["FTP"], dst_port=Port["FTP"], position=1) c2_beacon.apply_timestep(0) assert c2_beacon.keep_alive_inactivity == 1 @@ -337,7 +337,7 @@ def test_c2_suite_acl_bypass(basic_network): ################ Denying HTTP Traffic ######################### # Now we add a HTTP blocking acl (Thus preventing a keep alive) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) blocking_acl: AccessControlList = router.acl.acl[0] # Asserts to show the C2 Suite is unable to maintain connection: @@ -359,8 +359,8 @@ def test_c2_suite_acl_bypass(basic_network): c2_beacon.configure( c2_server_ip_address="192.168.0.2", keep_alive_frequency=2, - masquerade_port=Port.FTP, - masquerade_protocol=IPProtocol.TCP, + masquerade_port=Port["FTP"], + masquerade_protocol=IPProtocol["TCP"], ) c2_beacon.establish() @@ -407,8 +407,8 @@ def test_c2_suite_acl_bypass(basic_network): ################ Denying FTP Traffic & Enable HTTP ######################### # Blocking FTP and re-permitting HTTP: - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.FTP, dst_port=Port.FTP, position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=Port["FTP"], dst_port=Port["FTP"], position=1) blocking_acl: AccessControlList = router.acl.acl[1] # Asserts to show the C2 Suite is unable to maintain connection: @@ -430,8 +430,8 @@ def test_c2_suite_acl_bypass(basic_network): c2_beacon.configure( c2_server_ip_address="192.168.0.2", keep_alive_frequency=2, - masquerade_port=Port.HTTP, - masquerade_protocol=IPProtocol.TCP, + masquerade_port=Port["HTTP"], + masquerade_protocol=IPProtocol["TCP"], ) c2_beacon.establish() diff --git a/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py b/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py index 2e87578d..9c0760b7 100644 --- a/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py +++ b/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py @@ -52,7 +52,7 @@ def data_manipulation_db_server_green_client(example_network) -> Network: router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 ) client_1: Computer = network.get_node_by_hostname("client_1") diff --git a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py index 68c1fbfe..709a417f 100644 --- a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py +++ b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py @@ -26,7 +26,7 @@ def dos_bot_and_db_server(client_server) -> Tuple[DoSBot, Computer, DatabaseServ dos_bot: DoSBot = computer.software_manager.software.get("DoSBot") dos_bot.configure( target_ip_address=IPv4Address(server.network_interface[1].ip_address), - target_port=Port.POSTGRES_SERVER, + target_port=Port["POSTGRES_SERVER"], ) # Install DB Server service on server @@ -43,7 +43,7 @@ def dos_bot_db_server_green_client(example_network) -> Network: router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 ) client_1: Computer = network.get_node_by_hostname("client_1") @@ -56,7 +56,7 @@ def dos_bot_db_server_green_client(example_network) -> Network: dos_bot: DoSBot = client_1.software_manager.software.get("DoSBot") dos_bot.configure( target_ip_address=IPv4Address(server.network_interface[1].ip_address), - target_port=Port.POSTGRES_SERVER, + target_port=Port["POSTGRES_SERVER"], ) # install db server service on server diff --git a/tests/integration_tests/system/red_applications/test_ransomware_script.py b/tests/integration_tests/system/red_applications/test_ransomware_script.py index 97abafb5..b34e9b30 100644 --- a/tests/integration_tests/system/red_applications/test_ransomware_script.py +++ b/tests/integration_tests/system/red_applications/test_ransomware_script.py @@ -47,7 +47,7 @@ def ransomware_script_db_server_green_client(example_network) -> Network: router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 ) client_1: Computer = network.get_node_by_hostname("client_1") diff --git a/tests/integration_tests/system/test_nmap.py b/tests/integration_tests/system/test_nmap.py index 2b8691cc..1064ed1b 100644 --- a/tests/integration_tests/system/test_nmap.py +++ b/tests/integration_tests/system/test_nmap.py @@ -73,10 +73,10 @@ def test_port_scan_one_node_one_port(example_network): client_2 = network.get_node_by_hostname("client_2") actual_result = client_1_nmap.port_scan( - target_ip_address=client_2.network_interface[1].ip_address, target_port=Port.DNS, target_protocol=IPProtocol.TCP + target_ip_address=client_2.network_interface[1].ip_address, target_port=Port["DNS"], target_protocol=IPProtocol["TCP"] ) - expected_result = {IPv4Address("192.168.10.22"): {IPProtocol.TCP: [Port.DNS]}} + expected_result = {IPv4Address("192.168.10.22"): {IPProtocol["TCP"]: [Port["DNS"]]}} assert actual_result == expected_result @@ -101,14 +101,14 @@ def test_port_scan_full_subnet_all_ports_and_protocols(example_network): actual_result = client_1_nmap.port_scan( target_ip_address=IPv4Network("192.168.10.0/24"), - target_port=[Port.ARP, Port.HTTP, Port.FTP, Port.DNS, Port.NTP], + target_port=[Port["ARP"], Port["HTTP"], Port["FTP"], Port["DNS"], Port["NTP"]], ) expected_result = { - IPv4Address("192.168.10.1"): {IPProtocol.UDP: [Port.ARP]}, + IPv4Address("192.168.10.1"): {IPProtocol["UDP"]: [Port["ARP"]]}, IPv4Address("192.168.10.22"): { - IPProtocol.TCP: [Port.HTTP, Port.FTP, Port.DNS], - IPProtocol.UDP: [Port.ARP, Port.NTP], + IPProtocol["TCP"]: [Port["HTTP"], Port["FTP"], Port["DNS"]], + IPProtocol["UDP"]: [Port["ARP"], Port["NTP"]], }, } @@ -122,10 +122,10 @@ def test_network_service_recon_all_ports_and_protocols(example_network): client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa actual_result = client_1_nmap.network_service_recon( - target_ip_address=IPv4Network("192.168.10.0/24"), target_port=Port.HTTP, target_protocol=IPProtocol.TCP + target_ip_address=IPv4Network("192.168.10.0/24"), target_port=Port["HTTP"], target_protocol=IPProtocol["TCP"] ) - expected_result = {IPv4Address("192.168.10.22"): {IPProtocol.TCP: [Port.HTTP]}} + expected_result = {IPv4Address("192.168.10.22"): {IPProtocol["TCP"]: [Port["HTTP"]]}} assert sort_dict(actual_result) == sort_dict(expected_result) diff --git a/tests/integration_tests/system/test_service_listening_on_ports.py b/tests/integration_tests/system/test_service_listening_on_ports.py index fd502a70..5226ab4a 100644 --- a/tests/integration_tests/system/test_service_listening_on_ports.py +++ b/tests/integration_tests/system/test_service_listening_on_ports.py @@ -16,9 +16,9 @@ from tests import TEST_ASSETS_ROOT class _DatabaseListener(Service): name: str = "DatabaseListener" - protocol: IPProtocol = IPProtocol.TCP - port: Port = Port.NONE - listen_on_ports: Set[Port] = {Port.POSTGRES_SERVER} + protocol: str = IPProtocol["TCP"] + port: int = Port["NONE"] + listen_on_ports: Set[int] = {Port["POSTGRES_SERVER"]} payloads_received: List[Any] = Field(default_factory=list) def receive(self, payload: Any, session_id: str, **kwargs) -> bool: @@ -51,8 +51,8 @@ def test_http_listener(client_server): computer.session_manager.receive_payload_from_software_manager( payload="masquerade as Database traffic", dst_ip_address=server.network_interface[1].ip_address, - dst_port=Port.POSTGRES_SERVER, - ip_protocol=IPProtocol.TCP, + dst_port=Port["POSTGRES_SERVER"], + ip_protocol=IPProtocol["TCP"], ) assert len(server_db_listener.payloads_received) == 1 @@ -76,9 +76,9 @@ def test_set_listen_on_ports_from_config(): network = PrimaiteGame.from_config(cfg=config_dict).simulation.network client: Computer = network.get_node_by_hostname("client") - assert Port.SMB in client.software_manager.get_open_ports() - assert Port.IPP in client.software_manager.get_open_ports() + assert Port["SMB"] in client.software_manager.get_open_ports() + assert Port["IPP"] in client.software_manager.get_open_ports() web_browser = client.software_manager.software["WebBrowser"] - assert not web_browser.listen_on_ports.difference({Port.SMB, Port.IPP}) + assert not web_browser.listen_on_ports.difference({Port["SMB"], Port["IPP"]}) diff --git a/tests/integration_tests/system/test_web_client_server_and_database.py b/tests/integration_tests/system/test_web_client_server_and_database.py index 5a765763..6c37360f 100644 --- a/tests/integration_tests/system/test_web_client_server_and_database.py +++ b/tests/integration_tests/system/test_web_client_server_and_database.py @@ -24,17 +24,17 @@ def web_client_web_server_database(example_network) -> Tuple[Network, Computer, # add rules to network router router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 ) # Allow DNS requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["DNS"], dst_port=Port["DNS"], position=1) # Allow FTP requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=2) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["FTP"], dst_port=Port["FTP"], position=2) # Open port 80 for web server - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=3) # Create Computer computer: Computer = example_network.get_node_by_hostname("client_1") @@ -148,7 +148,7 @@ class TestWebBrowserHistory: assert web_browser.history[-1].response_code == 200 router = network.get_node_by_hostname("router_1") - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) assert not web_browser.get_webpage() assert len(web_browser.history) == 3 # with current NIC behaviour, even if you block communication, you won't get SERVER_UNREACHABLE because @@ -166,7 +166,7 @@ class TestWebBrowserHistory: web_browser.get_webpage() router = network.get_node_by_hostname("router_1") - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) web_browser.get_webpage() state = computer.describe_state() diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py index 95634cf1..ff73e621 100644 --- a/tests/integration_tests/test_simulation/test_request_response.py +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -171,7 +171,7 @@ class TestDataManipulationGreenRequests: assert client_1_browser_execute.status == "success" assert client_2_browser_execute.status == "success" - router.acl.add_rule(ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + router.acl.add_rule(ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=3) client_1_browser_execute = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"]) client_2_browser_execute = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"]) assert client_1_browser_execute.status == "failure" @@ -182,7 +182,7 @@ class TestDataManipulationGreenRequests: assert client_1_db_client_execute.status == "success" assert client_2_db_client_execute.status == "success" - router.acl.add_rule(ACLAction.DENY, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER) + router.acl.add_rule(ACLAction.DENY, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"]) client_1_db_client_execute = net.apply_request(["node", "client_1", "application", "DatabaseClient", "execute"]) client_2_db_client_execute = net.apply_request(["node", "client_2", "application", "DatabaseClient", "execute"]) assert client_1_db_client_execute.status == "failure" diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py index 9bc1abfd..4c471faa 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py @@ -28,20 +28,20 @@ def router_with_acl_rules(): # Add rules here as needed acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="192.168.1.1", - src_port=Port.HTTPS, + src_port=Port["HTTPS"], dst_ip_address="192.168.1.2", - dst_port=Port.HTTP, + dst_port=Port["HTTP"], position=1, ) acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="192.168.1.3", - src_port=Port(8080), + src_port=8080, dst_ip_address="192.168.1.4", - dst_port=Port(80), + dst_port=80, position=2, ) return router @@ -65,21 +65,21 @@ def router_with_wildcard_acl(): # Rule to permit traffic from a specific source IP and port to a specific destination IP and port acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="192.168.1.1", - src_port=Port(8080), + src_port=8080, dst_ip_address="10.1.1.2", - dst_port=Port(80), + dst_port=80, position=1, ) # Rule to deny traffic from an IP range to a specific destination IP and port acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], src_ip_address="192.168.1.0", src_wildcard_mask="0.0.0.255", dst_ip_address="10.1.1.3", - dst_port=Port(443), + dst_port=443, position=2, ) # Rule to permit any traffic to a range of destination IPs @@ -109,11 +109,11 @@ def test_add_rule(router_with_acl_rules): acl = router_with_acl_rules.acl assert acl.acl[1].action == ACLAction.PERMIT - assert acl.acl[1].protocol == IPProtocol.TCP + assert acl.acl[1].protocol == IPProtocol["TCP"] assert acl.acl[1].src_ip_address == IPv4Address("192.168.1.1") - assert acl.acl[1].src_port == Port.HTTPS + assert acl.acl[1].src_port == Port["HTTPS"] assert acl.acl[1].dst_ip_address == IPv4Address("192.168.1.2") - assert acl.acl[1].dst_port == Port.HTTP + assert acl.acl[1].dst_port == Port["HTTP"] def test_remove_rule(router_with_acl_rules): @@ -136,8 +136,8 @@ def test_traffic_permitted_by_specific_rule(router_with_acl_rules): acl = router_with_acl_rules.acl permitted_frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="192.168.1.2", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port.HTTPS, dst_port=Port.HTTP), + ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="192.168.1.2", protocol=IPProtocol["TCP"]), + tcp=TCPHeader(src_port=Port["HTTPS"], dst_port=Port["HTTP"]), ) is_permitted, _ = acl.is_permitted(permitted_frame) assert is_permitted @@ -153,8 +153,8 @@ def test_traffic_denied_by_specific_rule(router_with_acl_rules): acl = router_with_acl_rules.acl not_permitted_frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.3", dst_ip_address="192.168.1.4", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port(8080), dst_port=Port(80)), + ip=IPPacket(src_ip_address="192.168.1.3", dst_ip_address="192.168.1.4", protocol=IPProtocol["TCP"]), + tcp=TCPHeader(src_port=8080, dst_port=80), ) is_permitted, _ = acl.is_permitted(not_permitted_frame) assert not is_permitted @@ -173,8 +173,8 @@ def test_default_rule(router_with_acl_rules): acl = router_with_acl_rules.acl not_permitted_frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.5", dst_ip_address="192.168.1.12", protocol=IPProtocol.UDP), - udp=UDPHeader(src_port=Port.HTTPS, dst_port=Port.HTTP), + ip=IPPacket(src_ip_address="192.168.1.5", dst_ip_address="192.168.1.12", protocol=IPProtocol["UDP"]), + udp=UDPHeader(src_port=Port["HTTPS"], dst_port=Port["HTTP"]), ) is_permitted, rule = acl.is_permitted(not_permitted_frame) assert not is_permitted @@ -189,8 +189,8 @@ def test_direct_ip_match_with_acl(router_with_wildcard_acl): acl = router_with_wildcard_acl.acl frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="10.1.1.2", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port(8080), dst_port=Port(80)), + ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="10.1.1.2", protocol=IPProtocol["TCP"]), + tcp=TCPHeader(src_port=8080, dst_port=80), ) assert acl.is_permitted(frame)[0], "Direct IP match should be permitted." @@ -204,8 +204,8 @@ def test_ip_range_match_denied_with_acl(router_with_wildcard_acl): acl = router_with_wildcard_acl.acl frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.100", dst_ip_address="10.1.1.3", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port(8080), dst_port=Port(443)), + ip=IPPacket(src_ip_address="192.168.1.100", dst_ip_address="10.1.1.3", protocol=IPProtocol["TCP"]), + tcp=TCPHeader(src_port=8080, dst_port=443), ) assert not acl.is_permitted(frame)[0], "IP range match with wildcard mask should be denied." @@ -219,8 +219,8 @@ def test_traffic_permitted_to_destination_range_with_acl(router_with_wildcard_ac acl = router_with_wildcard_acl.acl frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=IPProtocol.UDP), - udp=UDPHeader(src_port=Port(1433), dst_port=Port(1433)), + ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=IPProtocol["UDP"]), + udp=UDPHeader(src_port=1433, dst_port=1433), ) assert acl.is_permitted(frame)[0], "Traffic to destination IP range should be permitted." @@ -253,23 +253,23 @@ def test_ip_traffic_from_specific_subnet(): permitted_frame_1 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER), + ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=IPProtocol["TCP"]), + tcp=TCPHeader(src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"]), ) assert acl.is_permitted(permitted_frame_1)[0] permitted_frame_2 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.10", dst_ip_address="85.199.214.101", protocol=IPProtocol.UDP), - udp=UDPHeader(src_port=Port.NTP, dst_port=Port.NTP), + ip=IPPacket(src_ip_address="192.168.1.10", dst_ip_address="85.199.214.101", protocol=IPProtocol["UDP"]), + udp=UDPHeader(src_port=Port["NTP"], dst_port=Port["NTP"]), ) assert acl.is_permitted(permitted_frame_2)[0] permitted_frame_3 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.200", dst_ip_address="192.168.1.1", protocol=IPProtocol.ICMP), + ip=IPPacket(src_ip_address="192.168.1.200", dst_ip_address="192.168.1.1", protocol=IPProtocol["ICMP"]), icmp=ICMPPacket(identifier=1), ) @@ -277,16 +277,16 @@ def test_ip_traffic_from_specific_subnet(): not_permitted_frame_1 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.0.50", dst_ip_address="10.2.200.200", protocol=IPProtocol.TCP), - tcp=TCPHeader(src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER), + ip=IPPacket(src_ip_address="192.168.0.50", dst_ip_address="10.2.200.200", protocol=IPProtocol["TCP"]), + tcp=TCPHeader(src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"]), ) assert not acl.is_permitted(not_permitted_frame_1)[0] not_permitted_frame_2 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.2.10", dst_ip_address="85.199.214.101", protocol=IPProtocol.UDP), - udp=UDPHeader(src_port=Port.NTP, dst_port=Port.NTP), + ip=IPPacket(src_ip_address="192.168.2.10", dst_ip_address="85.199.214.101", protocol=IPProtocol["UDP"]), + udp=UDPHeader(src_port=Port["NTP"], dst_port=Port["NTP"]), ) assert not acl.is_permitted(not_permitted_frame_2)[0] diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py index d4e38ded..3551ce38 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py @@ -67,12 +67,12 @@ def test_wireless_router_from_config(): r0 = rt.acl.acl[0] assert r0.action == ACLAction.PERMIT - assert r0.src_port == r0.dst_port == Port.POSTGRES_SERVER + assert r0.src_port == r0.dst_port == Port["POSTGRES_SERVER"] assert r0.src_ip_address == r0.dst_ip_address == r0.dst_wildcard_mask == r0.src_wildcard_mask == r0.protocol == None r1 = rt.acl.acl[1] assert r1.action == ACLAction.PERMIT - assert r1.protocol == IPProtocol.ICMP + assert r1.protocol == IPProtocol["ICMP"] assert ( r1.src_ip_address == r1.dst_ip_address diff --git a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py index 92618baa..9fd39dfc 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py @@ -20,7 +20,7 @@ def test_frame_minimal_instantiation(): ) # Check network layer default values - assert frame.ip.protocol == IPProtocol.TCP + assert frame.ip.protocol == IPProtocol["TCP"] assert frame.ip.ttl == 64 assert frame.ip.precedence == Precedence.ROUTINE @@ -40,7 +40,7 @@ def test_frame_creation_fails_tcp_without_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.TCP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["TCP"]), ) @@ -49,7 +49,7 @@ def test_frame_creation_fails_udp_without_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.UDP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["UDP"]), ) @@ -58,7 +58,7 @@ def test_frame_creation_fails_tcp_with_udp_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.TCP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["TCP"]), udp=UDPHeader(src_port=8080, dst_port=80), ) @@ -68,7 +68,7 @@ def test_frame_creation_fails_udp_with_tcp_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.UDP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["UDP"]), udp=TCPHeader(src_port=8080, dst_port=80), ) @@ -77,7 +77,7 @@ def test_icmp_frame_creation(): """Tests Frame creation for ICMP.""" frame = Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.ICMP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["ICMP"]), icmp=ICMPPacket(), ) assert frame @@ -88,5 +88,5 @@ def test_icmp_frame_creation_fails_without_icmp_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol.ICMP), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["ICMP"]), ) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py index 885a3cb6..6e53aebc 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py @@ -129,19 +129,19 @@ def test_c2_handle_switching_port(basic_c2_network): # Assert to confirm that both the C2 server and the C2 beacon are configured correctly. assert c2_beacon.c2_config.keep_alive_frequency is 2 - assert c2_beacon.c2_config.masquerade_port is Port.HTTP - assert c2_beacon.c2_config.masquerade_protocol is IPProtocol.TCP + assert c2_beacon.c2_config.masquerade_port is Port["HTTP"] + assert c2_beacon.c2_config.masquerade_protocol is IPProtocol["TCP"] assert c2_server.c2_config.keep_alive_frequency is 2 - assert c2_server.c2_config.masquerade_port is Port.HTTP - assert c2_server.c2_config.masquerade_protocol is IPProtocol.TCP + assert c2_server.c2_config.masquerade_port is Port["HTTP"] + assert c2_server.c2_config.masquerade_protocol is IPProtocol["TCP"] # Configuring the C2 Beacon. c2_beacon.configure( c2_server_ip_address="192.168.0.1", keep_alive_frequency=2, - masquerade_port=Port.FTP, - masquerade_protocol=IPProtocol.TCP, + masquerade_port=Port["FTP"], + masquerade_protocol=IPProtocol["TCP"], ) # Asserting that the c2 applications have established a c2 connection @@ -150,11 +150,11 @@ def test_c2_handle_switching_port(basic_c2_network): # Assert to confirm that both the C2 server and the C2 beacon # Have reconfigured their C2 settings. - assert c2_beacon.c2_config.masquerade_port is Port.FTP - assert c2_beacon.c2_config.masquerade_protocol is IPProtocol.TCP + assert c2_beacon.c2_config.masquerade_port is Port["FTP"] + assert c2_beacon.c2_config.masquerade_protocol is IPProtocol["TCP"] - assert c2_server.c2_config.masquerade_port is Port.FTP - assert c2_server.c2_config.masquerade_protocol is IPProtocol.TCP + assert c2_server.c2_config.masquerade_port is Port["FTP"] + assert c2_server.c2_config.masquerade_protocol is IPProtocol["TCP"] def test_c2_handle_switching_frequency(basic_c2_network): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py index 0811d2a0..229f98fe 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py @@ -27,8 +27,8 @@ def test_create_dm_bot(dm_client): data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software.get("DataManipulationBot") assert data_manipulation_bot.name == "DataManipulationBot" - assert data_manipulation_bot.port == Port.NONE - assert data_manipulation_bot.protocol == IPProtocol.NONE + assert data_manipulation_bot.port == Port["NONE"] + assert data_manipulation_bot.protocol == IPProtocol["NONE"] assert data_manipulation_bot.payload == "DELETE" diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py index ce98d164..c274c18e 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py @@ -39,8 +39,8 @@ def test_create_web_client(): # Web Browser should be pre-installed in computer web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") assert web_browser.name is "WebBrowser" - assert web_browser.port is Port.HTTP - assert web_browser.protocol is IPProtocol.TCP + assert web_browser.port is Port["HTTP"] + assert web_browser.protocol is IPProtocol["TCP"] def test_receive_invalid_payload(web_browser): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py index e9ce4884..1a51708d 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py @@ -28,8 +28,8 @@ def test_create_dns_client(dns_client): assert dns_client is not None dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient") assert dns_client_service.name is "DNSClient" - assert dns_client_service.port is Port.DNS - assert dns_client_service.protocol is IPProtocol.TCP + assert dns_client_service.port is Port["DNS"] + assert dns_client_service.protocol is IPProtocol["TCP"] def test_dns_client_add_domain_to_cache_when_not_running(dns_client): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py index 4658fe76..8cdb1b84 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py @@ -32,8 +32,8 @@ def test_create_dns_server(dns_server): assert dns_server is not None dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer") assert dns_server_service.name is "DNSServer" - assert dns_server_service.port is Port.DNS - assert dns_server_service.protocol is IPProtocol.TCP + assert dns_server_service.port is Port["DNS"] + assert dns_server_service.protocol is IPProtocol["TCP"] def test_dns_server_domain_name_registration(dns_server): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py index 3ce4d8ee..3c1afb28 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py @@ -31,8 +31,8 @@ def test_create_ftp_client(ftp_client): assert ftp_client is not None ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") assert ftp_client_service.name is "FTPClient" - assert ftp_client_service.port is Port.FTP - assert ftp_client_service.protocol is IPProtocol.TCP + assert ftp_client_service.port is Port["FTP"] + assert ftp_client_service.protocol is IPProtocol["TCP"] def test_ftp_client_store_file(ftp_client): @@ -61,7 +61,7 @@ def test_ftp_should_not_process_commands_if_service_not_running(ftp_client): """Method _process_ftp_command should return false if service is not running.""" payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.PORT, - ftp_command_args=Port.FTP, + ftp_command_args=Port["FTP"], status_code=FTPStatusCode.OK, ) @@ -102,7 +102,7 @@ def test_offline_ftp_client_receives_request(ftp_client): payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.PORT, - ftp_command_args=Port.FTP, + ftp_command_args=Port["FTP"], status_code=FTPStatusCode.OK, ) @@ -119,7 +119,7 @@ def test_receive_should_ignore_payload_with_none_status_code(ftp_client): """Receive should ignore payload with no set status code to prevent infinite send/receive loops.""" payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.PORT, - ftp_command_args=Port.FTP, + ftp_command_args=Port["FTP"], status_code=None, ) ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py index a1c2ba59..aa13ec5e 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py @@ -30,8 +30,8 @@ def test_create_ftp_server(ftp_server): assert ftp_server is not None ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer") assert ftp_server_service.name is "FTPServer" - assert ftp_server_service.port is Port.FTP - assert ftp_server_service.protocol is IPProtocol.TCP + assert ftp_server_service.port is Port["FTP"] + assert ftp_server_service.protocol is IPProtocol["TCP"] def test_ftp_server_store_file(ftp_server): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 41858b90..21ed839b 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -77,11 +77,11 @@ def wireless_wan_network(): network.connect(pc_a.network_interface[1], router_1.network_interface[2]) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) # add ACL rule to allow SSH traffic - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=21) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["SSH"], dst_port=Port["SSH"], position=21) # Configure PC B pc_b = Computer( @@ -329,7 +329,7 @@ def test_SSH_across_network(wireless_wan_network): terminal_a: Terminal = pc_a.software_manager.software.get("Terminal") terminal_b: Terminal = pc_b.software_manager.software.get("Terminal") - router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=21) + router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["SSH"], dst_port=Port["SSH"], position=21) assert len(terminal_a._connections) == 0 diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py index 9af176be..c1df3857 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py @@ -33,8 +33,8 @@ def test_create_web_server(web_server): assert web_server is not None web_server_service: WebServer = web_server.software_manager.software.get("WebServer") assert web_server_service.name is "WebServer" - assert web_server_service.port is Port.HTTP - assert web_server_service.protocol is IPProtocol.TCP + assert web_server_service.port is Port["HTTP"] + assert web_server_service.protocol is IPProtocol["TCP"] def test_handling_get_request_not_found_path(web_server): diff --git a/tests/unit_tests/_primaite/_simulator/_system/test_software.py b/tests/unit_tests/_primaite/_simulator/_system/test_software.py index 4cf83370..b7a663af 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/test_software.py +++ b/tests/unit_tests/_primaite/_simulator/_system/test_software.py @@ -19,10 +19,10 @@ class TestSoftware(Service): def software(file_system): return TestSoftware( name="TestSoftware", - port=Port.ARP, + port=Port["ARP"], file_system=file_system, sys_log=SysLog(hostname="test_service"), - protocol=IPProtocol.TCP, + protocol=IPProtocol["TCP"], ) diff --git a/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py b/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py index a8fb0a3a..4e40bbd8 100644 --- a/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py +++ b/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py @@ -11,7 +11,7 @@ def test_simple_conversion(): The original dictionary contains one level of nested dictionary with enums as keys. The expected output should have string values of enums as keys. """ - original_dict = {IPProtocol.UDP: {Port.ARP: {"inbound": 0, "outbound": 1016.0}}} + original_dict = {IPProtocol["UDP"]: {Port["ARP"]: {"inbound": 0, "outbound": 1016.0}}} expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0}}} assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict @@ -36,8 +36,8 @@ def test_mixed_keys(): The expected output should have string values of enums and original string keys. """ original_dict = { - IPProtocol.TCP: {"port": {"inbound": 0, "outbound": 1016.0}}, - "protocol": {Port.HTTP: {"inbound": 10, "outbound": 2020.0}}, + IPProtocol["TCP"]: {"port": {"inbound": 0, "outbound": 1016.0}}, + "protocol": {Port["HTTP"]: {"inbound": 10, "outbound": 2020.0}}, } expected_dict = { "tcp": {"port": {"inbound": 0, "outbound": 1016.0}}, @@ -66,7 +66,7 @@ def test_nested_dicts(): The expected output should have string values of enums as keys at all levels. """ original_dict = { - IPProtocol.UDP: {Port.ARP: {"inbound": 0, "outbound": 1016.0, "details": {IPProtocol.TCP: {"latency": "low"}}}} + IPProtocol["UDP"]: {Port["ARP"]: {"inbound": 0, "outbound": 1016.0, "details": {IPProtocol["TCP"]: {"latency": "low"}}}} } expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0, "details": {"tcp": {"latency": "low"}}}}} assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict @@ -79,6 +79,6 @@ def test_non_dict_values(): The original dictionary contains lists and tuples as values. The expected output should preserve these non-dictionary values while converting enum keys to string values. """ - original_dict = {IPProtocol.UDP: [Port.ARP, Port.HTTP], "protocols": (IPProtocol.TCP, IPProtocol.UDP)} - expected_dict = {"udp": [Port.ARP, Port.HTTP], "protocols": (IPProtocol.TCP, IPProtocol.UDP)} + original_dict = {IPProtocol["UDP"]: [Port["ARP"], Port["HTTP"]], "protocols": (IPProtocol["TCP"], IPProtocol["UDP"])} + expected_dict = {"udp": [Port["ARP"], Port["HTTP"]], "protocols": (IPProtocol["TCP"], IPProtocol["UDP"])} assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict From 08f1cf1fbd67f5fa8158876915f9ea940def0c7c Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 19 Sep 2024 15:06:29 +0100 Subject: [PATCH 2/9] Fix airspace and remaining port problems from refactor --- .../network/nodes/wireless_router.rst | 6 +- src/primaite/config/load.py | 5 +- .../agent/observations/acl_observation.py | 4 +- .../agent/observations/host_observations.py | 10 +- .../agent/observations/nic_observations.py | 17 +-- .../agent/observations/node_observations.py | 10 +- src/primaite/game/game.py | 26 ++--- src/primaite/simulator/network/airspace.py | 101 +++++++++++------- src/primaite/simulator/network/container.py | 8 +- .../network/hardware/nodes/host/host_node.py | 4 +- .../hardware/nodes/network/network_node.py | 4 +- .../network/hardware/nodes/network/router.py | 11 +- .../hardware/nodes/network/wireless_router.py | 14 +-- .../network/transmission/network_layer.py | 10 +- .../network/transmission/transport_layer.py | 61 ++++++----- .../system/applications/application.py | 4 +- .../red_applications/c2/c2_beacon.py | 1 - .../simulator/system/core/session_manager.py | 4 +- .../system/services/ftp/ftp_service.py | 1 - .../simulator/system/services/service.py | 4 +- src/primaite/simulator/system/software.py | 1 - .../extensions/nodes/giga_switch.py | 3 +- .../extensions/nodes/super_computer.py | 4 +- .../extensions/services/extended_service.py | 12 ++- .../extensions/test_extendable_config.py | 18 ++-- .../observations/test_acl_observations.py | 4 +- .../observations/test_firewall_observation.py | 4 +- .../observations/test_nic_observations.py | 7 +- .../observations/test_router_observation.py | 4 +- .../game_layer/test_rewards.py | 4 +- .../network/test_airspace_config.py | 9 +- .../network/test_firewall.py | 16 ++- tests/integration_tests/system/test_nmap.py | 4 +- .../_utils/test_dict_enum_keys_conversion.py | 9 +- 34 files changed, 227 insertions(+), 177 deletions(-) diff --git a/docs/source/simulation_components/network/nodes/wireless_router.rst b/docs/source/simulation_components/network/nodes/wireless_router.rst index bd361afa..436852ea 100644 --- a/docs/source/simulation_components/network/nodes/wireless_router.rst +++ b/docs/source/simulation_components/network/nodes/wireless_router.rst @@ -49,7 +49,7 @@ additional steps to configure wireless settings: wireless_router.configure_wireless_access_point( port=1, ip_address="192.168.2.1", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency["WIFI_2_4"], + frequency="WIFI_2_4", ) @@ -130,13 +130,13 @@ ICMP traffic, ensuring basic network connectivity and ping functionality. port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency["WIFI_2_4"], + frequency="WIFI_2_4", ) router_2.configure_wireless_access_point( port=1, ip_address="192.168.1.2", subnet_mask="255.255.255.0", - frequency=AirSpaceFrequency["WIFI_2_4"], + frequency="WIFI_2_4", ) # Configure routes for inter-router communication diff --git a/src/primaite/config/load.py b/src/primaite/config/load.py index b00c26f6..39040d76 100644 --- a/src/primaite/config/load.py +++ b/src/primaite/config/load.py @@ -60,9 +60,10 @@ def data_manipulation_marl_config_path() -> Path: raise FileNotFoundError(msg) return path + def get_extended_config_path() -> Path: """ - Get the path to an 'extended' example config that contains nodes using the extension framework + Get the path to an 'extended' example config that contains nodes using the extension framework. :return: Path to the extended example config :rtype: Path @@ -72,4 +73,4 @@ def get_extended_config_path() -> Path: msg = f"Example config does not exist: {path}. Have you run `primaite setup`?" _LOGGER.error(msg) raise FileNotFoundError(msg) - return path \ No newline at end of file + return path diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py index abb6f1f8..41af5a8f 100644 --- a/src/primaite/game/agent/observations/acl_observation.py +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -10,8 +10,6 @@ from gymnasium.core import ObsType from primaite import getLogger from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port _LOGGER = getLogger(__name__) @@ -63,7 +61,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)} self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)} self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)} - self.protocol_to_id: Dict[str, int] = {IPProtocol[p]: i + 2 for i, p in enumerate(protocol_list)} + self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)} self.default_observation: Dict = { i + 1: { diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 05b25952..30ccd195 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -58,8 +58,14 @@ class HostObservation(AbstractObservation, identifier="HOST"): include_users: Optional[bool] = True """If True, report user session information.""" - @field_validator('monitored_traffic', mode='before') - def traffic_lookup(cls, val:Optional[Dict]) -> Optional[Dict]: + @field_validator("monitored_traffic", mode="before") + def traffic_lookup(cls, val: Optional[Dict]) -> Optional[Dict]: + """ + Convert monitored_traffic by lookup against Port and Protocol dicts. + + This is necessary for retaining compatiblility with configs written for PrimAITE <=3.3. + This method will be removed in PrimAITE >= 4.0 + """ if val is None: return val new_val = {} diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 200187f5..296ce04c 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -26,8 +26,14 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): monitored_traffic: Optional[Dict] = None """A dict containing which traffic types are to be included in the observation.""" - @field_validator('monitored_traffic', mode='before') - def traffic_lookup(cls, val:Optional[Dict]) -> Optional[Dict]: + @field_validator("monitored_traffic", mode="before") + def traffic_lookup(cls, val: Optional[Dict]) -> Optional[Dict]: + """ + Convert monitored_traffic by lookup against Port and Protocol dicts. + + This is necessary for retaining compatiblility with configs written for PrimAITE <=3.3. + This method will be removed in PrimAITE >= 4.0 + """ if val is None: return val new_val = {} @@ -41,7 +47,6 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): new_val[proto].append(port) return new_val - def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None: """ Initialise a network interface observation instance. @@ -76,7 +81,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): def _default_monitored_traffic_observation(self, monitored_traffic_config: Dict) -> Dict: default_traffic_obs = {"TRAFFIC": {}} - for protocol in monitored_traffic_config: + for protocol in self.monitored_traffic: protocol = str(protocol).lower() default_traffic_obs["TRAFFIC"][protocol] = {} @@ -84,8 +89,8 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): default_traffic_obs["TRAFFIC"]["icmp"] = {"inbound": 0, "outbound": 0} else: default_traffic_obs["TRAFFIC"][protocol] = {} - for port in monitored_traffic_config[protocol]: - default_traffic_obs["TRAFFIC"][protocol] = {"inbound": 0, "outbound": 0} + for port in self.monitored_traffic[protocol]: + default_traffic_obs["TRAFFIC"][protocol][port] = {"inbound": 0, "outbound": 0} return default_traffic_obs diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 3e51c3b3..054ffcdb 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -63,8 +63,14 @@ class NodesObservation(AbstractObservation, identifier="NODES"): num_rules: Optional[int] = None """Number of rules ACL rules to show.""" - @field_validator('monitored_traffic', mode='before') - def traffic_lookup(cls, val:Optional[Dict]) -> Optional[Dict]: + @field_validator("monitored_traffic", mode="before") + def traffic_lookup(cls, val: Optional[Dict]) -> Optional[Dict]: + """ + Convert monitored_traffic by lookup against Port and Protocol dicts. + + This is necessary for retaining compatiblility with configs written for PrimAITE <=3.3. + This method will be removed in PrimAITE >= 4.0 + """ if val is None: return val new_val = {} diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index e8329c63..8e0abb1e 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -17,10 +17,9 @@ from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT -from primaite.simulator.network.airspace import AirSpaceFrequency from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState, UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.hardware.nodes.host.host_node import NIC, HostNode +from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC from primaite.simulator.network.hardware.nodes.host.server import Printer, Server from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode @@ -89,8 +88,8 @@ class PrimaiteGameOptions(BaseModel): thresholds: Optional[Dict] = {} """A dict containing the thresholds used for determining what is acceptable during observations.""" - @field_validator('ports', mode='before') - def ports_str2int(cls, vals:Union[List[str],List[int]]) -> List[int]: + @field_validator("ports", mode="before") + def ports_str2int(cls, vals: Union[List[str], List[int]]) -> List[int]: """ Convert named port strings to port integer values. Integer ports remain unaffected. @@ -102,8 +101,8 @@ class PrimaiteGameOptions(BaseModel): vals[i] = Port[port_val] return vals - @field_validator('protocols', mode='before') - def protocols_str2int(cls, vals:List[str]) -> List[str]: + @field_validator("protocols", mode="before") + def protocols_str2int(cls, vals: List[str]) -> List[str]: """ Convert old-style named protocols to their proper values. @@ -116,7 +115,6 @@ class PrimaiteGameOptions(BaseModel): return vals - class PrimaiteGame: """ Primaite game encapsulates the simulation and agents which interact with it. @@ -294,10 +292,7 @@ class PrimaiteGame: network_config = simulation_config.get("network", {}) airspace_cfg = network_config.get("airspace", {}) frequency_max_capacity_mbps_cfg = airspace_cfg.get("frequency_max_capacity_mbps", {}) - - frequency_max_capacity_mbps_cfg = {AirSpaceFrequency[k]: v for k, v in frequency_max_capacity_mbps_cfg.items()} - - net.airspace.frequency_max_capacity_mbps_ = frequency_max_capacity_mbps_cfg + net.airspace.set_frequency_max_capacity_mbps(frequency_max_capacity_mbps_cfg) nodes_cfg = network_config.get("nodes", []) links_cfg = network_config.get("links", []) @@ -318,11 +313,10 @@ class PrimaiteGame: dns_server=node_cfg.get("dns_server", None), operating_state=NodeOperatingState.ON if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()]) - elif n_type in NetworkNode._registry: - new_node = NetworkNode._registry[n_type]( - **node_cfg + else NodeOperatingState[p.upper()], ) + elif n_type in NetworkNode._registry: + new_node = NetworkNode._registry[n_type](**node_cfg) # Default PrimAITE nodes elif n_type == "computer": new_node = Computer( @@ -502,7 +496,7 @@ class PrimaiteGame: opt = application_cfg["options"] new_application.configure( target_ip_address=IPv4Address(opt.get("target_ip_address")), - target_port = Port[opt.get("target_port", "POSTGRES_SERVER")], + target_port=Port[opt.get("target_port", "POSTGRES_SERVER")], payload=opt.get("payload"), repeat=bool(opt.get("repeat")), port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index 29326df8..65dceeb1 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -1,14 +1,12 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations +import copy from abc import ABC, abstractmethod -from enum import Enum -from typing import Any, ClassVar, Dict, List, Type -from pydantic._internal._generics import PydanticGenericMetadata -from typing_extensions import Unpack +from typing import Any, Dict, List from prettytable import MARKDOWN, PrettyTable -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, Field, validate_call from primaite import getLogger from primaite.simulator.network.hardware.base import Layer3Interface, NetworkInterface, WiredNetworkInterface @@ -42,29 +40,31 @@ def format_hertz(hertz: float, format_terahertz: bool = False, decimals: int = 3 else: # Hertz return format_str.format(hertz) + " Hz" -AirSpaceFrequencyRegistry: Dict[str,Dict] = { - "WIFI_2_4" : {'frequency': 2.4e9, 'data_rate_bps':100_000_000.0}, - "WIFI_5" : {'frequency': 5e9, 'data_rate_bps':500_000_000.0}, + +_default_frequency_set: Dict[str, Dict] = { + "WIFI_2_4": {"frequency": 2.4e9, "data_rate_bps": 100_000_000.0}, + "WIFI_5": {"frequency": 5e9, "data_rate_bps": 500_000_000.0}, } +"""Frequency configuration that is automatically used for any new airspace.""" -def register_frequency(freq_name: str, freq_hz: int, data_rate_bps: int) -> None: - if freq_name in AirSpaceFrequencyRegistry: - raise RuntimeError(f"Cannot register new frequency {freq_name} because it's already registered.") - AirSpaceFrequencyRegistry.update({freq_name:{'frequency': freq_hz, 'data_rate_bps':data_rate_bps}}) -def maximum_data_rate_mbps(frequency_name:str) -> float: +def register_default_frequency(freq_name: str, freq_hz: float, data_rate_bps: float): + """Add to the default frequency configuration. This is intended as a plugin hook. + + If your plugin makes use of bespoke frequencies for wireless communication, you should make a call to this method + whereever you define components that rely on the bespoke frequencies. That way, as soon as your components are + imported, this function automatically updates the default frequency set. + + This should also be run before instances of AirSpace are created. + + :param freq_name: The frequency name. If this clashes with an existing frequency name, it will be overwritten. + :type freq_name: str + :param freq_hz: The frequency itself, measured in Hertz. + :type freq_hz: float + :param data_rate_bps: The transmission capacity over this frequency, in bits per second. + :type data_rate_bps: float """ - Retrieves the maximum data transmission rate in megabits per second (Mbps). - - This is derived by converting the maximum data rate from bits per second, as defined - in `maximum_data_rate_bps`, to megabits per second. - - :return: The maximum data rate in megabits per second. - """ - return AirSpaceFrequencyRegistry[frequency_name]['data_rate_bps'] - return data_rate / 1_000_000.0 - - + _default_frequency_set.update({freq_name: {"frequency": freq_hz, "data_rate_bps": data_rate_bps}}) class AirSpace(BaseModel): @@ -77,27 +77,21 @@ class AirSpace(BaseModel): """ wireless_interfaces: Dict[str, WirelessNetworkInterface] = Field(default_factory=lambda: {}) - wireless_interfaces_by_frequency: Dict[int, List[WirelessNetworkInterface]] = Field( - default_factory=lambda: {} - ) + wireless_interfaces_by_frequency: Dict[int, List[WirelessNetworkInterface]] = Field(default_factory=lambda: {}) bandwidth_load: Dict[int, float] = Field(default_factory=lambda: {}) - frequency_max_capacity_mbps_: Dict[int, float] = Field(default_factory=lambda: {}) + frequencies: Dict[str, Dict] = Field(default_factory=lambda: copy.deepcopy(_default_frequency_set)) - def get_frequency_max_capacity_mbps(self, frequency: str) -> float: + @validate_call + def get_frequency_max_capacity_mbps(self, freq_name: str) -> float: """ Retrieves the maximum data transmission capacity for a specified frequency. - This method checks a dictionary holding custom maximum capacities. If the frequency is found, it returns the - custom set maximum capacity. If the frequency is not found in the dictionary, it defaults to the standard - maximum data rate associated with that frequency. - - :param frequency: The frequency for which the maximum capacity is queried. - + :param freq_name: The frequency for which the maximum capacity is queried. :return: The maximum capacity in Mbps for the specified frequency. """ - if frequency in self.frequency_max_capacity_mbps_: - return self.frequency_max_capacity_mbps_[frequency] - return maximum_data_rate_mbps(frequency) + if freq_name in self.frequencies: + return self.frequencies[freq_name]["data_rate_bps"] / (1024.0 * 1024.0) + return 0.0 def set_frequency_max_capacity_mbps(self, cfg: Dict[int, float]): """ @@ -105,10 +99,29 @@ class AirSpace(BaseModel): :param cfg: A dictionary mapping frequencies to their new maximum capacities in Mbps. """ - self.frequency_max_capacity_mbps_ = cfg for freq, mbps in cfg.items(): + self.frequencies[freq]["data_rate_bps"] = mbps * 1024 * 1024 print(f"Overriding {freq} max capacity as {mbps:.3f} mbps") + def register_frequency(self, freq_name: str, freq_hz: float, data_rate_bps: float) -> None: + """ + Define a new frequency for this airspace. + + :param freq_name: The frequency name. If this clashes with an existing frequency name, it will be overwritten. + :type freq_name: str + :param freq_hz: The frequency itself, measured in Hertz. + :type freq_hz: float + :param data_rate_bps: The transmission capacity over this frequency, in bits per second. + :type data_rate_bps: float + """ + if freq_name in self.frequencies: + _LOGGER.info( + f"Overwriting Air space frequency {freq_name}. " + f"Previous data rate: {self.frequencies[freq_name]['data_rate_bps']}. " + f"Current data rate: {data_rate_bps}." + ) + self.frequencies.update({freq_name: {"frequency": freq_hz, "data_rate_bps": data_rate_bps}}) + def show_bandwidth_load(self, markdown: bool = False): """ Prints a table of the current bandwidth load for each frequency on the airspace. @@ -130,7 +143,13 @@ class AirSpace(BaseModel): load_percent = load / maximum_capacity if maximum_capacity > 0 else 0.0 if load_percent > 1.0: load_percent = 1.0 - table.add_row([format_hertz(frequency), f"{load_percent:.0%}", f"{maximum_capacity:.3f}"]) + table.add_row( + [ + format_hertz(self.frequencies[frequency]["frequency"]), + f"{load_percent:.0%}", + f"{maximum_capacity:.3f}", + ] + ) print(table) def show_wireless_interfaces(self, markdown: bool = False): @@ -162,7 +181,7 @@ class AirSpace(BaseModel): interface.mac_address, interface.ip_address if hasattr(interface, "ip_address") else None, interface.subnet_mask if hasattr(interface, "subnet_mask") else None, - format_hertz(interface.frequency), + format_hertz(self.frequencies[interface.frequency]["frequency"]), f"{interface.speed:.3f}", status, ] diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index 39fbe783..6e019f32 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -130,15 +130,15 @@ class Network(SimComponent): def firewall_nodes(self) -> List[Node]: """The Firewalls in the Network.""" return [node for node in self.nodes.values() if node.__class__.__name__ == "Firewall"] - + @property def extended_hostnodes(self) -> List[Node]: - """Extended nodes that inherited HostNode in the network""" + """Extended nodes that inherited HostNode in the network.""" return [node for node in self.nodes.values() if node.__class__.__name__.lower() in HostNode._registry] - + @property def extended_networknodes(self) -> List[Node]: - """Extended nodes that inherited NetworkNode in the network""" + """Extended nodes that inherited NetworkNode in the network.""" return [node for node in self.nodes.values() if node.__class__.__name__.lower() in NetworkNode._registry] @property diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index ea162e88..8a420e44 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -332,7 +332,7 @@ class HostNode(Node): super().__init__(**kwargs) self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask)) - def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: """ Register a hostnode type. @@ -340,7 +340,7 @@ class HostNode(Node): :type identifier: str :raises ValueError: When attempting to register an hostnode with a name that is already allocated. """ - if identifier == 'default': + if identifier == "default": return # Enforce lowercase registry entries because it makes comparisons everywhere else much easier. identifier = identifier.lower() diff --git a/src/primaite/simulator/network/hardware/nodes/network/network_node.py b/src/primaite/simulator/network/hardware/nodes/network/network_node.py index 6515bb02..a0cb63e1 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/network_node.py +++ b/src/primaite/simulator/network/hardware/nodes/network/network_node.py @@ -19,7 +19,7 @@ class NetworkNode(Node): _registry: ClassVar[Dict[str, Type["NetworkNode"]]] = {} """Registry of application types. Automatically populated when subclasses are defined.""" - def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: """ Register a networknode type. @@ -27,7 +27,7 @@ class NetworkNode(Node): :type identifier: str :raises ValueError: When attempting to register an networknode with a name that is already allocated. """ - if identifier == 'default': + if identifier == "default": return identifier = identifier.lower() super().__init_subclass__(**kwargs) diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 013c473e..fded23f9 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -130,19 +130,20 @@ class ACLRule(SimComponent): dst_port: Optional[int] = None match_count: int = 0 - @field_validator('protocol', mode='before') - def protocol_valid(cls, val:Optional[str]) -> Optional[str]: + @field_validator("protocol", mode="before") + def protocol_valid(cls, val: Optional[str]) -> Optional[str]: + """Assert that the protocol for the rule is predefined in the IPProtocol lookup.""" if val is not None: assert val in IPProtocol.values(), f"Cannot create ACL rule with invalid protocol {val}" return val - @field_validator('src_port', 'dst_port', mode='before') - def ports_valid(cls, val:Optional[int]) -> Optional[int]: + @field_validator("src_port", "dst_port", mode="before") + def ports_valid(cls, val: Optional[int]) -> Optional[int]: + """Assert that the port for the rule is predefined in the Port lookup.""" if val is not None: assert val in Port.values(), f"Cannot create ACL rule with invalid port {val}" return val - def __str__(self) -> str: rule_strings = [] for key, value in self.model_dump(exclude={"uuid", "request_manager"}).items(): diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index d73bc756..1969a121 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Optional, Union from pydantic import validate_call -from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, IPWirelessNetworkInterface +from primaite.simulator.network.airspace import AirSpace, IPWirelessNetworkInterface from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterInterface from primaite.simulator.network.transmission.data_link_layer import Frame @@ -116,7 +116,7 @@ class WirelessRouter(Router): >>> wireless_router.configure_wireless_access_point( ... ip_address="10.10.10.1", ... subnet_mask="255.255.255.0" - ... frequency=AirSpaceFrequency["WIFI_2_4"] + ... frequency="WIFI_2_4" ... ) """ @@ -153,7 +153,7 @@ class WirelessRouter(Router): self, ip_address: IPV4Address, subnet_mask: IPV4Address, - frequency: Optional[int] = AirSpaceFrequency["WIFI_2_4"], + frequency: Optional[str] = "WIFI_2_4", ): """ Configures a wireless access point (WAP). @@ -166,12 +166,12 @@ class WirelessRouter(Router): :param ip_address: The IP address to be assigned to the wireless access point. :param subnet_mask: The subnet mask associated with the IP address - :param frequency: The operating frequency of the wireless access point, defined by the AirSpaceFrequency + :param frequency: The operating frequency of the wireless access point, defined by the air space frequency enum. This determines the frequency band (e.g., 2.4 GHz or 5 GHz) the access point will use for wireless - communication. Default is AirSpaceFrequency["WIFI_2_4"]. + communication. Default is "WIFI_2_4". """ if not frequency: - frequency = AirSpaceFrequency["WIFI_2_4"] + frequency = "WIFI_2_4" self.sys_log.info("Configuring wireless access point") self.wireless_access_point.disable() # Temporarily disable the WAP for reconfiguration @@ -264,7 +264,7 @@ class WirelessRouter(Router): if "wireless_access_point" in cfg: ip_address = cfg["wireless_access_point"]["ip_address"] subnet_mask = cfg["wireless_access_point"]["subnet_mask"] - frequency = AirSpaceFrequency[cfg["wireless_access_point"]["frequency"]] + frequency = cfg["wireless_access_point"]["frequency"] router.configure_wireless_access_point(ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency) if "acl" in cfg: diff --git a/src/primaite/simulator/network/transmission/network_layer.py b/src/primaite/simulator/network/transmission/network_layer.py index 36ff2751..a01b7f42 100644 --- a/src/primaite/simulator/network/transmission/network_layer.py +++ b/src/primaite/simulator/network/transmission/network_layer.py @@ -9,11 +9,11 @@ from primaite.utils.validators import IPV4Address _LOGGER = getLogger(__name__) -IPProtocol : dict[str, str] = dict( - NONE = "none", - TCP = "tcp", - UDP = "udp", - ICMP = "icmp", +IPProtocol: dict[str, str] = dict( + NONE="none", + TCP="tcp", + UDP="udp", + ICMP="icmp", ) # class IPProtocol(Enum): diff --git a/src/primaite/simulator/network/transmission/transport_layer.py b/src/primaite/simulator/network/transmission/transport_layer.py index c77ef532..60f2f070 100644 --- a/src/primaite/simulator/network/transmission/transport_layer.py +++ b/src/primaite/simulator/network/transmission/transport_layer.py @@ -1,40 +1,39 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from enum import Enum -from typing import List, Union +from typing import List from pydantic import BaseModel - Port: dict[str, int] = dict( - UNUSED = -1, - NONE = 0, - WOL = 9, - FTP_DATA = 20, - FTP = 21, - SSH = 22, - SMTP = 25, - DNS = 53, - HTTP = 80, - POP3 = 110, - SFTP = 115, - NTP = 123, - IMAP = 143, - SNMP = 161, - SNMP_TRAP = 162, - ARP = 219, - LDAP = 389, - HTTPS = 443, - SMB = 445, - IPP = 631, - SQL_SERVER = 1433, - MYSQL = 3306, - RDP = 3389, - RTP = 5004, - RTP_ALT = 5005, - DNS_ALT = 5353, - HTTP_ALT = 8080, - HTTPS_ALT = 8443, - POSTGRES_SERVER = 5432, + UNUSED=-1, + NONE=0, + WOL=9, + FTP_DATA=20, + FTP=21, + SSH=22, + SMTP=25, + DNS=53, + HTTP=80, + POP3=110, + SFTP=115, + NTP=123, + IMAP=143, + SNMP=161, + SNMP_TRAP=162, + ARP=219, + LDAP=389, + HTTPS=443, + SMB=445, + IPP=631, + SQL_SERVER=1433, + MYSQL=3306, + RDP=3389, + RTP=5004, + RTP_ALT=5005, + DNS_ALT=5353, + HTTP_ALT=8080, + HTTPS_ALT=8443, + POSTGRES_SERVER=5432, ) # class Port(): diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index b5284968..a7871315 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -44,7 +44,7 @@ class Application(IOSoftware): _registry: ClassVar[Dict[str, Type["Application"]]] = {} """Registry of application types. Automatically populated when subclasses are defined.""" - def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: """ Register an application type. @@ -52,7 +52,7 @@ class Application(IOSoftware): :type identifier: str :raises ValueError: When attempting to register an application with a name that is already allocated. """ - if identifier == 'default': + if identifier == "default": return super().__init_subclass__(**kwargs) if identifier in cls._registry: diff --git a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py index 06453330..9178e68a 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py @@ -1,5 +1,4 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from enum import Enum from ipaddress import IPv4Address from typing import Dict, Optional diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index 172be453..33de3443 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -76,9 +76,7 @@ class SessionManager: """ def __init__(self, sys_log: SysLog): - self.sessions_by_key: Dict[ - Tuple[str, IPv4Address, IPv4Address, Optional[int], Optional[int]], Session - ] = {} + self.sessions_by_key: Dict[Tuple[str, IPv4Address, IPv4Address, Optional[int], Optional[int]], Session] = {} self.sessions_by_uuid: Dict[str, Session] = {} self.sys_log: SysLog = sys_log self.software_manager: SoftwareManager = None # Noqa diff --git a/src/primaite/simulator/system/services/ftp/ftp_service.py b/src/primaite/simulator/system/services/ftp/ftp_service.py index 36245e0f..49678c82 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_service.py +++ b/src/primaite/simulator/system/services/ftp/ftp_service.py @@ -5,7 +5,6 @@ from typing import Dict, Optional from primaite.simulator.file_system.file_system import File from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.service import Service diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 74dcb506..4f0b879c 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -52,7 +52,7 @@ class Service(IOSoftware): def __init__(self, **kwargs): super().__init__(**kwargs) - def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: """ Register a hostnode type. @@ -60,7 +60,7 @@ class Service(IOSoftware): :type identifier: str :raises ValueError: When attempting to register an hostnode with a name that is already allocated. """ - if identifier == 'default': + if identifier == "default": return # Enforce lowercase registry entries because it makes comparisons everywhere else much easier. identifier = identifier.lower() diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 1880d244..084bdaf6 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -14,7 +14,6 @@ from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.file_system.file_system import FileSystem, Folder from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.core.sys_log import SysLog diff --git a/tests/integration_tests/extensions/nodes/giga_switch.py b/tests/integration_tests/extensions/nodes/giga_switch.py index b86bea7d..e4100741 100644 --- a/tests/integration_tests/extensions/nodes/giga_switch.py +++ b/tests/integration_tests/extensions/nodes/giga_switch.py @@ -1,3 +1,4 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from typing import Dict from prettytable import MARKDOWN, PrettyTable @@ -27,7 +28,7 @@ class GigaSwitch(NetworkNode, identifier="gigaswitch"): "A MAC address table mapping destination MAC addresses to corresponding SwitchPorts." def __init__(self, **kwargs): - print('--- Extended Component: GigaSwitch ---') + print("--- Extended Component: GigaSwitch ---") super().__init__(**kwargs) for i in range(1, self.num_ports + 1): self.connect_nic(SwitchPort()) diff --git a/tests/integration_tests/extensions/nodes/super_computer.py b/tests/integration_tests/extensions/nodes/super_computer.py index 8a1465e9..55bdce09 100644 --- a/tests/integration_tests/extensions/nodes/super_computer.py +++ b/tests/integration_tests/extensions/nodes/super_computer.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from typing import ClassVar, Dict -from primaite.simulator.network.hardware.nodes.host.host_node import NIC, HostNode +from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.utils.validators import IPV4Address @@ -37,7 +37,7 @@ class SuperComputer(HostNode, identifier="supercomputer"): SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "FTPClient": FTPClient} def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs): - print('--- Extended Component: SuperComputer ---') + print("--- Extended Component: SuperComputer ---") super().__init__(ip_address=ip_address, subnet_mask=subnet_mask, **kwargs) pass diff --git a/tests/integration_tests/extensions/services/extended_service.py b/tests/integration_tests/extensions/services/extended_service.py index d4af600f..b745b774 100644 --- a/tests/integration_tests/extensions/services/extended_service.py +++ b/tests/integration_tests/extensions/services/extended_service.py @@ -17,7 +17,7 @@ from primaite.simulator.system.software import SoftwareHealthState _LOGGER = getLogger(__name__) -class ExtendedService(Service, identifier='extendedservice'): +class ExtendedService(Service, identifier="extendedservice"): """ A copy of DatabaseService that uses the extension framework instead of being part of PrimAITE. @@ -42,7 +42,7 @@ class ExtendedService(Service, identifier='extendedservice'): kwargs["protocol"] = IPProtocol["TCP"] super().__init__(**kwargs) self._create_db_file() - if kwargs.get('options'): + if kwargs.get("options"): opt = kwargs["options"] self.password = opt.get("db_password", None) if "backup_server_ip" in opt: @@ -139,7 +139,9 @@ class ExtendedService(Service, identifier='extendedservice'): old_visible_state = SoftwareHealthState.GOOD # get db file regardless of whether or not it was deleted - db_file = self.file_system.get_file(folder_name="database", file_name="extended_service_database.db", include_deleted=True) + db_file = self.file_system.get_file( + folder_name="database", file_name="extended_service_database.db", include_deleted=True + ) if db_file is None: self.sys_log.warning("Database file not initialised.") @@ -153,7 +155,9 @@ class ExtendedService(Service, identifier='extendedservice'): self.file_system.delete_file(folder_name="database", file_name="extended_service_database.db") # replace db file - self.file_system.copy_file(src_folder_name="downloads", src_file_name="extended_service_database.db", dst_folder_name="database") + self.file_system.copy_file( + src_folder_name="downloads", src_file_name="extended_service_database.db", dst_folder_name="database" + ) if self.db_file is None: self.sys_log.error("Copying database backup failed.") diff --git a/tests/integration_tests/extensions/test_extendable_config.py b/tests/integration_tests/extensions/test_extendable_config.py index 5d8af64d..8467151b 100644 --- a/tests/integration_tests/extensions/test_extendable_config.py +++ b/tests/integration_tests/extensions/test_extendable_config.py @@ -1,22 +1,22 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import os + from primaite.config.load import get_extended_config_path from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ_NETWORK, load_config -import os +from tests.integration_tests.extensions.applications.extended_application import ExtendedApplication +from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch # Import the extended components so that PrimAITE registers them from tests.integration_tests.extensions.nodes.super_computer import SuperComputer -from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch from tests.integration_tests.extensions.services.extended_service import ExtendedService -from tests.integration_tests.extensions.applications.extended_application import ExtendedApplication def test_extended_example_config(): - """Test that the example config can be parsed properly.""" - config_path = os.path.join( "tests", "assets", "configs", "extended_config.yaml") + config_path = os.path.join("tests", "assets", "configs", "extended_config.yaml") game = load_config(config_path) network: Network = game.simulation.network @@ -25,8 +25,8 @@ def test_extended_example_config(): assert len(network.router_nodes) == 1 # 1 router in network assert len(network.switch_nodes) == 1 # 1 switches in network assert len(network.server_nodes) == 5 # 5 servers in network - assert len(network.extended_hostnodes) == 1 # One extended node based on HostNode - assert len(network.extended_networknodes) == 1 # One extended node based on NetworkNode + assert len(network.extended_hostnodes) == 1 # One extended node based on HostNode + assert len(network.extended_networknodes) == 1 # One extended node based on NetworkNode - assert 'ExtendedApplication' in network.extended_hostnodes[0].software_manager.software - assert 'ExtendedService' in network.extended_hostnodes[0].software_manager.software + assert "ExtendedApplication" in network.extended_hostnodes[0].software_manager.software + assert "ExtendedService" in network.extended_hostnodes[0].software_manager.software diff --git a/tests/integration_tests/game_layer/observations/test_acl_observations.py b/tests/integration_tests/game_layer/observations/test_acl_observations.py index 398c43a9..28f9ac5a 100644 --- a/tests/integration_tests/game_layer/observations/test_acl_observations.py +++ b/tests/integration_tests/game_layer/observations/test_acl_observations.py @@ -38,8 +38,8 @@ def test_acl_observations(simulation): acl_obs = ACLObservation( where=["network", "nodes", router.hostname, "acl", "acl"], ip_list=[], - port_list=["NTP", "HTTP", "POSTGRES_SERVER"], - protocol_list=["TCP", "UDP", "ICMP"], + port_list=[123, 80, 5432], + protocol_list=["tcp", "udp", "icmp"], num_rules=10, wildcard_list=[], ) diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 68506d59..21fe4bed 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -31,8 +31,8 @@ def test_firewall_observation(): num_rules=7, ip_list=["10.0.0.1", "10.0.0.2"], wildcard_list=["0.0.0.255", "0.0.0.1"], - port_list=["HTTP", "DNS"], - protocol_list=["TCP"], + port_list=[80, 53], + protocol_list=["tcp"], include_users=False, ) diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index bd8dfc4e..8254dad2 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -152,7 +152,12 @@ def test_config_nic_categories(simulation): def test_nic_monitored_traffic(simulation): - monitored_traffic = {"icmp": ["NONE"], "tcp": [53,]} + monitored_traffic = { + "icmp": ["NONE"], + "tcp": [ + 53, + ], + } pc: Computer = simulation.network.get_node_by_hostname("client_1") pc2: Computer = simulation.network.get_node_by_hostname("client_2") diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index 937bb061..c28e1bb8 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -24,8 +24,8 @@ def test_router_observation(): num_rules=7, ip_list=["10.0.0.1", "10.0.0.2"], wildcard_list=["0.0.0.255", "0.0.0.1"], - port_list=["HTTP", "DNS"], - protocol_list=["TCP"], + port_list=[80, 53], + protocol_list=["tcp"], ) router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl, include_users=False) diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index d872c2b0..570c4ad6 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -65,7 +65,9 @@ def test_uc2_rewards(game_and_agent): db_client.run() router: Router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=2) + router.acl.add_rule( + ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=2 + ) comp = GreenAdminDatabaseUnreachablePenalty("client_1") diff --git a/tests/integration_tests/network/test_airspace_config.py b/tests/integration_tests/network/test_airspace_config.py index 1794c4bc..e000f6ae 100644 --- a/tests/integration_tests/network/test_airspace_config.py +++ b/tests/integration_tests/network/test_airspace_config.py @@ -2,7 +2,6 @@ import yaml from primaite.game.game import PrimaiteGame -from primaite.simulator.network.airspace import AirSpaceFrequency from tests import TEST_ASSETS_ROOT @@ -13,8 +12,8 @@ def test_override_freq_max_capacity_mbps(): config_dict = yaml.safe_load(f) network = PrimaiteGame.from_config(cfg=config_dict).simulation.network - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_2_4"]) == 123.45 - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_5"]) == 0.0 + assert network.airspace.get_frequency_max_capacity_mbps("WIFI_2_4") == 123.45 + assert network.airspace.get_frequency_max_capacity_mbps("WIFI_5") == 0.0 pc_a = network.get_node_by_hostname("pc_a") pc_b = network.get_node_by_hostname("pc_b") @@ -32,8 +31,8 @@ def test_override_freq_max_capacity_mbps_blocked(): config_dict = yaml.safe_load(f) network = PrimaiteGame.from_config(cfg=config_dict).simulation.network - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_2_4"]) == 0.0 - assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_5"]) == 0.0 + assert network.airspace.get_frequency_max_capacity_mbps("WIFI_2_4") == 0.0 + assert network.airspace.get_frequency_max_capacity_mbps("WIFI_5") == 0.0 pc_a = network.get_node_by_hostname("pc_a") pc_b = network.get_node_by_hostname("pc_b") diff --git a/tests/integration_tests/network/test_firewall.py b/tests/integration_tests/network/test_firewall.py index 8e06ccfb..44b660cf 100644 --- a/tests/integration_tests/network/test_firewall.py +++ b/tests/integration_tests/network/test_firewall.py @@ -73,8 +73,12 @@ def dmz_external_internal_network() -> Network: firewall_node.external_outbound_acl.add_rule( action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 ) - firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + firewall_node.dmz_inbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 + ) + firewall_node.dmz_outbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 + ) # external node external_node = Computer( @@ -262,8 +266,12 @@ def test_service_allowed_with_rule(dmz_external_internal_network): assert not internal_ntp_client.time - firewall.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1) - firewall.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1) + firewall.internal_outbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1 + ) + firewall.internal_inbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1 + ) internal_ntp_client.request_time() diff --git a/tests/integration_tests/system/test_nmap.py b/tests/integration_tests/system/test_nmap.py index 1064ed1b..9d92b660 100644 --- a/tests/integration_tests/system/test_nmap.py +++ b/tests/integration_tests/system/test_nmap.py @@ -73,7 +73,9 @@ def test_port_scan_one_node_one_port(example_network): client_2 = network.get_node_by_hostname("client_2") actual_result = client_1_nmap.port_scan( - target_ip_address=client_2.network_interface[1].ip_address, target_port=Port["DNS"], target_protocol=IPProtocol["TCP"] + target_ip_address=client_2.network_interface[1].ip_address, + target_port=Port["DNS"], + target_protocol=IPProtocol["TCP"], ) expected_result = {IPv4Address("192.168.10.22"): {IPProtocol["TCP"]: [Port["DNS"]]}} diff --git a/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py b/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py index 4e40bbd8..8becc6ae 100644 --- a/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py +++ b/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py @@ -66,7 +66,9 @@ def test_nested_dicts(): The expected output should have string values of enums as keys at all levels. """ original_dict = { - IPProtocol["UDP"]: {Port["ARP"]: {"inbound": 0, "outbound": 1016.0, "details": {IPProtocol["TCP"]: {"latency": "low"}}}} + IPProtocol["UDP"]: { + Port["ARP"]: {"inbound": 0, "outbound": 1016.0, "details": {IPProtocol["TCP"]: {"latency": "low"}}} + } } expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0, "details": {"tcp": {"latency": "low"}}}}} assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict @@ -79,6 +81,9 @@ def test_non_dict_values(): The original dictionary contains lists and tuples as values. The expected output should preserve these non-dictionary values while converting enum keys to string values. """ - original_dict = {IPProtocol["UDP"]: [Port["ARP"], Port["HTTP"]], "protocols": (IPProtocol["TCP"], IPProtocol["UDP"])} + original_dict = { + IPProtocol["UDP"]: [Port["ARP"], Port["HTTP"]], + "protocols": (IPProtocol["TCP"], IPProtocol["UDP"]), + } expected_dict = {"udp": [Port["ARP"], Port["HTTP"]], "protocols": (IPProtocol["TCP"], IPProtocol["UDP"])} assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict From 695891f55c5f8b8cc9ff9cf893fb0c880618af2f Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 20 Sep 2024 11:21:28 +0100 Subject: [PATCH 3/9] Add port and protocol custom validators --- .../agent/observations/host_observations.py | 44 +++++------ .../agent/observations/nic_observations.py | 8 +- .../agent/observations/node_observations.py | 8 +- src/primaite/game/game.py | 22 +++--- src/primaite/simulator/network/creation.py | 10 ++- .../simulator/network/hardware/base.py | 23 +++--- .../hardware/nodes/network/firewall.py | 41 +++++----- .../network/hardware/nodes/network/router.py | 25 +++--- .../hardware/nodes/network/wireless_router.py | 11 ++- src/primaite/simulator/network/networks.py | 29 ++++--- .../network/transmission/data_link_layer.py | 13 ++-- .../network/transmission/network_layer.py | 32 +------- .../network/transmission/transport_layer.py | 77 +------------------ .../system/applications/database_client.py | 9 +-- .../simulator/system/applications/nmap.py | 13 ++-- .../red_applications/c2/abstract_c2.py | 18 ++--- .../red_applications/c2/c2_beacon.py | 12 +-- .../red_applications/data_manipulation_bot.py | 8 +- .../applications/red_applications/dos_bot.py | 6 +- .../red_applications/ransomware_script.py | 8 +- .../system/applications/web_browser.py | 12 +-- .../simulator/system/core/session_manager.py | 19 ++--- .../simulator/system/core/software_manager.py | 8 +- .../simulator/system/services/arp/arp.py | 9 +-- .../services/database/database_service.py | 8 +- .../system/services/dns/dns_client.py | 10 +-- .../system/services/dns/dns_server.py | 8 +- .../system/services/ftp/ftp_client.py | 18 ++--- .../system/services/ftp/ftp_server.py | 8 +- .../simulator/system/services/icmp/icmp.py | 8 +- .../system/services/ntp/ntp_client.py | 10 +-- .../system/services/ntp/ntp_server.py | 8 +- .../system/services/terminal/terminal.py | 8 +- .../system/services/web_server/web_server.py | 8 +- src/primaite/simulator/system/software.py | 4 +- src/primaite/utils/validators.py | 42 +++++++++- tests/conftest.py | 32 ++++---- .../nodes/network/test_firewall_config.py | 36 ++++----- .../nodes/network/test_router_config.py | 10 +-- .../applications/extended_application.py | 12 +-- .../extensions/services/extended_service.py | 8 +- .../actions/test_c2_suite_actions.py | 8 +- .../actions/test_configure_actions.py | 4 +- .../actions/test_terminal_actions.py | 4 +- .../observations/test_acl_observations.py | 4 +- .../observations/test_firewall_observation.py | 10 +-- .../observations/test_router_observation.py | 10 +-- .../observations/test_user_observations.py | 4 +- .../game_layer/test_actions.py | 28 +++---- .../game_layer/test_rewards.py | 13 +++- .../network/test_broadcast.py | 18 +++-- .../network/test_firewall.py | 40 +++++----- .../integration_tests/network/test_routing.py | 18 +++-- .../network/test_wireless_router.py | 10 ++- .../test_c2_suite_integration.py | 24 +++--- .../test_data_manipulation_bot_and_server.py | 7 +- .../test_dos_bot_and_server.py | 11 ++- .../test_ransomware_script.py | 7 +- tests/integration_tests/system/test_nmap.py | 30 +++++--- .../system/test_service_listening_on_ports.py | 20 ++--- .../test_web_client_server_and_database.py | 23 ++++-- .../test_simulation/test_request_response.py | 8 +- .../_network/_hardware/nodes/test_acl.py | 57 +++++++------- .../_network/_hardware/nodes/test_router.py | 8 +- .../_transmission/test_data_link_layer.py | 19 ++--- .../_red_applications/test_c2_suite.py | 24 +++--- .../test_data_manipulation_bot.py | 8 +- .../_red_applications/test_dos_bot.py | 2 +- .../_system/_applications/test_web_browser.py | 8 +- .../_system/_services/test_dns_client.py | 8 +- .../_system/_services/test_dns_server.py | 8 +- .../_system/_services/test_ftp_client.py | 14 ++-- .../_system/_services/test_ftp_server.py | 8 +- .../_system/_services/test_terminal.py | 18 +++-- .../_system/_services/test_web_server.py | 8 +- .../_simulator/_system/test_software.py | 8 +- .../_utils/test_dict_enum_keys_conversion.py | 27 ++++--- 77 files changed, 616 insertions(+), 613 deletions(-) diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 30ccd195..0984f008 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -5,7 +5,6 @@ from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType -from pydantic import field_validator from primaite import getLogger from primaite.game.agent.observations.file_system_observations import FolderObservation @@ -13,8 +12,7 @@ from primaite.game.agent.observations.nic_observations import NICObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.utils.validators import IPProtocol, Port _LOGGER = getLogger(__name__) @@ -47,7 +45,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): """Number of spaces for network interface observations on this host.""" include_nmne: Optional[bool] = None """Whether network interface observations should include number of malicious network events.""" - monitored_traffic: Optional[Dict] = None + monitored_traffic: Optional[Dict[IPProtocol, List[Port]]] = None """A dict containing which traffic types are to be included in the observation.""" include_num_access: Optional[bool] = None """Whether to include the number of accesses to files observations on this host.""" @@ -58,26 +56,26 @@ class HostObservation(AbstractObservation, identifier="HOST"): include_users: Optional[bool] = True """If True, report user session information.""" - @field_validator("monitored_traffic", mode="before") - def traffic_lookup(cls, val: Optional[Dict]) -> Optional[Dict]: - """ - Convert monitored_traffic by lookup against Port and Protocol dicts. + # @field_validator("monitored_traffic", mode="before") + # def traffic_lookup(cls, val: Optional[Dict]) -> Optional[Dict]: + # """ + # Convert monitored_traffic by lookup against Port and Protocol dicts. - This is necessary for retaining compatiblility with configs written for PrimAITE <=3.3. - This method will be removed in PrimAITE >= 4.0 - """ - if val is None: - return val - new_val = {} - for proto, port_list in val.items(): - # convert protocol, for instance ICMP becomes "icmp" - proto = IPProtocol[proto] if proto in IPProtocol else proto - new_val[proto] = [] - for port in port_list: - # convert ports, for instance "HTTP" becomes 80 - port = Port[port] if port in Port else port - new_val[proto].append(port) - return new_val + # This is necessary for retaining compatiblility with configs written for PrimAITE <=3.3. + # This method will be removed in PrimAITE >= 4.0 + # """ + # if val is None: + # return val + # new_val = {} + # for proto, port_list in val.items(): + # # convert protocol, for instance ICMP becomes "icmp" + # proto = PROTOCOL_LOOKUP[proto] if proto in PROTOCOL_LOOKUP else proto + # new_val[proto] = [] + # for port in port_list: + # # convert ports, for instance "HTTP" becomes 80 + # port = PORT_LOOKUP[port] if port in PORT_LOOKUP else port + # new_val[proto].append(port) + # return new_val def __init__( self, diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 296ce04c..c51cb427 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -9,8 +9,8 @@ from pydantic import field_validator from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validators import PROTOCOL_LOOKUP class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): @@ -39,11 +39,11 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): new_val = {} for proto, port_list in val.items(): # convert protocol, for instance ICMP becomes "icmp" - proto = IPProtocol[proto] if proto in IPProtocol else proto + proto = PROTOCOL_LOOKUP[proto] if proto in PROTOCOL_LOOKUP else proto new_val[proto] = [] for port in port_list: # convert ports, for instance "HTTP" becomes 80 - port = Port[port] if port in Port else port + port = PORT_LOOKUP[port] if port in PORT_LOOKUP else port new_val[proto].append(port) return new_val diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 054ffcdb..0bb8ea0f 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -12,8 +12,8 @@ from primaite.game.agent.observations.firewall_observation import FirewallObserv from primaite.game.agent.observations.host_observations import HostObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.observations.router_observation import RouterObservation -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -76,11 +76,11 @@ class NodesObservation(AbstractObservation, identifier="NODES"): new_val = {} for proto, port_list in val.items(): # convert protocol, for instance ICMP becomes "icmp" - proto = IPProtocol[proto] if proto in IPProtocol else proto + proto = PROTOCOL_LOOKUP[proto] if proto in PROTOCOL_LOOKUP else proto new_val[proto] = [] for port in port_list: # convert ports, for instance "HTTP" becomes 80 - port = Port[port] if port in Port else port + port = PORT_LOOKUP[port] if port in PORT_LOOKUP else port new_val[proto].append(port) return new_val diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 8e0abb1e..a0d2ceb4 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -27,8 +27,7 @@ from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from primaite.simulator.network.nmne import NMNEConfig -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient # noqa: F401 @@ -51,6 +50,7 @@ from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.terminal.terminal import Terminal from primaite.simulator.system.services.web_server.web_server import WebServer from primaite.simulator.system.software import Software +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -97,8 +97,8 @@ class PrimaiteGameOptions(BaseModel): :warning: This will be deprecated in PrimAITE 4.0 and configs will need to be converted. """ for i, port_val in enumerate(vals): - if port_val in Port: - vals[i] = Port[port_val] + if port_val in PORT_LOOKUP: + vals[i] = PORT_LOOKUP[port_val] return vals @field_validator("protocols", mode="before") @@ -110,8 +110,8 @@ class PrimaiteGameOptions(BaseModel): :warning: This will be deprecated in PrimAITE 4.0 and configs will need to be converted. """ for i, proto_val in enumerate(vals): - if proto_val in IPProtocol: - vals[i] = IPProtocol[proto_val] + if proto_val in PROTOCOL_LOOKUP: + vals[i] = PROTOCOL_LOOKUP[proto_val] return vals @@ -381,7 +381,7 @@ class PrimaiteGame: if isinstance(port_id, int): port = port_id elif isinstance(port_id, str): - port = Port[port_id] + port = PORT_LOOKUP[port_id] if port: listen_on_ports.append(port) software.listen_on_ports = set(listen_on_ports) @@ -496,7 +496,7 @@ class PrimaiteGame: opt = application_cfg["options"] new_application.configure( target_ip_address=IPv4Address(opt.get("target_ip_address")), - target_port=Port[opt.get("target_port", "POSTGRES_SERVER")], + target_port=PORT_LOOKUP[opt.get("target_port", "POSTGRES_SERVER")], payload=opt.get("payload"), repeat=bool(opt.get("repeat")), port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), @@ -509,8 +509,10 @@ class PrimaiteGame: new_application.configure( c2_server_ip_address=IPv4Address(opt.get("c2_server_ip_address")), keep_alive_frequency=(opt.get("keep_alive_frequency", 5)), - masquerade_protocol=IPProtocol[(opt.get("masquerade_protocol", IPProtocol["TCP"]))], - masquerade_port=Port[(opt.get("masquerade_port", Port["HTTP"]))], + masquerade_protocol=PROTOCOL_LOOKUP[ + (opt.get("masquerade_protocol", PROTOCOL_LOOKUP["TCP"])) + ], + masquerade_port=PORT_LOOKUP[(opt.get("masquerade_port", PORT_LOOKUP["HTTP"]))], ) if "network_interfaces" in node_cfg: for nic_num, nic_cfg in node_cfg["network_interfaces"].items(): diff --git a/src/primaite/simulator/network/creation.py b/src/primaite/simulator/network/creation.py index c2524b4b..9e2e5502 100644 --- a/src/primaite/simulator/network/creation.py +++ b/src/primaite/simulator/network/creation.py @@ -6,8 +6,8 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validators import PROTOCOL_LOOKUP def num_of_switches_required(num_nodes: int, max_network_interface: int = 24) -> int: @@ -98,8 +98,10 @@ def create_office_lan( default_gateway = IPv4Address(f"192.168.{subnet_base}.1") router = Router(hostname=f"router_{lan_name}", start_up_duration=0) router.power_on() - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + router.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + router.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) network.add_node(router) router.configure_port(port=1, ip_address=default_gateway, subnet_mask="255.255.255.0") router.enable_port(1) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 4154cc08..affaf3cc 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -21,8 +21,7 @@ from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.packet_capture import PacketCapture from primaite.simulator.system.core.session_manager import SessionManager @@ -33,7 +32,7 @@ from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.terminal.terminal import Terminal from primaite.simulator.system.software import IOSoftware, Software from primaite.utils.converters import convert_dict_enum_keys_to_enum_values -from primaite.utils.validators import IPV4Address +from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware) @@ -274,20 +273,20 @@ class NetworkInterface(SimComponent, ABC): # Identify the protocol and port from the frame if frame.tcp: - protocol = IPProtocol["TCP"] + protocol = PROTOCOL_LOOKUP["TCP"] port = frame.tcp.dst_port elif frame.udp: - protocol = IPProtocol["UDP"] + protocol = PROTOCOL_LOOKUP["UDP"] port = frame.udp.dst_port elif frame.icmp: - protocol = IPProtocol["ICMP"] + protocol = PROTOCOL_LOOKUP["ICMP"] # Ensure the protocol is in the capture dict if protocol not in self.traffic: self.traffic[protocol] = {} # Handle non-ICMP protocols that use ports - if protocol != IPProtocol["ICMP"]: + if protocol != PROTOCOL_LOOKUP["ICMP"]: if port not in self.traffic[protocol]: self.traffic[protocol][port] = {"inbound": 0, "outbound": 0} self.traffic[protocol][port][direction] += frame.size_Mbits @@ -843,8 +842,8 @@ class UserManager(Service): :param password: The password for the default admin user """ kwargs["name"] = "UserManager" - kwargs["port"] = Port["NONE"] - kwargs["protocol"] = IPProtocol["NONE"] + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) self.start() @@ -1166,8 +1165,8 @@ class UserSessionManager(Service): :param password: The password for the default admin user """ kwargs["name"] = "UserSessionManager" - kwargs["port"] = Port["NONE"] - kwargs["protocol"] = IPProtocol["NONE"] + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) self.start() @@ -1312,7 +1311,7 @@ class UserSessionManager(Service): software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( payload={"type": "user_timeout", "connection_id": session.uuid}, - dest_port=Port["SSH"], + dest_port=PORT_LOOKUP["SSH"], dest_ip_address=session.remote_ip_address, ) diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 6d8e084d..eed1132b 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -14,10 +14,9 @@ from primaite.simulator.network.hardware.nodes.network.router import ( RouterInterface, ) from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.sys_log import SysLog -from primaite.utils.validators import IPV4Address +from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP EXTERNAL_PORT_ID: Final[int] = 1 """The Firewall port ID of the external port.""" @@ -596,9 +595,9 @@ class Firewall(Router): for r_num, r_cfg in cfg["acl"]["internal_inbound_acl"].items(): firewall.internal_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -611,9 +610,9 @@ class Firewall(Router): for r_num, r_cfg in cfg["acl"]["internal_outbound_acl"].items(): firewall.internal_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -626,9 +625,9 @@ class Firewall(Router): for r_num, r_cfg in cfg["acl"]["dmz_inbound_acl"].items(): firewall.dmz_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -641,9 +640,9 @@ class Firewall(Router): for r_num, r_cfg in cfg["acl"]["dmz_outbound_acl"].items(): firewall.dmz_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -656,9 +655,9 @@ class Firewall(Router): for r_num, r_cfg in cfg["acl"]["external_inbound_acl"].items(): firewall.external_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), @@ -671,9 +670,9 @@ class Firewall(Router): for r_num, r_cfg in cfg["acl"]["external_outbound_acl"].items(): firewall.external_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index fded23f9..46efe668 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -17,15 +17,14 @@ from primaite.simulator.network.hardware.nodes.network.network_node import Netwo from primaite.simulator.network.protocols.arp import ARPPacket from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.nmap import NMAP from primaite.simulator.system.core.session_manager import SessionManager from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.arp.arp import ARP from primaite.simulator.system.services.icmp.icmp import ICMP from primaite.simulator.system.services.terminal.terminal import Terminal -from primaite.utils.validators import IPV4Address +from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP @validate_call() @@ -134,14 +133,14 @@ class ACLRule(SimComponent): def protocol_valid(cls, val: Optional[str]) -> Optional[str]: """Assert that the protocol for the rule is predefined in the IPProtocol lookup.""" if val is not None: - assert val in IPProtocol.values(), f"Cannot create ACL rule with invalid protocol {val}" + assert val in PROTOCOL_LOOKUP.values(), f"Cannot create ACL rule with invalid protocol {val}" return val @field_validator("src_port", "dst_port", mode="before") def ports_valid(cls, val: Optional[int]) -> Optional[int]: """Assert that the port for the rule is predefined in the Port lookup.""" if val is not None: - assert val in Port.values(), f"Cannot create ACL rule with invalid port {val}" + assert val in PORT_LOOKUP.values(), f"Cannot create ACL rule with invalid port {val}" return val def __str__(self) -> str: @@ -1271,8 +1270,10 @@ class Router(NetworkNode): Initializes the router's ACL (Access Control List) with default rules, permitting essential protocols like ARP and ICMP, which are necessary for basic network operations and diagnostics. """ - self.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - self.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + self.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + self.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) def setup_for_episode(self, episode: int): """ @@ -1371,9 +1372,9 @@ class Router(NetworkNode): """ dst_ip_address = frame.ip.dst_ip_address dst_port = None - if frame.ip.protocol == IPProtocol["TCP"]: + if frame.ip.protocol == PROTOCOL_LOOKUP["TCP"]: dst_port = frame.tcp.dst_port - elif frame.ip.protocol == IPProtocol["UDP"]: + elif frame.ip.protocol == PROTOCOL_LOOKUP["UDP"]: dst_port = frame.udp.dst_port if self.ip_is_router_interface(dst_ip_address) and ( @@ -1646,9 +1647,9 @@ class Router(NetworkNode): for r_num, r_cfg in cfg["acl"].items(): router.acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), dst_ip_address=r_cfg.get("dst_ip"), diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 1969a121..3615ef54 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -8,9 +8,8 @@ from primaite.simulator.network.airspace import AirSpace, IPWirelessNetworkInter from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterInterface from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port -from primaite.utils.validators import IPV4Address +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP class WirelessAccessPoint(IPWirelessNetworkInterface): @@ -271,9 +270,9 @@ class WirelessRouter(Router): for r_num, r_cfg in cfg["acl"].items(): router.acl.add_rule( action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], src_ip_address=r_cfg.get("src_ip"), dst_ip_address=r_cfg.get("dst_ip"), src_wildcard_mask=r_cfg.get("src_wildcard_mask"), diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index a73f3b12..c3b4a341 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -12,14 +12,14 @@ from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -79,9 +79,11 @@ def client_server_routed() -> Network: server_1.power_on() network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1]) - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) return network @@ -271,23 +273,30 @@ def arcd_uc2_network() -> Network: security_suite.connect_nic(NIC(ip_address="192.168.10.110", subnet_mask="255.255.255.0")) network.connect(endpoint_b=security_suite.network_interface[2], endpoint_a=switch_2.network_interface[7]) - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Allow PostgreSQL requests router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) # Allow DNS requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["DNS"], dst_port=Port["DNS"], position=1) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["DNS"], dst_port=PORT_LOOKUP["DNS"], position=1) # Allow FTP requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["FTP"], dst_port=Port["FTP"], position=2) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=2) # Open port 80 for web server - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=3) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=3 + ) return network diff --git a/src/primaite/simulator/network/transmission/data_link_layer.py b/src/primaite/simulator/network/transmission/data_link_layer.py index b9bc48d9..ca212c58 100644 --- a/src/primaite/simulator/network/transmission/data_link_layer.py +++ b/src/primaite/simulator/network/transmission/data_link_layer.py @@ -7,10 +7,11 @@ from pydantic import BaseModel from primaite import getLogger from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.protocols.packet import DataPacket -from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol +from primaite.simulator.network.transmission.network_layer import IPPacket from primaite.simulator.network.transmission.primaite_layer import PrimaiteHeader -from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader, UDPHeader +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP, TCPHeader, UDPHeader from primaite.simulator.network.utils import convert_bytes_to_megabits +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -70,15 +71,15 @@ class Frame(BaseModel): msg = "Network Frame cannot have both a TCP header and a UDP header" _LOGGER.error(msg) raise ValueError(msg) - if kwargs["ip"].protocol == IPProtocol["TCP"] and not kwargs.get("tcp"): + if kwargs["ip"].protocol == PROTOCOL_LOOKUP["TCP"] and not kwargs.get("tcp"): msg = "Cannot build a Frame using the TCP IP Protocol without a TCPHeader" _LOGGER.error(msg) raise ValueError(msg) - if kwargs["ip"].protocol == IPProtocol["UDP"] and not kwargs.get("udp"): + if kwargs["ip"].protocol == PROTOCOL_LOOKUP["UDP"] and not kwargs.get("udp"): msg = "Cannot build a Frame using the UDP IP Protocol without a UDPHeader" _LOGGER.error(msg) raise ValueError(msg) - if kwargs["ip"].protocol == IPProtocol["ICMP"] and not kwargs.get("icmp"): + if kwargs["ip"].protocol == PROTOCOL_LOOKUP["ICMP"] and not kwargs.get("icmp"): msg = "Cannot build a Frame using the ICMP IP Protocol without a ICMPPacket" _LOGGER.error(msg) raise ValueError(msg) @@ -165,7 +166,7 @@ class Frame(BaseModel): :return: True if the Frame is an ARP packet, otherwise False. """ - return self.udp.dst_port == Port["ARP"] + return self.udp.dst_port == PORT_LOOKUP["ARP"] @property def is_icmp(self) -> bool: diff --git a/src/primaite/simulator/network/transmission/network_layer.py b/src/primaite/simulator/network/transmission/network_layer.py index a01b7f42..47e8a032 100644 --- a/src/primaite/simulator/network/transmission/network_layer.py +++ b/src/primaite/simulator/network/transmission/network_layer.py @@ -4,39 +4,11 @@ from enum import Enum from pydantic import BaseModel from primaite import getLogger -from primaite.utils.validators import IPV4Address +from primaite.utils.validators import IPProtocol, IPV4Address, PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) -IPProtocol: dict[str, str] = dict( - NONE="none", - TCP="tcp", - UDP="udp", - ICMP="icmp", -) - -# class IPProtocol(Enum): -# """ -# Enum representing transport layer protocols in IP header. - -# .. _List of IPProtocols: -# """ - -# NONE = "none" -# """Placeholder for a non-protocol.""" -# TCP = "tcp" -# """Transmission Control Protocol.""" -# UDP = "udp" -# """User Datagram Protocol.""" -# ICMP = "icmp" -# """Internet Control Message Protocol.""" - -# def model_dump(self) -> str: -# """Return as JSON-serialisable string.""" -# return self.name - - class Precedence(Enum): """ Enum representing the Precedence levels in Quality of Service (QoS) for IP packets. @@ -98,7 +70,7 @@ class IPPacket(BaseModel): "Source IP address." dst_ip_address: IPV4Address "Destination IP address." - protocol: str = IPProtocol["TCP"] + protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"] "IPProtocol." ttl: int = 64 "Time to Live (TTL) for the packet." diff --git a/src/primaite/simulator/network/transmission/transport_layer.py b/src/primaite/simulator/network/transmission/transport_layer.py index 60f2f070..fbc4b5ad 100644 --- a/src/primaite/simulator/network/transmission/transport_layer.py +++ b/src/primaite/simulator/network/transmission/transport_layer.py @@ -4,7 +4,7 @@ from typing import List from pydantic import BaseModel -Port: dict[str, int] = dict( +PORT_LOOKUP: dict[str, int] = dict( UNUSED=-1, NONE=0, WOL=9, @@ -36,81 +36,6 @@ Port: dict[str, int] = dict( POSTGRES_SERVER=5432, ) -# class Port(): -# def __getattr__() - - -# class Port(Enum): -# """ -# Enumeration of common known TCP/UDP ports used by protocols for operation of network applications. - -# .. _List of Ports: -# """ - -# UNUSED = -1 -# "An unused port stub." - -# NONE = 0 -# "Place holder for a non-port." -# WOL = 9 -# "Wake-on-Lan (WOL) - Used to turn or awaken a computer from sleep mode by a network message." -# FTP_DATA = 20 -# "File Transfer [Default Data]" -# FTP = 21 -# "File Transfer Protocol (FTP) - FTP control (command)" -# SSH = 22 -# "Secure Shell (SSH) - Used for secure remote access and command execution." -# SMTP = 25 -# "Simple Mail Transfer Protocol (SMTP) - Used for email delivery between servers." -# DNS = 53 -# "Domain Name System (DNS) - Used for translating domain names to IP addresses." -# HTTP = 80 -# "HyperText Transfer Protocol (HTTP) - Used for web traffic." -# POP3 = 110 -# "Post Office Protocol version 3 (POP3) - Used for retrieving emails from a mail server." -# SFTP = 115 -# "Secure File Transfer Protocol (SFTP) - Used for secure file transfer over SSH." -# NTP = 123 -# "Network Time Protocol (NTP) - Used for clock synchronization between computer systems." -# IMAP = 143 -# "Internet Message Access Protocol (IMAP) - Used for retrieving emails from a mail server." -# SNMP = 161 -# "Simple Network Management Protocol (SNMP) - Used for network device management." -# SNMP_TRAP = 162 -# "SNMP Trap - Used for sending SNMP notifications (traps) to a network management system." -# ARP = 219 -# "Address resolution Protocol - Used to connect a MAC address to an IP address." -# LDAP = 389 -# "Lightweight Directory Access Protocol (LDAP) - Used for accessing and modifying directory information." -# HTTPS = 443 -# "HyperText Transfer Protocol Secure (HTTPS) - Used for secure web traffic." -# SMB = 445 -# "Server Message Block (SMB) - Used for file sharing and printer sharing in Windows environments." -# IPP = 631 -# "Internet Printing Protocol (IPP) - Used for printing over the internet or an intranet." -# SQL_SERVER = 1433 -# "Microsoft SQL Server Database Engine - Used for communication with the SQL Server." -# MYSQL = 3306 -# "MySQL Database Server - Used for MySQL database communication." -# RDP = 3389 -# "Remote Desktop Protocol (RDP) - Used for remote desktop access to Windows machines." -# RTP = 5004 -# "Real-time Transport Protocol (RTP) - Used for transmitting real-time media, e.g., audio and video." -# RTP_ALT = 5005 -# "Alternative port for RTP (RTP_ALT) - Used in some configurations for transmitting real-time media." -# DNS_ALT = 5353 -# "Alternative port for DNS (DNS_ALT) - Used in some configurations for DNS service." -# HTTP_ALT = 8080 -# "Alternative port for HTTP (HTTP_ALT) - Often used as an alternative HTTP port for web applications." -# HTTPS_ALT = 8443 -# "Alternative port for HTTPS (HTTPS_ALT) - Used in some configurations for secure web traffic." -# POSTGRES_SERVER = 5432 -# "Postgres SQL Server." - -# def model_dump(self) -> str: -# """Return a json-serialisable string.""" -# return self.name - class UDPHeader(BaseModel): """ diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 170e2647..4967f519 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -11,11 +11,10 @@ from pydantic import BaseModel from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.hardware.nodes.host.host_node import HostNode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.software_manager import SoftwareManager -from primaite.utils.validators import IPV4Address +from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP class DatabaseClientConnection(BaseModel): @@ -90,8 +89,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"): def __init__(self, **kwargs): kwargs["name"] = "DatabaseClient" - kwargs["port"] = Port["POSTGRES_SERVER"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def _init_request_manager(self) -> RequestManager: diff --git a/src/primaite/simulator/system/applications/nmap.py b/src/primaite/simulator/system/applications/nmap.py index 74bce85d..34433e65 100644 --- a/src/primaite/simulator/system/applications/nmap.py +++ b/src/primaite/simulator/system/applications/nmap.py @@ -7,10 +7,9 @@ from pydantic import validate_call from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application -from primaite.utils.validators import IPV4Address +from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP class PortScanPayload(SimComponent): @@ -64,8 +63,8 @@ class NMAP(Application, identifier="NMAP"): def __init__(self, **kwargs): kwargs["name"] = "NMAP" - kwargs["port"] = Port["NONE"] - kwargs["protocol"] = IPProtocol["NONE"] + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) def _can_perform_network_action(self) -> bool: @@ -348,12 +347,12 @@ class NMAP(Application, identifier="NMAP"): if isinstance(target_port, int): target_port = [target_port] elif target_port is None: - target_port = [port for port in Port if port not in {Port["NONE"], Port["UNUSED"]}] + target_port = [port for port in PORT_LOOKUP if port not in {PORT_LOOKUP["NONE"], PORT_LOOKUP["UNUSED"]}] if isinstance(target_protocol, str): target_protocol = [target_protocol] elif target_protocol is None: - target_protocol = [IPProtocol["TCP"], IPProtocol["UDP"]] + target_protocol = [PROTOCOL_LOOKUP["TCP"], PROTOCOL_LOOKUP["UDP"]] scan_type = self._determine_port_scan_type(list(ip_addresses), target_port) active_ports = {} diff --git a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py index d442d968..b0cdefba 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py @@ -9,14 +9,14 @@ from pydantic import BaseModel, Field, validate_call from primaite.interface.request import RequestResponse from primaite.simulator.file_system.file_system import FileSystem, Folder from primaite.simulator.network.protocols.masquerade import C2Packet -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application, ApplicationOperatingState from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validators import PROTOCOL_LOOKUP class C2Command(Enum): @@ -81,10 +81,10 @@ class AbstractC2(Application, identifier="AbstractC2"): keep_alive_frequency: int = Field(default=5, ge=1) """The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon.""" - masquerade_protocol: str = Field(default=IPProtocol["TCP"]) + masquerade_protocol: str = Field(default=PROTOCOL_LOOKUP["TCP"]) """The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP.""" - masquerade_port: int = Field(default=Port["HTTP"]) + masquerade_port: int = Field(default=PORT_LOOKUP["HTTP"]) """The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP.""" c2_config: _C2Opts = _C2Opts() @@ -142,9 +142,9 @@ class AbstractC2(Application, identifier="AbstractC2"): def __init__(self, **kwargs): """Initialise the C2 applications to by default listen for HTTP traffic.""" - kwargs["listen_on_ports"] = {Port["HTTP"], Port["FTP"], Port["DNS"]} - kwargs["port"] = Port["NONE"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["listen_on_ports"] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]} + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) @property @@ -410,8 +410,8 @@ class AbstractC2(Application, identifier="AbstractC2"): self.keep_alive_inactivity = 0 self.keep_alive_frequency = 5 self.c2_remote_connection = None - self.c2_config.masquerade_port = Port["HTTP"] - self.c2_config.masquerade_protocol = IPProtocol["TCP"] + self.c2_config.masquerade_port = PORT_LOOKUP["HTTP"] + self.c2_config.masquerade_protocol = PROTOCOL_LOOKUP["TCP"] @abstractmethod def _confirm_remote_connection(self, timestep: int) -> bool: diff --git a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py index 9178e68a..450c60ad 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py @@ -8,12 +8,12 @@ from pydantic import validate_call from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.protocols.masquerade import C2Packet -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.red_applications.c2 import ExfilOpts, RansomwareOpts, TerminalOpts from primaite.simulator.system.applications.red_applications.c2.abstract_c2 import AbstractC2, C2Command, C2Payload from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.terminal.terminal import Terminal, TerminalClientConnection +from primaite.utils.validators import PROTOCOL_LOOKUP class C2Beacon(AbstractC2, identifier="C2Beacon"): @@ -111,8 +111,8 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): self.configure( c2_server_ip_address=c2_remote_ip, keep_alive_frequency=frequency, - masquerade_protocol=IPProtocol[protocol], - masquerade_port=Port[port], + masquerade_protocol=PROTOCOL_LOOKUP[protocol], + masquerade_port=PORT_LOOKUP[port], ) ) @@ -129,8 +129,8 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"): self, c2_server_ip_address: IPv4Address = None, keep_alive_frequency: int = 5, - masquerade_protocol: str = IPProtocol["TCP"], - masquerade_port: int = Port["HTTP"], + masquerade_protocol: str = PROTOCOL_LOOKUP["TCP"], + masquerade_port: int = PORT_LOOKUP["HTTP"], ) -> bool: """ Configures the C2 beacon to communicate with the C2 server. diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index d74ae384..c2d19160 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -7,10 +7,10 @@ from primaite import getLogger from primaite.game.science import simulate_trial from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -50,8 +50,8 @@ class DataManipulationBot(Application, identifier="DataManipulationBot"): def __init__(self, **kwargs): kwargs["name"] = "DataManipulationBot" - kwargs["port"] = Port["NONE"] - kwargs["protocol"] = IPProtocol["NONE"] + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) self._db_connection: Optional[DatabaseClientConnection] = None diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index 2cc99c4a..7e199b48 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -7,7 +7,7 @@ from primaite import getLogger from primaite.game.science import simulate_trial from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClient _LOGGER = getLogger(__name__) @@ -85,7 +85,7 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): if "target_ip_address" in request[-1]: request[-1]["target_ip_address"] = IPv4Address(request[-1]["target_ip_address"]) if "target_port" in request[-1]: - request[-1]["target_port"] = Port[request[-1]["target_port"]] + request[-1]["target_port"] = PORT_LOOKUP[request[-1]["target_port"]] return RequestResponse.from_bool(self.configure(**request[-1])) rm.add_request("configure", request_type=RequestType(func=_configure)) @@ -94,7 +94,7 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): def configure( self, target_ip_address: IPv4Address, - target_port: Optional[int] = Port["POSTGRES_SERVER"], + target_port: Optional[int] = PORT_LOOKUP["POSTGRES_SERVER"], payload: Optional[str] = None, repeat: bool = False, port_scan_p_of_success: float = 0.1, diff --git a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py index a819190c..56f885f4 100644 --- a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py +++ b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py @@ -6,10 +6,10 @@ from prettytable import MARKDOWN, PrettyTable from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection +from primaite.utils.validators import PROTOCOL_LOOKUP class RansomwareScript(Application, identifier="RansomwareScript"): @@ -27,8 +27,8 @@ class RansomwareScript(Application, identifier="RansomwareScript"): def __init__(self, **kwargs): kwargs["name"] = "RansomwareScript" - kwargs["port"] = Port["NONE"] - kwargs["protocol"] = IPProtocol["NONE"] + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"] super().__init__(**kwargs) self._db_connection: Optional[DatabaseClientConnection] = None diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 6707fa52..faa7b5ec 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -15,10 +15,10 @@ from primaite.simulator.network.protocols.http import ( HttpResponsePacket, HttpStatusCode, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.services.dns.dns_client import DNSClient +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -43,10 +43,10 @@ class WebBrowser(Application, identifier="WebBrowser"): def __init__(self, **kwargs): kwargs["name"] = "WebBrowser" - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] # default for web is port 80 if kwargs.get("port") is None: - kwargs["port"] = Port["HTTP"] + kwargs["port"] = PORT_LOOKUP["HTTP"] super().__init__(**kwargs) self.run() @@ -126,7 +126,7 @@ class WebBrowser(Application, identifier="WebBrowser"): if self.send( payload=payload, dest_ip_address=self.domain_name_ip_address, - dest_port=parsed_url.port if parsed_url.port else Port["HTTP"], + dest_port=parsed_url.port if parsed_url.port else PORT_LOOKUP["HTTP"], ): self.sys_log.info( f"{self.name}: Received HTTP {payload.request_method.name} " @@ -154,7 +154,7 @@ class WebBrowser(Application, identifier="WebBrowser"): self, payload: HttpRequestPacket, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[int] = Port["HTTP"], + dest_port: Optional[int] = PORT_LOOKUP["HTTP"], session_id: Optional[str] = None, **kwargs, ) -> bool: diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index 33de3443..fcf07d9f 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -10,8 +10,9 @@ from primaite.simulator.core import SimComponent from primaite.simulator.network.protocols.arp import ARPPacket from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame -from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader, UDPHeader +from primaite.simulator.network.transmission.network_layer import IPPacket +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP, TCPHeader, UDPHeader +from primaite.utils.validators import PROTOCOL_LOOKUP if TYPE_CHECKING: from primaite.simulator.network.hardware.base import NetworkInterface @@ -117,7 +118,7 @@ class SessionManager: """ protocol = frame.ip.protocol with_ip_address = frame.ip.src_ip_address - if protocol == IPProtocol["TCP"]: + if protocol == PROTOCOL_LOOKUP["TCP"]: if inbound_frame: src_port = frame.tcp.src_port dst_port = frame.tcp.dst_port @@ -125,7 +126,7 @@ class SessionManager: dst_port = frame.tcp.src_port src_port = frame.tcp.dst_port with_ip_address = frame.ip.dst_ip_address - elif protocol == IPProtocol["UDP"]: + elif protocol == PROTOCOL_LOOKUP["UDP"]: if inbound_frame: src_port = frame.udp.src_port dst_port = frame.udp.dst_port @@ -260,7 +261,7 @@ class SessionManager: src_port: Optional[int] = None, dst_port: Optional[int] = None, session_id: Optional[str] = None, - ip_protocol: str = IPProtocol["TCP"], + ip_protocol: str = PROTOCOL_LOOKUP["TCP"], icmp_packet: Optional[ICMPPacket] = None, ) -> Union[Any, None]: """ @@ -284,7 +285,7 @@ class SessionManager: dst_mac_address = payload.target_mac_addr outbound_network_interface = self.resolve_outbound_network_interface(payload.target_ip_address) is_broadcast = payload.request - ip_protocol = IPProtocol["UDP"] + ip_protocol = PROTOCOL_LOOKUP["UDP"] else: vals = self.resolve_outbound_transmission_details( dst_ip_address=dst_ip_address, @@ -316,12 +317,12 @@ class SessionManager: tcp_header = None udp_header = None - if ip_protocol == IPProtocol["TCP"]: + if ip_protocol == PROTOCOL_LOOKUP["TCP"]: tcp_header = TCPHeader( src_port=dst_port, dst_port=dst_port, ) - elif ip_protocol == IPProtocol["UDP"]: + elif ip_protocol == PROTOCOL_LOOKUP["UDP"]: udp_header = UDPHeader( src_port=dst_port, dst_port=dst_port, @@ -385,7 +386,7 @@ class SessionManager: elif frame.udp: dst_port = frame.udp.dst_port elif frame.icmp: - dst_port = Port["NONE"] + dst_port = PORT_LOOKUP["NONE"] self.software_manager.receive_payload_from_session_manager( payload=frame.payload, port=dst_port, diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index 8eac33fa..abf2ca3a 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -8,12 +8,12 @@ from prettytable import MARKDOWN, PrettyTable from primaite.simulator.core import RequestType from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application, ApplicationOperatingState from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.service import Service, ServiceOperatingState from primaite.simulator.system.software import IOSoftware +from primaite.utils.validators import PROTOCOL_LOOKUP if TYPE_CHECKING: from primaite.simulator.system.core.session_manager import SessionManager @@ -191,7 +191,7 @@ class SoftwareManager: dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, src_port: Optional[int] = None, dest_port: Optional[int] = None, - ip_protocol: str = IPProtocol["TCP"], + ip_protocol: str = PROTOCOL_LOOKUP["TCP"], session_id: Optional[str] = None, ) -> bool: """ @@ -275,7 +275,7 @@ class SoftwareManager: software_type, software.operating_state.name, software.health_state_actual.name, - software.port if software.port != Port["NONE"] else None, + software.port if software.port != PORT_LOOKUP["NONE"] else None, software.protocol, ] ) diff --git a/src/primaite/simulator/system/services/arp/arp.py b/src/primaite/simulator/system/services/arp/arp.py index b8dd5f89..2641f1c8 100644 --- a/src/primaite/simulator/system/services/arp/arp.py +++ b/src/primaite/simulator/system/services/arp/arp.py @@ -8,10 +8,9 @@ from prettytable import MARKDOWN, PrettyTable from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.service import Service -from primaite.utils.validators import IPV4Address +from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP class ARP(Service): @@ -26,8 +25,8 @@ class ARP(Service): def __init__(self, **kwargs): kwargs["name"] = "ARP" - kwargs["port"] = Port["ARP"] - kwargs["protocol"] = IPProtocol["UDP"] + kwargs["port"] = PORT_LOOKUP["ARP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["UDP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 11ca9eb2..f9a5d087 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -7,12 +7,12 @@ from primaite import getLogger from primaite.simulator.file_system.file_system import File from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus from primaite.simulator.file_system.folder import Folder -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.service import Service, ServiceOperatingState from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -38,8 +38,8 @@ class DatabaseService(Service): def __init__(self, **kwargs): kwargs["name"] = "DatabaseService" - kwargs["port"] = Port["POSTGRES_SERVER"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self._create_db_file() diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 62f14366..316189a7 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -4,10 +4,10 @@ from typing import Dict, Optional from primaite import getLogger from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -22,11 +22,11 @@ class DNSClient(Service): def __init__(self, **kwargs): kwargs["name"] = "DNSClient" - kwargs["port"] = Port["DNS"] + kwargs["port"] = PORT_LOOKUP["DNS"] # DNS uses UDP by default # it switches to TCP when the bytes exceed 512 (or 4096) bytes # TCP for now - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self.start() @@ -95,7 +95,7 @@ class DNSClient(Service): # send a request to check if domain name exists in the DNS Server software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( - payload=payload, dest_ip_address=self.dns_server, dest_port=Port["DNS"] + payload=payload, dest_ip_address=self.dns_server, dest_port=PORT_LOOKUP["DNS"] ) # recursively re-call the function passing is_reattempt=True diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index 93895825..e0786124 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -6,9 +6,9 @@ from prettytable import MARKDOWN, PrettyTable from primaite import getLogger from primaite.simulator.network.protocols.dns import DNSPacket -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.service import Service +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -21,11 +21,11 @@ class DNSServer(Service): def __init__(self, **kwargs): kwargs["name"] = "DNSServer" - kwargs["port"] = Port["DNS"] + kwargs["port"] = PORT_LOOKUP["DNS"] # DNS uses UDP by default # it switches to TCP when the bytes exceed 512 (or 4096) bytes # TCP for now - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self.start() diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index 1fce4133..11a926cf 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -7,10 +7,10 @@ from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.file_system.file_system import File from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -25,8 +25,8 @@ class FTPClient(FTPServiceABC): def __init__(self, **kwargs): kwargs["name"] = "FTPClient" - kwargs["port"] = Port["FTP"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["FTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self.start() @@ -104,7 +104,7 @@ class FTPClient(FTPServiceABC): def _connect_to_server( self, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[int] = Port["FTP"], + dest_port: Optional[int] = PORT_LOOKUP["FTP"], session_id: Optional[str] = None, is_reattempt: Optional[bool] = False, ) -> bool: @@ -124,7 +124,7 @@ class FTPClient(FTPServiceABC): # normally FTP will choose a random port for the transfer, but using the FTP command port will do for now # create FTP packet - payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.PORT, ftp_command_args=Port["FTP"]) + payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.PORT, ftp_command_args=PORT_LOOKUP["FTP"]) if self.send(payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id): if payload.status_code == FTPStatusCode.OK: @@ -152,7 +152,7 @@ class FTPClient(FTPServiceABC): return False def _disconnect_from_server( - self, dest_ip_address: Optional[IPv4Address] = None, dest_port: Optional[int] = Port["FTP"] + self, dest_ip_address: Optional[IPv4Address] = None, dest_port: Optional[int] = PORT_LOOKUP["FTP"] ) -> bool: """ Connects the client from a given FTP server. @@ -179,7 +179,7 @@ class FTPClient(FTPServiceABC): src_file_name: str, dest_folder_name: str, dest_file_name: str, - dest_port: Optional[int] = Port["FTP"], + dest_port: Optional[int] = PORT_LOOKUP["FTP"], session_id: Optional[str] = None, ) -> bool: """ @@ -241,7 +241,7 @@ class FTPClient(FTPServiceABC): src_file_name: str, dest_folder_name: str, dest_file_name: str, - dest_port: Optional[int] = Port["FTP"], + dest_port: Optional[int] = PORT_LOOKUP["FTP"], ) -> bool: """ Request a file from a target IP address. diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index 701bff79..38a253be 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -3,9 +3,9 @@ from typing import Any, Optional from primaite import getLogger from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -23,8 +23,8 @@ class FTPServer(FTPServiceABC): def __init__(self, **kwargs): kwargs["name"] = "FTPServer" - kwargs["port"] = Port["FTP"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["FTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self.start() diff --git a/src/primaite/simulator/system/services/icmp/icmp.py b/src/primaite/simulator/system/services/icmp/icmp.py index a2dfac0d..486ba2b0 100644 --- a/src/primaite/simulator/system/services/icmp/icmp.py +++ b/src/primaite/simulator/system/services/icmp/icmp.py @@ -7,9 +7,9 @@ from primaite import getLogger from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.service import Service +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -26,8 +26,8 @@ class ICMP(Service): def __init__(self, **kwargs): kwargs["name"] = "ICMP" - kwargs["port"] = Port["NONE"] - kwargs["protocol"] = IPProtocol["ICMP"] + kwargs["port"] = PORT_LOOKUP["NONE"] + kwargs["protocol"] = PROTOCOL_LOOKUP["ICMP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index 40b8d273..184833e1 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -5,9 +5,9 @@ from typing import Dict, Optional from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.service import Service, ServiceOperatingState +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -21,8 +21,8 @@ class NTPClient(Service): def __init__(self, **kwargs): kwargs["name"] = "NTPClient" - kwargs["port"] = Port["NTP"] - kwargs["protocol"] = IPProtocol["UDP"] + kwargs["port"] = PORT_LOOKUP["NTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["UDP"] super().__init__(**kwargs) self.start() @@ -55,7 +55,7 @@ class NTPClient(Service): payload: NTPPacket, session_id: Optional[str] = None, dest_ip_address: IPv4Address = None, - dest_port: int = Port["NTP"], + dest_port: int = PORT_LOOKUP["NTP"], **kwargs, ) -> bool: """Requests NTP data from NTP server. diff --git a/src/primaite/simulator/system/services/ntp/ntp_server.py b/src/primaite/simulator/system/services/ntp/ntp_server.py index d9de40c6..4764bffb 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_server.py +++ b/src/primaite/simulator/system/services/ntp/ntp_server.py @@ -4,9 +4,9 @@ from typing import Dict, Optional from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.service import Service +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -16,8 +16,8 @@ class NTPServer(Service): def __init__(self, **kwargs): kwargs["name"] = "NTPServer" - kwargs["port"] = Port["NTP"] - kwargs["protocol"] = IPProtocol["UDP"] + kwargs["port"] = PORT_LOOKUP["NTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["UDP"] super().__init__(**kwargs) self.start() diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 41987aff..2b0bc02b 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -17,10 +17,10 @@ from primaite.simulator.network.protocols.ssh import ( SSHTransportMessage, SSHUserCredentials, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service, ServiceOperatingState +from primaite.utils.validators import PROTOCOL_LOOKUP # TODO 2824: Since remote terminal connections and remote user sessions are the same thing, we could refactor @@ -137,8 +137,8 @@ class Terminal(Service): def __init__(self, **kwargs): kwargs["name"] = "Terminal" - kwargs["port"] = Port["SSH"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["SSH"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index c021a86e..2805b1b2 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -10,11 +10,11 @@ from primaite.simulator.network.protocols.http import ( HttpResponsePacket, HttpStatusCode, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClientConnection from primaite.simulator.system.services.service import Service from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -49,10 +49,10 @@ class WebServer(Service): def __init__(self, **kwargs): kwargs["name"] = "WebServer" - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] # default for web is port 80 if kwargs.get("port") is None: - kwargs["port"] = Port["HTTP"] + kwargs["port"] = PORT_LOOKUP["HTTP"] super().__init__(**kwargs) self._install_web_files() diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 084bdaf6..d34678b9 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -13,9 +13,9 @@ from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.file_system.file_system import FileSystem, Folder from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState -from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.core.sys_log import SysLog +from primaite.utils.validators import PROTOCOL_LOOKUP if TYPE_CHECKING: from primaite.simulator.system.core.software_manager import SoftwareManager @@ -386,7 +386,7 @@ class IOSoftware(Software): session_id: Optional[str] = None, dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, dest_port: Optional[int] = None, - ip_protocol: str = IPProtocol["TCP"], + ip_protocol: str = PROTOCOL_LOOKUP["TCP"], **kwargs, ) -> bool: """ diff --git a/src/primaite/utils/validators.py b/src/primaite/utils/validators.py index 139d303c..f07b475d 100644 --- a/src/primaite/utils/validators.py +++ b/src/primaite/utils/validators.py @@ -6,6 +6,9 @@ from pydantic import BeforeValidator from typing_extensions import Annotated +# Define a custom type IPV4Address using the typing_extensions.Annotated. +# Annotated is used to attach metadata to type hints. In this case, it's used to associate the ipv4_validator +# with the IPv4Address type, ensuring that any usage of IPV4Address undergoes validation before assignment. def ipv4_validator(v: Any) -> IPv4Address: """ Validate the input and ensure it can be converted to an IPv4Address instance. @@ -24,9 +27,6 @@ def ipv4_validator(v: Any) -> IPv4Address: return IPv4Address(v) -# Define a custom type IPV4Address using the typing_extensions.Annotated. -# Annotated is used to attach metadata to type hints. In this case, it's used to associate the ipv4_validator -# with the IPv4Address type, ensuring that any usage of IPV4Address undergoes validation before assignment. IPV4Address: Final[Annotated] = Annotated[IPv4Address, BeforeValidator(ipv4_validator)] """ IPv4Address with with IPv4Address with with pre-validation and auto-conversion from str using ipv4_validator.. @@ -37,3 +37,39 @@ will automatically check and convert the input value to an instance of IPv4Addre any Pydantic model uses it. This ensures that any field marked with this type is not just an IPv4Address in form, but also valid according to the rules defined in ipv4_validator. """ + +# Define a custom port validator +Port: Final[Annotated] = Annotated[int, BeforeValidator(lambda n: 0 <= n <= 65535)] +"""Validates that network ports lie in the appropriate range of [0,65535].""" + +# Define a custom IP protocol validator +PROTOCOL_LOOKUP: dict[str, str] = dict( + NONE="none", + TCP="tcp", + UDP="udp", + ICMP="icmp", +) +""" +Lookup table used for compatibility with PrimAITE <= 3.3. Configs with the capitalised protocol names are converted +to lowercase at runtime. +""" +VALID_PROTOCOLS = ["none", "tcp", "udp", "icmp"] +"""Supported protocols.""" + + +def protocol_validator(v: Any) -> str: + """ + Validate that IP Protocols are chosen from the list of supported IP Protocols. + + The protocol list is dynamic because plugins are able to extend it, therefore it is necessary to use this custom + validator instead of being able to specify a union of string literals. + """ + if v in PROTOCOL_LOOKUP: + return PROTOCOL_LOOKUP(v) + if v in VALID_PROTOCOLS: + return v + raise ValueError(f"{v} is not a valid IP Protocol. It must be one of the following: {VALID_PROTOCOLS}") + + +IPProtocol: Final[Annotated] = Annotated[str, BeforeValidator(protocol_validator)] +"""Validates that IP Protocols used in the simulation belong to the list of supported protocols.""" diff --git a/tests/conftest.py b/tests/conftest.py index 1ffa2146..687bec92 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,8 +18,7 @@ from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.networks import arcd_uc2_network -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.web_browser import WebBrowser @@ -28,6 +27,7 @@ from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validators import PROTOCOL_LOOKUP from tests import TEST_ASSETS_ROOT rayinit() @@ -45,8 +45,8 @@ class DummyService(Service): def __init__(self, **kwargs): kwargs["name"] = "DummyService" - kwargs["port"] = Port["HTTP"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["HTTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def receive(self, payload: Any, session_id: str, **kwargs) -> bool: @@ -58,8 +58,8 @@ class DummyApplication(Application, identifier="DummyApplication"): def __init__(self, **kwargs): kwargs["name"] = "DummyApplication" - kwargs["port"] = Port["HTTP"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["HTTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: @@ -77,7 +77,7 @@ def uc2_network() -> Network: @pytest.fixture(scope="function") def service(file_system) -> DummyService: return DummyService( - name="DummyService", port=Port["ARP"], file_system=file_system, sys_log=SysLog(hostname="dummy_service") + name="DummyService", port=PORT_LOOKUP["ARP"], file_system=file_system, sys_log=SysLog(hostname="dummy_service") ) @@ -90,7 +90,7 @@ def service_class(): def application(file_system) -> DummyApplication: return DummyApplication( name="DummyApplication", - port=Port["ARP"], + port=PORT_LOOKUP["ARP"], file_system=file_system, sys_log=SysLog(hostname="dummy_application"), ) @@ -350,10 +350,10 @@ def install_stuff_to_sim(sim: Simulation): network.connect(endpoint_a=server_2.network_interface[1], endpoint_b=switch_2.network_interface[2]) # 2: Configure base ACL - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["DNS"], dst_port=Port["DNS"], position=1) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=3) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22) + router.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["DNS"], dst_port=PORT_LOOKUP["DNS"], position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=3) # 3: Install server software server_1.software_manager.install(DNSServer) @@ -379,13 +379,13 @@ def install_stuff_to_sim(sim: Simulation): r = sim.network.router_nodes[0] for i, acl_rule in enumerate(r.acl.acl): if i == 1: - assert acl_rule.src_port == acl_rule.dst_port == Port["DNS"] + assert acl_rule.src_port == acl_rule.dst_port == PORT_LOOKUP["DNS"] elif i == 3: - assert acl_rule.src_port == acl_rule.dst_port == Port["HTTP"] + assert acl_rule.src_port == acl_rule.dst_port == PORT_LOOKUP["HTTP"] elif i == 22: - assert acl_rule.src_port == acl_rule.dst_port == Port["ARP"] + assert acl_rule.src_port == acl_rule.dst_port == PORT_LOOKUP["ARP"] elif i == 23: - assert acl_rule.protocol == IPProtocol["ICMP"] + assert acl_rule.protocol == PROTOCOL_LOOKUP["ICMP"] elif i == 24: ... else: diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py index d35e2ebb..6d0ef7b0 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py @@ -9,8 +9,8 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validators import PROTOCOL_LOOKUP from tests.integration_tests.configuration_file_parsing import BASIC_FIREWALL, DMZ_NETWORK, load_config @@ -68,44 +68,44 @@ def test_firewall_acl_rules_correctly_added(dmz_config): # ICMP and ARP should be allowed internal_inbound assert firewall.internal_inbound_acl.num_rules == 2 assert firewall.internal_inbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.internal_inbound_acl.acl[22].src_port == Port["ARP"] - assert firewall.internal_inbound_acl.acl[22].dst_port == Port["ARP"] + assert firewall.internal_inbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.internal_inbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert firewall.internal_inbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.internal_inbound_acl.acl[23].protocol == IPProtocol["ICMP"] + assert firewall.internal_inbound_acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert firewall.internal_inbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed internal_outbound assert firewall.internal_outbound_acl.num_rules == 2 assert firewall.internal_outbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.internal_outbound_acl.acl[22].src_port == Port["ARP"] - assert firewall.internal_outbound_acl.acl[22].dst_port == Port["ARP"] + assert firewall.internal_outbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.internal_outbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert firewall.internal_outbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.internal_outbound_acl.acl[23].protocol == IPProtocol["ICMP"] + assert firewall.internal_outbound_acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert firewall.internal_outbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed dmz_inbound assert firewall.dmz_inbound_acl.num_rules == 2 assert firewall.dmz_inbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.dmz_inbound_acl.acl[22].src_port == Port["ARP"] - assert firewall.dmz_inbound_acl.acl[22].dst_port == Port["ARP"] + assert firewall.dmz_inbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.dmz_inbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert firewall.dmz_inbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.dmz_inbound_acl.acl[23].protocol == IPProtocol["ICMP"] + assert firewall.dmz_inbound_acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert firewall.dmz_inbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed dmz_outbound assert firewall.dmz_outbound_acl.num_rules == 2 assert firewall.dmz_outbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.dmz_outbound_acl.acl[22].src_port == Port["ARP"] - assert firewall.dmz_outbound_acl.acl[22].dst_port == Port["ARP"] + assert firewall.dmz_outbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.dmz_outbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert firewall.dmz_outbound_acl.acl[23].action == ACLAction.PERMIT - assert firewall.dmz_outbound_acl.acl[23].protocol == IPProtocol["ICMP"] + assert firewall.dmz_outbound_acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert firewall.dmz_outbound_acl.implicit_action == ACLAction.DENY # ICMP and ARP should be allowed external_inbound assert firewall.external_inbound_acl.num_rules == 1 assert firewall.external_inbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.external_inbound_acl.acl[22].src_port == Port["ARP"] - assert firewall.external_inbound_acl.acl[22].dst_port == Port["ARP"] + assert firewall.external_inbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.external_inbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] # external_inbound should have implicit action PERMIT # ICMP does not have a provided ACL Rule but implicit action should allow anything assert firewall.external_inbound_acl.implicit_action == ACLAction.PERMIT @@ -113,8 +113,8 @@ def test_firewall_acl_rules_correctly_added(dmz_config): # ICMP and ARP should be allowed external_outbound assert firewall.external_outbound_acl.num_rules == 1 assert firewall.external_outbound_acl.acl[22].action == ACLAction.PERMIT - assert firewall.external_outbound_acl.acl[22].src_port == Port["ARP"] - assert firewall.external_outbound_acl.acl[22].dst_port == Port["ARP"] + assert firewall.external_outbound_acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert firewall.external_outbound_acl.acl[22].dst_port == PORT_LOOKUP["ARP"] # external_outbound should have implicit action PERMIT # ICMP does not have a provided ACL Rule but implicit action should allow anything assert firewall.external_outbound_acl.implicit_action == ACLAction.PERMIT diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py index 16543565..c348ee81 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py @@ -6,8 +6,8 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validators import PROTOCOL_LOOKUP from tests.integration_tests.configuration_file_parsing import DMZ_NETWORK, load_config @@ -63,8 +63,8 @@ def test_router_acl_rules_correctly_added(dmz_config): # ICMP and ARP should be allowed assert router_1.acl.num_rules == 2 assert router_1.acl.acl[22].action == ACLAction.PERMIT - assert router_1.acl.acl[22].src_port == Port["ARP"] - assert router_1.acl.acl[22].dst_port == Port["ARP"] + assert router_1.acl.acl[22].src_port == PORT_LOOKUP["ARP"] + assert router_1.acl.acl[22].dst_port == PORT_LOOKUP["ARP"] assert router_1.acl.acl[23].action == ACLAction.PERMIT - assert router_1.acl.acl[23].protocol == IPProtocol["ICMP"] + assert router_1.acl.acl[23].protocol == PROTOCOL_LOOKUP["ICMP"] assert router_1.acl.implicit_action == ACLAction.DENY diff --git a/tests/integration_tests/extensions/applications/extended_application.py b/tests/integration_tests/extensions/applications/extended_application.py index 8e3d33e1..28029b32 100644 --- a/tests/integration_tests/extensions/applications/extended_application.py +++ b/tests/integration_tests/extensions/applications/extended_application.py @@ -15,11 +15,11 @@ from primaite.simulator.network.protocols.http import ( HttpResponsePacket, HttpStatusCode, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.dns.dns_client import DNSClient +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -44,10 +44,10 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): def __init__(self, **kwargs): kwargs["name"] = "ExtendedApplication" - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] # default for web is port 80 if kwargs.get("port") is None: - kwargs["port"] = Port["HTTP"] + kwargs["port"] = PORT_LOOKUP["HTTP"] super().__init__(**kwargs) self.run() @@ -127,7 +127,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): if self.send( payload=payload, dest_ip_address=self.domain_name_ip_address, - dest_port=parsed_url.port if parsed_url.port else Port["HTTP"], + dest_port=parsed_url.port if parsed_url.port else PORT_LOOKUP["HTTP"], ): self.sys_log.info( f"{self.name}: Received HTTP {payload.request_method.name} " @@ -155,7 +155,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"): self, payload: HttpRequestPacket, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[int] = Port["HTTP"], + dest_port: Optional[int] = PORT_LOOKUP["HTTP"], session_id: Optional[str] = None, **kwargs, ) -> bool: diff --git a/tests/integration_tests/extensions/services/extended_service.py b/tests/integration_tests/extensions/services/extended_service.py index b745b774..70d47aaa 100644 --- a/tests/integration_tests/extensions/services/extended_service.py +++ b/tests/integration_tests/extensions/services/extended_service.py @@ -7,12 +7,12 @@ from primaite import getLogger from primaite.simulator.file_system.file_system import File from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus from primaite.simulator.file_system.folder import Folder -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.service import Service, ServiceOperatingState from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validators import PROTOCOL_LOOKUP _LOGGER = getLogger(__name__) @@ -38,8 +38,8 @@ class ExtendedService(Service, identifier="extendedservice"): def __init__(self, **kwargs): kwargs["name"] = "ExtendedService" - kwargs["port"] = Port["POSTGRES_SERVER"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self._create_db_file() if kwargs.get("options"): diff --git a/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py b/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py index 17b0ba8c..2c750621 100644 --- a/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py +++ b/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py @@ -11,7 +11,7 @@ from primaite.simulator.network.hardware.base import UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.red_applications.c2.c2_beacon import C2Beacon from primaite.simulator.system.applications.red_applications.c2.c2_server import C2Command, C2Server from primaite.simulator.system.services.database.database_service import DatabaseService @@ -26,9 +26,9 @@ def game_and_agent_fixture(game_and_agent): game, agent = game_and_agent router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=4) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["DNS"], dst_port=Port["DNS"], position=5) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["FTP"], dst_port=Port["FTP"], position=6) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=4) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["DNS"], dst_port=PORT_LOOKUP["DNS"], position=5) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=6) c2_server_host = game.simulation.network.get_node_by_hostname("client_1") c2_server_host.software_manager.install(software_class=C2Server) diff --git a/tests/integration_tests/game_layer/actions/test_configure_actions.py b/tests/integration_tests/game_layer/actions/test_configure_actions.py index 34ee25d6..b56a4b99 100644 --- a/tests/integration_tests/game_layer/actions/test_configure_actions.py +++ b/tests/integration_tests/game_layer/actions/test_configure_actions.py @@ -11,7 +11,7 @@ from primaite.game.agent.actions import ( ) from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot @@ -200,7 +200,7 @@ class TestConfigureDoSBot: game.step() assert dos_bot.target_ip_address == IPv4Address("192.168.1.99") - assert dos_bot.target_port == Port["POSTGRES_SERVER"] + assert dos_bot.target_port == PORT_LOOKUP["POSTGRES_SERVER"] assert dos_bot.payload == "HACC" assert not dos_bot.repeat assert dos_bot.port_scan_p_of_success == 0.875 diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py index 857edd26..bc168c3c 100644 --- a/tests/integration_tests/game_layer/actions/test_terminal_actions.py +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -9,7 +9,7 @@ from primaite.simulator.network.hardware.base import UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection @@ -20,7 +20,7 @@ def game_and_agent_fixture(game_and_agent): game, agent = game_and_agent router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["SSH"], dst_port=Port["SSH"], position=4) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=4) return (game, agent) diff --git a/tests/integration_tests/game_layer/observations/test_acl_observations.py b/tests/integration_tests/game_layer/observations/test_acl_observations.py index 28f9ac5a..2bf0486c 100644 --- a/tests/integration_tests/game_layer/observations/test_acl_observations.py +++ b/tests/integration_tests/game_layer/observations/test_acl_observations.py @@ -4,7 +4,7 @@ import pytest from primaite.game.agent.observations.acl_observation import ACLObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.sim_container import Simulation from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.ntp.ntp_server import NTPServer @@ -33,7 +33,7 @@ def test_acl_observations(simulation): server.software_manager.install(NTPServer) # add router acl rule - router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port["NTP"], src_port=Port["NTP"], position=1) + router.acl.add_rule(action=ACLAction.PERMIT, dst_port=PORT_LOOKUP["NTP"], src_port=PORT_LOOKUP["NTP"], position=1) acl_obs = ACLObservation( where=["network", "nodes", router.hostname, "acl", "acl"], diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 21fe4bed..af8c4669 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -5,8 +5,8 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import ACLAction from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validators import PROTOCOL_LOOKUP def check_default_rules(acl_obs): @@ -62,13 +62,13 @@ def test_firewall_observation(): # add a rule to the internal inbound and check that the observation is correct firewall.internal_inbound_acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol["TCP"], + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="10.0.0.1", src_wildcard_mask="0.0.0.1", dst_ip_address="10.0.0.2", dst_wildcard_mask="0.0.0.1", - src_port=Port["HTTP"], - dst_port=Port["HTTP"], + src_port=PORT_LOOKUP["HTTP"], + dst_port=PORT_LOOKUP["HTTP"], position=5, ) diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index c28e1bb8..cdd428b0 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -8,9 +8,9 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.sim_container import Simulation +from primaite.utils.validators import PROTOCOL_LOOKUP def test_router_observation(): @@ -39,13 +39,13 @@ def test_router_observation(): # Add an ACL rule to the router router.acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol["TCP"], + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="10.0.0.1", src_wildcard_mask="0.0.0.1", dst_ip_address="10.0.0.2", dst_wildcard_mask="0.0.0.1", - src_port=Port["HTTP"], - dst_port=Port["HTTP"], + src_port=PORT_LOOKUP["HTTP"], + dst_port=PORT_LOOKUP["HTTP"], position=5, ) # Observe the state using the RouterObservation instance diff --git a/tests/integration_tests/game_layer/observations/test_user_observations.py b/tests/integration_tests/game_layer/observations/test_user_observations.py index 6ca4bc9e..70637b0d 100644 --- a/tests/integration_tests/game_layer/observations/test_user_observations.py +++ b/tests/integration_tests/game_layer/observations/test_user_observations.py @@ -3,7 +3,7 @@ import pytest from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from tests import TEST_ASSETS_ROOT DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml" @@ -15,7 +15,7 @@ def env_with_ssh() -> PrimaiteGymEnv: env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG) env.agent.flatten_obs = False router: Router = env.game.simulation.network.get_node_by_hostname("router_1") - router.acl.add_rule(ACLAction.PERMIT, src_port=Port["SSH"], dst_port=Port["SSH"], position=3) + router.acl.add_rule(ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=3) return env diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index c3e86263..2675b615 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -21,11 +21,11 @@ from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validators import PROTOCOL_LOOKUP from tests import TEST_ASSETS_ROOT FIREWALL_ACTIONS_NETWORK = TEST_ASSETS_ROOT / "configs/firewall_actions_network.yaml" @@ -608,9 +608,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.internal_outbound_acl.acl[1].action.name == "DENY" assert firewall.internal_outbound_acl.acl[1].src_ip_address == IPv4Address("192.168.0.10") assert firewall.internal_outbound_acl.acl[1].dst_ip_address is None - assert firewall.internal_outbound_acl.acl[1].dst_port == Port["DNS"] - assert firewall.internal_outbound_acl.acl[1].src_port == Port["ARP"] - assert firewall.internal_outbound_acl.acl[1].protocol == IPProtocol["ICMP"] + assert firewall.internal_outbound_acl.acl[1].dst_port == PORT_LOOKUP["DNS"] + assert firewall.internal_outbound_acl.acl[1].src_port == PORT_LOOKUP["ARP"] + assert firewall.internal_outbound_acl.acl[1].protocol == PROTOCOL_LOOKUP["ICMP"] env.step(4) # Remove ACL rule from Internal Outbound assert firewall.internal_outbound_acl.num_rules == 2 @@ -620,9 +620,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.dmz_inbound_acl.acl[1].action.name == "DENY" assert firewall.dmz_inbound_acl.acl[1].src_ip_address == IPv4Address("192.168.10.10") assert firewall.dmz_inbound_acl.acl[1].dst_ip_address == IPv4Address("192.168.0.10") - assert firewall.dmz_inbound_acl.acl[1].dst_port == Port["HTTP"] - assert firewall.dmz_inbound_acl.acl[1].src_port == Port["HTTP"] - assert firewall.dmz_inbound_acl.acl[1].protocol == IPProtocol["UDP"] + assert firewall.dmz_inbound_acl.acl[1].dst_port == PORT_LOOKUP["HTTP"] + assert firewall.dmz_inbound_acl.acl[1].src_port == PORT_LOOKUP["HTTP"] + assert firewall.dmz_inbound_acl.acl[1].protocol == PROTOCOL_LOOKUP["UDP"] env.step(6) # Remove ACL rule from DMZ Inbound assert firewall.dmz_inbound_acl.num_rules == 2 @@ -632,9 +632,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.dmz_outbound_acl.acl[2].action.name == "DENY" assert firewall.dmz_outbound_acl.acl[2].src_ip_address == IPv4Address("192.168.10.10") assert firewall.dmz_outbound_acl.acl[2].dst_ip_address == IPv4Address("192.168.0.10") - assert firewall.dmz_outbound_acl.acl[2].dst_port == Port["HTTP"] - assert firewall.dmz_outbound_acl.acl[2].src_port == Port["HTTP"] - assert firewall.dmz_outbound_acl.acl[2].protocol == IPProtocol["TCP"] + assert firewall.dmz_outbound_acl.acl[2].dst_port == PORT_LOOKUP["HTTP"] + assert firewall.dmz_outbound_acl.acl[2].src_port == PORT_LOOKUP["HTTP"] + assert firewall.dmz_outbound_acl.acl[2].protocol == PROTOCOL_LOOKUP["TCP"] env.step(8) # Remove ACL rule from DMZ Outbound assert firewall.dmz_outbound_acl.num_rules == 2 @@ -644,9 +644,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.external_inbound_acl.acl[10].action.name == "DENY" assert firewall.external_inbound_acl.acl[10].src_ip_address == IPv4Address("192.168.20.10") assert firewall.external_inbound_acl.acl[10].dst_ip_address == IPv4Address("192.168.10.10") - assert firewall.external_inbound_acl.acl[10].dst_port == Port["POSTGRES_SERVER"] - assert firewall.external_inbound_acl.acl[10].src_port == Port["POSTGRES_SERVER"] - assert firewall.external_inbound_acl.acl[10].protocol == IPProtocol["ICMP"] + assert firewall.external_inbound_acl.acl[10].dst_port == PORT_LOOKUP["POSTGRES_SERVER"] + assert firewall.external_inbound_acl.acl[10].src_port == PORT_LOOKUP["POSTGRES_SERVER"] + assert firewall.external_inbound_acl.acl[10].protocol == PROTOCOL_LOOKUP["ICMP"] env.step(10) # Remove ACL rule from External Inbound assert firewall.external_inbound_acl.num_rules == 1 diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 570c4ad6..0afe666c 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -9,11 +9,11 @@ from primaite.interface.request import RequestResponse from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService +from primaite.utils.validators import PROTOCOL_LOOKUP from tests import TEST_ASSETS_ROOT from tests.conftest import ControlledAgent @@ -42,7 +42,12 @@ def test_WebpageUnavailablePenalty(game_and_agent): # Block the web traffic, check that failing to fetch the webpage yields a reward of -0.7 router: Router = game.simulation.network.get_node_by_hostname("router") - router.acl.add_rule(action=ACLAction.DENY, protocol=IPProtocol["TCP"], src_port=Port["HTTP"], dst_port=Port["HTTP"]) + router.acl.add_rule( + action=ACLAction.DENY, + protocol=PROTOCOL_LOOKUP["TCP"], + src_port=PORT_LOOKUP["HTTP"], + dst_port=PORT_LOOKUP["HTTP"], + ) agent.store_action(("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0})) game.step() assert agent.reward_function.current_reward == -0.7 @@ -66,7 +71,7 @@ def test_uc2_rewards(game_and_agent): router: Router = game.simulation.network.get_node_by_hostname("router") router.acl.add_rule( - ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=2 + ACLAction.PERMIT, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"], position=2 ) comp = GreenAdminDatabaseUnreachablePenalty("client_1") diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index da0af89d..b5b2acbc 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -8,10 +8,10 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.services.service import Service +from primaite.utils.validators import PROTOCOL_LOOKUP class BroadcastTestService(Service): @@ -20,8 +20,8 @@ class BroadcastTestService(Service): def __init__(self, **kwargs): # Set default service properties for broadcasting kwargs["name"] = "BroadcastService" - kwargs["port"] = Port["HTTP"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["HTTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: @@ -33,12 +33,14 @@ class BroadcastTestService(Service): super().send( payload="unicast", dest_ip_address=ip_address, - dest_port=Port["HTTP"], + dest_port=PORT_LOOKUP["HTTP"], ) def broadcast(self, ip_network: IPv4Network): # Send a broadcast payload to an entire IP network - super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port["HTTP"], ip_protocol=self.protocol) + super().send( + payload="broadcast", dest_ip_address=ip_network, dest_port=PORT_LOOKUP["HTTP"], ip_protocol=self.protocol + ) class BroadcastTestClient(Application, identifier="BroadcastTestClient"): @@ -49,8 +51,8 @@ class BroadcastTestClient(Application, identifier="BroadcastTestClient"): def __init__(self, **kwargs): # Set default client properties kwargs["name"] = "BroadcastTestClient" - kwargs["port"] = Port["HTTP"] - kwargs["protocol"] = IPProtocol["TCP"] + kwargs["port"] = PORT_LOOKUP["HTTP"] + kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) def describe_state(self) -> Dict: diff --git a/tests/integration_tests/network/test_firewall.py b/tests/integration_tests/network/test_firewall.py index 44b660cf..58763c3e 100644 --- a/tests/integration_tests/network/test_firewall.py +++ b/tests/integration_tests/network/test_firewall.py @@ -7,10 +7,10 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.ntp.ntp_server import NTPServer +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -53,31 +53,31 @@ def dmz_external_internal_network() -> Network: ) # Allow ICMP - firewall_node.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) - firewall_node.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) - firewall_node.external_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) - firewall_node.external_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) - firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) - firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + firewall_node.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.external_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.external_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Allow ARP firewall_node.internal_inbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) firewall_node.internal_outbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) firewall_node.external_inbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) firewall_node.external_outbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) firewall_node.dmz_inbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) firewall_node.dmz_outbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) # external node @@ -267,10 +267,10 @@ def test_service_allowed_with_rule(dmz_external_internal_network): assert not internal_ntp_client.time firewall.internal_outbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=1 ) firewall.internal_inbound_acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1 + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=1 ) internal_ntp_client.request_time() @@ -279,8 +279,12 @@ def test_service_allowed_with_rule(dmz_external_internal_network): assert not dmz_ntp_client.time - firewall.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1) - firewall.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1) + firewall.dmz_outbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=1 + ) + firewall.dmz_inbound_acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=1 + ) dmz_ntp_client.request_time() diff --git a/tests/integration_tests/network/test_routing.py b/tests/integration_tests/network/test_routing.py index 641342e2..dde66a43 100644 --- a/tests/integration_tests/network/test_routing.py +++ b/tests/integration_tests/network/test_routing.py @@ -6,10 +6,10 @@ import pytest from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.ntp.ntp_server import NTPServer +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -73,8 +73,10 @@ def multi_hop_network() -> Network: router_1.enable_port(2) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Configure PC B pc_b = Computer( @@ -197,8 +199,12 @@ def test_routing_services(multi_hop_network): router_1: Router = multi_hop_network.get_node_by_hostname("router_1") # noqa router_2: Router = multi_hop_network.get_node_by_hostname("router_2") # noqa - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=21) - router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=21) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=21 + ) + router_2.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"], position=21 + ) assert ntp_client.time is None ntp_client.request_time() diff --git a/tests/integration_tests/network/test_wireless_router.py b/tests/integration_tests/network/test_wireless_router.py index 2f1be930..520ec21a 100644 --- a/tests/integration_tests/network/test_wireless_router.py +++ b/tests/integration_tests/network/test_wireless_router.py @@ -7,8 +7,8 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validators import PROTOCOL_LOOKUP from tests import TEST_ASSETS_ROOT @@ -37,8 +37,10 @@ def wireless_wan_network(): network.connect(pc_a.network_interface[1], router_1.network_interface[2]) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Configure PC B pc_b = Computer( diff --git a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py index d819b511..b1979154 100644 --- a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py +++ b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py @@ -13,8 +13,7 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import AccessControlList, ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.c2.c2_beacon import C2Beacon @@ -25,6 +24,7 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validators import PROTOCOL_LOOKUP from tests import TEST_ASSETS_ROOT @@ -227,7 +227,7 @@ def test_c2_suite_acl_block(basic_network): assert c2_beacon.c2_connection_active == True # Now we add a HTTP blocking acl (Thus preventing a keep alive) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0) c2_beacon.apply_timestep(2) c2_beacon.apply_timestep(3) @@ -322,8 +322,8 @@ def test_c2_suite_acl_bypass(basic_network): ################ Confirm Default Setup ######################### # Permitting all HTTP & FTP traffic - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["FTP"], dst_port=Port["FTP"], position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=1) c2_beacon.apply_timestep(0) assert c2_beacon.keep_alive_inactivity == 1 @@ -337,7 +337,7 @@ def test_c2_suite_acl_bypass(basic_network): ################ Denying HTTP Traffic ######################### # Now we add a HTTP blocking acl (Thus preventing a keep alive) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0) blocking_acl: AccessControlList = router.acl.acl[0] # Asserts to show the C2 Suite is unable to maintain connection: @@ -359,8 +359,8 @@ def test_c2_suite_acl_bypass(basic_network): c2_beacon.configure( c2_server_ip_address="192.168.0.2", keep_alive_frequency=2, - masquerade_port=Port["FTP"], - masquerade_protocol=IPProtocol["TCP"], + masquerade_port=PORT_LOOKUP["FTP"], + masquerade_protocol=PROTOCOL_LOOKUP["TCP"], ) c2_beacon.establish() @@ -407,8 +407,8 @@ def test_c2_suite_acl_bypass(basic_network): ################ Denying FTP Traffic & Enable HTTP ######################### # Blocking FTP and re-permitting HTTP: - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port["FTP"], dst_port=Port["FTP"], position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0) + router.acl.add_rule(action=ACLAction.DENY, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=1) blocking_acl: AccessControlList = router.acl.acl[1] # Asserts to show the C2 Suite is unable to maintain connection: @@ -430,8 +430,8 @@ def test_c2_suite_acl_bypass(basic_network): c2_beacon.configure( c2_server_ip_address="192.168.0.2", keep_alive_frequency=2, - masquerade_port=Port["HTTP"], - masquerade_protocol=IPProtocol["TCP"], + masquerade_port=PORT_LOOKUP["HTTP"], + masquerade_protocol=PROTOCOL_LOOKUP["TCP"], ) c2_beacon.establish() diff --git a/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py b/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py index 9c0760b7..54c372e4 100644 --- a/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py +++ b/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py @@ -9,7 +9,7 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection from primaite.simulator.system.applications.red_applications.data_manipulation_bot import ( @@ -52,7 +52,10 @@ def data_manipulation_db_server_green_client(example_network) -> Network: router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) client_1: Computer = network.get_node_by_hostname("client_1") diff --git a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py index 709a417f..ad0a519b 100644 --- a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py +++ b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py @@ -8,7 +8,7 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot @@ -26,7 +26,7 @@ def dos_bot_and_db_server(client_server) -> Tuple[DoSBot, Computer, DatabaseServ dos_bot: DoSBot = computer.software_manager.software.get("DoSBot") dos_bot.configure( target_ip_address=IPv4Address(server.network_interface[1].ip_address), - target_port=Port["POSTGRES_SERVER"], + target_port=PORT_LOOKUP["POSTGRES_SERVER"], ) # Install DB Server service on server @@ -43,7 +43,10 @@ def dos_bot_db_server_green_client(example_network) -> Network: router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) client_1: Computer = network.get_node_by_hostname("client_1") @@ -56,7 +59,7 @@ def dos_bot_db_server_green_client(example_network) -> Network: dos_bot: DoSBot = client_1.software_manager.software.get("DoSBot") dos_bot.configure( target_ip_address=IPv4Address(server.network_interface[1].ip_address), - target_port=Port["POSTGRES_SERVER"], + target_port=PORT_LOOKUP["POSTGRES_SERVER"], ) # install db server service on server diff --git a/tests/integration_tests/system/red_applications/test_ransomware_script.py b/tests/integration_tests/system/red_applications/test_ransomware_script.py index b34e9b30..09cbcf85 100644 --- a/tests/integration_tests/system/red_applications/test_ransomware_script.py +++ b/tests/integration_tests/system/red_applications/test_ransomware_script.py @@ -9,7 +9,7 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.database.database_service import DatabaseService @@ -47,7 +47,10 @@ def ransomware_script_db_server_green_client(example_network) -> Network: router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) client_1: Computer = network.get_node_by_hostname("client_1") diff --git a/tests/integration_tests/system/test_nmap.py b/tests/integration_tests/system/test_nmap.py index 9d92b660..c1c4df82 100644 --- a/tests/integration_tests/system/test_nmap.py +++ b/tests/integration_tests/system/test_nmap.py @@ -5,9 +5,9 @@ from ipaddress import IPv4Address, IPv4Network import yaml from primaite.game.game import PrimaiteGame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.nmap import NMAP +from primaite.utils.validators import PROTOCOL_LOOKUP from tests import TEST_ASSETS_ROOT @@ -74,11 +74,11 @@ def test_port_scan_one_node_one_port(example_network): actual_result = client_1_nmap.port_scan( target_ip_address=client_2.network_interface[1].ip_address, - target_port=Port["DNS"], - target_protocol=IPProtocol["TCP"], + target_port=PORT_LOOKUP["DNS"], + target_protocol=PROTOCOL_LOOKUP["TCP"], ) - expected_result = {IPv4Address("192.168.10.22"): {IPProtocol["TCP"]: [Port["DNS"]]}} + expected_result = {IPv4Address("192.168.10.22"): {PROTOCOL_LOOKUP["TCP"]: [PORT_LOOKUP["DNS"]]}} assert actual_result == expected_result @@ -103,14 +103,20 @@ def test_port_scan_full_subnet_all_ports_and_protocols(example_network): actual_result = client_1_nmap.port_scan( target_ip_address=IPv4Network("192.168.10.0/24"), - target_port=[Port["ARP"], Port["HTTP"], Port["FTP"], Port["DNS"], Port["NTP"]], + target_port=[ + PORT_LOOKUP["ARP"], + PORT_LOOKUP["HTTP"], + PORT_LOOKUP["FTP"], + PORT_LOOKUP["DNS"], + PORT_LOOKUP["NTP"], + ], ) expected_result = { - IPv4Address("192.168.10.1"): {IPProtocol["UDP"]: [Port["ARP"]]}, + IPv4Address("192.168.10.1"): {PROTOCOL_LOOKUP["UDP"]: [PORT_LOOKUP["ARP"]]}, IPv4Address("192.168.10.22"): { - IPProtocol["TCP"]: [Port["HTTP"], Port["FTP"], Port["DNS"]], - IPProtocol["UDP"]: [Port["ARP"], Port["NTP"]], + PROTOCOL_LOOKUP["TCP"]: [PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]], + PROTOCOL_LOOKUP["UDP"]: [PORT_LOOKUP["ARP"], PORT_LOOKUP["NTP"]], }, } @@ -124,10 +130,12 @@ def test_network_service_recon_all_ports_and_protocols(example_network): client_1_nmap: NMAP = client_1.software_manager.software["NMAP"] # noqa actual_result = client_1_nmap.network_service_recon( - target_ip_address=IPv4Network("192.168.10.0/24"), target_port=Port["HTTP"], target_protocol=IPProtocol["TCP"] + target_ip_address=IPv4Network("192.168.10.0/24"), + target_port=PORT_LOOKUP["HTTP"], + target_protocol=PROTOCOL_LOOKUP["TCP"], ) - expected_result = {IPv4Address("192.168.10.22"): {IPProtocol["TCP"]: [Port["HTTP"]]}} + expected_result = {IPv4Address("192.168.10.22"): {PROTOCOL_LOOKUP["TCP"]: [PORT_LOOKUP["HTTP"]]}} assert sort_dict(actual_result) == sort_dict(expected_result) diff --git a/tests/integration_tests/system/test_service_listening_on_ports.py b/tests/integration_tests/system/test_service_listening_on_ports.py index 5226ab4a..4108041d 100644 --- a/tests/integration_tests/system/test_service_listening_on_ports.py +++ b/tests/integration_tests/system/test_service_listening_on_ports.py @@ -6,19 +6,19 @@ from pydantic import Field from primaite.game.game import PrimaiteGame from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.service import Service +from primaite.utils.validators import PROTOCOL_LOOKUP from tests import TEST_ASSETS_ROOT class _DatabaseListener(Service): name: str = "DatabaseListener" - protocol: str = IPProtocol["TCP"] - port: int = Port["NONE"] - listen_on_ports: Set[int] = {Port["POSTGRES_SERVER"]} + protocol: str = PROTOCOL_LOOKUP["TCP"] + port: int = PORT_LOOKUP["NONE"] + listen_on_ports: Set[int] = {PORT_LOOKUP["POSTGRES_SERVER"]} payloads_received: List[Any] = Field(default_factory=list) def receive(self, payload: Any, session_id: str, **kwargs) -> bool: @@ -51,8 +51,8 @@ def test_http_listener(client_server): computer.session_manager.receive_payload_from_software_manager( payload="masquerade as Database traffic", dst_ip_address=server.network_interface[1].ip_address, - dst_port=Port["POSTGRES_SERVER"], - ip_protocol=IPProtocol["TCP"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + ip_protocol=PROTOCOL_LOOKUP["TCP"], ) assert len(server_db_listener.payloads_received) == 1 @@ -76,9 +76,9 @@ def test_set_listen_on_ports_from_config(): network = PrimaiteGame.from_config(cfg=config_dict).simulation.network client: Computer = network.get_node_by_hostname("client") - assert Port["SMB"] in client.software_manager.get_open_ports() - assert Port["IPP"] in client.software_manager.get_open_ports() + assert PORT_LOOKUP["SMB"] in client.software_manager.get_open_ports() + assert PORT_LOOKUP["IPP"] in client.software_manager.get_open_ports() web_browser = client.software_manager.software["WebBrowser"] - assert not web_browser.listen_on_ports.difference({Port["SMB"], Port["IPP"]}) + assert not web_browser.listen_on_ports.difference({PORT_LOOKUP["SMB"], PORT_LOOKUP["IPP"]}) diff --git a/tests/integration_tests/system/test_web_client_server_and_database.py b/tests/integration_tests/system/test_web_client_server_and_database.py index 6c37360f..854ef41b 100644 --- a/tests/integration_tests/system/test_web_client_server_and_database.py +++ b/tests/integration_tests/system/test_web_client_server_and_database.py @@ -9,7 +9,7 @@ from primaite.simulator.network.hardware.base import Link from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService @@ -24,17 +24,22 @@ def web_client_web_server_database(example_network) -> Tuple[Network, Computer, # add rules to network router router_1: Router = example_network.get_node_by_hostname("router_1") router_1.acl.add_rule( - action=ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=0 + action=ACLAction.PERMIT, + src_port=PORT_LOOKUP["POSTGRES_SERVER"], + dst_port=PORT_LOOKUP["POSTGRES_SERVER"], + position=0, ) # Allow DNS requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["DNS"], dst_port=Port["DNS"], position=1) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["DNS"], dst_port=PORT_LOOKUP["DNS"], position=1) # Allow FTP requests - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["FTP"], dst_port=Port["FTP"], position=2) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["FTP"], dst_port=PORT_LOOKUP["FTP"], position=2) # Open port 80 for web server - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=3) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=3 + ) # Create Computer computer: Computer = example_network.get_node_by_hostname("client_1") @@ -148,7 +153,9 @@ class TestWebBrowserHistory: assert web_browser.history[-1].response_code == 200 router = network.get_node_by_hostname("router_1") - router.acl.add_rule(action=ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) + router.acl.add_rule( + action=ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0 + ) assert not web_browser.get_webpage() assert len(web_browser.history) == 3 # with current NIC behaviour, even if you block communication, you won't get SERVER_UNREACHABLE because @@ -166,7 +173,9 @@ class TestWebBrowserHistory: web_browser.get_webpage() router = network.get_node_by_hostname("router_1") - router.acl.add_rule(action=ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=0) + router.acl.add_rule( + action=ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=0 + ) web_browser.get_webpage() state = computer.describe_state() diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py index ff73e621..7813628c 100644 --- a/tests/integration_tests/test_simulation/test_request_response.py +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -12,7 +12,7 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from tests.conftest import DummyApplication, DummyService @@ -171,7 +171,7 @@ class TestDataManipulationGreenRequests: assert client_1_browser_execute.status == "success" assert client_2_browser_execute.status == "success" - router.acl.add_rule(ACLAction.DENY, src_port=Port["HTTP"], dst_port=Port["HTTP"], position=3) + router.acl.add_rule(ACLAction.DENY, src_port=PORT_LOOKUP["HTTP"], dst_port=PORT_LOOKUP["HTTP"], position=3) client_1_browser_execute = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"]) client_2_browser_execute = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"]) assert client_1_browser_execute.status == "failure" @@ -182,7 +182,9 @@ class TestDataManipulationGreenRequests: assert client_1_db_client_execute.status == "success" assert client_2_db_client_execute.status == "success" - router.acl.add_rule(ACLAction.DENY, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"]) + router.acl.add_rule( + ACLAction.DENY, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"] + ) client_1_db_client_execute = net.apply_request(["node", "client_1", "application", "DatabaseClient", "execute"]) client_2_db_client_execute = net.apply_request(["node", "client_2", "application", "DatabaseClient", "execute"]) assert client_1_db_client_execute.status == "failure" diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py index 4c471faa..ba7628c2 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py @@ -7,8 +7,9 @@ from primaite.simulator.network.hardware.base import generate_mac_address from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame -from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader, UDPHeader +from primaite.simulator.network.transmission.network_layer import IPPacket +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP, TCPHeader, UDPHeader +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -28,16 +29,16 @@ def router_with_acl_rules(): # Add rules here as needed acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol["TCP"], + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="192.168.1.1", - src_port=Port["HTTPS"], + src_port=PORT_LOOKUP["HTTPS"], dst_ip_address="192.168.1.2", - dst_port=Port["HTTP"], + dst_port=PORT_LOOKUP["HTTP"], position=1, ) acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol["TCP"], + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="192.168.1.3", src_port=8080, dst_ip_address="192.168.1.4", @@ -65,7 +66,7 @@ def router_with_wildcard_acl(): # Rule to permit traffic from a specific source IP and port to a specific destination IP and port acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol["TCP"], + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="192.168.1.1", src_port=8080, dst_ip_address="10.1.1.2", @@ -75,7 +76,7 @@ def router_with_wildcard_acl(): # Rule to deny traffic from an IP range to a specific destination IP and port acl.add_rule( action=ACLAction.DENY, - protocol=IPProtocol["TCP"], + protocol=PROTOCOL_LOOKUP["TCP"], src_ip_address="192.168.1.0", src_wildcard_mask="0.0.0.255", dst_ip_address="10.1.1.3", @@ -109,11 +110,11 @@ def test_add_rule(router_with_acl_rules): acl = router_with_acl_rules.acl assert acl.acl[1].action == ACLAction.PERMIT - assert acl.acl[1].protocol == IPProtocol["TCP"] + assert acl.acl[1].protocol == PROTOCOL_LOOKUP["TCP"] assert acl.acl[1].src_ip_address == IPv4Address("192.168.1.1") - assert acl.acl[1].src_port == Port["HTTPS"] + assert acl.acl[1].src_port == PORT_LOOKUP["HTTPS"] assert acl.acl[1].dst_ip_address == IPv4Address("192.168.1.2") - assert acl.acl[1].dst_port == Port["HTTP"] + assert acl.acl[1].dst_port == PORT_LOOKUP["HTTP"] def test_remove_rule(router_with_acl_rules): @@ -136,8 +137,8 @@ def test_traffic_permitted_by_specific_rule(router_with_acl_rules): acl = router_with_acl_rules.acl permitted_frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="192.168.1.2", protocol=IPProtocol["TCP"]), - tcp=TCPHeader(src_port=Port["HTTPS"], dst_port=Port["HTTP"]), + ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="192.168.1.2", protocol=PROTOCOL_LOOKUP["TCP"]), + tcp=TCPHeader(src_port=PORT_LOOKUP["HTTPS"], dst_port=PORT_LOOKUP["HTTP"]), ) is_permitted, _ = acl.is_permitted(permitted_frame) assert is_permitted @@ -153,7 +154,7 @@ def test_traffic_denied_by_specific_rule(router_with_acl_rules): acl = router_with_acl_rules.acl not_permitted_frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.3", dst_ip_address="192.168.1.4", protocol=IPProtocol["TCP"]), + ip=IPPacket(src_ip_address="192.168.1.3", dst_ip_address="192.168.1.4", protocol=PROTOCOL_LOOKUP["TCP"]), tcp=TCPHeader(src_port=8080, dst_port=80), ) is_permitted, _ = acl.is_permitted(not_permitted_frame) @@ -173,8 +174,8 @@ def test_default_rule(router_with_acl_rules): acl = router_with_acl_rules.acl not_permitted_frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.5", dst_ip_address="192.168.1.12", protocol=IPProtocol["UDP"]), - udp=UDPHeader(src_port=Port["HTTPS"], dst_port=Port["HTTP"]), + ip=IPPacket(src_ip_address="192.168.1.5", dst_ip_address="192.168.1.12", protocol=PROTOCOL_LOOKUP["UDP"]), + udp=UDPHeader(src_port=PORT_LOOKUP["HTTPS"], dst_port=PORT_LOOKUP["HTTP"]), ) is_permitted, rule = acl.is_permitted(not_permitted_frame) assert not is_permitted @@ -189,7 +190,7 @@ def test_direct_ip_match_with_acl(router_with_wildcard_acl): acl = router_with_wildcard_acl.acl frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="10.1.1.2", protocol=IPProtocol["TCP"]), + ip=IPPacket(src_ip_address="192.168.1.1", dst_ip_address="10.1.1.2", protocol=PROTOCOL_LOOKUP["TCP"]), tcp=TCPHeader(src_port=8080, dst_port=80), ) assert acl.is_permitted(frame)[0], "Direct IP match should be permitted." @@ -204,7 +205,7 @@ def test_ip_range_match_denied_with_acl(router_with_wildcard_acl): acl = router_with_wildcard_acl.acl frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.100", dst_ip_address="10.1.1.3", protocol=IPProtocol["TCP"]), + ip=IPPacket(src_ip_address="192.168.1.100", dst_ip_address="10.1.1.3", protocol=PROTOCOL_LOOKUP["TCP"]), tcp=TCPHeader(src_port=8080, dst_port=443), ) assert not acl.is_permitted(frame)[0], "IP range match with wildcard mask should be denied." @@ -219,7 +220,7 @@ def test_traffic_permitted_to_destination_range_with_acl(router_with_wildcard_ac acl = router_with_wildcard_acl.acl frame = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=IPProtocol["UDP"]), + ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=PROTOCOL_LOOKUP["UDP"]), udp=UDPHeader(src_port=1433, dst_port=1433), ) assert acl.is_permitted(frame)[0], "Traffic to destination IP range should be permitted." @@ -253,23 +254,23 @@ def test_ip_traffic_from_specific_subnet(): permitted_frame_1 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=IPProtocol["TCP"]), - tcp=TCPHeader(src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"]), + ip=IPPacket(src_ip_address="192.168.1.50", dst_ip_address="10.2.200.200", protocol=PROTOCOL_LOOKUP["TCP"]), + tcp=TCPHeader(src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"]), ) assert acl.is_permitted(permitted_frame_1)[0] permitted_frame_2 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.10", dst_ip_address="85.199.214.101", protocol=IPProtocol["UDP"]), - udp=UDPHeader(src_port=Port["NTP"], dst_port=Port["NTP"]), + ip=IPPacket(src_ip_address="192.168.1.10", dst_ip_address="85.199.214.101", protocol=PROTOCOL_LOOKUP["UDP"]), + udp=UDPHeader(src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"]), ) assert acl.is_permitted(permitted_frame_2)[0] permitted_frame_3 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.1.200", dst_ip_address="192.168.1.1", protocol=IPProtocol["ICMP"]), + ip=IPPacket(src_ip_address="192.168.1.200", dst_ip_address="192.168.1.1", protocol=PROTOCOL_LOOKUP["ICMP"]), icmp=ICMPPacket(identifier=1), ) @@ -277,16 +278,16 @@ def test_ip_traffic_from_specific_subnet(): not_permitted_frame_1 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.0.50", dst_ip_address="10.2.200.200", protocol=IPProtocol["TCP"]), - tcp=TCPHeader(src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"]), + ip=IPPacket(src_ip_address="192.168.0.50", dst_ip_address="10.2.200.200", protocol=PROTOCOL_LOOKUP["TCP"]), + tcp=TCPHeader(src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"]), ) assert not acl.is_permitted(not_permitted_frame_1)[0] not_permitted_frame_2 = Frame( ethernet=EthernetHeader(src_mac_addr=generate_mac_address(), dst_mac_addr=generate_mac_address()), - ip=IPPacket(src_ip_address="192.168.2.10", dst_ip_address="85.199.214.101", protocol=IPProtocol["UDP"]), - udp=UDPHeader(src_port=Port["NTP"], dst_port=Port["NTP"]), + ip=IPPacket(src_ip_address="192.168.2.10", dst_ip_address="85.199.214.101", protocol=PROTOCOL_LOOKUP["UDP"]), + udp=UDPHeader(src_port=PORT_LOOKUP["NTP"], dst_port=PORT_LOOKUP["NTP"]), ) assert not acl.is_permitted(not_permitted_frame_2)[0] diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py index 3551ce38..0e1844c4 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py @@ -2,8 +2,8 @@ from ipaddress import IPv4Address from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validators import PROTOCOL_LOOKUP def test_wireless_router_from_config(): @@ -67,12 +67,12 @@ def test_wireless_router_from_config(): r0 = rt.acl.acl[0] assert r0.action == ACLAction.PERMIT - assert r0.src_port == r0.dst_port == Port["POSTGRES_SERVER"] + assert r0.src_port == r0.dst_port == PORT_LOOKUP["POSTGRES_SERVER"] assert r0.src_ip_address == r0.dst_ip_address == r0.dst_wildcard_mask == r0.src_wildcard_mask == r0.protocol == None r1 = rt.acl.acl[1] assert r1.action == ACLAction.PERMIT - assert r1.protocol == IPProtocol["ICMP"] + assert r1.protocol == PROTOCOL_LOOKUP["ICMP"] assert ( r1.src_ip_address == r1.dst_ip_address diff --git a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py index 9fd39dfc..9e9a1f72 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py @@ -3,9 +3,10 @@ import pytest from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame -from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol, Precedence +from primaite.simulator.network.transmission.network_layer import IPPacket, Precedence from primaite.simulator.network.transmission.primaite_layer import AgentSource, DataStatus -from primaite.simulator.network.transmission.transport_layer import Port, TCPFlags, TCPHeader, UDPHeader +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP, TCPFlags, TCPHeader, UDPHeader +from primaite.utils.validators import PROTOCOL_LOOKUP def test_frame_minimal_instantiation(): @@ -20,7 +21,7 @@ def test_frame_minimal_instantiation(): ) # Check network layer default values - assert frame.ip.protocol == IPProtocol["TCP"] + assert frame.ip.protocol == PROTOCOL_LOOKUP["TCP"] assert frame.ip.ttl == 64 assert frame.ip.precedence == Precedence.ROUTINE @@ -40,7 +41,7 @@ def test_frame_creation_fails_tcp_without_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["TCP"]), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["TCP"]), ) @@ -49,7 +50,7 @@ def test_frame_creation_fails_udp_without_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["UDP"]), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["UDP"]), ) @@ -58,7 +59,7 @@ def test_frame_creation_fails_tcp_with_udp_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["TCP"]), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["TCP"]), udp=UDPHeader(src_port=8080, dst_port=80), ) @@ -68,7 +69,7 @@ def test_frame_creation_fails_udp_with_tcp_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["UDP"]), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["UDP"]), udp=TCPHeader(src_port=8080, dst_port=80), ) @@ -77,7 +78,7 @@ def test_icmp_frame_creation(): """Tests Frame creation for ICMP.""" frame = Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["ICMP"]), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["ICMP"]), icmp=ICMPPacket(), ) assert frame @@ -88,5 +89,5 @@ def test_icmp_frame_creation_fails_without_icmp_header(): with pytest.raises(ValueError): Frame( ethernet=EthernetHeader(src_mac_addr="aa:bb:cc:dd:ee:ff", dst_mac_addr="11:22:33:44:55:66"), - ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=IPProtocol["ICMP"]), + ip=IPPacket(src_ip_address="192.168.0.10", dst_ip_address="192.168.0.20", protocol=PROTOCOL_LOOKUP["ICMP"]), ) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py index 6e53aebc..fde70616 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py @@ -4,11 +4,11 @@ import pytest from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.red_applications.c2.c2_beacon import C2Beacon from primaite.simulator.system.applications.red_applications.c2.c2_server import C2Command, C2Server +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -129,19 +129,19 @@ def test_c2_handle_switching_port(basic_c2_network): # Assert to confirm that both the C2 server and the C2 beacon are configured correctly. assert c2_beacon.c2_config.keep_alive_frequency is 2 - assert c2_beacon.c2_config.masquerade_port is Port["HTTP"] - assert c2_beacon.c2_config.masquerade_protocol is IPProtocol["TCP"] + assert c2_beacon.c2_config.masquerade_port is PORT_LOOKUP["HTTP"] + assert c2_beacon.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] assert c2_server.c2_config.keep_alive_frequency is 2 - assert c2_server.c2_config.masquerade_port is Port["HTTP"] - assert c2_server.c2_config.masquerade_protocol is IPProtocol["TCP"] + assert c2_server.c2_config.masquerade_port is PORT_LOOKUP["HTTP"] + assert c2_server.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] # Configuring the C2 Beacon. c2_beacon.configure( c2_server_ip_address="192.168.0.1", keep_alive_frequency=2, - masquerade_port=Port["FTP"], - masquerade_protocol=IPProtocol["TCP"], + masquerade_port=PORT_LOOKUP["FTP"], + masquerade_protocol=PROTOCOL_LOOKUP["TCP"], ) # Asserting that the c2 applications have established a c2 connection @@ -150,11 +150,11 @@ def test_c2_handle_switching_port(basic_c2_network): # Assert to confirm that both the C2 server and the C2 beacon # Have reconfigured their C2 settings. - assert c2_beacon.c2_config.masquerade_port is Port["FTP"] - assert c2_beacon.c2_config.masquerade_protocol is IPProtocol["TCP"] + assert c2_beacon.c2_config.masquerade_port is PORT_LOOKUP["FTP"] + assert c2_beacon.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] - assert c2_server.c2_config.masquerade_port is Port["FTP"] - assert c2_server.c2_config.masquerade_protocol is IPProtocol["TCP"] + assert c2_server.c2_config.masquerade_port is PORT_LOOKUP["FTP"] + assert c2_server.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"] def test_c2_handle_switching_frequency(basic_c2_network): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py index 229f98fe..f4750158 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py @@ -3,13 +3,13 @@ import pytest from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.networks import arcd_uc2_network -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.red_applications.data_manipulation_bot import ( DataManipulationAttackStage, DataManipulationBot, ) +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -27,8 +27,8 @@ def test_create_dm_bot(dm_client): data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software.get("DataManipulationBot") assert data_manipulation_bot.name == "DataManipulationBot" - assert data_manipulation_bot.port == Port["NONE"] - assert data_manipulation_bot.protocol == IPProtocol["NONE"] + assert data_manipulation_bot.port == PORT_LOOKUP["NONE"] + assert data_manipulation_bot.protocol == PROTOCOL_LOOKUP["NONE"] assert data_manipulation_bot.payload == "DELETE" diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py index 2acd991a..d0c65266 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py @@ -5,7 +5,7 @@ import pytest from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py index c274c18e..f5781485 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py @@ -4,10 +4,10 @@ import pytest from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.protocols.http import HttpResponsePacket, HttpStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -39,8 +39,8 @@ def test_create_web_client(): # Web Browser should be pre-installed in computer web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") assert web_browser.name is "WebBrowser" - assert web_browser.port is Port["HTTP"] - assert web_browser.protocol is IPProtocol["TCP"] + assert web_browser.port is PORT_LOOKUP["HTTP"] + assert web_browser.protocol is PROTOCOL_LOOKUP["TCP"] def test_receive_invalid_payload(web_browser): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py index 1a51708d..09099c5c 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py @@ -6,10 +6,10 @@ import pytest from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.protocols.dns import DNSPacket, DNSReply, DNSRequest -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -28,8 +28,8 @@ def test_create_dns_client(dns_client): assert dns_client is not None dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient") assert dns_client_service.name is "DNSClient" - assert dns_client_service.port is Port["DNS"] - assert dns_client_service.protocol is IPProtocol["TCP"] + assert dns_client_service.port is PORT_LOOKUP["DNS"] + assert dns_client_service.protocol is PROTOCOL_LOOKUP["TCP"] def test_dns_client_add_domain_to_cache_when_not_running(dns_client): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py index 8cdb1b84..688bfd7d 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py @@ -8,10 +8,10 @@ from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -32,8 +32,8 @@ def test_create_dns_server(dns_server): assert dns_server is not None dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer") assert dns_server_service.name is "DNSServer" - assert dns_server_service.port is Port["DNS"] - assert dns_server_service.protocol is IPProtocol["TCP"] + assert dns_server_service.port is PORT_LOOKUP["DNS"] + assert dns_server_service.protocol is PROTOCOL_LOOKUP["TCP"] def test_dns_server_domain_name_registration(dns_server): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py index 3c1afb28..b4fe8633 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py @@ -8,10 +8,10 @@ from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -31,8 +31,8 @@ def test_create_ftp_client(ftp_client): assert ftp_client is not None ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") assert ftp_client_service.name is "FTPClient" - assert ftp_client_service.port is Port["FTP"] - assert ftp_client_service.protocol is IPProtocol["TCP"] + assert ftp_client_service.port is PORT_LOOKUP["FTP"] + assert ftp_client_service.protocol is PROTOCOL_LOOKUP["TCP"] def test_ftp_client_store_file(ftp_client): @@ -61,7 +61,7 @@ def test_ftp_should_not_process_commands_if_service_not_running(ftp_client): """Method _process_ftp_command should return false if service is not running.""" payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.PORT, - ftp_command_args=Port["FTP"], + ftp_command_args=PORT_LOOKUP["FTP"], status_code=FTPStatusCode.OK, ) @@ -102,7 +102,7 @@ def test_offline_ftp_client_receives_request(ftp_client): payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.PORT, - ftp_command_args=Port["FTP"], + ftp_command_args=PORT_LOOKUP["FTP"], status_code=FTPStatusCode.OK, ) @@ -119,7 +119,7 @@ def test_receive_should_ignore_payload_with_none_status_code(ftp_client): """Receive should ignore payload with no set status code to prevent infinite send/receive loops.""" payload: FTPPacket = FTPPacket( ftp_command=FTPCommand.PORT, - ftp_command_args=Port["FTP"], + ftp_command_args=PORT_LOOKUP["FTP"], status_code=None, ) ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py index aa13ec5e..3f10db4d 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py @@ -6,10 +6,10 @@ from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -30,8 +30,8 @@ def test_create_ftp_server(ftp_server): assert ftp_server is not None ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer") assert ftp_server_service.name is "FTPServer" - assert ftp_server_service.port is Port["FTP"] - assert ftp_server_service.protocol is IPProtocol["TCP"] + assert ftp_server_service.port is PORT_LOOKUP["FTP"] + assert ftp_server_service.protocol is PROTOCOL_LOOKUP["TCP"] def test_ftp_server_store_file(ftp_server): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 21ed839b..f2895091 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -18,13 +18,13 @@ from primaite.simulator.network.protocols.ssh import ( SSHTransportMessage, SSHUserCredentials, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection, Terminal from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -77,11 +77,15 @@ def wireless_wan_network(): network.connect(pc_a.network_interface[1], router_1.network_interface[2]) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 + ) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # add ACL rule to allow SSH traffic - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["SSH"], dst_port=Port["SSH"], position=21) + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=21 + ) # Configure PC B pc_b = Computer( @@ -329,7 +333,9 @@ def test_SSH_across_network(wireless_wan_network): terminal_a: Terminal = pc_a.software_manager.software.get("Terminal") terminal_b: Terminal = pc_b.software_manager.software.get("Terminal") - router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["SSH"], dst_port=Port["SSH"], position=21) + router_2.acl.add_rule( + action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=21 + ) assert len(terminal_a._connections) == 0 diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py index c1df3857..c78a381e 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py @@ -9,9 +9,9 @@ from primaite.simulator.network.protocols.http import ( HttpResponsePacket, HttpStatusCode, ) -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.utils.validators import PROTOCOL_LOOKUP @pytest.fixture(scope="function") @@ -33,8 +33,8 @@ def test_create_web_server(web_server): assert web_server is not None web_server_service: WebServer = web_server.software_manager.software.get("WebServer") assert web_server_service.name is "WebServer" - assert web_server_service.port is Port["HTTP"] - assert web_server_service.protocol is IPProtocol["TCP"] + assert web_server_service.port is PORT_LOOKUP["HTTP"] + assert web_server_service.protocol is PROTOCOL_LOOKUP["TCP"] def test_handling_get_request_not_found_path(web_server): diff --git a/tests/unit_tests/_primaite/_simulator/_system/test_software.py b/tests/unit_tests/_primaite/_simulator/_system/test_software.py index b7a663af..1baaf88e 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/test_software.py +++ b/tests/unit_tests/_primaite/_simulator/_system/test_software.py @@ -3,11 +3,11 @@ from typing import Dict import pytest -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.service import Service from primaite.simulator.system.software import IOSoftware, SoftwareHealthState +from primaite.utils.validators import PROTOCOL_LOOKUP class TestSoftware(Service): @@ -19,10 +19,10 @@ class TestSoftware(Service): def software(file_system): return TestSoftware( name="TestSoftware", - port=Port["ARP"], + port=PORT_LOOKUP["ARP"], file_system=file_system, sys_log=SysLog(hostname="test_service"), - protocol=IPProtocol["TCP"], + protocol=PROTOCOL_LOOKUP["TCP"], ) diff --git a/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py b/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py index 8becc6ae..10ed36e0 100644 --- a/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py +++ b/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.utils.converters import convert_dict_enum_keys_to_enum_values +from primaite.utils.validators import PROTOCOL_LOOKUP def test_simple_conversion(): @@ -11,7 +11,7 @@ def test_simple_conversion(): The original dictionary contains one level of nested dictionary with enums as keys. The expected output should have string values of enums as keys. """ - original_dict = {IPProtocol["UDP"]: {Port["ARP"]: {"inbound": 0, "outbound": 1016.0}}} + original_dict = {PROTOCOL_LOOKUP["UDP"]: {PORT_LOOKUP["ARP"]: {"inbound": 0, "outbound": 1016.0}}} expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0}}} assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict @@ -36,8 +36,8 @@ def test_mixed_keys(): The expected output should have string values of enums and original string keys. """ original_dict = { - IPProtocol["TCP"]: {"port": {"inbound": 0, "outbound": 1016.0}}, - "protocol": {Port["HTTP"]: {"inbound": 10, "outbound": 2020.0}}, + PROTOCOL_LOOKUP["TCP"]: {"port": {"inbound": 0, "outbound": 1016.0}}, + "protocol": {PORT_LOOKUP["HTTP"]: {"inbound": 10, "outbound": 2020.0}}, } expected_dict = { "tcp": {"port": {"inbound": 0, "outbound": 1016.0}}, @@ -66,8 +66,12 @@ def test_nested_dicts(): The expected output should have string values of enums as keys at all levels. """ original_dict = { - IPProtocol["UDP"]: { - Port["ARP"]: {"inbound": 0, "outbound": 1016.0, "details": {IPProtocol["TCP"]: {"latency": "low"}}} + PROTOCOL_LOOKUP["UDP"]: { + PORT_LOOKUP["ARP"]: { + "inbound": 0, + "outbound": 1016.0, + "details": {PROTOCOL_LOOKUP["TCP"]: {"latency": "low"}}, + } } } expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0, "details": {"tcp": {"latency": "low"}}}}} @@ -82,8 +86,11 @@ def test_non_dict_values(): The expected output should preserve these non-dictionary values while converting enum keys to string values. """ original_dict = { - IPProtocol["UDP"]: [Port["ARP"], Port["HTTP"]], - "protocols": (IPProtocol["TCP"], IPProtocol["UDP"]), + PROTOCOL_LOOKUP["UDP"]: [PORT_LOOKUP["ARP"], PORT_LOOKUP["HTTP"]], + "protocols": (PROTOCOL_LOOKUP["TCP"], PROTOCOL_LOOKUP["UDP"]), + } + expected_dict = { + "udp": [PORT_LOOKUP["ARP"], PORT_LOOKUP["HTTP"]], + "protocols": (PROTOCOL_LOOKUP["TCP"], PROTOCOL_LOOKUP["UDP"]), } - expected_dict = {"udp": [Port["ARP"], Port["HTTP"]], "protocols": (IPProtocol["TCP"], IPProtocol["UDP"])} assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict From f1b911bc651f152ae0426dac13e26b6002668e46 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 25 Sep 2024 16:28:22 +0100 Subject: [PATCH 4/9] Change port and protocol to annotated validators --- .../agent/observations/host_observations.py | 3 +- .../agent/observations/nic_observations.py | 30 ++------ .../agent/observations/node_observations.py | 29 ++------ src/primaite/game/game.py | 36 ++-------- src/primaite/simulator/network/creation.py | 4 +- .../simulator/network/hardware/base.py | 5 +- .../network/hardware/nodes/host/host_node.py | 2 +- .../hardware/nodes/network/firewall.py | 5 +- .../network/hardware/nodes/network/router.py | 61 +++++++--------- .../hardware/nodes/network/wireless_router.py | 5 +- src/primaite/simulator/network/networks.py | 4 +- .../simulator/network/protocols/masquerade.py | 6 +- .../network/transmission/data_link_layer.py | 5 +- .../network/transmission/network_layer.py | 3 +- .../network/transmission/transport_layer.py | 32 --------- .../system/applications/database_client.py | 5 +- .../simulator/system/applications/nmap.py | 47 +++++++------ .../red_applications/c2/abstract_c2.py | 10 +-- .../red_applications/c2/c2_beacon.py | 4 +- .../red_applications/data_manipulation_bot.py | 4 +- .../applications/red_applications/dos_bot.py | 4 +- .../red_applications/ransomware_script.py | 4 +- .../system/applications/web_browser.py | 6 +- .../simulator/system/core/session_manager.py | 49 +++++++------ .../simulator/system/core/software_manager.py | 22 +++--- .../simulator/system/services/arp/arp.py | 5 +- .../services/database/database_service.py | 4 +- .../system/services/dns/dns_client.py | 6 +- .../system/services/dns/dns_server.py | 4 +- .../system/services/ftp/ftp_client.py | 18 ++--- .../system/services/ftp/ftp_server.py | 6 +- .../system/services/ftp/ftp_service.py | 7 +- .../simulator/system/services/icmp/icmp.py | 4 +- .../system/services/ntp/ntp_client.py | 6 +- .../system/services/ntp/ntp_server.py | 4 +- .../system/services/terminal/terminal.py | 4 +- .../system/services/web_server/web_server.py | 6 +- src/primaite/simulator/system/software.py | 11 +-- src/primaite/utils/validation/__init__.py | 1 + src/primaite/utils/validation/ip_protocol.py | 47 +++++++++++++ .../ipv4_address.py} | 38 +--------- src/primaite/utils/validation/port.py | 70 +++++++++++++++++++ tests/conftest.py | 4 +- .../nodes/network/test_firewall_config.py | 4 +- .../nodes/network/test_router_config.py | 4 +- .../applications/extended_application.py | 4 +- .../extensions/nodes/super_computer.py | 2 +- .../extensions/services/extended_service.py | 4 +- .../actions/test_c2_suite_actions.py | 2 +- .../actions/test_configure_actions.py | 2 +- .../actions/test_terminal_actions.py | 2 +- .../observations/test_acl_observations.py | 2 +- .../observations/test_firewall_observation.py | 4 +- .../observations/test_router_observation.py | 4 +- .../observations/test_user_observations.py | 2 +- .../game_layer/test_actions.py | 4 +- .../game_layer/test_rewards.py | 4 +- .../network/test_broadcast.py | 4 +- .../network/test_firewall.py | 4 +- .../integration_tests/network/test_routing.py | 4 +- .../network/test_wireless_router.py | 4 +- .../test_c2_suite_integration.py | 4 +- .../test_data_manipulation_bot_and_server.py | 2 +- .../test_dos_bot_and_server.py | 2 +- .../test_ransomware_script.py | 2 +- tests/integration_tests/system/test_nmap.py | 4 +- .../system/test_service_listening_on_ports.py | 4 +- .../test_web_client_server_and_database.py | 2 +- .../test_simulation/test_request_response.py | 2 +- .../_network/_hardware/nodes/test_acl.py | 5 +- .../_network/_hardware/nodes/test_router.py | 4 +- .../_transmission/test_data_link_layer.py | 5 +- .../_red_applications/test_c2_suite.py | 4 +- .../test_data_manipulation_bot.py | 4 +- .../_red_applications/test_dos_bot.py | 2 +- .../_system/_applications/test_web_browser.py | 4 +- .../_system/_services/test_dns_client.py | 4 +- .../_system/_services/test_dns_server.py | 4 +- .../_system/_services/test_ftp_client.py | 4 +- .../_system/_services/test_ftp_server.py | 4 +- .../_system/_services/test_terminal.py | 4 +- .../_system/_services/test_web_server.py | 4 +- .../_simulator/_system/test_software.py | 4 +- .../_utils/test_dict_enum_keys_conversion.py | 4 +- 84 files changed, 380 insertions(+), 392 deletions(-) create mode 100644 src/primaite/utils/validation/__init__.py create mode 100644 src/primaite/utils/validation/ip_protocol.py rename src/primaite/utils/{validators.py => validation/ipv4_address.py} (59%) create mode 100644 src/primaite/utils/validation/port.py diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 0984f008..96c5f40d 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -12,7 +12,8 @@ from primaite.game.agent.observations.nic_observations import NICObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -from primaite.utils.validators import IPProtocol, Port +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.port import Port _LOGGER = getLogger(__name__) diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index c51cb427..d180b641 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -1,16 +1,15 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import Dict, Optional +from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType -from pydantic import field_validator from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.port import Port class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): @@ -23,30 +22,9 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): """Number of the network interface.""" include_nmne: Optional[bool] = None """Whether to include number of malicious network events (NMNE) in the observation.""" - monitored_traffic: Optional[Dict] = None + monitored_traffic: Optional[Dict[IPProtocol, List[Port]]] = None """A dict containing which traffic types are to be included in the observation.""" - @field_validator("monitored_traffic", mode="before") - def traffic_lookup(cls, val: Optional[Dict]) -> Optional[Dict]: - """ - Convert monitored_traffic by lookup against Port and Protocol dicts. - - This is necessary for retaining compatiblility with configs written for PrimAITE <=3.3. - This method will be removed in PrimAITE >= 4.0 - """ - if val is None: - return val - new_val = {} - for proto, port_list in val.items(): - # convert protocol, for instance ICMP becomes "icmp" - proto = PROTOCOL_LOOKUP[proto] if proto in PROTOCOL_LOOKUP else proto - new_val[proto] = [] - for port in port_list: - # convert ports, for instance "HTTP" becomes 80 - port = PORT_LOOKUP[port] if port in PORT_LOOKUP else port - new_val[proto].append(port) - return new_val - def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None: """ Initialise a network interface observation instance. diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 0bb8ea0f..e11521b6 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -5,15 +5,15 @@ from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType -from pydantic import field_validator, model_validator +from pydantic import model_validator from primaite import getLogger from primaite.game.agent.observations.firewall_observation import FirewallObservation from primaite.game.agent.observations.host_observations import HostObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.observations.router_observation import RouterObservation -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.port import Port _LOGGER = getLogger(__name__) @@ -42,7 +42,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"): """Number of network interface cards (NICs).""" include_nmne: Optional[bool] = None """Flag to include nmne.""" - monitored_traffic: Optional[Dict] = None + monitored_traffic: Optional[Dict[IPProtocol, List[Port]]] = None """A dict containing which traffic types are to be included in the observation.""" include_num_access: Optional[bool] = None """Flag to include the number of accesses.""" @@ -63,27 +63,6 @@ class NodesObservation(AbstractObservation, identifier="NODES"): num_rules: Optional[int] = None """Number of rules ACL rules to show.""" - @field_validator("monitored_traffic", mode="before") - def traffic_lookup(cls, val: Optional[Dict]) -> Optional[Dict]: - """ - Convert monitored_traffic by lookup against Port and Protocol dicts. - - This is necessary for retaining compatiblility with configs written for PrimAITE <=3.3. - This method will be removed in PrimAITE >= 4.0 - """ - if val is None: - return val - new_val = {} - for proto, port_list in val.items(): - # convert protocol, for instance ICMP becomes "icmp" - proto = PROTOCOL_LOOKUP[proto] if proto in PROTOCOL_LOOKUP else proto - new_val[proto] = [] - for port in port_list: - # convert ports, for instance "HTTP" becomes 80 - port = PORT_LOOKUP[port] if port in PORT_LOOKUP else port - new_val[proto].append(port) - return new_val - @model_validator(mode="after") def force_optional_fields(self) -> NodesObservation.ConfigSchema: """Check that options are specified only if they are needed for the nodes that are part of the config.""" diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index a0d2ceb4..6d1c0920 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -4,7 +4,7 @@ from ipaddress import IPv4Address from typing import Dict, List, Optional, Union import numpy as np -from pydantic import BaseModel, ConfigDict, field_validator +from pydantic import BaseModel, ConfigDict from primaite import DEFAULT_BANDWIDTH, getLogger from primaite.game.agent.actions import ActionManager @@ -27,7 +27,6 @@ from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from primaite.simulator.network.nmne import NMNEConfig -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient # noqa: F401 @@ -50,7 +49,8 @@ from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.terminal.terminal import Terminal from primaite.simulator.system.services.web_server.web_server import WebServer from primaite.simulator.system.software import Software -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -81,39 +81,13 @@ class PrimaiteGameOptions(BaseModel): """Random number seed for RNGs.""" max_episode_length: int = 256 """Maximum number of episodes for the PrimAITE game.""" - ports: List[int] + ports: List[Port] """A whitelist of available ports in the simulation.""" - protocols: List[str] + protocols: List[IPProtocol] """A whitelist of available protocols in the simulation.""" thresholds: Optional[Dict] = {} """A dict containing the thresholds used for determining what is acceptable during observations.""" - @field_validator("ports", mode="before") - def ports_str2int(cls, vals: Union[List[str], List[int]]) -> List[int]: - """ - Convert named port strings to port integer values. Integer ports remain unaffected. - - This is necessary to retain backwards compatibility with configs written for PrimAITE<=3.3. - :warning: This will be deprecated in PrimAITE 4.0 and configs will need to be converted. - """ - for i, port_val in enumerate(vals): - if port_val in PORT_LOOKUP: - vals[i] = PORT_LOOKUP[port_val] - return vals - - @field_validator("protocols", mode="before") - def protocols_str2int(cls, vals: List[str]) -> List[str]: - """ - Convert old-style named protocols to their proper values. - - This is necessary to retain backwards compatibility with configs written for PrimAITE<=3.3. - :warning: This will be deprecated in PrimAITE 4.0 and configs will need to be converted. - """ - for i, proto_val in enumerate(vals): - if proto_val in PROTOCOL_LOOKUP: - vals[i] = PROTOCOL_LOOKUP[proto_val] - return vals - class PrimaiteGame: """ diff --git a/src/primaite/simulator/network/creation.py b/src/primaite/simulator/network/creation.py index 9e2e5502..891c445e 100644 --- a/src/primaite/simulator/network/creation.py +++ b/src/primaite/simulator/network/creation.py @@ -6,8 +6,8 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP def num_of_switches_required(num_nodes: int, max_network_interface: int = 24) -> int: diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index affaf3cc..778cffa2 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -21,7 +21,6 @@ from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.packet_capture import PacketCapture from primaite.simulator.system.core.session_manager import SessionManager @@ -32,7 +31,9 @@ from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.terminal.terminal import Terminal from primaite.simulator.system.software import IOSoftware, Software from primaite.utils.converters import convert_dict_enum_keys_to_enum_values -from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import PORT_LOOKUP IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware) diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 8a420e44..5699721b 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -22,7 +22,7 @@ from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.icmp.icmp import ICMP from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.terminal.terminal import Terminal -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ipv4_address import IPV4Address _LOGGER = getLogger(__name__) diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index eed1132b..47cfae57 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -14,9 +14,10 @@ from primaite.simulator.network.hardware.nodes.network.router import ( RouterInterface, ) from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.sys_log import SysLog -from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import PORT_LOOKUP EXTERNAL_PORT_ID: Final[int] = 1 """The Firewall port ID of the external port.""" diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 46efe668..244f40ce 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -7,7 +7,7 @@ from ipaddress import IPv4Address, IPv4Network from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union from prettytable import MARKDOWN, PrettyTable -from pydantic import field_validator, validate_call +from pydantic import validate_call from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent @@ -17,14 +17,15 @@ from primaite.simulator.network.hardware.nodes.network.network_node import Netwo from primaite.simulator.network.protocols.arp import ARPPacket from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.nmap import NMAP from primaite.simulator.system.core.session_manager import SessionManager from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.arp.arp import ARP from primaite.simulator.system.services.icmp.icmp import ICMP from primaite.simulator.system.services.terminal.terminal import Terminal -from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import Port, PORT_LOOKUP @validate_call() @@ -120,29 +121,15 @@ class ACLRule(SimComponent): """ action: ACLAction = ACLAction.DENY - protocol: Optional[str] = None + protocol: Optional[IPProtocol] = None src_ip_address: Optional[IPV4Address] = None src_wildcard_mask: Optional[IPV4Address] = None dst_ip_address: Optional[IPV4Address] = None dst_wildcard_mask: Optional[IPV4Address] = None - src_port: Optional[int] = None - dst_port: Optional[int] = None + src_port: Optional[Port] = None + dst_port: Optional[Port] = None match_count: int = 0 - @field_validator("protocol", mode="before") - def protocol_valid(cls, val: Optional[str]) -> Optional[str]: - """Assert that the protocol for the rule is predefined in the IPProtocol lookup.""" - if val is not None: - assert val in PROTOCOL_LOOKUP.values(), f"Cannot create ACL rule with invalid protocol {val}" - return val - - @field_validator("src_port", "dst_port", mode="before") - def ports_valid(cls, val: Optional[int]) -> Optional[int]: - """Assert that the port for the rule is predefined in the Port lookup.""" - if val is not None: - assert val in PORT_LOOKUP.values(), f"Cannot create ACL rule with invalid port {val}" - return val - def __str__(self) -> str: rule_strings = [] for key, value in self.model_dump(exclude={"uuid", "request_manager"}).items(): @@ -390,13 +377,13 @@ class AccessControlList(SimComponent): def add_rule( self, action: ACLAction = ACLAction.DENY, - protocol: Optional[str] = None, + protocol: Optional[IPProtocol] = None, src_ip_address: Optional[IPV4Address] = None, src_wildcard_mask: Optional[IPV4Address] = None, dst_ip_address: Optional[IPV4Address] = None, dst_wildcard_mask: Optional[IPV4Address] = None, - src_port: Optional[int] = None, - dst_port: Optional[int] = None, + src_port: Optional[Port] = None, + dst_port: Optional[Port] = None, position: int = 0, ) -> bool: """ @@ -498,11 +485,11 @@ class AccessControlList(SimComponent): def get_relevant_rules( self, - protocol: str, + protocol: IPProtocol, src_ip_address: Union[str, IPv4Address], - src_port: int, + src_port: Port, dst_ip_address: Union[str, IPv4Address], - dst_port: int, + dst_port: Port, ) -> List[ACLRule]: """ Get the list of relevant rules for a packet with given properties. @@ -1101,17 +1088,17 @@ class RouterSessionManager(SessionManager): def resolve_outbound_transmission_details( self, dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, - src_port: Optional[int] = None, - dst_port: Optional[int] = None, - protocol: Optional[str] = None, + src_port: Optional[Port] = None, + dst_port: Optional[Port] = None, + protocol: Optional[IPProtocol] = None, session_id: Optional[str] = None, ) -> Tuple[ Optional[RouterInterface], Optional[str], IPv4Address, - Optional[int], - Optional[int], - Optional[str], + Optional[Port], + Optional[Port], + Optional[IPProtocol], bool, ]: """ @@ -1131,19 +1118,19 @@ class RouterSessionManager(SessionManager): treats the transmission as a broadcast to that network. Optional. :type dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] :param src_port: The source port number for the transmission. Optional. - :type src_port: Optional[int] + :type src_port: Optional[Port] :param dst_port: The destination port number for the transmission. Optional. - :type dst_port: Optional[int] + :type dst_port: Optional[Port] :param protocol: The IP protocol to be used for the transmission. Optional. - :type protocol: Optional[str] + :type protocol: Optional[IPProtocol] :param session_id: The session ID associated with the transmission. If provided, the session details override other parameters. Optional. :type session_id: Optional[str] :return: A tuple containing the resolved outbound network interface, destination MAC address, destination IP address, source port, destination port, protocol, and a boolean indicating whether the transmission is a broadcast. - :rtype: Tuple[Optional[RouterInterface], Optional[str], IPv4Address, Optional[int], Optional[int], - Optional[str], bool] + :rtype: Tuple[Optional[RouterInterface], Optional[str], IPv4Address, Optional[Port], Optional[Port], + Optional[IPProtocol], bool] """ if dst_ip_address and not isinstance(dst_ip_address, (IPv4Address, IPv4Network)): dst_ip_address = IPv4Address(dst_ip_address) diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 3615ef54..27a13154 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -8,8 +8,9 @@ from primaite.simulator.network.airspace import AirSpace, IPWirelessNetworkInter from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterInterface from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP -from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import PORT_LOOKUP class WirelessAccessPoint(IPWirelessNetworkInterface): diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index c3b4a341..2c3c15b4 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -12,14 +12,14 @@ from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.web_server.web_server import WebServer -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) diff --git a/src/primaite/simulator/network/protocols/masquerade.py b/src/primaite/simulator/network/protocols/masquerade.py index ef060bc7..5c5f03b2 100644 --- a/src/primaite/simulator/network/protocols/masquerade.py +++ b/src/primaite/simulator/network/protocols/masquerade.py @@ -3,14 +3,16 @@ from enum import Enum from typing import Optional from primaite.simulator.network.protocols.packet import DataPacket +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.port import Port class MasqueradePacket(DataPacket): """Represents an generic malicious packet that is masquerading as another protocol.""" - masquerade_protocol: str # The 'Masquerade' protocol that is currently in use + masquerade_protocol: IPProtocol # The 'Masquerade' protocol that is currently in use - masquerade_port: int # The 'Masquerade' port that is currently in use + masquerade_port: Port # The 'Masquerade' port that is currently in use class C2Packet(MasqueradePacket): diff --git a/src/primaite/simulator/network/transmission/data_link_layer.py b/src/primaite/simulator/network/transmission/data_link_layer.py index ca212c58..259d62e3 100644 --- a/src/primaite/simulator/network/transmission/data_link_layer.py +++ b/src/primaite/simulator/network/transmission/data_link_layer.py @@ -9,9 +9,10 @@ from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.protocols.packet import DataPacket from primaite.simulator.network.transmission.network_layer import IPPacket from primaite.simulator.network.transmission.primaite_layer import PrimaiteHeader -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP, TCPHeader, UDPHeader +from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader from primaite.simulator.network.utils import convert_bytes_to_megabits -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) diff --git a/src/primaite/simulator/network/transmission/network_layer.py b/src/primaite/simulator/network/transmission/network_layer.py index 47e8a032..49dcd1f5 100644 --- a/src/primaite/simulator/network/transmission/network_layer.py +++ b/src/primaite/simulator/network/transmission/network_layer.py @@ -4,7 +4,8 @@ from enum import Enum from pydantic import BaseModel from primaite import getLogger -from primaite.utils.validators import IPProtocol, IPV4Address, PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address _LOGGER = getLogger(__name__) diff --git a/src/primaite/simulator/network/transmission/transport_layer.py b/src/primaite/simulator/network/transmission/transport_layer.py index fbc4b5ad..10cf802c 100644 --- a/src/primaite/simulator/network/transmission/transport_layer.py +++ b/src/primaite/simulator/network/transmission/transport_layer.py @@ -4,38 +4,6 @@ from typing import List from pydantic import BaseModel -PORT_LOOKUP: dict[str, int] = dict( - UNUSED=-1, - NONE=0, - WOL=9, - FTP_DATA=20, - FTP=21, - SSH=22, - SMTP=25, - DNS=53, - HTTP=80, - POP3=110, - SFTP=115, - NTP=123, - IMAP=143, - SNMP=161, - SNMP_TRAP=162, - ARP=219, - LDAP=389, - HTTPS=443, - SMB=445, - IPP=631, - SQL_SERVER=1433, - MYSQL=3306, - RDP=3389, - RTP=5004, - RTP_ALT=5005, - DNS_ALT=5353, - HTTP_ALT=8080, - HTTPS_ALT=8443, - POSTGRES_SERVER=5432, -) - class UDPHeader(BaseModel): """ diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 4967f519..cd4b2a03 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -11,10 +11,11 @@ from pydantic import BaseModel from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.hardware.nodes.host.host_node import HostNode -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.software_manager import SoftwareManager -from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import PORT_LOOKUP class DatabaseClientConnection(BaseModel): diff --git a/src/primaite/simulator/system/applications/nmap.py b/src/primaite/simulator/system/applications/nmap.py index 34433e65..a04067c4 100644 --- a/src/primaite/simulator/system/applications/nmap.py +++ b/src/primaite/simulator/system/applications/nmap.py @@ -7,9 +7,10 @@ from pydantic import validate_call from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application -from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import IPProtocol, is_valid_protocol, PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import is_valid_port, Port, PORT_LOOKUP class PortScanPayload(SimComponent): @@ -23,8 +24,8 @@ class PortScanPayload(SimComponent): """ ip_address: IPV4Address - port: int - protocol: str + port: Port + protocol: IPProtocol request: bool = True def describe_state(self) -> Dict: @@ -217,14 +218,14 @@ class NMAP(Application, identifier="NMAP"): print(table.get_string(sortby="IP Address")) return active_nodes - def _determine_port_scan_type(self, target_ip_addresses: List[IPV4Address], target_ports: List[int]) -> str: + def _determine_port_scan_type(self, target_ip_addresses: List[IPV4Address], target_ports: List[Port]) -> str: """ Determine the type of port scan based on the number of target IP addresses and ports. :param target_ip_addresses: The list of target IP addresses. :type target_ip_addresses: List[IPV4Address] :param target_ports: The list of target ports. - :type target_ports: List[int] + :type target_ports: List[Port] :return: The type of port scan. :rtype: str @@ -237,8 +238,8 @@ class NMAP(Application, identifier="NMAP"): def _check_port_open_on_ip_address( self, ip_address: IPv4Address, - port: int, - protocol: str, + port: Port, + protocol: IPProtocol, is_re_attempt: bool = False, port_scan_uuid: Optional[str] = None, ) -> bool: @@ -250,7 +251,7 @@ class NMAP(Application, identifier="NMAP"): :param port: The target port. :type port: Port :param protocol: The protocol used for the port scan. - :type protocol: str + :type protocol: IPProtocol :param is_re_attempt: Flag indicating if this is a reattempt. Defaults to False. :type is_re_attempt: bool :param port_scan_uuid: The UUID of the port scan payload. Defaults to None. @@ -319,20 +320,20 @@ class NMAP(Application, identifier="NMAP"): def port_scan( self, target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]], - target_protocol: Optional[Union[str, List[str]]] = None, - target_port: Optional[Union[int, List[int]]] = None, + target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None, + target_port: Optional[Union[Port, List[Port]]] = None, show: bool = True, json_serializable: bool = False, - ) -> Dict[IPv4Address, Dict[str, List[int]]]: + ) -> Dict[IPv4Address, Dict[IPProtocol, List[Port]]]: """ Perform a port scan on the target IP address(es). :param target_ip_address: The target IP address(es) or network(s) for the port scan. :type target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]] :param target_protocol: The protocol(s) to use for the port scan. Defaults to None, which includes TCP and UDP. - :type target_protocol: Optional[Union[str, List[str]]] + :type target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] :param target_port: The port(s) to scan. Defaults to None, which includes all valid ports. - :type target_port: Optional[Union[int, List[int]]] + :type target_port: Optional[Union[Port, List[Port]]] :param show: Flag indicating whether to display the scan results. Defaults to True. :type show: bool :param json_serializable: Flag indicating whether the return value should be JSON serializable. Defaults to @@ -340,16 +341,16 @@ class NMAP(Application, identifier="NMAP"): :type json_serializable: bool :return: A dictionary mapping IP addresses to protocols and lists of open ports. - :rtype: Dict[IPv4Address, Dict[str, List[int]]] + :rtype: Dict[IPv4Address, Dict[IPProtocol, List[Port]]] """ ip_addresses = self._explode_ip_address_network_array(target_ip_address) - if isinstance(target_port, int): + if is_valid_port(target_port): target_port = [target_port] elif target_port is None: target_port = [port for port in PORT_LOOKUP if port not in {PORT_LOOKUP["NONE"], PORT_LOOKUP["UNUSED"]}] - if isinstance(target_protocol, str): + if is_valid_protocol(target_protocol): target_protocol = [target_protocol] elif target_protocol is None: target_protocol = [PROTOCOL_LOOKUP["TCP"], PROTOCOL_LOOKUP["UDP"]] @@ -389,12 +390,12 @@ class NMAP(Application, identifier="NMAP"): def network_service_recon( self, target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]], - target_protocol: Optional[Union[str, List[str]]] = None, - target_port: Optional[Union[int, List[int]]] = None, + target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None, + target_port: Optional[Union[Port, List[Port]]] = None, show: bool = True, show_online_only: bool = True, json_serializable: bool = False, - ) -> Dict[IPv4Address, Dict[str, List[int]]]: + ) -> Dict[IPv4Address, Dict[IPProtocol, List[Port]]]: """ Perform a network service reconnaissance which includes a ping scan followed by a port scan. @@ -407,9 +408,9 @@ class NMAP(Application, identifier="NMAP"): :param target_ip_address: The target IP address(es) or network(s) for the port scan. :type target_ip_address: Union[IPV4Address, List[IPV4Address], IPv4Network, List[IPv4Network]] :param target_protocol: The protocol(s) to use for the port scan. Defaults to None, which includes TCP and UDP. - :type target_protocol: Optional[Union[str, List[str]]] + :type target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] :param target_port: The port(s) to scan. Defaults to None, which includes all valid ports. - :type target_port: Optional[Union[int, List[int]]] + :type target_port: Optional[Union[Port, List[Port]]] :param show: Flag indicating whether to display the scan results. Defaults to True. :type show: bool :param show_online_only: Flag indicating whether to show only the online hosts. Defaults to True. @@ -419,7 +420,7 @@ class NMAP(Application, identifier="NMAP"): :type json_serializable: bool :return: A dictionary mapping IP addresses to protocols and lists of open ports. - :rtype: Dict[IPv4Address, Dict[str, List[int]]] + :rtype: Dict[IPv4Address, Dict[IPProtocol, List[Port]]] """ ping_scan_results = self.ping_scan( target_ip_address=target_ip_address, show=show, show_online_only=show_online_only, json_serializable=False diff --git a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py index b0cdefba..aff12748 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py @@ -9,14 +9,14 @@ from pydantic import BaseModel, Field, validate_call from primaite.interface.request import RequestResponse from primaite.simulator.file_system.file_system import FileSystem, Folder from primaite.simulator.network.protocols.masquerade import C2Packet -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application, ApplicationOperatingState from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.software import SoftwareHealthState -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import IPProtocol, is_valid_protocol, PROTOCOL_LOOKUP +from primaite.utils.validation.port import is_valid_port, Port, PORT_LOOKUP class C2Command(Enum): @@ -81,10 +81,10 @@ class AbstractC2(Application, identifier="AbstractC2"): keep_alive_frequency: int = Field(default=5, ge=1) """The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon.""" - masquerade_protocol: str = Field(default=PROTOCOL_LOOKUP["TCP"]) + masquerade_protocol: IPProtocol = Field(default=PROTOCOL_LOOKUP["TCP"]) """The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP.""" - masquerade_port: int = Field(default=PORT_LOOKUP["HTTP"]) + masquerade_port: Port = Field(default=PORT_LOOKUP["HTTP"]) """The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP.""" c2_config: _C2Opts = _C2Opts() @@ -367,7 +367,7 @@ class AbstractC2(Application, identifier="AbstractC2"): :rtype: bool """ # Validating that they are valid Enums. - if not isinstance(payload.masquerade_port, int) or not isinstance(payload.masquerade_protocol, str): + if not is_valid_port(payload.masquerade_port) or not is_valid_protocol(payload.masquerade_protocol): self.sys_log.warning( f"{self.name}: Received invalid Masquerade Values within Keep Alive." f"Port: {payload.masquerade_port} Protocol: {payload.masquerade_protocol}." diff --git a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py index 450c60ad..c0c3d872 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py @@ -8,12 +8,12 @@ from pydantic import validate_call from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.protocols.masquerade import C2Packet -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.red_applications.c2 import ExfilOpts, RansomwareOpts, TerminalOpts from primaite.simulator.system.applications.red_applications.c2.abstract_c2 import AbstractC2, C2Command, C2Payload from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.terminal.terminal import Terminal, TerminalClientConnection -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP class C2Beacon(AbstractC2, identifier="C2Beacon"): diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index c2d19160..9fdbae57 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -7,10 +7,10 @@ from primaite import getLogger from primaite.game.science import simulate_trial from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index 7e199b48..fb2c8847 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -7,8 +7,8 @@ from primaite import getLogger from primaite.game.science import simulate_trial from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -35,7 +35,7 @@ class DoSBot(DatabaseClient, identifier="DoSBot"): target_ip_address: Optional[IPv4Address] = None """IP address of the target service.""" - target_port: Optional[int] = None + target_port: Optional[Port] = None """Port of the target service.""" payload: Optional[str] = None diff --git a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py index 56f885f4..93b4c50d 100644 --- a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py +++ b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py @@ -6,10 +6,10 @@ from prettytable import MARKDOWN, PrettyTable from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP class RansomwareScript(Application, identifier="RansomwareScript"): diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index faa7b5ec..c57a9bd3 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -15,10 +15,10 @@ from primaite.simulator.network.protocols.http import ( HttpResponsePacket, HttpStatusCode, ) -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.services.dns.dns_client import DNSClient -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -154,7 +154,7 @@ class WebBrowser(Application, identifier="WebBrowser"): self, payload: HttpRequestPacket, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[int] = PORT_LOOKUP["HTTP"], + dest_port: Optional[Port] = PORT_LOOKUP["HTTP"], session_id: Optional[str] = None, **kwargs, ) -> bool: diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index fcf07d9f..75322e86 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -11,8 +11,9 @@ from primaite.simulator.network.protocols.arp import ARPPacket from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame from primaite.simulator.network.transmission.network_layer import IPPacket -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP, TCPHeader, UDPHeader -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP if TYPE_CHECKING: from primaite.simulator.network.hardware.base import NetworkInterface @@ -37,12 +38,12 @@ class Session(SimComponent): protocol: str with_ip_address: IPv4Address - src_port: Optional[int] - dst_port: Optional[int] + src_port: Optional[Port] + dst_port: Optional[Port] connected: bool = False @classmethod - def from_session_key(cls, session_key: Tuple[str, IPv4Address, Optional[int], Optional[int]]) -> Session: + def from_session_key(cls, session_key: Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]) -> Session: """ Create a Session instance from a session key tuple. @@ -77,7 +78,9 @@ class SessionManager: """ def __init__(self, sys_log: SysLog): - self.sessions_by_key: Dict[Tuple[str, IPv4Address, IPv4Address, Optional[int], Optional[int]], Session] = {} + self.sessions_by_key: Dict[ + Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]], Session + ] = {} self.sessions_by_uuid: Dict[str, Session] = {} self.sys_log: SysLog = sys_log self.software_manager: SoftwareManager = None # Noqa @@ -102,7 +105,7 @@ class SessionManager: @staticmethod def _get_session_key( frame: Frame, inbound_frame: bool = True - ) -> Tuple[str, IPv4Address, Optional[int], Optional[int]]: + ) -> Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]: """ Extracts the session key from the given frame. @@ -110,8 +113,8 @@ class SessionManager: - IPProtocol: The transport protocol (e.g. TCP, UDP, ICMP). - IPv4Address: The source IP address. - IPv4Address: The destination IP address. - - Optional[int]: The source port number (if applicable). - - Optional[int]: The destination port number (if applicable). + - Optional[Port]: The source port number (if applicable). + - Optional[Port]: The destination port number (if applicable). :param frame: The network frame from which to extract the session key. :return: A tuple containing the session key. @@ -166,17 +169,17 @@ class SessionManager: def resolve_outbound_transmission_details( self, dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, - src_port: Optional[int] = None, - dst_port: Optional[int] = None, - protocol: Optional[str] = None, + src_port: Optional[Port] = None, + dst_port: Optional[Port] = None, + protocol: Optional[IPProtocol] = None, session_id: Optional[str] = None, ) -> Tuple[ Optional["NetworkInterface"], Optional[str], IPv4Address, - Optional[int], - Optional[int], - Optional[str], + Optional[Port], + Optional[Port], + Optional[IPProtocol], bool, ]: """ @@ -195,19 +198,19 @@ class SessionManager: treats the transmission as a broadcast to that network. Optional. :type dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] :param src_port: The source port number for the transmission. Optional. - :type src_port: Optional[int] + :type src_port: Optional[Port] :param dst_port: The destination port number for the transmission. Optional. - :type dst_port: Optional[int] + :type dst_port: Optional[Port] :param protocol: The IP protocol to be used for the transmission. Optional. - :type protocol: Optional[str] + :type protocol: Optional[IPProtocol] :param session_id: The session ID associated with the transmission. If provided, the session details override other parameters. Optional. :type session_id: Optional[str] :return: A tuple containing the resolved outbound network interface, destination MAC address, destination IP address, source port, destination port, protocol, and a boolean indicating whether the transmission is a broadcast. - :rtype: Tuple[Optional["NetworkInterface"], Optional[str], IPv4Address, Optional[int], Optional[int], - Optional[str], bool] + :rtype: Tuple[Optional["NetworkInterface"], Optional[str], IPv4Address, Optional[Port], Optional[Port], + Optional[IPProtocol], bool] """ if dst_ip_address and not isinstance(dst_ip_address, (IPv4Address, IPv4Network)): dst_ip_address = IPv4Address(dst_ip_address) @@ -258,10 +261,10 @@ class SessionManager: self, payload: Any, dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, - src_port: Optional[int] = None, - dst_port: Optional[int] = None, + src_port: Optional[Port] = None, + dst_port: Optional[Port] = None, session_id: Optional[str] = None, - ip_protocol: str = PROTOCOL_LOOKUP["TCP"], + ip_protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"], icmp_packet: Optional[ICMPPacket] = None, ) -> Union[Any, None]: """ diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index abf2ca3a..60621384 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -8,12 +8,12 @@ from prettytable import MARKDOWN, PrettyTable from primaite.simulator.core import RequestType from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application, ApplicationOperatingState from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.service import Service, ServiceOperatingState from primaite.simulator.system.software import IOSoftware -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP if TYPE_CHECKING: from primaite.simulator.system.core.session_manager import SessionManager @@ -52,7 +52,7 @@ class SoftwareManager: self.session_manager = session_manager self.software: Dict[str, Union[Service, Application]] = {} self._software_class_to_name_map: Dict[Type[IOSoftware], str] = {} - self.port_protocol_mapping: Dict[Tuple[int, str], Union[Service, Application]] = {} + self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {} self.sys_log: SysLog = sys_log self.file_system: FileSystem = file_system self.dns_server: Optional[IPv4Address] = dns_server @@ -67,7 +67,7 @@ class SoftwareManager: """Provides access to the ICMP service instance, if installed.""" return self.software.get("ICMP") # noqa - def get_open_ports(self) -> List[int]: + def get_open_ports(self) -> List[Port]: """ Get a list of open ports. @@ -81,7 +81,7 @@ class SoftwareManager: open_ports += list(software.listen_on_ports) return open_ports - def check_port_is_open(self, port: int, protocol: str) -> bool: + def check_port_is_open(self, port: Port, protocol: IPProtocol) -> bool: """ Check if a specific port is open and running a service using the specified protocol. @@ -93,7 +93,7 @@ class SoftwareManager: :param port: The port to check. :type port: Port :param protocol: The protocol to check (e.g., TCP, UDP). - :type protocol: str + :type protocol: IPProtocol :return: True if the port is open and a service is running on it using the specified protocol, False otherwise. :rtype: bool """ @@ -189,9 +189,9 @@ class SoftwareManager: self, payload: Any, dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, - src_port: Optional[int] = None, - dest_port: Optional[int] = None, - ip_protocol: str = PROTOCOL_LOOKUP["TCP"], + src_port: Optional[Port] = None, + dest_port: Optional[Port] = None, + ip_protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"], session_id: Optional[str] = None, ) -> bool: """ @@ -219,8 +219,8 @@ class SoftwareManager: def receive_payload_from_session_manager( self, payload: Any, - port: int, - protocol: str, + port: Port, + protocol: IPProtocol, session_id: str, from_network_interface: "NIC", frame: Frame, diff --git a/src/primaite/simulator/system/services/arp/arp.py b/src/primaite/simulator/system/services/arp/arp.py index 2641f1c8..816eb99e 100644 --- a/src/primaite/simulator/system/services/arp/arp.py +++ b/src/primaite/simulator/system/services/arp/arp.py @@ -8,9 +8,10 @@ from prettytable import MARKDOWN, PrettyTable from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.service import Service -from primaite.utils.validators import IPV4Address, PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import PORT_LOOKUP class ARP(Service): diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index f9a5d087..b7cd8886 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -7,12 +7,12 @@ from primaite import getLogger from primaite.simulator.file_system.file_system import File from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus from primaite.simulator.file_system.folder import Folder -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.service import Service, ServiceOperatingState from primaite.simulator.system.software import SoftwareHealthState -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 316189a7..78642fa6 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -4,10 +4,10 @@ from typing import Dict, Optional from primaite import getLogger from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -110,7 +110,7 @@ class DNSClient(Service): payload: DNSPacket, session_id: Optional[str] = None, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[int] = None, + dest_port: Optional[Port] = None, **kwargs, ) -> bool: """ diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index e0786124..5b380320 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -6,9 +6,9 @@ from prettytable import MARKDOWN, PrettyTable from primaite import getLogger from primaite.simulator.network.protocols.dns import DNSPacket -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.service import Service -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index 11a926cf..00b70332 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -7,10 +7,10 @@ from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.file_system.file_system import File from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -104,7 +104,7 @@ class FTPClient(FTPServiceABC): def _connect_to_server( self, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[int] = PORT_LOOKUP["FTP"], + dest_port: Optional[Port] = PORT_LOOKUP["FTP"], session_id: Optional[str] = None, is_reattempt: Optional[bool] = False, ) -> bool: @@ -114,7 +114,7 @@ class FTPClient(FTPServiceABC): :param: dest_ip_address: IP address of the FTP server the client needs to connect to. Optional. :type: dest_ip_address: Optional[IPv4Address] :param: dest_port: Port of the FTP server the client needs to connect to. Optional. - :type: dest_port: Optional[int] + :type: dest_port: Optional[Port] :param: is_reattempt: Set to True if attempt to connect to FTP Server has been attempted. Default False. :type: is_reattempt: Optional[bool] """ @@ -152,7 +152,7 @@ class FTPClient(FTPServiceABC): return False def _disconnect_from_server( - self, dest_ip_address: Optional[IPv4Address] = None, dest_port: Optional[int] = PORT_LOOKUP["FTP"] + self, dest_ip_address: Optional[IPv4Address] = None, dest_port: Optional[Port] = PORT_LOOKUP["FTP"] ) -> bool: """ Connects the client from a given FTP server. @@ -160,7 +160,7 @@ class FTPClient(FTPServiceABC): :param: dest_ip_address: IP address of the FTP server the client needs to disconnect from. Optional. :type: dest_ip_address: Optional[IPv4Address] :param: dest_port: Port of the FTP server the client needs to disconnect from. Optional. - :type: dest_port: Optional[int] + :type: dest_port: Optional[Port] :param: is_reattempt: Set to True if attempt to disconnect from FTP Server has been attempted. Default False. :type: is_reattempt: Optional[bool] """ @@ -179,7 +179,7 @@ class FTPClient(FTPServiceABC): src_file_name: str, dest_folder_name: str, dest_file_name: str, - dest_port: Optional[int] = PORT_LOOKUP["FTP"], + dest_port: Optional[Port] = PORT_LOOKUP["FTP"], session_id: Optional[str] = None, ) -> bool: """ @@ -204,7 +204,7 @@ class FTPClient(FTPServiceABC): :type: dest_file_name: str :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"]. - :type: dest_port: Optional[int] + :type: dest_port: Optional[Port] :param: session_id: The id of the session :type: session_id: Optional[str] @@ -241,7 +241,7 @@ class FTPClient(FTPServiceABC): src_file_name: str, dest_folder_name: str, dest_file_name: str, - dest_port: Optional[int] = PORT_LOOKUP["FTP"], + dest_port: Optional[Port] = PORT_LOOKUP["FTP"], ) -> bool: """ Request a file from a target IP address. diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index 38a253be..671200f5 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -3,9 +3,9 @@ from typing import Any, Optional from primaite import getLogger from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import is_valid_port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -52,7 +52,7 @@ class FTPServer(FTPServiceABC): # process server specific commands, otherwise call super if payload.ftp_command == FTPCommand.PORT: # check that the port is valid - if isinstance(payload.ftp_command_args, int) and (0 <= payload.ftp_command_args < 65535): + if is_valid_port(payload.ftp_command_args): # return successful connection self.add_connection(connection_id=session_id, session_id=session_id) payload.status_code = FTPStatusCode.OK diff --git a/src/primaite/simulator/system/services/ftp/ftp_service.py b/src/primaite/simulator/system/services/ftp/ftp_service.py index 49678c82..77d82997 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_service.py +++ b/src/primaite/simulator/system/services/ftp/ftp_service.py @@ -6,6 +6,7 @@ from typing import Dict, Optional from primaite.simulator.file_system.file_system import File from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode from primaite.simulator.system.services.service import Service +from primaite.utils.validation.port import Port class FTPServiceABC(Service, ABC): @@ -77,7 +78,7 @@ class FTPServiceABC(Service, ABC): dest_folder_name: str, dest_file_name: str, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[int] = None, + dest_port: Optional[Port] = None, session_id: Optional[str] = None, is_response: bool = False, ) -> bool: @@ -97,7 +98,7 @@ class FTPServiceABC(Service, ABC): :type: dest_ip_address: Optional[IPv4Address] :param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"]. - :type: dest_port: Optional[int] + :type: dest_port: Optional[Port] :param: session_id: session ID linked to the FTP Packet. Optional. :type: session_id: Optional[str] @@ -167,7 +168,7 @@ class FTPServiceABC(Service, ABC): payload: FTPPacket, session_id: Optional[str] = None, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[int] = None, + dest_port: Optional[Port] = None, **kwargs, ) -> bool: """ diff --git a/src/primaite/simulator/system/services/icmp/icmp.py b/src/primaite/simulator/system/services/icmp/icmp.py index 486ba2b0..84ad995d 100644 --- a/src/primaite/simulator/system/services/icmp/icmp.py +++ b/src/primaite/simulator/system/services/icmp/icmp.py @@ -7,9 +7,9 @@ from primaite import getLogger from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.service import Service -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index 184833e1..ed89971f 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -5,9 +5,9 @@ from typing import Dict, Optional from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.service import Service, ServiceOperatingState -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -55,7 +55,7 @@ class NTPClient(Service): payload: NTPPacket, session_id: Optional[str] = None, dest_ip_address: IPv4Address = None, - dest_port: int = PORT_LOOKUP["NTP"], + dest_port: Port = PORT_LOOKUP["NTP"], **kwargs, ) -> bool: """Requests NTP data from NTP server. diff --git a/src/primaite/simulator/system/services/ntp/ntp_server.py b/src/primaite/simulator/system/services/ntp/ntp_server.py index 4764bffb..b674a296 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_server.py +++ b/src/primaite/simulator/system/services/ntp/ntp_server.py @@ -4,9 +4,9 @@ from typing import Dict, Optional from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.service import Service -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 2b0bc02b..ae3557f7 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -17,10 +17,10 @@ from primaite.simulator.network.protocols.ssh import ( SSHTransportMessage, SSHUserCredentials, ) -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service, ServiceOperatingState -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP # TODO 2824: Since remote terminal connections and remote user sessions are the same thing, we could refactor diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 2805b1b2..75d9c472 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -10,11 +10,11 @@ from primaite.simulator.network.protocols.http import ( HttpResponsePacket, HttpStatusCode, ) -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClientConnection from primaite.simulator.system.services.service import Service from primaite.simulator.system.software import SoftwareHealthState -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -145,7 +145,7 @@ class WebServer(Service): payload: HttpResponsePacket, session_id: Optional[str] = None, dest_ip_address: Optional[IPv4Address] = None, - dest_port: Optional[int] = None, + dest_port: Optional[Port] = None, **kwargs, ) -> bool: """ diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index d34678b9..6fb09a16 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -15,7 +15,8 @@ from primaite.simulator.file_system.file_system import FileSystem, Folder from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.core.sys_log import SysLog -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP +from primaite.utils.validation.port import Port if TYPE_CHECKING: from primaite.simulator.system.core.software_manager import SoftwareManager @@ -250,11 +251,11 @@ class IOSoftware(Software): "Indicates if the software uses TCP protocol for communication. Default is True." udp: bool = True "Indicates if the software uses UDP protocol for communication. Default is True." - port: int + port: Port "The port to which the software is connected." - listen_on_ports: Set[int] = Field(default_factory=set) + listen_on_ports: Set[Port] = Field(default_factory=set) "The set of ports to listen on." - protocol: str + protocol: IPProtocol "The IP Protocol the Software operates on." _connections: Dict[str, Dict] = {} "Active connections." @@ -386,7 +387,7 @@ class IOSoftware(Software): session_id: Optional[str] = None, dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, dest_port: Optional[int] = None, - ip_protocol: str = PROTOCOL_LOOKUP["TCP"], + ip_protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"], **kwargs, ) -> bool: """ diff --git a/src/primaite/utils/validation/__init__.py b/src/primaite/utils/validation/__init__.py new file mode 100644 index 00000000..be6c00e7 --- /dev/null +++ b/src/primaite/utils/validation/__init__.py @@ -0,0 +1 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK diff --git a/src/primaite/utils/validation/ip_protocol.py b/src/primaite/utils/validation/ip_protocol.py new file mode 100644 index 00000000..4e358305 --- /dev/null +++ b/src/primaite/utils/validation/ip_protocol.py @@ -0,0 +1,47 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# Define a custom IP protocol validator +from typing import Any + +from pydantic import BeforeValidator, TypeAdapter, ValidationError +from typing_extensions import Annotated, Final + +PROTOCOL_LOOKUP: dict[str, str] = dict( + NONE="none", + TCP="tcp", + UDP="udp", + ICMP="icmp", +) +""" +Lookup table used for compatibility with PrimAITE <= 3.3. Configs with the capitalised protocol names are converted +to lowercase at runtime. +""" +VALID_PROTOCOLS = ["none", "tcp", "udp", "icmp"] +"""Supported protocols.""" + + +def protocol_validator(v: Any) -> str: + """ + Validate that IP Protocols are chosen from the list of supported IP Protocols. + + The protocol list is dynamic because plugins are able to extend it, therefore it is necessary to use this custom + validator instead of being able to specify a union of string literals. + """ + if isinstance(v, str) and v in PROTOCOL_LOOKUP: + return PROTOCOL_LOOKUP[v] + if v in VALID_PROTOCOLS: + return v + raise ValueError(f"{v} is not a valid IP Protocol. It must be one of the following: {VALID_PROTOCOLS}") + + +IPProtocol: Final[Annotated] = Annotated[str, BeforeValidator(protocol_validator)] +"""Validates that IP Protocols used in the simulation belong to the list of supported protocols.""" +_IPProtocolTypeAdapter = TypeAdapter(IPProtocol) + + +def is_valid_protocol(v: Any) -> bool: + """Convenience method to return true if the value matches the schema, and false otherwise.""" + try: + _IPProtocolTypeAdapter.validate_python(v) + return True + except ValidationError: + return False diff --git a/src/primaite/utils/validators.py b/src/primaite/utils/validation/ipv4_address.py similarity index 59% rename from src/primaite/utils/validators.py rename to src/primaite/utils/validation/ipv4_address.py index f07b475d..eb0e2574 100644 --- a/src/primaite/utils/validators.py +++ b/src/primaite/utils/validation/ipv4_address.py @@ -1,4 +1,6 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + + from ipaddress import IPv4Address from typing import Any, Final @@ -37,39 +39,3 @@ will automatically check and convert the input value to an instance of IPv4Addre any Pydantic model uses it. This ensures that any field marked with this type is not just an IPv4Address in form, but also valid according to the rules defined in ipv4_validator. """ - -# Define a custom port validator -Port: Final[Annotated] = Annotated[int, BeforeValidator(lambda n: 0 <= n <= 65535)] -"""Validates that network ports lie in the appropriate range of [0,65535].""" - -# Define a custom IP protocol validator -PROTOCOL_LOOKUP: dict[str, str] = dict( - NONE="none", - TCP="tcp", - UDP="udp", - ICMP="icmp", -) -""" -Lookup table used for compatibility with PrimAITE <= 3.3. Configs with the capitalised protocol names are converted -to lowercase at runtime. -""" -VALID_PROTOCOLS = ["none", "tcp", "udp", "icmp"] -"""Supported protocols.""" - - -def protocol_validator(v: Any) -> str: - """ - Validate that IP Protocols are chosen from the list of supported IP Protocols. - - The protocol list is dynamic because plugins are able to extend it, therefore it is necessary to use this custom - validator instead of being able to specify a union of string literals. - """ - if v in PROTOCOL_LOOKUP: - return PROTOCOL_LOOKUP(v) - if v in VALID_PROTOCOLS: - return v - raise ValueError(f"{v} is not a valid IP Protocol. It must be one of the following: {VALID_PROTOCOLS}") - - -IPProtocol: Final[Annotated] = Annotated[str, BeforeValidator(protocol_validator)] -"""Validates that IP Protocols used in the simulation belong to the list of supported protocols.""" diff --git a/src/primaite/utils/validation/port.py b/src/primaite/utils/validation/port.py new file mode 100644 index 00000000..90c36add --- /dev/null +++ b/src/primaite/utils/validation/port.py @@ -0,0 +1,70 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +# Define a custom port validator +from typing import Any + +from pydantic import BeforeValidator, TypeAdapter, ValidationError +from typing_extensions import Annotated, Final + +PORT_LOOKUP: dict[str, int] = dict( + UNUSED=-1, + NONE=0, + WOL=9, + FTP_DATA=20, + FTP=21, + SSH=22, + SMTP=25, + DNS=53, + HTTP=80, + POP3=110, + SFTP=115, + NTP=123, + IMAP=143, + SNMP=161, + SNMP_TRAP=162, + ARP=219, + LDAP=389, + HTTPS=443, + SMB=445, + IPP=631, + SQL_SERVER=1433, + MYSQL=3306, + RDP=3389, + RTP=5004, + RTP_ALT=5005, + DNS_ALT=5353, + HTTP_ALT=8080, + HTTPS_ALT=8443, + POSTGRES_SERVER=5432, +) +""" +Lookup table used for compatibility with PrimAITE <= 3.3. Configs with named ports names are converted +to port integers at runtime. +""" + + +def port_validator(v: Any) -> int: + """ + Validate that Ports are chosen from the list of supported Ports. + + The protocol list is dynamic because plugins are able to extend it, therefore it is necessary to use this custom + validator instead of being able to specify a union of string literals. + """ + if isinstance(v, str) and v in PORT_LOOKUP: + v = PORT_LOOKUP[v] + if isinstance(v, int) and (0 <= v <= 65535): + return v + raise ValueError(f"{v} is not a valid Port. It must be an integer in the range [0,65535] or ") + + +Port: Final[Annotated] = Annotated[int, BeforeValidator(port_validator)] +"""Validates that network ports lie in the appropriate range of [0,65535].""" +_PortTypeAdapter = TypeAdapter(Port) + + +def is_valid_port(v: Any) -> bool: + """Convenience method to return true if the value matches the schema, and false otherwise.""" + try: + _PortTypeAdapter.validate_python(v) + return True + except ValidationError: + return False diff --git a/tests/conftest.py b/tests/conftest.py index 687bec92..64fe0699 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,6 @@ from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.networks import arcd_uc2_network -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.web_browser import WebBrowser @@ -27,7 +26,8 @@ from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT rayinit() diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py index 6d0ef7b0..7f251613 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py @@ -9,8 +9,8 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests.integration_tests.configuration_file_parsing import BASIC_FIREWALL, DMZ_NETWORK, load_config diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py index c348ee81..d10c7dbb 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py @@ -6,8 +6,8 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests.integration_tests.configuration_file_parsing import DMZ_NETWORK, load_config diff --git a/tests/integration_tests/extensions/applications/extended_application.py b/tests/integration_tests/extensions/applications/extended_application.py index 28029b32..70dc7cba 100644 --- a/tests/integration_tests/extensions/applications/extended_application.py +++ b/tests/integration_tests/extensions/applications/extended_application.py @@ -15,11 +15,11 @@ from primaite.simulator.network.protocols.http import ( HttpResponsePacket, HttpStatusCode, ) -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.dns.dns_client import DNSClient -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) diff --git a/tests/integration_tests/extensions/nodes/super_computer.py b/tests/integration_tests/extensions/nodes/super_computer.py index 55bdce09..80f7e3c3 100644 --- a/tests/integration_tests/extensions/nodes/super_computer.py +++ b/tests/integration_tests/extensions/nodes/super_computer.py @@ -3,7 +3,7 @@ from typing import ClassVar, Dict from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC from primaite.simulator.system.services.ftp.ftp_client import FTPClient -from primaite.utils.validators import IPV4Address +from primaite.utils.validation.ipv4_address import IPV4Address class SuperComputer(HostNode, identifier="supercomputer"): diff --git a/tests/integration_tests/extensions/services/extended_service.py b/tests/integration_tests/extensions/services/extended_service.py index 70d47aaa..ddaf4a1e 100644 --- a/tests/integration_tests/extensions/services/extended_service.py +++ b/tests/integration_tests/extensions/services/extended_service.py @@ -7,12 +7,12 @@ from primaite import getLogger from primaite.simulator.file_system.file_system import File from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus from primaite.simulator.file_system.folder import Folder -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.service import Service, ServiceOperatingState from primaite.simulator.system.software import SoftwareHealthState -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) diff --git a/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py b/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py index 2c750621..187fb1fe 100644 --- a/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py +++ b/tests/integration_tests/game_layer/actions/test_c2_suite_actions.py @@ -11,13 +11,13 @@ from primaite.simulator.network.hardware.base import UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.red_applications.c2.c2_beacon import C2Beacon from primaite.simulator.system.applications.red_applications.c2.c2_server import C2Command, C2Server from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture diff --git a/tests/integration_tests/game_layer/actions/test_configure_actions.py b/tests/integration_tests/game_layer/actions/test_configure_actions.py index b56a4b99..508bd5a4 100644 --- a/tests/integration_tests/game_layer/actions/test_configure_actions.py +++ b/tests/integration_tests/game_layer/actions/test_configure_actions.py @@ -11,12 +11,12 @@ from primaite.game.agent.actions import ( ) from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.database.database_service import DatabaseService +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT from tests.conftest import ControlledAgent diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py index bc168c3c..a70cea72 100644 --- a/tests/integration_tests/game_layer/actions/test_terminal_actions.py +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -9,9 +9,9 @@ from primaite.simulator.network.hardware.base import UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture diff --git a/tests/integration_tests/game_layer/observations/test_acl_observations.py b/tests/integration_tests/game_layer/observations/test_acl_observations.py index 2bf0486c..e7212f3c 100644 --- a/tests/integration_tests/game_layer/observations/test_acl_observations.py +++ b/tests/integration_tests/game_layer/observations/test_acl_observations.py @@ -4,10 +4,10 @@ import pytest from primaite.game.agent.observations.acl_observation import ACLObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.sim_container import Simulation from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.ntp.ntp_server import NTPServer +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index af8c4669..05cf910c 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -5,8 +5,8 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import ACLAction from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP def check_default_rules(acl_obs): diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index cdd428b0..4ced02f5 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -8,9 +8,9 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.sim_container import Simulation -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP def test_router_observation(): diff --git a/tests/integration_tests/game_layer/observations/test_user_observations.py b/tests/integration_tests/game_layer/observations/test_user_observations.py index 70637b0d..e7287eee 100644 --- a/tests/integration_tests/game_layer/observations/test_user_observations.py +++ b/tests/integration_tests/game_layer/observations/test_user_observations.py @@ -3,7 +3,7 @@ import pytest from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml" diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index 2675b615..e03a7d26 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -21,11 +21,11 @@ from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.software import SoftwareHealthState -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT FIREWALL_ACTIONS_NETWORK = TEST_ASSETS_ROOT / "configs/firewall_actions_network.yaml" diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 0afe666c..0005b508 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -9,11 +9,11 @@ from primaite.interface.request import RequestResponse from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT from tests.conftest import ControlledAgent diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index b5b2acbc..f07f02e7 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -8,10 +8,10 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import Application from primaite.simulator.system.services.service import Service -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP class BroadcastTestService(Service): diff --git a/tests/integration_tests/network/test_firewall.py b/tests/integration_tests/network/test_firewall.py index 58763c3e..79452318 100644 --- a/tests/integration_tests/network/test_firewall.py +++ b/tests/integration_tests/network/test_firewall.py @@ -7,10 +7,10 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import ACLAction -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.ntp.ntp_server import NTPServer -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") diff --git a/tests/integration_tests/network/test_routing.py b/tests/integration_tests/network/test_routing.py index dde66a43..04cdbe78 100644 --- a/tests/integration_tests/network/test_routing.py +++ b/tests/integration_tests/network/test_routing.py @@ -6,10 +6,10 @@ import pytest from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.ntp.ntp_server import NTPServer -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") diff --git a/tests/integration_tests/network/test_wireless_router.py b/tests/integration_tests/network/test_wireless_router.py index 520ec21a..fb0035e9 100644 --- a/tests/integration_tests/network/test_wireless_router.py +++ b/tests/integration_tests/network/test_wireless_router.py @@ -7,8 +7,8 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT diff --git a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py index b1979154..2cbd4d11 100644 --- a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py +++ b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py @@ -13,7 +13,6 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import AccessControlList, ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.c2.c2_beacon import C2Beacon @@ -24,7 +23,8 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.web_server.web_server import WebServer -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT diff --git a/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py b/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py index 54c372e4..50b0ceac 100644 --- a/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py +++ b/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py @@ -9,7 +9,6 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection from primaite.simulator.system.applications.red_applications.data_manipulation_bot import ( @@ -19,6 +18,7 @@ from primaite.simulator.system.applications.red_applications.data_manipulation_b from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") diff --git a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py index ad0a519b..1a09e875 100644 --- a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py +++ b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py @@ -8,12 +8,12 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") diff --git a/tests/integration_tests/system/red_applications/test_ransomware_script.py b/tests/integration_tests/system/red_applications/test_ransomware_script.py index 09cbcf85..a5adbb04 100644 --- a/tests/integration_tests/system/red_applications/test_ransomware_script.py +++ b/tests/integration_tests/system/red_applications/test_ransomware_script.py @@ -9,11 +9,11 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") diff --git a/tests/integration_tests/system/test_nmap.py b/tests/integration_tests/system/test_nmap.py index c1c4df82..c52b5caa 100644 --- a/tests/integration_tests/system/test_nmap.py +++ b/tests/integration_tests/system/test_nmap.py @@ -5,9 +5,9 @@ from ipaddress import IPv4Address, IPv4Network import yaml from primaite.game.game import PrimaiteGame -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.nmap import NMAP -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT diff --git a/tests/integration_tests/system/test_service_listening_on_ports.py b/tests/integration_tests/system/test_service_listening_on_ports.py index 4108041d..7a085ee1 100644 --- a/tests/integration_tests/system/test_service_listening_on_ports.py +++ b/tests/integration_tests/system/test_service_listening_on_ports.py @@ -6,11 +6,11 @@ from pydantic import Field from primaite.game.game import PrimaiteGame from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.service import Service -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT diff --git a/tests/integration_tests/system/test_web_client_server_and_database.py b/tests/integration_tests/system/test_web_client_server_and_database.py index 854ef41b..f2ac1183 100644 --- a/tests/integration_tests/system/test_web_client_server_and_database.py +++ b/tests/integration_tests/system/test_web_client_server_and_database.py @@ -9,7 +9,6 @@ from primaite.simulator.network.hardware.base import Link from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService @@ -17,6 +16,7 @@ from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.web_server.web_server import WebServer from primaite.simulator.system.software import SoftwareHealthState +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py index 7813628c..a767f365 100644 --- a/tests/integration_tests/test_simulation/test_request_response.py +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -12,7 +12,7 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP from tests.conftest import DummyApplication, DummyService diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py index ba7628c2..6eca0c44 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py @@ -8,8 +8,9 @@ from primaite.simulator.network.hardware.nodes.network.router import ACLAction, from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame from primaite.simulator.network.transmission.network_layer import IPPacket -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP, TCPHeader, UDPHeader -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py index 0e1844c4..fe9387de 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py @@ -2,8 +2,8 @@ from ipaddress import IPv4Address from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP def test_wireless_router_from_config(): diff --git a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py index 9e9a1f72..e7e425b1 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_transmission/test_data_link_layer.py @@ -5,8 +5,9 @@ from primaite.simulator.network.protocols.icmp import ICMPPacket from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame from primaite.simulator.network.transmission.network_layer import IPPacket, Precedence from primaite.simulator.network.transmission.primaite_layer import AgentSource, DataStatus -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP, TCPFlags, TCPHeader, UDPHeader -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.simulator.network.transmission.transport_layer import TCPFlags, TCPHeader, UDPHeader +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP def test_frame_minimal_instantiation(): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py index fde70616..12dddf67 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py @@ -4,11 +4,11 @@ import pytest from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.red_applications.c2.c2_beacon import C2Beacon from primaite.simulator.system.applications.red_applications.c2.c2_server import C2Command, C2Server -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py index f4750158..34a29cd0 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py @@ -3,13 +3,13 @@ import pytest from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.networks import arcd_uc2_network -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.red_applications.data_manipulation_bot import ( DataManipulationAttackStage, DataManipulationBot, ) -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py index d0c65266..e9762476 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py @@ -5,9 +5,9 @@ import pytest from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py index f5781485..f1be475a 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py @@ -4,10 +4,10 @@ import pytest from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.protocols.http import HttpResponsePacket, HttpStatusCode -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.web_browser import WebBrowser -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py index 09099c5c..db7e8d58 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py @@ -6,10 +6,10 @@ import pytest from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.protocols.dns import DNSPacket, DNSReply, DNSRequest -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.service import ServiceOperatingState -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py index 688bfd7d..c64602c0 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py @@ -8,10 +8,10 @@ from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py index b4fe8633..95788834 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py @@ -8,10 +8,10 @@ from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.service import ServiceOperatingState -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py index 3f10db4d..291cdede 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py @@ -6,10 +6,10 @@ from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.service import ServiceOperatingState -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index f2895091..9b6a4bf3 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -18,13 +18,13 @@ from primaite.simulator.network.protocols.ssh import ( SSHTransportMessage, SSHUserCredentials, ) -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection, Terminal from primaite.simulator.system.services.web_server.web_server import WebServer -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py index c78a381e..54f86ec8 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py @@ -9,9 +9,9 @@ from primaite.simulator.network.protocols.http import ( HttpResponsePacket, HttpStatusCode, ) -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.services.web_server.web_server import WebServer -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") diff --git a/tests/unit_tests/_primaite/_simulator/_system/test_software.py b/tests/unit_tests/_primaite/_simulator/_system/test_software.py index 1baaf88e..300f8d9d 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/test_software.py +++ b/tests/unit_tests/_primaite/_simulator/_system/test_software.py @@ -3,11 +3,11 @@ from typing import Dict import pytest -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.service import Service from primaite.simulator.system.software import IOSoftware, SoftwareHealthState -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP class TestSoftware(Service): diff --git a/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py b/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py index 10ed36e0..1a1848ac 100644 --- a/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py +++ b/tests/unit_tests/_primaite/_utils/test_dict_enum_keys_conversion.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from primaite.simulator.network.transmission.transport_layer import PORT_LOOKUP from primaite.utils.converters import convert_dict_enum_keys_to_enum_values -from primaite.utils.validators import PROTOCOL_LOOKUP +from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP +from primaite.utils.validation.port import PORT_LOOKUP def test_simple_conversion(): From c3eb093144bf96277b2d62f3ddd1797861602234 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 25 Sep 2024 16:50:01 +0100 Subject: [PATCH 5/9] remove temporary notebook --- notebooks/test.ipynb | 157 ------------------------------------------- 1 file changed, 157 deletions(-) delete mode 100644 notebooks/test.ipynb diff --git a/notebooks/test.ipynb b/notebooks/test.ipynb deleted file mode 100644 index 5afe04b0..00000000 --- a/notebooks/test.ipynb +++ /dev/null @@ -1,157 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import yaml\n", - "\n", - "from primaite.game.game import PrimaiteGame\n", - "from primaite.session.environment import PrimaiteGymEnv\n", - "from primaite.simulator.network.hardware.nodes.host.computer import Computer\n", - "from primaite.simulator.network.hardware.nodes.host.server import Server\n", - "from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection\n", - "from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot\n", - "from primaite.simulator.system.services.database.database_service import DatabaseService\n", - "from primaite import getLogger, PRIMAITE_PATHS" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "with open(PRIMAITE_PATHS.user_config_path / \"example_config\" / \"data_manipulation.yaml\") as f:\n", - " cfg = yaml.safe_load(f)\n", - "game = PrimaiteGame.from_config(cfg)\n", - "uc2_network = game.simulation.network" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "client_1: Computer = uc2_network.get_node_by_hostname(\"client_1\")\n", - "db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get(\"DataManipulationBot\")\n", - "\n", - "database_server: Server = uc2_network.get_node_by_hostname(\"database_server\")\n", - "db_service: DatabaseService = database_server.software_manager.software.get(\"DatabaseService\")\n", - "\n", - "web_server: Server = uc2_network.get_node_by_hostname(\"web_server\")\n", - "db_client: DatabaseClient = web_server.software_manager.software.get(\"DatabaseClient\")\n", - "db_connection: DatabaseClientConnection = db_client.get_new_connection()\n", - "db_service.backup_database()\n", - "\n", - "# First check that the DB client on the web_server can successfully query the users table on the database\n", - "assert db_connection.query(\"SELECT\")\n", - "\n", - "db_manipulation_bot.data_manipulation_p_of_success = 1.0\n", - "db_manipulation_bot.port_scan_p_of_success = 1.0\n", - "\n", - "# Now we run the DataManipulationBot\n", - "db_manipulation_bot.attack()\n", - "\n", - "# Now check that the DB client on the web_server cannot query the users table on the database\n", - "assert not db_connection.query(\"SELECT\")\n", - "\n", - "# Now restore the database\n", - "db_service.restore_backup()\n", - "\n", - "# Now check that the DB client on the web_server can successfully query the users table on the database\n", - "assert db_connection.query(\"SELECT\")" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "router = uc2_network.get_node_by_hostname(\"router_1\")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+-----------------------------------------------------------------------------------------------------------------+\n", - "| router_1 Access Control List |\n", - "+-------+--------+----------+--------+--------------+-------------+--------+--------------+-------------+---------+\n", - "| Index | Action | Protocol | Src IP | Src Wildcard | Src Port | Dst IP | Dst Wildcard | Dst Port | Matched |\n", - "+-------+--------+----------+--------+--------------+-------------+--------+--------------+-------------+---------+\n", - "| 18 | PERMIT | ANY | ANY | ANY | 5432 (5432) | ANY | ANY | 5432 (5432) | 4 |\n", - "| 19 | PERMIT | ANY | ANY | ANY | 53 (53) | ANY | ANY | 53 (53) | 0 |\n", - "| 20 | PERMIT | ANY | ANY | ANY | 21 (21) | ANY | ANY | 21 (21) | 0 |\n", - "| 21 | PERMIT | ANY | ANY | ANY | 80 (80) | ANY | ANY | 80 (80) | 0 |\n", - "| 22 | PERMIT | ANY | ANY | ANY | 219 (219) | ANY | ANY | 219 (219) | 9 |\n", - "| 23 | PERMIT | icmp | ANY | ANY | ANY | ANY | ANY | ANY | 0 |\n", - "| 24 | DENY | ANY | ANY | ANY | ANY | ANY | ANY | ANY | 0 |\n", - "+-------+--------+----------+--------+--------------+-------------+--------+--------------+-------------+---------+\n" - ] - } - ], - "source": [ - "router.acl.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "AccessControlList.is_permitted() missing 1 required positional argument: 'frame'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mrouter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43macl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_permitted\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[0;31mTypeError\u001b[0m: AccessControlList.is_permitted() missing 1 required positional argument: 'frame'" - ] - } - ], - "source": [ - "router.acl.is_permitted()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 50e2234a6951d3305d36f4877bcc966159fda53f Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 25 Sep 2024 16:51:58 +0100 Subject: [PATCH 6/9] Remove commented out code --- .../agent/observations/host_observations.py | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 96c5f40d..617e8eee 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -57,27 +57,6 @@ class HostObservation(AbstractObservation, identifier="HOST"): include_users: Optional[bool] = True """If True, report user session information.""" - # @field_validator("monitored_traffic", mode="before") - # def traffic_lookup(cls, val: Optional[Dict]) -> Optional[Dict]: - # """ - # Convert monitored_traffic by lookup against Port and Protocol dicts. - - # This is necessary for retaining compatiblility with configs written for PrimAITE <=3.3. - # This method will be removed in PrimAITE >= 4.0 - # """ - # if val is None: - # return val - # new_val = {} - # for proto, port_list in val.items(): - # # convert protocol, for instance ICMP becomes "icmp" - # proto = PROTOCOL_LOOKUP[proto] if proto in PROTOCOL_LOOKUP else proto - # new_val[proto] = [] - # for port in port_list: - # # convert ports, for instance "HTTP" becomes 80 - # port = PORT_LOOKUP[port] if port in PORT_LOOKUP else port - # new_val[proto].append(port) - # return new_val - def __init__( self, where: WhereType, From f2b6d68b14621d7a2badc44bfaf373eb8b9d466a Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 26 Sep 2024 15:35:50 +0100 Subject: [PATCH 7/9] Fix Port scan --- src/primaite/simulator/system/applications/nmap.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/primaite/simulator/system/applications/nmap.py b/src/primaite/simulator/system/applications/nmap.py index a04067c4..e2b9117d 100644 --- a/src/primaite/simulator/system/applications/nmap.py +++ b/src/primaite/simulator/system/applications/nmap.py @@ -348,7 +348,7 @@ class NMAP(Application, identifier="NMAP"): if is_valid_port(target_port): target_port = [target_port] elif target_port is None: - target_port = [port for port in PORT_LOOKUP if port not in {PORT_LOOKUP["NONE"], PORT_LOOKUP["UNUSED"]}] + target_port = [PORT_LOOKUP[port] for port in PORT_LOOKUP if port not in {"NONE", "UNUSED"}] if is_valid_protocol(target_protocol): target_protocol = [target_protocol] @@ -358,7 +358,7 @@ class NMAP(Application, identifier="NMAP"): scan_type = self._determine_port_scan_type(list(ip_addresses), target_port) active_ports = {} if show: - table = PrettyTable(["IP Address", "Port", "Name", "Protocol"]) + table = PrettyTable(["IP Address", "Port", "Protocol"]) table.align = "l" table.title = f"{self.software_manager.node.hostname} NMAP Port Scan ({scan_type})" self.sys_log.info(f"{self.name}: Starting port scan") @@ -369,13 +369,12 @@ class NMAP(Application, identifier="NMAP"): for protocol in target_protocol: for port in set(target_port): port_open = self._check_port_open_on_ip_address(ip_address=ip_address, port=port, protocol=protocol) - if port_open: if show: - table.add_row([ip_address, port, port, protocol]) + table.add_row([ip_address, port, protocol]) _ip_address = ip_address if not json_serializable else str(ip_address) - _protocol = protocol if not json_serializable else protocol - _port = port if not json_serializable else port + _protocol = protocol + _port = port if _ip_address not in active_ports: active_ports[_ip_address] = dict() if _protocol not in active_ports[_ip_address]: From 203ec5ec856556fd0d5f81dc782e53ece2385893 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 26 Sep 2024 16:00:59 +0100 Subject: [PATCH 8/9] Add tests for port and protocol validation and update changelog --- CHANGELOG.md | 1 + src/primaite/simulator/network/airspace.py | 6 ++--- .../simulator/network/hardware/base.py | 5 ++-- .../red_applications/c2/abstract_c2.py | 2 +- .../_primaite/_utils/_validation/__init__.py | 1 + .../_utils/_validation/test_ip_protocol.py | 23 +++++++++++++++++ .../_primaite/_utils/_validation/test_port.py | 25 +++++++++++++++++++ 7 files changed, 56 insertions(+), 7 deletions(-) create mode 100644 tests/unit_tests/_primaite/_utils/_validation/__init__.py create mode 100644 tests/unit_tests/_primaite/_utils/_validation/test_ip_protocol.py create mode 100644 tests/unit_tests/_primaite/_utils/_validation/test_port.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d08974c..9493dec4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - File and folder observations can now be configured to always show the true health status, or require scanning like before. - It's now possible to disable stickiness on reward components, meaning their value returns to 0 during timesteps where agent don't issue the corresponding action. Affects `GreenAdminDatabaseUnreachablePenalty`, `WebpageUnavailablePenalty`, `WebServer404Penalty` - Node observations can now be configured to show the number of active local and remote logins. +- Ports, IP Protocols, and airspace frequencies no longer use enums. They defined in dictionary lookups and are handled by custom validation to enable extendability with plugins. ### Fixed - Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config) diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index 65dceeb1..03d43130 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -48,11 +48,11 @@ _default_frequency_set: Dict[str, Dict] = { """Frequency configuration that is automatically used for any new airspace.""" -def register_default_frequency(freq_name: str, freq_hz: float, data_rate_bps: float): +def register_default_frequency(freq_name: str, freq_hz: float, data_rate_bps: float) -> None: """Add to the default frequency configuration. This is intended as a plugin hook. If your plugin makes use of bespoke frequencies for wireless communication, you should make a call to this method - whereever you define components that rely on the bespoke frequencies. That way, as soon as your components are + wherever you define components that rely on the bespoke frequencies. That way, as soon as your components are imported, this function automatically updates the default frequency set. This should also be run before instances of AirSpace are created. @@ -93,7 +93,7 @@ class AirSpace(BaseModel): return self.frequencies[freq_name]["data_rate_bps"] / (1024.0 * 1024.0) return 0.0 - def set_frequency_max_capacity_mbps(self, cfg: Dict[int, float]): + def set_frequency_max_capacity_mbps(self, cfg: Dict[int, float]) -> None: """ Sets custom maximum data transmission capacities for multiple frequencies. diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 778cffa2..050f4667 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1839,15 +1839,14 @@ class Node(SimComponent): def show_open_ports(self, markdown: bool = False): """Prints a table of the open ports on the Node.""" - table = PrettyTable(["Port", "Name"]) + table = PrettyTable(["Port"]) if markdown: table.set_style(MARKDOWN) table.align = "l" table.title = f"{self.hostname} Open Ports" for port in self.software_manager.get_open_ports(): if port > 0: - # TODO: do a reverse lookup for port name, or change this to only show port int - table.add_row([port, port]) + table.add_row([port]) print(table.get_string(sortby="Port")) @property diff --git a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py index aff12748..f77bc33a 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py @@ -366,7 +366,7 @@ class AbstractC2(Application, identifier="AbstractC2"): :return: True on successful configuration, false otherwise. :rtype: bool """ - # Validating that they are valid Enums. + # Validating that they are valid Ports and Protocols. if not is_valid_port(payload.masquerade_port) or not is_valid_protocol(payload.masquerade_protocol): self.sys_log.warning( f"{self.name}: Received invalid Masquerade Values within Keep Alive." diff --git a/tests/unit_tests/_primaite/_utils/_validation/__init__.py b/tests/unit_tests/_primaite/_utils/_validation/__init__.py new file mode 100644 index 00000000..be6c00e7 --- /dev/null +++ b/tests/unit_tests/_primaite/_utils/_validation/__init__.py @@ -0,0 +1 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_utils/_validation/test_ip_protocol.py b/tests/unit_tests/_primaite/_utils/_validation/test_ip_protocol.py new file mode 100644 index 00000000..27829570 --- /dev/null +++ b/tests/unit_tests/_primaite/_utils/_validation/test_ip_protocol.py @@ -0,0 +1,23 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest + +from primaite.utils.validation.ip_protocol import IPProtocol, is_valid_protocol, PROTOCOL_LOOKUP, protocol_validator + + +def test_port_conversion(): + for proto_name, proto_val in PROTOCOL_LOOKUP.items(): + assert protocol_validator(proto_name) == proto_val + assert is_valid_protocol(proto_name) + + +def test_port_passthrough(): + for proto_val in PROTOCOL_LOOKUP.values(): + assert protocol_validator(proto_val) == proto_val + assert is_valid_protocol(proto_val) + + +def test_invalid_ports(): + for port in (123, "abcdefg", "NONEXISTENT_PROTO"): + with pytest.raises(ValueError): + protocol_validator(port) + assert not is_valid_protocol(port) diff --git a/tests/unit_tests/_primaite/_utils/_validation/test_port.py b/tests/unit_tests/_primaite/_utils/_validation/test_port.py new file mode 100644 index 00000000..6a8a2429 --- /dev/null +++ b/tests/unit_tests/_primaite/_utils/_validation/test_port.py @@ -0,0 +1,25 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest + +from primaite.utils.validation.port import is_valid_port, Port, PORT_LOOKUP, port_validator + + +def test_port_conversion(): + valid_port_lookup = {k: v for k, v in PORT_LOOKUP.items() if k != "UNUSED"} + for port_name, port_val in valid_port_lookup.items(): + assert port_validator(port_name) == port_val + assert is_valid_port(port_name) + + +def test_port_passthrough(): + valid_port_lookup = {k: v for k, v in PORT_LOOKUP.items() if k != "UNUSED"} + for port_val in valid_port_lookup.values(): + assert port_validator(port_val) == port_val + assert is_valid_port(port_val) + + +def test_invalid_ports(): + for port in (999999, -20, 3.214, "NONEXISTENT_PORT"): + with pytest.raises(ValueError): + port_validator(port) + assert not is_valid_port(port) From c74d5ac227d8bf947f2545870bc08a362671f52c Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 27 Sep 2024 09:28:26 +0100 Subject: [PATCH 9/9] Fix changelog typo and remove repitition in ACL show method --- CHANGELOG.md | 2 +- .../simulator/network/hardware/nodes/network/router.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9493dec4..f51fd648 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - File and folder observations can now be configured to always show the true health status, or require scanning like before. - It's now possible to disable stickiness on reward components, meaning their value returns to 0 during timesteps where agent don't issue the corresponding action. Affects `GreenAdminDatabaseUnreachablePenalty`, `WebpageUnavailablePenalty`, `WebServer404Penalty` - Node observations can now be configured to show the number of active local and remote logins. -- Ports, IP Protocols, and airspace frequencies no longer use enums. They defined in dictionary lookups and are handled by custom validation to enable extendability with plugins. +- Ports, IP Protocols, and airspace frequencies no longer use enums. They are defined in dictionary lookups and are handled by custom validation to enable extendability with plugins. ### Fixed - Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config) diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 244f40ce..1080dca8 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -555,10 +555,10 @@ class AccessControlList(SimComponent): rule.protocol if rule.protocol else "ANY", rule.src_ip_address if rule.src_ip_address else "ANY", rule.src_wildcard_mask if rule.src_wildcard_mask else "ANY", - f"{rule.src_port} ({rule.src_port})" if rule.src_port else "ANY", + f"{rule.src_port}" if rule.src_port else "ANY", rule.dst_ip_address if rule.dst_ip_address else "ANY", rule.dst_wildcard_mask if rule.dst_wildcard_mask else "ANY", - f"{rule.dst_port} ({rule.dst_port})" if rule.dst_port else "ANY", + f"{rule.dst_port}" if rule.dst_port else "ANY", rule.match_count, ] )