diff --git a/docs/source/how_to_guides/extensible_nodes.rst b/docs/source/how_to_guides/extensible_nodes.rst new file mode 100644 index 00000000..21907767 --- /dev/null +++ b/docs/source/how_to_guides/extensible_nodes.rst @@ -0,0 +1,58 @@ +.. only:: comment + + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + +.. _about: + + +Extensible Nodes +**************** + +Node classes within PrimAITE have been updated to allow for easier generation of custom nodes within simulations. + + +Changes to Node Class structure. +================================ + +Node classes all inherit from the base Node Class, though new classes should inherit from either HostNode or NetworkNode, subject to the intended application of the Node. + +The use of an `__init__` method is not necessary, as configurable variables for the class should be specified within the `config` of the class, and passed at run time via your YAML configuration using the `from_config` method. + + +An example of how additional Node classes is below, taken from `router.py` withing PrimAITE. + +.. code-block:: Python + +class Router(NetworkNode, identifier="router"): + """ Represents a network router within the simulation, managing routing and forwarding of IP packets across network interfaces.""" + + SYSTEM_SOFTWARE: ClassVar[Dict] = { + "UserSessionManager": UserSessionManager, + "UserManager": UserManager, + "Terminal": Terminal, + } + + network_interfaces: Dict[str, RouterInterface] = {} + "The Router Interfaces on the node." + network_interface: Dict[int, RouterInterface] = {} + "The Router Interfaces on the node by port id." + + sys_log: SysLog + + config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSchema()) + + class ConfigSchema(NetworkNode.ConfigSchema): + """Configuration Schema for Router Objects.""" + + num_ports: int = 5 + + hostname: ClassVar[str] = "Router" + + ports: list = [] + + + +Changes to YAML file. +===================== + +Nodes defined within configuration YAML files for use with PrimAITE 3.X should still be compatible following these changes. \ No newline at end of file diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 6599430a..b9dc9c4d 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -271,12 +271,13 @@ class PrimaiteGame: for node_cfg in nodes_cfg: n_type = node_cfg["type"] - node_config: dict = node_cfg["config"] + # node_config: dict = node_cfg["config"] + print(f"{n_type}:{node_cfg}") new_node = None if n_type in Node._registry: # simplify down Node creation: - new_node = Node._registry["n_type"].from_config(config=node_config) + new_node = Node._registry[n_type].from_config(config=node_cfg) else: msg = f"invalid node type {n_type} in config" _LOGGER.error(msg) @@ -313,7 +314,7 @@ class PrimaiteGame: service_class = SERVICE_TYPES_MAPPING[service_type] if service_class is not None: - _LOGGER.debug(f"installing {service_type} on node {new_node.hostname}") + _LOGGER.debug(f"installing {service_type} on node {new_node.config.hostname}") new_node.software_manager.install(service_class, **service_cfg.get("options", {})) new_service = new_node.software_manager.software[service_class.__name__] diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 567a0493..dc4ae73b 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -3,7 +3,7 @@ """Core of the PrimAITE Simulator.""" import warnings from abc import abstractmethod -from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union from uuid import uuid4 from prettytable import PrettyTable diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index bf677d5c..aac82633 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -180,7 +180,7 @@ class Network(SimComponent): table.align = "l" table.title = "Nodes" for node in self.nodes.values(): - table.add_row((node.hostname, type(node)._identifier, node.operating_state.name)) + table.add_row((node.config.hostname, type(node)._identifier, node.operating_state.name)) print(table) if ip_addresses: @@ -196,7 +196,7 @@ class Network(SimComponent): if port.ip_address != IPv4Address("127.0.0.1"): port_str = port.port_name if port.port_name else port.port_num table.add_row( - [node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway] + [node.config.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway] ) print(table) @@ -215,9 +215,9 @@ class Network(SimComponent): if node in [link.endpoint_a.parent, link.endpoint_b.parent]: table.add_row( [ - link.endpoint_a.parent.hostname, + link.endpoint_a.parent.config.hostname, str(link.endpoint_a), - link.endpoint_b.parent.hostname, + link.endpoint_b.parent.config.hostname, str(link.endpoint_b), link.is_up, link.bandwidth, @@ -251,7 +251,7 @@ class Network(SimComponent): state = super().describe_state() state.update( { - "nodes": {node.hostname: node.describe_state() for node in self.nodes.values()}, + "nodes": {node.config.hostname: node.describe_state() for node in self.nodes.values()}, "links": {}, } ) @@ -259,8 +259,8 @@ class Network(SimComponent): for _, link in self.links.items(): node_a = link.endpoint_a._connected_node node_b = link.endpoint_b._connected_node - hostname_a = node_a.hostname if node_a else None - hostname_b = node_b.hostname if node_b else None + hostname_a = node_a.config.hostname if node_a else None + hostname_b = node_b.config.hostname if node_b else None port_a = link.endpoint_a.port_num port_b = link.endpoint_b.port_num link_key = f"{hostname_a}:eth-{port_a}<->{hostname_b}:eth-{port_b}" @@ -286,9 +286,9 @@ class Network(SimComponent): self.nodes[node.uuid] = node self._node_id_map[len(self.nodes)] = node node.parent = self - self._nx_graph.add_node(node.hostname) + self._nx_graph.add_node(node.config.hostname) _LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}") - self._node_request_manager.add_request(name=node.hostname, request_type=RequestType(func=node._request_manager)) + self._node_request_manager.add_request(name=node.config.hostname, request_type=RequestType(func=node._request_manager)) def get_node_by_hostname(self, hostname: str) -> Optional[Node]: """ @@ -300,7 +300,7 @@ class Network(SimComponent): :return: The Node if it exists in the network. """ for node in self.nodes.values(): - if node.hostname == hostname: + if node.config.hostname == hostname: return node def remove_node(self, node: Node) -> None: @@ -313,7 +313,7 @@ class Network(SimComponent): :type node: Node """ if node not in self: - _LOGGER.warning(f"Can't remove node {node.hostname}. It's not in the network.") + _LOGGER.warning(f"Can't remove node {node.config.hostname}. It's not in the network.") return self.nodes.pop(node.uuid) for i, _node in self._node_id_map.items(): @@ -321,8 +321,8 @@ class Network(SimComponent): self._node_id_map.pop(i) break node.parent = None - self._node_request_manager.remove_request(name=node.hostname) - _LOGGER.info(f"Removed node {node.hostname} from network {self.uuid}") + self._node_request_manager.remove_request(name=node.config.hostname) + _LOGGER.info(f"Removed node {node.config.hostname} from network {self.uuid}") def connect( self, endpoint_a: WiredNetworkInterface, endpoint_b: WiredNetworkInterface, bandwidth: int = 100, **kwargs @@ -352,7 +352,7 @@ class Network(SimComponent): link = Link(endpoint_a=endpoint_a, endpoint_b=endpoint_b, bandwidth=bandwidth, **kwargs) self.links[link.uuid] = link self._link_id_map[len(self.links)] = link - self._nx_graph.add_edge(endpoint_a.parent.hostname, endpoint_b.parent.hostname) + self._nx_graph.add_edge(endpoint_a.parent.config.hostname, endpoint_b.parent.config.hostname) link.parent = self _LOGGER.debug(f"Added link {link.uuid} to connect {endpoint_a} and {endpoint_b}") return link diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 822714cb..dbe9705b 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -431,7 +431,7 @@ class WiredNetworkInterface(NetworkInterface, ABC): self.enabled = True self._connected_node.sys_log.info(f"Network Interface {self} enabled") self.pcap = PacketCapture( - hostname=self._connected_node.hostname, port_num=self.port_num, port_name=self.port_name + hostname=self._connected_node.config.hostname, port_num=self.port_num, port_name=self.port_name ) if self._connected_link: self._connected_link.endpoint_up() @@ -1515,14 +1515,16 @@ class Node(SimComponent, ABC): _identifier: ClassVar[str] = "unknown" """Identifier for this particular class, used for printing and logging. Each subclass redefines this.""" - config: Node.ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema()) + config: Node.ConfigSchema + """Configuration items within Node""" class ConfigSchema(BaseModel, ABC): """Configuration Schema for Node based classes.""" model_config = ConfigDict(arbitrary_types_allowed=True) """Configure pydantic to allow arbitrary types and to let the instance have attributes not present in the model.""" - hostname: str + + hostname: str = "default" "The node hostname on the network." revealed_to_red: bool = False @@ -1552,6 +1554,7 @@ class Node(SimComponent, ABC): red_scan_countdown: int = 0 "Time steps until reveal to red scan is complete." + @classmethod def from_config(cls, config: Dict) -> "Node": """Create Node object from a given configuration dictionary.""" @@ -1586,11 +1589,11 @@ class Node(SimComponent, ABC): provided. """ if not kwargs.get("sys_log"): - kwargs["sys_log"] = SysLog(kwargs["hostname"]) + kwargs["sys_log"] = SysLog(kwargs["config"].hostname) if not kwargs.get("session_manager"): kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log")) if not kwargs.get("root"): - kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"] + kwargs["root"] = SIM_OUTPUT.path / kwargs["config"].hostname if not kwargs.get("file_system"): kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs") if not kwargs.get("software_manager"): @@ -1601,10 +1604,12 @@ class Node(SimComponent, ABC): file_system=kwargs.get("file_system"), dns_server=kwargs.get("dns_server"), ) + super().__init__(**kwargs) self._install_system_software() self.session_manager.node = self self.session_manager.software_manager = self.software_manager + self.power_on() @property def user_manager(self) -> Optional[UserManager]: @@ -1856,7 +1861,7 @@ class Node(SimComponent, ABC): state = super().describe_state() state.update( { - "hostname": self.hostname, + "hostname": self.config.hostname, "operating_state": self.operating_state.value, "NICs": { eth_num: network_interface.describe_state() diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 00f21342..23db025d 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -333,10 +333,13 @@ class HostNode(Node, identifier="HostNode"): """Configuration Schema for HostNode class.""" hostname: str = "HostNode" + ip_address: IPV4Address = "192.168.0.1" + subnet_mask: IPV4Address = "255.255.255.0" + default_gateway: IPV4Address = "192.168.10.1" - def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs): + def __init__(self, **kwargs): super().__init__(**kwargs) - self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask)) + self.connect_nic(NIC(ip_address=kwargs["config"].ip_address, subnet_mask=kwargs["config"].subnet_mask)) @property def nmap(self) -> Optional[NMAP]: diff --git a/src/primaite/simulator/network/hardware/nodes/host/server.py b/src/primaite/simulator/network/hardware/nodes/host/server.py index e16cfd8f..1b3f6c58 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/server.py +++ b/src/primaite/simulator/network/hardware/nodes/host/server.py @@ -1,4 +1,6 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from typing import ClassVar +from pydantic import Field from primaite.simulator.network.hardware.nodes.host.host_node import HostNode @@ -30,8 +32,23 @@ class Server(HostNode, identifier="server"): * Web Browser """ + config: "Server.ConfigSchema" = Field(default_factory=lambda: Server.ConfigSchema()) + + class ConfigSchema(HostNode.ConfigSchema): + """Configuration Schema for Server class.""" + + hostname: str = "server" + class Printer(HostNode, identifier="printer"): """Printer? I don't even know her!.""" # TODO: Implement printer-specific behaviour + + + config: "Printer.ConfigSchema" = Field(default_factory=lambda: Printer.ConfigSchema()) + + class ConfigSchema(HostNode.ConfigSchema): + """Configuration Schema for Printer class.""" + + hostname: ClassVar[str] = "printer" \ No newline at end of file diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index c7e22d49..2ebfe44a 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -99,19 +99,22 @@ class Firewall(Router, identifier="firewall"): ) """Access Control List for managing traffic leaving towards an external network.""" + _identifier: str = "firewall" + config: "Firewall.ConfigSchema" = Field(default_factory=lambda: Firewall.ConfigSchema()) - class ConfigSchema(Router.ConfigSChema): + class ConfigSchema(Router.ConfigSchema): """Configuration Schema for Firewall 'Nodes' within PrimAITE.""" - hostname: str = "Firewall" + hostname: str = "firewall" num_ports: int = 0 + operating_state: NodeOperatingState = NodeOperatingState.ON def __init__(self, **kwargs): if not kwargs.get("sys_log"): - kwargs["sys_log"] = SysLog(self.config.hostname) + kwargs["sys_log"] = SysLog(kwargs["config"].hostname) - super().__init__(hostname=self.config.hostname, num_ports=self.config.num_ports, **kwargs) + super().__init__(**kwargs) self.connect_nic( RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", port_name="external") @@ -124,22 +127,22 @@ class Firewall(Router, identifier="firewall"): ) # Update ACL objects with firewall's hostname and syslog to allow accurate logging self.internal_inbound_acl.sys_log = kwargs["sys_log"] - self.internal_inbound_acl.name = f"{hostname} - Internal Inbound" + self.internal_inbound_acl.name = f"{kwargs['config'].hostname} - Internal Inbound" self.internal_outbound_acl.sys_log = kwargs["sys_log"] - self.internal_outbound_acl.name = f"{hostname} - Internal Outbound" + self.internal_outbound_acl.name = f"{kwargs['config'].hostname} - Internal Outbound" self.dmz_inbound_acl.sys_log = kwargs["sys_log"] - self.dmz_inbound_acl.name = f"{hostname} - DMZ Inbound" + self.dmz_inbound_acl.name = f"{kwargs['config'].hostname} - DMZ Inbound" self.dmz_outbound_acl.sys_log = kwargs["sys_log"] - self.dmz_outbound_acl.name = f"{hostname} - DMZ Outbound" + self.dmz_outbound_acl.name = f"{kwargs['config'].hostname} - DMZ Outbound" self.external_inbound_acl.sys_log = kwargs["sys_log"] - self.external_inbound_acl.name = f"{hostname} - External Inbound" + self.external_inbound_acl.name = f"{kwargs['config'].hostname} - External Inbound" self.external_outbound_acl.sys_log = kwargs["sys_log"] - self.external_outbound_acl.name = f"{hostname} - External Outbound" + self.external_outbound_acl.name = f"{kwargs['config'].hostname} - External Outbound" def _init_request_manager(self) -> RequestManager: """ @@ -567,18 +570,21 @@ class Firewall(Router, identifier="firewall"): self.dmz_port.enable() @classmethod - def from_config(cls, cfg: dict) -> "Firewall": + def from_config(cls, config: dict) -> "Firewall": """Create a firewall based on a config dict.""" - firewall = Firewall( - hostname=cfg["hostname"], - operating_state=NodeOperatingState.ON - if not (p := cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) - if "ports" in cfg: - internal_port = cfg["ports"]["internal_port"] - external_port = cfg["ports"]["external_port"] - dmz_port = cfg["ports"].get("dmz_port") + # firewall = Firewall( + # hostname=config["hostname"], + # operating_state=NodeOperatingState.ON + # if not (p := config.get("operating_state")) + # else NodeOperatingState[p.upper()], + # ) + + firewall = Firewall(config = cls.ConfigSchema(**config)) + + if "ports" in config: + internal_port = config["ports"]["internal_port"] + external_port = config["ports"]["external_port"] + dmz_port = config["ports"].get("dmz_port") # configure internal port firewall.configure_internal_port( @@ -598,10 +604,10 @@ class Firewall(Router, identifier="firewall"): ip_address=IPV4Address(dmz_port.get("ip_address")), subnet_mask=IPV4Address(dmz_port.get("subnet_mask", "255.255.255.0")), ) - if "acl" in cfg: + if "acl" in config: # acl rules for internal_inbound_acl - if cfg["acl"]["internal_inbound_acl"]: - for r_num, r_cfg in cfg["acl"]["internal_inbound_acl"].items(): + if config["acl"]["internal_inbound_acl"]: + for r_num, r_cfg in config["acl"]["internal_inbound_acl"].items(): firewall.internal_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -615,8 +621,8 @@ class Firewall(Router, identifier="firewall"): ) # acl rules for internal_outbound_acl - if cfg["acl"]["internal_outbound_acl"]: - for r_num, r_cfg in cfg["acl"]["internal_outbound_acl"].items(): + if config["acl"]["internal_outbound_acl"]: + for r_num, r_cfg in config["acl"]["internal_outbound_acl"].items(): firewall.internal_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -630,8 +636,8 @@ class Firewall(Router, identifier="firewall"): ) # acl rules for dmz_inbound_acl - if cfg["acl"]["dmz_inbound_acl"]: - for r_num, r_cfg in cfg["acl"]["dmz_inbound_acl"].items(): + if config["acl"]["dmz_inbound_acl"]: + for r_num, r_cfg in config["acl"]["dmz_inbound_acl"].items(): firewall.dmz_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -645,8 +651,8 @@ class Firewall(Router, identifier="firewall"): ) # acl rules for dmz_outbound_acl - if cfg["acl"]["dmz_outbound_acl"]: - for r_num, r_cfg in cfg["acl"]["dmz_outbound_acl"].items(): + if config["acl"]["dmz_outbound_acl"]: + for r_num, r_cfg in config["acl"]["dmz_outbound_acl"].items(): firewall.dmz_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -660,8 +666,8 @@ class Firewall(Router, identifier="firewall"): ) # acl rules for external_inbound_acl - if cfg["acl"].get("external_inbound_acl"): - for r_num, r_cfg in cfg["acl"]["external_inbound_acl"].items(): + if config["acl"].get("external_inbound_acl"): + for r_num, r_cfg in config["acl"]["external_inbound_acl"].items(): firewall.external_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -675,8 +681,8 @@ class Firewall(Router, identifier="firewall"): ) # acl rules for external_outbound_acl - if cfg["acl"].get("external_outbound_acl"): - for r_num, r_cfg in cfg["acl"]["external_outbound_acl"].items(): + if config["acl"].get("external_outbound_acl"): + for r_num, r_cfg in config["acl"]["external_outbound_acl"].items(): firewall.external_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], @@ -689,16 +695,16 @@ class Firewall(Router, identifier="firewall"): position=r_num, ) - if "routes" in cfg: - for route in cfg.get("routes"): + if "routes" in config: + for route in config.get("routes"): firewall.route_table.add_route( address=IPv4Address(route.get("address")), subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")), next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")), metric=float(route.get("metric", 0)), ) - if "default_route" in cfg: - next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None) + if "default_route" in config: + next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None) if next_hop_ip_address: firewall.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 83fa066d..e475df66 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1212,21 +1212,34 @@ class Router(NetworkNode, identifier="router"): network_interface: Dict[int, RouterInterface] = {} "The Router Interfaces on the node by port id." - config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSchema()) + sys_log: SysLog = None + + acl: AccessControlList = None + + route_table: RouteTable = None + + config: "Router.ConfigSchema" class ConfigSchema(NetworkNode.ConfigSchema): """Configuration Schema for Router Objects.""" num_ports: int = 5 + """Number of ports available for this Router. Default is 5""" + hostname: str = "Router" - ports: list = [] - sys_log: SysLog = SysLog(hostname) - acl: AccessControlList = AccessControlList(sys_log=sys_log, implicit_action=ACLAction.DENY, name=hostname) - route_table: RouteTable = RouteTable(sys_log=sys_log) + + ports: Dict[Union[int, str], Dict] = {} + def __init__(self, **kwargs): - super().__init__(hostname=self.config.hostname, num_ports=self.config.num_ports, **kwargs) - self.session_manager = RouterSessionManager(sys_log=self.config.sys_log) + if not kwargs.get("sys_log"): + kwargs["sys_log"] = SysLog(kwargs["config"].hostname) + if not kwargs.get("acl"): + kwargs["acl"] = AccessControlList(sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=kwargs["config"].hostname) + if not kwargs.get("route_table"): + kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"]) + super().__init__(**kwargs) + self.session_manager = RouterSessionManager(sys_log=self.sys_log) self.session_manager.node = self self.software_manager.session_manager = self.session_manager self.session_manager.software_manager = self.software_manager @@ -1234,9 +1247,11 @@ class Router(NetworkNode, identifier="router"): network_interface = RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0") self.connect_nic(network_interface) self.network_interface[i] = network_interface - self.operating_state = NodeOperatingState.ON + self._set_default_acl() + + def _install_system_software(self): """ Installs essential system software and network services on the router. @@ -1260,10 +1275,10 @@ class Router(NetworkNode, identifier="router"): Initializes the router's ACL (Access Control List) with default rules, permitting essential protocols like ARP and ICMP, which are necessary for basic network operations and diagnostics. """ - self.config.acl.add_rule( + self.acl.add_rule( action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) - self.config.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + self.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) def setup_for_episode(self, episode: int): """ @@ -1287,7 +1302,7 @@ class Router(NetworkNode, identifier="router"): More information in user guide and docstring for SimComponent._init_request_manager. """ rm = super()._init_request_manager() - rm.add_request("acl", RequestType(func=self.config.acl._request_manager)) + rm.add_request("acl", RequestType(func=self.acl._request_manager)) return rm def ip_is_router_interface(self, ip_address: IPv4Address, enabled_only: bool = False) -> bool: @@ -1341,7 +1356,7 @@ class Router(NetworkNode, identifier="router"): """ state = super().describe_state() state["num_ports"] = self.config.num_ports - state["acl"] = self.config.acl.describe_state() + state["acl"] = self.acl.describe_state() return state def check_send_frame_to_session_manager(self, frame: Frame) -> bool: @@ -1562,7 +1577,7 @@ class Router(NetworkNode, identifier="router"): print(table) def setup_router(self, cfg: dict) -> Router: - """ TODO: This is the extra bit of Router's from_config metho. Needs sorting.""" + """TODO: This is the extra bit of Router's from_config metho. Needs sorting.""" if "ports" in cfg: for port_num, port_cfg in cfg["ports"].items(): self.configure_port( @@ -1594,5 +1609,100 @@ class Router(NetworkNode, identifier="router"): if "default_route" in cfg: next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None) if next_hop_ip_address: - self.config.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) + self.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) return self + + + @classmethod + def from_config(cls, config: dict, **kwargs) -> "Router": + """Create a router based on a config dict. + + Schema: + - hostname (str): unique name for this router. + - num_ports (int, optional): Number of network ports on the router. 8 by default + - ports (dict): Dict with integers from 1 - num_ports as keys. The values should be another dict specifying + ip_address and subnet_mask assigned to that ports (as strings) + - acl (dict): Dict with integers from 1 - max_acl_rules as keys. The key defines the position within the ACL + where the rule will be added (lower number is resolved first). The values should describe valid ACL + Rules as: + - action (str): either PERMIT or DENY + - src_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER + - dst_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER + - protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP + - src_ip_address (str, optional): IP address octet written in base 10 + - dst_ip_address (str, optional): IP address octet written in base 10 + - routes (list[dict]): List of route dicts with values: + - address (str): The destination address of the route. + - subnet_mask (str): The subnet mask of the route. + - next_hop_ip_address (str): The next hop IP for the route. + - metric (int): The metric of the route. Optional. + - default_route: + - next_hop_ip_address (str): The next hop IP for the route. + + Example config: + ``` + { + 'hostname': 'router_1', + 'num_ports': 5, + 'ports': { + 1: { + 'ip_address' : '192.168.1.1', + 'subnet_mask' : '255.255.255.0', + }, + 2: { + 'ip_address' : '192.168.0.1', + 'subnet_mask' : '255.255.255.252', + } + }, + 'acl' : { + 21: {'action': 'PERMIT', 'src_port': 'HTTP', dst_port: 'HTTP'}, + 22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'}, + 23: {'action': 'PERMIT', 'protocol': 'ICMP'}, + }, + 'routes' : [ + {'address': '192.168.0.0', 'subnet_mask': '255.255.255.0', 'next_hop_ip_address': '192.168.1.2'} + ], + 'default_route': {'next_hop_ip_address': '192.168.0.2'} + } + ``` + + :param cfg: Router config adhering to schema described in main docstring body + :type cfg: dict + :return: Configured router. + :rtype: Router + """ + router = Router(config=Router.ConfigSchema(**config) + ) + if "ports" in config: + for port_num, port_cfg in config["ports"].items(): + router.configure_port( + port=port_num, + ip_address=port_cfg["ip_address"], + subnet_mask=IPv4Address(port_cfg.get("subnet_mask", "255.255.255.0")), + ) + if "acl" in config: + for r_num, r_cfg in config["acl"].items(): + router.acl.add_rule( + action=ACLAction[r_cfg["action"]], + src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], + protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p], + src_ip_address=r_cfg.get("src_ip"), + src_wildcard_mask=r_cfg.get("src_wildcard_mask"), + dst_ip_address=r_cfg.get("dst_ip"), + dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"), + position=r_num, + ) + if "routes" in config: + for route in config.get("routes"): + router.route_table.add_route( + address=IPv4Address(route.get("address")), + subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")), + next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")), + metric=float(route.get("metric", 0)), + ) + if "default_route" in config: + next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None) + if next_hop_ip_address: + router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) + return router \ No newline at end of file diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index a2d0050b..2ca0cafd 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import Dict, Optional +from typing import ClassVar, Dict, Optional from prettytable import MARKDOWN, PrettyTable from pydantic import Field @@ -102,7 +102,7 @@ class Switch(NetworkNode, identifier="switch"): mac_address_table: Dict[str, SwitchPort] = {} "A MAC address table mapping destination MAC addresses to corresponding SwitchPorts." - config: "Switch.ConfigSchema" = Field(default_factory=lambda: Switch.ConfigSchema()) + config: "Switch.ConfigSchema" class ConfigSchema(NetworkNode.ConfigSchema): """Configuration Schema for Switch nodes within PrimAITE.""" @@ -113,7 +113,7 @@ class Switch(NetworkNode, identifier="switch"): def __init__(self, **kwargs): super().__init__(**kwargs) - for i in range(1, self.config.num_ports + 1): + for i in range(1, kwargs["config"].num_ports + 1): self.connect_nic(SwitchPort()) def _install_system_software(self): diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 34c893eb..9a30f3e3 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -294,7 +294,7 @@ class IOSoftware(Software): """ if self.software_manager and self.software_manager.node.operating_state != NodeOperatingState.ON: self.software_manager.node.sys_log.error( - f"{self.name} Error: {self.software_manager.node.hostname} is not powered on." + f"{self.name} Error: {self.software_manager.node.config.hostname} is not powered on." ) return False return True diff --git a/tests/assets/configs/data_manipulation.yaml b/tests/assets/configs/data_manipulation.yaml index 97442903..bddea1a0 100644 --- a/tests/assets/configs/data_manipulation.yaml +++ b/tests/assets/configs/data_manipulation.yaml @@ -187,7 +187,7 @@ agents: num_files: 1 num_nics: 2 include_num_access: false - include_nmne: true + include_nmne: true monitored_traffic: icmp: - NONE diff --git a/tests/conftest.py b/tests/conftest.py index 6cbcfa84..08c16537 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -195,68 +195,91 @@ def example_network() -> Network: network = Network() # Router 1 + + router_1_cfg = {"hostname":"router_1", "type":"router"} + # router_1 = Router(hostname="router_1", start_up_duration=0) - router_1 = Router(hostname="router_1", start_up_duration=0) + router_1 = Router.from_config(config=router_1_cfg) router_1.power_on() router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0") router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0") # Switch 1 - # switch_1_config = Switch.ConfigSchema() - switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0) + + switch_1_cfg = {"hostname": "switch_1", "type": "switch"} + + switch_1 = Switch.from_config(config=switch_1_cfg) + + # switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0) switch_1.power_on() network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[8]) router_1.enable_port(1) # Switch 2 - # switch_2_config = Switch.ConfigSchema() - switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0) + switch_2_config = {"hostname": "switch_2", "type": "switch", "num_ports": 8} + # switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0) + switch_2 = Switch.from_config(config=switch_2_config) switch_2.power_on() network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8]) router_1.enable_port(2) - # Client 1 - client_1 = Computer( - hostname="client_1", - ip_address="192.168.10.21", - subnet_mask="255.255.255.0", - default_gateway="192.168.10.1", - start_up_duration=0, - ) + # # Client 1 + + client_1_cfg = {"type": "computer", + "hostname": "client_1", + "ip_address": "192.168.10.21", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.10.1", + "start_up_duration": 0} + + client_1=Computer.from_config(config=client_1_cfg) + client_1.power_on() network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) - # Client 2 - client_2 = Computer( - hostname="client_2", - ip_address="192.168.10.22", - subnet_mask="255.255.255.0", - default_gateway="192.168.10.1", - start_up_duration=0, - ) + # # Client 2 + + client_2_cfg = {"type": "computer", + "hostname": "client_2", + "ip_address": "192.168.10.22", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.10.1", + "start_up_duration": 0, + } + + client_2 = Computer.from_config(config=client_2_cfg) + client_2.power_on() network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.network_interface[2]) - # Server 1 - server_1 = Server( - hostname="server_1", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + # # Server 1 + + server_1_cfg = {"type": "server", + "hostname": "server_1", + "ip_address":"192.168.1.10", + "subnet_mask":"255.255.255.0", + "default_gateway":"192.168.1.1", + "start_up_duration":0, + } + + server_1 = Server.from_config(config=server_1_cfg) + server_1.power_on() network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1]) - # DServer 2 - server_2 = Server( - hostname="server_2", - ip_address="192.168.1.14", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + # # DServer 2 + + server_2_cfg = {"type": "server", + "hostname": "server_2", + "ip_address":"192.168.1.14", + "subnet_mask":"255.255.255.0", + "default_gateway":"192.168.1.1", + "start_up_duration":0, + } + + server_2 = Server.from_config(config=server_2_cfg) + server_2.power_on() network.connect(endpoint_b=server_2.network_interface[1], endpoint_a=switch_1.network_interface[2]) @@ -264,6 +287,8 @@ def example_network() -> Network: assert all(link.is_up for link in network.links.values()) + client_1.software_manager.show() + return network diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py index 16f4dee5..c9691fab 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py @@ -6,6 +6,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router +from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP from primaite.utils.validation.port import PORT_LOOKUP from tests.integration_tests.configuration_file_parsing import DMZ_NETWORK, load_config diff --git a/tests/integration_tests/game_layer/observations/test_acl_observations.py b/tests/integration_tests/game_layer/observations/test_acl_observations.py index 02cf005a..68964b90 100644 --- a/tests/integration_tests/game_layer/observations/test_acl_observations.py +++ b/tests/integration_tests/game_layer/observations/test_acl_observations.py @@ -36,7 +36,7 @@ def test_acl_observations(simulation): router.acl.add_rule(action=ACLAction.PERMIT, dst_port=PORT_LOOKUP["NTP"], src_port=PORT_LOOKUP["NTP"], position=1) acl_obs = ACLObservation( - where=["network", "nodes", router.hostname, "acl", "acl"], + where=["network", "nodes", router.config.hostname, "acl", "acl"], ip_list=[], port_list=[123, 80, 5432], protocol_list=["tcp", "udp", "icmp"], diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py index 0268cb95..a56deb2b 100644 --- a/tests/integration_tests/game_layer/observations/test_file_system_observations.py +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -24,7 +24,7 @@ def test_file_observation(simulation): file = pc.file_system.create_file(file_name="dog.png") dog_file_obs = FileObservation( - where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"], include_num_access=False, file_system_requires_scan=True, ) @@ -52,7 +52,7 @@ def test_folder_observation(simulation): file = pc.file_system.create_file(file_name="dog.png", folder_name="test_folder") root_folder_obs = FolderObservation( - where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"], + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "test_folder"], include_num_access=False, file_system_requires_scan=True, num_files=1, diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py index fe0c3a57..5a0ebe8f 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_router.py @@ -50,7 +50,7 @@ def test_wireless_router_from_config(): }, } - rt = Router.from_config(cfg=cfg) + rt = Router.from_config(config=cfg) assert rt.num_ports == 6