diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index e00afba6..7b36097b 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -1,4 +1,5 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from copy import deepcopy from ipaddress import IPv4Address, IPv4Network from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union @@ -76,6 +77,8 @@ class SoftwareManager: for software in self.port_protocol_mapping.values(): if software.operating_state in {ApplicationOperatingState.RUNNING, ServiceOperatingState.RUNNING}: open_ports.append(software.port) + if software.listen_on_ports: + open_ports += list(software.listen_on_ports) return open_ports def check_port_is_open(self, port: Port, protocol: IPProtocol) -> bool: @@ -223,7 +226,9 @@ class SoftwareManager: frame: Frame, ): """ - Receive a payload from the SessionManager and forward it to the corresponding service or application. + Receive a payload from the SessionManager and forward it to the corresponding service or applications. + + This function handles both software assigned a specific port, and software listening in on other ports. :param payload: The payload being received. :param session: The transport session the payload originates from. @@ -231,11 +236,17 @@ class SoftwareManager: if payload.__class__.__name__ == "PortScanPayload": self.software.get("NMAP").receive(payload=payload, session_id=session_id) return - receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None) - if receiver: - receiver.receive( - payload=payload, session_id=session_id, from_network_interface=from_network_interface, frame=frame - ) + main_receiver = self.port_protocol_mapping.get((port, protocol), None) + listening_receivers = [software for software in self.software.values() if port in software.listen_on_ports] + receivers = [main_receiver] + listening_receivers if main_receiver else listening_receivers + if receivers: + for receiver in receivers: + receiver.receive( + payload=deepcopy(payload), + session_id=session_id, + from_network_interface=from_network_interface, + frame=frame, + ) else: self.sys_log.warning(f"No service or application found for port {port} and protocol {protocol}") pass diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 22ae0ff3..56edcf89 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -377,6 +377,8 @@ class DatabaseService(Service): ) else: result = {"status_code": 401, "type": "sql"} + else: + self.sys_log.info(f"{self.name}: Ignoring payload as it is not a Database payload") self.send(payload=result, session_id=session_id) return True diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 7c27534a..7a3d675c 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -4,9 +4,10 @@ from abc import abstractmethod from datetime import datetime from enum import Enum from ipaddress import IPv4Address, IPv4Network -from typing import Any, Dict, Optional, TYPE_CHECKING, Union +from typing import Any, Dict, Optional, Set, TYPE_CHECKING, Union from prettytable import MARKDOWN, PrettyTable +from pydantic import Field from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent @@ -252,6 +253,8 @@ class IOSoftware(Software): "Indicates if the software uses UDP protocol for communication. Default is True." port: Port "The port to which the software is connected." + listen_on_ports: Set[Port] = Field(default_factory=set) + "The set of ports to listen on." protocol: IPProtocol "The IP Protocol the Software operates on." _connections: Dict[str, Dict] = {} diff --git a/tests/integration_tests/system/test_service_listening_on_ports.py b/tests/integration_tests/system/test_service_listening_on_ports.py new file mode 100644 index 00000000..0cb1ad54 --- /dev/null +++ b/tests/integration_tests/system/test_service_listening_on_ports.py @@ -0,0 +1,64 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Any, Dict, List, Set + +from pydantic import Field + +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.services.database.database_service import DatabaseService +from primaite.simulator.system.services.service import Service + + +class _DatabaseListener(Service): + name: str = "DatabaseListener" + protocol: IPProtocol = IPProtocol.TCP + port: Port = Port.NONE + listen_on_ports: Set[Port] = {Port.POSTGRES_SERVER} + payloads_received: List[Any] = Field(default_factory=list) + + def receive(self, payload: Any, session_id: str, **kwargs) -> bool: + self.payloads_received.append(payload) + self.sys_log.info(f"{self.name}: received payload {payload}") + return True + + def describe_state(self) -> Dict: + return super().describe_state() + + +def test_http_listener(client_server): + computer, server = client_server + + server.software_manager.install(DatabaseService) + server_db = server.software_manager.software["DatabaseService"] + server_db.start() + + server.software_manager.install(_DatabaseListener) + server_db_listener: _DatabaseListener = server.software_manager.software["DatabaseListener"] + server_db_listener.start() + + computer.software_manager.install(DatabaseClient) + computer_db_client: DatabaseClient = computer.software_manager.software["DatabaseClient"] + + computer_db_client.run() + computer_db_client.server_ip_address = server.network_interface[1].ip_address + + assert len(server_db_listener.payloads_received) == 0 + computer.session_manager.receive_payload_from_software_manager( + payload="masquerade as Database traffic", + dst_ip_address=server.network_interface[1].ip_address, + dst_port=Port.POSTGRES_SERVER, + ip_protocol=IPProtocol.TCP, + ) + + assert len(server_db_listener.payloads_received) == 1 + + db_connection = computer_db_client.get_new_connection() + + assert db_connection + + assert len(server_db_listener.payloads_received) == 2 + + assert db_connection.query("SELECT") + + assert len(server_db_listener.payloads_received) == 3