From 582e7cfec784fdfaf240eeb28d58218a89b9d7eb Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 15 Jan 2025 11:21:18 +0000 Subject: [PATCH 01/23] #2887 - Initial commit of Node refactor for extensibility in version 4.0.0. Addition of ConfigSchema and changes to how Nodes are generated within Game.py --- src/primaite/game/game.py | 131 +++++++++--------- .../simulator/network/hardware/base.py | 110 ++++++++------- .../network/hardware/nodes/host/computer.py | 9 ++ .../network/hardware/nodes/host/host_node.py | 9 ++ .../hardware/nodes/network/firewall.py | 7 + .../network/hardware/nodes/network/router.py | 20 ++- .../network/hardware/nodes/network/switch.py | 15 +- .../hardware/nodes/network/wireless_router.py | 9 +- 8 files changed, 185 insertions(+), 125 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index f369bc2b..48d9df87 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -19,14 +19,8 @@ from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT from primaite.simulator.network.creation import NetworkNodeAdder from primaite.simulator.network.hardware.base import NetworkInterface, Node, NodeOperatingState, UserManager -from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC -from primaite.simulator.network.hardware.nodes.host.server import Printer, Server -from primaite.simulator.network.hardware.nodes.network.firewall import Firewall -from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode -from primaite.simulator.network.hardware.nodes.network.router import Router +from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.application import Application @@ -277,68 +271,73 @@ class PrimaiteGame: for node_cfg in nodes_cfg: n_type = node_cfg["type"] + node_config: dict = node_cfg["config"] new_node = None + if n_type in Node._registry: + # simplify down Node creation: + new_node = Node._registry["n_type"].from_config(config=node_config) + # Default PrimAITE nodes - if n_type == "computer": - new_node = Computer( - hostname=node_cfg["hostname"], - ip_address=node_cfg["ip_address"], - subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), - default_gateway=node_cfg.get("default_gateway"), - dns_server=node_cfg.get("dns_server", None), - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - elif n_type == "server": - new_node = Server( - hostname=node_cfg["hostname"], - ip_address=node_cfg["ip_address"], - subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), - default_gateway=node_cfg.get("default_gateway"), - dns_server=node_cfg.get("dns_server", None), - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - elif n_type == "switch": - new_node = Switch( - hostname=node_cfg["hostname"], - num_ports=int(node_cfg.get("num_ports", "8")), - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - elif n_type == "router": - new_node = Router.from_config(node_cfg) - elif n_type == "firewall": - new_node = Firewall.from_config(node_cfg) - elif n_type == "wireless_router": - new_node = WirelessRouter.from_config(node_cfg, airspace=net.airspace) - elif n_type == "printer": - new_node = Printer( - hostname=node_cfg["hostname"], - ip_address=node_cfg["ip_address"], - subnet_mask=node_cfg["subnet_mask"], - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - # Handle extended nodes - elif n_type.lower() in Node._registry: - new_node = HostNode._registry[n_type]( - hostname=node_cfg["hostname"], - ip_address=node_cfg.get("ip_address"), - subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), - default_gateway=node_cfg.get("default_gateway"), - dns_server=node_cfg.get("dns_server", None), - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - elif n_type in NetworkNode._registry: - new_node = NetworkNode._registry[n_type](**node_cfg) + # if n_type == "computer": + # new_node = Computer( + # hostname=node_cfg["hostname"], + # ip_address=node_cfg["ip_address"], + # subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), + # default_gateway=node_cfg.get("default_gateway"), + # dns_server=node_cfg.get("dns_server", None), + # operating_state=NodeOperatingState.ON + # if not (p := node_cfg.get("operating_state")) + # else NodeOperatingState[p.upper()], + # ) + # elif n_type == "server": + # new_node = Server( + # hostname=node_cfg["hostname"], + # ip_address=node_cfg["ip_address"], + # subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), + # default_gateway=node_cfg.get("default_gateway"), + # dns_server=node_cfg.get("dns_server", None), + # operating_state=NodeOperatingState.ON + # if not (p := node_cfg.get("operating_state")) + # else NodeOperatingState[p.upper()], + # ) + # elif n_type == "switch": + # new_node = Switch( + # hostname=node_cfg["hostname"], + # num_ports=int(node_cfg.get("num_ports", "8")), + # operating_state=NodeOperatingState.ON + # if not (p := node_cfg.get("operating_state")) + # else NodeOperatingState[p.upper()], + # ) + # elif n_type == "router": + # new_node = Router.from_config(node_cfg) + # elif n_type == "firewall": + # new_node = Firewall.from_config(node_cfg) + # elif n_type == "wireless_router": + # new_node = WirelessRouter.from_config(node_cfg, airspace=net.airspace) + # elif n_type == "printer": + # new_node = Printer( + # hostname=node_cfg["hostname"], + # ip_address=node_cfg["ip_address"], + # subnet_mask=node_cfg["subnet_mask"], + # operating_state=NodeOperatingState.ON + # if not (p := node_cfg.get("operating_state")) + # else NodeOperatingState[p.upper()], + # ) + # # Handle extended nodes + # elif n_type.lower() in Node._registry: + # new_node = HostNode._registry[n_type]( + # hostname=node_cfg["hostname"], + # ip_address=node_cfg.get("ip_address"), + # subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), + # default_gateway=node_cfg.get("default_gateway"), + # dns_server=node_cfg.get("dns_server", None), + # operating_state=NodeOperatingState.ON + # if not (p := node_cfg.get("operating_state")) + # else NodeOperatingState[p.upper()], + # ) + # elif n_type in NetworkNode._registry: + # new_node = NetworkNode._registry[n_type](**node_cfg) else: msg = f"invalid node type {n_type} in config" _LOGGER.error(msg) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 8324715f..b003009b 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1469,7 +1469,7 @@ class UserSessionManager(Service): return self.local_session is not None -class Node(SimComponent): +class Node(SimComponent, ABC): """ A basic Node class that represents a node on the network. @@ -1492,7 +1492,6 @@ class Node(SimComponent): "The Network Interfaces on the node by port id." dns_server: Optional[IPv4Address] = None "List of IP addresses of DNS servers used for name resolution." - accounts: Dict[str, Account] = {} "All accounts on the node." applications: Dict[str, Application] = {} @@ -1509,33 +1508,6 @@ class Node(SimComponent): session_manager: SessionManager software_manager: SoftwareManager - revealed_to_red: bool = False - "Informs whether the node has been revealed to a red agent." - - start_up_duration: int = 3 - "Time steps needed for the node to start up." - - start_up_countdown: int = 0 - "Time steps needed until node is booted up." - - shut_down_duration: int = 3 - "Time steps needed for the node to shut down." - - shut_down_countdown: int = 0 - "Time steps needed until node is shut down." - - is_resetting: bool = False - "If true, the node will try turning itself off then back on again." - - node_scan_duration: int = 10 - "How many timesteps until the whole node is scanned. Default 10 time steps." - - node_scan_countdown: int = 0 - "Time steps until scan is complete" - - red_scan_countdown: int = 0 - "Time steps until reveal to red scan is complete." - SYSTEM_SOFTWARE: ClassVar[Dict[str, Type[Software]]] = {} "Base system software that must be preinstalled." @@ -1545,6 +1517,46 @@ class Node(SimComponent): _identifier: ClassVar[str] = "unknown" """Identifier for this particular class, used for printing and logging. Each subclass redefines this.""" + config: Node.ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema()) + + class ConfigSchema: + """Configuration Schema for Node based classes.""" + + revealed_to_red: bool = False + "Informs whether the node has been revealed to a red agent." + + start_up_duration: int = 3 + "Time steps needed for the node to start up." + + start_up_countdown: int = 0 + "Time steps needed until node is booted up." + + shut_down_duration: int = 3 + "Time steps needed for the node to shut down." + + shut_down_countdown: int = 0 + "Time steps needed until node is shut down." + + is_resetting: bool = False + "If true, the node will try turning itself off then back on again." + + node_scan_duration: int = 10 + "How many timesteps until the whole node is scanned. Default 10 time steps." + + node_scan_countdown: int = 0 + "Time steps until scan is complete" + + red_scan_countdown: int = 0 + "Time steps until reveal to red scan is complete." + + def from_config(cls, config: Dict) -> Node: + """Create Node object from a given configuration.""" + if config["type"] not in cls._registry: + msg = f"Configuration contains an invalid Node type: {config['type']}" + return ValueError(msg) + obj = cls(config=cls.ConfigSchema(**config)) + return obj + def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: """ Register a node type. @@ -1850,7 +1862,7 @@ class Node(SimComponent): "applications": {app.name: app.describe_state() for app in self.applications.values()}, "services": {svc.name: svc.describe_state() for svc in self.services.values()}, "process": {proc.name: proc.describe_state() for proc in self.processes.values()}, - "revealed_to_red": self.revealed_to_red, + "revealed_to_red": self.config.revealed_to_red, } ) return state @@ -1928,8 +1940,8 @@ class Node(SimComponent): network_interface.apply_timestep(timestep=timestep) # count down to boot up - if self.start_up_countdown > 0: - self.start_up_countdown -= 1 + if self.config.start_up_countdown > 0: + self.config.start_up_countdown -= 1 else: if self.operating_state == NodeOperatingState.BOOTING: self.operating_state = NodeOperatingState.ON @@ -1940,8 +1952,8 @@ class Node(SimComponent): self._start_up_actions() # count down to shut down - if self.shut_down_countdown > 0: - self.shut_down_countdown -= 1 + if self.config.shut_down_countdown > 0: + self.config.shut_down_countdown -= 1 else: if self.operating_state == NodeOperatingState.SHUTTING_DOWN: self.operating_state = NodeOperatingState.OFF @@ -1949,17 +1961,17 @@ class Node(SimComponent): self._shut_down_actions() # if resetting turn back on - if self.is_resetting: - self.is_resetting = False + if self.config.is_resetting: + self.config.is_resetting = False self.power_on() # time steps which require the node to be on if self.operating_state == NodeOperatingState.ON: # node scanning - if self.node_scan_countdown > 0: - self.node_scan_countdown -= 1 + if self.config.node_scan_countdown > 0: + self.config.node_scan_countdown -= 1 - if self.node_scan_countdown == 0: + if self.config.node_scan_countdown == 0: # scan everything! for process_id in self.processes: self.processes[process_id].scan() @@ -1975,10 +1987,10 @@ class Node(SimComponent): # scan file system self.file_system.scan(instant_scan=True) - if self.red_scan_countdown > 0: - self.red_scan_countdown -= 1 + if self.config.red_scan_countdown > 0: + self.config.red_scan_countdown -= 1 - if self.red_scan_countdown == 0: + if self.config.red_scan_countdown == 0: # scan processes for process_id in self.processes: self.processes[process_id].reveal_to_red() @@ -2035,7 +2047,7 @@ class Node(SimComponent): to the red agent. """ - self.node_scan_countdown = self.node_scan_duration + self.config.node_scan_countdown = self.config.node_scan_duration return True def reveal_to_red(self) -> bool: @@ -2051,12 +2063,12 @@ class Node(SimComponent): `revealed_to_red` to `True`. """ - self.red_scan_countdown = self.node_scan_duration + self.config.red_scan_countdown = self.config.node_scan_duration return True def power_on(self) -> bool: """Power on the Node, enabling its NICs if it is in the OFF state.""" - if self.start_up_duration <= 0: + if self.config.start_up_duration <= 0: self.operating_state = NodeOperatingState.ON self._start_up_actions() self.sys_log.info("Power on") @@ -2065,14 +2077,14 @@ class Node(SimComponent): return True if self.operating_state == NodeOperatingState.OFF: self.operating_state = NodeOperatingState.BOOTING - self.start_up_countdown = self.start_up_duration + self.config.start_up_countdown = self.config.start_up_duration return True return False def power_off(self) -> bool: """Power off the Node, disabling its NICs if it is in the ON state.""" - if self.shut_down_duration <= 0: + if self.config.shut_down_duration <= 0: self._shut_down_actions() self.operating_state = NodeOperatingState.OFF self.sys_log.info("Power off") @@ -2081,7 +2093,7 @@ class Node(SimComponent): for network_interface in self.network_interfaces.values(): network_interface.disable() self.operating_state = NodeOperatingState.SHUTTING_DOWN - self.shut_down_countdown = self.shut_down_duration + self.config.shut_down_countdown = self.config.shut_down_duration return True return False @@ -2093,7 +2105,7 @@ class Node(SimComponent): Applying more timesteps will eventually turn the node back on. """ if self.operating_state.ON: - self.is_resetting = True + self.config.is_resetting = True self.sys_log.info("Resetting") self.power_off() return True diff --git a/src/primaite/simulator/network/hardware/nodes/host/computer.py b/src/primaite/simulator/network/hardware/nodes/host/computer.py index 11b925b9..1fb63b2e 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/computer.py +++ b/src/primaite/simulator/network/hardware/nodes/host/computer.py @@ -1,6 +1,8 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import ClassVar, Dict +from pydantic import Field + from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.system.services.ftp.ftp_client import FTPClient @@ -35,4 +37,11 @@ class Computer(HostNode, identifier="computer"): SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "FTPClient": FTPClient} + config: "Computer.ConfigSchema" = Field(default_factory=lambda: Computer.ConfigSchema()) + + class ConfigSchema(HostNode.ConfigSchema): + """Configuration Schema for Computer class.""" + + pass + pass diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index c51afbca..fa73bf10 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -4,6 +4,8 @@ from __future__ import annotations from ipaddress import IPv4Address from typing import Any, ClassVar, Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.hardware.base import ( IPWiredNetworkInterface, @@ -325,6 +327,13 @@ class HostNode(Node, identifier="HostNode"): network_interface: Dict[int, NIC] = {} "The NICs on the node by port id." + config: HostNode.ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema()) + + class ConfigSchema(Node.ConfigSchema): + """Configuration Schema for HostNode class.""" + + pass + def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs): super().__init__(**kwargs) self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask)) diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index f1ca4930..0a397c49 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -99,6 +99,13 @@ class Firewall(Router, identifier="firewall"): ) """Access Control List for managing traffic leaving towards an external network.""" + config: "Firewall.ConfigSchema" = Field(default_factory=lambda: Firewall.ConfigSchema()) + + class ConfigSchema(Router.ConfigSChema): + """Configuration Schema for Firewall 'Nodes' within PrimAITE.""" + + pass + def __init__(self, hostname: str, **kwargs): if not kwargs.get("sys_log"): kwargs["sys_log"] = SysLog(hostname) diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 4a049f99..132b1462 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -7,7 +7,7 @@ from ipaddress import IPv4Address, IPv4Network from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union from prettytable import MARKDOWN, PrettyTable -from pydantic import validate_call +from pydantic import Field, validate_call from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent @@ -1207,7 +1207,6 @@ class Router(NetworkNode, identifier="router"): "Terminal": Terminal, } - num_ports: int network_interfaces: Dict[str, RouterInterface] = {} "The Router Interfaces on the node." network_interface: Dict[int, RouterInterface] = {} @@ -1215,6 +1214,15 @@ class Router(NetworkNode, identifier="router"): acl: AccessControlList route_table: RouteTable + config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSChema()) + + class ConfigSChema(NetworkNode.ConfigSchema): + """Configuration Schema for Router Objects.""" + + num_ports: int = 10 + hostname: str = "Router" + ports: list = [] + def __init__(self, hostname: str, num_ports: int = 5, **kwargs): if not kwargs.get("sys_log"): kwargs["sys_log"] = SysLog(hostname) @@ -1227,11 +1235,11 @@ class Router(NetworkNode, identifier="router"): self.session_manager.node = self self.software_manager.session_manager = self.session_manager self.session_manager.software_manager = self.software_manager - for i in range(1, self.num_ports + 1): + for i in range(1, self.config.num_ports + 1): network_interface = RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0") self.connect_nic(network_interface) self.network_interface[i] = network_interface - + self.operating_state = NodeOperatingState.ON self._set_default_acl() def _install_system_software(self): @@ -1337,7 +1345,7 @@ class Router(NetworkNode, identifier="router"): :return: A dictionary representing the current state. """ state = super().describe_state() - state["num_ports"] = self.num_ports + state["num_ports"] = self.config.num_ports state["acl"] = self.acl.describe_state() return state @@ -1558,6 +1566,8 @@ class Router(NetworkNode, identifier="router"): ) print(table) + # TODO: Remove - Cover normal config items with ConfigSchema. Move additional setup components to __init__ ? + @classmethod def from_config(cls, cfg: dict, **kwargs) -> "Router": """Create a router based on a config dict. diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index db923f1a..b73af3cb 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -4,6 +4,7 @@ from __future__ import annotations from typing import Dict, Optional from prettytable import MARKDOWN, PrettyTable +from pydantic import Field from primaite import getLogger from primaite.exceptions import NetworkError @@ -94,8 +95,6 @@ class Switch(NetworkNode, identifier="switch"): :ivar num_ports: The number of ports on the switch. Default is 24. """ - num_ports: int = 24 - "The number of ports on the switch." network_interfaces: Dict[str, SwitchPort] = {} "The SwitchPorts on the Switch." network_interface: Dict[int, SwitchPort] = {} @@ -103,9 +102,17 @@ class Switch(NetworkNode, identifier="switch"): mac_address_table: Dict[str, SwitchPort] = {} "A MAC address table mapping destination MAC addresses to corresponding SwitchPorts." + config: "Switch.ConfigSchema" = Field(default_factory=lambda: Switch.ConfigSchema()) + + class ConfigSchema(NetworkNode.ConfigSchema): + """Configuration Schema for Switch nodes within PrimAITE.""" + + num_ports: int = 24 + "The number of ports on the switch." + def __init__(self, **kwargs): super().__init__(**kwargs) - for i in range(1, self.num_ports + 1): + for i in range(1, self.config.num_ports + 1): self.connect_nic(SwitchPort()) def _install_system_software(self): @@ -134,7 +141,7 @@ class Switch(NetworkNode, identifier="switch"): """ state = super().describe_state() state["ports"] = {port_num: port.describe_state() for port_num, port in self.network_interface.items()} - state["num_ports"] = self.num_ports # redundant? + state["num_ports"] = self.config.num_ports # redundant? state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()} return state diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 804a570e..0a527ab8 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -2,7 +2,7 @@ from ipaddress import IPv4Address from typing import Any, Dict, Optional, Union -from pydantic import validate_call +from pydantic import Field, validate_call from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, FREQ_WIFI_2_4, IPWirelessNetworkInterface from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState @@ -124,6 +124,13 @@ class WirelessRouter(Router, identifier="wireless_router"): network_interface: Dict[int, Union[RouterInterface, WirelessAccessPoint]] = {} airspace: AirSpace + config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.Configschema()) + + class ConfigSchema(Router.ConfigSChema): + """Configuration Schema for WirelessRouter nodes within PrimAITE.""" + + pass + def __init__(self, hostname: str, airspace: AirSpace, **kwargs): super().__init__(hostname=hostname, num_ports=0, airspace=airspace, **kwargs) From 70d9fe2fd97317f2516b1629ca222da4e8008147 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 15 Jan 2025 16:33:11 +0000 Subject: [PATCH 02/23] #2887 - End of day commit. Updates to ConfigSchema inheritance, and some initials changes to Router to remove the custom from_config method --- src/primaite/game/game.py | 61 ---------- .../simulator/network/hardware/base.py | 18 +-- .../network/hardware/nodes/host/computer.py | 2 +- .../network/hardware/nodes/host/host_node.py | 2 +- .../hardware/nodes/network/firewall.py | 9 +- .../network/hardware/nodes/network/router.py | 112 ++++-------------- .../network/hardware/nodes/network/switch.py | 3 +- .../hardware/nodes/network/wireless_router.py | 4 +- tests/conftest.py | 3 + 9 files changed, 46 insertions(+), 168 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 48d9df87..6599430a 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -277,67 +277,6 @@ class PrimaiteGame: if n_type in Node._registry: # simplify down Node creation: new_node = Node._registry["n_type"].from_config(config=node_config) - - # Default PrimAITE nodes - # if n_type == "computer": - # new_node = Computer( - # hostname=node_cfg["hostname"], - # ip_address=node_cfg["ip_address"], - # subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), - # default_gateway=node_cfg.get("default_gateway"), - # dns_server=node_cfg.get("dns_server", None), - # operating_state=NodeOperatingState.ON - # if not (p := node_cfg.get("operating_state")) - # else NodeOperatingState[p.upper()], - # ) - # elif n_type == "server": - # new_node = Server( - # hostname=node_cfg["hostname"], - # ip_address=node_cfg["ip_address"], - # subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), - # default_gateway=node_cfg.get("default_gateway"), - # dns_server=node_cfg.get("dns_server", None), - # operating_state=NodeOperatingState.ON - # if not (p := node_cfg.get("operating_state")) - # else NodeOperatingState[p.upper()], - # ) - # elif n_type == "switch": - # new_node = Switch( - # hostname=node_cfg["hostname"], - # num_ports=int(node_cfg.get("num_ports", "8")), - # operating_state=NodeOperatingState.ON - # if not (p := node_cfg.get("operating_state")) - # else NodeOperatingState[p.upper()], - # ) - # elif n_type == "router": - # new_node = Router.from_config(node_cfg) - # elif n_type == "firewall": - # new_node = Firewall.from_config(node_cfg) - # elif n_type == "wireless_router": - # new_node = WirelessRouter.from_config(node_cfg, airspace=net.airspace) - # elif n_type == "printer": - # new_node = Printer( - # hostname=node_cfg["hostname"], - # ip_address=node_cfg["ip_address"], - # subnet_mask=node_cfg["subnet_mask"], - # operating_state=NodeOperatingState.ON - # if not (p := node_cfg.get("operating_state")) - # else NodeOperatingState[p.upper()], - # ) - # # Handle extended nodes - # elif n_type.lower() in Node._registry: - # new_node = HostNode._registry[n_type]( - # hostname=node_cfg["hostname"], - # ip_address=node_cfg.get("ip_address"), - # subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), - # default_gateway=node_cfg.get("default_gateway"), - # dns_server=node_cfg.get("dns_server", None), - # operating_state=NodeOperatingState.ON - # if not (p := node_cfg.get("operating_state")) - # else NodeOperatingState[p.upper()], - # ) - # elif n_type in NetworkNode._registry: - # new_node = NetworkNode._registry[n_type](**node_cfg) else: msg = f"invalid node type {n_type} in config" _LOGGER.error(msg) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index b003009b..822714cb 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -9,7 +9,7 @@ from pathlib import Path from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union from prettytable import MARKDOWN, PrettyTable -from pydantic import BaseModel, Field, validate_call +from pydantic import BaseModel, ConfigDict, Field, validate_call from primaite import getLogger from primaite.exceptions import NetworkError @@ -1480,8 +1480,6 @@ class Node(SimComponent, ABC): :param operating_state: The node operating state, either ON or OFF. """ - hostname: str - "The node hostname on the network." default_gateway: Optional[IPV4Address] = None "The default gateway IP address for forwarding network traffic to other networks." operating_state: NodeOperatingState = NodeOperatingState.OFF @@ -1519,13 +1517,18 @@ class Node(SimComponent, ABC): config: Node.ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema()) - class ConfigSchema: + class ConfigSchema(BaseModel, ABC): """Configuration Schema for Node based classes.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + """Configure pydantic to allow arbitrary types and to let the instance have attributes not present in the model.""" + hostname: str + "The node hostname on the network." + revealed_to_red: bool = False "Informs whether the node has been revealed to a red agent." - start_up_duration: int = 3 + start_up_duration: int = 0 "Time steps needed for the node to start up." start_up_countdown: int = 0 @@ -1549,8 +1552,9 @@ class Node(SimComponent, ABC): red_scan_countdown: int = 0 "Time steps until reveal to red scan is complete." - def from_config(cls, config: Dict) -> Node: - """Create Node object from a given configuration.""" + @classmethod + def from_config(cls, config: Dict) -> "Node": + """Create Node object from a given configuration dictionary.""" if config["type"] not in cls._registry: msg = f"Configuration contains an invalid Node type: {config['type']}" return ValueError(msg) diff --git a/src/primaite/simulator/network/hardware/nodes/host/computer.py b/src/primaite/simulator/network/hardware/nodes/host/computer.py index 1fb63b2e..85857a44 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/computer.py +++ b/src/primaite/simulator/network/hardware/nodes/host/computer.py @@ -42,6 +42,6 @@ class Computer(HostNode, identifier="computer"): class ConfigSchema(HostNode.ConfigSchema): """Configuration Schema for Computer class.""" - pass + hostname: str = "Computer" pass diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index fa73bf10..00f21342 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -332,7 +332,7 @@ class HostNode(Node, identifier="HostNode"): class ConfigSchema(Node.ConfigSchema): """Configuration Schema for HostNode class.""" - pass + hostname: str = "HostNode" def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs): super().__init__(**kwargs) diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 0a397c49..c7e22d49 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -104,13 +104,14 @@ class Firewall(Router, identifier="firewall"): class ConfigSchema(Router.ConfigSChema): """Configuration Schema for Firewall 'Nodes' within PrimAITE.""" - pass + hostname: str = "Firewall" + num_ports: int = 0 - def __init__(self, hostname: str, **kwargs): + def __init__(self, **kwargs): if not kwargs.get("sys_log"): - kwargs["sys_log"] = SysLog(hostname) + kwargs["sys_log"] = SysLog(self.config.hostname) - super().__init__(hostname=hostname, num_ports=0, **kwargs) + super().__init__(hostname=self.config.hostname, num_ports=self.config.num_ports, **kwargs) self.connect_nic( RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", port_name="external") diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 132b1462..83fa066d 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1211,27 +1211,22 @@ class Router(NetworkNode, identifier="router"): "The Router Interfaces on the node." network_interface: Dict[int, RouterInterface] = {} "The Router Interfaces on the node by port id." - acl: AccessControlList - route_table: RouteTable - config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSChema()) + config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSchema()) - class ConfigSChema(NetworkNode.ConfigSchema): + class ConfigSchema(NetworkNode.ConfigSchema): """Configuration Schema for Router Objects.""" - num_ports: int = 10 + num_ports: int = 5 hostname: str = "Router" ports: list = [] + sys_log: SysLog = SysLog(hostname) + acl: AccessControlList = AccessControlList(sys_log=sys_log, implicit_action=ACLAction.DENY, name=hostname) + route_table: RouteTable = RouteTable(sys_log=sys_log) - def __init__(self, hostname: str, num_ports: int = 5, **kwargs): - if not kwargs.get("sys_log"): - kwargs["sys_log"] = SysLog(hostname) - if not kwargs.get("acl"): - kwargs["acl"] = AccessControlList(sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=hostname) - if not kwargs.get("route_table"): - kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"]) - super().__init__(hostname=hostname, num_ports=num_ports, **kwargs) - self.session_manager = RouterSessionManager(sys_log=self.sys_log) + def __init__(self, **kwargs): + super().__init__(hostname=self.config.hostname, num_ports=self.config.num_ports, **kwargs) + self.session_manager = RouterSessionManager(sys_log=self.config.sys_log) self.session_manager.node = self self.software_manager.session_manager = self.session_manager self.session_manager.software_manager = self.software_manager @@ -1265,10 +1260,10 @@ class Router(NetworkNode, identifier="router"): Initializes the router's ACL (Access Control List) with default rules, permitting essential protocols like ARP and ICMP, which are necessary for basic network operations and diagnostics. """ - self.acl.add_rule( + self.config.acl.add_rule( action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) - self.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + self.config.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) def setup_for_episode(self, episode: int): """ @@ -1292,7 +1287,7 @@ class Router(NetworkNode, identifier="router"): More information in user guide and docstring for SimComponent._init_request_manager. """ rm = super()._init_request_manager() - rm.add_request("acl", RequestType(func=self.acl._request_manager)) + rm.add_request("acl", RequestType(func=self.config.acl._request_manager)) return rm def ip_is_router_interface(self, ip_address: IPv4Address, enabled_only: bool = False) -> bool: @@ -1346,7 +1341,7 @@ class Router(NetworkNode, identifier="router"): """ state = super().describe_state() state["num_ports"] = self.config.num_ports - state["acl"] = self.acl.describe_state() + state["acl"] = self.config.acl.describe_state() return state def check_send_frame_to_session_manager(self, frame: Frame) -> bool: @@ -1393,7 +1388,7 @@ class Router(NetworkNode, identifier="router"): return # Check if it's permitted - permitted, rule = self.acl.is_permitted(frame) + permitted, rule = self.config.acl.is_permitted(frame) if not permitted: at_port = self._get_port_of_nic(from_network_interface) @@ -1566,83 +1561,18 @@ class Router(NetworkNode, identifier="router"): ) print(table) - # TODO: Remove - Cover normal config items with ConfigSchema. Move additional setup components to __init__ ? - - @classmethod - def from_config(cls, cfg: dict, **kwargs) -> "Router": - """Create a router based on a config dict. - - Schema: - - hostname (str): unique name for this router. - - num_ports (int, optional): Number of network ports on the router. 8 by default - - ports (dict): Dict with integers from 1 - num_ports as keys. The values should be another dict specifying - ip_address and subnet_mask assigned to that ports (as strings) - - acl (dict): Dict with integers from 1 - max_acl_rules as keys. The key defines the position within the ACL - where the rule will be added (lower number is resolved first). The values should describe valid ACL - Rules as: - - action (str): either PERMIT or DENY - - src_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER - - dst_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER - - protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP - - src_ip_address (str, optional): IP address octet written in base 10 - - dst_ip_address (str, optional): IP address octet written in base 10 - - routes (list[dict]): List of route dicts with values: - - address (str): The destination address of the route. - - subnet_mask (str): The subnet mask of the route. - - next_hop_ip_address (str): The next hop IP for the route. - - metric (int): The metric of the route. Optional. - - default_route: - - next_hop_ip_address (str): The next hop IP for the route. - - Example config: - ``` - { - 'hostname': 'router_1', - 'num_ports': 5, - 'ports': { - 1: { - 'ip_address' : '192.168.1.1', - 'subnet_mask' : '255.255.255.0', - }, - 2: { - 'ip_address' : '192.168.0.1', - 'subnet_mask' : '255.255.255.252', - } - }, - 'acl' : { - 21: {'action': 'PERMIT', 'src_port': 'HTTP', dst_port: 'HTTP'}, - 22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'}, - 23: {'action': 'PERMIT', 'protocol': 'ICMP'}, - }, - 'routes' : [ - {'address': '192.168.0.0', 'subnet_mask': '255.255.255.0', 'next_hop_ip_address': '192.168.1.2'} - ], - 'default_route': {'next_hop_ip_address': '192.168.0.2'} - } - ``` - - :param cfg: Router config adhering to schema described in main docstring body - :type cfg: dict - :return: Configured router. - :rtype: Router - """ - router = Router( - hostname=cfg["hostname"], - num_ports=int(cfg.get("num_ports", "5")), - operating_state=NodeOperatingState.ON - if not (p := cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) + def setup_router(self, cfg: dict) -> Router: + """ TODO: This is the extra bit of Router's from_config metho. Needs sorting.""" if "ports" in cfg: for port_num, port_cfg in cfg["ports"].items(): - router.configure_port( + self.configure_port( port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=IPv4Address(port_cfg.get("subnet_mask", "255.255.255.0")), ) if "acl" in cfg: for r_num, r_cfg in cfg["acl"].items(): - router.acl.add_rule( + self.config.acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], @@ -1655,7 +1585,7 @@ class Router(NetworkNode, identifier="router"): ) if "routes" in cfg: for route in cfg.get("routes"): - router.route_table.add_route( + self.config.route_table.add_route( address=IPv4Address(route.get("address")), subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")), next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")), @@ -1664,5 +1594,5 @@ class Router(NetworkNode, identifier="router"): if "default_route" in cfg: next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None) if next_hop_ip_address: - router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) - return router + self.config.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) + return self diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index b73af3cb..a2d0050b 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -107,8 +107,9 @@ class Switch(NetworkNode, identifier="switch"): class ConfigSchema(NetworkNode.ConfigSchema): """Configuration Schema for Switch nodes within PrimAITE.""" + hostname: str = "Switch" num_ports: int = 24 - "The number of ports on the switch." + "The number of ports on the switch. Default is 24." def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 0a527ab8..2c4b5976 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -124,12 +124,12 @@ class WirelessRouter(Router, identifier="wireless_router"): network_interface: Dict[int, Union[RouterInterface, WirelessAccessPoint]] = {} airspace: AirSpace - config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.Configschema()) + config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.ConfigSchema()) class ConfigSchema(Router.ConfigSChema): """Configuration Schema for WirelessRouter nodes within PrimAITE.""" - pass + hostname: str = "WirelessRouter" def __init__(self, hostname: str, airspace: AirSpace, **kwargs): super().__init__(hostname=hostname, num_ports=0, airspace=airspace, **kwargs) diff --git a/tests/conftest.py b/tests/conftest.py index 0d2cc363..6cbcfa84 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -195,12 +195,14 @@ def example_network() -> Network: network = Network() # Router 1 + # router_1 = Router(hostname="router_1", start_up_duration=0) router_1 = Router(hostname="router_1", start_up_duration=0) router_1.power_on() router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0") router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0") # Switch 1 + # switch_1_config = Switch.ConfigSchema() switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0) switch_1.power_on() @@ -208,6 +210,7 @@ def example_network() -> Network: router_1.enable_port(1) # Switch 2 + # switch_2_config = Switch.ConfigSchema() switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0) switch_2.power_on() network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8]) From 3957142afdf5083d2af032d6198f8c4dbbfc5827 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 22 Jan 2025 17:20:38 +0000 Subject: [PATCH 03/23] #2887 - Updates to Node components to use rom_config and allow for extensibility. Router and Firewall continue to have custom from_config. Some test fixes to reflect changes to functionality. --- .../source/how_to_guides/extensible_nodes.rst | 58 ++++++++ src/primaite/game/game.py | 7 +- src/primaite/simulator/core.py | 2 +- src/primaite/simulator/network/container.py | 28 ++-- .../simulator/network/hardware/base.py | 17 ++- .../network/hardware/nodes/host/host_node.py | 7 +- .../network/hardware/nodes/host/server.py | 17 +++ .../hardware/nodes/network/firewall.py | 82 ++++++----- .../network/hardware/nodes/network/router.py | 138 ++++++++++++++++-- .../network/hardware/nodes/network/switch.py | 6 +- src/primaite/simulator/system/software.py | 2 +- tests/assets/configs/data_manipulation.yaml | 2 +- tests/conftest.py | 99 ++++++++----- .../nodes/network/test_router_config.py | 1 + .../observations/test_acl_observations.py | 2 +- .../test_file_system_observations.py | 4 +- .../_network/_hardware/nodes/test_router.py | 2 +- 17 files changed, 350 insertions(+), 124 deletions(-) create mode 100644 docs/source/how_to_guides/extensible_nodes.rst diff --git a/docs/source/how_to_guides/extensible_nodes.rst b/docs/source/how_to_guides/extensible_nodes.rst new file mode 100644 index 00000000..21907767 --- /dev/null +++ b/docs/source/how_to_guides/extensible_nodes.rst @@ -0,0 +1,58 @@ +.. only:: comment + + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + +.. _about: + + +Extensible Nodes +**************** + +Node classes within PrimAITE have been updated to allow for easier generation of custom nodes within simulations. + + +Changes to Node Class structure. +================================ + +Node classes all inherit from the base Node Class, though new classes should inherit from either HostNode or NetworkNode, subject to the intended application of the Node. + +The use of an `__init__` method is not necessary, as configurable variables for the class should be specified within the `config` of the class, and passed at run time via your YAML configuration using the `from_config` method. + + +An example of how additional Node classes is below, taken from `router.py` withing PrimAITE. + +.. code-block:: Python + +class Router(NetworkNode, identifier="router"): + """ Represents a network router within the simulation, managing routing and forwarding of IP packets across network interfaces.""" + + SYSTEM_SOFTWARE: ClassVar[Dict] = { + "UserSessionManager": UserSessionManager, + "UserManager": UserManager, + "Terminal": Terminal, + } + + network_interfaces: Dict[str, RouterInterface] = {} + "The Router Interfaces on the node." + network_interface: Dict[int, RouterInterface] = {} + "The Router Interfaces on the node by port id." + + sys_log: SysLog + + config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSchema()) + + class ConfigSchema(NetworkNode.ConfigSchema): + """Configuration Schema for Router Objects.""" + + num_ports: int = 5 + + hostname: ClassVar[str] = "Router" + + ports: list = [] + + + +Changes to YAML file. +===================== + +Nodes defined within configuration YAML files for use with PrimAITE 3.X should still be compatible following these changes. \ No newline at end of file diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 6599430a..b9dc9c4d 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -271,12 +271,13 @@ class PrimaiteGame: for node_cfg in nodes_cfg: n_type = node_cfg["type"] - node_config: dict = node_cfg["config"] + # node_config: dict = node_cfg["config"] + print(f"{n_type}:{node_cfg}") new_node = None if n_type in Node._registry: # simplify down Node creation: - new_node = Node._registry["n_type"].from_config(config=node_config) + new_node = Node._registry[n_type].from_config(config=node_cfg) else: msg = f"invalid node type {n_type} in config" _LOGGER.error(msg) @@ -313,7 +314,7 @@ class PrimaiteGame: service_class = SERVICE_TYPES_MAPPING[service_type] if service_class is not None: - _LOGGER.debug(f"installing {service_type} on node {new_node.hostname}") + _LOGGER.debug(f"installing {service_type} on node {new_node.config.hostname}") new_node.software_manager.install(service_class, **service_cfg.get("options", {})) new_service = new_node.software_manager.software[service_class.__name__] diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 567a0493..dc4ae73b 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -3,7 +3,7 @@ """Core of the PrimAITE Simulator.""" import warnings from abc import abstractmethod -from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union from uuid import uuid4 from prettytable import PrettyTable diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index bf677d5c..aac82633 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -180,7 +180,7 @@ class Network(SimComponent): table.align = "l" table.title = "Nodes" for node in self.nodes.values(): - table.add_row((node.hostname, type(node)._identifier, node.operating_state.name)) + table.add_row((node.config.hostname, type(node)._identifier, node.operating_state.name)) print(table) if ip_addresses: @@ -196,7 +196,7 @@ class Network(SimComponent): if port.ip_address != IPv4Address("127.0.0.1"): port_str = port.port_name if port.port_name else port.port_num table.add_row( - [node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway] + [node.config.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway] ) print(table) @@ -215,9 +215,9 @@ class Network(SimComponent): if node in [link.endpoint_a.parent, link.endpoint_b.parent]: table.add_row( [ - link.endpoint_a.parent.hostname, + link.endpoint_a.parent.config.hostname, str(link.endpoint_a), - link.endpoint_b.parent.hostname, + link.endpoint_b.parent.config.hostname, str(link.endpoint_b), link.is_up, link.bandwidth, @@ -251,7 +251,7 @@ class Network(SimComponent): state = super().describe_state() state.update( { - "nodes": {node.hostname: node.describe_state() for node in self.nodes.values()}, + "nodes": {node.config.hostname: node.describe_state() for node in self.nodes.values()}, "links": {}, } ) @@ -259,8 +259,8 @@ class Network(SimComponent): for _, link in self.links.items(): node_a = link.endpoint_a._connected_node node_b = link.endpoint_b._connected_node - hostname_a = node_a.hostname if node_a else None - hostname_b = node_b.hostname if node_b else None + hostname_a = node_a.config.hostname if node_a else None + hostname_b = node_b.config.hostname if node_b else None port_a = link.endpoint_a.port_num port_b = link.endpoint_b.port_num link_key = f"{hostname_a}:eth-{port_a}<->{hostname_b}:eth-{port_b}" @@ -286,9 +286,9 @@ class Network(SimComponent): self.nodes[node.uuid] = node self._node_id_map[len(self.nodes)] = node node.parent = self - self._nx_graph.add_node(node.hostname) + self._nx_graph.add_node(node.config.hostname) _LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}") - self._node_request_manager.add_request(name=node.hostname, request_type=RequestType(func=node._request_manager)) + self._node_request_manager.add_request(name=node.config.hostname, request_type=RequestType(func=node._request_manager)) def get_node_by_hostname(self, hostname: str) -> Optional[Node]: """ @@ -300,7 +300,7 @@ class Network(SimComponent): :return: The Node if it exists in the network. """ for node in self.nodes.values(): - if node.hostname == hostname: + if node.config.hostname == hostname: return node def remove_node(self, node: Node) -> None: @@ -313,7 +313,7 @@ class Network(SimComponent): :type node: Node """ if node not in self: - _LOGGER.warning(f"Can't remove node {node.hostname}. It's not in the network.") + _LOGGER.warning(f"Can't remove node {node.config.hostname}. It's not in the network.") return self.nodes.pop(node.uuid) for i, _node in self._node_id_map.items(): @@ -321,8 +321,8 @@ class Network(SimComponent): self._node_id_map.pop(i) break node.parent = None - self._node_request_manager.remove_request(name=node.hostname) - _LOGGER.info(f"Removed node {node.hostname} from network {self.uuid}") + self._node_request_manager.remove_request(name=node.config.hostname) + _LOGGER.info(f"Removed node {node.config.hostname} from network {self.uuid}") def connect( self, endpoint_a: WiredNetworkInterface, endpoint_b: WiredNetworkInterface, bandwidth: int = 100, **kwargs @@ -352,7 +352,7 @@ class Network(SimComponent): link = Link(endpoint_a=endpoint_a, endpoint_b=endpoint_b, bandwidth=bandwidth, **kwargs) self.links[link.uuid] = link self._link_id_map[len(self.links)] = link - self._nx_graph.add_edge(endpoint_a.parent.hostname, endpoint_b.parent.hostname) + self._nx_graph.add_edge(endpoint_a.parent.config.hostname, endpoint_b.parent.config.hostname) link.parent = self _LOGGER.debug(f"Added link {link.uuid} to connect {endpoint_a} and {endpoint_b}") return link diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 822714cb..dbe9705b 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -431,7 +431,7 @@ class WiredNetworkInterface(NetworkInterface, ABC): self.enabled = True self._connected_node.sys_log.info(f"Network Interface {self} enabled") self.pcap = PacketCapture( - hostname=self._connected_node.hostname, port_num=self.port_num, port_name=self.port_name + hostname=self._connected_node.config.hostname, port_num=self.port_num, port_name=self.port_name ) if self._connected_link: self._connected_link.endpoint_up() @@ -1515,14 +1515,16 @@ class Node(SimComponent, ABC): _identifier: ClassVar[str] = "unknown" """Identifier for this particular class, used for printing and logging. Each subclass redefines this.""" - config: Node.ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema()) + config: Node.ConfigSchema + """Configuration items within Node""" class ConfigSchema(BaseModel, ABC): """Configuration Schema for Node based classes.""" model_config = ConfigDict(arbitrary_types_allowed=True) """Configure pydantic to allow arbitrary types and to let the instance have attributes not present in the model.""" - hostname: str + + hostname: str = "default" "The node hostname on the network." revealed_to_red: bool = False @@ -1552,6 +1554,7 @@ class Node(SimComponent, ABC): red_scan_countdown: int = 0 "Time steps until reveal to red scan is complete." + @classmethod def from_config(cls, config: Dict) -> "Node": """Create Node object from a given configuration dictionary.""" @@ -1586,11 +1589,11 @@ class Node(SimComponent, ABC): provided. """ if not kwargs.get("sys_log"): - kwargs["sys_log"] = SysLog(kwargs["hostname"]) + kwargs["sys_log"] = SysLog(kwargs["config"].hostname) if not kwargs.get("session_manager"): kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log")) if not kwargs.get("root"): - kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"] + kwargs["root"] = SIM_OUTPUT.path / kwargs["config"].hostname if not kwargs.get("file_system"): kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs") if not kwargs.get("software_manager"): @@ -1601,10 +1604,12 @@ class Node(SimComponent, ABC): file_system=kwargs.get("file_system"), dns_server=kwargs.get("dns_server"), ) + super().__init__(**kwargs) self._install_system_software() self.session_manager.node = self self.session_manager.software_manager = self.software_manager + self.power_on() @property def user_manager(self) -> Optional[UserManager]: @@ -1856,7 +1861,7 @@ class Node(SimComponent, ABC): state = super().describe_state() state.update( { - "hostname": self.hostname, + "hostname": self.config.hostname, "operating_state": self.operating_state.value, "NICs": { eth_num: network_interface.describe_state() diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 00f21342..23db025d 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -333,10 +333,13 @@ class HostNode(Node, identifier="HostNode"): """Configuration Schema for HostNode class.""" hostname: str = "HostNode" + ip_address: IPV4Address = "192.168.0.1" + subnet_mask: IPV4Address = "255.255.255.0" + default_gateway: IPV4Address = "192.168.10.1" - def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs): + def __init__(self, **kwargs): super().__init__(**kwargs) - self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask)) + self.connect_nic(NIC(ip_address=kwargs["config"].ip_address, subnet_mask=kwargs["config"].subnet_mask)) @property def nmap(self) -> Optional[NMAP]: diff --git a/src/primaite/simulator/network/hardware/nodes/host/server.py b/src/primaite/simulator/network/hardware/nodes/host/server.py index e16cfd8f..1b3f6c58 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/server.py +++ b/src/primaite/simulator/network/hardware/nodes/host/server.py @@ -1,4 +1,6 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from typing import ClassVar +from pydantic import Field from primaite.simulator.network.hardware.nodes.host.host_node import HostNode @@ -30,8 +32,23 @@ class Server(HostNode, identifier="server"): * Web Browser """ + config: "Server.ConfigSchema" = Field(default_factory=lambda: Server.ConfigSchema()) + + class ConfigSchema(HostNode.ConfigSchema): + """Configuration Schema for Server class.""" + + hostname: str = "server" + class Printer(HostNode, identifier="printer"): """Printer? I don't even know her!.""" # TODO: Implement printer-specific behaviour + + + config: "Printer.ConfigSchema" = Field(default_factory=lambda: Printer.ConfigSchema()) + + class ConfigSchema(HostNode.ConfigSchema): + """Configuration Schema for Printer class.""" + + hostname: ClassVar[str] = "printer" \ No newline at end of file diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index c7e22d49..2ebfe44a 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -99,19 +99,22 @@ class Firewall(Router, identifier="firewall"): ) """Access Control List for managing traffic leaving towards an external network.""" + _identifier: str = "firewall" + config: "Firewall.ConfigSchema" = Field(default_factory=lambda: Firewall.ConfigSchema()) - class ConfigSchema(Router.ConfigSChema): + class ConfigSchema(Router.ConfigSchema): """Configuration Schema for Firewall 'Nodes' within PrimAITE.""" - hostname: str = "Firewall" + hostname: str = "firewall" num_ports: int = 0 + operating_state: NodeOperatingState = NodeOperatingState.ON def __init__(self, **kwargs): if not kwargs.get("sys_log"): - kwargs["sys_log"] = SysLog(self.config.hostname) + kwargs["sys_log"] = SysLog(kwargs["config"].hostname) - super().__init__(hostname=self.config.hostname, num_ports=self.config.num_ports, **kwargs) + super().__init__(**kwargs) self.connect_nic( RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", port_name="external") @@ -124,22 +127,22 @@ class Firewall(Router, identifier="firewall"): ) # Update ACL objects with firewall's hostname and syslog to allow accurate logging self.internal_inbound_acl.sys_log = kwargs["sys_log"] - self.internal_inbound_acl.name = f"{hostname} - Internal Inbound" + self.internal_inbound_acl.name = f"{kwargs['config'].hostname} - Internal Inbound" self.internal_outbound_acl.sys_log = kwargs["sys_log"] - self.internal_outbound_acl.name = f"{hostname} - Internal Outbound" + self.internal_outbound_acl.name = f"{kwargs['config'].hostname} - Internal Outbound" self.dmz_inbound_acl.sys_log = kwargs["sys_log"] - self.dmz_inbound_acl.name = f"{hostname} - DMZ Inbound" + self.dmz_inbound_acl.name = f"{kwargs['config'].hostname} - DMZ Inbound" self.dmz_outbound_acl.sys_log = kwargs["sys_log"] - self.dmz_outbound_acl.name = f"{hostname} - DMZ Outbound" + self.dmz_outbound_acl.name = f"{kwargs['config'].hostname} - DMZ Outbound" self.external_inbound_acl.sys_log = kwargs["sys_log"] - self.external_inbound_acl.name = f"{hostname} - External Inbound" + self.external_inbound_acl.name = f"{kwargs['config'].hostname} - External Inbound" self.external_outbound_acl.sys_log = kwargs["sys_log"] - self.external_outbound_acl.name = f"{hostname} - External Outbound" + self.external_outbound_acl.name = f"{kwargs['config'].hostname} - External Outbound" def _init_request_manager(self) -> RequestManager: """ @@ -567,18 +570,21 @@ class Firewall(Router, identifier="firewall"): self.dmz_port.enable() @classmethod - def from_config(cls, cfg: dict) -> "Firewall": + def from_config(cls, config: dict) -> "Firewall": """Create a firewall based on a config dict.""" - firewall = Firewall( - hostname=cfg["hostname"], - operating_state=NodeOperatingState.ON - if not (p := cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - if "ports" in cfg: - internal_port = cfg["ports"]["internal_port"] - external_port = cfg["ports"]["external_port"] - dmz_port = cfg["ports"].get("dmz_port") + # firewall = Firewall( + # hostname=config["hostname"], + # operating_state=NodeOperatingState.ON + # if not (p := config.get("operating_state")) + # else NodeOperatingState[p.upper()], + # ) + + firewall = Firewall(config = cls.ConfigSchema(**config)) + + if "ports" in config: + internal_port = config["ports"]["internal_port"] + external_port = config["ports"]["external_port"] + dmz_port = config["ports"].get("dmz_port") # configure internal port firewall.configure_internal_port( @@ -598,10 +604,10 @@ class Firewall(Router, identifier="firewall"): ip_address=IPV4Address(dmz_port.get("ip_address")), subnet_mask=IPV4Address(dmz_port.get("subnet_mask", "255.255.255.0")), ) - if "acl" in cfg: + if "acl" in config: # acl rules for internal_inbound_acl - if cfg["acl"]["internal_inbound_acl"]: - for r_num, r_cfg in cfg["acl"]["internal_inbound_acl"].items(): + if config["acl"]["internal_inbound_acl"]: + for r_num, r_cfg in config["acl"]["internal_inbound_acl"].items(): firewall.internal_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -615,8 +621,8 @@ class Firewall(Router, identifier="firewall"): ) # acl rules for internal_outbound_acl - if cfg["acl"]["internal_outbound_acl"]: - for r_num, r_cfg in cfg["acl"]["internal_outbound_acl"].items(): + if config["acl"]["internal_outbound_acl"]: + for r_num, r_cfg in config["acl"]["internal_outbound_acl"].items(): firewall.internal_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -630,8 +636,8 @@ class Firewall(Router, identifier="firewall"): ) # acl rules for dmz_inbound_acl - if cfg["acl"]["dmz_inbound_acl"]: - for r_num, r_cfg in cfg["acl"]["dmz_inbound_acl"].items(): + if config["acl"]["dmz_inbound_acl"]: + for r_num, r_cfg in config["acl"]["dmz_inbound_acl"].items(): firewall.dmz_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -645,8 +651,8 @@ class Firewall(Router, identifier="firewall"): ) # acl rules for dmz_outbound_acl - if cfg["acl"]["dmz_outbound_acl"]: - for r_num, r_cfg in cfg["acl"]["dmz_outbound_acl"].items(): + if config["acl"]["dmz_outbound_acl"]: + for r_num, r_cfg in config["acl"]["dmz_outbound_acl"].items(): firewall.dmz_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -660,8 +666,8 @@ class Firewall(Router, identifier="firewall"): ) # acl rules for external_inbound_acl - if cfg["acl"].get("external_inbound_acl"): - for r_num, r_cfg in cfg["acl"]["external_inbound_acl"].items(): + if config["acl"].get("external_inbound_acl"): + for r_num, r_cfg in config["acl"]["external_inbound_acl"].items(): firewall.external_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -675,8 +681,8 @@ class Firewall(Router, identifier="firewall"): ) # acl rules for external_outbound_acl - if cfg["acl"].get("external_outbound_acl"): - for r_num, r_cfg in cfg["acl"]["external_outbound_acl"].items(): + if config["acl"].get("external_outbound_acl"): + for r_num, r_cfg in config["acl"]["external_outbound_acl"].items(): firewall.external_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -689,16 +695,16 @@ class Firewall(Router, identifier="firewall"): position=r_num, ) - if "routes" in cfg: - for route in cfg.get("routes"): + if "routes" in config: + for route in config.get("routes"): firewall.route_table.add_route( address=IPv4Address(route.get("address")), subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")), next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")), metric=float(route.get("metric", 0)), ) - if "default_route" in cfg: - next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None) + if "default_route" in config: + next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None) if next_hop_ip_address: firewall.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 83fa066d..e475df66 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1212,21 +1212,34 @@ class Router(NetworkNode, identifier="router"): network_interface: Dict[int, RouterInterface] = {} "The Router Interfaces on the node by port id." - config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSchema()) + sys_log: SysLog = None + + acl: AccessControlList = None + + route_table: RouteTable = None + + config: "Router.ConfigSchema" class ConfigSchema(NetworkNode.ConfigSchema): """Configuration Schema for Router Objects.""" num_ports: int = 5 + """Number of ports available for this Router. Default is 5""" + hostname: str = "Router" - ports: list = [] - sys_log: SysLog = SysLog(hostname) - acl: AccessControlList = AccessControlList(sys_log=sys_log, implicit_action=ACLAction.DENY, name=hostname) - route_table: RouteTable = RouteTable(sys_log=sys_log) + + ports: Dict[Union[int, str], Dict] = {} + def __init__(self, **kwargs): - super().__init__(hostname=self.config.hostname, num_ports=self.config.num_ports, **kwargs) - self.session_manager = RouterSessionManager(sys_log=self.config.sys_log) + if not kwargs.get("sys_log"): + kwargs["sys_log"] = SysLog(kwargs["config"].hostname) + if not kwargs.get("acl"): + kwargs["acl"] = AccessControlList(sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=kwargs["config"].hostname) + if not kwargs.get("route_table"): + kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"]) + super().__init__(**kwargs) + self.session_manager = RouterSessionManager(sys_log=self.sys_log) self.session_manager.node = self self.software_manager.session_manager = self.session_manager self.session_manager.software_manager = self.software_manager @@ -1234,9 +1247,11 @@ class Router(NetworkNode, identifier="router"): network_interface = RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0") self.connect_nic(network_interface) self.network_interface[i] = network_interface - self.operating_state = NodeOperatingState.ON + self._set_default_acl() + + def _install_system_software(self): """ Installs essential system software and network services on the router. @@ -1260,10 +1275,10 @@ class Router(NetworkNode, identifier="router"): Initializes the router's ACL (Access Control List) with default rules, permitting essential protocols like ARP and ICMP, which are necessary for basic network operations and diagnostics. """ - self.config.acl.add_rule( + self.acl.add_rule( action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) - self.config.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + self.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) def setup_for_episode(self, episode: int): """ @@ -1287,7 +1302,7 @@ class Router(NetworkNode, identifier="router"): More information in user guide and docstring for SimComponent._init_request_manager. """ rm = super()._init_request_manager() - rm.add_request("acl", RequestType(func=self.config.acl._request_manager)) + rm.add_request("acl", RequestType(func=self.acl._request_manager)) return rm def ip_is_router_interface(self, ip_address: IPv4Address, enabled_only: bool = False) -> bool: @@ -1341,7 +1356,7 @@ class Router(NetworkNode, identifier="router"): """ state = super().describe_state() state["num_ports"] = self.config.num_ports - state["acl"] = self.config.acl.describe_state() + state["acl"] = self.acl.describe_state() return state def check_send_frame_to_session_manager(self, frame: Frame) -> bool: @@ -1562,7 +1577,7 @@ class Router(NetworkNode, identifier="router"): print(table) def setup_router(self, cfg: dict) -> Router: - """ TODO: This is the extra bit of Router's from_config metho. Needs sorting.""" + """TODO: This is the extra bit of Router's from_config metho. Needs sorting.""" if "ports" in cfg: for port_num, port_cfg in cfg["ports"].items(): self.configure_port( @@ -1594,5 +1609,100 @@ class Router(NetworkNode, identifier="router"): if "default_route" in cfg: next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None) if next_hop_ip_address: - self.config.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) + self.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) return self + + + @classmethod + def from_config(cls, config: dict, **kwargs) -> "Router": + """Create a router based on a config dict. + + Schema: + - hostname (str): unique name for this router. + - num_ports (int, optional): Number of network ports on the router. 8 by default + - ports (dict): Dict with integers from 1 - num_ports as keys. The values should be another dict specifying + ip_address and subnet_mask assigned to that ports (as strings) + - acl (dict): Dict with integers from 1 - max_acl_rules as keys. The key defines the position within the ACL + where the rule will be added (lower number is resolved first). The values should describe valid ACL + Rules as: + - action (str): either PERMIT or DENY + - src_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER + - dst_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER + - protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP + - src_ip_address (str, optional): IP address octet written in base 10 + - dst_ip_address (str, optional): IP address octet written in base 10 + - routes (list[dict]): List of route dicts with values: + - address (str): The destination address of the route. + - subnet_mask (str): The subnet mask of the route. + - next_hop_ip_address (str): The next hop IP for the route. + - metric (int): The metric of the route. Optional. + - default_route: + - next_hop_ip_address (str): The next hop IP for the route. + + Example config: + ``` + { + 'hostname': 'router_1', + 'num_ports': 5, + 'ports': { + 1: { + 'ip_address' : '192.168.1.1', + 'subnet_mask' : '255.255.255.0', + }, + 2: { + 'ip_address' : '192.168.0.1', + 'subnet_mask' : '255.255.255.252', + } + }, + 'acl' : { + 21: {'action': 'PERMIT', 'src_port': 'HTTP', dst_port: 'HTTP'}, + 22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'}, + 23: {'action': 'PERMIT', 'protocol': 'ICMP'}, + }, + 'routes' : [ + {'address': '192.168.0.0', 'subnet_mask': '255.255.255.0', 'next_hop_ip_address': '192.168.1.2'} + ], + 'default_route': {'next_hop_ip_address': '192.168.0.2'} + } + ``` + + :param cfg: Router config adhering to schema described in main docstring body + :type cfg: dict + :return: Configured router. + :rtype: Router + """ + router = Router(config=Router.ConfigSchema(**config) + ) + if "ports" in config: + for port_num, port_cfg in config["ports"].items(): + router.configure_port( + port=port_num, + ip_address=port_cfg["ip_address"], + subnet_mask=IPv4Address(port_cfg.get("subnet_mask", "255.255.255.0")), + ) + if "acl" in config: + for r_num, r_cfg in config["acl"].items(): + router.acl.add_rule( + action=ACLAction[r_cfg["action"]], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], + src_ip_address=r_cfg.get("src_ip"), + src_wildcard_mask=r_cfg.get("src_wildcard_mask"), + dst_ip_address=r_cfg.get("dst_ip"), + dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"), + position=r_num, + ) + if "routes" in config: + for route in config.get("routes"): + router.route_table.add_route( + address=IPv4Address(route.get("address")), + subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")), + next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")), + metric=float(route.get("metric", 0)), + ) + if "default_route" in config: + next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None) + if next_hop_ip_address: + router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) + return router \ No newline at end of file diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index a2d0050b..2ca0cafd 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import Dict, Optional +from typing import ClassVar, Dict, Optional from prettytable import MARKDOWN, PrettyTable from pydantic import Field @@ -102,7 +102,7 @@ class Switch(NetworkNode, identifier="switch"): mac_address_table: Dict[str, SwitchPort] = {} "A MAC address table mapping destination MAC addresses to corresponding SwitchPorts." - config: "Switch.ConfigSchema" = Field(default_factory=lambda: Switch.ConfigSchema()) + config: "Switch.ConfigSchema" class ConfigSchema(NetworkNode.ConfigSchema): """Configuration Schema for Switch nodes within PrimAITE.""" @@ -113,7 +113,7 @@ class Switch(NetworkNode, identifier="switch"): def __init__(self, **kwargs): super().__init__(**kwargs) - for i in range(1, self.config.num_ports + 1): + for i in range(1, kwargs["config"].num_ports + 1): self.connect_nic(SwitchPort()) def _install_system_software(self): diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 34c893eb..9a30f3e3 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -294,7 +294,7 @@ class IOSoftware(Software): """ if self.software_manager and self.software_manager.node.operating_state != NodeOperatingState.ON: self.software_manager.node.sys_log.error( - f"{self.name} Error: {self.software_manager.node.hostname} is not powered on." + f"{self.name} Error: {self.software_manager.node.config.hostname} is not powered on." ) return False return True diff --git a/tests/assets/configs/data_manipulation.yaml b/tests/assets/configs/data_manipulation.yaml index 97442903..bddea1a0 100644 --- a/tests/assets/configs/data_manipulation.yaml +++ b/tests/assets/configs/data_manipulation.yaml @@ -187,7 +187,7 @@ agents: num_files: 1 num_nics: 2 include_num_access: false - include_nmne: true + include_nmne: true monitored_traffic: icmp: - NONE diff --git a/tests/conftest.py b/tests/conftest.py index 6cbcfa84..08c16537 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -195,68 +195,91 @@ def example_network() -> Network: network = Network() # Router 1 + + router_1_cfg = {"hostname":"router_1", "type":"router"} + # router_1 = Router(hostname="router_1", start_up_duration=0) - router_1 = Router(hostname="router_1", start_up_duration=0) + router_1 = Router.from_config(config=router_1_cfg) router_1.power_on() router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0") router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0") # Switch 1 - # switch_1_config = Switch.ConfigSchema() - switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0) + + switch_1_cfg = {"hostname": "switch_1", "type": "switch"} + + switch_1 = Switch.from_config(config=switch_1_cfg) + + # switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0) switch_1.power_on() network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[8]) router_1.enable_port(1) # Switch 2 - # switch_2_config = Switch.ConfigSchema() - switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0) + switch_2_config = {"hostname": "switch_2", "type": "switch", "num_ports": 8} + # switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0) + switch_2 = Switch.from_config(config=switch_2_config) switch_2.power_on() network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8]) router_1.enable_port(2) - # Client 1 - client_1 = Computer( - hostname="client_1", - ip_address="192.168.10.21", - subnet_mask="255.255.255.0", - default_gateway="192.168.10.1", - start_up_duration=0, - ) + # # Client 1 + + client_1_cfg = {"type": "computer", + "hostname": "client_1", + "ip_address": "192.168.10.21", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.10.1", + "start_up_duration": 0} + + client_1=Computer.from_config(config=client_1_cfg) + client_1.power_on() network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) - # Client 2 - client_2 = Computer( - hostname="client_2", - ip_address="192.168.10.22", - subnet_mask="255.255.255.0", - default_gateway="192.168.10.1", - start_up_duration=0, - ) + # # Client 2 + + client_2_cfg = {"type": "computer", + "hostname": "client_2", + "ip_address": "192.168.10.22", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.10.1", + "start_up_duration": 0, + } + + client_2 = Computer.from_config(config=client_2_cfg) + client_2.power_on() network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.network_interface[2]) - # Server 1 - server_1 = Server( - hostname="server_1", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + # # Server 1 + + server_1_cfg = {"type": "server", + "hostname": "server_1", + "ip_address":"192.168.1.10", + "subnet_mask":"255.255.255.0", + "default_gateway":"192.168.1.1", + "start_up_duration":0, + } + + server_1 = Server.from_config(config=server_1_cfg) + server_1.power_on() network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1]) - # DServer 2 - server_2 = Server( - hostname="server_2", - ip_address="192.168.1.14", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + # # DServer 2 + + server_2_cfg = {"type": "server", + "hostname": "server_2", + "ip_address":"192.168.1.14", + "subnet_mask":"255.255.255.0", + "default_gateway":"192.168.1.1", + "start_up_duration":0, + } + + server_2 = Server.from_config(config=server_2_cfg) + server_2.power_on() network.connect(endpoint_b=server_2.network_interface[1], endpoint_a=switch_1.network_interface[2]) @@ -264,6 +287,8 @@ def example_network() -> Network: assert all(link.is_up for link in network.links.values()) + client_1.software_manager.show() + return network diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py index 16f4dee5..c9691fab 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py @@ -6,6 +6,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router +from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP from primaite.utils.validation.port import PORT_LOOKUP from tests.integration_tests.configuration_file_parsing import DMZ_NETWORK, load_config diff --git a/tests/integration_tests/game_layer/observations/test_acl_observations.py b/tests/integration_tests/game_layer/observations/test_acl_observations.py index 02cf005a..68964b90 100644 --- a/tests/integration_tests/game_layer/observations/test_acl_observations.py +++ b/tests/integration_tests/game_layer/observations/test_acl_observations.py @@ -36,7 +36,7 @@ def test_acl_observations(simulation): router.acl.add_rule(action=ACLAction.PERMIT, dst_port=PORT_LOOKUP["NTP"], src_port=PORT_LOOKUP["NTP"], position=1) acl_obs = ACLObservation( - where=["network", "nodes", router.hostname, "acl", "acl"], + where=["network", "nodes", router.config.hostname, "acl", "acl"], ip_list=[], port_list=[123, 80, 5432], protocol_list=["tcp", "udp", "icmp"], diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py index 0268cb95..a56deb2b 100644 --- a/tests/integration_tests/game_layer/observations/test_file_system_observations.py +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -24,7 +24,7 @@ def test_file_observation(simulation): file = pc.file_system.create_file(file_name="dog.png") dog_file_obs = FileObservation( - where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"], include_num_access=False, file_system_requires_scan=True, ) @@ -52,7 +52,7 @@ def test_folder_observation(simulation): file = pc.file_system.create_file(file_name="dog.png", folder_name="test_folder") root_folder_obs = FolderObservation( - where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"], + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "test_folder"], include_num_access=False, file_system_requires_scan=True, num_files=1, diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py index fe0c3a57..5a0ebe8f 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py @@ -50,7 +50,7 @@ def test_wireless_router_from_config(): }, } - rt = Router.from_config(cfg=cfg) + rt = Router.from_config(config=cfg) assert rt.num_ports == 6 From 65355f83e85874dfa1f3ca941ecc90cf10254332 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Thu, 23 Jan 2025 09:52:14 +0000 Subject: [PATCH 04/23] #2887 - Commit before switching branch --- src/primaite/game/game.py | 2 - .../network/hardware/nodes/host/server.py | 2 +- .../hardware/nodes/network/firewall.py | 7 --- .../network/hardware/nodes/network/router.py | 47 ++----------------- 4 files changed, 5 insertions(+), 53 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 009a376f..6a902c40 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -266,8 +266,6 @@ class PrimaiteGame: for node_cfg in nodes_cfg: n_type = node_cfg["type"] - # node_config: dict = node_cfg["config"] - print(f"{n_type}:{node_cfg}") new_node = None if n_type in Node._registry: diff --git a/src/primaite/simulator/network/hardware/nodes/host/server.py b/src/primaite/simulator/network/hardware/nodes/host/server.py index 1b3f6c58..f1abefc2 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/server.py +++ b/src/primaite/simulator/network/hardware/nodes/host/server.py @@ -51,4 +51,4 @@ class Printer(HostNode, identifier="printer"): class ConfigSchema(HostNode.ConfigSchema): """Configuration Schema for Printer class.""" - hostname: ClassVar[str] = "printer" \ No newline at end of file + hostname: str = "printer" \ No newline at end of file diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 2ebfe44a..a30c49bd 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -572,13 +572,6 @@ class Firewall(Router, identifier="firewall"): @classmethod def from_config(cls, config: dict) -> "Firewall": """Create a firewall based on a config dict.""" - # firewall = Firewall( - # hostname=config["hostname"], - # operating_state=NodeOperatingState.ON - # if not (p := config.get("operating_state")) - # else NodeOperatingState[p.upper()], - # ) - firewall = Firewall(config = cls.ConfigSchema(**config)) if "ports" in config: diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index e475df66..4f9d9ca4 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1212,11 +1212,11 @@ class Router(NetworkNode, identifier="router"): network_interface: Dict[int, RouterInterface] = {} "The Router Interfaces on the node by port id." - sys_log: SysLog = None + sys_log: SysLog - acl: AccessControlList = None + acl: AccessControlList - route_table: RouteTable = None + route_table: RouteTable config: "Router.ConfigSchema" @@ -1250,8 +1250,6 @@ class Router(NetworkNode, identifier="router"): self._set_default_acl() - - def _install_system_software(self): """ Installs essential system software and network services on the router. @@ -1403,7 +1401,7 @@ class Router(NetworkNode, identifier="router"): return # Check if it's permitted - permitted, rule = self.config.acl.is_permitted(frame) + permitted, rule = self.acl.is_permitted(frame) if not permitted: at_port = self._get_port_of_nic(from_network_interface) @@ -1576,43 +1574,6 @@ class Router(NetworkNode, identifier="router"): ) print(table) - def setup_router(self, cfg: dict) -> Router: - """TODO: This is the extra bit of Router's from_config metho. Needs sorting.""" - if "ports" in cfg: - for port_num, port_cfg in cfg["ports"].items(): - self.configure_port( - port=port_num, - ip_address=port_cfg["ip_address"], - subnet_mask=IPv4Address(port_cfg.get("subnet_mask", "255.255.255.0")), - ) - if "acl" in cfg: - for r_num, r_cfg in cfg["acl"].items(): - self.config.acl.add_rule( - action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], - protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], - src_ip_address=r_cfg.get("src_ip"), - src_wildcard_mask=r_cfg.get("src_wildcard_mask"), - dst_ip_address=r_cfg.get("dst_ip"), - dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"), - position=r_num, - ) - if "routes" in cfg: - for route in cfg.get("routes"): - self.config.route_table.add_route( - address=IPv4Address(route.get("address")), - subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")), - next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")), - metric=float(route.get("metric", 0)), - ) - if "default_route" in cfg: - next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None) - if next_hop_ip_address: - self.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) - return self - - @classmethod def from_config(cls, config: dict, **kwargs) -> "Router": """Create a router based on a config dict. From b9d2cd25f3b992349bb156c646dbab7f66106388 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Thu, 23 Jan 2025 15:28:10 +0000 Subject: [PATCH 05/23] #2887 - Unit test fixes ahead of raising PR. --- src/primaite/game/agent/actions/node.py | 1 - src/primaite/simulator/network/creation.py | 23 +-- .../simulator/network/hardware/base.py | 2 +- .../network/hardware/nodes/network/switch.py | 2 +- src/primaite/simulator/network/networks.py | 135 ++++++++++-------- tests/conftest.py | 9 +- .../_network/_hardware/nodes/test_acl.py | 9 +- .../_network/_hardware/nodes/test_router.py | 2 +- .../_network/_hardware/nodes/test_switch.py | 6 +- .../test_network_interface_actions.py | 8 +- .../_network/_hardware/test_node_actions.py | 18 ++- .../_simulator/_network/test_container.py | 6 +- .../_simulator/_network/test_creation.py | 2 +- .../_red_applications/test_c2_suite.py | 24 ++-- .../_red_applications/test_dos_bot.py | 11 +- .../_applications/test_database_client.py | 8 +- .../_system/_services/test_database.py | 9 +- .../_system/_services/test_dns_client.py | 22 ++- .../_system/_services/test_dns_server.py | 22 +-- .../_system/_services/test_ftp_client.py | 17 +-- .../_system/_services/test_ftp_server.py | 14 +- .../_system/_services/test_web_server.py | 14 +- 22 files changed, 222 insertions(+), 142 deletions(-) diff --git a/src/primaite/game/agent/actions/node.py b/src/primaite/game/agent/actions/node.py index fbab18f0..5e1b6725 100644 --- a/src/primaite/game/agent/actions/node.py +++ b/src/primaite/game/agent/actions/node.py @@ -36,7 +36,6 @@ class NodeAbstractAction(AbstractAction, identifier="node_abstract"): @classmethod def form_request(cls, config: ConfigSchema) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - print(config) return ["network", "node", config.node_name, config.verb] diff --git a/src/primaite/simulator/network/creation.py b/src/primaite/simulator/network/creation.py index ebd17638..3221939b 100644 --- a/src/primaite/simulator/network/creation.py +++ b/src/primaite/simulator/network/creation.py @@ -153,7 +153,7 @@ class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"): # Create a core switch if more than one edge switch is needed if num_of_switches > 1: - core_switch = Switch(hostname=f"switch_core_{config.lan_name}", start_up_duration=0) + core_switch = Switch.from_config(config = {"type":"switch","hostname":f"switch_core_{config.lan_name}", "start_up_duration": 0 }) core_switch.power_on() network.add_node(core_switch) core_switch_port = 1 @@ -164,7 +164,8 @@ class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"): # Optionally include a router in the LAN if config.include_router: default_gateway = IPv4Address(f"192.168.{config.subnet_base}.1") - router = Router(hostname=f"router_{config.lan_name}", start_up_duration=0) + # router = Router(hostname=f"router_{config.lan_name}", start_up_duration=0) + router = Router.from_config(config={"hostname":f"router_{config.lan_name}", "type": "router", "start_up_duration": 0}) router.power_on() router.acl.add_rule( action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 @@ -177,7 +178,7 @@ class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"): # Initialise the first edge switch and connect to the router or core switch switch_port = 0 switch_n = 1 - switch = Switch(hostname=f"switch_edge_{switch_n}_{config.lan_name}", start_up_duration=0) + switch = Switch.from_config(config={"type": "switch","hostname":f"switch_edge_{switch_n}_{config.lan_name}", "start_up_duration":0}) switch.power_on() network.add_node(switch) if num_of_switches > 1: @@ -195,7 +196,7 @@ class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"): if switch_port == effective_network_interface: switch_n += 1 switch_port = 0 - switch = Switch(hostname=f"switch_edge_{switch_n}_{config.lan_name}", start_up_duration=0) + switch = Switch.from_config(config={"type": "switch","hostname":f"switch_edge_{switch_n}_{config.lan_name}", "start_up_duration":0}) switch.power_on() network.add_node(switch) # Connect the new switch to the router or core switch @@ -212,13 +213,13 @@ class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"): ) # Create and add a PC to the network - pc = Computer( - hostname=f"pc_{i}_{config.lan_name}", - ip_address=f"192.168.{config.subnet_base}.{i+config.pcs_ip_block_start-1}", - subnet_mask="255.255.255.0", - default_gateway=default_gateway, - start_up_duration=0, - ) + pc_cfg = {"type": "computer", + "hostname": f"pc_{i}_{config.lan_name}", + "ip_address": f"192.168.{config.subnet_base}.{i+config.pcs_ip_block_start-1}", + "default_gateway": "192.168.10.1", + "start_up_duration": 0, + } + pc = Computer.from_config(config = pc_cfg) pc.power_on() network.add_node(pc) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 872b1bdf..f68b627a 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1979,7 +1979,7 @@ class Node(SimComponent, ABC): else: if self.operating_state == NodeOperatingState.SHUTTING_DOWN: self.operating_state = NodeOperatingState.OFF - self.sys_log.info(f"{self.hostname}: Turned off") + self.sys_log.info(f"{self.config.hostname}: Turned off") self._shut_down_actions() # if resetting turn back on diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index 2ca0cafd..e97c5321 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -129,7 +129,7 @@ class Switch(NetworkNode, identifier="switch"): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.hostname} Switch Ports" + table.title = f"{self.config.hostname} Switch Ports" for port_num, port in self.network_interface.items(): table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"]) print(table) diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index c840748e..4d881343 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -128,32 +128,34 @@ def arcd_uc2_network() -> Network: network = Network() # Router 1 - router_1 = Router(hostname="router_1", num_ports=5, start_up_duration=0) + router_1 = Router.from_config(config={"type":"router", "hostname":"router_1", "num_ports":5, "start_up_duration":0}) router_1.power_on() router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0") router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0") # Switch 1 - switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0) + switch_1 = Switch.from_config(config={"type":"switch", "hostname":"switch_1", "num_ports":8, "start_up_duration":0}) switch_1.power_on() network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[8]) router_1.enable_port(1) # Switch 2 - switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0) + switch_2 = Switch.from_config(config={"type":"switch", "hostname":"switch_2", "num_ports":8, "start_up_duration":0}) switch_2.power_on() network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8]) router_1.enable_port(2) # Client 1 - client_1 = Computer( - hostname="client_1", - ip_address="192.168.10.21", - subnet_mask="255.255.255.0", - default_gateway="192.168.10.1", - dns_server=IPv4Address("192.168.1.10"), - start_up_duration=0, - ) + client_1_cfg = {"type": "computer", + "hostname": "client_1", + "ip_address": "192.168.10.21", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.10.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } + client_1: Computer = Computer.from_config(config = client_1_cfg) + client_1.power_on() network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) client_1.software_manager.install(DatabaseClient) @@ -172,14 +174,17 @@ def arcd_uc2_network() -> Network: ) # Client 2 - client_2 = Computer( - hostname="client_2", - ip_address="192.168.10.22", - subnet_mask="255.255.255.0", - default_gateway="192.168.10.1", - dns_server=IPv4Address("192.168.1.10"), - start_up_duration=0, - ) + + client_2_cfg = {"type": "computer", + "hostname": "client_2", + "ip_address": "192.168.10.22", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.10.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } + client_2: Computer = Computer.from_config(config = client_2_cfg) + client_2.power_on() client_2.software_manager.install(DatabaseClient) db_client_2 = client_2.software_manager.software.get("DatabaseClient") @@ -193,27 +198,34 @@ def arcd_uc2_network() -> Network: ) # Domain Controller - domain_controller = Server( - hostname="domain_controller", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + + domain_controller_cfg = {"type": "server", + "hostname": "domain_controller", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0 + } + + domain_controller = Server.from_config(config=domain_controller_cfg) domain_controller.power_on() domain_controller.software_manager.install(DNSServer) network.connect(endpoint_b=domain_controller.network_interface[1], endpoint_a=switch_1.network_interface[1]) # Database Server - database_server = Server( - hostname="database_server", - ip_address="192.168.1.14", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - dns_server=IPv4Address("192.168.1.10"), - start_up_duration=0, - ) + + database_server_cfg = {"type": "server", + "hostname": "database_server", + "ip_address": "192.168.1.14", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0 + } + + database_server = Server.from_config(config=database_server_cfg) + database_server.power_on() network.connect(endpoint_b=database_server.network_interface[1], endpoint_a=switch_1.network_interface[3]) @@ -223,14 +235,18 @@ def arcd_uc2_network() -> Network: database_service.configure_backup(backup_server=IPv4Address("192.168.1.16")) # Web Server - web_server = Server( - hostname="web_server", - ip_address="192.168.1.12", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - dns_server=IPv4Address("192.168.1.10"), - start_up_duration=0, - ) + + + web_server_cfg = {"type": "server", + "hostname": "web_server", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0 + } + web_server = Server.from_config(config=web_server_cfg) + web_server.power_on() web_server.software_manager.install(DatabaseClient) @@ -247,27 +263,30 @@ def arcd_uc2_network() -> Network: dns_server_service.dns_register("arcd.com", web_server.network_interface[1].ip_address) # Backup Server - backup_server = Server( - hostname="backup_server", - ip_address="192.168.1.16", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - dns_server=IPv4Address("192.168.1.10"), - start_up_duration=0, - ) + backup_server_cfg = {"type": "server", + "hostname": "backup_server", + "ip_address": "192.168.1.16", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0 + } + backup_server: Server = Server.from_config(config=backup_server_cfg) + backup_server.power_on() backup_server.software_manager.install(FTPServer) network.connect(endpoint_b=backup_server.network_interface[1], endpoint_a=switch_1.network_interface[4]) # Security Suite - security_suite = Server( - hostname="security_suite", - ip_address="192.168.1.110", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - dns_server=IPv4Address("192.168.1.10"), - start_up_duration=0, - ) + security_suite_cfg = {"type": "server", + "hostname": "backup_server", + "ip_address": "192.168.1.110", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0 + } + security_suite: Server = Server.from_config(config=security_suite_cfg) security_suite.power_on() network.connect(endpoint_b=security_suite.network_interface[1], endpoint_a=switch_1.network_interface[7]) security_suite.connect_nic(NIC(ip_address="192.168.10.110", subnet_mask="255.255.255.0")) diff --git a/tests/conftest.py b/tests/conftest.py index 287f216c..fc86bb4d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -119,7 +119,14 @@ def application_class(): @pytest.fixture(scope="function") def file_system() -> FileSystem: - computer = Computer(hostname="fs_node", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + # computer = Computer(hostname="fs_node", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + computer_cfg = {"type": "computer", + "hostname": "fs_node", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + computer = Computer.from_config(config=computer_cfg) computer.power_on() return computer.file_system diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py index 79392d66..ee7eb08f 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py @@ -25,7 +25,8 @@ def router_with_acl_rules(): :return: A configured Router object with ACL rules. """ - router = Router("Router") + router_cfg = {"hostname": "router_1", "type": "router"} + router = Router.from_config(config=router_cfg) acl = router.acl # Add rules here as needed acl.add_rule( @@ -62,7 +63,8 @@ def router_with_wildcard_acl(): :return: A Router object with configured ACL rules, including rules with wildcard masking. """ - router = Router("Router") + router_cfg = {"hostname": "router_1", "type": "router"} + router = Router.from_config(config=router_cfg) acl = router.acl # Rule to permit traffic from a specific source IP and port to a specific destination IP and port acl.add_rule( @@ -243,7 +245,8 @@ def test_ip_traffic_from_specific_subnet(): - Traffic from outside the 192.168.1.0/24 subnet is denied. """ - router = Router("Router") + router_cfg = {"hostname": "router_1", "type": "router"} + router = Router.from_config(config=router_cfg) acl = router.acl # Add rules here as needed acl.add_rule( diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py index 5a0ebe8f..e9d16533 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py @@ -52,7 +52,7 @@ def test_wireless_router_from_config(): rt = Router.from_config(config=cfg) - assert rt.num_ports == 6 + assert rt.config.num_ports == 6 assert rt.network_interface[1].ip_address == IPv4Address("192.168.1.1") assert rt.network_interface[1].subnet_mask == IPv4Address("255.255.255.0") diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_switch.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_switch.py index e6bff60e..dbc04f6d 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_switch.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_switch.py @@ -7,7 +7,11 @@ from primaite.simulator.network.hardware.nodes.network.switch import Switch @pytest.fixture(scope="function") def switch() -> Switch: - switch: Switch = Switch(hostname="switch_1", num_ports=8, start_up_duration=0) + switch_cfg = {"type": "switch", + "hostname": "switch_1", + "num_ports": 8, + "start_up_duration": 0} + switch: Switch = Switch.from_config(config=switch_cfg) switch.power_on() switch.show() return switch diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py index 5cff4407..0e0023cd 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py @@ -7,7 +7,13 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer @pytest.fixture def node() -> Node: - return Computer(hostname="test", ip_address="192.168.1.2", subnet_mask="255.255.255.0") + computer_cfg = {"type": "computer", + "hostname": "test", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0"} + computer = Computer.from_config(config=computer_cfg) + + return computer def test_nic_enabled_validator(node): diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py index 672a4b5f..d077f46b 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py @@ -12,8 +12,16 @@ from tests.conftest import DummyApplication, DummyService @pytest.fixture def node() -> Node: - return Computer(hostname="test", ip_address="192.168.1.2", subnet_mask="255.255.255.0") + computer_cfg = {"type": "computer", + "hostname": "test", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "shut_down_duration": 3, + "operating_state": "OFF", + } + computer = Computer.from_config(config=computer_cfg) + return computer def test_node_startup(node): assert node.operating_state == NodeOperatingState.OFF @@ -166,7 +174,7 @@ def test_node_is_on_validator(node): """Test that the node is on validator.""" node.power_on() - for i in range(node.start_up_duration + 1): + for i in range(node.config.start_up_duration + 1): node.apply_timestep(i) validator = Node._NodeIsOnValidator(node=node) @@ -174,7 +182,7 @@ def test_node_is_on_validator(node): assert validator(request=[], context={}) node.power_off() - for i in range(node.shut_down_duration + 1): + for i in range(node.config.shut_down_duration + 1): node.apply_timestep(i) assert validator(request=[], context={}) is False @@ -184,7 +192,7 @@ def test_node_is_off_validator(node): """Test that the node is on validator.""" node.power_on() - for i in range(node.start_up_duration + 1): + for i in range(node.config.start_up_duration + 1): node.apply_timestep(i) validator = Node._NodeIsOffValidator(node=node) @@ -192,7 +200,7 @@ def test_node_is_off_validator(node): assert validator(request=[], context={}) is False node.power_off() - for i in range(node.shut_down_duration + 1): + for i in range(node.config.shut_down_duration + 1): node.apply_timestep(i) assert validator(request=[], context={}) diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_container.py b/tests/unit_tests/_primaite/_simulator/_network/test_container.py index b1de710a..d175b865 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_container.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_container.py @@ -61,12 +61,12 @@ def test_apply_timestep_to_nodes(network): client_1.power_off() assert client_1.operating_state is NodeOperatingState.SHUTTING_DOWN - for i in range(client_1.shut_down_duration + 1): + for i in range(client_1.config.shut_down_duration + 1): network.apply_timestep(timestep=i) assert client_1.operating_state is NodeOperatingState.OFF - network.apply_timestep(client_1.shut_down_duration + 2) + network.apply_timestep(client_1.config.shut_down_duration + 2) assert client_1.operating_state is NodeOperatingState.OFF @@ -74,7 +74,7 @@ def test_removing_node_that_does_not_exist(network): """Node that does not exist on network should not affect existing nodes.""" assert len(network.nodes) is 7 - network.remove_node(Computer(hostname="new_node", ip_address="192.168.1.2", subnet_mask="255.255.255.0")) + network.remove_node(Computer.from_config(config = {"type":"computer","hostname":"new_node", "ip_address":"192.168.1.2", "subnet_mask":"255.255.255.0"})) assert len(network.nodes) is 7 diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_creation.py b/tests/unit_tests/_primaite/_simulator/_network/test_creation.py index 9885df67..29331d08 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_creation.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_creation.py @@ -19,7 +19,7 @@ def _assert_valid_creation(net: Network, lan_name, subnet_base, pcs_ip_block_sta num_routers = 1 if include_router else 0 total_nodes = num_pcs + num_switches + num_routers - assert all((n.hostname.endswith(lan_name) for n in net.nodes.values())) + assert all((n.config.hostname.endswith(lan_name) for n in net.nodes.values())) assert len(net.computer_nodes) == num_pcs assert len(net.switch_nodes) == num_switches assert len(net.router_nodes) == num_routers diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py index 17f8445a..5d8bea80 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py @@ -16,19 +16,25 @@ def basic_c2_network() -> Network: network = Network() # Creating two generic nodes for the C2 Server and the C2 Beacon. + computer_a_cfg = {"type": "computer", + "hostname": "computer_a", + "ip_address": "192.168.0.1", + "subnet_mask": "255.255.255.252", + "start_up_duration": 0} + computer_a = Computer.from_config(config = computer_a_cfg) - computer_a = Computer( - hostname="computer_a", - ip_address="192.168.0.1", - subnet_mask="255.255.255.252", - start_up_duration=0, - ) computer_a.power_on() computer_a.software_manager.install(software_class=C2Server) - computer_b = Computer( - hostname="computer_b", ip_address="192.168.0.2", subnet_mask="255.255.255.252", start_up_duration=0 - ) + + computer_b_cfg = {"type": "computer", + "hostname": "computer_b", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.252", + "start_up_duration": 0, + } + + computer_b = Computer.from_config(config=computer_b_cfg) computer_b.power_on() computer_b.software_manager.install(software_class=C2Beacon) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py index 9d8b7809..02b13724 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py @@ -12,9 +12,14 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def dos_bot() -> DoSBot: - computer = Computer( - hostname="compromised_pc", ip_address="192.168.0.1", subnet_mask="255.255.255.0", start_up_duration=0 - ) + computer_cfg = {"type":"computer", + "hostname": "compromised_pc", + "ip_address": "192.168.0.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + computer: Computer = Computer.from_config(config=computer_cfg) + computer.power_on() computer.software_manager.install(DoSBot) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py index 5917fde7..6e32b646 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py @@ -17,14 +17,14 @@ from primaite.simulator.system.services.database.database_service import Databas def database_client_on_computer() -> Tuple[DatabaseClient, Computer]: network = Network() - db_server = Server(hostname="db_server", ip_address="192.168.0.1", subnet_mask="255.255.255.0", start_up_duration=0) + db_server: Server = Server.from_config(config={"type": "server", "hostname":"db_server", "ip_address":"192.168.0.1", "subnet_mask":"255.255.255.0", "start_up_duration":0}) db_server.power_on() db_server.software_manager.install(DatabaseService) db_server.software_manager.software["DatabaseService"].start() - db_client = Computer( - hostname="db_client", ip_address="192.168.0.2", subnet_mask="255.255.255.0", start_up_duration=0 - ) + db_client: Computer = Computer.from_config(config = {"type":"computer", + "hostname":"db_client", "ip_address":"192.168.0.2", "subnet_mask":"255.255.255.0", "start_up_duration":0 + }) db_client.power_on() db_client.software_manager.install(DatabaseClient) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py index ef165c8f..b7ba2d04 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py @@ -8,7 +8,14 @@ from primaite.simulator.system.services.database.database_service import Databas @pytest.fixture(scope="function") def database_server() -> Node: - node = Computer(hostname="db_node", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + node_cfg = {"type": "computer", + "hostname": "db_node", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + + node = Computer.from_config(config=node_cfg) node.power_on() node.software_manager.install(DatabaseService) node.software_manager.software.get("DatabaseService").start() diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py index 1bc5b353..3f621331 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py @@ -14,13 +14,21 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def dns_client() -> Computer: - node = Computer( - hostname="dns_client", - ip_address="192.168.1.11", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - dns_server=IPv4Address("192.168.1.10"), - ) + + node_cfg = {"type": "computer", + "hostname": "dns_client", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10")} + node = Computer.from_config(config=node_cfg) + # node = Computer( + # hostname="dns_client", + # ip_address="192.168.1.11", + # subnet_mask="255.255.255.0", + # default_gateway="192.168.1.1", + # dns_server=IPv4Address("192.168.1.10"), + # ) return node diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py index 3bc2b1a4..8df96099 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py @@ -16,13 +16,13 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def dns_server() -> Node: - node = Server( - hostname="dns_server", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + node_cfg = {"type": "server", + "hostname": "dns_server", + "ip_address": "192.168.1.10", + "subnet_mask":"255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration":0} + node = Server.from_config(config=node_cfg) node.power_on() node.software_manager.install(software_class=DNSServer) return node @@ -55,7 +55,13 @@ def test_dns_server_receive(dns_server): # register the web server in the domain controller dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12")) - client = Computer(hostname="client", ip_address="192.168.1.11", subnet_mask="255.255.255.0", start_up_duration=0) + client_cfg = {"type": "computer", + "hostname": "client", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + client = Computer.from_config(config=client_cfg) client.power_on() client.dns_server = IPv4Address("192.168.1.10") network = Network() diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py index d3e679db..c6e10b7d 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py @@ -16,13 +16,14 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def ftp_client() -> Node: - node = Computer( - hostname="ftp_client", - ip_address="192.168.1.11", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + node_cfg = {"type": "computer", + "hostname": "ftp_client", + "ip_address":"192.168.1.11", + "subnet_mask":"255.255.255.0", + "default_gateway":"192.168.1.1", + "start_up_duration": 0, + } + node = Computer.from_config(config=node_cfg) node.power_on() return node @@ -94,7 +95,7 @@ def test_offline_ftp_client_receives_request(ftp_client): ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") ftp_client.power_off() - for i in range(ftp_client.shut_down_duration + 1): + for i in range(ftp_client.config.shut_down_duration + 1): ftp_client.apply_timestep(timestep=i) assert ftp_client.operating_state is NodeOperatingState.OFF diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py index 37c3d019..5cae88e0 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py @@ -14,13 +14,13 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def ftp_server() -> Node: - node = Server( - hostname="ftp_server", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + node_cfg = {"type": "server", + "hostname":"ftp_server", + "ip_address":"192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration":0} + node = Server.from_config(config=node_cfg) node.power_on() node.software_manager.install(software_class=FTPServer) return node diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py index 606a195c..f0901b70 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py @@ -16,13 +16,13 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def web_server() -> Server: - node = Server( - hostname="web_server", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + node_cfg = {"type": "server", + "hostname":"web_server", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway":"192.168.1.1", + "start_up_duration":0 } + node = Server.from_config(config=node_cfg) node.power_on() node.software_manager.install(WebServer) node.software_manager.software.get("WebServer").start() From 30c177c2722ea43a80dbf918d2839b883f57a63c Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Thu, 23 Jan 2025 17:07:15 +0000 Subject: [PATCH 06/23] #2887 - Additional test failure fixes --- .../simulator/network/hardware/base.py | 2 +- .../hardware/nodes/network/wireless_router.py | 8 +-- tests/conftest.py | 57 ++++++++++--------- .../observations/test_firewall_observation.py | 5 +- .../observations/test_link_observations.py | 10 ++-- .../observations/test_nic_observations.py | 6 +- .../observations/test_node_observations.py | 4 +- .../observations/test_router_observation.py | 4 +- .../test_software_observations.py | 4 +- .../game_layer/test_action_mask.py | 1 + .../game_layer/test_actions.py | 1 + .../network/test_airspace_config.py | 1 + .../network/test_broadcast.py | 47 ++++++++------- .../_system/_applications/test_web_browser.py | 22 +++---- 14 files changed, 88 insertions(+), 84 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index f68b627a..d462f75c 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1760,7 +1760,7 @@ class Node(SimComponent, ABC): self.software_manager.install(application_class) application_instance = self.software_manager.software.get(application_name) self.applications[application_instance.uuid] = application_instance - _LOGGER.debug(f"Added application {application_instance.name} to node {self.hostname}") + _LOGGER.debug(f"Added application {application_instance.name} to node {self.config.hostname}") self._application_request_manager.add_request( application_name, RequestType(func=application_instance._request_manager) ) diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 2c4b5976..75e4d5ea 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -126,16 +126,16 @@ class WirelessRouter(Router, identifier="wireless_router"): config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.ConfigSchema()) - class ConfigSchema(Router.ConfigSChema): + class ConfigSchema(Router.ConfigSchema): """Configuration Schema for WirelessRouter nodes within PrimAITE.""" hostname: str = "WirelessRouter" - def __init__(self, hostname: str, airspace: AirSpace, **kwargs): - super().__init__(hostname=hostname, num_ports=0, airspace=airspace, **kwargs) + def __init__(self, **kwargs): + super().__init__(hostname=kwargs["config"].hostname, num_ports=0, airspace=kwargs["config"].airspace, **kwargs) self.connect_nic( - WirelessAccessPoint(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", airspace=airspace) + WirelessAccessPoint(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", airspace=kwargs["config"].airspace) ) self.connect_nic(RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")) diff --git a/tests/conftest.py b/tests/conftest.py index fc86bb4d..1bdc217c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -348,29 +348,29 @@ def install_stuff_to_sim(sim: Simulation): # 1: Set up network hardware # 1.1: Configure the router - router = Router(hostname="router", num_ports=3, start_up_duration=0) + router = Router.from_config(config={"type":"router", "hostname":"router", "num_ports":3, "start_up_duration":0}) router.power_on() router.configure_port(port=1, ip_address="10.0.1.1", subnet_mask="255.255.255.0") router.configure_port(port=2, ip_address="10.0.2.1", subnet_mask="255.255.255.0") # 1.2: Create and connect switches - switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0) + switch_1 = Switch.from_config(config={"type":"switch", "hostname":"switch_1", "num_ports":6, "start_up_duration":0}) switch_1.power_on() network.connect(endpoint_a=router.network_interface[1], endpoint_b=switch_1.network_interface[6]) router.enable_port(1) - switch_2 = Switch(hostname="switch_2", num_ports=6, start_up_duration=0) + switch_2 = Switch.from_config(config={"type":"switch", "hostname":"switch_2", "num_ports":6, "start_up_duration":0}) switch_2.power_on() network.connect(endpoint_a=router.network_interface[2], endpoint_b=switch_2.network_interface[6]) router.enable_port(2) # 1.3: Create and connect computer - client_1 = Computer( - hostname="client_1", - ip_address="10.0.1.2", - subnet_mask="255.255.255.0", - default_gateway="10.0.1.1", - start_up_duration=0, - ) + client_1_cfg = {"type": "computer", + "hostname": "client_1", + "ip_address":"10.0.1.2", + "subnet_mask":"255.255.255.0", + "default_gateway": "10.0.1.1", + "start_up_duration":0} + client_1: Computer = Computer.from_config(config=client_1_cfg) client_1.power_on() network.connect( endpoint_a=client_1.network_interface[1], @@ -378,23 +378,26 @@ def install_stuff_to_sim(sim: Simulation): ) # 1.4: Create and connect servers - server_1 = Server( - hostname="server_1", - ip_address="10.0.2.2", - subnet_mask="255.255.255.0", - default_gateway="10.0.2.1", - start_up_duration=0, - ) + server_1_cfg = {"type": "server", + "hostname":"server_1", + "ip_address": "10.0.2.2", + "subnet_mask":"255.255.255.0", + "default_gateway":"10.0.2.1", + "start_up_duration": 0} + + + server_1: Server = Server.from_config(config=server_1_cfg) server_1.power_on() network.connect(endpoint_a=server_1.network_interface[1], endpoint_b=switch_2.network_interface[1]) + server_2_cfg = {"type": "server", + "hostname":"server_2", + "ip_address": "10.0.2.3", + "subnet_mask":"255.255.255.0", + "default_gateway":"10.0.2.1", + "start_up_duration": 0} - server_2 = Server( - hostname="server_2", - ip_address="10.0.2.3", - subnet_mask="255.255.255.0", - default_gateway="10.0.2.1", - start_up_duration=0, - ) + + server_2: Server = Server.from_config(config=server_2_cfg) server_2.power_on() network.connect(endpoint_a=server_2.network_interface[1], endpoint_b=switch_2.network_interface[2]) @@ -442,18 +445,18 @@ def install_stuff_to_sim(sim: Simulation): assert acl_rule is None # 5.2: Assert the client is correctly configured - c: Computer = [node for node in sim.network.nodes.values() if node.hostname == "client_1"][0] + c: Computer = [node for node in sim.network.nodes.values() if node.config.hostname == "client_1"][0] assert c.software_manager.software.get("WebBrowser") is not None assert c.software_manager.software.get("DNSClient") is not None assert str(c.network_interface[1].ip_address) == "10.0.1.2" # 5.3: Assert that server_1 is correctly configured - s1: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_1"][0] + s1: Server = [node for node in sim.network.nodes.values() if node.config.hostname == "server_1"][0] assert str(s1.network_interface[1].ip_address) == "10.0.2.2" assert s1.software_manager.software.get("DNSServer") is not None # 5.4: Assert that server_2 is correctly configured - s2: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_2"][0] + s2: Server = [node for node in sim.network.nodes.values() if node.config.hostname == "server_2"][0] assert str(s2.network_interface[1].ip_address) == "10.0.2.3" assert s2.software_manager.software.get("WebServer") is not None diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 97608132..6b0d4359 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -25,7 +25,8 @@ def check_default_rules(acl_obs): def test_firewall_observation(): """Test adding/removing acl rules and enabling/disabling ports.""" net = Network() - firewall = Firewall(hostname="firewall", operating_state=NodeOperatingState.ON) + firewall_cfg = {"type": "firewall", "hostname": "firewall", "opertating_state": NodeOperatingState.ON} + firewall = Firewall.from_config(config=firewall_cfg) firewall_observation = FirewallObservation( where=[], num_rules=7, @@ -116,7 +117,7 @@ def test_firewall_observation(): assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(1, 4)) # connect a switch to the firewall and check that only the correct port is updated - switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON) + switch: Switch = Switch.from_config(config={"type": "switch", "hostname":"switch", "num_ports":1, "operating_state":NodeOperatingState.ON}) link = net.connect(firewall.network_interface[1], switch.network_interface[1]) assert firewall.network_interface[1].enabled observation = firewall_observation.observe(firewall.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_link_observations.py b/tests/integration_tests/game_layer/observations/test_link_observations.py index 630e29ea..b5cd6134 100644 --- a/tests/integration_tests/game_layer/observations/test_link_observations.py +++ b/tests/integration_tests/game_layer/observations/test_link_observations.py @@ -56,12 +56,12 @@ def test_link_observation(): """Check the shape and contents of the link observation.""" net = Network() sim = Simulation(network=net) - switch = Switch(hostname="switch", num_ports=5, operating_state=NodeOperatingState.ON) - computer_1 = Computer( - hostname="computer_1", ip_address="10.0.0.1", subnet_mask="255.255.255.0", start_up_duration=0 + switch: Switch = Switch.from_config(config={"type":"switch", "hostname":"switch", "num_ports":5, "operating_state":NodeOperatingState.ON}) + computer_1: Computer = Computer.from_config(config={"type": "computer", + "hostname":"computer_1", "ip_address":"10.0.0.1", "subnet_mask":"255.255.255.0", "start_up_duration":0} ) - computer_2 = Computer( - hostname="computer_2", ip_address="10.0.0.2", subnet_mask="255.255.255.0", start_up_duration=0 + computer_2: Computer = Computer.from_config(config={"type":"computer", + "hostname":"computer_2", "ip_address":"10.0.0.2", "subnet_mask":"255.255.255.0", "start_up_duration":0} ) computer_1.power_on() computer_2.power_on() diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index bd9417ba..2a311853 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -75,7 +75,7 @@ def test_nic(simulation): nic: NIC = pc.network_interface[1] - nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True) + nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True) # Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs nmne_config = { @@ -108,7 +108,7 @@ def test_nic_categories(simulation): """Test the NIC observation nmne count categories.""" pc: Computer = simulation.network.get_node_by_hostname("client_1") - nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True) + nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True) assert nic_obs.high_nmne_threshold == 10 # default assert nic_obs.med_nmne_threshold == 5 # default @@ -163,7 +163,7 @@ def test_nic_monitored_traffic(simulation): pc2: Computer = simulation.network.get_node_by_hostname("client_2") nic_obs = NICObservation( - where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=False, monitored_traffic=monitored_traffic + where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=False, monitored_traffic=monitored_traffic ) simulation.pre_timestep(0) # apply timestep to whole sim diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index 63ca8f6b..09eb3fe4 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -25,7 +25,7 @@ def test_host_observation(simulation): pc: Computer = simulation.network.get_node_by_hostname("client_1") host_obs = HostObservation( - where=["network", "nodes", pc.hostname], + where=["network", "nodes", pc.config.hostname], num_applications=0, num_files=1, num_folders=1, @@ -56,7 +56,7 @@ def test_host_observation(simulation): observation_state = host_obs.observe(simulation.describe_state()) assert observation_state.get("operating_status") == 4 # shutting down - for i in range(pc.shut_down_duration + 1): + for i in range(pc.config.shut_down_duration + 1): pc.apply_timestep(i) observation_state = host_obs.observe(simulation.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index f4bfb193..131af57f 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -16,7 +16,7 @@ from primaite.utils.validation.port import PORT_LOOKUP def test_router_observation(): """Test adding/removing acl rules and enabling/disabling ports.""" net = Network() - router = Router(hostname="router", num_ports=5, operating_state=NodeOperatingState.ON) + router = Router.from_config(config={"type": "router", "hostname":"router", "num_ports":5, "operating_state":NodeOperatingState.ON}) ports = [PortObservation(where=["NICs", i]) for i in range(1, 6)] acl = ACLObservation( @@ -89,7 +89,7 @@ def test_router_observation(): assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(1, 6)) # connect a switch to the router and check that only the correct port is updated - switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON) + switch: Switch = Switch.from_config(config={"type": "switch", "hostname":"switch", "num_ports":1, "operating_state":NodeOperatingState.ON}) link = net.connect(router.network_interface[1], switch.network_interface[1]) assert router.network_interface[1].enabled observed_output = router_observation.observe(router.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_software_observations.py b/tests/integration_tests/game_layer/observations/test_software_observations.py index 291ee395..7957625a 100644 --- a/tests/integration_tests/game_layer/observations/test_software_observations.py +++ b/tests/integration_tests/game_layer/observations/test_software_observations.py @@ -29,7 +29,7 @@ def test_service_observation(simulation): ntp_server = pc.software_manager.software.get("NTPServer") assert ntp_server - service_obs = ServiceObservation(where=["network", "nodes", pc.hostname, "services", "NTPServer"]) + service_obs = ServiceObservation(where=["network", "nodes", pc.config.hostname, "services", "NTPServer"]) assert service_obs.space["operating_status"] == spaces.Discrete(7) assert service_obs.space["health_status"] == spaces.Discrete(5) @@ -54,7 +54,7 @@ def test_application_observation(simulation): web_browser: WebBrowser = pc.software_manager.software.get("WebBrowser") assert web_browser - app_obs = ApplicationObservation(where=["network", "nodes", pc.hostname, "applications", "WebBrowser"]) + app_obs = ApplicationObservation(where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"]) web_browser.close() observation_state = app_obs.observe(simulation.describe_state()) diff --git a/tests/integration_tests/game_layer/test_action_mask.py b/tests/integration_tests/game_layer/test_action_mask.py index 75965f16..ebba1119 100644 --- a/tests/integration_tests/game_layer/test_action_mask.py +++ b/tests/integration_tests/game_layer/test_action_mask.py @@ -3,6 +3,7 @@ from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from tests.conftest import TEST_ASSETS_ROOT CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml" diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index 5a308cf8..9d9b528c 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -17,6 +17,7 @@ from typing import Tuple import pytest import yaml +from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv diff --git a/tests/integration_tests/network/test_airspace_config.py b/tests/integration_tests/network/test_airspace_config.py index e8abc0f2..fd3f6f28 100644 --- a/tests/integration_tests/network/test_airspace_config.py +++ b/tests/integration_tests/network/test_airspace_config.py @@ -2,6 +2,7 @@ import yaml from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from tests import TEST_ASSETS_ROOT diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index ed40334f..5c30d2ac 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -84,44 +84,47 @@ class BroadcastTestClient(Application, identifier="BroadcastTestClient"): def broadcast_network() -> Network: network = Network() - client_1 = Computer( - hostname="client_1", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + client_1_cfg = {"type": "computer", + "hostname": "client_1", + "ip_address":"192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration":0} + + client_1: Computer = Computer.from_config(config=client_1_cfg) client_1.power_on() client_1.software_manager.install(BroadcastTestClient) application_1 = client_1.software_manager.software["BroadcastTestClient"] application_1.run() + client_2_cfg = {"type": "computer", + "hostname": "client_2", + "ip_address":"192.168.1.3", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration":0} - client_2 = Computer( - hostname="client_2", - ip_address="192.168.1.3", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + client_2: Computer = Computer.from_config(config=client_2_cfg) client_2.power_on() client_2.software_manager.install(BroadcastTestClient) application_2 = client_2.software_manager.software["BroadcastTestClient"] application_2.run() - server_1 = Server( - hostname="server_1", - ip_address="192.168.1.1", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + server_1_cfg = {"type": "server", + "hostname": "server_1", + "ip_address":"192.168.1.1", + "subnet_mask": "255.255.255.0", + "default_gateway":"192.168.1.1", + "start_up_duration": 0} + + server_1 :Server = Server.from_config(config=server_1_cfg) + server_1.power_on() server_1.software_manager.install(BroadcastTestService) service: BroadcastTestService = server_1.software_manager.software["BroadcastService"] service.start() - switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0) + switch_1: Switch = Switch.from_config(config={"type": "switch", "hostname":"switch_1", "num_ports":6, "start_up_duration":0}) switch_1.power_on() network.connect(endpoint_a=client_1.network_interface[1], endpoint_b=switch_1.network_interface[1]) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py index f78b3261..85cd369f 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py @@ -12,13 +12,10 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def web_browser() -> WebBrowser: - computer = Computer( - hostname="web_client", - ip_address="192.168.1.11", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + computer_cfg = {"type": "computer", "hostname": "web_client", "ip_address": "192.168.1.11", "subnet_mask": "255.255.255.0", "default_gateway": "192.168.1.1", "start_up_duration": 0} + + computer: Computer = Computer.from_config(config=computer_cfg) + computer.power_on() # Web Browser should be pre-installed in computer web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") @@ -28,13 +25,10 @@ def web_browser() -> WebBrowser: def test_create_web_client(): - computer = Computer( - hostname="web_client", - ip_address="192.168.1.11", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + computer_cfg = {"type": "computer", "hostname": "web_client", "ip_address": "192.168.1.11", "subnet_mask": "255.255.255.0", "default_gateway": "192.168.1.1", "start_up_duration": 0} + + computer: Computer = Computer.from_config(config=computer_cfg) + computer.power_on() # Web Browser should be pre-installed in computer web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") From a7395c466e6d27cb671d047fe4cdf6350af8c468 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Thu, 23 Jan 2025 17:42:59 +0000 Subject: [PATCH 07/23] #2887 - Final test changes before end of day --- .../test_action_integration.py | 10 ++-- .../test_c2_suite_integration.py | 52 ++++++++++--------- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/tests/integration_tests/component_creation/test_action_integration.py b/tests/integration_tests/component_creation/test_action_integration.py index 8b81b7d3..2d493045 100644 --- a/tests/integration_tests/component_creation/test_action_integration.py +++ b/tests/integration_tests/component_creation/test_action_integration.py @@ -12,12 +12,12 @@ def test_passing_actions_down(monkeypatch) -> None: sim = Simulation() - pc1 = Computer(hostname="PC-1", ip_address="10.10.1.1", subnet_mask="255.255.255.0") + pc1 = Computer.from_config(config={"type":"computer", "hostname":"PC-1", "ip_address":"10.10.1.1", "subnet_mask":"255.255.255.0"}) pc1.start_up_duration = 0 pc1.power_on() - pc2 = Computer(hostname="PC-2", ip_address="10.10.1.2", subnet_mask="255.255.255.0") - srv = Server(hostname="WEBSERVER", ip_address="10.10.1.100", subnet_mask="255.255.255.0") - s1 = Switch(hostname="switch1") + pc2 = Computer.from_config(config={"type":"computer", "hostname":"PC-2", "ip_address":"10.10.1.2", "subnet_mask":"255.255.255.0"}) + srv = Server.from_config(config={"type":"server", "hostname":"WEBSERVER", "ip_address":"10.10.1.100", "subnet_mask":"255.255.255.0"}) + s1 = Switch.from_config(config={"type":"switch", "hostname":"switch1"}) for n in [pc1, pc2, srv, s1]: sim.network.add_node(n) @@ -48,6 +48,6 @@ def test_passing_actions_down(monkeypatch) -> None: assert not action_invoked # call the patched method - sim.apply_request(["network", "node", pc1.hostname, "file_system", "folder", "downloads", "repair"]) + sim.apply_request(["network", "node", pc1.config.hostname, "file_system", "folder", "downloads", "repair"]) assert action_invoked diff --git a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py index 40226be6..86a68865 100644 --- a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py +++ b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py @@ -34,52 +34,54 @@ def basic_network() -> Network: # Creating two generic nodes for the C2 Server and the C2 Beacon. - node_a = Computer( - hostname="node_a", - ip_address="192.168.0.2", - subnet_mask="255.255.255.252", - default_gateway="192.168.0.1", - start_up_duration=0, - ) + node_a_cfg = {"type": "computer", + "hostname": "node_a", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.252", + "default_gateway": "192.168.0.1", + "start_up_duration": 0} + + node_a: Computer = Computer.from_config(config=node_a_cfg) node_a.power_on() node_a.software_manager.get_open_ports() node_a.software_manager.install(software_class=C2Server) - node_b = Computer( - hostname="node_b", - ip_address="192.168.255.2", - subnet_mask="255.255.255.248", - default_gateway="192.168.255.1", - start_up_duration=0, - ) - + node_b_cfg = {"type": "computer", + "hostname": "node_b", + "ip_address": "192.168.255.2", + "subnet_mask": "255.255.255.248", + "default_gateway": "192.168.255.1", + "start_up_duration": 0} + + node_b: Computer = Computer.from_config(config=node_b_cfg) node_b.power_on() node_b.software_manager.install(software_class=C2Beacon) # Creating a generic computer for testing remote terminal connections. - node_c = Computer( - hostname="node_c", - ip_address="192.168.255.3", - subnet_mask="255.255.255.248", - default_gateway="192.168.255.1", - start_up_duration=0, - ) + node_c_cfg = {"type": "computer", + "hostname": "node_c", + "ip_address": "192.168.255.3", + "subnet_mask": "255.255.255.248", + "default_gateway": "192.168.255.1", + "start_up_duration": 0} + + node_c: Computer = Computer.from_config(config=node_c_cfg) node_c.power_on() # Creating a router to sit between node 1 and node 2. - router = Router(hostname="router", num_ports=3, start_up_duration=0) + router = Router.from_config(config={"type":"router", "hostname":"router", "num_ports":3, "start_up_duration":0}) # Default allow all. router.acl.add_rule(action=ACLAction.PERMIT) router.power_on() # Creating switches for each client. - switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0) + switch_1 = Switch.from_config(config={"type":"switch", "hostname":"switch_1", "num_ports":6, "start_up_duration":0}) switch_1.power_on() # Connecting the switches to the router. router.configure_port(port=1, ip_address="192.168.0.1", subnet_mask="255.255.255.252") network.connect(endpoint_a=router.network_interface[1], endpoint_b=switch_1.network_interface[6]) - switch_2 = Switch(hostname="switch_2", num_ports=6, start_up_duration=0) + switch_2 = Switch.from_config(config={"type":"switch", "hostname":"switch_2", "num_ports":6, "start_up_duration":0}) switch_2.power_on() network.connect(endpoint_a=router.network_interface[2], endpoint_b=switch_2.network_interface[6]) From 0570ab984d99b35dd6fa8ac4e4065bb6d39f503b Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 27 Jan 2025 16:35:40 +0000 Subject: [PATCH 08/23] #2887 - Node class changes to address some test failures. Addressed some inconsistencies around operating_state, amended instantiation of some Nodes in test environments --- .../source/how_to_guides/extensible_nodes.rst | 4 +- src/primaite/simulator/network/airspace.py | 2 +- src/primaite/simulator/network/container.py | 12 +- src/primaite/simulator/network/creation.py | 35 +++-- .../simulator/network/hardware/base.py | 19 +-- .../network/hardware/nodes/host/computer.py | 2 +- .../network/hardware/nodes/host/server.py | 6 +- .../hardware/nodes/network/firewall.py | 5 +- .../network/hardware/nodes/network/router.py | 12 +- .../network/hardware/nodes/network/switch.py | 10 +- .../hardware/nodes/network/wireless_router.py | 36 +++-- src/primaite/simulator/network/networks.py | 134 ++++++++-------- tests/conftest.py | 145 ++++++++++-------- .../test_action_integration.py | 14 +- .../nodes/network/test_router_config.py | 2 +- .../nodes/test_node_config.py | 2 + .../test_episode_scheduler.py | 1 + .../actions/test_node_request_permission.py | 6 +- .../observations/test_firewall_observation.py | 4 +- .../observations/test_link_observations.py | 24 ++- .../observations/test_nic_observations.py | 4 +- .../observations/test_router_observation.py | 8 +- .../game_layer/test_action_mask.py | 2 +- .../game_layer/test_actions.py | 2 +- .../game_layer/test_observations.py | 7 +- .../network/test_broadcast.py | 48 +++--- .../network/test_firewall.py | 4 +- .../test_c2_suite_integration.py | 58 ++++--- .../system/test_dns_client_server.py | 2 +- .../system/test_service_on_node.py | 34 ++-- .../system/test_web_client_server.py | 2 +- .../_network/_hardware/nodes/test_switch.py | 5 +- .../test_network_interface_actions.py | 5 +- .../_network/_hardware/test_node_actions.py | 13 +- .../_simulator/_network/test_container.py | 11 +- .../_red_applications/test_c2_suite.py | 28 ++-- .../_red_applications/test_dos_bot.py | 15 +- .../_applications/test_database_client.py | 22 ++- .../_system/_applications/test_web_browser.py | 22 ++- .../_system/_services/test_database.py | 13 +- .../_system/_services/test_dns_client.py | 24 ++- .../_system/_services/test_dns_server.py | 27 ++-- .../_system/_services/test_ftp_client.py | 15 +- .../_system/_services/test_ftp_server.py | 14 +- .../_system/_services/test_terminal.py | 65 +++++--- .../_system/_services/test_web_server.py | 14 +- 46 files changed, 548 insertions(+), 391 deletions(-) diff --git a/docs/source/how_to_guides/extensible_nodes.rst b/docs/source/how_to_guides/extensible_nodes.rst index 21907767..f0b78b08 100644 --- a/docs/source/how_to_guides/extensible_nodes.rst +++ b/docs/source/how_to_guides/extensible_nodes.rst @@ -46,7 +46,7 @@ class Router(NetworkNode, identifier="router"): num_ports: int = 5 - hostname: ClassVar[str] = "Router" + hostname: str = "Router" ports: list = [] @@ -55,4 +55,4 @@ class Router(NetworkNode, identifier="router"): Changes to YAML file. ===================== -Nodes defined within configuration YAML files for use with PrimAITE 3.X should still be compatible following these changes. \ No newline at end of file +Nodes defined within configuration YAML files for use with PrimAITE 3.X should still be compatible following these changes. diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index 1f6fe6b0..5549eb78 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -320,7 +320,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC): self.enabled = True self._connected_node.sys_log.info(f"Network Interface {self} enabled") self.pcap = PacketCapture( - hostname=self._connected_node.hostname, port_num=self.port_num, port_name=self.port_name + hostname=self._connected_node.config.hostname, port_num=self.port_num, port_name=self.port_name ) self.airspace.add_wireless_interface(self) diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index aac82633..982495a4 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -196,7 +196,13 @@ class Network(SimComponent): if port.ip_address != IPv4Address("127.0.0.1"): port_str = port.port_name if port.port_name else port.port_num table.add_row( - [node.config.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway] + [ + node.config.hostname, + port_str, + port.ip_address, + port.subnet_mask, + node.default_gateway, + ] ) print(table) @@ -288,7 +294,9 @@ class Network(SimComponent): node.parent = self self._nx_graph.add_node(node.config.hostname) _LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}") - self._node_request_manager.add_request(name=node.config.hostname, request_type=RequestType(func=node._request_manager)) + self._node_request_manager.add_request( + name=node.config.hostname, request_type=RequestType(func=node._request_manager) + ) def get_node_by_hostname(self, hostname: str) -> Optional[Node]: """ diff --git a/src/primaite/simulator/network/creation.py b/src/primaite/simulator/network/creation.py index 3221939b..2a981d59 100644 --- a/src/primaite/simulator/network/creation.py +++ b/src/primaite/simulator/network/creation.py @@ -153,7 +153,9 @@ class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"): # Create a core switch if more than one edge switch is needed if num_of_switches > 1: - core_switch = Switch.from_config(config = {"type":"switch","hostname":f"switch_core_{config.lan_name}", "start_up_duration": 0 }) + core_switch = Switch.from_config( + config={"type": "switch", "hostname": f"switch_core_{config.lan_name}", "start_up_duration": 0} + ) core_switch.power_on() network.add_node(core_switch) core_switch_port = 1 @@ -165,7 +167,9 @@ class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"): if config.include_router: default_gateway = IPv4Address(f"192.168.{config.subnet_base}.1") # router = Router(hostname=f"router_{config.lan_name}", start_up_duration=0) - router = Router.from_config(config={"hostname":f"router_{config.lan_name}", "type": "router", "start_up_duration": 0}) + router = Router.from_config( + config={"hostname": f"router_{config.lan_name}", "type": "router", "start_up_duration": 0} + ) router.power_on() router.acl.add_rule( action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 @@ -178,7 +182,9 @@ class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"): # Initialise the first edge switch and connect to the router or core switch switch_port = 0 switch_n = 1 - switch = Switch.from_config(config={"type": "switch","hostname":f"switch_edge_{switch_n}_{config.lan_name}", "start_up_duration":0}) + switch = Switch.from_config( + config={"type": "switch", "hostname": f"switch_edge_{switch_n}_{config.lan_name}", "start_up_duration": 0} + ) switch.power_on() network.add_node(switch) if num_of_switches > 1: @@ -196,7 +202,13 @@ class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"): if switch_port == effective_network_interface: switch_n += 1 switch_port = 0 - switch = Switch.from_config(config={"type": "switch","hostname":f"switch_edge_{switch_n}_{config.lan_name}", "start_up_duration":0}) + switch = Switch.from_config( + config={ + "type": "switch", + "hostname": f"switch_edge_{switch_n}_{config.lan_name}", + "start_up_duration": 0, + } + ) switch.power_on() network.add_node(switch) # Connect the new switch to the router or core switch @@ -213,13 +225,14 @@ class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"): ) # Create and add a PC to the network - pc_cfg = {"type": "computer", - "hostname": f"pc_{i}_{config.lan_name}", - "ip_address": f"192.168.{config.subnet_base}.{i+config.pcs_ip_block_start-1}", - "default_gateway": "192.168.10.1", - "start_up_duration": 0, - } - pc = Computer.from_config(config = pc_cfg) + pc_cfg = { + "type": "computer", + "hostname": f"pc_{i}_{config.lan_name}", + "ip_address": f"192.168.{config.subnet_base}.{i+config.pcs_ip_block_start-1}", + "default_gateway": "192.168.10.1", + "start_up_duration": 0, + } + pc = Computer.from_config(config=pc_cfg) pc.power_on() network.add_node(pc) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index d462f75c..de97f22b 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1529,22 +1529,21 @@ class Node(SimComponent, ABC): _identifier: ClassVar[str] = "unknown" """Identifier for this particular class, used for printing and logging. Each subclass redefines this.""" - config: Node.ConfigSchema + config: Node.ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema()) """Configuration items within Node""" class ConfigSchema(BaseModel, ABC): """Configuration Schema for Node based classes.""" model_config = ConfigDict(arbitrary_types_allowed=True) - """Configure pydantic to allow arbitrary types and to let the instance have attributes not present in the model.""" - + """Configure pydantic to allow arbitrary types, let the instance have attributes not present in the model.""" hostname: str = "default" "The node hostname on the network." revealed_to_red: bool = False "Informs whether the node has been revealed to a red agent." - start_up_duration: int = 0 + start_up_duration: int = 3 "Time steps needed for the node to start up." start_up_countdown: int = 0 @@ -1617,12 +1616,10 @@ class Node(SimComponent, ABC): file_system=kwargs.get("file_system"), dns_server=kwargs.get("dns_server"), ) - super().__init__(**kwargs) self._install_system_software() self.session_manager.node = self self.session_manager.software_manager = self.software_manager - self.power_on() @property def user_manager(self) -> Optional[UserManager]: @@ -1713,7 +1710,7 @@ class Node(SimComponent, ABC): @property def fail_message(self) -> str: """Message that is reported when a request is rejected by this validator.""" - return f"Cannot perform request on node '{self.node.hostname}' because it is not powered on." + return f"Cannot perform request on node '{self.node.config.hostname}' because it is not powered on." class _NodeIsOffValidator(RequestPermissionValidator): """ @@ -1732,7 +1729,7 @@ class Node(SimComponent, ABC): @property def fail_message(self) -> str: """Message that is reported when a request is rejected by this validator.""" - return f"Cannot perform request on node '{self.node.hostname}' because it is not turned off." + return f"Cannot perform request on node '{self.node.config.hostname}' because it is not turned off." def _init_request_manager(self) -> RequestManager: """ @@ -1900,7 +1897,7 @@ class Node(SimComponent, ABC): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.hostname} Open Ports" + table.title = f"{self.config.hostname} Open Ports" for port in self.software_manager.get_open_ports(): if port > 0: table.add_row([port]) @@ -1927,7 +1924,7 @@ class Node(SimComponent, ABC): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.hostname} Network Interface Cards" + table.title = f"{self.config.hostname} Network Interface Cards" for port, network_interface in self.network_interface.items(): ip_address = "" if hasattr(network_interface, "ip_address"): @@ -1967,7 +1964,7 @@ class Node(SimComponent, ABC): else: if self.operating_state == NodeOperatingState.BOOTING: self.operating_state = NodeOperatingState.ON - self.sys_log.info(f"{self.hostname}: Turned on") + self.sys_log.info(f"{self.config.hostname}: Turned on") for network_interface in self.network_interfaces.values(): network_interface.enable() diff --git a/src/primaite/simulator/network/hardware/nodes/host/computer.py b/src/primaite/simulator/network/hardware/nodes/host/computer.py index 85857a44..1aebc3af 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/computer.py +++ b/src/primaite/simulator/network/hardware/nodes/host/computer.py @@ -37,7 +37,7 @@ class Computer(HostNode, identifier="computer"): SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "FTPClient": FTPClient} - config: "Computer.ConfigSchema" = Field(default_factory=lambda: Computer.ConfigSchema()) + config: "Computer.ConfigSchema" = Field(default_factory=lambda: Computer.ConfigSchema()) class ConfigSchema(HostNode.ConfigSchema): """Configuration Schema for Computer class.""" diff --git a/src/primaite/simulator/network/hardware/nodes/host/server.py b/src/primaite/simulator/network/hardware/nodes/host/server.py index f1abefc2..bdf4e8c2 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/server.py +++ b/src/primaite/simulator/network/hardware/nodes/host/server.py @@ -1,6 +1,7 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK -from typing import ClassVar + from pydantic import Field + from primaite.simulator.network.hardware.nodes.host.host_node import HostNode @@ -45,10 +46,9 @@ class Printer(HostNode, identifier="printer"): # TODO: Implement printer-specific behaviour - config: "Printer.ConfigSchema" = Field(default_factory=lambda: Printer.ConfigSchema()) class ConfigSchema(HostNode.ConfigSchema): """Configuration Schema for Printer class.""" - hostname: str = "printer" \ No newline at end of file + hostname: str = "printer" diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index a30c49bd..6c582cd5 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -108,7 +108,6 @@ class Firewall(Router, identifier="firewall"): hostname: str = "firewall" num_ports: int = 0 - operating_state: NodeOperatingState = NodeOperatingState.ON def __init__(self, **kwargs): if not kwargs.get("sys_log"): @@ -242,7 +241,7 @@ class Firewall(Router, identifier="firewall"): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.hostname} Network Interfaces" + table.title = f"{self.config.hostname} Network Interfaces" ports = {"External": self.external_port, "Internal": self.internal_port, "DMZ": self.dmz_port} for port, network_interface in ports.items(): table.add_row( @@ -572,7 +571,7 @@ class Firewall(Router, identifier="firewall"): @classmethod def from_config(cls, config: dict) -> "Firewall": """Create a firewall based on a config dict.""" - firewall = Firewall(config = cls.ConfigSchema(**config)) + firewall = Firewall(config=cls.ConfigSchema(**config)) if "ports" in config: internal_port = config["ports"]["internal_port"] diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 4f9d9ca4..3ecb761b 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1218,7 +1218,7 @@ class Router(NetworkNode, identifier="router"): route_table: RouteTable - config: "Router.ConfigSchema" + config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSchema()) class ConfigSchema(NetworkNode.ConfigSchema): """Configuration Schema for Router Objects.""" @@ -1230,12 +1230,13 @@ class Router(NetworkNode, identifier="router"): ports: Dict[Union[int, str], Dict] = {} - def __init__(self, **kwargs): if not kwargs.get("sys_log"): kwargs["sys_log"] = SysLog(kwargs["config"].hostname) if not kwargs.get("acl"): - kwargs["acl"] = AccessControlList(sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=kwargs["config"].hostname) + kwargs["acl"] = AccessControlList( + sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=kwargs["config"].hostname + ) if not kwargs.get("route_table"): kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"]) super().__init__(**kwargs) @@ -1632,8 +1633,7 @@ class Router(NetworkNode, identifier="router"): :return: Configured router. :rtype: Router """ - router = Router(config=Router.ConfigSchema(**config) - ) + router = Router(config=Router.ConfigSchema(**config)) if "ports" in config: for port_num, port_cfg in config["ports"].items(): router.configure_port( @@ -1666,4 +1666,4 @@ class Router(NetworkNode, identifier="router"): next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None) if next_hop_ip_address: router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) - return router \ No newline at end of file + return router diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index e97c5321..3cb335f7 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import ClassVar, Dict, Optional +from typing import Dict, Optional from prettytable import MARKDOWN, PrettyTable from pydantic import Field @@ -89,11 +89,7 @@ class SwitchPort(WiredNetworkInterface): class Switch(NetworkNode, identifier="switch"): - """ - A class representing a Layer 2 network switch. - - :ivar num_ports: The number of ports on the switch. Default is 24. - """ + """A class representing a Layer 2 network switch.""" network_interfaces: Dict[str, SwitchPort] = {} "The SwitchPorts on the Switch." @@ -102,7 +98,7 @@ class Switch(NetworkNode, identifier="switch"): mac_address_table: Dict[str, SwitchPort] = {} "A MAC address table mapping destination MAC addresses to corresponding SwitchPorts." - config: "Switch.ConfigSchema" + config: "Switch.ConfigSchema" = Field(default_factory=lambda: Switch.ConfigSchema()) class ConfigSchema(NetworkNode.ConfigSchema): """Configuration Schema for Switch nodes within PrimAITE.""" diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 75e4d5ea..70e655ac 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -122,7 +122,6 @@ class WirelessRouter(Router, identifier="wireless_router"): network_interfaces: Dict[str, Union[RouterInterface, WirelessAccessPoint]] = {} network_interface: Dict[int, Union[RouterInterface, WirelessAccessPoint]] = {} - airspace: AirSpace config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.ConfigSchema()) @@ -130,12 +129,15 @@ class WirelessRouter(Router, identifier="wireless_router"): """Configuration Schema for WirelessRouter nodes within PrimAITE.""" hostname: str = "WirelessRouter" + airspace: Optional[AirSpace] = None def __init__(self, **kwargs): - super().__init__(hostname=kwargs["config"].hostname, num_ports=0, airspace=kwargs["config"].airspace, **kwargs) + super().__init__(**kwargs) self.connect_nic( - WirelessAccessPoint(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", airspace=kwargs["config"].airspace) + WirelessAccessPoint( + ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", airspace=kwargs["config"].airspace + ) ) self.connect_nic(RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")) @@ -233,7 +235,7 @@ class WirelessRouter(Router, identifier="wireless_router"): ) @classmethod - def from_config(cls, cfg: Dict, **kwargs) -> "WirelessRouter": + def from_config(cls, config: Dict, **kwargs) -> "WirelessRouter": """Generate the wireless router from config. Schema: @@ -261,21 +263,21 @@ class WirelessRouter(Router, identifier="wireless_router"): :rtype: WirelessRouter """ operating_state = ( - NodeOperatingState.ON if not (p := cfg.get("operating_state")) else NodeOperatingState[p.upper()] + NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()] ) - router = cls(hostname=cfg["hostname"], operating_state=operating_state, airspace=kwargs["airspace"]) - if "router_interface" in cfg: - ip_address = cfg["router_interface"]["ip_address"] - subnet_mask = cfg["router_interface"]["subnet_mask"] + router = cls(config=cls.ConfigSchema(**config)) + if "router_interface" in config: + ip_address = config["router_interface"]["ip_address"] + subnet_mask = config["router_interface"]["subnet_mask"] router.configure_router_interface(ip_address=ip_address, subnet_mask=subnet_mask) - if "wireless_access_point" in cfg: - ip_address = cfg["wireless_access_point"]["ip_address"] - subnet_mask = cfg["wireless_access_point"]["subnet_mask"] - frequency = AirSpaceFrequency._registry[cfg["wireless_access_point"]["frequency"]] + if "wireless_access_point" in config: + ip_address = config["wireless_access_point"]["ip_address"] + subnet_mask = config["wireless_access_point"]["subnet_mask"] + frequency = AirSpaceFrequency._registry[config["wireless_access_point"]["frequency"]] router.configure_wireless_access_point(ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency) - if "acl" in cfg: - for r_num, r_cfg in cfg["acl"].items(): + if "acl" in config: + for r_num, r_cfg in config["acl"].items(): router.acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -287,8 +289,8 @@ class WirelessRouter(Router, identifier="wireless_router"): dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"), position=r_num, ) - if "routes" in cfg: - for route in cfg.get("routes"): + if "routes" in config: + for route in config.get("routes"): router.route_table.add_route( address=IPv4Address(route.get("address")), subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")), diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index 4d881343..0579f137 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -128,33 +128,40 @@ def arcd_uc2_network() -> Network: network = Network() # Router 1 - router_1 = Router.from_config(config={"type":"router", "hostname":"router_1", "num_ports":5, "start_up_duration":0}) + router_1 = Router.from_config( + config={"type": "router", "hostname": "router_1", "num_ports": 5, "start_up_duration": 0} + ) router_1.power_on() router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0") router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0") # Switch 1 - switch_1 = Switch.from_config(config={"type":"switch", "hostname":"switch_1", "num_ports":8, "start_up_duration":0}) + switch_1 = Switch.from_config( + config={"type": "switch", "hostname": "switch_1", "num_ports": 8, "start_up_duration": 0} + ) switch_1.power_on() network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[8]) router_1.enable_port(1) # Switch 2 - switch_2 = Switch.from_config(config={"type":"switch", "hostname":"switch_2", "num_ports":8, "start_up_duration":0}) + switch_2 = Switch.from_config( + config={"type": "switch", "hostname": "switch_2", "num_ports": 8, "start_up_duration": 0} + ) switch_2.power_on() network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8]) router_1.enable_port(2) # Client 1 - client_1_cfg = {"type": "computer", - "hostname": "client_1", - "ip_address": "192.168.10.21", - "subnet_mask": "255.255.255.0", - "default_gateway": "192.168.10.1", - "dns_server": IPv4Address("192.168.1.10"), - "start_up_duration": 0, - } - client_1: Computer = Computer.from_config(config = client_1_cfg) + client_1_cfg = { + "type": "computer", + "hostname": "client_1", + "ip_address": "192.168.10.21", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.10.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } + client_1: Computer = Computer.from_config(config=client_1_cfg) client_1.power_on() network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) @@ -175,15 +182,16 @@ def arcd_uc2_network() -> Network: # Client 2 - client_2_cfg = {"type": "computer", - "hostname": "client_2", - "ip_address": "192.168.10.22", - "subnet_mask": "255.255.255.0", - "default_gateway": "192.168.10.1", - "dns_server": IPv4Address("192.168.1.10"), - "start_up_duration": 0, - } - client_2: Computer = Computer.from_config(config = client_2_cfg) + client_2_cfg = { + "type": "computer", + "hostname": "client_2", + "ip_address": "192.168.10.22", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.10.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } + client_2: Computer = Computer.from_config(config=client_2_cfg) client_2.power_on() client_2.software_manager.install(DatabaseClient) @@ -199,13 +207,14 @@ def arcd_uc2_network() -> Network: # Domain Controller - domain_controller_cfg = {"type": "server", - "hostname": "domain_controller", - "ip_address": "192.168.1.10", - "subnet_mask": "255.255.255.0", - "default_gateway": "192.168.1.1", - "start_up_duration": 0 - } + domain_controller_cfg = { + "type": "server", + "hostname": "domain_controller", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } domain_controller = Server.from_config(config=domain_controller_cfg) domain_controller.power_on() @@ -215,14 +224,15 @@ def arcd_uc2_network() -> Network: # Database Server - database_server_cfg = {"type": "server", - "hostname": "database_server", - "ip_address": "192.168.1.14", - "subnet_mask": "255.255.255.0", - "default_gateway": "192.168.1.1", - "dns_server": IPv4Address("192.168.1.10"), - "start_up_duration": 0 - } + database_server_cfg = { + "type": "server", + "hostname": "database_server", + "ip_address": "192.168.1.14", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } database_server = Server.from_config(config=database_server_cfg) @@ -236,15 +246,15 @@ def arcd_uc2_network() -> Network: # Web Server - - web_server_cfg = {"type": "server", - "hostname": "web_server", - "ip_address": "192.168.1.11", - "subnet_mask": "255.255.255.0", - "default_gateway": "192.168.1.1", - "dns_server": IPv4Address("192.168.1.10"), - "start_up_duration": 0 - } + web_server_cfg = { + "type": "server", + "hostname": "web_server", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } web_server = Server.from_config(config=web_server_cfg) web_server.power_on() @@ -263,14 +273,15 @@ def arcd_uc2_network() -> Network: dns_server_service.dns_register("arcd.com", web_server.network_interface[1].ip_address) # Backup Server - backup_server_cfg = {"type": "server", - "hostname": "backup_server", - "ip_address": "192.168.1.16", - "subnet_mask": "255.255.255.0", - "default_gateway": "192.168.1.1", - "dns_server": IPv4Address("192.168.1.10"), - "start_up_duration": 0 - } + backup_server_cfg = { + "type": "server", + "hostname": "backup_server", + "ip_address": "192.168.1.16", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } backup_server: Server = Server.from_config(config=backup_server_cfg) backup_server.power_on() @@ -278,14 +289,15 @@ def arcd_uc2_network() -> Network: network.connect(endpoint_b=backup_server.network_interface[1], endpoint_a=switch_1.network_interface[4]) # Security Suite - security_suite_cfg = {"type": "server", - "hostname": "backup_server", - "ip_address": "192.168.1.110", - "subnet_mask": "255.255.255.0", - "default_gateway": "192.168.1.1", - "dns_server": IPv4Address("192.168.1.10"), - "start_up_duration": 0 - } + security_suite_cfg = { + "type": "server", + "hostname": "backup_server", + "ip_address": "192.168.1.110", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } security_suite: Server = Server.from_config(config=security_suite_cfg) security_suite.power_on() network.connect(endpoint_b=security_suite.network_interface[1], endpoint_a=switch_1.network_interface[7]) diff --git a/tests/conftest.py b/tests/conftest.py index 1bdc217c..6ac227ef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -119,13 +119,13 @@ def application_class(): @pytest.fixture(scope="function") def file_system() -> FileSystem: - # computer = Computer(hostname="fs_node", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) - computer_cfg = {"type": "computer", - "hostname": "fs_node", - "ip_address": "192.168.1.2", - "subnet_mask": "255.255.255.0", - "start_up_duration": 0, - } + computer_cfg = { + "type": "computer", + "hostname": "fs_node", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } computer = Computer.from_config(config=computer_cfg) computer.power_on() return computer.file_system @@ -136,23 +136,29 @@ def client_server() -> Tuple[Computer, Server]: network = Network() # Create Computer - computer = Computer( - hostname="computer", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + computer_cfg = { + "type": "computer", + "hostname": "computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + computer: Computer = Computer.from_config(config=computer_cfg) computer.power_on() # Create Server - server = Server( - hostname="server", - ip_address="192.168.1.3", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + server_cfg = { + "type": "server", + "hostname": "server", + "ip_address": "192.168.1.3", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + server: Server = Server.from_config(config=server_cfg) server.power_on() # Connect Computer and Server @@ -169,26 +175,33 @@ def client_switch_server() -> Tuple[Computer, Switch, Server]: network = Network() # Create Computer - computer = Computer( - hostname="computer", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + computer_cfg = { + "type": "computer", + "hostname": "computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + computer: Computer = Computer.from_config(config=computer_cfg) computer.power_on() # Create Server - server = Server( - hostname="server", - ip_address="192.168.1.3", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + server_cfg = { + "type": "server", + "hostname": "server", + "ip_address": "192.168.1.3", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + server: Server = Server.from_config(config=server_cfg) server.power_on() - switch = Switch(hostname="switch", start_up_duration=0) + # Create Switch + switch: Switch = Switch.from_config(config={"type": "switch", "hostname": "switch", "start_up_duration": 0}) switch.power_on() network.connect(endpoint_a=computer.network_interface[1], endpoint_b=switch.network_interface[1]) @@ -219,7 +232,7 @@ def example_network() -> Network: # Router 1 - router_1_cfg = {"hostname": "router_1", "type": "router"} + router_1_cfg = {"hostname": "router_1", "type": "router", "start_up_duration":0} # router_1 = Router(hostname="router_1", start_up_duration=0) router_1 = Router.from_config(config=router_1_cfg) @@ -229,7 +242,7 @@ def example_network() -> Network: # Switch 1 - switch_1_cfg = {"hostname": "switch_1", "type": "switch"} + switch_1_cfg = {"hostname": "switch_1", "type": "switch", "start_up_duration": 0} switch_1 = Switch.from_config(config=switch_1_cfg) @@ -240,7 +253,7 @@ def example_network() -> Network: router_1.enable_port(1) # Switch 2 - switch_2_config = {"hostname": "switch_2", "type": "switch", "num_ports": 8} + switch_2_config = {"hostname": "switch_2", "type": "switch", "num_ports": 8, "start_up_duration":0} # switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0) switch_2 = Switch.from_config(config=switch_2_config) switch_2.power_on() @@ -348,28 +361,34 @@ def install_stuff_to_sim(sim: Simulation): # 1: Set up network hardware # 1.1: Configure the router - router = Router.from_config(config={"type":"router", "hostname":"router", "num_ports":3, "start_up_duration":0}) + router = Router.from_config(config={"type": "router", "hostname": "router", "num_ports": 3, "start_up_duration": 0}) router.power_on() router.configure_port(port=1, ip_address="10.0.1.1", subnet_mask="255.255.255.0") router.configure_port(port=2, ip_address="10.0.2.1", subnet_mask="255.255.255.0") # 1.2: Create and connect switches - switch_1 = Switch.from_config(config={"type":"switch", "hostname":"switch_1", "num_ports":6, "start_up_duration":0}) + switch_1 = Switch.from_config( + config={"type": "switch", "hostname": "switch_1", "num_ports": 6, "start_up_duration": 0} + ) switch_1.power_on() network.connect(endpoint_a=router.network_interface[1], endpoint_b=switch_1.network_interface[6]) router.enable_port(1) - switch_2 = Switch.from_config(config={"type":"switch", "hostname":"switch_2", "num_ports":6, "start_up_duration":0}) + switch_2 = Switch.from_config( + config={"type": "switch", "hostname": "switch_2", "num_ports": 6, "start_up_duration": 0} + ) switch_2.power_on() network.connect(endpoint_a=router.network_interface[2], endpoint_b=switch_2.network_interface[6]) router.enable_port(2) # 1.3: Create and connect computer - client_1_cfg = {"type": "computer", - "hostname": "client_1", - "ip_address":"10.0.1.2", - "subnet_mask":"255.255.255.0", - "default_gateway": "10.0.1.1", - "start_up_duration":0} + client_1_cfg = { + "type": "computer", + "hostname": "client_1", + "ip_address": "10.0.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "10.0.1.1", + "start_up_duration": 0, + } client_1: Computer = Computer.from_config(config=client_1_cfg) client_1.power_on() network.connect( @@ -378,24 +397,26 @@ def install_stuff_to_sim(sim: Simulation): ) # 1.4: Create and connect servers - server_1_cfg = {"type": "server", - "hostname":"server_1", - "ip_address": "10.0.2.2", - "subnet_mask":"255.255.255.0", - "default_gateway":"10.0.2.1", - "start_up_duration": 0} - + server_1_cfg = { + "type": "server", + "hostname": "server_1", + "ip_address": "10.0.2.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "10.0.2.1", + "start_up_duration": 0, + } server_1: Server = Server.from_config(config=server_1_cfg) server_1.power_on() network.connect(endpoint_a=server_1.network_interface[1], endpoint_b=switch_2.network_interface[1]) - server_2_cfg = {"type": "server", - "hostname":"server_2", - "ip_address": "10.0.2.3", - "subnet_mask":"255.255.255.0", - "default_gateway":"10.0.2.1", - "start_up_duration": 0} - + server_2_cfg = { + "type": "server", + "hostname": "server_2", + "ip_address": "10.0.2.3", + "subnet_mask": "255.255.255.0", + "default_gateway": "10.0.2.1", + "start_up_duration": 0, + } server_2: Server = Server.from_config(config=server_2_cfg) server_2.power_on() diff --git a/tests/integration_tests/component_creation/test_action_integration.py b/tests/integration_tests/component_creation/test_action_integration.py index 2d493045..0fd0aa19 100644 --- a/tests/integration_tests/component_creation/test_action_integration.py +++ b/tests/integration_tests/component_creation/test_action_integration.py @@ -12,12 +12,18 @@ def test_passing_actions_down(monkeypatch) -> None: sim = Simulation() - pc1 = Computer.from_config(config={"type":"computer", "hostname":"PC-1", "ip_address":"10.10.1.1", "subnet_mask":"255.255.255.0"}) + pc1 = Computer.from_config( + config={"type": "computer", "hostname": "PC-1", "ip_address": "10.10.1.1", "subnet_mask": "255.255.255.0"} + ) pc1.start_up_duration = 0 pc1.power_on() - pc2 = Computer.from_config(config={"type":"computer", "hostname":"PC-2", "ip_address":"10.10.1.2", "subnet_mask":"255.255.255.0"}) - srv = Server.from_config(config={"type":"server", "hostname":"WEBSERVER", "ip_address":"10.10.1.100", "subnet_mask":"255.255.255.0"}) - s1 = Switch.from_config(config={"type":"switch", "hostname":"switch1"}) + pc2 = Computer.from_config( + config={"type": "computer", "hostname": "PC-2", "ip_address": "10.10.1.2", "subnet_mask": "255.255.255.0"} + ) + srv = Server.from_config( + config={"type": "server", "hostname": "WEBSERVER", "ip_address": "10.10.1.100", "subnet_mask": "255.255.255.0"} + ) + s1 = Switch.from_config(config={"type": "switch", "hostname": "switch1"}) for n in [pc1, pc2, srv, s1]: sim.network.add_node(n) diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py index c9691fab..7ca3a6aa 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py @@ -5,8 +5,8 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.firewall import Firewall +from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP from primaite.utils.validation.port import PORT_LOOKUP from tests.integration_tests.configuration_file_parsing import DMZ_NETWORK, load_config diff --git a/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py b/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py index 764a7aac..6ccbf4e1 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py @@ -3,6 +3,8 @@ from primaite.config.load import data_manipulation_config_path from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter +from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ_NETWORK, load_config diff --git a/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py b/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py index c588829b..1352f894 100644 --- a/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py +++ b/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py @@ -4,6 +4,7 @@ import yaml from primaite.session.environment import PrimaiteGymEnv from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from tests.conftest import TEST_ASSETS_ROOT folder_path = TEST_ASSETS_ROOT / "configs" / "scenario_with_placeholders" diff --git a/tests/integration_tests/game_layer/actions/test_node_request_permission.py b/tests/integration_tests/game_layer/actions/test_node_request_permission.py index 8a438673..32c9e8a5 100644 --- a/tests/integration_tests/game_layer/actions/test_node_request_permission.py +++ b/tests/integration_tests/game_layer/actions/test_node_request_permission.py @@ -35,7 +35,7 @@ def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy assert client_1.operating_state == NodeOperatingState.SHUTTING_DOWN - for i in range(client_1.shut_down_duration + 1): + for i in range(client_1.config.shut_down_duration + 1): action = ("do_nothing", {}) agent.store_action(action) game.step() @@ -49,7 +49,7 @@ def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy assert client_1.operating_state == NodeOperatingState.BOOTING - for i in range(client_1.start_up_duration + 1): + for i in range(client_1.config.start_up_duration + 1): action = ("do_nothing", {}) agent.store_action(action) game.step() @@ -79,7 +79,7 @@ def test_node_cannot_be_shut_down_if_node_is_already_off(game_and_agent_fixture: client_1 = game.simulation.network.get_node_by_hostname("client_1") client_1.power_off() - for i in range(client_1.shut_down_duration + 1): + for i in range(client_1.config.shut_down_duration + 1): action = ("do_nothing", {}) agent.store_action(action) game.step() diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 6b0d4359..17c7775f 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -117,7 +117,9 @@ def test_firewall_observation(): assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(1, 4)) # connect a switch to the firewall and check that only the correct port is updated - switch: Switch = Switch.from_config(config={"type": "switch", "hostname":"switch", "num_ports":1, "operating_state":NodeOperatingState.ON}) + switch: Switch = Switch.from_config( + config={"type": "switch", "hostname": "switch", "num_ports": 1, "operating_state": NodeOperatingState.ON} + ) link = net.connect(firewall.network_interface[1], switch.network_interface[1]) assert firewall.network_interface[1].enabled observation = firewall_observation.observe(firewall.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_link_observations.py b/tests/integration_tests/game_layer/observations/test_link_observations.py index b5cd6134..f95d35c2 100644 --- a/tests/integration_tests/game_layer/observations/test_link_observations.py +++ b/tests/integration_tests/game_layer/observations/test_link_observations.py @@ -56,12 +56,26 @@ def test_link_observation(): """Check the shape and contents of the link observation.""" net = Network() sim = Simulation(network=net) - switch: Switch = Switch.from_config(config={"type":"switch", "hostname":"switch", "num_ports":5, "operating_state":NodeOperatingState.ON}) - computer_1: Computer = Computer.from_config(config={"type": "computer", - "hostname":"computer_1", "ip_address":"10.0.0.1", "subnet_mask":"255.255.255.0", "start_up_duration":0} + switch: Switch = Switch.from_config( + config={"type": "switch", "hostname": "switch", "num_ports": 5, "operating_state": NodeOperatingState.ON} ) - computer_2: Computer = Computer.from_config(config={"type":"computer", - "hostname":"computer_2", "ip_address":"10.0.0.2", "subnet_mask":"255.255.255.0", "start_up_duration":0} + computer_1: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer_1", + "ip_address": "10.0.0.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) + computer_2: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer_2", + "ip_address": "10.0.0.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } ) computer_1.power_on() computer_2.power_on() diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index 2a311853..5e1c0f81 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -163,7 +163,9 @@ def test_nic_monitored_traffic(simulation): pc2: Computer = simulation.network.get_node_by_hostname("client_2") nic_obs = NICObservation( - where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=False, monitored_traffic=monitored_traffic + where=["network", "nodes", pc.config.hostname, "NICs", 1], + include_nmne=False, + monitored_traffic=monitored_traffic, ) simulation.pre_timestep(0) # apply timestep to whole sim diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index 131af57f..8335867d 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -16,7 +16,9 @@ from primaite.utils.validation.port import PORT_LOOKUP def test_router_observation(): """Test adding/removing acl rules and enabling/disabling ports.""" net = Network() - router = Router.from_config(config={"type": "router", "hostname":"router", "num_ports":5, "operating_state":NodeOperatingState.ON}) + router = Router.from_config( + config={"type": "router", "hostname": "router", "num_ports": 5, "operating_state": NodeOperatingState.ON} + ) ports = [PortObservation(where=["NICs", i]) for i in range(1, 6)] acl = ACLObservation( @@ -89,7 +91,9 @@ def test_router_observation(): assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(1, 6)) # connect a switch to the router and check that only the correct port is updated - switch: Switch = Switch.from_config(config={"type": "switch", "hostname":"switch", "num_ports":1, "operating_state":NodeOperatingState.ON}) + switch: Switch = Switch.from_config( + config={"type": "switch", "hostname": "switch", "num_ports": 1, "operating_state": NodeOperatingState.ON} + ) link = net.connect(router.network_interface[1], switch.network_interface[1]) assert router.network_interface[1].enabled observed_output = router_observation.observe(router.describe_state()) diff --git a/tests/integration_tests/game_layer/test_action_mask.py b/tests/integration_tests/game_layer/test_action_mask.py index ebba1119..4ac7b9a6 100644 --- a/tests/integration_tests/game_layer/test_action_mask.py +++ b/tests/integration_tests/game_layer/test_action_mask.py @@ -2,8 +2,8 @@ from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.host_node import HostNode -from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter +from primaite.simulator.system.services.service import ServiceOperatingState from tests.conftest import TEST_ASSETS_ROOT CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml" diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index 9d9b528c..cd230546 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -17,11 +17,11 @@ from typing import Tuple import pytest import yaml -from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus +from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.software import SoftwareHealthState diff --git a/tests/integration_tests/game_layer/test_observations.py b/tests/integration_tests/game_layer/test_observations.py index 5afad296..090725b5 100644 --- a/tests/integration_tests/game_layer/test_observations.py +++ b/tests/integration_tests/game_layer/test_observations.py @@ -8,14 +8,17 @@ from primaite.simulator.sim_container import Simulation def test_file_observation(): sim = Simulation() - pc = Computer(hostname="beep", ip_address="123.123.123.123", subnet_mask="255.255.255.0") + pc: Computer = Computer.from_config(config={"type":"computer", + "hostname":"beep", + "ip_address":"123.123.123.123", + "subnet_mask":"255.255.255.0"}) sim.network.add_node(pc) f = pc.file_system.create_file(file_name="dog.png") state = sim.describe_state() dog_file_obs = FileObservation( - where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"], include_num_access=False, file_system_requires_scan=False, ) diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index 5c30d2ac..5469e803 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -84,24 +84,28 @@ class BroadcastTestClient(Application, identifier="BroadcastTestClient"): def broadcast_network() -> Network: network = Network() - client_1_cfg = {"type": "computer", - "hostname": "client_1", - "ip_address":"192.168.1.2", - "subnet_mask": "255.255.255.0", - "default_gateway": "192.168.1.1", - "start_up_duration":0} + client_1_cfg = { + "type": "computer", + "hostname": "client_1", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } client_1: Computer = Computer.from_config(config=client_1_cfg) client_1.power_on() client_1.software_manager.install(BroadcastTestClient) application_1 = client_1.software_manager.software["BroadcastTestClient"] application_1.run() - client_2_cfg = {"type": "computer", - "hostname": "client_2", - "ip_address":"192.168.1.3", - "subnet_mask": "255.255.255.0", - "default_gateway": "192.168.1.1", - "start_up_duration":0} + client_2_cfg = { + "type": "computer", + "hostname": "client_2", + "ip_address": "192.168.1.3", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } client_2: Computer = Computer.from_config(config=client_2_cfg) client_2.power_on() @@ -109,14 +113,16 @@ def broadcast_network() -> Network: application_2 = client_2.software_manager.software["BroadcastTestClient"] application_2.run() - server_1_cfg = {"type": "server", - "hostname": "server_1", - "ip_address":"192.168.1.1", - "subnet_mask": "255.255.255.0", - "default_gateway":"192.168.1.1", - "start_up_duration": 0} + server_1_cfg = { + "type": "server", + "hostname": "server_1", + "ip_address": "192.168.1.1", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } - server_1 :Server = Server.from_config(config=server_1_cfg) + server_1: Server = Server.from_config(config=server_1_cfg) server_1.power_on() @@ -124,7 +130,9 @@ def broadcast_network() -> Network: service: BroadcastTestService = server_1.software_manager.software["BroadcastService"] service.start() - switch_1: Switch = Switch.from_config(config={"type": "switch", "hostname":"switch_1", "num_ports":6, "start_up_duration":0}) + switch_1: Switch = Switch.from_config( + config={"type": "switch", "hostname": "switch_1", "num_ports": 6, "start_up_duration": 0} + ) switch_1.power_on() network.connect(endpoint_a=client_1.network_interface[1], endpoint_b=switch_1.network_interface[1]) diff --git a/tests/integration_tests/network/test_firewall.py b/tests/integration_tests/network/test_firewall.py index 24fbfd05..69f3e5ab 100644 --- a/tests/integration_tests/network/test_firewall.py +++ b/tests/integration_tests/network/test_firewall.py @@ -41,7 +41,9 @@ def dmz_external_internal_network() -> Network: """ network = Network() - firewall_node: Firewall = Firewall(hostname="firewall_1", start_up_duration=0) + firewall_node: Firewall = Firewall.from_config( + config={"type": "firewall", "hostname": "firewall_1", "start_up_duration": 0} + ) firewall_node.power_on() # configure firewall ports firewall_node.configure_external_port( diff --git a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py index 86a68865..1cc4e8e2 100644 --- a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py +++ b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py @@ -34,54 +34,64 @@ def basic_network() -> Network: # Creating two generic nodes for the C2 Server and the C2 Beacon. - node_a_cfg = {"type": "computer", - "hostname": "node_a", - "ip_address": "192.168.0.2", - "subnet_mask": "255.255.255.252", - "default_gateway": "192.168.0.1", - "start_up_duration": 0} - + node_a_cfg = { + "type": "computer", + "hostname": "node_a", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.252", + "default_gateway": "192.168.0.1", + "start_up_duration": 0, + } + node_a: Computer = Computer.from_config(config=node_a_cfg) node_a.power_on() node_a.software_manager.get_open_ports() node_a.software_manager.install(software_class=C2Server) - node_b_cfg = {"type": "computer", - "hostname": "node_b", - "ip_address": "192.168.255.2", - "subnet_mask": "255.255.255.248", - "default_gateway": "192.168.255.1", - "start_up_duration": 0} - + node_b_cfg = { + "type": "computer", + "hostname": "node_b", + "ip_address": "192.168.255.2", + "subnet_mask": "255.255.255.248", + "default_gateway": "192.168.255.1", + "start_up_duration": 0, + } + node_b: Computer = Computer.from_config(config=node_b_cfg) node_b.power_on() node_b.software_manager.install(software_class=C2Beacon) # Creating a generic computer for testing remote terminal connections. - node_c_cfg = {"type": "computer", - "hostname": "node_c", - "ip_address": "192.168.255.3", - "subnet_mask": "255.255.255.248", - "default_gateway": "192.168.255.1", - "start_up_duration": 0} - + node_c_cfg = { + "type": "computer", + "hostname": "node_c", + "ip_address": "192.168.255.3", + "subnet_mask": "255.255.255.248", + "default_gateway": "192.168.255.1", + "start_up_duration": 0, + } + node_c: Computer = Computer.from_config(config=node_c_cfg) node_c.power_on() # Creating a router to sit between node 1 and node 2. - router = Router.from_config(config={"type":"router", "hostname":"router", "num_ports":3, "start_up_duration":0}) + router = Router.from_config(config={"type": "router", "hostname": "router", "num_ports": 3, "start_up_duration": 0}) # Default allow all. router.acl.add_rule(action=ACLAction.PERMIT) router.power_on() # Creating switches for each client. - switch_1 = Switch.from_config(config={"type":"switch", "hostname":"switch_1", "num_ports":6, "start_up_duration":0}) + switch_1 = Switch.from_config( + config={"type": "switch", "hostname": "switch_1", "num_ports": 6, "start_up_duration": 0} + ) switch_1.power_on() # Connecting the switches to the router. router.configure_port(port=1, ip_address="192.168.0.1", subnet_mask="255.255.255.252") network.connect(endpoint_a=router.network_interface[1], endpoint_b=switch_1.network_interface[6]) - switch_2 = Switch.from_config(config={"type":"switch", "hostname":"switch_2", "num_ports":6, "start_up_duration":0}) + switch_2 = Switch.from_config( + config={"type": "switch", "hostname": "switch_2", "num_ports": 6, "start_up_duration": 0} + ) switch_2.power_on() network.connect(endpoint_a=router.network_interface[2], endpoint_b=switch_2.network_interface[6]) diff --git a/tests/integration_tests/system/test_dns_client_server.py b/tests/integration_tests/system/test_dns_client_server.py index 38caf1a2..8266c814 100644 --- a/tests/integration_tests/system/test_dns_client_server.py +++ b/tests/integration_tests/system/test_dns_client_server.py @@ -72,7 +72,7 @@ def test_dns_client_requests_offline_dns_server(dns_client_and_dns_server): server.power_off() - for i in range(server.shut_down_duration + 1): + for i in range(server.config.shut_down_duration + 1): server.apply_timestep(timestep=i) assert server.operating_state == NodeOperatingState.OFF diff --git a/tests/integration_tests/system/test_service_on_node.py b/tests/integration_tests/system/test_service_on_node.py index 4e73a050..15fd3ccd 100644 --- a/tests/integration_tests/system/test_service_on_node.py +++ b/tests/integration_tests/system/test_service_on_node.py @@ -13,13 +13,15 @@ from primaite.simulator.system.services.service import Service, ServiceOperating def populated_node( service_class, ) -> Tuple[Server, Service]: - server = Server( - hostname="server", - ip_address="192.168.0.1", - subnet_mask="255.255.255.0", - start_up_duration=0, - shut_down_duration=0, - ) + server_cfg = { + "type": "server", + "hostname": "server", + "ip_address": "192.168.0.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + "shut_down_duration": 0, + } + server: Server = Server.from_config(config=server_cfg) server.power_on() server.software_manager.install(service_class) @@ -31,14 +33,16 @@ def populated_node( def test_service_on_offline_node(service_class): """Test to check that the service cannot be interacted with when node it is on is off.""" - computer: Computer = Computer( - hostname="test_computer", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - shut_down_duration=0, - ) + computer_cfg = { + "type": "computer", + "hostname": "test_computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + "shut_down_duration": 0, + } + computer: Computer = Computer.from_config(config=computer_cfg) computer.power_on() computer.software_manager.install(service_class) diff --git a/tests/integration_tests/system/test_web_client_server.py b/tests/integration_tests/system/test_web_client_server.py index 8aea34c1..8873a494 100644 --- a/tests/integration_tests/system/test_web_client_server.py +++ b/tests/integration_tests/system/test_web_client_server.py @@ -94,7 +94,7 @@ def test_web_page_request_from_shut_down_server(web_client_and_web_server): server.power_off() - for i in range(server.shut_down_duration + 1): + for i in range(server.config.shut_down_duration + 1): server.apply_timestep(timestep=i) # node should be off diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_switch.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_switch.py index dbc04f6d..e45fe45d 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_switch.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_switch.py @@ -7,10 +7,7 @@ from primaite.simulator.network.hardware.nodes.network.switch import Switch @pytest.fixture(scope="function") def switch() -> Switch: - switch_cfg = {"type": "switch", - "hostname": "switch_1", - "num_ports": 8, - "start_up_duration": 0} + switch_cfg = {"type": "switch", "hostname": "switch_1", "num_ports": 8, "start_up_duration": 0} switch: Switch = Switch.from_config(config=switch_cfg) switch.power_on() switch.show() diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py index 0e0023cd..cb2d3935 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py @@ -7,10 +7,7 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer @pytest.fixture def node() -> Node: - computer_cfg = {"type": "computer", - "hostname": "test", - "ip_address": "192.168.1.2", - "subnet_mask": "255.255.255.0"} + computer_cfg = {"type": "computer", "hostname": "test", "ip_address": "192.168.1.2", "subnet_mask": "255.255.255.0"} computer = Computer.from_config(config=computer_cfg) return computer diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py index d077f46b..f6308a21 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py @@ -12,13 +12,12 @@ from tests.conftest import DummyApplication, DummyService @pytest.fixture def node() -> Node: - computer_cfg = {"type": "computer", - "hostname": "test", - "ip_address": "192.168.1.2", - "subnet_mask": "255.255.255.0", - "shut_down_duration": 3, - "operating_state": "OFF", - } + computer_cfg = { + "type": "computer", + "hostname": "test", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + } computer = Computer.from_config(config=computer_cfg) return computer diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_container.py b/tests/unit_tests/_primaite/_simulator/_network/test_container.py index d175b865..9a54f7b2 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_container.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_container.py @@ -74,7 +74,16 @@ def test_removing_node_that_does_not_exist(network): """Node that does not exist on network should not affect existing nodes.""" assert len(network.nodes) is 7 - network.remove_node(Computer.from_config(config = {"type":"computer","hostname":"new_node", "ip_address":"192.168.1.2", "subnet_mask":"255.255.255.0"})) + network.remove_node( + Computer.from_config( + config={ + "type": "computer", + "hostname": "new_node", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + } + ) + ) assert len(network.nodes) is 7 diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py index 5d8bea80..4ce224f7 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py @@ -16,23 +16,25 @@ def basic_c2_network() -> Network: network = Network() # Creating two generic nodes for the C2 Server and the C2 Beacon. - computer_a_cfg = {"type": "computer", - "hostname": "computer_a", - "ip_address": "192.168.0.1", - "subnet_mask": "255.255.255.252", - "start_up_duration": 0} - computer_a = Computer.from_config(config = computer_a_cfg) + computer_a_cfg = { + "type": "computer", + "hostname": "computer_a", + "ip_address": "192.168.0.1", + "subnet_mask": "255.255.255.252", + "start_up_duration": 0, + } + computer_a = Computer.from_config(config=computer_a_cfg) computer_a.power_on() computer_a.software_manager.install(software_class=C2Server) - - computer_b_cfg = {"type": "computer", - "hostname": "computer_b", - "ip_address": "192.168.0.2", - "subnet_mask": "255.255.255.252", - "start_up_duration": 0, - } + computer_b_cfg = { + "type": "computer", + "hostname": "computer_b", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.252", + "start_up_duration": 0, + } computer_b = Computer.from_config(config=computer_b_cfg) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py index 02b13724..f73e661e 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py @@ -12,12 +12,13 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def dos_bot() -> DoSBot: - computer_cfg = {"type":"computer", - "hostname": "compromised_pc", - "ip_address": "192.168.0.1", - "subnet_mask": "255.255.255.0", - "start_up_duration": 0, - } + computer_cfg = { + "type": "computer", + "hostname": "compromised_pc", + "ip_address": "192.168.0.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } computer: Computer = Computer.from_config(config=computer_cfg) computer.power_on() @@ -39,7 +40,7 @@ def test_dos_bot_cannot_run_when_node_offline(dos_bot): dos_bot_node.power_off() - for i in range(dos_bot_node.shut_down_duration + 1): + for i in range(dos_bot_node.config.shut_down_duration + 1): dos_bot_node.apply_timestep(timestep=i) assert dos_bot_node.operating_state is NodeOperatingState.OFF diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py index 6e32b646..f2b538c0 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py @@ -17,14 +17,28 @@ from primaite.simulator.system.services.database.database_service import Databas def database_client_on_computer() -> Tuple[DatabaseClient, Computer]: network = Network() - db_server: Server = Server.from_config(config={"type": "server", "hostname":"db_server", "ip_address":"192.168.0.1", "subnet_mask":"255.255.255.0", "start_up_duration":0}) + db_server: Server = Server.from_config( + config={ + "type": "server", + "hostname": "db_server", + "ip_address": "192.168.0.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) db_server.power_on() db_server.software_manager.install(DatabaseService) db_server.software_manager.software["DatabaseService"].start() - db_client: Computer = Computer.from_config(config = {"type":"computer", - "hostname":"db_client", "ip_address":"192.168.0.2", "subnet_mask":"255.255.255.0", "start_up_duration":0 - }) + db_client: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "db_client", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) db_client.power_on() db_client.software_manager.install(DatabaseClient) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py index 85cd369f..a76e4c07 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py @@ -12,8 +12,15 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def web_browser() -> WebBrowser: - computer_cfg = {"type": "computer", "hostname": "web_client", "ip_address": "192.168.1.11", "subnet_mask": "255.255.255.0", "default_gateway": "192.168.1.1", "start_up_duration": 0} - + computer_cfg = { + "type": "computer", + "hostname": "web_client", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + computer: Computer = Computer.from_config(config=computer_cfg) computer.power_on() @@ -25,8 +32,15 @@ def web_browser() -> WebBrowser: def test_create_web_client(): - computer_cfg = {"type": "computer", "hostname": "web_client", "ip_address": "192.168.1.11", "subnet_mask": "255.255.255.0", "default_gateway": "192.168.1.1", "start_up_duration": 0} - + computer_cfg = { + "type": "computer", + "hostname": "web_client", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + computer: Computer = Computer.from_config(config=computer_cfg) computer.power_on() diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py index b7ba2d04..3382e2f7 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py @@ -8,12 +8,13 @@ from primaite.simulator.system.services.database.database_service import Databas @pytest.fixture(scope="function") def database_server() -> Node: - node_cfg = {"type": "computer", - "hostname": "db_node", - "ip_address": "192.168.1.2", - "subnet_mask": "255.255.255.0", - "start_up_duration": 0, - } + node_cfg = { + "type": "computer", + "hostname": "db_node", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } node = Computer.from_config(config=node_cfg) node.power_on() diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py index 3f621331..430e3835 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py @@ -14,21 +14,15 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def dns_client() -> Computer: - - node_cfg = {"type": "computer", - "hostname": "dns_client", - "ip_address": "192.168.1.11", - "subnet_mask": "255.255.255.0", - "default_gateway": "192.168.1.1", - "dns_server": IPv4Address("192.168.1.10")} + node_cfg = { + "type": "computer", + "hostname": "dns_client", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10"), + } node = Computer.from_config(config=node_cfg) - # node = Computer( - # hostname="dns_client", - # ip_address="192.168.1.11", - # subnet_mask="255.255.255.0", - # default_gateway="192.168.1.1", - # dns_server=IPv4Address("192.168.1.10"), - # ) return node @@ -69,7 +63,7 @@ def test_dns_client_check_domain_exists_when_not_running(dns_client): dns_client.power_off() - for i in range(dns_client.shut_down_duration + 1): + for i in range(dns_client.config.shut_down_duration + 1): dns_client.apply_timestep(timestep=i) assert dns_client.operating_state is NodeOperatingState.OFF diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py index 8df96099..fd193415 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py @@ -16,12 +16,14 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def dns_server() -> Node: - node_cfg = {"type": "server", - "hostname": "dns_server", - "ip_address": "192.168.1.10", - "subnet_mask":"255.255.255.0", - "default_gateway": "192.168.1.1", - "start_up_duration":0} + node_cfg = { + "type": "server", + "hostname": "dns_server", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } node = Server.from_config(config=node_cfg) node.power_on() node.software_manager.install(software_class=DNSServer) @@ -55,12 +57,13 @@ def test_dns_server_receive(dns_server): # register the web server in the domain controller dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12")) - client_cfg = {"type": "computer", - "hostname": "client", - "ip_address": "192.168.1.11", - "subnet_mask": "255.255.255.0", - "start_up_duration": 0, - } + client_cfg = { + "type": "computer", + "hostname": "client", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } client = Computer.from_config(config=client_cfg) client.power_on() client.dns_server = IPv4Address("192.168.1.10") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py index c6e10b7d..28ca194e 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py @@ -16,13 +16,14 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def ftp_client() -> Node: - node_cfg = {"type": "computer", - "hostname": "ftp_client", - "ip_address":"192.168.1.11", - "subnet_mask":"255.255.255.0", - "default_gateway":"192.168.1.1", - "start_up_duration": 0, - } + node_cfg = { + "type": "computer", + "hostname": "ftp_client", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } node = Computer.from_config(config=node_cfg) node.power_on() return node diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py index 5cae88e0..ea8ab071 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py @@ -14,12 +14,14 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def ftp_server() -> Node: - node_cfg = {"type": "server", - "hostname":"ftp_server", - "ip_address":"192.168.1.10", - "subnet_mask": "255.255.255.0", - "default_gateway": "192.168.1.1", - "start_up_duration":0} + node_cfg = { + "type": "server", + "hostname": "ftp_server", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } node = Server.from_config(config=node_cfg) node.power_on() node.software_manager.install(software_class=FTPServer) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 08bef92d..1666f008 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -29,8 +29,8 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def terminal_on_computer() -> Tuple[Terminal, Computer]: - computer: Computer = Computer( - hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0 + computer: Computer = Computer.from_config(config={"type":"computer", + "hostname":"node_a", "ip_address":"192.168.0.10", "subnet_mask":"255.255.255.0", "start_up_duration":0} ) computer.power_on() terminal: Terminal = computer.software_manager.software.get("Terminal") @@ -41,11 +41,19 @@ def terminal_on_computer() -> Tuple[Terminal, Computer]: @pytest.fixture(scope="function") def basic_network() -> Network: network = Network() - node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + node_a = Computer.from_config(config={"type":"computer", + "hostname":"node_a", + "ip_address":"192.168.0.10", + "subnet_mask":"255.255.255.0", + "start_up_duration":0}) node_a.power_on() node_a.software_manager.get_open_ports() - node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) + node_b = Computer.from_config(config={"type":"computer", + "hostname":"node_b", + "ip_address":"192.168.0.11", + "subnet_mask":"255.255.255.0", + "start_up_duration":0}) node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) @@ -57,18 +65,20 @@ def wireless_wan_network(): network = Network() # Configure PC A - pc_a = Computer( - hostname="pc_a", - ip_address="192.168.0.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.0.1", - start_up_duration=0, - ) + pc_a_cfg = {"type": "computer", + "hostname":"pc_a", + "ip_address":"192.168.0.2", + "subnet_mask":"255.255.255.0", + "default_gateway":"192.168.0.1", + "start_up_duration":0, + } + + pc_a = Computer.from_config(config=pc_a_cfg) pc_a.power_on() network.add_node(pc_a) # Configure Router 1 - router_1 = WirelessRouter(hostname="router_1", start_up_duration=0, airspace=network.airspace) + router_1 = WirelessRouter.from_config(config={"type":"wireless_router", "hostname":"router_1", "start_up_duration":0, "airspace":network.airspace}) router_1.power_on() network.add_node(router_1) @@ -88,18 +98,21 @@ def wireless_wan_network(): ) # Configure PC B - pc_b = Computer( - hostname="pc_b", - ip_address="192.168.2.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.2.1", - start_up_duration=0, - ) + + pc_b_cfg = {"type": "computer", + "hostname":"pc_b", + "ip_address":"192.168.2.2", + "subnet_mask":"255.255.255.0", + "default_gateway":"192.168.2.1", + "start_up_duration":0, + } + + pc_b = Computer.from_config(config=pc_b_cfg) pc_b.power_on() network.add_node(pc_b) # Configure Router 2 - router_2 = WirelessRouter(hostname="router_2", start_up_duration=0, airspace=network.airspace) + router_2 = WirelessRouter.from_config(config={"type":"wireless_router", "hostname":"router_2", "start_up_duration":0, "airspace":network.airspace}) router_2.power_on() network.add_node(router_2) @@ -131,7 +144,7 @@ def game_and_agent_fixture(game_and_agent): game, agent = game_and_agent client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - client_1.start_up_duration = 3 + client_1.config.start_up_duration = 3 return game, agent @@ -143,7 +156,11 @@ def test_terminal_creation(terminal_on_computer): def test_terminal_install_default(): """Terminal should be auto installed onto Nodes""" - computer = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + computer: Computer = Computer.from_config(config={"type":"computer", + "hostname":"node_a", + "ip_address":"192.168.0.10", + "subnet_mask":"255.255.255.0", + "start_up_duration":0}) computer.power_on() assert computer.software_manager.software.get("Terminal") @@ -151,7 +168,7 @@ def test_terminal_install_default(): def test_terminal_not_on_switch(): """Ensure terminal does not auto-install to switch""" - test_switch = Switch(hostname="Test") + test_switch = Switch.from_config(config={"type":"switch", "hostname":"Test"}) assert not test_switch.software_manager.software.get("Terminal") @@ -357,8 +374,6 @@ def test_multiple_remote_terminals_same_node(basic_network): for attempt in range(3): remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11") - terminal_a.show() - assert len(terminal_a._connections) == 3 diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py index f0901b70..916e1991 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py @@ -16,12 +16,14 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def web_server() -> Server: - node_cfg = {"type": "server", - "hostname":"web_server", - "ip_address": "192.168.1.10", - "subnet_mask": "255.255.255.0", - "default_gateway":"192.168.1.1", - "start_up_duration":0 } + node_cfg = { + "type": "server", + "hostname": "web_server", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } node = Server.from_config(config=node_cfg) node.power_on() node.software_manager.install(WebServer) From e1f2f73db08840c7c9670cb54a4feb0424301e9d Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 28 Jan 2025 09:37:58 +0000 Subject: [PATCH 09/23] #2887 - Test changes to correct NodeOperatingState is correct per passed config. --- src/primaite/simulator/network/hardware/base.py | 1 + .../game_layer/observations/test_link_observations.py | 2 +- .../game_layer/observations/test_router_observation.py | 4 ++-- .../integration_tests/system/test_database_on_node.py | 10 +++++----- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index de97f22b..21f2946b 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1574,6 +1574,7 @@ class Node(SimComponent, ABC): msg = f"Configuration contains an invalid Node type: {config['type']}" return ValueError(msg) obj = cls(config=cls.ConfigSchema(**config)) + obj.operating_state = NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()] return obj def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None: diff --git a/tests/integration_tests/game_layer/observations/test_link_observations.py b/tests/integration_tests/game_layer/observations/test_link_observations.py index f95d35c2..1ab50a68 100644 --- a/tests/integration_tests/game_layer/observations/test_link_observations.py +++ b/tests/integration_tests/game_layer/observations/test_link_observations.py @@ -57,7 +57,7 @@ def test_link_observation(): net = Network() sim = Simulation(network=net) switch: Switch = Switch.from_config( - config={"type": "switch", "hostname": "switch", "num_ports": 5, "operating_state": NodeOperatingState.ON} + config={"type": "switch", "hostname": "switch", "num_ports": 5, "operating_state": "ON"} ) computer_1: Computer = Computer.from_config( config={ diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index 8335867d..495e102d 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -17,7 +17,7 @@ def test_router_observation(): """Test adding/removing acl rules and enabling/disabling ports.""" net = Network() router = Router.from_config( - config={"type": "router", "hostname": "router", "num_ports": 5, "operating_state": NodeOperatingState.ON} + config={"type": "router", "hostname": "router", "num_ports": 5, "operating_state": "ON"} ) ports = [PortObservation(where=["NICs", i]) for i in range(1, 6)] @@ -92,7 +92,7 @@ def test_router_observation(): # connect a switch to the router and check that only the correct port is updated switch: Switch = Switch.from_config( - config={"type": "switch", "hostname": "switch", "num_ports": 1, "operating_state": NodeOperatingState.ON} + config={"type": "switch", "hostname": "switch", "num_ports": 1, "operating_state": "ON"} ) link = net.connect(router.network_interface[1], switch.network_interface[1]) assert router.network_interface[1].enabled diff --git a/tests/integration_tests/system/test_database_on_node.py b/tests/integration_tests/system/test_database_on_node.py index bb25f8c8..87aca129 100644 --- a/tests/integration_tests/system/test_database_on_node.py +++ b/tests/integration_tests/system/test_database_on_node.py @@ -20,11 +20,11 @@ from primaite.simulator.system.software import SoftwareHealthState @pytest.fixture(scope="function") def peer_to_peer() -> Tuple[Computer, Computer]: network = Network() - node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + node_a: Computer = Computer.from_config(config={"type":"computer", "hostname":"node_a", "ip_address":"192.168.0.10", "subnet_mask":"255.255.255.0", "start_up_duration":0}) node_a.power_on() node_a.software_manager.get_open_ports() - node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) + node_b: Computer = Computer.from_config(config={"type":"computer", "hostname":"node_b", "ip_address":"192.168.0.11", "subnet_mask":"255.255.255.0", "start_up_duration":0}) node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) @@ -412,8 +412,8 @@ def test_database_service_can_terminate_connection(peer_to_peer): def test_client_connection_terminate_does_not_terminate_another_clients_connection(): network = Network() - db_server = Server( - hostname="db_client", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0 + db_server: Server = Server.from_config(config={"type":"server", + "hostname":"db_client", "ip_address":"192.168.0.11", "subnet_mask":"255.255.255.0", "start_up_duration":0} ) db_server.power_on() @@ -465,6 +465,6 @@ def test_client_connection_terminate_does_not_terminate_another_clients_connecti def test_database_server_install_ftp_client(): - server = Server(hostname="db_server", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + server: Server = Server.from_config(config={"type":"server", "hostname":"db_server", "ip_address":"192.168.1.2", "subnet_mask":"255.255.255.0", "start_up_duration":0}) server.software_manager.install(DatabaseService) assert server.software_manager.software.get("FTPClient") From f85aace31b1a994c81db70bb92333ed3cf44fc3d Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 28 Jan 2025 19:35:27 +0000 Subject: [PATCH 10/23] #2887 - Correct networking troubles causing test failures --- src/primaite/game/game.py | 16 +++++------ src/primaite/simulator/network/container.py | 2 +- .../simulator/network/hardware/base.py | 17 +++++++----- .../network/hardware/nodes/host/host_node.py | 27 +++++++++---------- .../hardware/nodes/network/firewall.py | 1 + .../network/hardware/nodes/network/router.py | 23 +++++----------- src/primaite/simulator/network/networks.py | 14 +++++----- .../network/test_frame_transmission.py | 2 +- .../system/test_database_on_node.py | 2 +- .../system/test_ftp_client_server.py | 2 +- .../test_simulation/test_request_response.py | 4 +-- 11 files changed, 53 insertions(+), 57 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index f4d118ac..b1ff1f9d 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -278,11 +278,11 @@ class PrimaiteGame: # TODO: handle simulation defaults more cleanly if "node_start_up_duration" in defaults_config: - new_node.start_up_duration = defaults_config["node_startup_duration"] + new_node.config.start_up_duration = defaults_config["node_startup_duration"] if "node_shut_down_duration" in defaults_config: - new_node.shut_down_duration = defaults_config["node_shut_down_duration"] + new_node.config.shut_down_duration = defaults_config["node_shut_down_duration"] if "node_scan_duration" in defaults_config: - new_node.node_scan_duration = defaults_config["node_scan_duration"] + new_node.config.node_scan_duration = defaults_config["node_scan_duration"] if "folder_scan_duration" in defaults_config: new_node.file_system._default_folder_scan_duration = defaults_config["folder_scan_duration"] if "folder_restore_duration" in defaults_config: @@ -337,7 +337,7 @@ class PrimaiteGame: # TODO: handle simulation defaults more cleanly if "service_fix_duration" in defaults_config: - new_service.fixing_duration = defaults_config["service_fix_duration"] + new_service.config.fixing_duration = defaults_config["service_fix_duration"] if "service_restart_duration" in defaults_config: new_service.restart_duration = defaults_config["service_restart_duration"] if "service_install_duration" in defaults_config: @@ -394,8 +394,8 @@ class PrimaiteGame: new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"])) # temporarily set to 0 so all nodes are initially on - new_node.start_up_duration = 0 - new_node.shut_down_duration = 0 + new_node.config.start_up_duration = 0 + new_node.config.shut_down_duration = 0 net.add_node(new_node) # run through the power on step if the node is to be turned on at the start @@ -403,8 +403,8 @@ class PrimaiteGame: new_node.power_on() # set start up and shut down duration - new_node.start_up_duration = int(node_cfg.get("start_up_duration", 3)) - new_node.shut_down_duration = int(node_cfg.get("shut_down_duration", 3)) + new_node.config.start_up_duration = int(node_cfg.get("start_up_duration", 3)) + new_node.config.shut_down_duration = int(node_cfg.get("shut_down_duration", 3)) # 1.1 Create Node Sets for node_set_cfg in node_sets_cfg: diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index 982495a4..247c06bb 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -201,7 +201,7 @@ class Network(SimComponent): port_str, port.ip_address, port.subnet_mask, - node.default_gateway, + node.config.default_gateway, ] ) print(table) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 21f2946b..08e100d2 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1493,17 +1493,12 @@ class Node(SimComponent, ABC): :param hostname: The node hostname on the network. :param operating_state: The node operating state, either ON or OFF. """ - - default_gateway: Optional[IPV4Address] = None - "The default gateway IP address for forwarding network traffic to other networks." operating_state: NodeOperatingState = NodeOperatingState.OFF "The hardware state of the node." network_interfaces: Dict[str, NetworkInterface] = {} "The Network Interfaces on the node." network_interface: Dict[int, NetworkInterface] = {} "The Network Interfaces on the node by port id." - dns_server: Optional[IPv4Address] = None - "List of IP addresses of DNS servers used for name resolution." accounts: Dict[str, Account] = {} "All accounts on the node." applications: Dict[str, Application] = {} @@ -1567,6 +1562,16 @@ class Node(SimComponent, ABC): red_scan_countdown: int = 0 "Time steps until reveal to red scan is complete." + dns_server: Optional[IPv4Address] = None + "List of IP addresses of DNS servers used for name resolution." + + default_gateway: Optional[IPV4Address] = None + "The default gateway IP address for forwarding network traffic to other networks." + + @property + def dns_server(self) -> Optional[IPv4Address]: + return self.config.dns_server + @classmethod def from_config(cls, config: Dict) -> "Node": """Create Node object from a given configuration dictionary.""" @@ -1615,7 +1620,7 @@ class Node(SimComponent, ABC): sys_log=kwargs.get("sys_log"), session_manager=kwargs.get("session_manager"), file_system=kwargs.get("file_system"), - dns_server=kwargs.get("dns_server"), + dns_server=kwargs["config"].dns_server, ) super().__init__(**kwargs) self._install_system_software() diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 23db025d..3b1d8e48 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -46,8 +46,8 @@ class HostARP(ARP): :return: The MAC address of the default gateway if present in the ARP cache; otherwise, None. """ - if self.software_manager.node.default_gateway: - return self.get_arp_cache_mac_address(self.software_manager.node.default_gateway) + if self.software_manager.node.config.default_gateway: + return self.get_arp_cache_mac_address(self.software_manager.node.config.default_gateway) def get_default_gateway_network_interface(self) -> Optional[NIC]: """ @@ -55,8 +55,8 @@ class HostARP(ARP): :return: The NIC associated with the default gateway if it exists in the ARP cache; otherwise, None. """ - if self.software_manager.node.default_gateway and self.software_manager.node.has_enabled_network_interface: - return self.get_arp_cache_network_interface(self.software_manager.node.default_gateway) + if self.software_manager.node.config.default_gateway and self.software_manager.node.has_enabled_network_interface: + return self.get_arp_cache_network_interface(self.software_manager.node.config.default_gateway) def _get_arp_cache_mac_address( self, ip_address: IPV4Address, is_reattempt: bool = False, is_default_gateway_attempt: bool = False @@ -75,7 +75,7 @@ class HostARP(ARP): if arp_entry: return arp_entry.mac_address - if ip_address == self.software_manager.node.default_gateway: + if ip_address == self.software_manager.node.config.default_gateway: is_reattempt = True if not is_reattempt: self.send_arp_request(ip_address) @@ -83,11 +83,11 @@ class HostARP(ARP): ip_address=ip_address, is_reattempt=True, is_default_gateway_attempt=is_default_gateway_attempt ) else: - if self.software_manager.node.default_gateway: + if self.software_manager.node.config.default_gateway: if not is_default_gateway_attempt: - self.send_arp_request(self.software_manager.node.default_gateway) + self.send_arp_request(self.software_manager.node.config.default_gateway) return self._get_arp_cache_mac_address( - ip_address=self.software_manager.node.default_gateway, + ip_address=self.software_manager.node.config.default_gateway, is_reattempt=True, is_default_gateway_attempt=True, ) @@ -118,7 +118,7 @@ class HostARP(ARP): if arp_entry: return self.software_manager.node.network_interfaces[arp_entry.network_interface_uuid] else: - if ip_address == self.software_manager.node.default_gateway: + if ip_address == self.software_manager.node.config.default_gateway: is_reattempt = True if not is_reattempt: self.send_arp_request(ip_address) @@ -126,11 +126,11 @@ class HostARP(ARP): ip_address=ip_address, is_reattempt=True, is_default_gateway_attempt=is_default_gateway_attempt ) else: - if self.software_manager.node.default_gateway: + if self.software_manager.node.config.default_gateway: if not is_default_gateway_attempt: - self.send_arp_request(self.software_manager.node.default_gateway) + self.send_arp_request(self.software_manager.node.config.default_gateway) return self._get_arp_cache_network_interface( - ip_address=self.software_manager.node.default_gateway, + ip_address=self.software_manager.node.config.default_gateway, is_reattempt=True, is_default_gateway_attempt=True, ) @@ -333,9 +333,8 @@ class HostNode(Node, identifier="HostNode"): """Configuration Schema for HostNode class.""" hostname: str = "HostNode" - ip_address: IPV4Address = "192.168.0.1" subnet_mask: IPV4Address = "255.255.255.0" - default_gateway: IPV4Address = "192.168.10.1" + ip_address: IPV4Address def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 01c5159b..99dd48c4 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -141,6 +141,7 @@ class Firewall(Router, identifier="firewall"): self.external_outbound_acl.sys_log = kwargs["sys_log"] self.external_outbound_acl.name = f"{kwargs['config'].hostname} - External Outbound" + self.power_on() def _init_request_manager(self) -> RequestManager: """ diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 3ecb761b..ebb35cf3 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1211,32 +1211,22 @@ class Router(NetworkNode, identifier="router"): "The Router Interfaces on the node." network_interface: Dict[int, RouterInterface] = {} "The Router Interfaces on the node by port id." - - sys_log: SysLog - acl: AccessControlList - route_table: RouteTable - config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSchema()) + config: "Router.ConfigSchema" class ConfigSchema(NetworkNode.ConfigSchema): - """Configuration Schema for Router Objects.""" + + hostname: str = "router" + num_ports: int - num_ports: int = 5 - """Number of ports available for this Router. Default is 5""" - - hostname: str = "Router" - - ports: Dict[Union[int, str], Dict] = {} def __init__(self, **kwargs): if not kwargs.get("sys_log"): kwargs["sys_log"] = SysLog(kwargs["config"].hostname) if not kwargs.get("acl"): - kwargs["acl"] = AccessControlList( - sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=kwargs["config"].hostname - ) + kwargs["acl"] = AccessControlList(sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=kwargs["config"].hostname) if not kwargs.get("route_table"): kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"]) super().__init__(**kwargs) @@ -1562,7 +1552,7 @@ class Router(NetworkNode, identifier="router"): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.hostname} Network Interfaces" + table.title = f"{self.config.hostname} Network Interfaces" for port, network_interface in self.network_interface.items(): table.add_row( [ @@ -1666,4 +1656,5 @@ class Router(NetworkNode, identifier="router"): next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None) if next_hop_ip_address: router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) + router.operating_state = NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()] return router diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index 0579f137..644b2a4a 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -40,42 +40,42 @@ def client_server_routed() -> Network: network = Network() # Router 1 - router_1 = Router(hostname="router_1", num_ports=3) + router_1 = Router(config=dict(hostname="router_1", num_ports=3)) router_1.power_on() router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0") router_1.configure_port(port=2, ip_address="192.168.2.1", subnet_mask="255.255.255.0") # Switch 1 - switch_1 = Switch(hostname="switch_1", num_ports=6) + switch_1 = Switch(config=dict(hostname="switch_1", num_ports=6)) switch_1.power_on() network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[6]) router_1.enable_port(1) # Switch 2 - switch_2 = Switch(hostname="switch_2", num_ports=6) + switch_2 = Switch(config=dict(hostname="switch_2", num_ports=6)) switch_2.power_on() network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[6]) router_1.enable_port(2) # Client 1 - client_1 = Computer( + client_1 = Computer(config=dict( hostname="client_1", ip_address="192.168.2.2", subnet_mask="255.255.255.0", default_gateway="192.168.2.1", start_up_duration=0, - ) + )) client_1.power_on() network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) # Server 1 - server_1 = Server( + server_1 = Server(config=dict( hostname="server_1", ip_address="192.168.1.2", subnet_mask="255.255.255.0", default_gateway="192.168.1.1", start_up_duration=0, - ) + )) server_1.power_on() network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1]) diff --git a/tests/integration_tests/network/test_frame_transmission.py b/tests/integration_tests/network/test_frame_transmission.py index 327c87e5..cff99e07 100644 --- a/tests/integration_tests/network/test_frame_transmission.py +++ b/tests/integration_tests/network/test_frame_transmission.py @@ -41,7 +41,7 @@ def test_multi_nic(): """Tests that Computers with multiple NICs can ping each other and the data go across the correct links.""" network = Network() - node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + node_a = Computer(config=dict(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)) node_a.power_on() node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) diff --git a/tests/integration_tests/system/test_database_on_node.py b/tests/integration_tests/system/test_database_on_node.py index 87aca129..64b6ddbc 100644 --- a/tests/integration_tests/system/test_database_on_node.py +++ b/tests/integration_tests/system/test_database_on_node.py @@ -338,7 +338,7 @@ def test_database_client_cannot_query_offline_database_server(uc2_network): assert db_connection.query("INSERT") is True db_server.power_off() - for i in range(db_server.shut_down_duration + 1): + for i in range(db_server.config.shut_down_duration + 1): uc2_network.apply_timestep(timestep=i) assert db_server.operating_state is NodeOperatingState.OFF diff --git a/tests/integration_tests/system/test_ftp_client_server.py b/tests/integration_tests/system/test_ftp_client_server.py index fa4df0a9..57e42457 100644 --- a/tests/integration_tests/system/test_ftp_client_server.py +++ b/tests/integration_tests/system/test_ftp_client_server.py @@ -87,7 +87,7 @@ def test_ftp_client_tries_to_connect_to_offline_server(ftp_client_and_ftp_server server.power_off() - for i in range(server.shut_down_duration + 1): + for i in range(server.config.shut_down_duration + 1): server.apply_timestep(timestep=i) assert ftp_client.operating_state == ServiceOperatingState.RUNNING diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py index 21152199..efc97ce6 100644 --- a/tests/integration_tests/test_simulation/test_request_response.py +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -140,9 +140,9 @@ class TestDataManipulationGreenRequests: client_1 = net.get_node_by_hostname("client_1") client_2 = net.get_node_by_hostname("client_2") - client_1.shut_down_duration = 0 + client_1.config.shut_down_duration = 0 client_1.power_off() - client_2.shut_down_duration = 0 + client_2.config.shut_down_duration = 0 client_2.power_off() client_1_browser_execute_off = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"]) From 51f1c91e154d01874cce2bc87eb00d6c2a70858e Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 29 Jan 2025 11:55:10 +0000 Subject: [PATCH 11/23] #2887 - Fixed Node unit test failures --- src/primaite/game/game.py | 3 +- src/primaite/simulator/network/airspace.py | 2 +- .../network/hardware/nodes/network/router.py | 16 ++- .../hardware/nodes/network/wireless_router.py | 8 +- .../simulator/system/applications/nmap.py | 4 +- .../test_agents_use_action_masks.py | 1 + .../extensions/nodes/super_computer.py | 4 +- ...ndwidth_load_checks_before_transmission.py | 2 +- .../network/test_firewall.py | 63 +++++---- .../network/test_frame_transmission.py | 62 ++++++--- ...test_multi_lan_internet_example_network.py | 1 + .../network/test_network_creation.py | 84 +++++++++-- .../integration_tests/network/test_routing.py | 70 ++++++---- .../network/test_wireless_router.py | 46 +++--- .../system/test_application_on_node.py | 34 +++-- .../system/test_database_on_node.py | 62 +++++++-- .../test_user_session_manager_logins.py | 30 ++-- .../test_simulation/test_request_response.py | 2 +- .../_network/_hardware/test_node_actions.py | 2 + .../_system/_services/test_dns_client.py | 10 ++ .../_system/_services/test_dns_server.py | 2 +- .../_system/_services/test_terminal.py | 131 ++++++++++-------- 22 files changed, 427 insertions(+), 212 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index b1ff1f9d..05c13c2a 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -269,7 +269,8 @@ class PrimaiteGame: new_node = None if n_type in Node._registry: - # simplify down Node creation: + if n_type == "wireless_router": + node_cfg["airspace"] = net.airspace new_node = Node._registry[n_type].from_config(config=node_cfg) else: msg = f"invalid node type {n_type} in config" diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index 5549eb78..7ede0bb0 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -178,7 +178,7 @@ class AirSpace(BaseModel): status = "Enabled" if interface.enabled else "Disabled" table.add_row( [ - interface._connected_node.hostname, # noqa + interface._connected_node.config.hostname, # noqa interface.mac_address, interface.ip_address if hasattr(interface, "ip_address") else None, interface.subnet_mask if hasattr(interface, "subnet_mask") else None, diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index ebb35cf3..dd32fa31 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -7,7 +7,7 @@ from ipaddress import IPv4Address, IPv4Network from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union from prettytable import MARKDOWN, PrettyTable -from pydantic import Field, validate_call +from pydantic import validate_call from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent @@ -1217,16 +1217,18 @@ class Router(NetworkNode, identifier="router"): config: "Router.ConfigSchema" class ConfigSchema(NetworkNode.ConfigSchema): - - hostname: str = "router" - num_ports: int + """Configuration Schema for Routers.""" + hostname: str = "router" + num_ports: int = 5 def __init__(self, **kwargs): if not kwargs.get("sys_log"): kwargs["sys_log"] = SysLog(kwargs["config"].hostname) if not kwargs.get("acl"): - kwargs["acl"] = AccessControlList(sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=kwargs["config"].hostname) + kwargs["acl"] = AccessControlList( + sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=kwargs["config"].hostname + ) if not kwargs.get("route_table"): kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"]) super().__init__(**kwargs) @@ -1656,5 +1658,7 @@ class Router(NetworkNode, identifier="router"): next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None) if next_hop_ip_address: router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) - router.operating_state = NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()] + router.operating_state = ( + NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()] + ) return router diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 70e655ac..2ca854d4 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -129,7 +129,7 @@ class WirelessRouter(Router, identifier="wireless_router"): """Configuration Schema for WirelessRouter nodes within PrimAITE.""" hostname: str = "WirelessRouter" - airspace: Optional[AirSpace] = None + airspace: AirSpace def __init__(self, **kwargs): super().__init__(**kwargs) @@ -262,9 +262,6 @@ class WirelessRouter(Router, identifier="wireless_router"): :return: WirelessRouter instance. :rtype: WirelessRouter """ - operating_state = ( - NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()] - ) router = cls(config=cls.ConfigSchema(**config)) if "router_interface" in config: ip_address = config["router_interface"]["ip_address"] @@ -297,4 +294,7 @@ class WirelessRouter(Router, identifier="wireless_router"): next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")), metric=float(route.get("metric", 0)), ) + router.operating_state = ( + NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()] + ) return router diff --git a/src/primaite/simulator/system/applications/nmap.py b/src/primaite/simulator/system/applications/nmap.py index 3eeda4b6..46fb66a6 100644 --- a/src/primaite/simulator/system/applications/nmap.py +++ b/src/primaite/simulator/system/applications/nmap.py @@ -208,7 +208,7 @@ class NMAP(Application, identifier="NMAP"): if show: table = PrettyTable(["IP Address", "Can Ping"]) table.align = "l" - table.title = f"{self.software_manager.node.hostname} NMAP Ping Scan" + table.title = f"{self.software_manager.node.config.hostname} NMAP Ping Scan" ip_addresses = self._explode_ip_address_network_array(target_ip_address) @@ -367,7 +367,7 @@ class NMAP(Application, identifier="NMAP"): if show: table = PrettyTable(["IP Address", "Port", "Protocol"]) table.align = "l" - table.title = f"{self.software_manager.node.hostname} NMAP Port Scan ({scan_type})" + table.title = f"{self.software_manager.node.config.hostname} NMAP Port Scan ({scan_type})" self.sys_log.info(f"{self.name}: Starting port scan") for ip_address in ip_addresses: # Prevent port scan on this node diff --git a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py index a34d430b..6da801d4 100644 --- a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py +++ b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py @@ -12,6 +12,7 @@ from sb3_contrib import MaskablePPO from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from tests import TEST_ASSETS_ROOT CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml" diff --git a/tests/integration_tests/extensions/nodes/super_computer.py b/tests/integration_tests/extensions/nodes/super_computer.py index 4af1b748..4cf45706 100644 --- a/tests/integration_tests/extensions/nodes/super_computer.py +++ b/tests/integration_tests/extensions/nodes/super_computer.py @@ -36,8 +36,8 @@ class SuperComputer(HostNode, identifier="supercomputer"): SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "FTPClient": FTPClient} - def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs): + def __init__(self, **kwargs): print("--- Extended Component: SuperComputer ---") - super().__init__(ip_address=ip_address, subnet_mask=subnet_mask, **kwargs) + super().__init__(**kwargs) pass diff --git a/tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py b/tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py index 36c77fe1..32193946 100644 --- a/tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py +++ b/tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py @@ -16,7 +16,7 @@ def test_wireless_link_loading(wireless_wan_network): # Configure Router 2 ACLs router_2.acl.add_rule(action=ACLAction.PERMIT, position=1) - airspace = router_1.airspace + airspace = router_1.config.airspace client.software_manager.install(FTPClient) ftp_client: FTPClient = client.software_manager.software.get("FTPClient") diff --git a/tests/integration_tests/network/test_firewall.py b/tests/integration_tests/network/test_firewall.py index 69f3e5ab..131abe78 100644 --- a/tests/integration_tests/network/test_firewall.py +++ b/tests/integration_tests/network/test_firewall.py @@ -83,12 +83,15 @@ def dmz_external_internal_network() -> Network: ) # external node - external_node = Computer( - hostname="external_node", - ip_address="192.168.10.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.10.1", - start_up_duration=0, + external_node: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "external_node", + "ip_address": "192.168.10.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.10.1", + "start_up_duration": 0, + } ) external_node.power_on() external_node.software_manager.install(NTPServer) @@ -98,12 +101,15 @@ def dmz_external_internal_network() -> Network: network.connect(endpoint_b=external_node.network_interface[1], endpoint_a=firewall_node.external_port) # internal node - internal_node = Computer( - hostname="internal_node", - ip_address="192.168.0.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.0.1", - start_up_duration=0, + internal_node: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "internal_node", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.0.1", + "start_up_duration": 0, + } ) internal_node.power_on() internal_node.software_manager.install(NTPClient) @@ -114,12 +120,15 @@ def dmz_external_internal_network() -> Network: network.connect(endpoint_b=internal_node.network_interface[1], endpoint_a=firewall_node.internal_port) # dmz node - dmz_node = Computer( - hostname="dmz_node", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + dmz_node: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "dmz_node", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } ) dmz_node.power_on() dmz_ntp_client: NTPClient = dmz_node.software_manager.software["NTPClient"] @@ -157,9 +166,9 @@ def test_nodes_can_ping_default_gateway(dmz_external_internal_network): internal_node = dmz_external_internal_network.get_node_by_hostname("internal_node") dmz_node = dmz_external_internal_network.get_node_by_hostname("dmz_node") - assert internal_node.ping(internal_node.default_gateway) # default gateway internal - assert dmz_node.ping(dmz_node.default_gateway) # default gateway dmz - assert external_node.ping(external_node.default_gateway) # default gateway external + assert internal_node.ping(internal_node.config.default_gateway) # default gateway internal + assert dmz_node.ping(dmz_node.config.default_gateway) # default gateway dmz + assert external_node.ping(external_node.config.default_gateway) # default gateway external def test_nodes_can_ping_default_gateway_on_another_subnet(dmz_external_internal_network): @@ -173,14 +182,14 @@ def test_nodes_can_ping_default_gateway_on_another_subnet(dmz_external_internal_ internal_node = dmz_external_internal_network.get_node_by_hostname("internal_node") dmz_node = dmz_external_internal_network.get_node_by_hostname("dmz_node") - assert internal_node.ping(external_node.default_gateway) # internal node to external default gateway - assert internal_node.ping(dmz_node.default_gateway) # internal node to dmz default gateway + assert internal_node.ping(external_node.config.default_gateway) # internal node to external default gateway + assert internal_node.ping(dmz_node.config.default_gateway) # internal node to dmz default gateway - assert dmz_node.ping(internal_node.default_gateway) # dmz node to internal default gateway - assert dmz_node.ping(external_node.default_gateway) # dmz node to external default gateway + assert dmz_node.ping(internal_node.config.default_gateway) # dmz node to internal default gateway + assert dmz_node.ping(external_node.config.default_gateway) # dmz node to external default gateway - assert external_node.ping(external_node.default_gateway) # external node to internal default gateway - assert external_node.ping(dmz_node.default_gateway) # external node to dmz default gateway + assert external_node.ping(external_node.config.default_gateway) # external node to internal default gateway + assert external_node.ping(dmz_node.config.default_gateway) # external node to dmz default gateway def test_nodes_can_ping_each_other(dmz_external_internal_network): diff --git a/tests/integration_tests/network/test_frame_transmission.py b/tests/integration_tests/network/test_frame_transmission.py index cff99e07..6a514bdc 100644 --- a/tests/integration_tests/network/test_frame_transmission.py +++ b/tests/integration_tests/network/test_frame_transmission.py @@ -10,25 +10,31 @@ def test_node_to_node_ping(): """Tests two Computers are able to ping each other.""" network = Network() - client_1 = Computer( - hostname="client_1", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + client_1: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "client_1", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } ) client_1.power_on() - server_1 = Server( - hostname="server_1", - ip_address="192.168.1.11", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + server_1: Server = Server.from_config( + config={ + "type": "server", + "hostname": "server_1", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } ) server_1.power_on() - switch_1 = Switch(hostname="switch_1", start_up_duration=0) + switch_1: Switch = Switch.from_config(config={"type": "switch", "hostname": "switch_1", "start_up_duration": 0}) switch_1.power_on() network.connect(endpoint_a=client_1.network_interface[1], endpoint_b=switch_1.network_interface[1]) @@ -41,14 +47,38 @@ def test_multi_nic(): """Tests that Computers with multiple NICs can ping each other and the data go across the correct links.""" network = Network() - node_a = Computer(config=dict(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)) + node_a: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_a", + "ip_address": "192.168.0.10", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node_a.power_on() - node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) + node_b: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_b", + "ip_address": "192.168.0.11", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node_b.power_on() node_b.connect_nic(NIC(ip_address="10.0.0.12", subnet_mask="255.0.0.0")) - node_c = Computer(hostname="node_c", ip_address="10.0.0.13", subnet_mask="255.0.0.0", start_up_duration=0) + node_c: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_c", + "ip_address": "10.0.0.13", + "subnet_mask": "255.0.0.0", + "start_up_duration": 0, + } + ) node_c.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) diff --git a/tests/integration_tests/network/test_multi_lan_internet_example_network.py b/tests/integration_tests/network/test_multi_lan_internet_example_network.py index ea7e1c45..897b4008 100644 --- a/tests/integration_tests/network/test_multi_lan_internet_example_network.py +++ b/tests/integration_tests/network/test_multi_lan_internet_example_network.py @@ -1,6 +1,7 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.networks import multi_lan_internet_network_example from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser diff --git a/tests/integration_tests/network/test_network_creation.py b/tests/integration_tests/network/test_network_creation.py index 1ee3ccc2..4d88eac3 100644 --- a/tests/integration_tests/network/test_network_creation.py +++ b/tests/integration_tests/network/test_network_creation.py @@ -27,7 +27,15 @@ def test_network(example_network): def test_adding_removing_nodes(): """Check that we can create and add a node to a network.""" net = Network() - n1 = Computer(hostname="computer", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + n1 = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) net.add_node(n1) assert n1.parent is net assert n1 in net @@ -37,10 +45,18 @@ def test_adding_removing_nodes(): assert n1 not in net -def test_readding_node(): - """Check that warning is raised when readding a node.""" +def test_reading_node(): + """Check that warning is raised when reading a node.""" net = Network() - n1 = Computer(hostname="computer", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + n1 = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) net.add_node(n1) net.add_node(n1) assert n1.parent is net @@ -50,7 +66,15 @@ def test_readding_node(): def test_removing_nonexistent_node(): """Check that warning is raised when trying to remove a node that is not in the network.""" net = Network() - n1 = Computer(hostname="computer1", ip_address="192.168.1.1", subnet_mask="255.255.255.0", start_up_duration=0) + n1 = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer1", + "ip_address": "192.168.1.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) net.remove_node(n1) assert n1.parent is None assert n1 not in net @@ -59,8 +83,24 @@ def test_removing_nonexistent_node(): def test_connecting_nodes(): """Check that two nodes on the network can be connected.""" net = Network() - n1 = Computer(hostname="computer1", ip_address="192.168.1.1", subnet_mask="255.255.255.0", start_up_duration=0) - n2 = Computer(hostname="computer2", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + n1: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer1", + "ip_address": "192.168.1.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) + n2: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer2", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) net.add_node(n1) net.add_node(n2) @@ -75,7 +115,15 @@ def test_connecting_nodes(): def test_connecting_node_to_itself_fails(): net = Network() - node = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) + node = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_b", + "ip_address": "192.168.0.11", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node.power_on() node.connect_nic(NIC(ip_address="10.0.0.12", subnet_mask="255.0.0.0")) @@ -92,8 +140,24 @@ def test_connecting_node_to_itself_fails(): def test_disconnecting_nodes(): net = Network() - n1 = Computer(hostname="computer1", ip_address="192.168.1.1", subnet_mask="255.255.255.0", start_up_duration=0) - n2 = Computer(hostname="computer2", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + n1 = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer1", + "ip_address": "192.168.1.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) + n2 = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer2", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) net.connect(n1.network_interface[1], n2.network_interface[1]) assert len(net.links) == 1 diff --git a/tests/integration_tests/network/test_routing.py b/tests/integration_tests/network/test_routing.py index 948b409f..b60f3f6b 100644 --- a/tests/integration_tests/network/test_routing.py +++ b/tests/integration_tests/network/test_routing.py @@ -15,25 +15,31 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def pc_a_pc_b_router_1() -> Tuple[Computer, Computer, Router]: network = Network() - pc_a = Computer( - hostname="pc_a", - ip_address="192.168.0.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.0.1", - start_up_duration=0, + pc_a = Computer.from_config( + config={ + "type": "computer", + "hostname": "pc_a", + "ip_address": "192.168.0.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.0.1", + "start_up_duration": 0, + } ) pc_a.power_on() - pc_b = Computer( - hostname="pc_b", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + pc_b = Computer.from_config( + config={ + "type": "computer", + "hostname": "pc_b", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } ) pc_b.power_on() - router_1 = Router(hostname="router_1", start_up_duration=0) + router_1 = Router.from_config(config={"type": "router", "hostname": "router_1", "start_up_duration": 0}) router_1.power_on() router_1.configure_port(1, "192.168.0.1", "255.255.255.0") @@ -52,18 +58,21 @@ def multi_hop_network() -> Network: network = Network() # Configure PC A - pc_a = Computer( - hostname="pc_a", - ip_address="192.168.0.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.0.1", - start_up_duration=0, + pc_a: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "pc_a", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.0.1", + "start_up_duration": 0, + } ) pc_a.power_on() network.add_node(pc_a) # Configure Router 1 - router_1 = Router(hostname="router_1", start_up_duration=0) + router_1: Router = Router.from_config(config={"type": "router", "hostname": "router_1", "start_up_duration": 0}) router_1.power_on() network.add_node(router_1) @@ -79,18 +88,21 @@ def multi_hop_network() -> Network: router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Configure PC B - pc_b = Computer( - hostname="pc_b", - ip_address="192.168.2.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.2.1", - start_up_duration=0, + pc_b: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "pc_b", + "ip_address": "192.168.2.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.2.1", + "start_up_duration": 0, + } ) pc_b.power_on() network.add_node(pc_b) # Configure Router 2 - router_2 = Router(hostname="router_2", start_up_duration=0) + router_2: Router = Router.from_config(config={"type": "router", "hostname": "router_2", "start_up_duration": 0}) router_2.power_on() network.add_node(router_2) @@ -113,13 +125,13 @@ def multi_hop_network() -> Network: def test_ping_default_gateway(pc_a_pc_b_router_1): pc_a, pc_b, router_1 = pc_a_pc_b_router_1 - assert pc_a.ping(pc_a.default_gateway) + assert pc_a.ping(pc_a.config.default_gateway) def test_ping_other_router_port(pc_a_pc_b_router_1): pc_a, pc_b, router_1 = pc_a_pc_b_router_1 - assert pc_a.ping(pc_b.default_gateway) + assert pc_a.ping(pc_b.config.default_gateway) def test_host_on_other_subnet(pc_a_pc_b_router_1): diff --git a/tests/integration_tests/network/test_wireless_router.py b/tests/integration_tests/network/test_wireless_router.py index 26e50f4a..487736e7 100644 --- a/tests/integration_tests/network/test_wireless_router.py +++ b/tests/integration_tests/network/test_wireless_router.py @@ -17,18 +17,23 @@ def wireless_wan_network(): network = Network() # Configure PC A - pc_a = Computer( - hostname="pc_a", - ip_address="192.168.0.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.0.1", - start_up_duration=0, + pc_a = Computer.from_config( + config={ + "type": "computer", + "hostname": "pc_a", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.0.1", + "start_up_duration": 0, + } ) pc_a.power_on() network.add_node(pc_a) # Configure Router 1 - router_1 = WirelessRouter(hostname="router_1", start_up_duration=0, airspace=network.airspace) + router_1 = WirelessRouter.from_config( + config={"type": "wireless_router", "hostname": "router_1", "start_up_duration": 0, "airspace": network.airspace} + ) router_1.power_on() network.add_node(router_1) @@ -43,18 +48,23 @@ def wireless_wan_network(): router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Configure PC B - pc_b = Computer( - hostname="pc_b", - ip_address="192.168.2.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.2.1", - start_up_duration=0, + pc_b: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "pc_b", + "ip_address": "192.168.2.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.2.1", + "start_up_duration": 0, + } ) pc_b.power_on() network.add_node(pc_b) # Configure Router 2 - router_2 = WirelessRouter(hostname="router_2", start_up_duration=0, airspace=network.airspace) + router_2: WirelessRouter = WirelessRouter.from_config( + config={"type": "wireless_router", "hostname": "router_2", "start_up_duration": 0, "airspace": network.airspace} + ) router_2.power_on() network.add_node(router_2) @@ -98,8 +108,8 @@ def wireless_wan_network_from_config_yaml(): def test_cross_wireless_wan_connectivity(wireless_wan_network): pc_a, pc_b, router_1, router_2 = wireless_wan_network # Ensure that PCs can ping across routers before any frequency change - assert pc_a.ping(pc_a.default_gateway), "PC A should ping its default gateway successfully." - assert pc_b.ping(pc_b.default_gateway), "PC B should ping its default gateway successfully." + assert pc_a.ping(pc_a.config.default_gateway), "PC A should ping its default gateway successfully." + assert pc_b.ping(pc_b.config.default_gateway), "PC B should ping its default gateway successfully." assert pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should ping PC B across routers successfully." assert pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should ping PC A across routers successfully." @@ -109,8 +119,8 @@ def test_cross_wireless_wan_connectivity_from_yaml(wireless_wan_network_from_con pc_a = wireless_wan_network_from_config_yaml.get_node_by_hostname("pc_a") pc_b = wireless_wan_network_from_config_yaml.get_node_by_hostname("pc_b") - assert pc_a.ping(pc_a.default_gateway), "PC A should ping its default gateway successfully." - assert pc_b.ping(pc_b.default_gateway), "PC B should ping its default gateway successfully." + assert pc_a.ping(pc_a.config.default_gateway), "PC A should ping its default gateway successfully." + assert pc_b.ping(pc_b.config.default_gateway), "PC B should ping its default gateway successfully." assert pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should ping PC B across routers successfully." assert pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should ping PC A across routers successfully." diff --git a/tests/integration_tests/system/test_application_on_node.py b/tests/integration_tests/system/test_application_on_node.py index fc7aa69c..e1795a36 100644 --- a/tests/integration_tests/system/test_application_on_node.py +++ b/tests/integration_tests/system/test_application_on_node.py @@ -10,13 +10,16 @@ from primaite.simulator.system.applications.application import Application, Appl @pytest.fixture(scope="function") def populated_node(application_class) -> Tuple[Application, Computer]: - computer: Computer = Computer( - hostname="test_computer", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - shut_down_duration=0, + computer: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "test_computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + "shut_down_duration": 0, + } ) computer.power_on() computer.software_manager.install(application_class) @@ -29,13 +32,16 @@ def populated_node(application_class) -> Tuple[Application, Computer]: def test_application_on_offline_node(application_class): """Test to check that the application cannot be interacted with when node it is on is off.""" - computer: Computer = Computer( - hostname="test_computer", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - shut_down_duration=0, + computer: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "test_computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + "shut_down_duration": 0, + } ) computer.software_manager.install(application_class) diff --git a/tests/integration_tests/system/test_database_on_node.py b/tests/integration_tests/system/test_database_on_node.py index 64b6ddbc..8ad292b2 100644 --- a/tests/integration_tests/system/test_database_on_node.py +++ b/tests/integration_tests/system/test_database_on_node.py @@ -20,11 +20,27 @@ from primaite.simulator.system.software import SoftwareHealthState @pytest.fixture(scope="function") def peer_to_peer() -> Tuple[Computer, Computer]: network = Network() - node_a: Computer = Computer.from_config(config={"type":"computer", "hostname":"node_a", "ip_address":"192.168.0.10", "subnet_mask":"255.255.255.0", "start_up_duration":0}) + node_a: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_a", + "ip_address": "192.168.0.10", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node_a.power_on() node_a.software_manager.get_open_ports() - node_b: Computer = Computer.from_config(config={"type":"computer", "hostname":"node_b", "ip_address":"192.168.0.11", "subnet_mask":"255.255.255.0", "start_up_duration":0}) + node_b: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_b", + "ip_address": "192.168.0.11", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) @@ -412,8 +428,14 @@ def test_database_service_can_terminate_connection(peer_to_peer): def test_client_connection_terminate_does_not_terminate_another_clients_connection(): network = Network() - db_server: Server = Server.from_config(config={"type":"server", - "hostname":"db_client", "ip_address":"192.168.0.11", "subnet_mask":"255.255.255.0", "start_up_duration":0} + db_server: Server = Server.from_config( + config={ + "type": "server", + "hostname": "db_client", + "ip_address": "192.168.0.11", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } ) db_server.power_on() @@ -421,8 +443,14 @@ def test_client_connection_terminate_does_not_terminate_another_clients_connecti db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] # noqa db_service.start() - client_a = Computer( - hostname="client_a", ip_address="192.168.0.12", subnet_mask="255.255.255.0", start_up_duration=0 + client_a = Computer.from_config( + config={ + "type": "computer", + "hostname": "client_a", + "ip_address": "192.168.0.12", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } ) client_a.power_on() @@ -430,8 +458,14 @@ def test_client_connection_terminate_does_not_terminate_another_clients_connecti client_a.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.11")) client_a.software_manager.software["DatabaseClient"].run() - client_b = Computer( - hostname="client_b", ip_address="192.168.0.13", subnet_mask="255.255.255.0", start_up_duration=0 + client_b = Computer.from_config( + config={ + "type": "computer", + "hostname": "client_b", + "ip_address": "192.168.0.13", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } ) client_b.power_on() @@ -439,7 +473,7 @@ def test_client_connection_terminate_does_not_terminate_another_clients_connecti client_b.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.11")) client_b.software_manager.software["DatabaseClient"].run() - switch = Switch(hostname="switch", start_up_duration=0, num_ports=3) + switch = Switch.from_config(config={"type": "switch", "hostname": "switch", "start_up_duration": 0, "num_ports": 3}) switch.power_on() network.connect(endpoint_a=switch.network_interface[1], endpoint_b=db_server.network_interface[1]) @@ -465,6 +499,14 @@ def test_client_connection_terminate_does_not_terminate_another_clients_connecti def test_database_server_install_ftp_client(): - server: Server = Server.from_config(config={"type":"server", "hostname":"db_server", "ip_address":"192.168.1.2", "subnet_mask":"255.255.255.0", "start_up_duration":0}) + server: Server = Server.from_config( + config={ + "type": "server", + "hostname": "db_server", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) server.software_manager.install(DatabaseService) assert server.software_manager.software.get("FTPClient") diff --git a/tests/integration_tests/system/test_user_session_manager_logins.py b/tests/integration_tests/system/test_user_session_manager_logins.py index 0c591a4b..9736232b 100644 --- a/tests/integration_tests/system/test_user_session_manager_logins.py +++ b/tests/integration_tests/system/test_user_session_manager_logins.py @@ -14,21 +14,27 @@ from primaite.simulator.network.hardware.nodes.host.server import Server def client_server_network() -> Tuple[Computer, Server, Network]: network = Network() - client = Computer( - hostname="client", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + client = Computer.from_config( + config={ + "type": "computer", + "hostname": "client", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } ) client.power_on() - server = Server( - hostname="server", - ip_address="192.168.1.3", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + server = Server.from_config( + config={ + "type": "server", + "hostname": "server", + "ip_address": "192.168.1.3", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } ) server.power_on() diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py index efc97ce6..960ef50f 100644 --- a/tests/integration_tests/test_simulation/test_request_response.py +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -111,7 +111,7 @@ def test_request_fails_if_node_off(example_network, node_request): """Test that requests succeed when the node is on, and fail if the node is off.""" net = example_network client_1: HostNode = net.get_node_by_hostname("client_1") - client_1.shut_down_duration = 0 + client_1.config.shut_down_duration = 0 assert client_1.operating_state == NodeOperatingState.ON resp_1 = net.apply_request(node_request) diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py index f6308a21..0b307fe5 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py @@ -17,11 +17,13 @@ def node() -> Node: "hostname": "test", "ip_address": "192.168.1.2", "subnet_mask": "255.255.255.0", + "operating_state": "OFF", } computer = Computer.from_config(config=computer_cfg) return computer + def test_node_startup(node): assert node.operating_state == NodeOperatingState.OFF node.apply_request(["startup"]) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py index 430e3835..914d64e1 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py @@ -36,6 +36,16 @@ def test_create_dns_client(dns_client): def test_dns_client_add_domain_to_cache_when_not_running(dns_client): dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient") + + # shutdown the dns_client + dns_client.power_off() + + # wait for dns_client to turn off + idx = 0 + while dns_client.operating_state == NodeOperatingState.SHUTTING_DOWN: + dns_client.apply_timestep(idx) + idx += 1 + assert dns_client.operating_state is NodeOperatingState.OFF assert dns_client_service.operating_state is ServiceOperatingState.STOPPED diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py index fd193415..6f8664b2 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py @@ -66,7 +66,7 @@ def test_dns_server_receive(dns_server): } client = Computer.from_config(config=client_cfg) client.power_on() - client.dns_server = IPv4Address("192.168.1.10") + client.config.dns_server = IPv4Address("192.168.1.10") network = Network() network.connect(dns_server.network_interface[1], client.network_interface[1]) dns_client: DNSClient = client.software_manager.software["DNSClient"] # noqa diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 1666f008..e5fe2013 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -12,6 +12,7 @@ from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter +from primaite.simulator.network.networks import arcd_uc2_network from primaite.simulator.network.protocols.ssh import ( SSHConnectionMessage, SSHPacket, @@ -29,8 +30,14 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def terminal_on_computer() -> Tuple[Terminal, Computer]: - computer: Computer = Computer.from_config(config={"type":"computer", - "hostname":"node_a", "ip_address":"192.168.0.10", "subnet_mask":"255.255.255.0", "start_up_duration":0} + computer: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_a", + "ip_address": "192.168.0.10", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } ) computer.power_on() terminal: Terminal = computer.software_manager.software.get("Terminal") @@ -41,19 +48,27 @@ def terminal_on_computer() -> Tuple[Terminal, Computer]: @pytest.fixture(scope="function") def basic_network() -> Network: network = Network() - node_a = Computer.from_config(config={"type":"computer", - "hostname":"node_a", - "ip_address":"192.168.0.10", - "subnet_mask":"255.255.255.0", - "start_up_duration":0}) + node_a = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_a", + "ip_address": "192.168.0.10", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node_a.power_on() node_a.software_manager.get_open_ports() - node_b = Computer.from_config(config={"type":"computer", - "hostname":"node_b", - "ip_address":"192.168.0.11", - "subnet_mask":"255.255.255.0", - "start_up_duration":0}) + node_b = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_b", + "ip_address": "192.168.0.11", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) @@ -65,20 +80,23 @@ def wireless_wan_network(): network = Network() # Configure PC A - pc_a_cfg = {"type": "computer", - "hostname":"pc_a", - "ip_address":"192.168.0.2", - "subnet_mask":"255.255.255.0", - "default_gateway":"192.168.0.1", - "start_up_duration":0, - } + pc_a_cfg = { + "type": "computer", + "hostname": "pc_a", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.0.1", + "start_up_duration": 0, + } pc_a = Computer.from_config(config=pc_a_cfg) pc_a.power_on() network.add_node(pc_a) # Configure Router 1 - router_1 = WirelessRouter.from_config(config={"type":"wireless_router", "hostname":"router_1", "start_up_duration":0, "airspace":network.airspace}) + router_1 = WirelessRouter.from_config( + config={"type": "wireless_router", "hostname": "router_1", "start_up_duration": 0, "airspace": network.airspace} + ) router_1.power_on() network.add_node(router_1) @@ -99,43 +117,29 @@ def wireless_wan_network(): # Configure PC B - pc_b_cfg = {"type": "computer", - "hostname":"pc_b", - "ip_address":"192.168.2.2", - "subnet_mask":"255.255.255.0", - "default_gateway":"192.168.2.1", - "start_up_duration":0, - } + pc_b_cfg = { + "type": "computer", + "hostname": "pc_b", + "ip_address": "192.168.2.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.2.1", + "start_up_duration": 0, + } pc_b = Computer.from_config(config=pc_b_cfg) pc_b.power_on() network.add_node(pc_b) - # Configure Router 2 - router_2 = WirelessRouter.from_config(config={"type":"wireless_router", "hostname":"router_2", "start_up_duration":0, "airspace":network.airspace}) - router_2.power_on() - network.add_node(router_2) - - # Configure the connection between PC B and Router 2 port 2 - router_2.configure_router_interface("192.168.2.1", "255.255.255.0") - network.connect(pc_b.network_interface[1], router_2.network_interface[2]) - # Configure Router 2 ACLs # Configure the wireless connection between Router 1 port 1 and Router 2 port 1 router_1.configure_wireless_access_point("192.168.1.1", "255.255.255.0") - router_2.configure_wireless_access_point("192.168.1.2", "255.255.255.0") router_1.route_table.add_route( address="192.168.2.0", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.2" ) - # Configure Route from Router 2 to PC A subnet - router_2.route_table.add_route( - address="192.168.0.2", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.1" - ) - - return pc_a, pc_b, router_1, router_2 + return network @pytest.fixture @@ -156,11 +160,15 @@ def test_terminal_creation(terminal_on_computer): def test_terminal_install_default(): """Terminal should be auto installed onto Nodes""" - computer: Computer = Computer.from_config(config={"type":"computer", - "hostname":"node_a", - "ip_address":"192.168.0.10", - "subnet_mask":"255.255.255.0", - "start_up_duration":0}) + computer: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_a", + "ip_address": "192.168.0.10", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) computer.power_on() assert computer.software_manager.software.get("Terminal") @@ -168,7 +176,7 @@ def test_terminal_install_default(): def test_terminal_not_on_switch(): """Ensure terminal does not auto-install to switch""" - test_switch = Switch.from_config(config={"type":"switch", "hostname":"Test"}) + test_switch = Switch.from_config(config={"type": "switch", "hostname": "Test"}) assert not test_switch.software_manager.software.get("Terminal") @@ -291,7 +299,10 @@ def test_terminal_ignores_when_off(basic_network): def test_computer_remote_login_to_router(wireless_wan_network): """Test to confirm that a computer can SSH into a router.""" - pc_a, _, router_1, _ = wireless_wan_network + + pc_a = wireless_wan_network.get_node_by_hostname("pc_a") + + router_1 = wireless_wan_network.get_node_by_hostname("router_1") pc_a_terminal: Terminal = pc_a.software_manager.software.get("Terminal") @@ -310,7 +321,9 @@ def test_computer_remote_login_to_router(wireless_wan_network): def test_router_remote_login_to_computer(wireless_wan_network): """Test to confirm that a router can ssh into a computer.""" - pc_a, _, router_1, _ = wireless_wan_network + pc_a = wireless_wan_network.get_node_by_hostname("pc_a") + + router_1 = wireless_wan_network.get_node_by_hostname("router_1") router_1_terminal: Terminal = router_1.software_manager.software.get("Terminal") @@ -329,7 +342,9 @@ def test_router_remote_login_to_computer(wireless_wan_network): def test_router_blocks_SSH_traffic(wireless_wan_network): """Test to check that router will block SSH traffic if no ACL rule.""" - pc_a, _, router_1, _ = wireless_wan_network + pc_a = wireless_wan_network.get_node_by_hostname("pc_a") + + router_1 = wireless_wan_network.get_node_by_hostname("router_1") # Remove rule that allows SSH traffic. router_1.acl.remove_rule(position=21) @@ -343,20 +358,22 @@ def test_router_blocks_SSH_traffic(wireless_wan_network): assert len(pc_a_terminal._connections) == 0 -def test_SSH_across_network(wireless_wan_network): +def test_SSH_across_network(): """Test to show ability to SSH across a network.""" - pc_a, pc_b, router_1, router_2 = wireless_wan_network + network: Network = arcd_uc2_network() + pc_a = network.get_node_by_hostname("client_1") + router_1 = network.get_node_by_hostname("router_1") terminal_a: Terminal = pc_a.software_manager.software.get("Terminal") - terminal_b: Terminal = pc_b.software_manager.software.get("Terminal") - router_2.acl.add_rule( + router_1.acl.add_rule( action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=21 ) assert len(terminal_a._connections) == 0 - terminal_b_on_terminal_a = terminal_b.login(username="admin", password="admin", ip_address="192.168.0.2") + # Login to the Domain Controller + terminal_a.login(username="admin", password="admin", ip_address="192.168.1.10") assert len(terminal_a._connections) == 1 From 4b42a74ac89fb5f425a86545cc934e54b970da5b Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 29 Jan 2025 16:57:25 +0000 Subject: [PATCH 12/23] #2887 - Corrected failures seen when generating services from config & syntax issues. Wireless Router tests currently fail due to port 1 being disabled on startup --- .../hardware/nodes/network/wireless_router.py | 6 +++--- .../system/services/database/database_service.py | 7 +++---- .../simulator/system/services/dns/dns_client.py | 6 ++++-- .../simulator/system/services/ftp/ftp_server.py | 7 +++++-- .../simulator/system/services/ntp/ntp_client.py | 14 ++++++++------ tests/conftest.py | 1 - .../extensions/services/extended_service.py | 6 +++--- 7 files changed, 26 insertions(+), 21 deletions(-) diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 2ca854d4..5e52de7e 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -263,6 +263,9 @@ class WirelessRouter(Router, identifier="wireless_router"): :rtype: WirelessRouter """ router = cls(config=cls.ConfigSchema(**config)) + router.operating_state = ( + NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()] + ) if "router_interface" in config: ip_address = config["router_interface"]["ip_address"] subnet_mask = config["router_interface"]["subnet_mask"] @@ -294,7 +297,4 @@ class WirelessRouter(Router, identifier="wireless_router"): next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")), metric=float(route.get("metric", 0)), ) - router.operating_state = ( - NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()] - ) return router diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 4ba4c4d4..fc56483d 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -31,12 +31,11 @@ class DatabaseService(Service, identifier="DatabaseService"): type: str = "DatabaseService" backup_server_ip: Optional[IPv4Address] = None + db_password: Optional[str] = None + """Password that needs to be provided by clients if they want to connect to the DatabaseService.""" config: "DatabaseService.ConfigSchema" = Field(default_factory=lambda: DatabaseService.ConfigSchema()) - password: Optional[str] = None - """Password that needs to be provided by clients if they want to connect to the DatabaseService.""" - backup_server_ip: IPv4Address = None """IP address of the backup server.""" @@ -217,7 +216,7 @@ class DatabaseService(Service, identifier="DatabaseService"): SoftwareHealthState.FIXING, SoftwareHealthState.COMPROMISED, ]: - if self.password == password: + if self.config.db_password == password: status_code = 200 # ok connection_id = self._generate_connection_id() # try to create connection diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 0756eb05..3ff5b930 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -22,11 +22,13 @@ class DNSClient(Service, identifier="DNSClient"): type: str = "DNSClient" + dns_server: Optional[IPv4Address] = None + "The DNS Server the client sends requests to." + config: "DNSClient.ConfigSchema" = Field(default_factory=lambda: DNSClient.ConfigSchema()) dns_cache: Dict[str, IPv4Address] = {} "A dict of known mappings between domain/URLs names and IPv4 addresses." - dns_server: Optional[IPv4Address] = None - "The DNS Server the client sends requests to." + def __init__(self, **kwargs): kwargs["name"] = "DNSClient" diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index 054bfe15..a5b59ec9 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -23,14 +23,17 @@ class FTPServer(FTPServiceABC, identifier="FTPServer"): config: "FTPServer.ConfigSchema" = Field(default_factory=lambda: FTPServer.ConfigSchema()) - server_password: Optional[str] = None - """Password needed to connect to FTP server. Default is None.""" + class ConfigSchema(Service.ConfigSchema): """ConfigSchema for FTPServer.""" type: str = "FTPServer" + server_password: Optional[str] = None + """Password needed to connect to FTP server. Default is None.""" + + def __init__(self, **kwargs): kwargs["name"] = "FTPServer" kwargs["port"] = PORT_LOOKUP["FTP"] diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index fb470faf..b27d1241 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -22,10 +22,12 @@ class NTPClient(Service, identifier="NTPClient"): type: str = "NTPClient" + ntp_server_ip: Optional[IPv4Address] = None + "The NTP server the client sends requests to." + config: "NTPClient.ConfigSchema" = Field(default_factory=lambda: NTPClient.ConfigSchema()) - ntp_server: Optional[IPv4Address] = None - "The NTP server the client sends requests to." + time: Optional[datetime] = None def __init__(self, **kwargs): @@ -42,8 +44,8 @@ class NTPClient(Service, identifier="NTPClient"): :param ntp_server_ip_address: IPv4 address of NTP server. :param ntp_client_ip_Address: IPv4 address of NTP client. """ - self.ntp_server = ntp_server_ip_address - self.sys_log.info(f"{self.name}: ntp_server: {self.ntp_server}") + self.config.ntp_server_ip = ntp_server_ip_address + self.sys_log.info(f"{self.name}: ntp_server: {self.config.ntp_server_ip}") def describe_state(self) -> Dict: """ @@ -105,10 +107,10 @@ class NTPClient(Service, identifier="NTPClient"): def request_time(self) -> None: """Send request to ntp_server.""" - if self.ntp_server: + if self.config.ntp_server_ip: self.software_manager.session_manager.receive_payload_from_software_manager( payload=NTPPacket(), - dst_ip_address=self.ntp_server, + dst_ip_address=self.config.ntp_server_ip, src_port=self.port, dst_port=self.port, ip_protocol=self.protocol, diff --git a/tests/conftest.py b/tests/conftest.py index 6ac227ef..765ed8dc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -387,7 +387,6 @@ def install_stuff_to_sim(sim: Simulation): "ip_address": "10.0.1.2", "subnet_mask": "255.255.255.0", "default_gateway": "10.0.1.1", - "start_up_duration": 0, } client_1: Computer = Computer.from_config(config=client_1_cfg) client_1.power_on() diff --git a/tests/integration_tests/extensions/services/extended_service.py b/tests/integration_tests/extensions/services/extended_service.py index ba247369..11adc53b 100644 --- a/tests/integration_tests/extensions/services/extended_service.py +++ b/tests/integration_tests/extensions/services/extended_service.py @@ -31,14 +31,14 @@ class ExtendedService(Service, identifier="ExtendedService"): type: str = "ExtendedService" + backup_server_ip: IPv4Address = None + """IP address of the backup server.""" + config: "ExtendedService.ConfigSchema" = Field(default_factory=lambda: ExtendedService.ConfigSchema()) password: Optional[str] = None """Password that needs to be provided by clients if they want to connect to the DatabaseService.""" - backup_server_ip: IPv4Address = None - """IP address of the backup server.""" - latest_backup_directory: str = None """Directory of latest backup.""" From 3d47b9c8638322b537d6346612f71a7a69b309ad Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Thu, 30 Jan 2025 17:33:00 +0000 Subject: [PATCH 13/23] #2887 - Further fixes to unit tests --- src/primaite/simulator/network/hardware/base.py | 5 ++++- src/primaite/simulator/system/core/software_manager.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 08e100d2..8cbe2b87 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1532,6 +1532,7 @@ class Node(SimComponent, ABC): model_config = ConfigDict(arbitrary_types_allowed=True) """Configure pydantic to allow arbitrary types, let the instance have attributes not present in the model.""" + hostname: str = "default" "The node hostname on the network." @@ -1568,6 +1569,8 @@ class Node(SimComponent, ABC): default_gateway: Optional[IPV4Address] = None "The default gateway IP address for forwarding network traffic to other networks." + operating_state: Any = None + @property def dns_server(self) -> Optional[IPv4Address]: return self.config.dns_server @@ -1579,7 +1582,6 @@ class Node(SimComponent, ABC): msg = f"Configuration contains an invalid Node type: {config['type']}" return ValueError(msg) obj = cls(config=cls.ConfigSchema(**config)) - obj.operating_state = NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()] return obj def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None: @@ -1623,6 +1625,7 @@ class Node(SimComponent, ABC): dns_server=kwargs["config"].dns_server, ) super().__init__(**kwargs) + self.operating_state = NodeOperatingState.ON if not (p := kwargs["config"].operating_state) else NodeOperatingState[p.upper()] self._install_system_software() self.session_manager.node = self self.session_manager.software_manager = self.software_manager diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index ddb30a3b..0f7aa936 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -140,6 +140,7 @@ class SoftwareManager: elif isinstance(software, Service): self.node.services[software.uuid] = software self.node._service_request_manager.add_request(software.name, RequestType(func=software._request_manager)) + software.start() software.install() software.software_manager = self self.software[software.name] = software From d806391625a1569eba51f1cfbbc877b420909b87 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Fri, 31 Jan 2025 18:46:02 +0000 Subject: [PATCH 14/23] #2887 - Test fixes --- src/primaite/simulator/network/hardware/base.py | 2 +- .../network/hardware/nodes/network/wireless_router.py | 1 + .../game_layer/observations/test_firewall_observation.py | 4 ++-- tests/integration_tests/system/test_database_on_node.py | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 8cbe2b87..732b79f5 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1564,7 +1564,7 @@ class Node(SimComponent, ABC): "Time steps until reveal to red scan is complete." dns_server: Optional[IPv4Address] = None - "List of IP addresses of DNS servers used for name resolution." + "List of IP addresses of DNS servers used for name resolution." default_gateway: Optional[IPV4Address] = None "The default gateway IP address for forwarding network traffic to other networks." diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 5e52de7e..348c2aaa 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -130,6 +130,7 @@ class WirelessRouter(Router, identifier="wireless_router"): hostname: str = "WirelessRouter" airspace: AirSpace + num_ports: int = 0 def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 17c7775f..874fa49e 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -25,7 +25,7 @@ def check_default_rules(acl_obs): def test_firewall_observation(): """Test adding/removing acl rules and enabling/disabling ports.""" net = Network() - firewall_cfg = {"type": "firewall", "hostname": "firewall", "opertating_state": NodeOperatingState.ON} + firewall_cfg = {"type": "firewall", "hostname": "firewall"} firewall = Firewall.from_config(config=firewall_cfg) firewall_observation = FirewallObservation( where=[], @@ -118,7 +118,7 @@ def test_firewall_observation(): # connect a switch to the firewall and check that only the correct port is updated switch: Switch = Switch.from_config( - config={"type": "switch", "hostname": "switch", "num_ports": 1, "operating_state": NodeOperatingState.ON} + config={"type": "switch", "hostname": "switch", "num_ports": 1, "operating_state": "ON"} ) link = net.connect(firewall.network_interface[1], switch.network_interface[1]) assert firewall.network_interface[1].enabled diff --git a/tests/integration_tests/system/test_database_on_node.py b/tests/integration_tests/system/test_database_on_node.py index 8ad292b2..59e50659 100644 --- a/tests/integration_tests/system/test_database_on_node.py +++ b/tests/integration_tests/system/test_database_on_node.py @@ -62,7 +62,7 @@ def peer_to_peer_secure_db(peer_to_peer) -> Tuple[Computer, Computer]: database_service: DatabaseService = node_b.software_manager.software["DatabaseService"] # noqa database_service.stop() - database_service.password = "12345" + database_service.config.db_password = "12345" database_service.start() return node_a, node_b From 3d01f52eea304906022c426427b04f8ef0ecbc84 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 3 Feb 2025 11:18:34 +0000 Subject: [PATCH 15/23] #2887 - Merge in changes on dev to resolve conflicts. All tests should now pass --- src/primaite/game/game.py | 1 - .../simulator/network/hardware/base.py | 8 +++-- .../network/hardware/nodes/host/computer.py | 2 +- .../network/hardware/nodes/host/host_node.py | 5 ++- src/primaite/simulator/network/networks.py | 32 +++++++++++-------- .../system/services/ftp/ftp_server.py | 1 - .../system/services/ntp/ntp_client.py | 1 - tests/conftest.py | 6 ++-- .../nodes/test_node_config.py | 2 +- .../game_layer/test_observations.py | 7 ++-- 10 files changed, 36 insertions(+), 29 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index b52d506f..e9941a12 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -1,6 +1,5 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """PrimAITE game - Encapsulates the simulation and agents.""" -from ipaddress import IPv4Address from typing import Dict, List, Optional, Union import numpy as np diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 23ffbcf8..06b0dbe4 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1493,6 +1493,7 @@ class Node(SimComponent, ABC): :param hostname: The node hostname on the network. :param operating_state: The node operating state, either ON or OFF. """ + operating_state: NodeOperatingState = NodeOperatingState.OFF "The hardware state of the node." network_interfaces: Dict[str, NetworkInterface] = {} @@ -1564,7 +1565,7 @@ class Node(SimComponent, ABC): "Time steps until reveal to red scan is complete." dns_server: Optional[IPv4Address] = None - "List of IP addresses of DNS servers used for name resolution." + "List of IP addresses of DNS servers used for name resolution." default_gateway: Optional[IPV4Address] = None "The default gateway IP address for forwarding network traffic to other networks." @@ -1573,6 +1574,7 @@ class Node(SimComponent, ABC): @property def dns_server(self) -> Optional[IPv4Address]: + """Convenience method to access the dns_server IP.""" return self.config.dns_server @classmethod @@ -1625,7 +1627,9 @@ class Node(SimComponent, ABC): dns_server=kwargs["config"].dns_server, ) super().__init__(**kwargs) - self.operating_state = NodeOperatingState.ON if not (p := kwargs["config"].operating_state) else NodeOperatingState[p.upper()] + self.operating_state = ( + NodeOperatingState.ON if not (p := kwargs["config"].operating_state) else NodeOperatingState[p.upper()] + ) self._install_system_software() self.session_manager.node = self self.session_manager.software_manager = self.software_manager diff --git a/src/primaite/simulator/network/hardware/nodes/host/computer.py b/src/primaite/simulator/network/hardware/nodes/host/computer.py index 1aebc3af..85857a44 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/computer.py +++ b/src/primaite/simulator/network/hardware/nodes/host/computer.py @@ -37,7 +37,7 @@ class Computer(HostNode, identifier="computer"): SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "FTPClient": FTPClient} - config: "Computer.ConfigSchema" = Field(default_factory=lambda: Computer.ConfigSchema()) + config: "Computer.ConfigSchema" = Field(default_factory=lambda: Computer.ConfigSchema()) class ConfigSchema(HostNode.ConfigSchema): """Configuration Schema for Computer class.""" diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 3b1d8e48..424e39f1 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -55,7 +55,10 @@ class HostARP(ARP): :return: The NIC associated with the default gateway if it exists in the ARP cache; otherwise, None. """ - if self.software_manager.node.config.default_gateway and self.software_manager.node.has_enabled_network_interface: + if ( + self.software_manager.node.config.default_gateway + and self.software_manager.node.has_enabled_network_interface + ): return self.get_arp_cache_network_interface(self.software_manager.node.config.default_gateway) def _get_arp_cache_mac_address( diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index 644b2a4a..75978bee 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -58,24 +58,28 @@ def client_server_routed() -> Network: router_1.enable_port(2) # Client 1 - client_1 = Computer(config=dict( - hostname="client_1", - ip_address="192.168.2.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.2.1", - start_up_duration=0, - )) + client_1 = Computer( + config=dict( + hostname="client_1", + ip_address="192.168.2.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.2.1", + start_up_duration=0, + ) + ) client_1.power_on() network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) # Server 1 - server_1 = Server(config=dict( - hostname="server_1", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - )) + server_1 = Server( + config=dict( + hostname="server_1", + ip_address="192.168.1.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + ) server_1.power_on() network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1]) diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index 623bdf90..ebb93a7b 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -43,7 +43,6 @@ class FTPServer(FTPServiceABC, identifier="FTPServer"): """Convenience method for accessing FTP server password.""" return self.config.server_password - def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket: """ Process the command in the FTP Packet. diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index 81947395..72e8f6c0 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -29,7 +29,6 @@ class NTPClient(Service, identifier="NTPClient"): config: "NTPClient.ConfigSchema" = Field(default_factory=lambda: NTPClient.ConfigSchema()) - time: Optional[datetime] = None def __init__(self, **kwargs): diff --git a/tests/conftest.py b/tests/conftest.py index b3cac34a..c1d44aec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -232,7 +232,7 @@ def example_network() -> Network: # Router 1 - router_1_cfg = {"hostname": "router_1", "type": "router", "start_up_duration":0} + router_1_cfg = {"hostname": "router_1", "type": "router", "start_up_duration": 0} # router_1 = Router(hostname="router_1", start_up_duration=0) router_1 = Router.from_config(config=router_1_cfg) @@ -253,7 +253,7 @@ def example_network() -> Network: router_1.enable_port(1) # Switch 2 - switch_2_config = {"hostname": "switch_2", "type": "switch", "num_ports": 8, "start_up_duration":0} + switch_2_config = {"hostname": "switch_2", "type": "switch", "num_ports": 8, "start_up_duration": 0} # switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0) switch_2 = Switch.from_config(config=switch_2_config) switch_2.power_on() @@ -387,7 +387,7 @@ def install_stuff_to_sim(sim: Simulation): "ip_address": "10.0.1.2", "subnet_mask": "255.255.255.0", "default_gateway": "10.0.1.1", - "start_up_duration":0, + "start_up_duration": 0, } client_1: Computer = Computer.from_config(config=client_1_cfg) client_1.power_on() diff --git a/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py b/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py index 6ccbf4e1..f3911691 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py @@ -3,8 +3,8 @@ from primaite.config.load import data_manipulation_config_path from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from primaite.simulator.network.hardware.nodes.network.firewall import Firewall +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ_NETWORK, load_config diff --git a/tests/integration_tests/game_layer/test_observations.py b/tests/integration_tests/game_layer/test_observations.py index 090725b5..17b9b71e 100644 --- a/tests/integration_tests/game_layer/test_observations.py +++ b/tests/integration_tests/game_layer/test_observations.py @@ -8,10 +8,9 @@ from primaite.simulator.sim_container import Simulation def test_file_observation(): sim = Simulation() - pc: Computer = Computer.from_config(config={"type":"computer", - "hostname":"beep", - "ip_address":"123.123.123.123", - "subnet_mask":"255.255.255.0"}) + pc: Computer = Computer.from_config( + config={"type": "computer", "hostname": "beep", "ip_address": "123.123.123.123", "subnet_mask": "255.255.255.0"} + ) sim.network.add_node(pc) f = pc.file_system.create_file(file_name="dog.png") From 0920ec5f5b47a43be2d93e6f22f8af6ae98eb0f5 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 3 Feb 2025 11:32:07 +0000 Subject: [PATCH 16/23] #2887 - Remove debug print statements --- src/primaite/simulator/network/hardware/base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 06b0dbe4..0564e1f3 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -2100,17 +2100,14 @@ class Node(SimComponent, ABC): def power_on(self) -> bool: """Power on the Node, enabling its NICs if it is in the OFF state.""" - print("HI") if self.config.start_up_duration <= 0: self.operating_state = NodeOperatingState.ON - print(f"Powering On: f{self.config.hostname}") self._start_up_actions() self.sys_log.info("Power on") for network_interface in self.network_interfaces.values(): network_interface.enable() return True if self.operating_state == NodeOperatingState.OFF: - print("OOOOOOOOOOOOOOOOOOH") self.operating_state = NodeOperatingState.BOOTING self.config.start_up_countdown = self.config.start_up_duration return True From f3bbfffe7f4503bd4d14ddbc2d7c08766962c8eb Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 3 Feb 2025 14:03:21 +0000 Subject: [PATCH 17/23] #2887 - Update CHANGELOG.md --- CHANGELOG.md | 2 ++ docs/source/how_to_guides/extensible_nodes.rst | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c91bf4f4..7f87f54e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Relabeled action parameters to match the new action config schemas, and updated the values to no longer rely on indices - Removed action space options which were previously used for assigning meaning to action space IDs - Updated tests that don't use YAMLs to still use the new action and agent schemas +- Nodes now use a config schema and are extensible, allowing for plugin support. +- Node tests have been updated to use the new node config schemas when not using YAML files. ### Fixed - DNS client no longer fails to check its cache if a DNS server address is missing. diff --git a/docs/source/how_to_guides/extensible_nodes.rst b/docs/source/how_to_guides/extensible_nodes.rst index f0b78b08..78ee550e 100644 --- a/docs/source/how_to_guides/extensible_nodes.rst +++ b/docs/source/how_to_guides/extensible_nodes.rst @@ -48,8 +48,6 @@ class Router(NetworkNode, identifier="router"): hostname: str = "Router" - ports: list = [] - Changes to YAML file. From c1a5a26ffca6beb906bad50c0855d6a6dbc9252c Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 4 Feb 2025 10:21:56 +0000 Subject: [PATCH 18/23] #2887 - Actioning review comments --- src/primaite/simulator/core.py | 2 +- .../simulator/network/hardware/base.py | 8 ++++---- .../network/hardware/nodes/host/computer.py | 4 ++-- .../network/hardware/nodes/host/host_node.py | 4 ++-- .../network/hardware/nodes/host/server.py | 8 ++++---- .../network/hardware/nodes/network/firewall.py | 4 ++-- .../network/hardware/nodes/network/router.py | 18 +++++++++--------- .../network/hardware/nodes/network/switch.py | 4 ++-- .../hardware/nodes/network/wireless_router.py | 4 ++-- .../system/services/dns/dns_client.py | 4 ++-- .../system/services/ftp/ftp_server.py | 5 ++--- .../system/services/ntp/ntp_client.py | 4 +--- 12 files changed, 33 insertions(+), 36 deletions(-) diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index dc4ae73b..567a0493 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -3,7 +3,7 @@ """Core of the PrimAITE Simulator.""" import warnings from abc import abstractmethod -from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union +from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union from uuid import uuid4 from prettytable import PrettyTable diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 0564e1f3..36623a6f 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1525,16 +1525,13 @@ class Node(SimComponent, ABC): _identifier: ClassVar[str] = "unknown" """Identifier for this particular class, used for printing and logging. Each subclass redefines this.""" - config: Node.ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema()) - """Configuration items within Node""" - class ConfigSchema(BaseModel, ABC): """Configuration Schema for Node based classes.""" model_config = ConfigDict(arbitrary_types_allowed=True) """Configure pydantic to allow arbitrary types, let the instance have attributes not present in the model.""" - hostname: str = "default" + hostname: str "The node hostname on the network." revealed_to_red: bool = False @@ -1572,6 +1569,9 @@ class Node(SimComponent, ABC): operating_state: Any = None + config: ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema()) + """Configuration items within Node""" + @property def dns_server(self) -> Optional[IPv4Address]: """Convenience method to access the dns_server IP.""" diff --git a/src/primaite/simulator/network/hardware/nodes/host/computer.py b/src/primaite/simulator/network/hardware/nodes/host/computer.py index 85857a44..a0450443 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/computer.py +++ b/src/primaite/simulator/network/hardware/nodes/host/computer.py @@ -37,11 +37,11 @@ class Computer(HostNode, identifier="computer"): SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "FTPClient": FTPClient} - config: "Computer.ConfigSchema" = Field(default_factory=lambda: Computer.ConfigSchema()) - class ConfigSchema(HostNode.ConfigSchema): """Configuration Schema for Computer class.""" hostname: str = "Computer" + config: ConfigSchema = Field(default_factory=lambda: Computer.ConfigSchema()) + pass diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 424e39f1..1aa482db 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -330,8 +330,6 @@ class HostNode(Node, identifier="HostNode"): network_interface: Dict[int, NIC] = {} "The NICs on the node by port id." - config: HostNode.ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema()) - class ConfigSchema(Node.ConfigSchema): """Configuration Schema for HostNode class.""" @@ -339,6 +337,8 @@ class HostNode(Node, identifier="HostNode"): subnet_mask: IPV4Address = "255.255.255.0" ip_address: IPV4Address + config: ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema()) + def __init__(self, **kwargs): super().__init__(**kwargs) self.connect_nic(NIC(ip_address=kwargs["config"].ip_address, subnet_mask=kwargs["config"].subnet_mask)) diff --git a/src/primaite/simulator/network/hardware/nodes/host/server.py b/src/primaite/simulator/network/hardware/nodes/host/server.py index bdf4e8c2..3a9fc2f9 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/server.py +++ b/src/primaite/simulator/network/hardware/nodes/host/server.py @@ -33,22 +33,22 @@ class Server(HostNode, identifier="server"): * Web Browser """ - config: "Server.ConfigSchema" = Field(default_factory=lambda: Server.ConfigSchema()) - class ConfigSchema(HostNode.ConfigSchema): """Configuration Schema for Server class.""" hostname: str = "server" + config: ConfigSchema = Field(default_factory=lambda: Server.ConfigSchema()) + class Printer(HostNode, identifier="printer"): """Printer? I don't even know her!.""" # TODO: Implement printer-specific behaviour - config: "Printer.ConfigSchema" = Field(default_factory=lambda: Printer.ConfigSchema()) - class ConfigSchema(HostNode.ConfigSchema): """Configuration Schema for Printer class.""" hostname: str = "printer" + + config: ConfigSchema = Field(default_factory=lambda: Printer.ConfigSchema()) diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 99dd48c4..c4ddea58 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -100,14 +100,14 @@ class Firewall(Router, identifier="firewall"): _identifier: str = "firewall" - config: "Firewall.ConfigSchema" = Field(default_factory=lambda: Firewall.ConfigSchema()) - class ConfigSchema(Router.ConfigSchema): """Configuration Schema for Firewall 'Nodes' within PrimAITE.""" hostname: str = "firewall" num_ports: int = 0 + config: ConfigSchema = Field(default_factory=lambda: Firewall.ConfigSchema()) + def __init__(self, **kwargs): if not kwargs.get("sys_log"): kwargs["sys_log"] = SysLog(kwargs["config"].hostname) diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index dd32fa31..7138cf4f 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -7,7 +7,7 @@ from ipaddress import IPv4Address, IPv4Network from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union from prettytable import MARKDOWN, PrettyTable -from pydantic import validate_call +from pydantic import Field, validate_call from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent @@ -1201,6 +1201,14 @@ class Router(NetworkNode, identifier="router"): RouteTable, RouterARP, and RouterICMP services. """ + class ConfigSchema(NetworkNode.ConfigSchema): + """Configuration Schema for Routers.""" + + hostname: str = "router" + num_ports: int = 5 + + config: ConfigSchema = Field(default_factory=lambda: Router.ConfigSchema()) + SYSTEM_SOFTWARE: ClassVar[Dict] = { "UserSessionManager": UserSessionManager, "UserManager": UserManager, @@ -1214,14 +1222,6 @@ class Router(NetworkNode, identifier="router"): acl: AccessControlList route_table: RouteTable - config: "Router.ConfigSchema" - - class ConfigSchema(NetworkNode.ConfigSchema): - """Configuration Schema for Routers.""" - - hostname: str = "router" - num_ports: int = 5 - def __init__(self, **kwargs): if not kwargs.get("sys_log"): kwargs["sys_log"] = SysLog(kwargs["config"].hostname) diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index 3cb335f7..54e1d7ef 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -98,8 +98,6 @@ class Switch(NetworkNode, identifier="switch"): mac_address_table: Dict[str, SwitchPort] = {} "A MAC address table mapping destination MAC addresses to corresponding SwitchPorts." - config: "Switch.ConfigSchema" = Field(default_factory=lambda: Switch.ConfigSchema()) - class ConfigSchema(NetworkNode.ConfigSchema): """Configuration Schema for Switch nodes within PrimAITE.""" @@ -107,6 +105,8 @@ class Switch(NetworkNode, identifier="switch"): num_ports: int = 24 "The number of ports on the switch. Default is 24." + config: ConfigSchema = Field(default_factory=lambda: Switch.ConfigSchema()) + def __init__(self, **kwargs): super().__init__(**kwargs) for i in range(1, kwargs["config"].num_ports + 1): diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 348c2aaa..2beb03d6 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -123,8 +123,6 @@ class WirelessRouter(Router, identifier="wireless_router"): network_interfaces: Dict[str, Union[RouterInterface, WirelessAccessPoint]] = {} network_interface: Dict[int, Union[RouterInterface, WirelessAccessPoint]] = {} - config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.ConfigSchema()) - class ConfigSchema(Router.ConfigSchema): """Configuration Schema for WirelessRouter nodes within PrimAITE.""" @@ -132,6 +130,8 @@ class WirelessRouter(Router, identifier="wireless_router"): airspace: AirSpace num_ports: int = 0 + config: ConfigSchema = Field(default_factory=lambda: WirelessRouter.ConfigSchema()) + def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 4a1fd292..4f926a25 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -25,12 +25,12 @@ class DNSClient(Service, identifier="DNSClient"): """ConfigSchema for DNSClient.""" type: str = "DNSClient" - dns_server: Optional[IPV4Address] = None dns_server: Optional[IPv4Address] = None "The DNS Server the client sends requests to." - config: "DNSClient.ConfigSchema" = Field(default_factory=lambda: DNSClient.ConfigSchema()) + config: ConfigSchema = Field(default_factory=lambda: DNSClient.ConfigSchema()) + dns_cache: Dict[str, IPv4Address] = {} "A dict of known mappings between domain/URLs names and IPv4 addresses." diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index ebb93a7b..147d2dbb 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -20,17 +20,16 @@ class FTPServer(FTPServiceABC, identifier="FTPServer"): RFC 959: https://datatracker.ietf.org/doc/html/rfc959 """ - config: "FTPServer.ConfigSchema" = Field(default_factory=lambda: FTPServer.ConfigSchema()) - class ConfigSchema(FTPServiceABC.ConfigSchema): """ConfigSchema for FTPServer.""" type: str = "FTPServer" - server_password: Optional[str] = None server_password: Optional[str] = None """Password needed to connect to FTP server. Default is None.""" + config: ConfigSchema = Field(default_factory=lambda: FTPServer.ConfigSchema()) + def __init__(self, **kwargs): kwargs["name"] = "FTPServer" kwargs["port"] = PORT_LOOKUP["FTP"] diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index 72e8f6c0..7c3efd25 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -9,7 +9,6 @@ from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket from primaite.simulator.system.services.service import Service, ServiceOperatingState from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP -from primaite.utils.validation.ipv4_address import IPV4Address from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) @@ -22,12 +21,11 @@ class NTPClient(Service, identifier="NTPClient"): """ConfigSchema for NTPClient.""" type: str = "NTPClient" - ntp_server_ip: Optional[IPV4Address] = None ntp_server_ip: Optional[IPv4Address] = None "The NTP server the client sends requests to." - config: "NTPClient.ConfigSchema" = Field(default_factory=lambda: NTPClient.ConfigSchema()) + config: ConfigSchema = Field(default_factory=lambda: NTPClient.ConfigSchema()) time: Optional[datetime] = None From 961136fb4213d2274800760ff220d79bf5f59382 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 4 Feb 2025 10:41:51 +0000 Subject: [PATCH 19/23] #2887 - Updates to extensible_nodes.rst --- docs/source/how_to_guides/extensible_nodes.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/how_to_guides/extensible_nodes.rst b/docs/source/how_to_guides/extensible_nodes.rst index 78ee550e..6651b618 100644 --- a/docs/source/how_to_guides/extensible_nodes.rst +++ b/docs/source/how_to_guides/extensible_nodes.rst @@ -53,4 +53,4 @@ class Router(NetworkNode, identifier="router"): Changes to YAML file. ===================== -Nodes defined within configuration YAML files for use with PrimAITE 3.X should still be compatible following these changes. +While effort has been made to ensure that nodes defined within configuration YAML files for use with PrimAITE 3.X remain compatible with PrimAITE v4+, it is encouraged to review for minor changes needed. From 05946431ca5eac277e6269a8d2c172cd0b084cf8 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 4 Feb 2025 11:19:13 +0000 Subject: [PATCH 20/23] #2887 - Correct type in documentation --- docs/source/how_to_guides/extensible_nodes.rst | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/how_to_guides/extensible_nodes.rst b/docs/source/how_to_guides/extensible_nodes.rst index 6651b618..043d0f06 100644 --- a/docs/source/how_to_guides/extensible_nodes.rst +++ b/docs/source/how_to_guides/extensible_nodes.rst @@ -18,8 +18,7 @@ Node classes all inherit from the base Node Class, though new classes should inh The use of an `__init__` method is not necessary, as configurable variables for the class should be specified within the `config` of the class, and passed at run time via your YAML configuration using the `from_config` method. - -An example of how additional Node classes is below, taken from `router.py` withing PrimAITE. +An example of how additional Node classes is below, taken from `router.py` within PrimAITE. .. code-block:: Python From 99e38fbbc2277c981abd904be44ea9311843cfc9 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 4 Feb 2025 14:25:26 +0000 Subject: [PATCH 21/23] #2887 - Removal of un-necessary code and cleanup following review comments --- src/primaite/simulator/network/creation.py | 3 +-- src/primaite/simulator/network/hardware/base.py | 8 -------- .../simulator/network/hardware/nodes/network/switch.py | 2 +- 3 files changed, 2 insertions(+), 11 deletions(-) diff --git a/src/primaite/simulator/network/creation.py b/src/primaite/simulator/network/creation.py index 5e0d0ce8..009ac861 100644 --- a/src/primaite/simulator/network/creation.py +++ b/src/primaite/simulator/network/creation.py @@ -167,7 +167,6 @@ class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"): # Optionally include a router in the LAN if config.include_router: default_gateway = IPv4Address(f"192.168.{config.subnet_base}.1") - # router = Router(hostname=f"router_{config.lan_name}", start_up_duration=0) router = Router.from_config( config={"hostname": f"router_{config.lan_name}", "type": "router", "start_up_duration": 0} ) @@ -230,7 +229,7 @@ class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"): "type": "computer", "hostname": f"pc_{i}_{config.lan_name}", "ip_address": f"192.168.{config.subnet_base}.{i+config.pcs_ip_block_start-1}", - "default_gateway": "192.168.10.1", + "default_gateway": default_gateway, "start_up_duration": 0, } pc = Computer.from_config(config=pc_cfg) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 36623a6f..29d46164 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -2242,10 +2242,6 @@ class Node(SimComponent, ABC): for app_id in self.applications: self.applications[app_id].close() - # Turn off all processes in the node - # for process_id in self.processes: - # self.processes[process_id] - def _start_up_actions(self): """Actions to perform when the node is starting up.""" # Turn on all the services in the node @@ -2258,10 +2254,6 @@ class Node(SimComponent, ABC): print(f"Starting application:{self.applications[app_id].config.type}") self.applications[app_id].run() - # Turn off all processes in the node - # for process_id in self.processes: - # self.processes[process_id] - def _install_system_software(self) -> None: """Preinstall required software.""" for _, software_class in self.SYSTEM_SOFTWARE.items(): diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index 54e1d7ef..8a9fdb24 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -109,7 +109,7 @@ class Switch(NetworkNode, identifier="switch"): def __init__(self, **kwargs): super().__init__(**kwargs) - for i in range(1, kwargs["config"].num_ports + 1): + for i in range(1, self.config.num_ports + 1): self.connect_nic(SwitchPort()) def _install_system_software(self): From 51bb3f5b07deede2506bb242eb8586e8307aedf4 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 4 Feb 2025 14:26:55 +0000 Subject: [PATCH 22/23] #2887 - Removal of un-necessary print statement that was used for debugging --- src/primaite/simulator/network/hardware/base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 29d46164..6543d793 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -2250,8 +2250,6 @@ class Node(SimComponent, ABC): # Turn on all the applications in the node for app_id in self.applications: - print(app_id) - print(f"Starting application:{self.applications[app_id].config.type}") self.applications[app_id].run() def _install_system_software(self) -> None: From 24161bb3fc038ed946ee1bebf57711d1b443e65e Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 4 Feb 2025 15:06:23 +0000 Subject: [PATCH 23/23] #2887 - Removal of commented out code --- tests/conftest.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c1d44aec..a1fde7dd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -246,7 +246,6 @@ def example_network() -> Network: switch_1 = Switch.from_config(config=switch_1_cfg) - # switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0) switch_1.power_on() network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[8]) @@ -254,7 +253,6 @@ def example_network() -> Network: # Switch 2 switch_2_config = {"hostname": "switch_2", "type": "switch", "num_ports": 8, "start_up_duration": 0} - # switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0) switch_2 = Switch.from_config(config=switch_2_config) switch_2.power_on() network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8])