diff --git a/CHANGELOG.md b/CHANGELOG.md index c91bf4f4..7f87f54e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Relabeled action parameters to match the new action config schemas, and updated the values to no longer rely on indices - Removed action space options which were previously used for assigning meaning to action space IDs - Updated tests that don't use YAMLs to still use the new action and agent schemas +- Nodes now use a config schema and are extensible, allowing for plugin support. +- Node tests have been updated to use the new node config schemas when not using YAML files. ### Fixed - DNS client no longer fails to check its cache if a DNS server address is missing. diff --git a/docs/source/how_to_guides/extensible_nodes.rst b/docs/source/how_to_guides/extensible_nodes.rst new file mode 100644 index 00000000..78ee550e --- /dev/null +++ b/docs/source/how_to_guides/extensible_nodes.rst @@ -0,0 +1,56 @@ +.. only:: comment + + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + +.. _about: + + +Extensible Nodes +**************** + +Node classes within PrimAITE have been updated to allow for easier generation of custom nodes within simulations. + + +Changes to Node Class structure. +================================ + +Node classes all inherit from the base Node Class, though new classes should inherit from either HostNode or NetworkNode, subject to the intended application of the Node. + +The use of an `__init__` method is not necessary, as configurable variables for the class should be specified within the `config` of the class, and passed at run time via your YAML configuration using the `from_config` method. + + +An example of how additional Node classes is below, taken from `router.py` withing PrimAITE. + +.. code-block:: Python + +class Router(NetworkNode, identifier="router"): + """ Represents a network router within the simulation, managing routing and forwarding of IP packets across network interfaces.""" + + SYSTEM_SOFTWARE: ClassVar[Dict] = { + "UserSessionManager": UserSessionManager, + "UserManager": UserManager, + "Terminal": Terminal, + } + + network_interfaces: Dict[str, RouterInterface] = {} + "The Router Interfaces on the node." + network_interface: Dict[int, RouterInterface] = {} + "The Router Interfaces on the node by port id." + + sys_log: SysLog + + config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSchema()) + + class ConfigSchema(NetworkNode.ConfigSchema): + """Configuration Schema for Router Objects.""" + + num_ports: int = 5 + + hostname: str = "Router" + + + +Changes to YAML file. +===================== + +Nodes defined within configuration YAML files for use with PrimAITE 3.X should still be compatible following these changes. diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index f42d6824..e8ea4625 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -1,6 +1,5 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """PrimAITE game - Encapsulates the simulation and agents.""" -from ipaddress import IPv4Address from typing import Dict, List, Optional, Union import numpy as np @@ -13,14 +12,8 @@ from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT from primaite.simulator.network.creation import NetworkNodeAdder from primaite.simulator.network.hardware.base import NetworkInterface, Node, NodeOperatingState, UserManager -from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC -from primaite.simulator.network.hardware.nodes.host.server import Printer, Server -from primaite.simulator.network.hardware.nodes.network.firewall import Firewall -from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode -from primaite.simulator.network.hardware.nodes.network.router import Router +from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.application import Application @@ -274,66 +267,10 @@ class PrimaiteGame: n_type = node_cfg["type"] new_node = None - # Default PrimAITE nodes - if n_type == "computer": - new_node = Computer( - hostname=node_cfg["hostname"], - ip_address=node_cfg["ip_address"], - subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), - default_gateway=node_cfg.get("default_gateway"), - dns_server=node_cfg.get("dns_server", None), - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - elif n_type == "server": - new_node = Server( - hostname=node_cfg["hostname"], - ip_address=node_cfg["ip_address"], - subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), - default_gateway=node_cfg.get("default_gateway"), - dns_server=node_cfg.get("dns_server", None), - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - elif n_type == "switch": - new_node = Switch( - hostname=node_cfg["hostname"], - num_ports=int(node_cfg.get("num_ports", "8")), - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - elif n_type == "router": - new_node = Router.from_config(node_cfg) - elif n_type == "firewall": - new_node = Firewall.from_config(node_cfg) - elif n_type == "wireless_router": - new_node = WirelessRouter.from_config(node_cfg, airspace=net.airspace) - elif n_type == "printer": - new_node = Printer( - hostname=node_cfg["hostname"], - ip_address=node_cfg["ip_address"], - subnet_mask=node_cfg["subnet_mask"], - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - # Handle extended nodes - elif n_type.lower() in Node._registry: - new_node = HostNode._registry[n_type]( - hostname=node_cfg["hostname"], - ip_address=node_cfg.get("ip_address"), - subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), - default_gateway=node_cfg.get("default_gateway"), - dns_server=node_cfg.get("dns_server", None), - operating_state=NodeOperatingState.ON - if not (p := node_cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - elif n_type in NetworkNode._registry: - new_node = NetworkNode._registry[n_type](**node_cfg) + if n_type in Node._registry: + if n_type == "wireless-router": + node_cfg["airspace"] = net.airspace + new_node = Node._registry[n_type].from_config(config=node_cfg) else: msg = f"invalid node type {n_type} in config" _LOGGER.error(msg) @@ -341,11 +278,11 @@ class PrimaiteGame: # TODO: handle simulation defaults more cleanly if "node_start_up_duration" in defaults_config: - new_node.start_up_duration = defaults_config["node_startup_duration"] + new_node.config.start_up_duration = defaults_config["node_startup_duration"] if "node_shut_down_duration" in defaults_config: - new_node.shut_down_duration = defaults_config["node_shut_down_duration"] + new_node.config.shut_down_duration = defaults_config["node_shut_down_duration"] if "node_scan_duration" in defaults_config: - new_node.node_scan_duration = defaults_config["node_scan_duration"] + new_node.config.node_scan_duration = defaults_config["node_scan_duration"] if "folder_scan_duration" in defaults_config: new_node.file_system._default_folder_scan_duration = defaults_config["folder_scan_duration"] if "folder_restore_duration" in defaults_config: @@ -382,7 +319,7 @@ class PrimaiteGame: service_class = SERVICE_TYPES_MAPPING[service_type] if service_class is not None: - _LOGGER.debug(f"installing {service_type} on node {new_node.hostname}") + _LOGGER.debug(f"installing {service_type} on node {new_node.config.hostname}") new_node.software_manager.install(service_class, software_config=service_cfg.get("options", {})) new_service = new_node.software_manager.software[service_type] @@ -400,7 +337,7 @@ class PrimaiteGame: # TODO: handle simulation defaults more cleanly if "service_fix_duration" in defaults_config: - new_service.fixing_duration = defaults_config["service_fix_duration"] + new_service.config.fixing_duration = defaults_config["service_fix_duration"] if "service_restart_duration" in defaults_config: new_service.restart_duration = defaults_config["service_restart_duration"] if "service_install_duration" in defaults_config: @@ -431,8 +368,8 @@ class PrimaiteGame: new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"])) # temporarily set to 0 so all nodes are initially on - new_node.start_up_duration = 0 - new_node.shut_down_duration = 0 + new_node.config.start_up_duration = 0 + new_node.config.shut_down_duration = 0 net.add_node(new_node) # run through the power on step if the node is to be turned on at the start @@ -440,8 +377,8 @@ class PrimaiteGame: new_node.power_on() # set start up and shut down duration - new_node.start_up_duration = int(node_cfg.get("start_up_duration", 3)) - new_node.shut_down_duration = int(node_cfg.get("shut_down_duration", 3)) + new_node.config.start_up_duration = int(node_cfg.get("start_up_duration", 3)) + new_node.config.shut_down_duration = int(node_cfg.get("shut_down_duration", 3)) # 1.1 Create Node Sets for node_set_cfg in node_sets_cfg: diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 750372b3..d8b5e93c 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -3,7 +3,7 @@ """Core of the PrimAITE Simulator.""" import warnings from abc import abstractmethod -from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union from uuid import uuid4 from prettytable import PrettyTable diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index 1f6fe6b0..7ede0bb0 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -178,7 +178,7 @@ class AirSpace(BaseModel): status = "Enabled" if interface.enabled else "Disabled" table.add_row( [ - interface._connected_node.hostname, # noqa + interface._connected_node.config.hostname, # noqa interface.mac_address, interface.ip_address if hasattr(interface, "ip_address") else None, interface.subnet_mask if hasattr(interface, "subnet_mask") else None, @@ -320,7 +320,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC): self.enabled = True self._connected_node.sys_log.info(f"Network Interface {self} enabled") self.pcap = PacketCapture( - hostname=self._connected_node.hostname, port_num=self.port_num, port_name=self.port_name + hostname=self._connected_node.config.hostname, port_num=self.port_num, port_name=self.port_name ) self.airspace.add_wireless_interface(self) diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index f5ae0232..2e494910 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -180,7 +180,7 @@ class Network(SimComponent): table.align = "l" table.title = "Nodes" for node in self.nodes.values(): - table.add_row((node.hostname, type(node)._discriminator, node.operating_state.name)) + table.add_row((node.config.hostname, type(node)._discriminator, node.operating_state.name)) print(table) if ip_addresses: @@ -196,7 +196,13 @@ class Network(SimComponent): if port.ip_address != IPv4Address("127.0.0.1"): port_str = port.port_name if port.port_name else port.port_num table.add_row( - [node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway] + [ + node.config.hostname, + port_str, + port.ip_address, + port.subnet_mask, + node.config.default_gateway, + ] ) print(table) @@ -215,9 +221,9 @@ class Network(SimComponent): if node in [link.endpoint_a.parent, link.endpoint_b.parent]: table.add_row( [ - link.endpoint_a.parent.hostname, + link.endpoint_a.parent.config.hostname, str(link.endpoint_a), - link.endpoint_b.parent.hostname, + link.endpoint_b.parent.config.hostname, str(link.endpoint_b), link.is_up, link.bandwidth, @@ -251,7 +257,7 @@ class Network(SimComponent): state = super().describe_state() state.update( { - "nodes": {node.hostname: node.describe_state() for node in self.nodes.values()}, + "nodes": {node.config.hostname: node.describe_state() for node in self.nodes.values()}, "links": {}, } ) @@ -259,8 +265,8 @@ class Network(SimComponent): for _, link in self.links.items(): node_a = link.endpoint_a._connected_node node_b = link.endpoint_b._connected_node - hostname_a = node_a.hostname if node_a else None - hostname_b = node_b.hostname if node_b else None + hostname_a = node_a.config.hostname if node_a else None + hostname_b = node_b.config.hostname if node_b else None port_a = link.endpoint_a.port_num port_b = link.endpoint_b.port_num link_key = f"{hostname_a}:eth-{port_a}<->{hostname_b}:eth-{port_b}" @@ -286,9 +292,11 @@ class Network(SimComponent): self.nodes[node.uuid] = node self._node_id_map[len(self.nodes)] = node node.parent = self - self._nx_graph.add_node(node.hostname) + self._nx_graph.add_node(node.config.hostname) _LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}") - self._node_request_manager.add_request(name=node.hostname, request_type=RequestType(func=node._request_manager)) + self._node_request_manager.add_request( + name=node.config.hostname, request_type=RequestType(func=node._request_manager) + ) def get_node_by_hostname(self, hostname: str) -> Optional[Node]: """ @@ -300,7 +308,7 @@ class Network(SimComponent): :return: The Node if it exists in the network. """ for node in self.nodes.values(): - if node.hostname == hostname: + if node.config.hostname == hostname: return node def remove_node(self, node: Node) -> None: @@ -313,7 +321,7 @@ class Network(SimComponent): :type node: Node """ if node not in self: - _LOGGER.warning(f"Can't remove node {node.hostname}. It's not in the network.") + _LOGGER.warning(f"Can't remove node {node.config.hostname}. It's not in the network.") return self.nodes.pop(node.uuid) for i, _node in self._node_id_map.items(): @@ -321,8 +329,8 @@ class Network(SimComponent): self._node_id_map.pop(i) break node.parent = None - self._node_request_manager.remove_request(name=node.hostname) - _LOGGER.info(f"Removed node {node.hostname} from network {self.uuid}") + self._node_request_manager.remove_request(name=node.config.hostname) + _LOGGER.info(f"Removed node {node.config.hostname} from network {self.uuid}") def connect( self, endpoint_a: WiredNetworkInterface, endpoint_b: WiredNetworkInterface, bandwidth: int = 100, **kwargs @@ -352,7 +360,7 @@ class Network(SimComponent): link = Link(endpoint_a=endpoint_a, endpoint_b=endpoint_b, bandwidth=bandwidth, **kwargs) self.links[link.uuid] = link self._link_id_map[len(self.links)] = link - self._nx_graph.add_edge(endpoint_a.parent.hostname, endpoint_b.parent.hostname) + self._nx_graph.add_edge(endpoint_a.parent.config.hostname, endpoint_b.parent.config.hostname) link.parent = self _LOGGER.debug(f"Added link {link.uuid} to connect {endpoint_a} and {endpoint_b}") return link diff --git a/src/primaite/simulator/network/creation.py b/src/primaite/simulator/network/creation.py index fdc5f2a1..82adc750 100644 --- a/src/primaite/simulator/network/creation.py +++ b/src/primaite/simulator/network/creation.py @@ -154,7 +154,9 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"): # Create a core switch if more than one edge switch is needed if num_of_switches > 1: - core_switch = Switch(hostname=f"switch_core_{config.lan_name}", start_up_duration=0) + core_switch = Switch.from_config( + config={"type": "switch", "hostname": f"switch_core_{config.lan_name}", "start_up_duration": 0} + ) core_switch.power_on() network.add_node(core_switch) core_switch_port = 1 @@ -165,7 +167,10 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"): # Optionally include a router in the LAN if config.include_router: default_gateway = IPv4Address(f"192.168.{config.subnet_base}.1") - router = Router(hostname=f"router_{config.lan_name}", start_up_duration=0) + # router = Router(hostname=f"router_{config.lan_name}", start_up_duration=0) + router = Router.from_config( + config={"hostname": f"router_{config.lan_name}", "type": "router", "start_up_duration": 0} + ) router.power_on() router.acl.add_rule( action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 @@ -178,7 +183,9 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"): # Initialise the first edge switch and connect to the router or core switch switch_port = 0 switch_n = 1 - switch = Switch(hostname=f"switch_edge_{switch_n}_{config.lan_name}", start_up_duration=0) + switch = Switch.from_config( + config={"type": "switch", "hostname": f"switch_edge_{switch_n}_{config.lan_name}", "start_up_duration": 0} + ) switch.power_on() network.add_node(switch) if num_of_switches > 1: @@ -196,7 +203,13 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"): if switch_port == effective_network_interface: switch_n += 1 switch_port = 0 - switch = Switch(hostname=f"switch_edge_{switch_n}_{config.lan_name}", start_up_duration=0) + switch = Switch.from_config( + config={ + "type": "switch", + "hostname": f"switch_edge_{switch_n}_{config.lan_name}", + "start_up_duration": 0, + } + ) switch.power_on() network.add_node(switch) # Connect the new switch to the router or core switch @@ -213,13 +226,14 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"): ) # Create and add a PC to the network - pc = Computer( - hostname=f"pc_{i}_{config.lan_name}", - ip_address=f"192.168.{config.subnet_base}.{i+config.pcs_ip_block_start-1}", - subnet_mask="255.255.255.0", - default_gateway=default_gateway, - start_up_duration=0, - ) + pc_cfg = { + "type": "computer", + "hostname": f"pc_{i}_{config.lan_name}", + "ip_address": f"192.168.{config.subnet_base}.{i+config.pcs_ip_block_start-1}", + "default_gateway": "192.168.10.1", + "start_up_duration": 0, + } + pc = Computer.from_config(config=pc_cfg) pc.power_on() network.add_node(pc) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 89254459..32881e19 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -9,7 +9,7 @@ from pathlib import Path from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union from prettytable import MARKDOWN, PrettyTable -from pydantic import BaseModel, Field, validate_call +from pydantic import BaseModel, ConfigDict, Field, validate_call from primaite import getLogger from primaite.exceptions import NetworkError @@ -431,7 +431,7 @@ class WiredNetworkInterface(NetworkInterface, ABC): self.enabled = True self._connected_node.sys_log.info(f"Network Interface {self} enabled") self.pcap = PacketCapture( - hostname=self._connected_node.hostname, port_num=self.port_num, port_name=self.port_name + hostname=self._connected_node.config.hostname, port_num=self.port_num, port_name=self.port_name ) if self._connected_link: self._connected_link.endpoint_up() @@ -1494,19 +1494,12 @@ class Node(SimComponent, ABC): :param operating_state: The node operating state, either ON or OFF. """ - hostname: str - "The node hostname on the network." - default_gateway: Optional[IPV4Address] = None - "The default gateway IP address for forwarding network traffic to other networks." operating_state: NodeOperatingState = NodeOperatingState.OFF "The hardware state of the node." network_interfaces: Dict[str, NetworkInterface] = {} "The Network Interfaces on the node." network_interface: Dict[int, NetworkInterface] = {} "The Network Interfaces on the node by port id." - dns_server: Optional[IPv4Address] = None - "List of IP addresses of DNS servers used for name resolution." - accounts: Dict[str, Account] = {} "All accounts on the node." applications: Dict[str, Application] = {} @@ -1523,33 +1516,6 @@ class Node(SimComponent, ABC): session_manager: SessionManager software_manager: SoftwareManager - revealed_to_red: bool = False - "Informs whether the node has been revealed to a red agent." - - start_up_duration: int = 3 - "Time steps needed for the node to start up." - - start_up_countdown: int = 0 - "Time steps needed until node is booted up." - - shut_down_duration: int = 3 - "Time steps needed for the node to shut down." - - shut_down_countdown: int = 0 - "Time steps needed until node is shut down." - - is_resetting: bool = False - "If true, the node will try turning itself off then back on again." - - node_scan_duration: int = 10 - "How many timesteps until the whole node is scanned. Default 10 time steps." - - node_scan_countdown: int = 0 - "Time steps until scan is complete" - - red_scan_countdown: int = 0 - "Time steps until reveal to red scan is complete." - SYSTEM_SOFTWARE: ClassVar[Dict[str, Type[Software]]] = {} "Base system software that must be preinstalled." @@ -1560,6 +1526,67 @@ class Node(SimComponent, ABC): _discriminator: ClassVar[str] """discriminator for this particular class, used for printing and logging. Each subclass redefines this.""" + config: Node.ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema()) + """Configuration items within Node""" + + class ConfigSchema(BaseModel, ABC): + """Configuration Schema for Node based classes.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + """Configure pydantic to allow arbitrary types, let the instance have attributes not present in the model.""" + + hostname: str = "default" + "The node hostname on the network." + + revealed_to_red: bool = False + "Informs whether the node has been revealed to a red agent." + + start_up_duration: int = 3 + "Time steps needed for the node to start up." + + start_up_countdown: int = 0 + "Time steps needed until node is booted up." + + shut_down_duration: int = 3 + "Time steps needed for the node to shut down." + + shut_down_countdown: int = 0 + "Time steps needed until node is shut down." + + is_resetting: bool = False + "If true, the node will try turning itself off then back on again." + + node_scan_duration: int = 10 + "How many timesteps until the whole node is scanned. Default 10 time steps." + + node_scan_countdown: int = 0 + "Time steps until scan is complete" + + red_scan_countdown: int = 0 + "Time steps until reveal to red scan is complete." + + dns_server: Optional[IPv4Address] = None + "List of IP addresses of DNS servers used for name resolution." + + default_gateway: Optional[IPV4Address] = None + "The default gateway IP address for forwarding network traffic to other networks." + + operating_state: Any = None + + @property + def dns_server(self) -> Optional[IPv4Address]: + """Convenience method to access the dns_server IP.""" + return self.config.dns_server + + @classmethod + def from_config(cls, config: Dict) -> "Node": + """Create Node object from a given configuration dictionary.""" + if config["type"] not in cls._registry: + msg = f"Configuration contains an invalid Node type: {config['type']}" + return ValueError(msg) + obj = cls(config=cls.ConfigSchema(**config)) + return obj + def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None: """ Register a node type. @@ -1585,11 +1612,11 @@ class Node(SimComponent, ABC): provided. """ if not kwargs.get("sys_log"): - kwargs["sys_log"] = SysLog(kwargs["hostname"]) + kwargs["sys_log"] = SysLog(kwargs["config"].hostname) if not kwargs.get("session_manager"): kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log")) if not kwargs.get("root"): - kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"] + kwargs["root"] = SIM_OUTPUT.path / kwargs["config"].hostname if not kwargs.get("file_system"): kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs") if not kwargs.get("software_manager"): @@ -1598,9 +1625,12 @@ class Node(SimComponent, ABC): sys_log=kwargs.get("sys_log"), session_manager=kwargs.get("session_manager"), file_system=kwargs.get("file_system"), - dns_server=kwargs.get("dns_server"), + dns_server=kwargs["config"].dns_server, ) super().__init__(**kwargs) + self.operating_state = ( + NodeOperatingState.ON if not (p := kwargs["config"].operating_state) else NodeOperatingState[p.upper()] + ) self._install_system_software() self.session_manager.node = self self.session_manager.software_manager = self.software_manager @@ -1694,7 +1724,7 @@ class Node(SimComponent, ABC): @property def fail_message(self) -> str: """Message that is reported when a request is rejected by this validator.""" - return f"Cannot perform request on node '{self.node.hostname}' because it is not powered on." + return f"Cannot perform request on node '{self.node.config.hostname}' because it is not powered on." class _NodeIsOffValidator(RequestPermissionValidator): """ @@ -1713,7 +1743,7 @@ class Node(SimComponent, ABC): @property def fail_message(self) -> str: """Message that is reported when a request is rejected by this validator.""" - return f"Cannot perform request on node '{self.node.hostname}' because it is not turned off." + return f"Cannot perform request on node '{self.node.config.hostname}' because it is not turned off." def _init_request_manager(self) -> RequestManager: """ @@ -1741,7 +1771,7 @@ class Node(SimComponent, ABC): self.software_manager.install(application_class) application_instance = self.software_manager.software.get(application_name) self.applications[application_instance.uuid] = application_instance - _LOGGER.debug(f"Added application {application_instance.name} to node {self.hostname}") + _LOGGER.debug(f"Added application {application_instance.name} to node {self.config.hostname}") self._application_request_manager.add_request( application_name, RequestType(func=application_instance._request_manager) ) @@ -1855,7 +1885,7 @@ class Node(SimComponent, ABC): state = super().describe_state() state.update( { - "hostname": self.hostname, + "hostname": self.config.hostname, "operating_state": self.operating_state.value, "NICs": { eth_num: network_interface.describe_state() @@ -1865,7 +1895,7 @@ class Node(SimComponent, ABC): "applications": {app.name: app.describe_state() for app in self.applications.values()}, "services": {svc.name: svc.describe_state() for svc in self.services.values()}, "process": {proc.name: proc.describe_state() for proc in self.processes.values()}, - "revealed_to_red": self.revealed_to_red, + "revealed_to_red": self.config.revealed_to_red, } ) return state @@ -1881,7 +1911,7 @@ class Node(SimComponent, ABC): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.hostname} Open Ports" + table.title = f"{self.config.hostname} Open Ports" for port in self.software_manager.get_open_ports(): if port > 0: table.add_row([port]) @@ -1908,7 +1938,7 @@ class Node(SimComponent, ABC): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.hostname} Network Interface Cards" + table.title = f"{self.config.hostname} Network Interface Cards" for port, network_interface in self.network_interface.items(): ip_address = "" if hasattr(network_interface, "ip_address"): @@ -1943,38 +1973,38 @@ class Node(SimComponent, ABC): network_interface.apply_timestep(timestep=timestep) # count down to boot up - if self.start_up_countdown > 0: - self.start_up_countdown -= 1 + if self.config.start_up_countdown > 0: + self.config.start_up_countdown -= 1 else: if self.operating_state == NodeOperatingState.BOOTING: self.operating_state = NodeOperatingState.ON - self.sys_log.info(f"{self.hostname}: Turned on") + self.sys_log.info(f"{self.config.hostname}: Turned on") for network_interface in self.network_interfaces.values(): network_interface.enable() self._start_up_actions() # count down to shut down - if self.shut_down_countdown > 0: - self.shut_down_countdown -= 1 + if self.config.shut_down_countdown > 0: + self.config.shut_down_countdown -= 1 else: if self.operating_state == NodeOperatingState.SHUTTING_DOWN: self.operating_state = NodeOperatingState.OFF - self.sys_log.info(f"{self.hostname}: Turned off") + self.sys_log.info(f"{self.config.hostname}: Turned off") self._shut_down_actions() # if resetting turn back on - if self.is_resetting: - self.is_resetting = False + if self.config.is_resetting: + self.config.is_resetting = False self.power_on() # time steps which require the node to be on if self.operating_state == NodeOperatingState.ON: # node scanning - if self.node_scan_countdown > 0: - self.node_scan_countdown -= 1 + if self.config.node_scan_countdown > 0: + self.config.node_scan_countdown -= 1 - if self.node_scan_countdown == 0: + if self.config.node_scan_countdown == 0: # scan everything! for process_id in self.processes: self.processes[process_id].scan() @@ -1990,10 +2020,10 @@ class Node(SimComponent, ABC): # scan file system self.file_system.scan(instant_scan=True) - if self.red_scan_countdown > 0: - self.red_scan_countdown -= 1 + if self.config.red_scan_countdown > 0: + self.config.red_scan_countdown -= 1 - if self.red_scan_countdown == 0: + if self.config.red_scan_countdown == 0: # scan processes for process_id in self.processes: self.processes[process_id].reveal_to_red() @@ -2050,7 +2080,7 @@ class Node(SimComponent, ABC): to the red agent. """ - self.node_scan_countdown = self.node_scan_duration + self.config.node_scan_countdown = self.config.node_scan_duration return True def reveal_to_red(self) -> bool: @@ -2066,12 +2096,12 @@ class Node(SimComponent, ABC): `revealed_to_red` to `True`. """ - self.red_scan_countdown = self.node_scan_duration + self.config.red_scan_countdown = self.config.node_scan_duration return True def power_on(self) -> bool: """Power on the Node, enabling its NICs if it is in the OFF state.""" - if self.start_up_duration <= 0: + if self.config.start_up_duration <= 0: self.operating_state = NodeOperatingState.ON self._start_up_actions() self.sys_log.info("Power on") @@ -2080,14 +2110,14 @@ class Node(SimComponent, ABC): return True if self.operating_state == NodeOperatingState.OFF: self.operating_state = NodeOperatingState.BOOTING - self.start_up_countdown = self.start_up_duration + self.config.start_up_countdown = self.config.start_up_duration return True return False def power_off(self) -> bool: """Power off the Node, disabling its NICs if it is in the ON state.""" - if self.shut_down_duration <= 0: + if self.config.shut_down_duration <= 0: self._shut_down_actions() self.operating_state = NodeOperatingState.OFF self.sys_log.info("Power off") @@ -2096,7 +2126,7 @@ class Node(SimComponent, ABC): for network_interface in self.network_interfaces.values(): network_interface.disable() self.operating_state = NodeOperatingState.SHUTTING_DOWN - self.shut_down_countdown = self.shut_down_duration + self.config.shut_down_countdown = self.config.shut_down_duration return True return False @@ -2108,7 +2138,7 @@ class Node(SimComponent, ABC): Applying more timesteps will eventually turn the node back on. """ if self.operating_state.ON: - self.is_resetting = True + self.config.is_resetting = True self.sys_log.info("Resetting") self.power_off() return True @@ -2225,6 +2255,8 @@ class Node(SimComponent, ABC): # Turn on all the applications in the node for app_id in self.applications: + print(app_id) + print(f"Starting application:{self.applications[app_id].config.type}") self.applications[app_id].run() # Turn off all processes in the node diff --git a/src/primaite/simulator/network/hardware/nodes/host/computer.py b/src/primaite/simulator/network/hardware/nodes/host/computer.py index ed11b8e5..64e3ac96 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/computer.py +++ b/src/primaite/simulator/network/hardware/nodes/host/computer.py @@ -1,6 +1,8 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from typing import ClassVar, Dict +from pydantic import Field + from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.system.services.ftp.ftp_client import FTPClient @@ -35,4 +37,11 @@ class Computer(HostNode, discriminator="computer"): SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "ftp-client": FTPClient} + config: "Computer.ConfigSchema" = Field(default_factory=lambda: Computer.ConfigSchema()) + + class ConfigSchema(HostNode.ConfigSchema): + """Configuration Schema for Computer class.""" + + hostname: str = "Computer" + pass diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 13796602..d640de2e 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -4,6 +4,8 @@ from __future__ import annotations from ipaddress import IPv4Address from typing import Any, ClassVar, Dict, Optional +from pydantic import Field + from primaite import getLogger from primaite.simulator.network.hardware.base import ( IPWiredNetworkInterface, @@ -44,8 +46,8 @@ class HostARP(ARP): :return: The MAC address of the default gateway if present in the ARP cache; otherwise, None. """ - if self.software_manager.node.default_gateway: - return self.get_arp_cache_mac_address(self.software_manager.node.default_gateway) + if self.software_manager.node.config.default_gateway: + return self.get_arp_cache_mac_address(self.software_manager.node.config.default_gateway) def get_default_gateway_network_interface(self) -> Optional[NIC]: """ @@ -53,8 +55,11 @@ class HostARP(ARP): :return: The NIC associated with the default gateway if it exists in the ARP cache; otherwise, None. """ - if self.software_manager.node.default_gateway and self.software_manager.node.has_enabled_network_interface: - return self.get_arp_cache_network_interface(self.software_manager.node.default_gateway) + if ( + self.software_manager.node.config.default_gateway + and self.software_manager.node.has_enabled_network_interface + ): + return self.get_arp_cache_network_interface(self.software_manager.node.config.default_gateway) def _get_arp_cache_mac_address( self, ip_address: IPV4Address, is_reattempt: bool = False, is_default_gateway_attempt: bool = False @@ -73,7 +78,7 @@ class HostARP(ARP): if arp_entry: return arp_entry.mac_address - if ip_address == self.software_manager.node.default_gateway: + if ip_address == self.software_manager.node.config.default_gateway: is_reattempt = True if not is_reattempt: self.send_arp_request(ip_address) @@ -81,11 +86,11 @@ class HostARP(ARP): ip_address=ip_address, is_reattempt=True, is_default_gateway_attempt=is_default_gateway_attempt ) else: - if self.software_manager.node.default_gateway: + if self.software_manager.node.config.default_gateway: if not is_default_gateway_attempt: - self.send_arp_request(self.software_manager.node.default_gateway) + self.send_arp_request(self.software_manager.node.config.default_gateway) return self._get_arp_cache_mac_address( - ip_address=self.software_manager.node.default_gateway, + ip_address=self.software_manager.node.config.default_gateway, is_reattempt=True, is_default_gateway_attempt=True, ) @@ -116,7 +121,7 @@ class HostARP(ARP): if arp_entry: return self.software_manager.node.network_interfaces[arp_entry.network_interface_uuid] else: - if ip_address == self.software_manager.node.default_gateway: + if ip_address == self.software_manager.node.config.default_gateway: is_reattempt = True if not is_reattempt: self.send_arp_request(ip_address) @@ -124,11 +129,11 @@ class HostARP(ARP): ip_address=ip_address, is_reattempt=True, is_default_gateway_attempt=is_default_gateway_attempt ) else: - if self.software_manager.node.default_gateway: + if self.software_manager.node.config.default_gateway: if not is_default_gateway_attempt: - self.send_arp_request(self.software_manager.node.default_gateway) + self.send_arp_request(self.software_manager.node.config.default_gateway) return self._get_arp_cache_network_interface( - ip_address=self.software_manager.node.default_gateway, + ip_address=self.software_manager.node.config.default_gateway, is_reattempt=True, is_default_gateway_attempt=True, ) @@ -325,9 +330,18 @@ class HostNode(Node, discriminator="host-node"): network_interface: Dict[int, NIC] = {} "The NICs on the node by port id." - def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs): + config: HostNode.ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema()) + + class ConfigSchema(Node.ConfigSchema): + """Configuration Schema for HostNode class.""" + + hostname: str = "HostNode" + subnet_mask: IPV4Address = "255.255.255.0" + ip_address: IPV4Address + + def __init__(self, **kwargs): super().__init__(**kwargs) - self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask)) + self.connect_nic(NIC(ip_address=kwargs["config"].ip_address, subnet_mask=kwargs["config"].subnet_mask)) @property def nmap(self) -> Optional[NMAP]: diff --git a/src/primaite/simulator/network/hardware/nodes/host/server.py b/src/primaite/simulator/network/hardware/nodes/host/server.py index 50b82122..af44cf5a 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/server.py +++ b/src/primaite/simulator/network/hardware/nodes/host/server.py @@ -1,4 +1,7 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + +from pydantic import Field + from primaite.simulator.network.hardware.nodes.host.host_node import HostNode @@ -30,8 +33,22 @@ class Server(HostNode, discriminator="server"): * Web Browser """ + config: "Server.ConfigSchema" = Field(default_factory=lambda: Server.ConfigSchema()) + + class ConfigSchema(HostNode.ConfigSchema): + """Configuration Schema for Server class.""" + + hostname: str = "server" + class Printer(HostNode, discriminator="printer"): """Printer? I don't even know her!.""" # TODO: Implement printer-specific behaviour + + config: "Printer.ConfigSchema" = Field(default_factory=lambda: Printer.ConfigSchema()) + + class ConfigSchema(HostNode.ConfigSchema): + """Configuration Schema for Printer class.""" + + hostname: str = "printer" diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 4da9e24c..21b356a0 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -6,7 +6,6 @@ from prettytable import MARKDOWN, PrettyTable from pydantic import Field, validate_call from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.router import ( AccessControlList, ACLAction, @@ -99,11 +98,21 @@ class Firewall(Router, discriminator="firewall"): ) """Access Control List for managing traffic leaving towards an external network.""" - def __init__(self, hostname: str, **kwargs): - if not kwargs.get("sys_log"): - kwargs["sys_log"] = SysLog(hostname) + _identifier: str = "firewall" - super().__init__(hostname=hostname, num_ports=0, **kwargs) + config: "Firewall.ConfigSchema" = Field(default_factory=lambda: Firewall.ConfigSchema()) + + class ConfigSchema(Router.ConfigSchema): + """Configuration Schema for Firewall 'Nodes' within PrimAITE.""" + + hostname: str = "firewall" + num_ports: int = 0 + + def __init__(self, **kwargs): + if not kwargs.get("sys_log"): + kwargs["sys_log"] = SysLog(kwargs["config"].hostname) + + super().__init__(**kwargs) self.connect_nic( RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", port_name="external") @@ -116,22 +125,23 @@ class Firewall(Router, discriminator="firewall"): ) # Update ACL objects with firewall's hostname and syslog to allow accurate logging self.internal_inbound_acl.sys_log = kwargs["sys_log"] - self.internal_inbound_acl.name = f"{hostname} - Internal Inbound" + self.internal_inbound_acl.name = f"{kwargs['config'].hostname} - Internal Inbound" self.internal_outbound_acl.sys_log = kwargs["sys_log"] - self.internal_outbound_acl.name = f"{hostname} - Internal Outbound" + self.internal_outbound_acl.name = f"{kwargs['config'].hostname} - Internal Outbound" self.dmz_inbound_acl.sys_log = kwargs["sys_log"] - self.dmz_inbound_acl.name = f"{hostname} - DMZ Inbound" + self.dmz_inbound_acl.name = f"{kwargs['config'].hostname} - DMZ Inbound" self.dmz_outbound_acl.sys_log = kwargs["sys_log"] - self.dmz_outbound_acl.name = f"{hostname} - DMZ Outbound" + self.dmz_outbound_acl.name = f"{kwargs['config'].hostname} - DMZ Outbound" self.external_inbound_acl.sys_log = kwargs["sys_log"] - self.external_inbound_acl.name = f"{hostname} - External Inbound" + self.external_inbound_acl.name = f"{kwargs['config'].hostname} - External Inbound" self.external_outbound_acl.sys_log = kwargs["sys_log"] - self.external_outbound_acl.name = f"{hostname} - External Outbound" + self.external_outbound_acl.name = f"{kwargs['config'].hostname} - External Outbound" + self.power_on() def _init_request_manager(self) -> RequestManager: """ @@ -231,7 +241,7 @@ class Firewall(Router, discriminator="firewall"): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.hostname} Network Interfaces" + table.title = f"{self.config.hostname} Network Interfaces" ports = {"External": self.external_port, "Internal": self.internal_port, "DMZ": self.dmz_port} for port, network_interface in ports.items(): table.add_row( @@ -551,18 +561,14 @@ class Firewall(Router, discriminator="firewall"): self.dmz_port.enable() @classmethod - def from_config(cls, cfg: dict) -> "Firewall": + def from_config(cls, config: dict) -> "Firewall": """Create a firewall based on a config dict.""" - firewall = Firewall( - hostname=cfg["hostname"], - operating_state=NodeOperatingState.ON - if not (p := cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - if "ports" in cfg: - internal_port = cfg["ports"]["internal_port"] - external_port = cfg["ports"]["external_port"] - dmz_port = cfg["ports"].get("dmz_port") + firewall = Firewall(config=cls.ConfigSchema(**config)) + + if "ports" in config: + internal_port = config["ports"]["internal_port"] + external_port = config["ports"]["external_port"] + dmz_port = config["ports"].get("dmz_port") # configure internal port firewall.configure_internal_port( @@ -582,10 +588,10 @@ class Firewall(Router, discriminator="firewall"): ip_address=IPV4Address(dmz_port.get("ip_address")), subnet_mask=IPV4Address(dmz_port.get("subnet_mask", "255.255.255.0")), ) - if "acl" in cfg: + if "acl" in config: # acl rules for internal_inbound_acl - if cfg["acl"]["internal_inbound_acl"]: - for r_num, r_cfg in cfg["acl"]["internal_inbound_acl"].items(): + if config["acl"]["internal_inbound_acl"]: + for r_num, r_cfg in config["acl"]["internal_inbound_acl"].items(): firewall.internal_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -599,8 +605,8 @@ class Firewall(Router, discriminator="firewall"): ) # acl rules for internal_outbound_acl - if cfg["acl"]["internal_outbound_acl"]: - for r_num, r_cfg in cfg["acl"]["internal_outbound_acl"].items(): + if config["acl"]["internal_outbound_acl"]: + for r_num, r_cfg in config["acl"]["internal_outbound_acl"].items(): firewall.internal_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -614,8 +620,8 @@ class Firewall(Router, discriminator="firewall"): ) # acl rules for dmz_inbound_acl - if cfg["acl"]["dmz_inbound_acl"]: - for r_num, r_cfg in cfg["acl"]["dmz_inbound_acl"].items(): + if config["acl"]["dmz_inbound_acl"]: + for r_num, r_cfg in config["acl"]["dmz_inbound_acl"].items(): firewall.dmz_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -629,8 +635,8 @@ class Firewall(Router, discriminator="firewall"): ) # acl rules for dmz_outbound_acl - if cfg["acl"]["dmz_outbound_acl"]: - for r_num, r_cfg in cfg["acl"]["dmz_outbound_acl"].items(): + if config["acl"]["dmz_outbound_acl"]: + for r_num, r_cfg in config["acl"]["dmz_outbound_acl"].items(): firewall.dmz_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -644,8 +650,8 @@ class Firewall(Router, discriminator="firewall"): ) # acl rules for external_inbound_acl - if cfg["acl"].get("external_inbound_acl"): - for r_num, r_cfg in cfg["acl"]["external_inbound_acl"].items(): + if config["acl"].get("external_inbound_acl"): + for r_num, r_cfg in config["acl"]["external_inbound_acl"].items(): firewall.external_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -659,8 +665,8 @@ class Firewall(Router, discriminator="firewall"): ) # acl rules for external_outbound_acl - if cfg["acl"].get("external_outbound_acl"): - for r_num, r_cfg in cfg["acl"]["external_outbound_acl"].items(): + if config["acl"].get("external_outbound_acl"): + for r_num, r_cfg in config["acl"]["external_outbound_acl"].items(): firewall.external_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -673,16 +679,16 @@ class Firewall(Router, discriminator="firewall"): position=r_num, ) - if "routes" in cfg: - for route in cfg.get("routes"): + if "routes" in config: + for route in config.get("routes"): firewall.route_table.add_route( address=IPv4Address(route.get("address")), subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")), next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")), metric=float(route.get("metric", 0)), ) - if "default_route" in cfg: - next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None) + if "default_route" in config: + next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None) if next_hop_ip_address: firewall.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index b6004e8e..17fbbc94 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1207,7 +1207,6 @@ class Router(NetworkNode, discriminator="router"): "Terminal": Terminal, } - num_ports: int network_interfaces: Dict[str, RouterInterface] = {} "The Router Interfaces on the node." network_interface: Dict[int, RouterInterface] = {} @@ -1215,19 +1214,29 @@ class Router(NetworkNode, discriminator="router"): acl: AccessControlList route_table: RouteTable - def __init__(self, hostname: str, num_ports: int = 5, **kwargs): + config: "Router.ConfigSchema" + + class ConfigSchema(NetworkNode.ConfigSchema): + """Configuration Schema for Routers.""" + + hostname: str = "router" + num_ports: int = 5 + + def __init__(self, **kwargs): if not kwargs.get("sys_log"): - kwargs["sys_log"] = SysLog(hostname) + kwargs["sys_log"] = SysLog(kwargs["config"].hostname) if not kwargs.get("acl"): - kwargs["acl"] = AccessControlList(sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=hostname) + kwargs["acl"] = AccessControlList( + sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=kwargs["config"].hostname + ) if not kwargs.get("route_table"): kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"]) - super().__init__(hostname=hostname, num_ports=num_ports, **kwargs) + super().__init__(**kwargs) self.session_manager = RouterSessionManager(sys_log=self.sys_log) self.session_manager.node = self self.software_manager.session_manager = self.session_manager self.session_manager.software_manager = self.software_manager - for i in range(1, self.num_ports + 1): + for i in range(1, self.config.num_ports + 1): network_interface = RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0") self.connect_nic(network_interface) self.network_interface[i] = network_interface @@ -1337,7 +1346,7 @@ class Router(NetworkNode, discriminator="router"): :return: A dictionary representing the current state. """ state = super().describe_state() - state["num_ports"] = self.num_ports + state["num_ports"] = self.config.num_ports state["acl"] = self.acl.describe_state() return state @@ -1545,7 +1554,7 @@ class Router(NetworkNode, discriminator="router"): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.hostname} Network Interfaces" + table.title = f"{self.config.hostname} Network Interfaces" for port, network_interface in self.network_interface.items(): table.add_row( [ @@ -1559,7 +1568,7 @@ class Router(NetworkNode, discriminator="router"): print(table) @classmethod - def from_config(cls, cfg: dict, **kwargs) -> "Router": + def from_config(cls, config: dict, **kwargs) -> "Router": """Create a router based on a config dict. Schema: @@ -1616,22 +1625,16 @@ class Router(NetworkNode, discriminator="router"): :return: Configured router. :rtype: Router """ - router = Router( - hostname=cfg["hostname"], - num_ports=int(cfg.get("num_ports", "5")), - operating_state=NodeOperatingState.ON - if not (p := cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - if "ports" in cfg: - for port_num, port_cfg in cfg["ports"].items(): + router = Router(config=Router.ConfigSchema(**config)) + if "ports" in config: + for port_num, port_cfg in config["ports"].items(): router.configure_port( port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=IPv4Address(port_cfg.get("subnet_mask", "255.255.255.0")), ) - if "acl" in cfg: - for r_num, r_cfg in cfg["acl"].items(): + if "acl" in config: + for r_num, r_cfg in config["acl"].items(): router.acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -1643,16 +1646,19 @@ class Router(NetworkNode, discriminator="router"): dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"), position=r_num, ) - if "routes" in cfg: - for route in cfg.get("routes"): + if "routes" in config: + for route in config.get("routes"): router.route_table.add_route( address=IPv4Address(route.get("address")), subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")), next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")), metric=float(route.get("metric", 0)), ) - if "default_route" in cfg: - next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None) + if "default_route" in config: + next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None) if next_hop_ip_address: router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) + router.operating_state = ( + NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()] + ) return router diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index f06337aa..69218f86 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -4,6 +4,7 @@ from __future__ import annotations from typing import Dict, Optional from prettytable import MARKDOWN, PrettyTable +from pydantic import Field from primaite import getLogger from primaite.exceptions import NetworkError @@ -88,14 +89,8 @@ class SwitchPort(WiredNetworkInterface): class Switch(NetworkNode, discriminator="switch"): - """ - A class representing a Layer 2 network switch. + """A class representing a Layer 2 network switch.""" - :ivar num_ports: The number of ports on the switch. Default is 24. - """ - - num_ports: int = 24 - "The number of ports on the switch." network_interfaces: Dict[str, SwitchPort] = {} "The SwitchPorts on the Switch." network_interface: Dict[int, SwitchPort] = {} @@ -103,9 +98,18 @@ class Switch(NetworkNode, discriminator="switch"): mac_address_table: Dict[str, SwitchPort] = {} "A MAC address table mapping destination MAC addresses to corresponding SwitchPorts." + config: "Switch.ConfigSchema" = Field(default_factory=lambda: Switch.ConfigSchema()) + + class ConfigSchema(NetworkNode.ConfigSchema): + """Configuration Schema for Switch nodes within PrimAITE.""" + + hostname: str = "Switch" + num_ports: int = 24 + "The number of ports on the switch. Default is 24." + def __init__(self, **kwargs): super().__init__(**kwargs) - for i in range(1, self.num_ports + 1): + for i in range(1, kwargs["config"].num_ports + 1): self.connect_nic(SwitchPort()) def _install_system_software(self): @@ -121,7 +125,7 @@ class Switch(NetworkNode, discriminator="switch"): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.hostname} Switch Ports" + table.title = f"{self.config.hostname} Switch Ports" for port_num, port in self.network_interface.items(): table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"]) print(table) @@ -134,7 +138,7 @@ class Switch(NetworkNode, discriminator="switch"): """ state = super().describe_state() state["ports"] = {port_num: port.describe_state() for port_num, port in self.network_interface.items()} - state["num_ports"] = self.num_ports # redundant? + state["num_ports"] = self.config.num_ports # redundant? state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()} return state diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index ca962c24..f8c29923 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -2,7 +2,7 @@ from ipaddress import IPv4Address from typing import Any, Dict, Optional, Union -from pydantic import validate_call +from pydantic import Field, validate_call from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, FREQ_WIFI_2_4, IPWirelessNetworkInterface from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState @@ -122,13 +122,23 @@ class WirelessRouter(Router, discriminator="wireless-router"): network_interfaces: Dict[str, Union[RouterInterface, WirelessAccessPoint]] = {} network_interface: Dict[int, Union[RouterInterface, WirelessAccessPoint]] = {} - airspace: AirSpace - def __init__(self, hostname: str, airspace: AirSpace, **kwargs): - super().__init__(hostname=hostname, num_ports=0, airspace=airspace, **kwargs) + config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.ConfigSchema()) + + class ConfigSchema(Router.ConfigSchema): + """Configuration Schema for WirelessRouter nodes within PrimAITE.""" + + hostname: str = "WirelessRouter" + airspace: AirSpace + num_ports: int = 0 + + def __init__(self, **kwargs): + super().__init__(**kwargs) self.connect_nic( - WirelessAccessPoint(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", airspace=airspace) + WirelessAccessPoint( + ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", airspace=kwargs["config"].airspace + ) ) self.connect_nic(RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")) @@ -226,7 +236,7 @@ class WirelessRouter(Router, discriminator="wireless-router"): ) @classmethod - def from_config(cls, cfg: Dict, **kwargs) -> "WirelessRouter": + def from_config(cls, config: Dict, **kwargs) -> "WirelessRouter": """Generate the wireless router from config. Schema: @@ -253,22 +263,22 @@ class WirelessRouter(Router, discriminator="wireless-router"): :return: WirelessRouter instance. :rtype: WirelessRouter """ - operating_state = ( - NodeOperatingState.ON if not (p := cfg.get("operating_state")) else NodeOperatingState[p.upper()] + router = cls(config=cls.ConfigSchema(**config)) + router.operating_state = ( + NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()] ) - router = cls(hostname=cfg["hostname"], operating_state=operating_state, airspace=kwargs["airspace"]) - if "router_interface" in cfg: - ip_address = cfg["router_interface"]["ip_address"] - subnet_mask = cfg["router_interface"]["subnet_mask"] + if "router_interface" in config: + ip_address = config["router_interface"]["ip_address"] + subnet_mask = config["router_interface"]["subnet_mask"] router.configure_router_interface(ip_address=ip_address, subnet_mask=subnet_mask) - if "wireless_access_point" in cfg: - ip_address = cfg["wireless_access_point"]["ip_address"] - subnet_mask = cfg["wireless_access_point"]["subnet_mask"] - frequency = AirSpaceFrequency._registry[cfg["wireless_access_point"]["frequency"]] + if "wireless_access_point" in config: + ip_address = config["wireless_access_point"]["ip_address"] + subnet_mask = config["wireless_access_point"]["subnet_mask"] + frequency = AirSpaceFrequency._registry[config["wireless_access_point"]["frequency"]] router.configure_wireless_access_point(ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency) - if "acl" in cfg: - for r_num, r_cfg in cfg["acl"].items(): + if "acl" in config: + for r_num, r_cfg in config["acl"].items(): router.acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -280,8 +290,8 @@ class WirelessRouter(Router, discriminator="wireless-router"): dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"), position=r_num, ) - if "routes" in cfg: - for route in cfg.get("routes"): + if "routes" in config: + for route in config.get("routes"): router.route_table.add_route( address=IPv4Address(route.get("address")), subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")), diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index 97dde839..21a0041e 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -40,41 +40,45 @@ def client_server_routed() -> Network: network = Network() # Router 1 - router_1 = Router(hostname="router_1", num_ports=3) + router_1 = Router(config=dict(hostname="router_1", num_ports=3)) router_1.power_on() router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0") router_1.configure_port(port=2, ip_address="192.168.2.1", subnet_mask="255.255.255.0") # Switch 1 - switch_1 = Switch(hostname="switch_1", num_ports=6) + switch_1 = Switch(config=dict(hostname="switch_1", num_ports=6)) switch_1.power_on() network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[6]) router_1.enable_port(1) # Switch 2 - switch_2 = Switch(hostname="switch_2", num_ports=6) + switch_2 = Switch(config=dict(hostname="switch_2", num_ports=6)) switch_2.power_on() network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[6]) router_1.enable_port(2) # Client 1 client_1 = Computer( - hostname="client_1", - ip_address="192.168.2.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.2.1", - start_up_duration=0, + config=dict( + hostname="client_1", + ip_address="192.168.2.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.2.1", + start_up_duration=0, + ) ) client_1.power_on() network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) # Server 1 server_1 = Server( - hostname="server_1", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + config=dict( + hostname="server_1", + ip_address="192.168.1.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) ) server_1.power_on() network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1]) @@ -128,32 +132,41 @@ def arcd_uc2_network() -> Network: network = Network() # Router 1 - router_1 = Router(hostname="router_1", num_ports=5, start_up_duration=0) + router_1 = Router.from_config( + config={"type": "router", "hostname": "router_1", "num_ports": 5, "start_up_duration": 0} + ) router_1.power_on() router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0") router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0") # Switch 1 - switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0) + switch_1 = Switch.from_config( + config={"type": "switch", "hostname": "switch_1", "num_ports": 8, "start_up_duration": 0} + ) switch_1.power_on() network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[8]) router_1.enable_port(1) # Switch 2 - switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0) + switch_2 = Switch.from_config( + config={"type": "switch", "hostname": "switch_2", "num_ports": 8, "start_up_duration": 0} + ) switch_2.power_on() network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8]) router_1.enable_port(2) # Client 1 - client_1 = Computer( - hostname="client_1", - ip_address="192.168.10.21", - subnet_mask="255.255.255.0", - default_gateway="192.168.10.1", - dns_server=IPv4Address("192.168.1.10"), - start_up_duration=0, - ) + client_1_cfg = { + "type": "computer", + "hostname": "client_1", + "ip_address": "192.168.10.21", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.10.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } + client_1: Computer = Computer.from_config(config=client_1_cfg) + client_1.power_on() network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) client_1.software_manager.install(DatabaseClient) @@ -172,14 +185,18 @@ def arcd_uc2_network() -> Network: ) # Client 2 - client_2 = Computer( - hostname="client_2", - ip_address="192.168.10.22", - subnet_mask="255.255.255.0", - default_gateway="192.168.10.1", - dns_server=IPv4Address("192.168.1.10"), - start_up_duration=0, - ) + + client_2_cfg = { + "type": "computer", + "hostname": "client_2", + "ip_address": "192.168.10.22", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.10.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } + client_2: Computer = Computer.from_config(config=client_2_cfg) + client_2.power_on() client_2.software_manager.install(DatabaseClient) db_client_2 = client_2.software_manager.software.get("database-client") @@ -193,27 +210,36 @@ def arcd_uc2_network() -> Network: ) # Domain Controller - domain_controller = Server( - hostname="domain_controller", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + + domain_controller_cfg = { + "type": "server", + "hostname": "domain_controller", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + domain_controller = Server.from_config(config=domain_controller_cfg) domain_controller.power_on() domain_controller.software_manager.install(DNSServer) network.connect(endpoint_b=domain_controller.network_interface[1], endpoint_a=switch_1.network_interface[1]) # Database Server - database_server = Server( - hostname="database_server", - ip_address="192.168.1.14", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - dns_server=IPv4Address("192.168.1.10"), - start_up_duration=0, - ) + + database_server_cfg = { + "type": "server", + "hostname": "database_server", + "ip_address": "192.168.1.14", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } + + database_server = Server.from_config(config=database_server_cfg) + database_server.power_on() network.connect(endpoint_b=database_server.network_interface[1], endpoint_a=switch_1.network_interface[3]) @@ -223,14 +249,18 @@ def arcd_uc2_network() -> Network: database_service.configure_backup(backup_server=IPv4Address("192.168.1.16")) # Web Server - web_server = Server( - hostname="web_server", - ip_address="192.168.1.12", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - dns_server=IPv4Address("192.168.1.10"), - start_up_duration=0, - ) + + web_server_cfg = { + "type": "server", + "hostname": "web_server", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } + web_server = Server.from_config(config=web_server_cfg) + web_server.power_on() web_server.software_manager.install(DatabaseClient) @@ -247,27 +277,32 @@ def arcd_uc2_network() -> Network: dns_server_service.dns_register("arcd.com", web_server.network_interface[1].ip_address) # Backup Server - backup_server = Server( - hostname="backup_server", - ip_address="192.168.1.16", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - dns_server=IPv4Address("192.168.1.10"), - start_up_duration=0, - ) + backup_server_cfg = { + "type": "server", + "hostname": "backup_server", + "ip_address": "192.168.1.16", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } + backup_server: Server = Server.from_config(config=backup_server_cfg) + backup_server.power_on() backup_server.software_manager.install(FTPServer) network.connect(endpoint_b=backup_server.network_interface[1], endpoint_a=switch_1.network_interface[4]) # Security Suite - security_suite = Server( - hostname="security_suite", - ip_address="192.168.1.110", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - dns_server=IPv4Address("192.168.1.10"), - start_up_duration=0, - ) + security_suite_cfg = { + "type": "server", + "hostname": "backup_server", + "ip_address": "192.168.1.110", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10"), + "start_up_duration": 0, + } + security_suite: Server = Server.from_config(config=security_suite_cfg) security_suite.power_on() network.connect(endpoint_b=security_suite.network_interface[1], endpoint_a=switch_1.network_interface[7]) security_suite.connect_nic(NIC(ip_address="192.168.10.110", subnet_mask="255.255.255.0")) diff --git a/src/primaite/simulator/system/applications/nmap.py b/src/primaite/simulator/system/applications/nmap.py index 5aada5fb..93de43bb 100644 --- a/src/primaite/simulator/system/applications/nmap.py +++ b/src/primaite/simulator/system/applications/nmap.py @@ -208,7 +208,7 @@ class NMAP(Application, discriminator="nmap"): if show: table = PrettyTable(["IP Address", "Can Ping"]) table.align = "l" - table.title = f"{self.software_manager.node.hostname} NMAP Ping Scan" + table.title = f"{self.software_manager.node.config.hostname} NMAP Ping Scan" ip_addresses = self._explode_ip_address_network_array(target_ip_address) @@ -367,7 +367,7 @@ class NMAP(Application, discriminator="nmap"): if show: table = PrettyTable(["IP Address", "Port", "Protocol"]) table.align = "l" - table.title = f"{self.software_manager.node.hostname} NMAP Port Scan ({scan_type})" + table.title = f"{self.software_manager.node.config.hostname} NMAP Port Scan ({scan_type})" self.sys_log.info(f"{self.name}: Starting port scan") for ip_address in ip_addresses: # Prevent port scan on this node diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index adc8d565..67e555ae 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -140,6 +140,7 @@ class SoftwareManager: elif isinstance(software, Service): self.node.services[software.uuid] = software self.node._service_request_manager.add_request(software.name, RequestType(func=software._request_manager)) + software.start() software.install() software.software_manager = self self.software[software.name] = software diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 0969370a..edc3f6b4 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -32,6 +32,7 @@ class DatabaseService(Service, discriminator="database-service"): type: str = "database-service" backup_server_ip: Optional[IPv4Address] = None db_password: Optional[str] = None + """Password that needs to be provided by clients if they want to connect to the DatabaseService.""" config: ConfigSchema = Field(default_factory=lambda: DatabaseService.ConfigSchema()) @@ -224,7 +225,7 @@ class DatabaseService(Service, discriminator="database-service"): SoftwareHealthState.FIXING, SoftwareHealthState.COMPROMISED, ]: - if self.password == password: + if self.config.db_password == password: status_code = 200 # ok connection_id = self._generate_connection_id() # try to create connection diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index 0c451d12..4d1ee6ba 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -36,7 +36,11 @@ class FTPServer(FTPServiceABC, discriminator="ftp-server"): kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) self.start() - self.server_password = self.config.server_password + + @property + def server_password(self) -> Optional[str]: + """Convenience method for accessing FTP server password.""" + return self.config.server_password def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket: """ diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index 30d8c258..93d2fbf7 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -26,8 +26,6 @@ class NTPClient(Service, discriminator="ntp-client"): config: ConfigSchema = Field(default_factory=lambda: NTPClient.ConfigSchema()) - ntp_server: Optional[IPv4Address] = None - "The NTP server the client sends requests to." time: Optional[datetime] = None def __init__(self, **kwargs): @@ -45,8 +43,8 @@ class NTPClient(Service, discriminator="ntp-client"): :param ntp_server_ip_address: IPv4 address of NTP server. :param ntp_client_ip_Address: IPv4 address of NTP client. """ - self.ntp_server = ntp_server_ip_address - self.sys_log.info(f"{self.name}: ntp_server: {self.ntp_server}") + self.config.ntp_server_ip = ntp_server_ip_address + self.sys_log.info(f"{self.name}: ntp_server: {self.config.ntp_server_ip}") def describe_state(self) -> Dict: """ @@ -108,10 +106,10 @@ class NTPClient(Service, discriminator="ntp-client"): def request_time(self) -> None: """Send request to ntp_server.""" - if self.ntp_server: + if self.config.ntp_server_ip: self.software_manager.session_manager.receive_payload_from_software_manager( payload=NTPPacket(), - dst_ip_address=self.ntp_server, + dst_ip_address=self.config.ntp_server_ip, src_port=self.port, dst_port=self.port, ip_protocol=self.protocol, diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 25b2366c..950f77c6 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -315,7 +315,7 @@ class IOSoftware(Software, ABC): """ if self.software_manager and self.software_manager.node.operating_state != NodeOperatingState.ON: self.software_manager.node.sys_log.error( - f"{self.name} Error: {self.software_manager.node.hostname} is not powered on." + f"{self.name} Error: {self.software_manager.node.config.hostname} is not powered on." ) return False return True diff --git a/tests/conftest.py b/tests/conftest.py index 230a763d..226a8cef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -119,7 +119,14 @@ def application_class(): @pytest.fixture(scope="function") def file_system() -> FileSystem: - computer = Computer(hostname="fs_node", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + computer_cfg = { + "type": "computer", + "hostname": "fs_node", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + computer = Computer.from_config(config=computer_cfg) computer.power_on() return computer.file_system @@ -129,23 +136,29 @@ def client_server() -> Tuple[Computer, Server]: network = Network() # Create Computer - computer = Computer( - hostname="computer", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + computer_cfg = { + "type": "computer", + "hostname": "computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + computer: Computer = Computer.from_config(config=computer_cfg) computer.power_on() # Create Server - server = Server( - hostname="server", - ip_address="192.168.1.3", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + server_cfg = { + "type": "server", + "hostname": "server", + "ip_address": "192.168.1.3", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + server: Server = Server.from_config(config=server_cfg) server.power_on() # Connect Computer and Server @@ -162,26 +175,33 @@ def client_switch_server() -> Tuple[Computer, Switch, Server]: network = Network() # Create Computer - computer = Computer( - hostname="computer", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + computer_cfg = { + "type": "computer", + "hostname": "computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + computer: Computer = Computer.from_config(config=computer_cfg) computer.power_on() # Create Server - server = Server( - hostname="server", - ip_address="192.168.1.3", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + server_cfg = { + "type": "server", + "hostname": "server", + "ip_address": "192.168.1.3", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + server: Server = Server.from_config(config=server_cfg) server.power_on() - switch = Switch(hostname="switch", start_up_duration=0) + # Create Switch + switch: Switch = Switch.from_config(config={"type": "switch", "hostname": "switch", "start_up_duration": 0}) switch.power_on() network.connect(endpoint_a=computer.network_interface[1], endpoint_b=switch.network_interface[1]) @@ -211,65 +231,96 @@ def example_network() -> Network: network = Network() # Router 1 - router_1 = Router(hostname="router_1", start_up_duration=0) + + router_1_cfg = {"hostname": "router_1", "type": "router", "start_up_duration": 0} + + # router_1 = Router(hostname="router_1", start_up_duration=0) + router_1 = Router.from_config(config=router_1_cfg) router_1.power_on() router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0") router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0") # Switch 1 - switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0) + + switch_1_cfg = {"hostname": "switch_1", "type": "switch", "start_up_duration": 0} + + switch_1 = Switch.from_config(config=switch_1_cfg) + + # switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0) switch_1.power_on() network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[8]) router_1.enable_port(1) # Switch 2 - switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0) + switch_2_config = {"hostname": "switch_2", "type": "switch", "num_ports": 8, "start_up_duration": 0} + # switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0) + switch_2 = Switch.from_config(config=switch_2_config) switch_2.power_on() network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8]) router_1.enable_port(2) - # Client 1 - client_1 = Computer( - hostname="client_1", - ip_address="192.168.10.21", - subnet_mask="255.255.255.0", - default_gateway="192.168.10.1", - start_up_duration=0, - ) + # # Client 1 + + client_1_cfg = { + "type": "computer", + "hostname": "client_1", + "ip_address": "192.168.10.21", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.10.1", + "start_up_duration": 0, + } + + client_1 = Computer.from_config(config=client_1_cfg) + client_1.power_on() network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) - # Client 2 - client_2 = Computer( - hostname="client_2", - ip_address="192.168.10.22", - subnet_mask="255.255.255.0", - default_gateway="192.168.10.1", - start_up_duration=0, - ) + # # Client 2 + + client_2_cfg = { + "type": "computer", + "hostname": "client_2", + "ip_address": "192.168.10.22", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.10.1", + "start_up_duration": 0, + } + + client_2 = Computer.from_config(config=client_2_cfg) + client_2.power_on() network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.network_interface[2]) - # Server 1 - server_1 = Server( - hostname="server_1", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + # # Server 1 + + server_1_cfg = { + "type": "server", + "hostname": "server_1", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + server_1 = Server.from_config(config=server_1_cfg) + server_1.power_on() network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1]) - # DServer 2 - server_2 = Server( - hostname="server_2", - ip_address="192.168.1.14", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + # # DServer 2 + + server_2_cfg = { + "type": "server", + "hostname": "server_2", + "ip_address": "192.168.1.14", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + server_2 = Server.from_config(config=server_2_cfg) + server_2.power_on() network.connect(endpoint_b=server_2.network_interface[1], endpoint_a=switch_1.network_interface[2]) @@ -277,6 +328,8 @@ def example_network() -> Network: assert all(link.is_up for link in network.links.values()) + client_1.software_manager.show() + return network @@ -309,29 +362,35 @@ def install_stuff_to_sim(sim: Simulation): # 1: Set up network hardware # 1.1: Configure the router - router = Router(hostname="router", num_ports=3, start_up_duration=0) + router = Router.from_config(config={"type": "router", "hostname": "router", "num_ports": 3, "start_up_duration": 0}) router.power_on() router.configure_port(port=1, ip_address="10.0.1.1", subnet_mask="255.255.255.0") router.configure_port(port=2, ip_address="10.0.2.1", subnet_mask="255.255.255.0") # 1.2: Create and connect switches - switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0) + switch_1 = Switch.from_config( + config={"type": "switch", "hostname": "switch_1", "num_ports": 6, "start_up_duration": 0} + ) switch_1.power_on() network.connect(endpoint_a=router.network_interface[1], endpoint_b=switch_1.network_interface[6]) router.enable_port(1) - switch_2 = Switch(hostname="switch_2", num_ports=6, start_up_duration=0) + switch_2 = Switch.from_config( + config={"type": "switch", "hostname": "switch_2", "num_ports": 6, "start_up_duration": 0} + ) switch_2.power_on() network.connect(endpoint_a=router.network_interface[2], endpoint_b=switch_2.network_interface[6]) router.enable_port(2) # 1.3: Create and connect computer - client_1 = Computer( - hostname="client_1", - ip_address="10.0.1.2", - subnet_mask="255.255.255.0", - default_gateway="10.0.1.1", - start_up_duration=0, - ) + client_1_cfg = { + "type": "computer", + "hostname": "client_1", + "ip_address": "10.0.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "10.0.1.1", + "start_up_duration": 0, + } + client_1: Computer = Computer.from_config(config=client_1_cfg) client_1.power_on() network.connect( endpoint_a=client_1.network_interface[1], @@ -339,23 +398,28 @@ def install_stuff_to_sim(sim: Simulation): ) # 1.4: Create and connect servers - server_1 = Server( - hostname="server_1", - ip_address="10.0.2.2", - subnet_mask="255.255.255.0", - default_gateway="10.0.2.1", - start_up_duration=0, - ) + server_1_cfg = { + "type": "server", + "hostname": "server_1", + "ip_address": "10.0.2.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "10.0.2.1", + "start_up_duration": 0, + } + + server_1: Server = Server.from_config(config=server_1_cfg) server_1.power_on() network.connect(endpoint_a=server_1.network_interface[1], endpoint_b=switch_2.network_interface[1]) + server_2_cfg = { + "type": "server", + "hostname": "server_2", + "ip_address": "10.0.2.3", + "subnet_mask": "255.255.255.0", + "default_gateway": "10.0.2.1", + "start_up_duration": 0, + } - server_2 = Server( - hostname="server_2", - ip_address="10.0.2.3", - subnet_mask="255.255.255.0", - default_gateway="10.0.2.1", - start_up_duration=0, - ) + server_2: Server = Server.from_config(config=server_2_cfg) server_2.power_on() network.connect(endpoint_a=server_2.network_interface[1], endpoint_b=switch_2.network_interface[2]) @@ -403,18 +467,18 @@ def install_stuff_to_sim(sim: Simulation): assert acl_rule is None # 5.2: Assert the client is correctly configured - c: Computer = [node for node in sim.network.nodes.values() if node.hostname == "client_1"][0] + c: Computer = [node for node in sim.network.nodes.values() if node.config.hostname == "client_1"][0] assert c.software_manager.software.get("web-browser") is not None assert c.software_manager.software.get("dns-client") is not None assert str(c.network_interface[1].ip_address) == "10.0.1.2" # 5.3: Assert that server_1 is correctly configured - s1: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_1"][0] + s1: Server = [node for node in sim.network.nodes.values() if node.config.hostname == "server_1"][0] assert str(s1.network_interface[1].ip_address) == "10.0.2.2" assert s1.software_manager.software.get("dns-server") is not None # 5.4: Assert that server_2 is correctly configured - s2: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_2"][0] + s2: Server = [node for node in sim.network.nodes.values() if node.config.hostname == "server_2"][0] assert str(s2.network_interface[1].ip_address) == "10.0.2.3" assert s2.software_manager.software.get("web-server") is not None diff --git a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py index a34d430b..6da801d4 100644 --- a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py +++ b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py @@ -12,6 +12,7 @@ from sb3_contrib import MaskablePPO from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from tests import TEST_ASSETS_ROOT CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml" diff --git a/tests/integration_tests/component_creation/test_action_integration.py b/tests/integration_tests/component_creation/test_action_integration.py index 8b81b7d3..0fd0aa19 100644 --- a/tests/integration_tests/component_creation/test_action_integration.py +++ b/tests/integration_tests/component_creation/test_action_integration.py @@ -12,12 +12,18 @@ def test_passing_actions_down(monkeypatch) -> None: sim = Simulation() - pc1 = Computer(hostname="PC-1", ip_address="10.10.1.1", subnet_mask="255.255.255.0") + pc1 = Computer.from_config( + config={"type": "computer", "hostname": "PC-1", "ip_address": "10.10.1.1", "subnet_mask": "255.255.255.0"} + ) pc1.start_up_duration = 0 pc1.power_on() - pc2 = Computer(hostname="PC-2", ip_address="10.10.1.2", subnet_mask="255.255.255.0") - srv = Server(hostname="WEBSERVER", ip_address="10.10.1.100", subnet_mask="255.255.255.0") - s1 = Switch(hostname="switch1") + pc2 = Computer.from_config( + config={"type": "computer", "hostname": "PC-2", "ip_address": "10.10.1.2", "subnet_mask": "255.255.255.0"} + ) + srv = Server.from_config( + config={"type": "server", "hostname": "WEBSERVER", "ip_address": "10.10.1.100", "subnet_mask": "255.255.255.0"} + ) + s1 = Switch.from_config(config={"type": "switch", "hostname": "switch1"}) for n in [pc1, pc2, srv, s1]: sim.network.add_node(n) @@ -48,6 +54,6 @@ def test_passing_actions_down(monkeypatch) -> None: assert not action_invoked # call the patched method - sim.apply_request(["network", "node", pc1.hostname, "file_system", "folder", "downloads", "repair"]) + sim.apply_request(["network", "node", pc1.config.hostname, "file_system", "folder", "downloads", "repair"]) assert action_invoked diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py index 16f4dee5..7ca3a6aa 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py @@ -5,6 +5,7 @@ from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP from primaite.utils.validation.port import PORT_LOOKUP diff --git a/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py b/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py index 764a7aac..f3911691 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py @@ -3,6 +3,8 @@ from primaite.config.load import data_manipulation_config_path from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.network.firewall import Firewall +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ_NETWORK, load_config diff --git a/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py b/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py index c588829b..1352f894 100644 --- a/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py +++ b/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py @@ -4,6 +4,7 @@ import yaml from primaite.session.environment import PrimaiteGymEnv from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from tests.conftest import TEST_ASSETS_ROOT folder_path = TEST_ASSETS_ROOT / "configs" / "scenario_with_placeholders" diff --git a/tests/integration_tests/extensions/nodes/super_computer.py b/tests/integration_tests/extensions/nodes/super_computer.py index cf3ead58..c4ad61ae 100644 --- a/tests/integration_tests/extensions/nodes/super_computer.py +++ b/tests/integration_tests/extensions/nodes/super_computer.py @@ -36,8 +36,8 @@ class SuperComputer(HostNode, discriminator="supercomputer"): SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "ftp-client": FTPClient} - def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs): + def __init__(self, **kwargs): print("--- Extended Component: SuperComputer ---") - super().__init__(ip_address=ip_address, subnet_mask=subnet_mask, **kwargs) + super().__init__(**kwargs) pass diff --git a/tests/integration_tests/extensions/services/extended_service.py b/tests/integration_tests/extensions/services/extended_service.py index fc8dc630..b1cf7ed5 100644 --- a/tests/integration_tests/extensions/services/extended_service.py +++ b/tests/integration_tests/extensions/services/extended_service.py @@ -31,14 +31,14 @@ class ExtendedService(Service, discriminator="extended-service"): type: str = "extended-service" + backup_server_ip: IPv4Address = None + """IP address of the backup server.""" + config: "ExtendedService.ConfigSchema" = Field(default_factory=lambda: ExtendedService.ConfigSchema()) password: Optional[str] = None """Password that needs to be provided by clients if they want to connect to the DatabaseService.""" - backup_server_ip: IPv4Address = None - """IP address of the backup server.""" - latest_backup_directory: str = None """Directory of latest backup.""" diff --git a/tests/integration_tests/game_layer/actions/test_node_request_permission.py b/tests/integration_tests/game_layer/actions/test_node_request_permission.py index 58ea4c98..baddee46 100644 --- a/tests/integration_tests/game_layer/actions/test_node_request_permission.py +++ b/tests/integration_tests/game_layer/actions/test_node_request_permission.py @@ -25,6 +25,7 @@ def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy game, agent = game_and_agent_fixture client_1 = game.simulation.network.get_node_by_hostname("client_1") + client_1.config.shut_down_duration = 3 assert client_1.operating_state == NodeOperatingState.ON @@ -35,13 +36,15 @@ def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy assert client_1.operating_state == NodeOperatingState.SHUTTING_DOWN - for i in range(client_1.shut_down_duration + 1): + for i in range(client_1.config.shut_down_duration + 1): action = ("do-nothing", {}) agent.store_action(action) game.step() assert client_1.operating_state == NodeOperatingState.OFF + client_1.config.start_up_duration = 3 + # turn it on action = ("node-startup", {"node_name": "client_1"}) agent.store_action(action) @@ -49,7 +52,7 @@ def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy assert client_1.operating_state == NodeOperatingState.BOOTING - for i in range(client_1.start_up_duration + 1): + for i in range(client_1.config.start_up_duration + 1): action = ("do-nothing", {}) agent.store_action(action) game.step() @@ -79,7 +82,7 @@ def test_node_cannot_be_shut_down_if_node_is_already_off(game_and_agent_fixture: client_1 = game.simulation.network.get_node_by_hostname("client_1") client_1.power_off() - for i in range(client_1.shut_down_duration + 1): + for i in range(client_1.config.shut_down_duration + 1): action = ("do-nothing", {}) agent.store_action(action) game.step() diff --git a/tests/integration_tests/game_layer/observations/test_acl_observations.py b/tests/integration_tests/game_layer/observations/test_acl_observations.py index 44ef4a70..0a633b2d 100644 --- a/tests/integration_tests/game_layer/observations/test_acl_observations.py +++ b/tests/integration_tests/game_layer/observations/test_acl_observations.py @@ -36,7 +36,7 @@ def test_acl_observations(simulation): router.acl.add_rule(action=ACLAction.PERMIT, dst_port=PORT_LOOKUP["NTP"], src_port=PORT_LOOKUP["NTP"], position=1) acl_obs = ACLObservation( - where=["network", "nodes", router.hostname, "acl", "acl"], + where=["network", "nodes", router.config.hostname, "acl", "acl"], ip_list=[], port_list=[123, 80, 5432], protocol_list=["tcp", "udp", "icmp"], diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py index 19c0c4bc..7323461c 100644 --- a/tests/integration_tests/game_layer/observations/test_file_system_observations.py +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -24,7 +24,7 @@ def test_file_observation(simulation): file = pc.file_system.create_file(file_name="dog.png") dog_file_obs = FileObservation( - where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"], include_num_access=False, file_system_requires_scan=True, ) @@ -52,7 +52,7 @@ def test_folder_observation(simulation): file = pc.file_system.create_file(file_name="dog.png", folder_name="test_folder") root_folder_obs = FolderObservation( - where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"], + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "test_folder"], include_num_access=False, file_system_requires_scan=True, num_files=1, diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 97608132..874fa49e 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -25,7 +25,8 @@ def check_default_rules(acl_obs): def test_firewall_observation(): """Test adding/removing acl rules and enabling/disabling ports.""" net = Network() - firewall = Firewall(hostname="firewall", operating_state=NodeOperatingState.ON) + firewall_cfg = {"type": "firewall", "hostname": "firewall"} + firewall = Firewall.from_config(config=firewall_cfg) firewall_observation = FirewallObservation( where=[], num_rules=7, @@ -116,7 +117,9 @@ def test_firewall_observation(): assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(1, 4)) # connect a switch to the firewall and check that only the correct port is updated - switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON) + switch: Switch = Switch.from_config( + config={"type": "switch", "hostname": "switch", "num_ports": 1, "operating_state": "ON"} + ) link = net.connect(firewall.network_interface[1], switch.network_interface[1]) assert firewall.network_interface[1].enabled observation = firewall_observation.observe(firewall.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_link_observations.py b/tests/integration_tests/game_layer/observations/test_link_observations.py index 630e29ea..1ab50a68 100644 --- a/tests/integration_tests/game_layer/observations/test_link_observations.py +++ b/tests/integration_tests/game_layer/observations/test_link_observations.py @@ -56,12 +56,26 @@ def test_link_observation(): """Check the shape and contents of the link observation.""" net = Network() sim = Simulation(network=net) - switch = Switch(hostname="switch", num_ports=5, operating_state=NodeOperatingState.ON) - computer_1 = Computer( - hostname="computer_1", ip_address="10.0.0.1", subnet_mask="255.255.255.0", start_up_duration=0 + switch: Switch = Switch.from_config( + config={"type": "switch", "hostname": "switch", "num_ports": 5, "operating_state": "ON"} ) - computer_2 = Computer( - hostname="computer_2", ip_address="10.0.0.2", subnet_mask="255.255.255.0", start_up_duration=0 + computer_1: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer_1", + "ip_address": "10.0.0.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) + computer_2: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer_2", + "ip_address": "10.0.0.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } ) computer_1.power_on() computer_2.power_on() diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index 2a6ac44d..b5e5ca81 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -75,7 +75,7 @@ def test_nic(simulation): nic: NIC = pc.network_interface[1] - nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True) + nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True) # Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs nmne_config = { @@ -108,7 +108,7 @@ def test_nic_categories(simulation): """Test the NIC observation nmne count categories.""" pc: Computer = simulation.network.get_node_by_hostname("client_1") - nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True) + nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True) assert nic_obs.high_nmne_threshold == 10 # default assert nic_obs.med_nmne_threshold == 5 # default @@ -163,7 +163,9 @@ def test_nic_monitored_traffic(simulation): pc2: Computer = simulation.network.get_node_by_hostname("client_2") nic_obs = NICObservation( - where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=False, monitored_traffic=monitored_traffic + where=["network", "nodes", pc.config.hostname, "NICs", 1], + include_nmne=False, + monitored_traffic=monitored_traffic, ) simulation.pre_timestep(0) # apply timestep to whole sim diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index 63ca8f6b..09eb3fe4 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -25,7 +25,7 @@ def test_host_observation(simulation): pc: Computer = simulation.network.get_node_by_hostname("client_1") host_obs = HostObservation( - where=["network", "nodes", pc.hostname], + where=["network", "nodes", pc.config.hostname], num_applications=0, num_files=1, num_folders=1, @@ -56,7 +56,7 @@ def test_host_observation(simulation): observation_state = host_obs.observe(simulation.describe_state()) assert observation_state.get("operating_status") == 4 # shutting down - for i in range(pc.shut_down_duration + 1): + for i in range(pc.config.shut_down_duration + 1): pc.apply_timestep(i) observation_state = host_obs.observe(simulation.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index f4bfb193..495e102d 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -16,7 +16,9 @@ from primaite.utils.validation.port import PORT_LOOKUP def test_router_observation(): """Test adding/removing acl rules and enabling/disabling ports.""" net = Network() - router = Router(hostname="router", num_ports=5, operating_state=NodeOperatingState.ON) + router = Router.from_config( + config={"type": "router", "hostname": "router", "num_ports": 5, "operating_state": "ON"} + ) ports = [PortObservation(where=["NICs", i]) for i in range(1, 6)] acl = ACLObservation( @@ -89,7 +91,9 @@ def test_router_observation(): assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(1, 6)) # connect a switch to the router and check that only the correct port is updated - switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON) + switch: Switch = Switch.from_config( + config={"type": "switch", "hostname": "switch", "num_ports": 1, "operating_state": "ON"} + ) link = net.connect(router.network_interface[1], switch.network_interface[1]) assert router.network_interface[1].enabled observed_output = router_observation.observe(router.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_software_observations.py b/tests/integration_tests/game_layer/observations/test_software_observations.py index 1d457b0f..28cdaf01 100644 --- a/tests/integration_tests/game_layer/observations/test_software_observations.py +++ b/tests/integration_tests/game_layer/observations/test_software_observations.py @@ -29,7 +29,7 @@ def test_service_observation(simulation): ntp_server = pc.software_manager.software.get("ntp-server") assert ntp_server - service_obs = ServiceObservation(where=["network", "nodes", pc.hostname, "services", "ntp-server"]) + service_obs = ServiceObservation(where=["network", "nodes", pc.config.hostname, "services", "ntp-server"]) assert service_obs.space["operating_status"] == spaces.Discrete(7) assert service_obs.space["health_status"] == spaces.Discrete(5) @@ -54,7 +54,7 @@ def test_application_observation(simulation): web_browser: WebBrowser = pc.software_manager.software.get("web-browser") assert web_browser - app_obs = ApplicationObservation(where=["network", "nodes", pc.hostname, "applications", "web-browser"]) + app_obs = ApplicationObservation(where=["network", "nodes", pc.config.hostname, "applications", "web-browser"]) web_browser.close() observation_state = app_obs.observe(simulation.describe_state()) diff --git a/tests/integration_tests/game_layer/test_action_mask.py b/tests/integration_tests/game_layer/test_action_mask.py index 9a489a50..e0337929 100644 --- a/tests/integration_tests/game_layer/test_action_mask.py +++ b/tests/integration_tests/game_layer/test_action_mask.py @@ -2,6 +2,7 @@ from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.host_node import HostNode +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from primaite.simulator.system.services.service import ServiceOperatingState from tests.conftest import TEST_ASSETS_ROOT diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index 60cbaa53..4ae3dd6e 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -21,6 +21,7 @@ from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus +from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.software import SoftwareHealthState diff --git a/tests/integration_tests/game_layer/test_observations.py b/tests/integration_tests/game_layer/test_observations.py index 5afad296..17b9b71e 100644 --- a/tests/integration_tests/game_layer/test_observations.py +++ b/tests/integration_tests/game_layer/test_observations.py @@ -8,14 +8,16 @@ from primaite.simulator.sim_container import Simulation def test_file_observation(): sim = Simulation() - pc = Computer(hostname="beep", ip_address="123.123.123.123", subnet_mask="255.255.255.0") + pc: Computer = Computer.from_config( + config={"type": "computer", "hostname": "beep", "ip_address": "123.123.123.123", "subnet_mask": "255.255.255.0"} + ) sim.network.add_node(pc) f = pc.file_system.create_file(file_name="dog.png") state = sim.describe_state() dog_file_obs = FileObservation( - where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"], include_num_access=False, file_system_requires_scan=False, ) diff --git a/tests/integration_tests/network/test_airspace_config.py b/tests/integration_tests/network/test_airspace_config.py index e8abc0f2..fd3f6f28 100644 --- a/tests/integration_tests/network/test_airspace_config.py +++ b/tests/integration_tests/network/test_airspace_config.py @@ -2,6 +2,7 @@ import yaml from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from tests import TEST_ASSETS_ROOT diff --git a/tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py b/tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py index 479473d1..6f3e7546 100644 --- a/tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py +++ b/tests/integration_tests/network/test_bandwidth_load_checks_before_transmission.py @@ -16,7 +16,7 @@ def test_wireless_link_loading(wireless_wan_network): # Configure Router 2 ACLs router_2.acl.add_rule(action=ACLAction.PERMIT, position=1) - airspace = router_1.airspace + airspace = router_1.config.airspace client.software_manager.install(FTPClient) ftp_client: FTPClient = client.software_manager.software.get("ftp-client") diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index ab944564..b1bbfc63 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -84,44 +84,55 @@ class BroadcastTestClient(Application, discriminator="broadcast-test-client"): def broadcast_network() -> Network: network = Network() - client_1 = Computer( - hostname="client_1", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + client_1_cfg = { + "type": "computer", + "hostname": "client_1", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + client_1: Computer = Computer.from_config(config=client_1_cfg) client_1.power_on() client_1.software_manager.install(BroadcastTestClient) application_1 = client_1.software_manager.software["broadcast-test-client"] application_1.run() + client_2_cfg = { + "type": "computer", + "hostname": "client_2", + "ip_address": "192.168.1.3", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } - client_2 = Computer( - hostname="client_2", - ip_address="192.168.1.3", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + client_2: Computer = Computer.from_config(config=client_2_cfg) client_2.power_on() client_2.software_manager.install(BroadcastTestClient) application_2 = client_2.software_manager.software["broadcast-test-client"] application_2.run() - server_1 = Server( - hostname="server_1", - ip_address="192.168.1.1", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + server_1_cfg = { + "type": "server", + "hostname": "server_1", + "ip_address": "192.168.1.1", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + server_1: Server = Server.from_config(config=server_1_cfg) + server_1.power_on() server_1.software_manager.install(BroadcastTestService) service: BroadcastTestService = server_1.software_manager.software["BroadcastService"] service.start() - switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0) + switch_1: Switch = Switch.from_config( + config={"type": "switch", "hostname": "switch_1", "num_ports": 6, "start_up_duration": 0} + ) switch_1.power_on() network.connect(endpoint_a=client_1.network_interface[1], endpoint_b=switch_1.network_interface[1]) diff --git a/tests/integration_tests/network/test_firewall.py b/tests/integration_tests/network/test_firewall.py index ec0200a2..7c2c36c0 100644 --- a/tests/integration_tests/network/test_firewall.py +++ b/tests/integration_tests/network/test_firewall.py @@ -41,7 +41,9 @@ def dmz_external_internal_network() -> Network: """ network = Network() - firewall_node: Firewall = Firewall(hostname="firewall_1", start_up_duration=0) + firewall_node: Firewall = Firewall.from_config( + config={"type": "firewall", "hostname": "firewall_1", "start_up_duration": 0} + ) firewall_node.power_on() # configure firewall ports firewall_node.configure_external_port( @@ -81,12 +83,15 @@ def dmz_external_internal_network() -> Network: ) # external node - external_node = Computer( - hostname="external_node", - ip_address="192.168.10.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.10.1", - start_up_duration=0, + external_node: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "external_node", + "ip_address": "192.168.10.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.10.1", + "start_up_duration": 0, + } ) external_node.power_on() external_node.software_manager.install(NTPServer) @@ -96,12 +101,15 @@ def dmz_external_internal_network() -> Network: network.connect(endpoint_b=external_node.network_interface[1], endpoint_a=firewall_node.external_port) # internal node - internal_node = Computer( - hostname="internal_node", - ip_address="192.168.0.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.0.1", - start_up_duration=0, + internal_node: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "internal_node", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.0.1", + "start_up_duration": 0, + } ) internal_node.power_on() internal_node.software_manager.install(NTPClient) @@ -112,12 +120,15 @@ def dmz_external_internal_network() -> Network: network.connect(endpoint_b=internal_node.network_interface[1], endpoint_a=firewall_node.internal_port) # dmz node - dmz_node = Computer( - hostname="dmz_node", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + dmz_node: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "dmz_node", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } ) dmz_node.power_on() dmz_ntp_client: NTPClient = dmz_node.software_manager.software["ntp-client"] @@ -155,9 +166,9 @@ def test_nodes_can_ping_default_gateway(dmz_external_internal_network): internal_node = dmz_external_internal_network.get_node_by_hostname("internal_node") dmz_node = dmz_external_internal_network.get_node_by_hostname("dmz_node") - assert internal_node.ping(internal_node.default_gateway) # default gateway internal - assert dmz_node.ping(dmz_node.default_gateway) # default gateway dmz - assert external_node.ping(external_node.default_gateway) # default gateway external + assert internal_node.ping(internal_node.config.default_gateway) # default gateway internal + assert dmz_node.ping(dmz_node.config.default_gateway) # default gateway dmz + assert external_node.ping(external_node.config.default_gateway) # default gateway external def test_nodes_can_ping_default_gateway_on_another_subnet(dmz_external_internal_network): @@ -171,14 +182,14 @@ def test_nodes_can_ping_default_gateway_on_another_subnet(dmz_external_internal_ internal_node = dmz_external_internal_network.get_node_by_hostname("internal_node") dmz_node = dmz_external_internal_network.get_node_by_hostname("dmz_node") - assert internal_node.ping(external_node.default_gateway) # internal node to external default gateway - assert internal_node.ping(dmz_node.default_gateway) # internal node to dmz default gateway + assert internal_node.ping(external_node.config.default_gateway) # internal node to external default gateway + assert internal_node.ping(dmz_node.config.default_gateway) # internal node to dmz default gateway - assert dmz_node.ping(internal_node.default_gateway) # dmz node to internal default gateway - assert dmz_node.ping(external_node.default_gateway) # dmz node to external default gateway + assert dmz_node.ping(internal_node.config.default_gateway) # dmz node to internal default gateway + assert dmz_node.ping(external_node.config.default_gateway) # dmz node to external default gateway - assert external_node.ping(external_node.default_gateway) # external node to internal default gateway - assert external_node.ping(dmz_node.default_gateway) # external node to dmz default gateway + assert external_node.ping(external_node.config.default_gateway) # external node to internal default gateway + assert external_node.ping(dmz_node.config.default_gateway) # external node to dmz default gateway def test_nodes_can_ping_each_other(dmz_external_internal_network): diff --git a/tests/integration_tests/network/test_frame_transmission.py b/tests/integration_tests/network/test_frame_transmission.py index 327c87e5..6a514bdc 100644 --- a/tests/integration_tests/network/test_frame_transmission.py +++ b/tests/integration_tests/network/test_frame_transmission.py @@ -10,25 +10,31 @@ def test_node_to_node_ping(): """Tests two Computers are able to ping each other.""" network = Network() - client_1 = Computer( - hostname="client_1", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + client_1: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "client_1", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } ) client_1.power_on() - server_1 = Server( - hostname="server_1", - ip_address="192.168.1.11", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + server_1: Server = Server.from_config( + config={ + "type": "server", + "hostname": "server_1", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } ) server_1.power_on() - switch_1 = Switch(hostname="switch_1", start_up_duration=0) + switch_1: Switch = Switch.from_config(config={"type": "switch", "hostname": "switch_1", "start_up_duration": 0}) switch_1.power_on() network.connect(endpoint_a=client_1.network_interface[1], endpoint_b=switch_1.network_interface[1]) @@ -41,14 +47,38 @@ def test_multi_nic(): """Tests that Computers with multiple NICs can ping each other and the data go across the correct links.""" network = Network() - node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + node_a: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_a", + "ip_address": "192.168.0.10", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node_a.power_on() - node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) + node_b: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_b", + "ip_address": "192.168.0.11", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node_b.power_on() node_b.connect_nic(NIC(ip_address="10.0.0.12", subnet_mask="255.0.0.0")) - node_c = Computer(hostname="node_c", ip_address="10.0.0.13", subnet_mask="255.0.0.0", start_up_duration=0) + node_c: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_c", + "ip_address": "10.0.0.13", + "subnet_mask": "255.0.0.0", + "start_up_duration": 0, + } + ) node_c.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) diff --git a/tests/integration_tests/network/test_multi_lan_internet_example_network.py b/tests/integration_tests/network/test_multi_lan_internet_example_network.py index c58d79a4..381fea62 100644 --- a/tests/integration_tests/network/test_multi_lan_internet_example_network.py +++ b/tests/integration_tests/network/test_multi_lan_internet_example_network.py @@ -1,6 +1,7 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.networks import multi_lan_internet_network_example from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser diff --git a/tests/integration_tests/network/test_network_creation.py b/tests/integration_tests/network/test_network_creation.py index 1ee3ccc2..4d88eac3 100644 --- a/tests/integration_tests/network/test_network_creation.py +++ b/tests/integration_tests/network/test_network_creation.py @@ -27,7 +27,15 @@ def test_network(example_network): def test_adding_removing_nodes(): """Check that we can create and add a node to a network.""" net = Network() - n1 = Computer(hostname="computer", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + n1 = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) net.add_node(n1) assert n1.parent is net assert n1 in net @@ -37,10 +45,18 @@ def test_adding_removing_nodes(): assert n1 not in net -def test_readding_node(): - """Check that warning is raised when readding a node.""" +def test_reading_node(): + """Check that warning is raised when reading a node.""" net = Network() - n1 = Computer(hostname="computer", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + n1 = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) net.add_node(n1) net.add_node(n1) assert n1.parent is net @@ -50,7 +66,15 @@ def test_readding_node(): def test_removing_nonexistent_node(): """Check that warning is raised when trying to remove a node that is not in the network.""" net = Network() - n1 = Computer(hostname="computer1", ip_address="192.168.1.1", subnet_mask="255.255.255.0", start_up_duration=0) + n1 = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer1", + "ip_address": "192.168.1.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) net.remove_node(n1) assert n1.parent is None assert n1 not in net @@ -59,8 +83,24 @@ def test_removing_nonexistent_node(): def test_connecting_nodes(): """Check that two nodes on the network can be connected.""" net = Network() - n1 = Computer(hostname="computer1", ip_address="192.168.1.1", subnet_mask="255.255.255.0", start_up_duration=0) - n2 = Computer(hostname="computer2", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + n1: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer1", + "ip_address": "192.168.1.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) + n2: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer2", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) net.add_node(n1) net.add_node(n2) @@ -75,7 +115,15 @@ def test_connecting_nodes(): def test_connecting_node_to_itself_fails(): net = Network() - node = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) + node = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_b", + "ip_address": "192.168.0.11", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node.power_on() node.connect_nic(NIC(ip_address="10.0.0.12", subnet_mask="255.0.0.0")) @@ -92,8 +140,24 @@ def test_connecting_node_to_itself_fails(): def test_disconnecting_nodes(): net = Network() - n1 = Computer(hostname="computer1", ip_address="192.168.1.1", subnet_mask="255.255.255.0", start_up_duration=0) - n2 = Computer(hostname="computer2", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + n1 = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer1", + "ip_address": "192.168.1.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) + n2 = Computer.from_config( + config={ + "type": "computer", + "hostname": "computer2", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) net.connect(n1.network_interface[1], n2.network_interface[1]) assert len(net.links) == 1 diff --git a/tests/integration_tests/network/test_routing.py b/tests/integration_tests/network/test_routing.py index 7d23a2a6..ccf7c8ff 100644 --- a/tests/integration_tests/network/test_routing.py +++ b/tests/integration_tests/network/test_routing.py @@ -15,25 +15,31 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def pc_a_pc_b_router_1() -> Tuple[Computer, Computer, Router]: network = Network() - pc_a = Computer( - hostname="pc_a", - ip_address="192.168.0.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.0.1", - start_up_duration=0, + pc_a = Computer.from_config( + config={ + "type": "computer", + "hostname": "pc_a", + "ip_address": "192.168.0.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.0.1", + "start_up_duration": 0, + } ) pc_a.power_on() - pc_b = Computer( - hostname="pc_b", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + pc_b = Computer.from_config( + config={ + "type": "computer", + "hostname": "pc_b", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } ) pc_b.power_on() - router_1 = Router(hostname="router_1", start_up_duration=0) + router_1 = Router.from_config(config={"type": "router", "hostname": "router_1", "start_up_duration": 0}) router_1.power_on() router_1.configure_port(1, "192.168.0.1", "255.255.255.0") @@ -52,18 +58,21 @@ def multi_hop_network() -> Network: network = Network() # Configure PC A - pc_a = Computer( - hostname="pc_a", - ip_address="192.168.0.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.0.1", - start_up_duration=0, + pc_a: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "pc_a", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.0.1", + "start_up_duration": 0, + } ) pc_a.power_on() network.add_node(pc_a) # Configure Router 1 - router_1 = Router(hostname="router_1", start_up_duration=0) + router_1: Router = Router.from_config(config={"type": "router", "hostname": "router_1", "start_up_duration": 0}) router_1.power_on() network.add_node(router_1) @@ -79,18 +88,21 @@ def multi_hop_network() -> Network: router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Configure PC B - pc_b = Computer( - hostname="pc_b", - ip_address="192.168.2.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.2.1", - start_up_duration=0, + pc_b: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "pc_b", + "ip_address": "192.168.2.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.2.1", + "start_up_duration": 0, + } ) pc_b.power_on() network.add_node(pc_b) # Configure Router 2 - router_2 = Router(hostname="router_2", start_up_duration=0) + router_2: Router = Router.from_config(config={"type": "router", "hostname": "router_2", "start_up_duration": 0}) router_2.power_on() network.add_node(router_2) @@ -113,13 +125,13 @@ def multi_hop_network() -> Network: def test_ping_default_gateway(pc_a_pc_b_router_1): pc_a, pc_b, router_1 = pc_a_pc_b_router_1 - assert pc_a.ping(pc_a.default_gateway) + assert pc_a.ping(pc_a.config.default_gateway) def test_ping_other_router_port(pc_a_pc_b_router_1): pc_a, pc_b, router_1 = pc_a_pc_b_router_1 - assert pc_a.ping(pc_b.default_gateway) + assert pc_a.ping(pc_b.config.default_gateway) def test_host_on_other_subnet(pc_a_pc_b_router_1): diff --git a/tests/integration_tests/network/test_wireless_router.py b/tests/integration_tests/network/test_wireless_router.py index 26e50f4a..487736e7 100644 --- a/tests/integration_tests/network/test_wireless_router.py +++ b/tests/integration_tests/network/test_wireless_router.py @@ -17,18 +17,23 @@ def wireless_wan_network(): network = Network() # Configure PC A - pc_a = Computer( - hostname="pc_a", - ip_address="192.168.0.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.0.1", - start_up_duration=0, + pc_a = Computer.from_config( + config={ + "type": "computer", + "hostname": "pc_a", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.0.1", + "start_up_duration": 0, + } ) pc_a.power_on() network.add_node(pc_a) # Configure Router 1 - router_1 = WirelessRouter(hostname="router_1", start_up_duration=0, airspace=network.airspace) + router_1 = WirelessRouter.from_config( + config={"type": "wireless_router", "hostname": "router_1", "start_up_duration": 0, "airspace": network.airspace} + ) router_1.power_on() network.add_node(router_1) @@ -43,18 +48,23 @@ def wireless_wan_network(): router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Configure PC B - pc_b = Computer( - hostname="pc_b", - ip_address="192.168.2.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.2.1", - start_up_duration=0, + pc_b: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "pc_b", + "ip_address": "192.168.2.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.2.1", + "start_up_duration": 0, + } ) pc_b.power_on() network.add_node(pc_b) # Configure Router 2 - router_2 = WirelessRouter(hostname="router_2", start_up_duration=0, airspace=network.airspace) + router_2: WirelessRouter = WirelessRouter.from_config( + config={"type": "wireless_router", "hostname": "router_2", "start_up_duration": 0, "airspace": network.airspace} + ) router_2.power_on() network.add_node(router_2) @@ -98,8 +108,8 @@ def wireless_wan_network_from_config_yaml(): def test_cross_wireless_wan_connectivity(wireless_wan_network): pc_a, pc_b, router_1, router_2 = wireless_wan_network # Ensure that PCs can ping across routers before any frequency change - assert pc_a.ping(pc_a.default_gateway), "PC A should ping its default gateway successfully." - assert pc_b.ping(pc_b.default_gateway), "PC B should ping its default gateway successfully." + assert pc_a.ping(pc_a.config.default_gateway), "PC A should ping its default gateway successfully." + assert pc_b.ping(pc_b.config.default_gateway), "PC B should ping its default gateway successfully." assert pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should ping PC B across routers successfully." assert pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should ping PC A across routers successfully." @@ -109,8 +119,8 @@ def test_cross_wireless_wan_connectivity_from_yaml(wireless_wan_network_from_con pc_a = wireless_wan_network_from_config_yaml.get_node_by_hostname("pc_a") pc_b = wireless_wan_network_from_config_yaml.get_node_by_hostname("pc_b") - assert pc_a.ping(pc_a.default_gateway), "PC A should ping its default gateway successfully." - assert pc_b.ping(pc_b.default_gateway), "PC B should ping its default gateway successfully." + assert pc_a.ping(pc_a.config.default_gateway), "PC A should ping its default gateway successfully." + assert pc_b.ping(pc_b.config.default_gateway), "PC B should ping its default gateway successfully." assert pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should ping PC B across routers successfully." assert pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should ping PC A across routers successfully." diff --git a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py index 5475edd1..62eda37d 100644 --- a/tests/integration_tests/system/red_applications/test_c2_suite_integration.py +++ b/tests/integration_tests/system/red_applications/test_c2_suite_integration.py @@ -34,52 +34,64 @@ def basic_network() -> Network: # Creating two generic nodes for the C2 Server and the C2 Beacon. - node_a = Computer( - hostname="node_a", - ip_address="192.168.0.2", - subnet_mask="255.255.255.252", - default_gateway="192.168.0.1", - start_up_duration=0, - ) + node_a_cfg = { + "type": "computer", + "hostname": "node_a", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.252", + "default_gateway": "192.168.0.1", + "start_up_duration": 0, + } + + node_a: Computer = Computer.from_config(config=node_a_cfg) node_a.power_on() node_a.software_manager.get_open_ports() node_a.software_manager.install(software_class=C2Server) - node_b = Computer( - hostname="node_b", - ip_address="192.168.255.2", - subnet_mask="255.255.255.248", - default_gateway="192.168.255.1", - start_up_duration=0, - ) + node_b_cfg = { + "type": "computer", + "hostname": "node_b", + "ip_address": "192.168.255.2", + "subnet_mask": "255.255.255.248", + "default_gateway": "192.168.255.1", + "start_up_duration": 0, + } + node_b: Computer = Computer.from_config(config=node_b_cfg) node_b.power_on() node_b.software_manager.install(software_class=C2Beacon) # Creating a generic computer for testing remote terminal connections. - node_c = Computer( - hostname="node_c", - ip_address="192.168.255.3", - subnet_mask="255.255.255.248", - default_gateway="192.168.255.1", - start_up_duration=0, - ) + node_c_cfg = { + "type": "computer", + "hostname": "node_c", + "ip_address": "192.168.255.3", + "subnet_mask": "255.255.255.248", + "default_gateway": "192.168.255.1", + "start_up_duration": 0, + } + + node_c: Computer = Computer.from_config(config=node_c_cfg) node_c.power_on() # Creating a router to sit between node 1 and node 2. - router = Router(hostname="router", num_ports=3, start_up_duration=0) + router = Router.from_config(config={"type": "router", "hostname": "router", "num_ports": 3, "start_up_duration": 0}) # Default allow all. router.acl.add_rule(action=ACLAction.PERMIT) router.power_on() # Creating switches for each client. - switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0) + switch_1 = Switch.from_config( + config={"type": "switch", "hostname": "switch_1", "num_ports": 6, "start_up_duration": 0} + ) switch_1.power_on() # Connecting the switches to the router. router.configure_port(port=1, ip_address="192.168.0.1", subnet_mask="255.255.255.252") network.connect(endpoint_a=router.network_interface[1], endpoint_b=switch_1.network_interface[6]) - switch_2 = Switch(hostname="switch_2", num_ports=6, start_up_duration=0) + switch_2 = Switch.from_config( + config={"type": "switch", "hostname": "switch_2", "num_ports": 6, "start_up_duration": 0} + ) switch_2.power_on() network.connect(endpoint_a=router.network_interface[2], endpoint_b=switch_2.network_interface[6]) diff --git a/tests/integration_tests/system/test_application_on_node.py b/tests/integration_tests/system/test_application_on_node.py index 04fbb298..38a7ca03 100644 --- a/tests/integration_tests/system/test_application_on_node.py +++ b/tests/integration_tests/system/test_application_on_node.py @@ -10,13 +10,16 @@ from primaite.simulator.system.applications.application import Application, Appl @pytest.fixture(scope="function") def populated_node(application_class) -> Tuple[Application, Computer]: - computer: Computer = Computer( - hostname="test_computer", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - shut_down_duration=0, + computer: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "test_computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + "shut_down_duration": 0, + } ) computer.power_on() computer.software_manager.install(application_class) @@ -29,13 +32,16 @@ def populated_node(application_class) -> Tuple[Application, Computer]: def test_application_on_offline_node(application_class): """Test to check that the application cannot be interacted with when node it is on is off.""" - computer: Computer = Computer( - hostname="test_computer", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - shut_down_duration=0, + computer: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "test_computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + "shut_down_duration": 0, + } ) computer.software_manager.install(application_class) diff --git a/tests/integration_tests/system/test_database_on_node.py b/tests/integration_tests/system/test_database_on_node.py index 7c3ded2b..6627b7a1 100644 --- a/tests/integration_tests/system/test_database_on_node.py +++ b/tests/integration_tests/system/test_database_on_node.py @@ -20,11 +20,27 @@ from primaite.simulator.system.software import SoftwareHealthState @pytest.fixture(scope="function") def peer_to_peer() -> Tuple[Computer, Computer]: network = Network() - node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + node_a: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_a", + "ip_address": "192.168.0.10", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node_a.power_on() node_a.software_manager.get_open_ports() - node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) + node_b: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_b", + "ip_address": "192.168.0.11", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) @@ -46,7 +62,7 @@ def peer_to_peer_secure_db(peer_to_peer) -> Tuple[Computer, Computer]: database_service: DatabaseService = node_b.software_manager.software["database-service"] # noqa database_service.stop() - database_service.password = "12345" + database_service.config.db_password = "12345" database_service.start() return node_a, node_b @@ -338,7 +354,7 @@ def test_database_client_cannot_query_offline_database_server(uc2_network): assert db_connection.query("INSERT") is True db_server.power_off() - for i in range(db_server.shut_down_duration + 1): + for i in range(db_server.config.shut_down_duration + 1): uc2_network.apply_timestep(timestep=i) assert db_server.operating_state is NodeOperatingState.OFF @@ -412,8 +428,14 @@ def test_database_service_can_terminate_connection(peer_to_peer): def test_client_connection_terminate_does_not_terminate_another_clients_connection(): network = Network() - db_server = Server( - hostname="db_client", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0 + db_server: Server = Server.from_config( + config={ + "type": "server", + "hostname": "db_client", + "ip_address": "192.168.0.11", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } ) db_server.power_on() @@ -421,8 +443,14 @@ def test_client_connection_terminate_does_not_terminate_another_clients_connecti db_service: DatabaseService = db_server.software_manager.software["database-service"] # noqa db_service.start() - client_a = Computer( - hostname="client_a", ip_address="192.168.0.12", subnet_mask="255.255.255.0", start_up_duration=0 + client_a = Computer.from_config( + config={ + "type": "computer", + "hostname": "client_a", + "ip_address": "192.168.0.12", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } ) client_a.power_on() @@ -430,8 +458,14 @@ def test_client_connection_terminate_does_not_terminate_another_clients_connecti client_a.software_manager.software["database-client"].configure(server_ip_address=IPv4Address("192.168.0.11")) client_a.software_manager.software["database-client"].run() - client_b = Computer( - hostname="client_b", ip_address="192.168.0.13", subnet_mask="255.255.255.0", start_up_duration=0 + client_b = Computer.from_config( + config={ + "type": "computer", + "hostname": "client_b", + "ip_address": "192.168.0.13", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } ) client_b.power_on() @@ -439,7 +473,7 @@ def test_client_connection_terminate_does_not_terminate_another_clients_connecti client_b.software_manager.software["database-client"].configure(server_ip_address=IPv4Address("192.168.0.11")) client_b.software_manager.software["database-client"].run() - switch = Switch(hostname="switch", start_up_duration=0, num_ports=3) + switch = Switch.from_config(config={"type": "switch", "hostname": "switch", "start_up_duration": 0, "num_ports": 3}) switch.power_on() network.connect(endpoint_a=switch.network_interface[1], endpoint_b=db_server.network_interface[1]) @@ -465,6 +499,14 @@ def test_client_connection_terminate_does_not_terminate_another_clients_connecti def test_database_server_install_ftp_client(): - server = Server(hostname="db_server", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + server: Server = Server.from_config( + config={ + "type": "server", + "hostname": "db_server", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) server.software_manager.install(DatabaseService) assert server.software_manager.software.get("ftp-client") diff --git a/tests/integration_tests/system/test_dns_client_server.py b/tests/integration_tests/system/test_dns_client_server.py index 863c8c12..068b94d2 100644 --- a/tests/integration_tests/system/test_dns_client_server.py +++ b/tests/integration_tests/system/test_dns_client_server.py @@ -72,7 +72,7 @@ def test_dns_client_requests_offline_dns_server(dns_client_and_dns_server): server.power_off() - for i in range(server.shut_down_duration + 1): + for i in range(server.config.shut_down_duration + 1): server.apply_timestep(timestep=i) assert server.operating_state == NodeOperatingState.OFF diff --git a/tests/integration_tests/system/test_ftp_client_server.py b/tests/integration_tests/system/test_ftp_client_server.py index 97b9049c..bb3aa8f2 100644 --- a/tests/integration_tests/system/test_ftp_client_server.py +++ b/tests/integration_tests/system/test_ftp_client_server.py @@ -87,7 +87,7 @@ def test_ftp_client_tries_to_connect_to_offline_server(ftp_client_and_ftp_server server.power_off() - for i in range(server.shut_down_duration + 1): + for i in range(server.config.shut_down_duration + 1): server.apply_timestep(timestep=i) assert ftp_client.operating_state == ServiceOperatingState.RUNNING diff --git a/tests/integration_tests/system/test_service_on_node.py b/tests/integration_tests/system/test_service_on_node.py index 560b7770..5afb71dc 100644 --- a/tests/integration_tests/system/test_service_on_node.py +++ b/tests/integration_tests/system/test_service_on_node.py @@ -13,13 +13,15 @@ from primaite.simulator.system.services.service import Service, ServiceOperating def populated_node( service_class, ) -> Tuple[Server, Service]: - server = Server( - hostname="server", - ip_address="192.168.0.1", - subnet_mask="255.255.255.0", - start_up_duration=0, - shut_down_duration=0, - ) + server_cfg = { + "type": "server", + "hostname": "server", + "ip_address": "192.168.0.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + "shut_down_duration": 0, + } + server: Server = Server.from_config(config=server_cfg) server.power_on() server.software_manager.install(service_class) @@ -31,14 +33,16 @@ def populated_node( def test_service_on_offline_node(service_class): """Test to check that the service cannot be interacted with when node it is on is off.""" - computer: Computer = Computer( - hostname="test_computer", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - shut_down_duration=0, - ) + computer_cfg = { + "type": "computer", + "hostname": "test_computer", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + "shut_down_duration": 0, + } + computer: Computer = Computer.from_config(config=computer_cfg) computer.power_on() computer.software_manager.install(service_class) diff --git a/tests/integration_tests/system/test_user_session_manager_logins.py b/tests/integration_tests/system/test_user_session_manager_logins.py index 0c591a4b..9736232b 100644 --- a/tests/integration_tests/system/test_user_session_manager_logins.py +++ b/tests/integration_tests/system/test_user_session_manager_logins.py @@ -14,21 +14,27 @@ from primaite.simulator.network.hardware.nodes.host.server import Server def client_server_network() -> Tuple[Computer, Server, Network]: network = Network() - client = Computer( - hostname="client", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + client = Computer.from_config( + config={ + "type": "computer", + "hostname": "client", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } ) client.power_on() - server = Server( - hostname="server", - ip_address="192.168.1.3", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, + server = Server.from_config( + config={ + "type": "server", + "hostname": "server", + "ip_address": "192.168.1.3", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } ) server.power_on() diff --git a/tests/integration_tests/system/test_web_client_server.py b/tests/integration_tests/system/test_web_client_server.py index 5c79f755..fef483e9 100644 --- a/tests/integration_tests/system/test_web_client_server.py +++ b/tests/integration_tests/system/test_web_client_server.py @@ -94,7 +94,7 @@ def test_web_page_request_from_shut_down_server(web_client_and_web_server): server.power_off() - for i in range(server.shut_down_duration + 1): + for i in range(server.config.shut_down_duration + 1): server.apply_timestep(timestep=i) # node should be off diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py index 87ad3aaf..ebff2893 100644 --- a/tests/integration_tests/test_simulation/test_request_response.py +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -111,7 +111,7 @@ def test_request_fails_if_node_off(example_network, node_request): """Test that requests succeed when the node is on, and fail if the node is off.""" net = example_network client_1: HostNode = net.get_node_by_hostname("client_1") - client_1.shut_down_duration = 0 + client_1.config.shut_down_duration = 0 assert client_1.operating_state == NodeOperatingState.ON resp_1 = net.apply_request(node_request) @@ -144,9 +144,9 @@ class TestDataManipulationGreenRequests: client_1 = net.get_node_by_hostname("client_1") client_2 = net.get_node_by_hostname("client_2") - client_1.shut_down_duration = 0 + client_1.config.shut_down_duration = 0 client_1.power_off() - client_2.shut_down_duration = 0 + client_2.config.shut_down_duration = 0 client_2.power_off() client_1_browser_execute_off = net.apply_request(["node", "client_1", "application", "web-browser", "execute"]) diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py index 79392d66..ee7eb08f 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_acl.py @@ -25,7 +25,8 @@ def router_with_acl_rules(): :return: A configured Router object with ACL rules. """ - router = Router("Router") + router_cfg = {"hostname": "router_1", "type": "router"} + router = Router.from_config(config=router_cfg) acl = router.acl # Add rules here as needed acl.add_rule( @@ -62,7 +63,8 @@ def router_with_wildcard_acl(): :return: A Router object with configured ACL rules, including rules with wildcard masking. """ - router = Router("Router") + router_cfg = {"hostname": "router_1", "type": "router"} + router = Router.from_config(config=router_cfg) acl = router.acl # Rule to permit traffic from a specific source IP and port to a specific destination IP and port acl.add_rule( @@ -243,7 +245,8 @@ def test_ip_traffic_from_specific_subnet(): - Traffic from outside the 192.168.1.0/24 subnet is denied. """ - router = Router("Router") + router_cfg = {"hostname": "router_1", "type": "router"} + router = Router.from_config(config=router_cfg) acl = router.acl # Add rules here as needed acl.add_rule( diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py index fe0c3a57..e9d16533 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py @@ -50,9 +50,9 @@ def test_wireless_router_from_config(): }, } - rt = Router.from_config(cfg=cfg) + rt = Router.from_config(config=cfg) - assert rt.num_ports == 6 + assert rt.config.num_ports == 6 assert rt.network_interface[1].ip_address == IPv4Address("192.168.1.1") assert rt.network_interface[1].subnet_mask == IPv4Address("255.255.255.0") diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_switch.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_switch.py index e6bff60e..e45fe45d 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_switch.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_switch.py @@ -7,7 +7,8 @@ from primaite.simulator.network.hardware.nodes.network.switch import Switch @pytest.fixture(scope="function") def switch() -> Switch: - switch: Switch = Switch(hostname="switch_1", num_ports=8, start_up_duration=0) + switch_cfg = {"type": "switch", "hostname": "switch_1", "num_ports": 8, "start_up_duration": 0} + switch: Switch = Switch.from_config(config=switch_cfg) switch.power_on() switch.show() return switch diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py index 5cff4407..cb2d3935 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_network_interface_actions.py @@ -7,7 +7,10 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer @pytest.fixture def node() -> Node: - return Computer(hostname="test", ip_address="192.168.1.2", subnet_mask="255.255.255.0") + computer_cfg = {"type": "computer", "hostname": "test", "ip_address": "192.168.1.2", "subnet_mask": "255.255.255.0"} + computer = Computer.from_config(config=computer_cfg) + + return computer def test_nic_enabled_validator(node): diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py index 7026c116..425c0887 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py @@ -12,7 +12,16 @@ from tests.conftest import DummyApplication, DummyService @pytest.fixture def node() -> Node: - return Computer(hostname="test", ip_address="192.168.1.2", subnet_mask="255.255.255.0") + computer_cfg = { + "type": "computer", + "hostname": "test", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "operating_state": "OFF", + } + computer = Computer.from_config(config=computer_cfg) + + return computer def test_node_startup(node): @@ -166,7 +175,7 @@ def test_node_is_on_validator(node): """Test that the node is on validator.""" node.power_on() - for i in range(node.start_up_duration + 1): + for i in range(node.config.start_up_duration + 1): node.apply_timestep(i) validator = Node._NodeIsOnValidator(node=node) @@ -174,7 +183,7 @@ def test_node_is_on_validator(node): assert validator(request=[], context={}) node.power_off() - for i in range(node.shut_down_duration + 1): + for i in range(node.config.shut_down_duration + 1): node.apply_timestep(i) assert validator(request=[], context={}) is False @@ -184,7 +193,7 @@ def test_node_is_off_validator(node): """Test that the node is on validator.""" node.power_on() - for i in range(node.start_up_duration + 1): + for i in range(node.config.start_up_duration + 1): node.apply_timestep(i) validator = Node._NodeIsOffValidator(node=node) @@ -192,7 +201,7 @@ def test_node_is_off_validator(node): assert validator(request=[], context={}) is False node.power_off() - for i in range(node.shut_down_duration + 1): + for i in range(node.config.shut_down_duration + 1): node.apply_timestep(i) assert validator(request=[], context={}) diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_container.py b/tests/unit_tests/_primaite/_simulator/_network/test_container.py index b1de710a..9a54f7b2 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_container.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_container.py @@ -61,12 +61,12 @@ def test_apply_timestep_to_nodes(network): client_1.power_off() assert client_1.operating_state is NodeOperatingState.SHUTTING_DOWN - for i in range(client_1.shut_down_duration + 1): + for i in range(client_1.config.shut_down_duration + 1): network.apply_timestep(timestep=i) assert client_1.operating_state is NodeOperatingState.OFF - network.apply_timestep(client_1.shut_down_duration + 2) + network.apply_timestep(client_1.config.shut_down_duration + 2) assert client_1.operating_state is NodeOperatingState.OFF @@ -74,7 +74,16 @@ def test_removing_node_that_does_not_exist(network): """Node that does not exist on network should not affect existing nodes.""" assert len(network.nodes) is 7 - network.remove_node(Computer(hostname="new_node", ip_address="192.168.1.2", subnet_mask="255.255.255.0")) + network.remove_node( + Computer.from_config( + config={ + "type": "computer", + "hostname": "new_node", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + } + ) + ) assert len(network.nodes) is 7 diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_creation.py b/tests/unit_tests/_primaite/_simulator/_network/test_creation.py index 43ffe366..700780d0 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_creation.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_creation.py @@ -19,7 +19,7 @@ def _assert_valid_creation(net: Network, lan_name, subnet_base, pcs_ip_block_sta num_routers = 1 if include_router else 0 total_nodes = num_pcs + num_switches + num_routers - assert all((n.hostname.endswith(lan_name) for n in net.nodes.values())) + assert all((n.config.hostname.endswith(lan_name) for n in net.nodes.values())) assert len(net.computer_nodes) == num_pcs assert len(net.switch_nodes) == num_switches assert len(net.router_nodes) == num_routers diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py index dfbd26cb..64dbdd52 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_c2_suite.py @@ -16,19 +16,27 @@ def basic_c2_network() -> Network: network = Network() # Creating two generic nodes for the C2 Server and the C2 Beacon. + computer_a_cfg = { + "type": "computer", + "hostname": "computer_a", + "ip_address": "192.168.0.1", + "subnet_mask": "255.255.255.252", + "start_up_duration": 0, + } + computer_a = Computer.from_config(config=computer_a_cfg) - computer_a = Computer( - hostname="computer_a", - ip_address="192.168.0.1", - subnet_mask="255.255.255.252", - start_up_duration=0, - ) computer_a.power_on() computer_a.software_manager.install(software_class=C2Server) - computer_b = Computer( - hostname="computer_b", ip_address="192.168.0.2", subnet_mask="255.255.255.252", start_up_duration=0 - ) + computer_b_cfg = { + "type": "computer", + "hostname": "computer_b", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.252", + "start_up_duration": 0, + } + + computer_b = Computer.from_config(config=computer_b_cfg) computer_b.power_on() computer_b.software_manager.install(software_class=C2Beacon) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py index eab283b8..fffb7c84 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py @@ -12,9 +12,15 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def dos_bot() -> DoSBot: - computer = Computer( - hostname="compromised_pc", ip_address="192.168.0.1", subnet_mask="255.255.255.0", start_up_duration=0 - ) + computer_cfg = { + "type": "computer", + "hostname": "compromised_pc", + "ip_address": "192.168.0.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + computer: Computer = Computer.from_config(config=computer_cfg) + computer.power_on() computer.software_manager.install(DoSBot) @@ -34,7 +40,7 @@ def test_dos_bot_cannot_run_when_node_offline(dos_bot): dos_bot_node.power_off() - for i in range(dos_bot_node.shut_down_duration + 1): + for i in range(dos_bot_node.config.shut_down_duration + 1): dos_bot_node.apply_timestep(timestep=i) assert dos_bot_node.operating_state is NodeOperatingState.OFF diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py index 58215671..177f31b0 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py @@ -17,13 +17,27 @@ from primaite.simulator.system.services.database.database_service import Databas def database_client_on_computer() -> Tuple[DatabaseClient, Computer]: network = Network() - db_server = Server(hostname="db_server", ip_address="192.168.0.1", subnet_mask="255.255.255.0", start_up_duration=0) + db_server: Server = Server.from_config( + config={ + "type": "server", + "hostname": "db_server", + "ip_address": "192.168.0.1", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) db_server.power_on() db_server.software_manager.install(DatabaseService) db_server.software_manager.software["database-service"].start() - db_client = Computer( - hostname="db_client", ip_address="192.168.0.2", subnet_mask="255.255.255.0", start_up_duration=0 + db_client: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "db_client", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } ) db_client.power_on() db_client.software_manager.install(DatabaseClient) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py index c5f4e74c..6cffc5c0 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py @@ -12,13 +12,17 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def web_browser() -> WebBrowser: - computer = Computer( - hostname="web_client", - ip_address="192.168.1.11", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + computer_cfg = { + "type": "computer", + "hostname": "web_client", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + computer: Computer = Computer.from_config(config=computer_cfg) + computer.power_on() # Web Browser should be pre-installed in computer web_browser: WebBrowser = computer.software_manager.software.get("web-browser") @@ -28,13 +32,17 @@ def web_browser() -> WebBrowser: def test_create_web_client(): - computer = Computer( - hostname="web_client", - ip_address="192.168.1.11", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + computer_cfg = { + "type": "computer", + "hostname": "web_client", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + + computer: Computer = Computer.from_config(config=computer_cfg) + computer.power_on() # Web Browser should be pre-installed in computer web_browser: WebBrowser = computer.software_manager.software.get("web-browser") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py index 2d0fdb8d..2154ebf9 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py @@ -8,7 +8,15 @@ from primaite.simulator.system.services.database.database_service import Databas @pytest.fixture(scope="function") def database_server() -> Node: - node = Computer(hostname="db_node", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0) + node_cfg = { + "type": "computer", + "hostname": "db_node", + "ip_address": "192.168.1.2", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + + node = Computer.from_config(config=node_cfg) node.power_on() node.software_manager.install(DatabaseService) node.software_manager.software.get("database-service").start() diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py index f4dfe20e..71ed7140 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py @@ -14,13 +14,15 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def dns_client() -> Computer: - node = Computer( - hostname="dns_client", - ip_address="192.168.1.11", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - dns_server=IPv4Address("192.168.1.10"), - ) + node_cfg = { + "type": "computer", + "hostname": "dns_client", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "dns_server": IPv4Address("192.168.1.10"), + } + node = Computer.from_config(config=node_cfg) return node @@ -34,6 +36,16 @@ def test_create_dns_client(dns_client): def test_dns_client_add_domain_to_cache_when_not_running(dns_client): dns_client_service: DNSClient = dns_client.software_manager.software.get("dns-client") + + # shutdown the dns_client + dns_client.power_off() + + # wait for dns_client to turn off + idx = 0 + while dns_client.operating_state == NodeOperatingState.SHUTTING_DOWN: + dns_client.apply_timestep(idx) + idx += 1 + assert dns_client.operating_state is NodeOperatingState.OFF assert dns_client_service.operating_state is ServiceOperatingState.STOPPED @@ -61,7 +73,7 @@ def test_dns_client_check_domain_exists_when_not_running(dns_client): dns_client.power_off() - for i in range(dns_client.shut_down_duration + 1): + for i in range(dns_client.config.shut_down_duration + 1): dns_client.apply_timestep(timestep=i) assert dns_client.operating_state is NodeOperatingState.OFF diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py index 65f472d9..72304df3 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py @@ -16,13 +16,15 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def dns_server() -> Node: - node = Server( - hostname="dns_server", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + node_cfg = { + "type": "server", + "hostname": "dns_server", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + node = Server.from_config(config=node_cfg) node.power_on() node.software_manager.install(software_class=DNSServer) return node @@ -55,9 +57,16 @@ def test_dns_server_receive(dns_server): # register the web server in the domain controller dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12")) - client = Computer(hostname="client", ip_address="192.168.1.11", subnet_mask="255.255.255.0", start_up_duration=0) + client_cfg = { + "type": "computer", + "hostname": "client", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + client = Computer.from_config(config=client_cfg) client.power_on() - client.dns_server = IPv4Address("192.168.1.10") + client.config.dns_server = IPv4Address("192.168.1.10") network = Network() network.connect(dns_server.network_interface[1], client.network_interface[1]) dns_client: DNSClient = client.software_manager.software["dns-client"] # noqa diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py index 26aa5dd5..4a0f362b 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py @@ -16,13 +16,15 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def ftp_client() -> Node: - node = Computer( - hostname="ftp_client", - ip_address="192.168.1.11", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + node_cfg = { + "type": "computer", + "hostname": "ftp_client", + "ip_address": "192.168.1.11", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + node = Computer.from_config(config=node_cfg) node.power_on() return node @@ -94,7 +96,7 @@ def test_offline_ftp_client_receives_request(ftp_client): ftp_client_service: FTPClient = ftp_client.software_manager.software.get("ftp-client") ftp_client.power_off() - for i in range(ftp_client.shut_down_duration + 1): + for i in range(ftp_client.config.shut_down_duration + 1): ftp_client.apply_timestep(timestep=i) assert ftp_client.operating_state is NodeOperatingState.OFF diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py index fb2f82fe..2527a401 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py @@ -14,13 +14,15 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def ftp_server() -> Node: - node = Server( - hostname="ftp_server", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + node_cfg = { + "type": "server", + "hostname": "ftp_server", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + node = Server.from_config(config=node_cfg) node.power_on() node.software_manager.install(software_class=FTPServer) return node diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 144037f5..12d12ea8 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -12,6 +12,7 @@ from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter +from primaite.simulator.network.networks import arcd_uc2_network from primaite.simulator.network.protocols.ssh import ( SSHConnectionMessage, SSHPacket, @@ -29,8 +30,14 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def terminal_on_computer() -> Tuple[Terminal, Computer]: - computer: Computer = Computer( - hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0 + computer: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_a", + "ip_address": "192.168.0.10", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } ) computer.power_on() terminal: Terminal = computer.software_manager.software.get("terminal") @@ -41,11 +48,27 @@ def terminal_on_computer() -> Tuple[Terminal, Computer]: @pytest.fixture(scope="function") def basic_network() -> Network: network = Network() - node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + node_a = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_a", + "ip_address": "192.168.0.10", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node_a.power_on() node_a.software_manager.get_open_ports() - node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) + node_b = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_b", + "ip_address": "192.168.0.11", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) @@ -57,18 +80,23 @@ def wireless_wan_network(): network = Network() # Configure PC A - pc_a = Computer( - hostname="pc_a", - ip_address="192.168.0.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.0.1", - start_up_duration=0, - ) + pc_a_cfg = { + "type": "computer", + "hostname": "pc_a", + "ip_address": "192.168.0.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.0.1", + "start_up_duration": 0, + } + + pc_a = Computer.from_config(config=pc_a_cfg) pc_a.power_on() network.add_node(pc_a) # Configure Router 1 - router_1 = WirelessRouter(hostname="router_1", start_up_duration=0, airspace=network.airspace) + router_1 = WirelessRouter.from_config( + config={"type": "wireless_router", "hostname": "router_1", "start_up_duration": 0, "airspace": network.airspace} + ) router_1.power_on() network.add_node(router_1) @@ -88,41 +116,30 @@ def wireless_wan_network(): ) # Configure PC B - pc_b = Computer( - hostname="pc_b", - ip_address="192.168.2.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.2.1", - start_up_duration=0, - ) + + pc_b_cfg = { + "type": "computer", + "hostname": "pc_b", + "ip_address": "192.168.2.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.2.1", + "start_up_duration": 0, + } + + pc_b = Computer.from_config(config=pc_b_cfg) pc_b.power_on() network.add_node(pc_b) - # Configure Router 2 - router_2 = WirelessRouter(hostname="router_2", start_up_duration=0, airspace=network.airspace) - router_2.power_on() - network.add_node(router_2) - - # Configure the connection between PC B and Router 2 port 2 - router_2.configure_router_interface("192.168.2.1", "255.255.255.0") - network.connect(pc_b.network_interface[1], router_2.network_interface[2]) - # Configure Router 2 ACLs # Configure the wireless connection between Router 1 port 1 and Router 2 port 1 router_1.configure_wireless_access_point("192.168.1.1", "255.255.255.0") - router_2.configure_wireless_access_point("192.168.1.2", "255.255.255.0") router_1.route_table.add_route( address="192.168.2.0", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.2" ) - # Configure Route from Router 2 to PC A subnet - router_2.route_table.add_route( - address="192.168.0.2", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.1" - ) - - return pc_a, pc_b, router_1, router_2 + return network @pytest.fixture @@ -131,7 +148,7 @@ def game_and_agent_fixture(game_and_agent): game, agent = game_and_agent client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - client_1.start_up_duration = 3 + client_1.config.start_up_duration = 3 return game, agent @@ -142,8 +159,16 @@ def test_terminal_creation(terminal_on_computer): def test_terminal_install_default(): - """terminal should be auto installed onto Nodes""" - computer = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + """Terminal should be auto installed onto Nodes""" + computer: Computer = Computer.from_config( + config={ + "type": "computer", + "hostname": "node_a", + "ip_address": "192.168.0.10", + "subnet_mask": "255.255.255.0", + "start_up_duration": 0, + } + ) computer.power_on() assert computer.software_manager.software.get("terminal") @@ -151,7 +176,7 @@ def test_terminal_install_default(): def test_terminal_not_on_switch(): """Ensure terminal does not auto-install to switch""" - test_switch = Switch(hostname="Test") + test_switch = Switch.from_config(config={"type": "switch", "hostname": "Test"}) assert not test_switch.software_manager.software.get("terminal") @@ -274,7 +299,10 @@ def test_terminal_ignores_when_off(basic_network): def test_computer_remote_login_to_router(wireless_wan_network): """Test to confirm that a computer can SSH into a router.""" - pc_a, _, router_1, _ = wireless_wan_network + + pc_a = wireless_wan_network.get_node_by_hostname("pc_a") + + router_1 = wireless_wan_network.get_node_by_hostname("router_1") pc_a_terminal: Terminal = pc_a.software_manager.software.get("terminal") @@ -293,7 +321,9 @@ def test_computer_remote_login_to_router(wireless_wan_network): def test_router_remote_login_to_computer(wireless_wan_network): """Test to confirm that a router can ssh into a computer.""" - pc_a, _, router_1, _ = wireless_wan_network + pc_a = wireless_wan_network.get_node_by_hostname("pc_a") + + router_1 = wireless_wan_network.get_node_by_hostname("router_1") router_1_terminal: Terminal = router_1.software_manager.software.get("terminal") @@ -311,8 +341,10 @@ def test_router_remote_login_to_computer(wireless_wan_network): def test_router_blocks_SSH_traffic(wireless_wan_network): - """Test to check that router will block SSH traffic if no acl rule.""" - pc_a, _, router_1, _ = wireless_wan_network + """Test to check that router will block SSH traffic if no ACL rule.""" + pc_a = wireless_wan_network.get_node_by_hostname("pc_a") + + router_1 = wireless_wan_network.get_node_by_hostname("router_1") # Remove rule that allows SSH traffic. router_1.acl.remove_rule(position=21) @@ -326,20 +358,22 @@ def test_router_blocks_SSH_traffic(wireless_wan_network): assert len(pc_a_terminal._connections) == 0 -def test_SSH_across_network(wireless_wan_network): +def test_SSH_across_network(): """Test to show ability to SSH across a network.""" - pc_a, pc_b, router_1, router_2 = wireless_wan_network + network: Network = arcd_uc2_network() + pc_a = network.get_node_by_hostname("client_1") + router_1 = network.get_node_by_hostname("router_1") - terminal_a: Terminal = pc_a.software_manager.software.get("terminal") - terminal_b: Terminal = pc_b.software_manager.software.get("terminal") + terminal_a: Terminal = pc_a.software_manager.software.get("Terminal") - router_2.acl.add_rule( + router_1.acl.add_rule( action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=21 ) assert len(terminal_a._connections) == 0 - terminal_b_on_terminal_a = terminal_b.login(username="admin", password="admin", ip_address="192.168.0.2") + # Login to the Domain Controller + terminal_a.login(username="admin", password="admin", ip_address="192.168.1.10") assert len(terminal_a._connections) == 1 @@ -357,8 +391,6 @@ def test_multiple_remote_terminals_same_node(basic_network): for attempt in range(3): remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11") - terminal_a.show() - assert len(terminal_a._connections) == 3 diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py index 4bd8a7e3..ad6afa82 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py @@ -16,13 +16,15 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def web_server() -> Server: - node = Server( - hostname="web_server", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + node_cfg = { + "type": "server", + "hostname": "web_server", + "ip_address": "192.168.1.10", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration": 0, + } + node = Server.from_config(config=node_cfg) node.power_on() node.software_manager.install(WebServer) node.software_manager.software.get("web-server").start()