diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index f369bc2b..48d9df87 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -19,14 +19,8 @@ from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT from primaite.simulator.network.creation import NetworkNodeAdder from primaite.simulator.network.hardware.base import NetworkInterface, Node, NodeOperatingState, UserManager -from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC -from primaite.simulator.network.hardware.nodes.host.server import Printer, Server -from primaite.simulator.network.hardware.nodes.network.firewall import Firewall -from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode -from primaite.simulator.network.hardware.nodes.network.router import Router +from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.application import Application @@ -277,68 +271,73 @@ class PrimaiteGame: for node_cfg in nodes_cfg: n_type = node_cfg["type"] + node_config: dict = node_cfg["config"] new_node = None + if n_type in Node._registry: + # simplify down Node creation: + new_node = Node._registry["n_type"].from_config(config=node_config) + # Default PrimAITE nodes - if n_type == "computer": - new_node = Computer( - hostname=node_cfg["hostname"], - ip_address=node_cfg["ip_address"], - subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), - default_gateway=node_cfg.get("default_gateway"), - dns_server=node_cfg.get("dns_server", None), - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - elif n_type == "server": - new_node = Server( - hostname=node_cfg["hostname"], - ip_address=node_cfg["ip_address"], - subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), - default_gateway=node_cfg.get("default_gateway"), - dns_server=node_cfg.get("dns_server", None), - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - elif n_type == "switch": - new_node = Switch( - hostname=node_cfg["hostname"], - num_ports=int(node_cfg.get("num_ports", "8")), - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - elif n_type == "router": - new_node = Router.from_config(node_cfg) - elif n_type == "firewall": - new_node = Firewall.from_config(node_cfg) - elif n_type == "wireless_router": - new_node = WirelessRouter.from_config(node_cfg, airspace=net.airspace) - elif n_type == "printer": - new_node = Printer( - hostname=node_cfg["hostname"], - ip_address=node_cfg["ip_address"], - subnet_mask=node_cfg["subnet_mask"], - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - # Handle extended nodes - elif n_type.lower() in Node._registry: - new_node = HostNode._registry[n_type]( - hostname=node_cfg["hostname"], - ip_address=node_cfg.get("ip_address"), - subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), - default_gateway=node_cfg.get("default_gateway"), - dns_server=node_cfg.get("dns_server", None), - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - elif n_type in NetworkNode._registry: - new_node = NetworkNode._registry[n_type](**node_cfg) + # if n_type == "computer": + # new_node = Computer( + # hostname=node_cfg["hostname"], + # ip_address=node_cfg["ip_address"], + # subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), + # default_gateway=node_cfg.get("default_gateway"), + # dns_server=node_cfg.get("dns_server", None), + # operating_state=NodeOperatingState.ON + # if not (p := node_cfg.get("operating_state")) + # else NodeOperatingState[p.upper()], + # ) + # elif n_type == "server": + # new_node = Server( + # hostname=node_cfg["hostname"], + # ip_address=node_cfg["ip_address"], + # subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), + # default_gateway=node_cfg.get("default_gateway"), + # dns_server=node_cfg.get("dns_server", None), + # operating_state=NodeOperatingState.ON + # if not (p := node_cfg.get("operating_state")) + # else NodeOperatingState[p.upper()], + # ) + # elif n_type == "switch": + # new_node = Switch( + # hostname=node_cfg["hostname"], + # num_ports=int(node_cfg.get("num_ports", "8")), + # operating_state=NodeOperatingState.ON + # if not (p := node_cfg.get("operating_state")) + # else NodeOperatingState[p.upper()], + # ) + # elif n_type == "router": + # new_node = Router.from_config(node_cfg) + # elif n_type == "firewall": + # new_node = Firewall.from_config(node_cfg) + # elif n_type == "wireless_router": + # new_node = WirelessRouter.from_config(node_cfg, airspace=net.airspace) + # elif n_type == "printer": + # new_node = Printer( + # hostname=node_cfg["hostname"], + # ip_address=node_cfg["ip_address"], + # subnet_mask=node_cfg["subnet_mask"], + # operating_state=NodeOperatingState.ON + # if not (p := node_cfg.get("operating_state")) + # else NodeOperatingState[p.upper()], + # ) + # # Handle extended nodes + # elif n_type.lower() in Node._registry: + # new_node = HostNode._registry[n_type]( + # hostname=node_cfg["hostname"], + # ip_address=node_cfg.get("ip_address"), + # subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), + # default_gateway=node_cfg.get("default_gateway"), + # dns_server=node_cfg.get("dns_server", None), + # operating_state=NodeOperatingState.ON + # if not (p := node_cfg.get("operating_state")) + # else NodeOperatingState[p.upper()], + # ) + # elif n_type in NetworkNode._registry: + # new_node = NetworkNode._registry[n_type](**node_cfg) else: msg = f"invalid node type {n_type} in config" _LOGGER.error(msg) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 8324715f..b003009b 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1469,7 +1469,7 @@ class UserSessionManager(Service): return self.local_session is not None -class Node(SimComponent): +class Node(SimComponent, ABC): """ A basic Node class that represents a node on the network. @@ -1492,7 +1492,6 @@ class Node(SimComponent): "The Network Interfaces on the node by port id." dns_server: Optional[IPv4Address] = None "List of IP addresses of DNS servers used for name resolution." - accounts: Dict[str, Account] = {} "All accounts on the node." applications: Dict[str, Application] = {} @@ -1509,33 +1508,6 @@ class Node(SimComponent): session_manager: SessionManager software_manager: SoftwareManager - revealed_to_red: bool = False - "Informs whether the node has been revealed to a red agent." - - start_up_duration: int = 3 - "Time steps needed for the node to start up." - - start_up_countdown: int = 0 - "Time steps needed until node is booted up." - - shut_down_duration: int = 3 - "Time steps needed for the node to shut down." - - shut_down_countdown: int = 0 - "Time steps needed until node is shut down." - - is_resetting: bool = False - "If true, the node will try turning itself off then back on again." - - node_scan_duration: int = 10 - "How many timesteps until the whole node is scanned. Default 10 time steps." - - node_scan_countdown: int = 0 - "Time steps until scan is complete" - - red_scan_countdown: int = 0 - "Time steps until reveal to red scan is complete." - SYSTEM_SOFTWARE: ClassVar[Dict[str, Type[Software]]] = {} "Base system software that must be preinstalled." @@ -1545,6 +1517,46 @@ class Node(SimComponent): _identifier: ClassVar[str] = "unknown" """Identifier for this particular class, used for printing and logging. Each subclass redefines this.""" + config: Node.ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema()) + + class ConfigSchema: + """Configuration Schema for Node based classes.""" + + revealed_to_red: bool = False + "Informs whether the node has been revealed to a red agent." + + start_up_duration: int = 3 + "Time steps needed for the node to start up." + + start_up_countdown: int = 0 + "Time steps needed until node is booted up." + + shut_down_duration: int = 3 + "Time steps needed for the node to shut down." + + shut_down_countdown: int = 0 + "Time steps needed until node is shut down." + + is_resetting: bool = False + "If true, the node will try turning itself off then back on again." + + node_scan_duration: int = 10 + "How many timesteps until the whole node is scanned. Default 10 time steps." + + node_scan_countdown: int = 0 + "Time steps until scan is complete" + + red_scan_countdown: int = 0 + "Time steps until reveal to red scan is complete." + + def from_config(cls, config: Dict) -> Node: + """Create Node object from a given configuration.""" + if config["type"] not in cls._registry: + msg = f"Configuration contains an invalid Node type: {config['type']}" + return ValueError(msg) + obj = cls(config=cls.ConfigSchema(**config)) + return obj + def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: """ Register a node type. @@ -1850,7 +1862,7 @@ class Node(SimComponent): "applications": {app.name: app.describe_state() for app in self.applications.values()}, "services": {svc.name: svc.describe_state() for svc in self.services.values()}, "process": {proc.name: proc.describe_state() for proc in self.processes.values()}, - "revealed_to_red": self.revealed_to_red, + "revealed_to_red": self.config.revealed_to_red, } ) return state @@ -1928,8 +1940,8 @@ class Node(SimComponent): network_interface.apply_timestep(timestep=timestep) # count down to boot up - if self.start_up_countdown > 0: - self.start_up_countdown -= 1 + if self.config.start_up_countdown > 0: + self.config.start_up_countdown -= 1 else: if self.operating_state == NodeOperatingState.BOOTING: self.operating_state = NodeOperatingState.ON @@ -1940,8 +1952,8 @@ class Node(SimComponent): self._start_up_actions() # count down to shut down - if self.shut_down_countdown > 0: - self.shut_down_countdown -= 1 + if self.config.shut_down_countdown > 0: + self.config.shut_down_countdown -= 1 else: if self.operating_state == NodeOperatingState.SHUTTING_DOWN: self.operating_state = NodeOperatingState.OFF @@ -1949,17 +1961,17 @@ class Node(SimComponent): self._shut_down_actions() # if resetting turn back on - if self.is_resetting: - self.is_resetting = False + if self.config.is_resetting: + self.config.is_resetting = False self.power_on() # time steps which require the node to be on if self.operating_state == NodeOperatingState.ON: # node scanning - if self.node_scan_countdown > 0: - self.node_scan_countdown -= 1 + if self.config.node_scan_countdown > 0: + self.config.node_scan_countdown -= 1 - if self.node_scan_countdown == 0: + if self.config.node_scan_countdown == 0: # scan everything! for process_id in self.processes: self.processes[process_id].scan() @@ -1975,10 +1987,10 @@ class Node(SimComponent): # scan file system self.file_system.scan(instant_scan=True) - if self.red_scan_countdown > 0: - self.red_scan_countdown -= 1 + if self.config.red_scan_countdown > 0: + self.config.red_scan_countdown -= 1 - if self.red_scan_countdown == 0: + if self.config.red_scan_countdown == 0: # scan processes for process_id in self.processes: self.processes[process_id].reveal_to_red() @@ -2035,7 +2047,7 @@ class Node(SimComponent): to the red agent. """ - self.node_scan_countdown = self.node_scan_duration + self.config.node_scan_countdown = self.config.node_scan_duration return True def reveal_to_red(self) -> bool: @@ -2051,12 +2063,12 @@ class Node(SimComponent): `revealed_to_red` to `True`. """ - self.red_scan_countdown = self.node_scan_duration + self.config.red_scan_countdown = self.config.node_scan_duration return True def power_on(self) -> bool: """Power on the Node, enabling its NICs if it is in the OFF state.""" - if self.start_up_duration <= 0: + if self.config.start_up_duration <= 0: self.operating_state = NodeOperatingState.ON self._start_up_actions() self.sys_log.info("Power on") @@ -2065,14 +2077,14 @@ class Node(SimComponent): return True if self.operating_state == NodeOperatingState.OFF: self.operating_state = NodeOperatingState.BOOTING - self.start_up_countdown = self.start_up_duration + self.config.start_up_countdown = self.config.start_up_duration return True return False def power_off(self) -> bool: """Power off the Node, disabling its NICs if it is in the ON state.""" - if self.shut_down_duration <= 0: + if self.config.shut_down_duration <= 0: self._shut_down_actions() self.operating_state = NodeOperatingState.OFF self.sys_log.info("Power off") @@ -2081,7 +2093,7 @@ class Node(SimComponent): for network_interface in self.network_interfaces.values(): network_interface.disable() self.operating_state = NodeOperatingState.SHUTTING_DOWN - self.shut_down_countdown = self.shut_down_duration + self.config.shut_down_countdown = self.config.shut_down_duration return True return False @@ -2093,7 +2105,7 @@ class Node(SimComponent): Applying more timesteps will eventually turn the node back on. """ if self.operating_state.ON: - self.is_resetting = True + self.config.is_resetting = True self.sys_log.info("Resetting") self.power_off() return True diff --git a/src/primaite/simulator/network/hardware/nodes/host/computer.py b/src/primaite/simulator/network/hardware/nodes/host/computer.py index 11b925b9..1fb63b2e 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/computer.py +++ b/src/primaite/simulator/network/hardware/nodes/host/computer.py @@ -1,6 +1,8 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import ClassVar, Dict +from pydantic import Field + from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.system.services.ftp.ftp_client import FTPClient @@ -35,4 +37,11 @@ class Computer(HostNode, identifier="computer"): SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "FTPClient": FTPClient} + config: "Computer.ConfigSchema" = Field(default_factory=lambda: Computer.ConfigSchema()) + + class ConfigSchema(HostNode.ConfigSchema): + """Configuration Schema for Computer class.""" + + pass + pass diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index c51afbca..fa73bf10 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -4,6 +4,8 @@ from __future__ import annotations from ipaddress import IPv4Address from typing import Any, ClassVar, Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.hardware.base import ( IPWiredNetworkInterface, @@ -325,6 +327,13 @@ class HostNode(Node, identifier="HostNode"): network_interface: Dict[int, NIC] = {} "The NICs on the node by port id." + config: HostNode.ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema()) + + class ConfigSchema(Node.ConfigSchema): + """Configuration Schema for HostNode class.""" + + pass + def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs): super().__init__(**kwargs) self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask)) diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index f1ca4930..0a397c49 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -99,6 +99,13 @@ class Firewall(Router, identifier="firewall"): ) """Access Control List for managing traffic leaving towards an external network.""" + config: "Firewall.ConfigSchema" = Field(default_factory=lambda: Firewall.ConfigSchema()) + + class ConfigSchema(Router.ConfigSChema): + """Configuration Schema for Firewall 'Nodes' within PrimAITE.""" + + pass + def __init__(self, hostname: str, **kwargs): if not kwargs.get("sys_log"): kwargs["sys_log"] = SysLog(hostname) diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 4a049f99..132b1462 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -7,7 +7,7 @@ from ipaddress import IPv4Address, IPv4Network from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union from prettytable import MARKDOWN, PrettyTable -from pydantic import validate_call +from pydantic import Field, validate_call from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent @@ -1207,7 +1207,6 @@ class Router(NetworkNode, identifier="router"): "Terminal": Terminal, } - num_ports: int network_interfaces: Dict[str, RouterInterface] = {} "The Router Interfaces on the node." network_interface: Dict[int, RouterInterface] = {} @@ -1215,6 +1214,15 @@ class Router(NetworkNode, identifier="router"): acl: AccessControlList route_table: RouteTable + config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSChema()) + + class ConfigSChema(NetworkNode.ConfigSchema): + """Configuration Schema for Router Objects.""" + + num_ports: int = 10 + hostname: str = "Router" + ports: list = [] + def __init__(self, hostname: str, num_ports: int = 5, **kwargs): if not kwargs.get("sys_log"): kwargs["sys_log"] = SysLog(hostname) @@ -1227,11 +1235,11 @@ class Router(NetworkNode, identifier="router"): self.session_manager.node = self self.software_manager.session_manager = self.session_manager self.session_manager.software_manager = self.software_manager - for i in range(1, self.num_ports + 1): + for i in range(1, self.config.num_ports + 1): network_interface = RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0") self.connect_nic(network_interface) self.network_interface[i] = network_interface - + self.operating_state = NodeOperatingState.ON self._set_default_acl() def _install_system_software(self): @@ -1337,7 +1345,7 @@ class Router(NetworkNode, identifier="router"): :return: A dictionary representing the current state. """ state = super().describe_state() - state["num_ports"] = self.num_ports + state["num_ports"] = self.config.num_ports state["acl"] = self.acl.describe_state() return state @@ -1558,6 +1566,8 @@ class Router(NetworkNode, identifier="router"): ) print(table) + # TODO: Remove - Cover normal config items with ConfigSchema. Move additional setup components to __init__ ? + @classmethod def from_config(cls, cfg: dict, **kwargs) -> "Router": """Create a router based on a config dict. diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index db923f1a..b73af3cb 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -4,6 +4,7 @@ from __future__ import annotations from typing import Dict, Optional from prettytable import MARKDOWN, PrettyTable +from pydantic import Field from primaite import getLogger from primaite.exceptions import NetworkError @@ -94,8 +95,6 @@ class Switch(NetworkNode, identifier="switch"): :ivar num_ports: The number of ports on the switch. Default is 24. """ - num_ports: int = 24 - "The number of ports on the switch." network_interfaces: Dict[str, SwitchPort] = {} "The SwitchPorts on the Switch." network_interface: Dict[int, SwitchPort] = {} @@ -103,9 +102,17 @@ class Switch(NetworkNode, identifier="switch"): mac_address_table: Dict[str, SwitchPort] = {} "A MAC address table mapping destination MAC addresses to corresponding SwitchPorts." + config: "Switch.ConfigSchema" = Field(default_factory=lambda: Switch.ConfigSchema()) + + class ConfigSchema(NetworkNode.ConfigSchema): + """Configuration Schema for Switch nodes within PrimAITE.""" + + num_ports: int = 24 + "The number of ports on the switch." + def __init__(self, **kwargs): super().__init__(**kwargs) - for i in range(1, self.num_ports + 1): + for i in range(1, self.config.num_ports + 1): self.connect_nic(SwitchPort()) def _install_system_software(self): @@ -134,7 +141,7 @@ class Switch(NetworkNode, identifier="switch"): """ state = super().describe_state() state["ports"] = {port_num: port.describe_state() for port_num, port in self.network_interface.items()} - state["num_ports"] = self.num_ports # redundant? + state["num_ports"] = self.config.num_ports # redundant? state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()} return state diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 804a570e..0a527ab8 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -2,7 +2,7 @@ from ipaddress import IPv4Address from typing import Any, Dict, Optional, Union -from pydantic import validate_call +from pydantic import Field, validate_call from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, FREQ_WIFI_2_4, IPWirelessNetworkInterface from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState @@ -124,6 +124,13 @@ class WirelessRouter(Router, identifier="wireless_router"): network_interface: Dict[int, Union[RouterInterface, WirelessAccessPoint]] = {} airspace: AirSpace + config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.Configschema()) + + class ConfigSchema(Router.ConfigSChema): + """Configuration Schema for WirelessRouter nodes within PrimAITE.""" + + pass + def __init__(self, hostname: str, airspace: AirSpace, **kwargs): super().__init__(hostname=hostname, num_ports=0, airspace=airspace, **kwargs)