diff --git a/src/primaite/config/load.py b/src/primaite/config/load.py index 144e0733..b00c26f6 100644 --- a/src/primaite/config/load.py +++ b/src/primaite/config/load.py @@ -59,3 +59,17 @@ def data_manipulation_marl_config_path() -> Path: _LOGGER.error(msg) raise FileNotFoundError(msg) return path + +def get_extended_config_path() -> Path: + """ + Get the path to an 'extended' example config that contains nodes using the extension framework + + :return: Path to the extended example config + :rtype: Path + """ + path = _EXAMPLE_CFG / "extended_config.yaml" + if not path.exists(): + msg = f"Example config does not exist: {path}. Have you run `primaite setup`?" + _LOGGER.error(msg) + raise FileNotFoundError(msg) + return path \ No newline at end of file diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 045b2467..11c968af 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -20,9 +20,10 @@ from primaite.simulator import SIM_OUTPUT from primaite.simulator.network.airspace import AirSpaceFrequency from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState, UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.hardware.nodes.host.host_node import NIC +from primaite.simulator.network.hardware.nodes.host.host_node import NIC, HostNode from primaite.simulator.network.hardware.nodes.host.server import Printer, Server from primaite.simulator.network.hardware.nodes.network.firewall import Firewall +from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter @@ -278,8 +279,25 @@ class PrimaiteGame: for node_cfg in nodes_cfg: n_type = node_cfg["type"] + new_node = None - if n_type == "computer": + # Handle extended nodes + if n_type.lower() in HostNode._registry: + new_node = HostNode._registry[n_type]( + hostname=node_cfg["hostname"], + ip_address=node_cfg["ip_address"], + subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), + default_gateway=node_cfg.get("default_gateway"), + dns_server=node_cfg.get("dns_server", None), + operating_state=NodeOperatingState.ON + if not (p := node_cfg.get("operating_state")) + else NodeOperatingState[p.upper()]) + elif n_type in NetworkNode._registry: + new_node = NetworkNode._registry[n_type]( + **node_cfg + ) + # Default PrimAITE nodes + elif n_type == "computer": new_node = Computer( hostname=node_cfg["hostname"], ip_address=node_cfg["ip_address"], @@ -351,10 +369,18 @@ class PrimaiteGame: for service_cfg in node_cfg["services"]: new_service = None service_type = service_cfg["type"] - if service_type in SERVICE_TYPES_MAPPING: + + service_class = None + # Handle extended services + if service_type.lower() in Service._registry: + service_class = Service._registry[service_type.lower()] + elif service_type in SERVICE_TYPES_MAPPING: + service_class = SERVICE_TYPES_MAPPING[service_type] + + if service_class is not None: _LOGGER.debug(f"installing {service_type} on node {new_node.hostname}") - new_node.software_manager.install(SERVICE_TYPES_MAPPING[service_type]) - new_service = new_node.software_manager.software[service_type] + new_node.software_manager.install(service_class) + new_service = new_node.software_manager.software[service_class.__name__] # fixing duration for the service if "fix_duration" in service_cfg.get("options", {}): @@ -398,8 +424,8 @@ class PrimaiteGame: new_application = None application_type = application_cfg["type"] - if application_type in Application._application_registry: - new_node.software_manager.install(Application._application_registry[application_type]) + if application_type in Application._registry: + new_node.software_manager.install(Application._registry[application_type]) new_application = new_node.software_manager.software[application_type] # grab the instance # fixing duration for the application diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index 0408acde..39fbe783 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -12,7 +12,9 @@ from primaite import getLogger from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.network.airspace import AirSpace from primaite.simulator.network.hardware.base import Link, Node, WiredNetworkInterface +from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.network.hardware.nodes.host.server import Printer +from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode from primaite.simulator.system.applications.application import Application from primaite.simulator.system.services.service import Service @@ -128,6 +130,16 @@ class Network(SimComponent): def firewall_nodes(self) -> List[Node]: """The Firewalls in the Network.""" return [node for node in self.nodes.values() if node.__class__.__name__ == "Firewall"] + + @property + def extended_hostnodes(self) -> List[Node]: + """Extended nodes that inherited HostNode in the network""" + return [node for node in self.nodes.values() if node.__class__.__name__.lower() in HostNode._registry] + + @property + def extended_networknodes(self) -> List[Node]: + """Extended nodes that inherited NetworkNode in the network""" + return [node for node in self.nodes.values() if node.__class__.__name__.lower() in NetworkNode._registry] @property def printer_nodes(self) -> List[Node]: @@ -160,6 +172,7 @@ class Network(SimComponent): "Printer": self.printer_nodes, "Wireless Router": self.wireless_router_nodes, } + if nodes: table = PrettyTable(["Node", "Type", "Operating State"]) if markdown: diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index ef2d47c3..bf230e07 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1699,7 +1699,7 @@ class Node(SimComponent): if self.software_manager.software.get(application_name): self.sys_log.warning(f"Can't install {application_name}. It's already installed.") return RequestResponse(status="success", data={"reason": "already installed"}) - application_class = Application._application_registry[application_name] + application_class = Application._registry[application_name] self.software_manager.install(application_class) application_instance = self.software_manager.software.get(application_name) self.applications[application_instance.uuid] = application_instance diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index c197d30b..ea162e88 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -2,7 +2,7 @@ from __future__ import annotations from ipaddress import IPv4Address -from typing import Any, ClassVar, Dict, Optional +from typing import Any, ClassVar, Dict, Optional, Type from primaite import getLogger from primaite.simulator.network.hardware.base import ( @@ -325,10 +325,30 @@ class HostNode(Node): network_interface: Dict[int, NIC] = {} "The NICs on the node by port id." + _registry: ClassVar[Dict[str, Type["HostNode"]]] = {} + """Registry of application types. Automatically populated when subclasses are defined.""" + def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs): super().__init__(**kwargs) self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask)) + def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None: + """ + Register a hostnode type. + + :param identifier: Uniquely specifies an hostnode class by name. Used for finding items by config. + :type identifier: str + :raises ValueError: When attempting to register an hostnode with a name that is already allocated. + """ + if identifier == 'default': + return + # Enforce lowercase registry entries because it makes comparisons everywhere else much easier. + identifier = identifier.lower() + super().__init_subclass__(**kwargs) + if identifier in cls._registry: + raise ValueError(f"Tried to define new hostnode {identifier}, but this name is already reserved.") + cls._registry[identifier] = cls + @property def nmap(self) -> Optional[NMAP]: """ diff --git a/src/primaite/simulator/network/hardware/nodes/network/network_node.py b/src/primaite/simulator/network/hardware/nodes/network/network_node.py index 5ff791cc..6515bb02 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/network_node.py +++ b/src/primaite/simulator/network/hardware/nodes/network/network_node.py @@ -1,6 +1,6 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from abc import abstractmethod -from typing import Optional +from typing import Any, ClassVar, Dict, Optional, Type from primaite.simulator.network.hardware.base import NetworkInterface, Node from primaite.simulator.network.transmission.data_link_layer import Frame @@ -16,6 +16,25 @@ class NetworkNode(Node): provide functionality for receiving and processing frames received on their network interfaces. """ + _registry: ClassVar[Dict[str, Type["NetworkNode"]]] = {} + """Registry of application types. Automatically populated when subclasses are defined.""" + + def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None: + """ + Register a networknode type. + + :param identifier: Uniquely specifies an networknode class by name. Used for finding items by config. + :type identifier: str + :raises ValueError: When attempting to register an networknode with a name that is already allocated. + """ + if identifier == 'default': + return + identifier = identifier.lower() + super().__init_subclass__(**kwargs) + if identifier in cls._registry: + raise ValueError(f"Tried to define new networknode {identifier}, but this name is already reserved.") + cls._registry[identifier] = cls + @abstractmethod def receive_frame(self, frame: Frame, from_network_interface: NetworkInterface): """ diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 741f491d..b5284968 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -41,10 +41,10 @@ class Application(IOSoftware): install_countdown: Optional[int] = None "The countdown to the end of the installation process. None if not currently installing" - _application_registry: ClassVar[Dict[str, Type["Application"]]] = {} + _registry: ClassVar[Dict[str, Type["Application"]]] = {} """Registry of application types. Automatically populated when subclasses are defined.""" - def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None: """ Register an application type. @@ -52,10 +52,12 @@ class Application(IOSoftware): :type identifier: str :raises ValueError: When attempting to register an application with a name that is already allocated. """ + if identifier == 'default': + return super().__init_subclass__(**kwargs) - if identifier in cls._application_registry: + if identifier in cls._registry: raise ValueError(f"Tried to define new application {identifier}, but this name is already reserved.") - cls._application_registry[identifier] = cls + cls._registry[identifier] = cls def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 5adea6e7..74dcb506 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import abstractmethod from enum import Enum -from typing import Any, Dict, Optional +from typing import Any, ClassVar, Dict, Optional, Type from primaite import getLogger from primaite.interface.request import RequestFormat, RequestResponse @@ -46,9 +46,29 @@ class Service(IOSoftware): restart_countdown: Optional[int] = None "If currently restarting, how many timesteps remain until the restart is finished." + _registry: ClassVar[Dict[str, Type["Service"]]] = {} + """Registry of service types. Automatically populated when subclasses are defined.""" + def __init__(self, **kwargs): super().__init__(**kwargs) + def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None: + """ + Register a hostnode type. + + :param identifier: Uniquely specifies an hostnode class by name. Used for finding items by config. + :type identifier: str + :raises ValueError: When attempting to register an hostnode with a name that is already allocated. + """ + if identifier == 'default': + return + # Enforce lowercase registry entries because it makes comparisons everywhere else much easier. + identifier = identifier.lower() + super().__init_subclass__(**kwargs) + if identifier in cls._registry: + raise ValueError(f"Tried to define new hostnode {identifier}, but this name is already reserved.") + cls._registry[identifier] = cls + def _can_perform_action(self) -> bool: """ Checks if the service can perform actions. diff --git a/tests/assets/configs/extended_config.yaml b/tests/assets/configs/extended_config.yaml new file mode 100644 index 00000000..e1a06938 --- /dev/null +++ b/tests/assets/configs/extended_config.yaml @@ -0,0 +1,951 @@ +io_settings: + save_agent_actions: true + save_step_metadata: false + save_pcap_logs: false + save_sys_logs: false + sys_log_level: WARNING + + +game: + max_episode_length: 128 + ports: + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + thresholds: + nmne: + high: 10 + medium: 5 + low: 0 + +agents: + - ref: client_2_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_2 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_2 + + - ref: client_1_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_1 + + + + + + - ref: data_manipulation_attacker + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: null + + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: DataManipulationBot + - node_name: client_2 + applications: + - application_name: DataManipulationBot + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 + + - ref: defender + team: BLUE + type: ProxyAgent + + observation_space: + type: CUSTOM + options: + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1:eth-1<->switch_1:eth-8 + - router_1:eth-2<->switch_2:eth-8 + - switch_1:eth-1<->domain_controller:eth-1 + - switch_1:eth-2<->web_server:eth-1 + - switch_1:eth-3<->database_server:eth-1 + - switch_1:eth-4<->backup_server:eth-1 + - switch_1:eth-7<->security_suite:eth-1 + - switch_2:eth-1<->client_1:eth-1 + - switch_2:eth-2<->client_2:eth-1 + - switch_2:eth-7<->security_suite:eth-2 + - type: "NONE" + label: ICS + options: {} + + action_space: + action_list: + - type: DONOTHING + - type: NODE_SERVICE_SCAN + - type: NODE_SERVICE_STOP + - type: NODE_SERVICE_START + - type: NODE_SERVICE_PAUSE + - type: NODE_SERVICE_RESUME + - type: NODE_SERVICE_RESTART + - type: NODE_SERVICE_DISABLE + - type: NODE_SERVICE_ENABLE + - type: NODE_SERVICE_FIX + - type: NODE_FILE_SCAN + - type: NODE_FILE_CHECKHASH + - type: NODE_FILE_DELETE + - type: NODE_FILE_REPAIR + - type: NODE_FILE_RESTORE + - type: NODE_FOLDER_SCAN + - type: NODE_FOLDER_CHECKHASH + - type: NODE_FOLDER_REPAIR + - type: NODE_FOLDER_RESTORE + - type: NODE_OS_SCAN + - type: NODE_SHUTDOWN + - type: NODE_STARTUP + - type: NODE_RESET + - type: ROUTER_ACL_ADDRULE + - type: ROUTER_ACL_REMOVERULE + - type: HOST_NIC_ENABLE + - type: HOST_NIC_DISABLE + + action_map: + 0: + action: DONOTHING + options: {} + # scan webapp service + 1: + action: NODE_SERVICE_SCAN + options: + node_id: 1 + service_id: 0 + # stop webapp service + 2: + action: NODE_SERVICE_STOP + options: + node_id: 1 + service_id: 0 + # start webapp service + 3: + action: "NODE_SERVICE_START" + options: + node_id: 1 + service_id: 0 + 4: + action: "NODE_SERVICE_PAUSE" + options: + node_id: 1 + service_id: 0 + 5: + action: "NODE_SERVICE_RESUME" + options: + node_id: 1 + service_id: 0 + 6: + action: "NODE_SERVICE_RESTART" + options: + node_id: 1 + service_id: 0 + 7: + action: "NODE_SERVICE_DISABLE" + options: + node_id: 1 + service_id: 0 + 8: + action: "NODE_SERVICE_ENABLE" + options: + node_id: 1 + service_id: 0 + 9: # check database.db file + action: "NODE_FILE_SCAN" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 10: + action: "NODE_FILE_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 11: + action: "NODE_FILE_DELETE" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 12: + action: "NODE_FILE_REPAIR" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 13: + action: "NODE_SERVICE_FIX" + options: + node_id: 2 + service_id: 0 + 14: + action: "NODE_FOLDER_SCAN" + options: + node_id: 2 + folder_id: 0 + 15: + action: "NODE_FOLDER_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + options: + node_id: 2 + folder_id: 0 + 16: + action: "NODE_FOLDER_REPAIR" + options: + node_id: 2 + folder_id: 0 + 17: + action: "NODE_FOLDER_RESTORE" + options: + node_id: 2 + folder_id: 0 + 18: + action: "NODE_OS_SCAN" + options: + node_id: 0 + 19: + action: "NODE_SHUTDOWN" + options: + node_id: 0 + 20: + action: NODE_STARTUP + options: + node_id: 0 + 21: + action: NODE_RESET + options: + node_id: 0 + 22: + action: "NODE_OS_SCAN" + options: + node_id: 1 + 23: + action: "NODE_SHUTDOWN" + options: + node_id: 1 + 24: + action: NODE_STARTUP + options: + node_id: 1 + 25: + action: NODE_RESET + options: + node_id: 1 + 26: # old action num: 18 + action: "NODE_OS_SCAN" + options: + node_id: 2 + 27: + action: "NODE_SHUTDOWN" + options: + node_id: 2 + 28: + action: NODE_STARTUP + options: + node_id: 2 + 29: + action: NODE_RESET + options: + node_id: 2 + 30: + action: "NODE_OS_SCAN" + options: + node_id: 3 + 31: + action: "NODE_SHUTDOWN" + options: + node_id: 3 + 32: + action: NODE_STARTUP + options: + node_id: 3 + 33: + action: NODE_RESET + options: + node_id: 3 + 34: + action: "NODE_OS_SCAN" + options: + node_id: 4 + 35: + action: "NODE_SHUTDOWN" + options: + node_id: 4 + 36: + action: NODE_STARTUP + options: + node_id: 4 + 37: + action: NODE_RESET + options: + node_id: 4 + 38: + action: "NODE_OS_SCAN" + options: + node_id: 5 + 39: # old action num: 19 # shutdown client 1 + action: "NODE_SHUTDOWN" + options: + node_id: 5 + 40: # old action num: 20 + action: NODE_STARTUP + options: + node_id: 5 + 41: # old action num: 21 + action: NODE_RESET + options: + node_id: 5 + 42: + action: "NODE_OS_SCAN" + options: + node_id: 6 + 43: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 44: + action: NODE_STARTUP + options: + node_id: 6 + 45: + action: NODE_RESET + options: + node_id: 6 + + 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 1 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 2 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 48: # old action num: 24 # block tcp traffic from client 1 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 3 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 49: # old action num: 25 # block tcp traffic from client 2 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 4 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 50: # old action num: 26 + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 5 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 51: # old action num: 27 + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 6 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 52: # old action num: 28 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 0 + 53: # old action num: 29 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 1 + 54: # old action num: 30 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 2 + 55: # old action num: 31 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 3 + 56: # old action num: 32 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 4 + 57: # old action num: 33 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 5 + 58: # old action num: 34 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 6 + 59: # old action num: 35 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 7 + 60: # old action num: 36 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 8 + 61: # old action num: 37 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 9 + 62: # old action num: 38 + action: "HOST_NIC_DISABLE" + options: + node_id: 0 + nic_id: 0 + 63: # old action num: 39 + action: "HOST_NIC_ENABLE" + options: + node_id: 0 + nic_id: 0 + 64: # old action num: 40 + action: "HOST_NIC_DISABLE" + options: + node_id: 1 + nic_id: 0 + 65: # old action num: 41 + action: "HOST_NIC_ENABLE" + options: + node_id: 1 + nic_id: 0 + 66: # old action num: 42 + action: "HOST_NIC_DISABLE" + options: + node_id: 2 + nic_id: 0 + 67: # old action num: 43 + action: "HOST_NIC_ENABLE" + options: + node_id: 2 + nic_id: 0 + 68: # old action num: 44 + action: "HOST_NIC_DISABLE" + options: + node_id: 3 + nic_id: 0 + 69: # old action num: 45 + action: "HOST_NIC_ENABLE" + options: + node_id: 3 + nic_id: 0 + 70: # old action num: 46 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 0 + 71: # old action num: 47 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 0 + 72: # old action num: 48 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 1 + 73: # old action num: 49 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 1 + 74: # old action num: 50 + action: "HOST_NIC_DISABLE" + options: + node_id: 5 + nic_id: 0 + 75: # old action num: 51 + action: "HOST_NIC_ENABLE" + options: + node_id: 5 + nic_id: 0 + 76: # old action num: 52 + action: "HOST_NIC_DISABLE" + options: + node_id: 6 + nic_id: 0 + 77: # old action num: 53 + action: "HOST_NIC_ENABLE" + options: + node_id: 6 + nic_id: 0 + + + + options: + nodes: + - node_name: domain_controller + - node_name: web_server + applications: + - application_name: DatabaseClient + services: + - service_name: WebServer + - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db + services: + - service_name: DatabaseService + - node_name: backup_server + - node_name: security_suite + - node_name: client_1 + - node_name: client_2 + + max_folders_per_node: 2 + max_files_per_folder: 2 + max_services_per_node: 2 + max_nics_per_node: 8 + max_acl_rules: 10 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + + + reward_function: + reward_components: + - type: DATABASE_FILE_INTEGRITY + weight: 0.40 + options: + node_hostname: database_server + folder_name: database + file_name: database.db + + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_1_green_user + + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_2_green_user + + agent_settings: + flatten_obs: true + action_masking: true + + + + + +simulation: + network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE + nodes: + + - hostname: router_1 + type: router + num_ports: 5 + ports: + 1: + ip_address: 192.168.1.1 + subnet_mask: 255.255.255.0 + 2: + ip_address: 192.168.10.1 + subnet_mask: 255.255.255.0 + acl: + 18: + action: PERMIT + src_port: POSTGRES_SERVER + dst_port: POSTGRES_SERVER + 19: + action: PERMIT + src_port: DNS + dst_port: DNS + 20: + action: PERMIT + src_port: FTP + dst_port: FTP + 21: + action: PERMIT + src_port: HTTP + dst_port: HTTP + 22: + action: PERMIT + src_port: ARP + dst_port: ARP + 23: + action: PERMIT + protocol: ICMP + + - hostname: switch_1 + type: switch + num_ports: 8 + + - hostname: switch_2 + type: gigaswitch + num_ports: 8 + + - hostname: domain_controller + type: server + ip_address: 192.168.1.10 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + services: + - type: DNSServer + options: + domain_mapping: + arcd.com: 192.168.1.12 # web server + + - hostname: web_server + type: server + ip_address: 192.168.1.12 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: WebServer + applications: + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + + + - hostname: database_server + type: server + ip_address: 192.168.1.14 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: DatabaseService + options: + backup_server_ip: 192.168.1.16 + - type: FTPClient + + - hostname: backup_server + type: server + ip_address: 192.168.1.16 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: FTPServer + + - hostname: security_suite + type: server + ip_address: 192.168.1.110 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + network_interfaces: + 2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot + ip_address: 192.168.10.110 + subnet_mask: 255.255.255.0 + + - hostname: client_1 + type: supercomputer + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: ExtendedApplication + options: + target_url: http://arcd.com/users/ + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + services: + - type: DNSClient + - type: DatabaseService + options: + backup_server_ip: 192.168.1.16 + - type: ExtendedService + options: + backup_server_ip: 192.168.1.16 + + - hostname: client_2 + type: computer + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + services: + - type: DNSClient + + links: + - endpoint_a_hostname: router_1 + endpoint_a_port: 1 + endpoint_b_hostname: switch_1 + endpoint_b_port: 8 + - endpoint_a_hostname: router_1 + endpoint_a_port: 2 + endpoint_b_hostname: switch_2 + endpoint_b_port: 8 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 1 + endpoint_b_hostname: domain_controller + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 2 + endpoint_b_hostname: web_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 3 + endpoint_b_hostname: database_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 4 + endpoint_b_hostname: backup_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 7 + endpoint_b_hostname: security_suite + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 1 + endpoint_b_hostname: client_1 + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 2 + endpoint_b_hostname: client_2 + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 7 + endpoint_b_hostname: security_suite + endpoint_b_port: 2 diff --git a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py index 3e06d371..a642564c 100644 --- a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py +++ b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py @@ -86,7 +86,7 @@ def test_node_software_install(): assert client_2.software_manager.software.get(software.__name__) is not None # check that applications have been installed on client 1 - for applications in Application._application_registry: + for applications in Application._registry: assert client_1.software_manager.software.get(applications) is not None # check that services have been installed on client 1 diff --git a/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py b/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py index dd38fafd..168ebee0 100644 --- a/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py +++ b/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py @@ -51,7 +51,7 @@ def test_fix_duration_set_from_config(): # in config - applications take 1 timestep to fix # remove test applications from list - applications = set(Application._application_registry) - set(TestApplications) + applications = set(Application._registry) - set(TestApplications) for application in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot", "DatabaseClient"]: assert client_1.software_manager.software.get(application) is not None diff --git a/tests/integration_tests/extensions/applications/extended_application.py b/tests/integration_tests/extensions/applications/extended_application.py new file mode 100644 index 00000000..c9b3006d --- /dev/null +++ b/tests/integration_tests/extensions/applications/extended_application.py @@ -0,0 +1,220 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from enum import Enum +from ipaddress import IPv4Address +from typing import Dict, List, Optional +from urllib.parse import urlparse + +from pydantic import BaseModel, ConfigDict + +from primaite import getLogger +from primaite.interface.request import RequestResponse +from primaite.simulator.core import RequestManager, RequestType +from primaite.simulator.network.protocols.http import ( + HttpRequestMethod, + HttpRequestPacket, + HttpResponsePacket, + HttpStatusCode, +) +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.application import Application +from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.simulator.system.services.dns.dns_client import DNSClient + +_LOGGER = getLogger(__name__) + + +class ExtendedApplication(Application, identifier="ExtendedApplication"): + """ + Clone of web browser that uses the extension framework instead of being part of PrimAITE directly. + + The application requests and loads web pages using its domain name and requesting IP addresses using DNS. + """ + + target_url: Optional[str] = None + + domain_name_ip_address: Optional[IPv4Address] = None + "The IP address of the domain name for the webpage." + + latest_response: Optional[HttpResponsePacket] = None + """Keeps track of the latest HTTP response.""" + + history: List["BrowserHistoryItem"] = [] + """Keep a log of visited websites and information about the visit, such as response code.""" + + def __init__(self, **kwargs): + kwargs["name"] = "ExtendedApplication" + kwargs["protocol"] = IPProtocol.TCP + # default for web is port 80 + if kwargs.get("port") is None: + kwargs["port"] = Port.HTTP + + super().__init__(**kwargs) + self.run() + + def _init_request_manager(self) -> RequestManager: + """ + Initialise the request manager. + + More information in user guide and docstring for SimComponent._init_request_manager. + """ + rm = super()._init_request_manager() + rm.add_request( + name="execute", + request_type=RequestType( + func=lambda request, context: RequestResponse.from_bool(self.get_webpage()) + ), # noqa + ) + + return rm + + def describe_state(self) -> Dict: + """ + Produce a dictionary describing the current state of the WebBrowser. + + :return: A dictionary capturing the current state of the WebBrowser and its child objects. + """ + state = super().describe_state() + state["history"] = [hist_item.state() for hist_item in self.history] + return state + + def get_webpage(self, url: Optional[str] = None) -> bool: + """ + Retrieve the webpage. + + This should send a request to the web server which also requests for a list of users + + :param: url: The address of the web page the browser requests + :type: url: str + """ + url = url or self.target_url + if not self._can_perform_action(): + return False + + self.num_executions += 1 # trying to connect counts as an execution + + # reset latest response + self.latest_response = HttpResponsePacket(status_code=HttpStatusCode.NOT_FOUND) + + try: + parsed_url = urlparse(url) + except Exception: + self.sys_log.warning(f"{url} is not a valid URL") + return False + + # get the IP address of the domain name via DNS + dns_client: DNSClient = self.software_manager.software.get("DNSClient") + domain_exists = dns_client.check_domain_exists(target_domain=parsed_url.hostname) + + # if domain does not exist, the request fails + if domain_exists: + # set current domain name IP address + self.domain_name_ip_address = dns_client.dns_cache[parsed_url.hostname] + else: + # check if url is an ip address + try: + self.domain_name_ip_address = IPv4Address(parsed_url.hostname) + except Exception: + # unable to deal with this request + self.sys_log.warning(f"{self.name}: Unable to resolve URL {url}") + return False + + # create HTTPRequest payload + payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url=url) + + # send request - As part of the self.send call, a response will be received and stored in the + # self.latest_response variable + if self.send( + payload=payload, + dest_ip_address=self.domain_name_ip_address, + dest_port=parsed_url.port if parsed_url.port else Port.HTTP, + ): + self.sys_log.info( + f"{self.name}: Received HTTP {payload.request_method.name} " + f"Response {payload.request_url} - {self.latest_response.status_code.value}" + ) + self.history.append( + WebBrowser.BrowserHistoryItem( + url=url, + status=self.BrowserHistoryItem._HistoryItemStatus.LOADED, + response_code=self.latest_response.status_code, + ) + ) + return self.latest_response.status_code is HttpStatusCode.OK + else: + self.sys_log.warning(f"{self.name}: Error sending Http Packet") + self.sys_log.debug(f"{self.name}: {payload=}") + self.history.append( + WebBrowser.BrowserHistoryItem( + url=url, status=self.BrowserHistoryItem._HistoryItemStatus.SERVER_UNREACHABLE + ) + ) + return False + + def send( + self, + payload: HttpRequestPacket, + dest_ip_address: Optional[IPv4Address] = None, + dest_port: Optional[Port] = Port.HTTP, + session_id: Optional[str] = None, + **kwargs, + ) -> bool: + """ + Sends a payload to the SessionManager. + + :param payload: The payload to be sent. + :param dest_ip_address: The ip address of the payload destination. + :param dest_port: The port of the payload destination. + :param session_id: The Session ID the payload is to originate from. Optional. + + :return: True if successful, False otherwise. + """ + self.sys_log.info(f"{self.name}: Sending HTTP {payload.request_method.name} {payload.request_url}") + + return super().send( + payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id, **kwargs + ) + + def receive(self, payload: HttpResponsePacket, session_id: Optional[str] = None, **kwargs) -> bool: + """ + Receives a payload from the SessionManager. + + :param payload: The payload to be sent. + :param session_id: The Session ID the payload is to originate from. Optional. + :return: True if successful, False otherwise. + """ + if not isinstance(payload, HttpResponsePacket): + self.sys_log.warning(f"{self.name} received a packet that is not an HttpResponsePacket") + self.sys_log.debug(f"{self.name}: {payload=}") + return False + self.sys_log.info(f"{self.name}: Received HTTP {payload.status_code.value}") + self.latest_response = payload + return True + + class BrowserHistoryItem(BaseModel): + """Simple representation of browser history, used for tracking success of web requests to calculate rewards.""" + + model_config = ConfigDict(extra="forbid") + """Error if incorrect specification.""" + + url: str + """The URL that was attempted to be fetched by the browser""" + + class _HistoryItemStatus(Enum): + NOT_SENT = "NOT_SENT" + PENDING = "PENDING" + SERVER_UNREACHABLE = "SERVER_UNREACHABLE" + LOADED = "LOADED" + + status: _HistoryItemStatus = _HistoryItemStatus.PENDING + + response_code: Optional[HttpStatusCode] = None + """HTTP response code that was received, or PENDING if a response was not yet received.""" + + def state(self) -> Dict: + """Return the contents of this dataclass as a dict for use with describe_state method.""" + if self.status == self._HistoryItemStatus.LOADED: + outcome = self.response_code.value + else: + outcome = self.status.value + return {"url": self.url, "outcome": outcome} diff --git a/tests/integration_tests/extensions/nodes/giga_switch.py b/tests/integration_tests/extensions/nodes/giga_switch.py new file mode 100644 index 00000000..b86bea7d --- /dev/null +++ b/tests/integration_tests/extensions/nodes/giga_switch.py @@ -0,0 +1,121 @@ +from typing import Dict + +from prettytable import MARKDOWN, PrettyTable + +from primaite import _LOGGER +from primaite.exceptions import NetworkError +from primaite.simulator.network.hardware.base import Link +from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode +from primaite.simulator.network.hardware.nodes.network.switch import SwitchPort +from primaite.simulator.network.transmission.data_link_layer import Frame + + +class GigaSwitch(NetworkNode, identifier="gigaswitch"): + """ + A class representing a Layer 2 network switch. + + :ivar num_ports: The number of ports on the switch. Default is 24. + """ + + num_ports: int = 24 + "The number of ports on the switch." + network_interfaces: Dict[str, SwitchPort] = {} + "The SwitchPorts on the Switch." + network_interface: Dict[int, SwitchPort] = {} + "The SwitchPorts on the Switch by port id." + mac_address_table: Dict[str, SwitchPort] = {} + "A MAC address table mapping destination MAC addresses to corresponding SwitchPorts." + + def __init__(self, **kwargs): + print('--- Extended Component: GigaSwitch ---') + super().__init__(**kwargs) + for i in range(1, self.num_ports + 1): + self.connect_nic(SwitchPort()) + + def _install_system_software(self): + pass + + def show(self, markdown: bool = False): + """ + Prints a table of the SwitchPorts on the Switch. + + :param markdown: If True, outputs the table in markdown format. Default is False. + """ + table = PrettyTable(["Port", "MAC Address", "Speed", "Status"]) + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = f"{self.hostname} Switch Ports" + for port_num, port in self.network_interface.items(): + table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"]) + print(table) + + def describe_state(self) -> Dict: + """ + Produce a dictionary describing the current state of this object. + + :return: Current state of this object and child objects. + """ + state = super().describe_state() + state["ports"] = {port_num: port.describe_state() for port_num, port in self.network_interface.items()} + state["num_ports"] = self.num_ports # redundant? + state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()} + return state + + def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort): + """ + Private method to add an entry to the MAC address table. + + :param mac_address: MAC address to be added. + :param switch_port: Corresponding SwitchPort object. + """ + mac_table_port = self.mac_address_table.get(mac_address) + if not mac_table_port: + self.mac_address_table[mac_address] = switch_port + self.sys_log.info(f"Added MAC table entry: Port {switch_port.port_num} -> {mac_address}") + else: + if mac_table_port != switch_port: + self.mac_address_table.pop(mac_address) + self.sys_log.info(f"Removed MAC table entry: Port {mac_table_port.port_num} -> {mac_address}") + self._add_mac_table_entry(mac_address, switch_port) + + def receive_frame(self, frame: Frame, from_network_interface: SwitchPort): + """ + Forward a frame to the appropriate port based on the destination MAC address. + + :param frame: The Frame being received. + :param from_network_interface: The SwitchPort that received the frame. + """ + src_mac = frame.ethernet.src_mac_addr + dst_mac = frame.ethernet.dst_mac_addr + self._add_mac_table_entry(src_mac, from_network_interface) + + outgoing_port = self.mac_address_table.get(dst_mac) + if outgoing_port and dst_mac.lower() != "ff:ff:ff:ff:ff:ff": + outgoing_port.send_frame(frame) + else: + # If the destination MAC is not in the table, flood to all ports except incoming + for port in self.network_interface.values(): + if port.enabled and port != from_network_interface: + port.send_frame(frame) + + def disconnect_link_from_port(self, link: Link, port_number: int): + """ + Disconnect a given link from the specified port number on the switch. + + :param link: The Link object to be disconnected. + :param port_number: The port number on the switch from where the link should be disconnected. + :raise NetworkError: When an invalid port number is provided or the link does not match the connection. + """ + port = self.network_interface.get(port_number) + if port is None: + msg = f"Invalid port number {port_number} on the switch" + _LOGGER.error(msg) + raise NetworkError(msg) + + if port._connected_link != link: + msg = f"The link does not match the connection at port number {port_number}" + _LOGGER.error(msg) + raise NetworkError(msg) + + port.disconnect_link() diff --git a/tests/integration_tests/extensions/nodes/super_computer.py b/tests/integration_tests/extensions/nodes/super_computer.py new file mode 100644 index 00000000..8a1465e9 --- /dev/null +++ b/tests/integration_tests/extensions/nodes/super_computer.py @@ -0,0 +1,43 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import ClassVar, Dict + +from primaite.simulator.network.hardware.nodes.host.host_node import NIC, HostNode +from primaite.simulator.system.services.ftp.ftp_client import FTPClient +from primaite.utils.validators import IPV4Address + + +class SuperComputer(HostNode, identifier="supercomputer"): + """ + A basic Computer class. + + Example: + >>> pc_a = Computer( + hostname="pc_a", + ip_address="192.168.1.10", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1" + ) + >>> pc_a.power_on() + + Instances of computer come 'pre-packaged' with the following: + + * Core Functionality: + * Packet Capture + * Sys Log + * Services: + * ARP Service + * ICMP Service + * DNS Client + * FTP Client + * NTP Client + * Applications: + * Web Browser + """ + + SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "FTPClient": FTPClient} + + def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs): + print('--- Extended Component: SuperComputer ---') + super().__init__(ip_address=ip_address, subnet_mask=subnet_mask, **kwargs) + + pass diff --git a/tests/integration_tests/extensions/services/extended_service.py b/tests/integration_tests/extensions/services/extended_service.py new file mode 100644 index 00000000..3151571b --- /dev/null +++ b/tests/integration_tests/extensions/services/extended_service.py @@ -0,0 +1,426 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from ipaddress import IPv4Address +from typing import Any, Dict, List, Literal, Optional, Union +from uuid import uuid4 + +from primaite import getLogger +from primaite.simulator.file_system.file_system import File +from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus +from primaite.simulator.file_system.folder import Folder +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.core.software_manager import SoftwareManager +from primaite.simulator.system.services.ftp.ftp_client import FTPClient +from primaite.simulator.system.services.service import Service, ServiceOperatingState +from primaite.simulator.system.software import SoftwareHealthState + +_LOGGER = getLogger(__name__) + + +class ExtendedService(Service, identifier='extendedservice'): + """ + A copy of DatabaseService that uses the extension framework instead of being part of PrimAITE. + + This class inherits from the `Service` class and provides methods to simulate a SQL database. + """ + + password: Optional[str] = None + """Password that needs to be provided by clients if they want to connect to the DatabaseService.""" + + backup_server_ip: IPv4Address = None + """IP address of the backup server.""" + + latest_backup_directory: str = None + """Directory of latest backup.""" + + latest_backup_file_name: str = None + """File name of latest backup.""" + + def __init__(self, **kwargs): + kwargs["name"] = "ExtendedService" + kwargs["port"] = Port.POSTGRES_SERVER + kwargs["protocol"] = IPProtocol.TCP + super().__init__(**kwargs) + self._create_db_file() + if kwargs.get('options'): + opt = kwargs["options"] + self.password = opt.get("db_password", None) + if "backup_server_ip" in opt: + self.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip"))) + + def install(self): + """ + Perform first-time setup of the ExtendedService. + + Installs an instance of FTPClient on the Node to enable database backup if it isn't installed already. + """ + super().install() + + if not self.parent.software_manager.software.get("FTPClient"): + self.parent.sys_log.info(f"{self.name}: Installing FTPClient to enable database backups") + self.parent.software_manager.install(FTPClient) + + def configure_backup(self, backup_server: IPv4Address): + """ + Set up the database backup. + + :param: backup_server_ip: The IP address of the backup server + """ + self.backup_server_ip = backup_server + + def backup_database(self) -> bool: + """Create a backup of the database to the configured backup server.""" + # check if this action can be performed + if not self._can_perform_action(): + return False + + # check if the backup server was configured + if self.backup_server_ip is None: + self.sys_log.warning(f"{self.name} - {self.sys_log.hostname}: not configured.") + return False + + software_manager: SoftwareManager = self.software_manager + ftp_client_service: FTPClient = software_manager.software.get("FTPClient") + + if not ftp_client_service: + self.sys_log.error( + f"{self.name}: Failed to perform database backup as the FTPClient software is not installed" + ) + return False + + # send backup copy of database file to FTP server + if not self.db_file: + self.sys_log.error(f"{self.name}: Attempted to backup database file but it doesn't exist.") + return False + + response = ftp_client_service.send_file( + dest_ip_address=self.backup_server_ip, + src_file_name=self.db_file.name, + src_folder_name="database", + dest_folder_name=str(self.uuid), + # Prevent's a filename clash with the real DatabaseService service implementation + dest_file_name="extended_service_database.db", + ) + + if response: + return True + + self.sys_log.error("Unable to create database backup.") + return False + + def restore_backup(self) -> bool: + """Restore a backup from backup server.""" + # check if this action can be performed + if not self._can_perform_action(): + return False + + software_manager: SoftwareManager = self.software_manager + ftp_client_service: FTPClient = software_manager.software.get("FTPClient") + + if not ftp_client_service: + self.sys_log.error( + f"{self.name}: Failed to restore database backup as the FTPClient software is not installed" + ) + return False + + # retrieve backup file from backup server + response = ftp_client_service.request_file( + src_folder_name=str(self.uuid), + src_file_name="extended_service_database.db", + dest_folder_name="downloads", + dest_file_name="extended_service_database.db", + dest_ip_address=self.backup_server_ip, + ) + + if not response: + self.sys_log.error("Unable to restore database backup.") + return False + + old_visible_state = SoftwareHealthState.GOOD + + # get db file regardless of whether or not it was deleted + db_file = self.file_system.get_file(folder_name="database", file_name="extended_service_database.db", include_deleted=True) + + if db_file is None: + self.sys_log.warning("Database file not initialised.") + return False + + # if the file was deleted, get the old visible health state + if db_file.deleted: + old_visible_state = db_file.visible_health_status + else: + old_visible_state = self.db_file.visible_health_status + self.file_system.delete_file(folder_name="database", file_name="extended_service_database.db") + + # replace db file + self.file_system.copy_file(src_folder_name="downloads", src_file_name="extended_service_database.db", dst_folder_name="database") + + if self.db_file is None: + self.sys_log.error("Copying database backup failed.") + return False + + self.db_file.visible_health_status = old_visible_state + self.set_health_state(SoftwareHealthState.GOOD) + + return True + + def _create_db_file(self): + """Creates the Simulation File and sqlite file in the file system.""" + self.file_system.create_file(folder_name="database", file_name="extended_service_database.db") + + @property + def db_file(self) -> File: + """Returns the database file.""" + return self.file_system.get_file(folder_name="database", file_name="extended_service_database.db") + + def _return_database_folder(self) -> Folder: + """Returns the database folder.""" + return self.file_system.get_folder_by_id(self.db_file.folder_id) + + def _generate_connection_id(self) -> str: + """Generate a unique connection ID.""" + return str(uuid4()) + + def _process_connect( + self, + src_ip: IPv4Address, + connection_request_id: str, + password: Optional[str] = None, + session_id: Optional[str] = None, + ) -> Dict[str, Union[int, Dict[str, bool]]]: + """Process an incoming connection request. + + :param connection_id: A unique identifier for the connection + :type connection_id: str + :param password: Supplied password. It must match self.password for connection success, defaults to None + :type password: Optional[str], optional + :return: Response to connection request containing success info. + :rtype: Dict[str, Union[int, Dict[str, bool]]] + """ + self.sys_log.info(f"{self.name}: Processing new connection request ({connection_request_id}) from {src_ip}") + status_code = 500 # Default internal server error + connection_id = None + if self.operating_state == ServiceOperatingState.RUNNING: + status_code = 503 # service unavailable + if self.health_state_actual == SoftwareHealthState.OVERWHELMED: + self.sys_log.info( + f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, service is at " + f"capacity." + ) + if self.health_state_actual in [ + SoftwareHealthState.GOOD, + SoftwareHealthState.FIXING, + SoftwareHealthState.COMPROMISED, + ]: + if self.password == password: + status_code = 200 # ok + connection_id = self._generate_connection_id() + # try to create connection + if not self.add_connection(connection_id=connection_id, session_id=session_id): + status_code = 500 + self.sys_log.info( + f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, " + f"returning status code 500" + ) + else: + status_code = 401 # Unauthorised + self.sys_log.info( + f"{self.name}: Connection request ({connection_request_id}) from {src_ip} unauthorised " + f"(incorrect password), returning status code 401" + ) + else: + status_code = 404 # service not found + return { + "status_code": status_code, + "type": "connect_response", + "response": status_code == 200, + "connection_id": connection_id, + "connection_request_id": connection_request_id, + } + + def _process_sql( + self, + query: Literal["SELECT", "DELETE", "INSERT", "ENCRYPT"], + query_id: str, + connection_id: Optional[str] = None, + ) -> Dict[str, Union[int, List[Any]]]: + """ + Executes the given SQL query and returns the result. + + Possible queries: + - SELECT : returns the data + - DELETE : deletes the data + - INSERT : inserts the data + - ENCRYPT : corrupts the data + + :param query: The SQL query to be executed. + :return: Dictionary containing status code and data fetched. + """ + self.sys_log.info(f"{self.name}: Running {query}") + + if not self.db_file: + self.sys_log.error(f"{self.name}: Failed to run {query} because the database file is missing.") + return {"status_code": 404, "type": "sql", "data": False} + + if self.health_state_actual is not SoftwareHealthState.GOOD: + self.sys_log.error(f"{self.name}: Failed to run {query} because the database service is unavailable.") + return {"status_code": 500, "type": "sql", "data": False} + + if query == "SELECT": + if self.db_file.health_status == FileSystemItemHealthStatus.CORRUPT: + return { + "status_code": 200, + "type": "sql", + "data": False, + "uuid": query_id, + "connection_id": connection_id, + } + elif self.db_file.health_status == FileSystemItemHealthStatus.GOOD: + return { + "status_code": 200, + "type": "sql", + "data": True, + "uuid": query_id, + "connection_id": connection_id, + } + else: + return {"status_code": 404, "type": "sql", "data": False} + elif query == "DELETE": + self.db_file.health_status = FileSystemItemHealthStatus.COMPROMISED + return { + "status_code": 200, + "type": "sql", + "data": False, + "uuid": query_id, + "connection_id": connection_id, + } + elif query == "ENCRYPT": + self.file_system.num_file_creations += 1 + self.db_file.health_status = FileSystemItemHealthStatus.CORRUPT + self.db_file.num_access += 1 + database_folder = self._return_database_folder() + database_folder.health_status = FileSystemItemHealthStatus.CORRUPT + self.file_system.num_file_deletions += 1 + return { + "status_code": 200, + "type": "sql", + "data": False, + "uuid": query_id, + "connection_id": connection_id, + } + elif query == "INSERT": + if self.health_state_actual == SoftwareHealthState.GOOD: + return { + "status_code": 200, + "type": "sql", + "data": False, + "uuid": query_id, + "connection_id": connection_id, + } + else: + return {"status_code": 404, "type": "sql", "data": False} + elif query == "SELECT * FROM pg_stat_activity": + # Check if the connection is active. + if self.health_state_actual == SoftwareHealthState.GOOD: + return { + "status_code": 200, + "type": "sql", + "data": False, + "uuid": query_id, + "connection_id": connection_id, + } + else: + return {"status_code": 401, "data": False} + else: + # Invalid query + self.sys_log.warning(f"{self.name}: Invalid {query}") + return {"status_code": 500, "data": False} + + def describe_state(self) -> Dict: + """ + Produce a dictionary describing the current state of this object. + + Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation. + + :return: Current state of this object and child objects. + :rtype: Dict + """ + return super().describe_state() + + def receive(self, payload: Any, session_id: str, **kwargs) -> bool: + """ + Processes the incoming SQL payload and sends the result back. + + :param payload: The SQL query to be executed. + :param session_id: The session identifier. + :return: True if the Status Code is 200, otherwise False. + """ + result = {"status_code": 500, "data": []} + # if server service is down, return error + if not self._can_perform_action(): + return False + + if isinstance(payload, dict) and payload.get("type"): + if payload["type"] == "connect_request": + src_ip = kwargs.get("frame").ip.src_ip_address + result = self._process_connect( + src_ip=src_ip, + password=payload.get("password"), + connection_request_id=payload.get("connection_request_id"), + session_id=session_id, + ) + elif payload["type"] == "disconnect": + if payload["connection_id"] in self.connections: + connection_id = payload["connection_id"] + connected_ip_address = self.connections[connection_id]["ip_address"] + frame = kwargs.get("frame") + if connected_ip_address == frame.ip.src_ip_address: + self.sys_log.info( + f"{self.name}: Received disconnect command for {connection_id=} from {connected_ip_address}" + ) + self.terminate_connection(connection_id=payload["connection_id"], send_disconnect=False) + else: + self.sys_log.warning( + f"{self.name}: Ignoring disconnect command for {connection_id=} as the command source " + f"({frame.ip.src_ip_address}) doesn't match the connection source ({connected_ip_address})" + ) + elif payload["type"] == "sql": + if payload.get("connection_id") in self.connections: + result = self._process_sql( + query=payload["sql"], query_id=payload["uuid"], connection_id=payload["connection_id"] + ) + else: + result = {"status_code": 401, "type": "sql"} + else: + self.sys_log.info(f"{self.name}: Ignoring payload as it is not a Database payload") + self.send(payload=result, session_id=session_id) + return True + + def send(self, payload: Any, session_id: str, **kwargs) -> bool: + """ + Send a SQL response back down to the SessionManager. + + :param payload: The SQL query results. + :param session_id: The session identifier. + :return: True if the Status Code is 200, otherwise False. + """ + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id) + + return payload["status_code"] == 200 + + def apply_timestep(self, timestep: int) -> None: + """ + Apply a single timestep of simulation dynamics to this service. + + Here at the first step, the database backup is created, in addition to normal service update logic. + """ + if timestep == 1: + self.backup_database() + return super().apply_timestep(timestep) + + def _update_fix_status(self) -> None: + """Perform a database restore when the FIXING countdown is finished.""" + super()._update_fix_status() + if self._fixing_countdown is None: + self.restore_backup() diff --git a/tests/integration_tests/extensions/test_extendable_config.py b/tests/integration_tests/extensions/test_extendable_config.py new file mode 100644 index 00000000..5d8af64d --- /dev/null +++ b/tests/integration_tests/extensions/test_extendable_config.py @@ -0,0 +1,32 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from primaite.config.load import get_extended_config_path +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ_NETWORK, load_config +import os + +# Import the extended components so that PrimAITE registers them +from tests.integration_tests.extensions.nodes.super_computer import SuperComputer +from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch +from tests.integration_tests.extensions.services.extended_service import ExtendedService +from tests.integration_tests.extensions.applications.extended_application import ExtendedApplication + + +def test_extended_example_config(): + + """Test that the example config can be parsed properly.""" + config_path = os.path.join( "tests", "assets", "configs", "extended_config.yaml") + game = load_config(config_path) + network: Network = game.simulation.network + + assert len(network.nodes) == 10 # 10 nodes in example network + assert len(network.computer_nodes) == 1 + assert len(network.router_nodes) == 1 # 1 router in network + assert len(network.switch_nodes) == 1 # 1 switches in network + assert len(network.server_nodes) == 5 # 5 servers in network + assert len(network.extended_hostnodes) == 1 # One extended node based on HostNode + assert len(network.extended_networknodes) == 1 # One extended node based on NetworkNode + + assert 'ExtendedApplication' in network.extended_hostnodes[0].software_manager.software + assert 'ExtendedService' in network.extended_hostnodes[0].software_manager.software diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_registry.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_registry.py index d8d7dfab..f97e915e 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_registry.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_registry.py @@ -8,7 +8,7 @@ def test_adding_to_app_registry(): class temp_application(Application, identifier="temp_app"): pass - assert Application._application_registry["temp_app"] is temp_application + assert Application._registry["temp_app"] is temp_application with pytest.raises(ValueError): @@ -19,4 +19,4 @@ def test_adding_to_app_registry(): # Because pytest doesn't reimport classes from modules, registering this temporary test application will change the # state of the Application registry for all subsequently run tests. So, we have to delete and unregister the class. del temp_application - Application._application_registry.pop("temp_app") + Application._registry.pop("temp_app")