#2887 - Updates to Node components to use rom_config and allow for extensibility. Router and Firewall continue to have custom from_config. Some test fixes to reflect changes to functionality.
This commit is contained in:
58
docs/source/how_to_guides/extensible_nodes.rst
Normal file
58
docs/source/how_to_guides/extensible_nodes.rst
Normal file
@@ -0,0 +1,58 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _about:
|
||||
|
||||
|
||||
Extensible Nodes
|
||||
****************
|
||||
|
||||
Node classes within PrimAITE have been updated to allow for easier generation of custom nodes within simulations.
|
||||
|
||||
|
||||
Changes to Node Class structure.
|
||||
================================
|
||||
|
||||
Node classes all inherit from the base Node Class, though new classes should inherit from either HostNode or NetworkNode, subject to the intended application of the Node.
|
||||
|
||||
The use of an `__init__` method is not necessary, as configurable variables for the class should be specified within the `config` of the class, and passed at run time via your YAML configuration using the `from_config` method.
|
||||
|
||||
|
||||
An example of how additional Node classes is below, taken from `router.py` withing PrimAITE.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
class Router(NetworkNode, identifier="router"):
|
||||
""" Represents a network router within the simulation, managing routing and forwarding of IP packets across network interfaces."""
|
||||
|
||||
SYSTEM_SOFTWARE: ClassVar[Dict] = {
|
||||
"UserSessionManager": UserSessionManager,
|
||||
"UserManager": UserManager,
|
||||
"Terminal": Terminal,
|
||||
}
|
||||
|
||||
network_interfaces: Dict[str, RouterInterface] = {}
|
||||
"The Router Interfaces on the node."
|
||||
network_interface: Dict[int, RouterInterface] = {}
|
||||
"The Router Interfaces on the node by port id."
|
||||
|
||||
sys_log: SysLog
|
||||
|
||||
config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSchema())
|
||||
|
||||
class ConfigSchema(NetworkNode.ConfigSchema):
|
||||
"""Configuration Schema for Router Objects."""
|
||||
|
||||
num_ports: int = 5
|
||||
|
||||
hostname: ClassVar[str] = "Router"
|
||||
|
||||
ports: list = []
|
||||
|
||||
|
||||
|
||||
Changes to YAML file.
|
||||
=====================
|
||||
|
||||
Nodes defined within configuration YAML files for use with PrimAITE 3.X should still be compatible following these changes.
|
||||
@@ -271,12 +271,13 @@ class PrimaiteGame:
|
||||
|
||||
for node_cfg in nodes_cfg:
|
||||
n_type = node_cfg["type"]
|
||||
node_config: dict = node_cfg["config"]
|
||||
# node_config: dict = node_cfg["config"]
|
||||
print(f"{n_type}:{node_cfg}")
|
||||
|
||||
new_node = None
|
||||
if n_type in Node._registry:
|
||||
# simplify down Node creation:
|
||||
new_node = Node._registry["n_type"].from_config(config=node_config)
|
||||
new_node = Node._registry[n_type].from_config(config=node_cfg)
|
||||
else:
|
||||
msg = f"invalid node type {n_type} in config"
|
||||
_LOGGER.error(msg)
|
||||
@@ -313,7 +314,7 @@ class PrimaiteGame:
|
||||
service_class = SERVICE_TYPES_MAPPING[service_type]
|
||||
|
||||
if service_class is not None:
|
||||
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
|
||||
_LOGGER.debug(f"installing {service_type} on node {new_node.config.hostname}")
|
||||
new_node.software_manager.install(service_class, **service_cfg.get("options", {}))
|
||||
new_service = new_node.software_manager.software[service_class.__name__]
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"""Core of the PrimAITE Simulator."""
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from prettytable import PrettyTable
|
||||
|
||||
@@ -180,7 +180,7 @@ class Network(SimComponent):
|
||||
table.align = "l"
|
||||
table.title = "Nodes"
|
||||
for node in self.nodes.values():
|
||||
table.add_row((node.hostname, type(node)._identifier, node.operating_state.name))
|
||||
table.add_row((node.config.hostname, type(node)._identifier, node.operating_state.name))
|
||||
print(table)
|
||||
|
||||
if ip_addresses:
|
||||
@@ -196,7 +196,7 @@ class Network(SimComponent):
|
||||
if port.ip_address != IPv4Address("127.0.0.1"):
|
||||
port_str = port.port_name if port.port_name else port.port_num
|
||||
table.add_row(
|
||||
[node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway]
|
||||
[node.config.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway]
|
||||
)
|
||||
print(table)
|
||||
|
||||
@@ -215,9 +215,9 @@ class Network(SimComponent):
|
||||
if node in [link.endpoint_a.parent, link.endpoint_b.parent]:
|
||||
table.add_row(
|
||||
[
|
||||
link.endpoint_a.parent.hostname,
|
||||
link.endpoint_a.parent.config.hostname,
|
||||
str(link.endpoint_a),
|
||||
link.endpoint_b.parent.hostname,
|
||||
link.endpoint_b.parent.config.hostname,
|
||||
str(link.endpoint_b),
|
||||
link.is_up,
|
||||
link.bandwidth,
|
||||
@@ -251,7 +251,7 @@ class Network(SimComponent):
|
||||
state = super().describe_state()
|
||||
state.update(
|
||||
{
|
||||
"nodes": {node.hostname: node.describe_state() for node in self.nodes.values()},
|
||||
"nodes": {node.config.hostname: node.describe_state() for node in self.nodes.values()},
|
||||
"links": {},
|
||||
}
|
||||
)
|
||||
@@ -259,8 +259,8 @@ class Network(SimComponent):
|
||||
for _, link in self.links.items():
|
||||
node_a = link.endpoint_a._connected_node
|
||||
node_b = link.endpoint_b._connected_node
|
||||
hostname_a = node_a.hostname if node_a else None
|
||||
hostname_b = node_b.hostname if node_b else None
|
||||
hostname_a = node_a.config.hostname if node_a else None
|
||||
hostname_b = node_b.config.hostname if node_b else None
|
||||
port_a = link.endpoint_a.port_num
|
||||
port_b = link.endpoint_b.port_num
|
||||
link_key = f"{hostname_a}:eth-{port_a}<->{hostname_b}:eth-{port_b}"
|
||||
@@ -286,9 +286,9 @@ class Network(SimComponent):
|
||||
self.nodes[node.uuid] = node
|
||||
self._node_id_map[len(self.nodes)] = node
|
||||
node.parent = self
|
||||
self._nx_graph.add_node(node.hostname)
|
||||
self._nx_graph.add_node(node.config.hostname)
|
||||
_LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}")
|
||||
self._node_request_manager.add_request(name=node.hostname, request_type=RequestType(func=node._request_manager))
|
||||
self._node_request_manager.add_request(name=node.config.hostname, request_type=RequestType(func=node._request_manager))
|
||||
|
||||
def get_node_by_hostname(self, hostname: str) -> Optional[Node]:
|
||||
"""
|
||||
@@ -300,7 +300,7 @@ class Network(SimComponent):
|
||||
:return: The Node if it exists in the network.
|
||||
"""
|
||||
for node in self.nodes.values():
|
||||
if node.hostname == hostname:
|
||||
if node.config.hostname == hostname:
|
||||
return node
|
||||
|
||||
def remove_node(self, node: Node) -> None:
|
||||
@@ -313,7 +313,7 @@ class Network(SimComponent):
|
||||
:type node: Node
|
||||
"""
|
||||
if node not in self:
|
||||
_LOGGER.warning(f"Can't remove node {node.hostname}. It's not in the network.")
|
||||
_LOGGER.warning(f"Can't remove node {node.config.hostname}. It's not in the network.")
|
||||
return
|
||||
self.nodes.pop(node.uuid)
|
||||
for i, _node in self._node_id_map.items():
|
||||
@@ -321,8 +321,8 @@ class Network(SimComponent):
|
||||
self._node_id_map.pop(i)
|
||||
break
|
||||
node.parent = None
|
||||
self._node_request_manager.remove_request(name=node.hostname)
|
||||
_LOGGER.info(f"Removed node {node.hostname} from network {self.uuid}")
|
||||
self._node_request_manager.remove_request(name=node.config.hostname)
|
||||
_LOGGER.info(f"Removed node {node.config.hostname} from network {self.uuid}")
|
||||
|
||||
def connect(
|
||||
self, endpoint_a: WiredNetworkInterface, endpoint_b: WiredNetworkInterface, bandwidth: int = 100, **kwargs
|
||||
@@ -352,7 +352,7 @@ class Network(SimComponent):
|
||||
link = Link(endpoint_a=endpoint_a, endpoint_b=endpoint_b, bandwidth=bandwidth, **kwargs)
|
||||
self.links[link.uuid] = link
|
||||
self._link_id_map[len(self.links)] = link
|
||||
self._nx_graph.add_edge(endpoint_a.parent.hostname, endpoint_b.parent.hostname)
|
||||
self._nx_graph.add_edge(endpoint_a.parent.config.hostname, endpoint_b.parent.config.hostname)
|
||||
link.parent = self
|
||||
_LOGGER.debug(f"Added link {link.uuid} to connect {endpoint_a} and {endpoint_b}")
|
||||
return link
|
||||
|
||||
@@ -431,7 +431,7 @@ class WiredNetworkInterface(NetworkInterface, ABC):
|
||||
self.enabled = True
|
||||
self._connected_node.sys_log.info(f"Network Interface {self} enabled")
|
||||
self.pcap = PacketCapture(
|
||||
hostname=self._connected_node.hostname, port_num=self.port_num, port_name=self.port_name
|
||||
hostname=self._connected_node.config.hostname, port_num=self.port_num, port_name=self.port_name
|
||||
)
|
||||
if self._connected_link:
|
||||
self._connected_link.endpoint_up()
|
||||
@@ -1515,14 +1515,16 @@ class Node(SimComponent, ABC):
|
||||
_identifier: ClassVar[str] = "unknown"
|
||||
"""Identifier for this particular class, used for printing and logging. Each subclass redefines this."""
|
||||
|
||||
config: Node.ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema())
|
||||
config: Node.ConfigSchema
|
||||
"""Configuration items within Node"""
|
||||
|
||||
class ConfigSchema(BaseModel, ABC):
|
||||
"""Configuration Schema for Node based classes."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
"""Configure pydantic to allow arbitrary types and to let the instance have attributes not present in the model."""
|
||||
hostname: str
|
||||
|
||||
hostname: str = "default"
|
||||
"The node hostname on the network."
|
||||
|
||||
revealed_to_red: bool = False
|
||||
@@ -1552,6 +1554,7 @@ class Node(SimComponent, ABC):
|
||||
red_scan_countdown: int = 0
|
||||
"Time steps until reveal to red scan is complete."
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "Node":
|
||||
"""Create Node object from a given configuration dictionary."""
|
||||
@@ -1586,11 +1589,11 @@ class Node(SimComponent, ABC):
|
||||
provided.
|
||||
"""
|
||||
if not kwargs.get("sys_log"):
|
||||
kwargs["sys_log"] = SysLog(kwargs["hostname"])
|
||||
kwargs["sys_log"] = SysLog(kwargs["config"].hostname)
|
||||
if not kwargs.get("session_manager"):
|
||||
kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"))
|
||||
if not kwargs.get("root"):
|
||||
kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"]
|
||||
kwargs["root"] = SIM_OUTPUT.path / kwargs["config"].hostname
|
||||
if not kwargs.get("file_system"):
|
||||
kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs")
|
||||
if not kwargs.get("software_manager"):
|
||||
@@ -1601,10 +1604,12 @@ class Node(SimComponent, ABC):
|
||||
file_system=kwargs.get("file_system"),
|
||||
dns_server=kwargs.get("dns_server"),
|
||||
)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self._install_system_software()
|
||||
self.session_manager.node = self
|
||||
self.session_manager.software_manager = self.software_manager
|
||||
self.power_on()
|
||||
|
||||
@property
|
||||
def user_manager(self) -> Optional[UserManager]:
|
||||
@@ -1856,7 +1861,7 @@ class Node(SimComponent, ABC):
|
||||
state = super().describe_state()
|
||||
state.update(
|
||||
{
|
||||
"hostname": self.hostname,
|
||||
"hostname": self.config.hostname,
|
||||
"operating_state": self.operating_state.value,
|
||||
"NICs": {
|
||||
eth_num: network_interface.describe_state()
|
||||
|
||||
@@ -333,10 +333,13 @@ class HostNode(Node, identifier="HostNode"):
|
||||
"""Configuration Schema for HostNode class."""
|
||||
|
||||
hostname: str = "HostNode"
|
||||
ip_address: IPV4Address = "192.168.0.1"
|
||||
subnet_mask: IPV4Address = "255.255.255.0"
|
||||
default_gateway: IPV4Address = "192.168.10.1"
|
||||
|
||||
def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask))
|
||||
self.connect_nic(NIC(ip_address=kwargs["config"].ip_address, subnet_mask=kwargs["config"].subnet_mask))
|
||||
|
||||
@property
|
||||
def nmap(self) -> Optional[NMAP]:
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from typing import ClassVar
|
||||
from pydantic import Field
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
|
||||
|
||||
|
||||
@@ -30,8 +32,23 @@ class Server(HostNode, identifier="server"):
|
||||
* Web Browser
|
||||
"""
|
||||
|
||||
config: "Server.ConfigSchema" = Field(default_factory=lambda: Server.ConfigSchema())
|
||||
|
||||
class ConfigSchema(HostNode.ConfigSchema):
|
||||
"""Configuration Schema for Server class."""
|
||||
|
||||
hostname: str = "server"
|
||||
|
||||
|
||||
class Printer(HostNode, identifier="printer"):
|
||||
"""Printer? I don't even know her!."""
|
||||
|
||||
# TODO: Implement printer-specific behaviour
|
||||
|
||||
|
||||
config: "Printer.ConfigSchema" = Field(default_factory=lambda: Printer.ConfigSchema())
|
||||
|
||||
class ConfigSchema(HostNode.ConfigSchema):
|
||||
"""Configuration Schema for Printer class."""
|
||||
|
||||
hostname: ClassVar[str] = "printer"
|
||||
@@ -99,19 +99,22 @@ class Firewall(Router, identifier="firewall"):
|
||||
)
|
||||
"""Access Control List for managing traffic leaving towards an external network."""
|
||||
|
||||
_identifier: str = "firewall"
|
||||
|
||||
config: "Firewall.ConfigSchema" = Field(default_factory=lambda: Firewall.ConfigSchema())
|
||||
|
||||
class ConfigSchema(Router.ConfigSChema):
|
||||
class ConfigSchema(Router.ConfigSchema):
|
||||
"""Configuration Schema for Firewall 'Nodes' within PrimAITE."""
|
||||
|
||||
hostname: str = "Firewall"
|
||||
hostname: str = "firewall"
|
||||
num_ports: int = 0
|
||||
operating_state: NodeOperatingState = NodeOperatingState.ON
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if not kwargs.get("sys_log"):
|
||||
kwargs["sys_log"] = SysLog(self.config.hostname)
|
||||
kwargs["sys_log"] = SysLog(kwargs["config"].hostname)
|
||||
|
||||
super().__init__(hostname=self.config.hostname, num_ports=self.config.num_ports, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.connect_nic(
|
||||
RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", port_name="external")
|
||||
@@ -124,22 +127,22 @@ class Firewall(Router, identifier="firewall"):
|
||||
)
|
||||
# Update ACL objects with firewall's hostname and syslog to allow accurate logging
|
||||
self.internal_inbound_acl.sys_log = kwargs["sys_log"]
|
||||
self.internal_inbound_acl.name = f"{hostname} - Internal Inbound"
|
||||
self.internal_inbound_acl.name = f"{kwargs['config'].hostname} - Internal Inbound"
|
||||
|
||||
self.internal_outbound_acl.sys_log = kwargs["sys_log"]
|
||||
self.internal_outbound_acl.name = f"{hostname} - Internal Outbound"
|
||||
self.internal_outbound_acl.name = f"{kwargs['config'].hostname} - Internal Outbound"
|
||||
|
||||
self.dmz_inbound_acl.sys_log = kwargs["sys_log"]
|
||||
self.dmz_inbound_acl.name = f"{hostname} - DMZ Inbound"
|
||||
self.dmz_inbound_acl.name = f"{kwargs['config'].hostname} - DMZ Inbound"
|
||||
|
||||
self.dmz_outbound_acl.sys_log = kwargs["sys_log"]
|
||||
self.dmz_outbound_acl.name = f"{hostname} - DMZ Outbound"
|
||||
self.dmz_outbound_acl.name = f"{kwargs['config'].hostname} - DMZ Outbound"
|
||||
|
||||
self.external_inbound_acl.sys_log = kwargs["sys_log"]
|
||||
self.external_inbound_acl.name = f"{hostname} - External Inbound"
|
||||
self.external_inbound_acl.name = f"{kwargs['config'].hostname} - External Inbound"
|
||||
|
||||
self.external_outbound_acl.sys_log = kwargs["sys_log"]
|
||||
self.external_outbound_acl.name = f"{hostname} - External Outbound"
|
||||
self.external_outbound_acl.name = f"{kwargs['config'].hostname} - External Outbound"
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
@@ -567,18 +570,21 @@ class Firewall(Router, identifier="firewall"):
|
||||
self.dmz_port.enable()
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: dict) -> "Firewall":
|
||||
def from_config(cls, config: dict) -> "Firewall":
|
||||
"""Create a firewall based on a config dict."""
|
||||
firewall = Firewall(
|
||||
hostname=cfg["hostname"],
|
||||
operating_state=NodeOperatingState.ON
|
||||
if not (p := cfg.get("operating_state"))
|
||||
else NodeOperatingState[p.upper()],
|
||||
)
|
||||
if "ports" in cfg:
|
||||
internal_port = cfg["ports"]["internal_port"]
|
||||
external_port = cfg["ports"]["external_port"]
|
||||
dmz_port = cfg["ports"].get("dmz_port")
|
||||
# firewall = Firewall(
|
||||
# hostname=config["hostname"],
|
||||
# operating_state=NodeOperatingState.ON
|
||||
# if not (p := config.get("operating_state"))
|
||||
# else NodeOperatingState[p.upper()],
|
||||
# )
|
||||
|
||||
firewall = Firewall(config = cls.ConfigSchema(**config))
|
||||
|
||||
if "ports" in config:
|
||||
internal_port = config["ports"]["internal_port"]
|
||||
external_port = config["ports"]["external_port"]
|
||||
dmz_port = config["ports"].get("dmz_port")
|
||||
|
||||
# configure internal port
|
||||
firewall.configure_internal_port(
|
||||
@@ -598,10 +604,10 @@ class Firewall(Router, identifier="firewall"):
|
||||
ip_address=IPV4Address(dmz_port.get("ip_address")),
|
||||
subnet_mask=IPV4Address(dmz_port.get("subnet_mask", "255.255.255.0")),
|
||||
)
|
||||
if "acl" in cfg:
|
||||
if "acl" in config:
|
||||
# acl rules for internal_inbound_acl
|
||||
if cfg["acl"]["internal_inbound_acl"]:
|
||||
for r_num, r_cfg in cfg["acl"]["internal_inbound_acl"].items():
|
||||
if config["acl"]["internal_inbound_acl"]:
|
||||
for r_num, r_cfg in config["acl"]["internal_inbound_acl"].items():
|
||||
firewall.internal_inbound_acl.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
|
||||
@@ -615,8 +621,8 @@ class Firewall(Router, identifier="firewall"):
|
||||
)
|
||||
|
||||
# acl rules for internal_outbound_acl
|
||||
if cfg["acl"]["internal_outbound_acl"]:
|
||||
for r_num, r_cfg in cfg["acl"]["internal_outbound_acl"].items():
|
||||
if config["acl"]["internal_outbound_acl"]:
|
||||
for r_num, r_cfg in config["acl"]["internal_outbound_acl"].items():
|
||||
firewall.internal_outbound_acl.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
|
||||
@@ -630,8 +636,8 @@ class Firewall(Router, identifier="firewall"):
|
||||
)
|
||||
|
||||
# acl rules for dmz_inbound_acl
|
||||
if cfg["acl"]["dmz_inbound_acl"]:
|
||||
for r_num, r_cfg in cfg["acl"]["dmz_inbound_acl"].items():
|
||||
if config["acl"]["dmz_inbound_acl"]:
|
||||
for r_num, r_cfg in config["acl"]["dmz_inbound_acl"].items():
|
||||
firewall.dmz_inbound_acl.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
|
||||
@@ -645,8 +651,8 @@ class Firewall(Router, identifier="firewall"):
|
||||
)
|
||||
|
||||
# acl rules for dmz_outbound_acl
|
||||
if cfg["acl"]["dmz_outbound_acl"]:
|
||||
for r_num, r_cfg in cfg["acl"]["dmz_outbound_acl"].items():
|
||||
if config["acl"]["dmz_outbound_acl"]:
|
||||
for r_num, r_cfg in config["acl"]["dmz_outbound_acl"].items():
|
||||
firewall.dmz_outbound_acl.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
|
||||
@@ -660,8 +666,8 @@ class Firewall(Router, identifier="firewall"):
|
||||
)
|
||||
|
||||
# acl rules for external_inbound_acl
|
||||
if cfg["acl"].get("external_inbound_acl"):
|
||||
for r_num, r_cfg in cfg["acl"]["external_inbound_acl"].items():
|
||||
if config["acl"].get("external_inbound_acl"):
|
||||
for r_num, r_cfg in config["acl"]["external_inbound_acl"].items():
|
||||
firewall.external_inbound_acl.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
|
||||
@@ -675,8 +681,8 @@ class Firewall(Router, identifier="firewall"):
|
||||
)
|
||||
|
||||
# acl rules for external_outbound_acl
|
||||
if cfg["acl"].get("external_outbound_acl"):
|
||||
for r_num, r_cfg in cfg["acl"]["external_outbound_acl"].items():
|
||||
if config["acl"].get("external_outbound_acl"):
|
||||
for r_num, r_cfg in config["acl"]["external_outbound_acl"].items():
|
||||
firewall.external_outbound_acl.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
|
||||
@@ -689,16 +695,16 @@ class Firewall(Router, identifier="firewall"):
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
if "routes" in cfg:
|
||||
for route in cfg.get("routes"):
|
||||
if "routes" in config:
|
||||
for route in config.get("routes"):
|
||||
firewall.route_table.add_route(
|
||||
address=IPv4Address(route.get("address")),
|
||||
subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")),
|
||||
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
|
||||
metric=float(route.get("metric", 0)),
|
||||
)
|
||||
if "default_route" in cfg:
|
||||
next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None)
|
||||
if "default_route" in config:
|
||||
next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None)
|
||||
if next_hop_ip_address:
|
||||
firewall.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
|
||||
|
||||
|
||||
@@ -1212,21 +1212,34 @@ class Router(NetworkNode, identifier="router"):
|
||||
network_interface: Dict[int, RouterInterface] = {}
|
||||
"The Router Interfaces on the node by port id."
|
||||
|
||||
config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSchema())
|
||||
sys_log: SysLog = None
|
||||
|
||||
acl: AccessControlList = None
|
||||
|
||||
route_table: RouteTable = None
|
||||
|
||||
config: "Router.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NetworkNode.ConfigSchema):
|
||||
"""Configuration Schema for Router Objects."""
|
||||
|
||||
num_ports: int = 5
|
||||
"""Number of ports available for this Router. Default is 5"""
|
||||
|
||||
hostname: str = "Router"
|
||||
ports: list = []
|
||||
sys_log: SysLog = SysLog(hostname)
|
||||
acl: AccessControlList = AccessControlList(sys_log=sys_log, implicit_action=ACLAction.DENY, name=hostname)
|
||||
route_table: RouteTable = RouteTable(sys_log=sys_log)
|
||||
|
||||
ports: Dict[Union[int, str], Dict] = {}
|
||||
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(hostname=self.config.hostname, num_ports=self.config.num_ports, **kwargs)
|
||||
self.session_manager = RouterSessionManager(sys_log=self.config.sys_log)
|
||||
if not kwargs.get("sys_log"):
|
||||
kwargs["sys_log"] = SysLog(kwargs["config"].hostname)
|
||||
if not kwargs.get("acl"):
|
||||
kwargs["acl"] = AccessControlList(sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=kwargs["config"].hostname)
|
||||
if not kwargs.get("route_table"):
|
||||
kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"])
|
||||
super().__init__(**kwargs)
|
||||
self.session_manager = RouterSessionManager(sys_log=self.sys_log)
|
||||
self.session_manager.node = self
|
||||
self.software_manager.session_manager = self.session_manager
|
||||
self.session_manager.software_manager = self.software_manager
|
||||
@@ -1234,9 +1247,11 @@ class Router(NetworkNode, identifier="router"):
|
||||
network_interface = RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
|
||||
self.connect_nic(network_interface)
|
||||
self.network_interface[i] = network_interface
|
||||
self.operating_state = NodeOperatingState.ON
|
||||
|
||||
self._set_default_acl()
|
||||
|
||||
|
||||
|
||||
def _install_system_software(self):
|
||||
"""
|
||||
Installs essential system software and network services on the router.
|
||||
@@ -1260,10 +1275,10 @@ class Router(NetworkNode, identifier="router"):
|
||||
Initializes the router's ACL (Access Control List) with default rules, permitting essential protocols like ARP
|
||||
and ICMP, which are necessary for basic network operations and diagnostics.
|
||||
"""
|
||||
self.config.acl.add_rule(
|
||||
self.acl.add_rule(
|
||||
action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22
|
||||
)
|
||||
self.config.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23)
|
||||
self.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23)
|
||||
|
||||
def setup_for_episode(self, episode: int):
|
||||
"""
|
||||
@@ -1287,7 +1302,7 @@ class Router(NetworkNode, identifier="router"):
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request("acl", RequestType(func=self.config.acl._request_manager))
|
||||
rm.add_request("acl", RequestType(func=self.acl._request_manager))
|
||||
return rm
|
||||
|
||||
def ip_is_router_interface(self, ip_address: IPv4Address, enabled_only: bool = False) -> bool:
|
||||
@@ -1341,7 +1356,7 @@ class Router(NetworkNode, identifier="router"):
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["num_ports"] = self.config.num_ports
|
||||
state["acl"] = self.config.acl.describe_state()
|
||||
state["acl"] = self.acl.describe_state()
|
||||
return state
|
||||
|
||||
def check_send_frame_to_session_manager(self, frame: Frame) -> bool:
|
||||
@@ -1562,7 +1577,7 @@ class Router(NetworkNode, identifier="router"):
|
||||
print(table)
|
||||
|
||||
def setup_router(self, cfg: dict) -> Router:
|
||||
""" TODO: This is the extra bit of Router's from_config metho. Needs sorting."""
|
||||
"""TODO: This is the extra bit of Router's from_config metho. Needs sorting."""
|
||||
if "ports" in cfg:
|
||||
for port_num, port_cfg in cfg["ports"].items():
|
||||
self.configure_port(
|
||||
@@ -1594,5 +1609,100 @@ class Router(NetworkNode, identifier="router"):
|
||||
if "default_route" in cfg:
|
||||
next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None)
|
||||
if next_hop_ip_address:
|
||||
self.config.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
|
||||
self.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
|
||||
return self
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict, **kwargs) -> "Router":
|
||||
"""Create a router based on a config dict.
|
||||
|
||||
Schema:
|
||||
- hostname (str): unique name for this router.
|
||||
- num_ports (int, optional): Number of network ports on the router. 8 by default
|
||||
- ports (dict): Dict with integers from 1 - num_ports as keys. The values should be another dict specifying
|
||||
ip_address and subnet_mask assigned to that ports (as strings)
|
||||
- acl (dict): Dict with integers from 1 - max_acl_rules as keys. The key defines the position within the ACL
|
||||
where the rule will be added (lower number is resolved first). The values should describe valid ACL
|
||||
Rules as:
|
||||
- action (str): either PERMIT or DENY
|
||||
- src_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
|
||||
- dst_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
|
||||
- protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP
|
||||
- src_ip_address (str, optional): IP address octet written in base 10
|
||||
- dst_ip_address (str, optional): IP address octet written in base 10
|
||||
- routes (list[dict]): List of route dicts with values:
|
||||
- address (str): The destination address of the route.
|
||||
- subnet_mask (str): The subnet mask of the route.
|
||||
- next_hop_ip_address (str): The next hop IP for the route.
|
||||
- metric (int): The metric of the route. Optional.
|
||||
- default_route:
|
||||
- next_hop_ip_address (str): The next hop IP for the route.
|
||||
|
||||
Example config:
|
||||
```
|
||||
{
|
||||
'hostname': 'router_1',
|
||||
'num_ports': 5,
|
||||
'ports': {
|
||||
1: {
|
||||
'ip_address' : '192.168.1.1',
|
||||
'subnet_mask' : '255.255.255.0',
|
||||
},
|
||||
2: {
|
||||
'ip_address' : '192.168.0.1',
|
||||
'subnet_mask' : '255.255.255.252',
|
||||
}
|
||||
},
|
||||
'acl' : {
|
||||
21: {'action': 'PERMIT', 'src_port': 'HTTP', dst_port: 'HTTP'},
|
||||
22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'},
|
||||
23: {'action': 'PERMIT', 'protocol': 'ICMP'},
|
||||
},
|
||||
'routes' : [
|
||||
{'address': '192.168.0.0', 'subnet_mask': '255.255.255.0', 'next_hop_ip_address': '192.168.1.2'}
|
||||
],
|
||||
'default_route': {'next_hop_ip_address': '192.168.0.2'}
|
||||
}
|
||||
```
|
||||
|
||||
:param cfg: Router config adhering to schema described in main docstring body
|
||||
:type cfg: dict
|
||||
:return: Configured router.
|
||||
:rtype: Router
|
||||
"""
|
||||
router = Router(config=Router.ConfigSchema(**config)
|
||||
)
|
||||
if "ports" in config:
|
||||
for port_num, port_cfg in config["ports"].items():
|
||||
router.configure_port(
|
||||
port=port_num,
|
||||
ip_address=port_cfg["ip_address"],
|
||||
subnet_mask=IPv4Address(port_cfg.get("subnet_mask", "255.255.255.0")),
|
||||
)
|
||||
if "acl" in config:
|
||||
for r_num, r_cfg in config["acl"].items():
|
||||
router.acl.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
|
||||
position=r_num,
|
||||
)
|
||||
if "routes" in config:
|
||||
for route in config.get("routes"):
|
||||
router.route_table.add_route(
|
||||
address=IPv4Address(route.get("address")),
|
||||
subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")),
|
||||
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
|
||||
metric=float(route.get("metric", 0)),
|
||||
)
|
||||
if "default_route" in config:
|
||||
next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None)
|
||||
if next_hop_ip_address:
|
||||
router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
|
||||
return router
|
||||
@@ -1,7 +1,7 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional
|
||||
from typing import ClassVar, Dict, Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import Field
|
||||
@@ -102,7 +102,7 @@ class Switch(NetworkNode, identifier="switch"):
|
||||
mac_address_table: Dict[str, SwitchPort] = {}
|
||||
"A MAC address table mapping destination MAC addresses to corresponding SwitchPorts."
|
||||
|
||||
config: "Switch.ConfigSchema" = Field(default_factory=lambda: Switch.ConfigSchema())
|
||||
config: "Switch.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NetworkNode.ConfigSchema):
|
||||
"""Configuration Schema for Switch nodes within PrimAITE."""
|
||||
@@ -113,7 +113,7 @@ class Switch(NetworkNode, identifier="switch"):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for i in range(1, self.config.num_ports + 1):
|
||||
for i in range(1, kwargs["config"].num_ports + 1):
|
||||
self.connect_nic(SwitchPort())
|
||||
|
||||
def _install_system_software(self):
|
||||
|
||||
@@ -294,7 +294,7 @@ class IOSoftware(Software):
|
||||
"""
|
||||
if self.software_manager and self.software_manager.node.operating_state != NodeOperatingState.ON:
|
||||
self.software_manager.node.sys_log.error(
|
||||
f"{self.name} Error: {self.software_manager.node.hostname} is not powered on."
|
||||
f"{self.name} Error: {self.software_manager.node.config.hostname} is not powered on."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -187,7 +187,7 @@ agents:
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
include_nmne: true
|
||||
monitored_traffic:
|
||||
icmp:
|
||||
- NONE
|
||||
|
||||
@@ -195,68 +195,91 @@ def example_network() -> Network:
|
||||
network = Network()
|
||||
|
||||
# Router 1
|
||||
|
||||
router_1_cfg = {"hostname":"router_1", "type":"router"}
|
||||
|
||||
# router_1 = Router(hostname="router_1", start_up_duration=0)
|
||||
router_1 = Router(hostname="router_1", start_up_duration=0)
|
||||
router_1 = Router.from_config(config=router_1_cfg)
|
||||
router_1.power_on()
|
||||
router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
|
||||
router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0")
|
||||
|
||||
# Switch 1
|
||||
# switch_1_config = Switch.ConfigSchema()
|
||||
switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0)
|
||||
|
||||
switch_1_cfg = {"hostname": "switch_1", "type": "switch"}
|
||||
|
||||
switch_1 = Switch.from_config(config=switch_1_cfg)
|
||||
|
||||
# switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0)
|
||||
switch_1.power_on()
|
||||
|
||||
network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[8])
|
||||
router_1.enable_port(1)
|
||||
|
||||
# Switch 2
|
||||
# switch_2_config = Switch.ConfigSchema()
|
||||
switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0)
|
||||
switch_2_config = {"hostname": "switch_2", "type": "switch", "num_ports": 8}
|
||||
# switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0)
|
||||
switch_2 = Switch.from_config(config=switch_2_config)
|
||||
switch_2.power_on()
|
||||
network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8])
|
||||
router_1.enable_port(2)
|
||||
|
||||
# Client 1
|
||||
client_1 = Computer(
|
||||
hostname="client_1",
|
||||
ip_address="192.168.10.21",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.10.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
# # Client 1
|
||||
|
||||
client_1_cfg = {"type": "computer",
|
||||
"hostname": "client_1",
|
||||
"ip_address": "192.168.10.21",
|
||||
"subnet_mask": "255.255.255.0",
|
||||
"default_gateway": "192.168.10.1",
|
||||
"start_up_duration": 0}
|
||||
|
||||
client_1=Computer.from_config(config=client_1_cfg)
|
||||
|
||||
client_1.power_on()
|
||||
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
|
||||
|
||||
# Client 2
|
||||
client_2 = Computer(
|
||||
hostname="client_2",
|
||||
ip_address="192.168.10.22",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.10.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
# # Client 2
|
||||
|
||||
client_2_cfg = {"type": "computer",
|
||||
"hostname": "client_2",
|
||||
"ip_address": "192.168.10.22",
|
||||
"subnet_mask": "255.255.255.0",
|
||||
"default_gateway": "192.168.10.1",
|
||||
"start_up_duration": 0,
|
||||
}
|
||||
|
||||
client_2 = Computer.from_config(config=client_2_cfg)
|
||||
|
||||
client_2.power_on()
|
||||
network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.network_interface[2])
|
||||
|
||||
# Server 1
|
||||
server_1 = Server(
|
||||
hostname="server_1",
|
||||
ip_address="192.168.1.10",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
# # Server 1
|
||||
|
||||
server_1_cfg = {"type": "server",
|
||||
"hostname": "server_1",
|
||||
"ip_address":"192.168.1.10",
|
||||
"subnet_mask":"255.255.255.0",
|
||||
"default_gateway":"192.168.1.1",
|
||||
"start_up_duration":0,
|
||||
}
|
||||
|
||||
server_1 = Server.from_config(config=server_1_cfg)
|
||||
|
||||
server_1.power_on()
|
||||
network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1])
|
||||
|
||||
# DServer 2
|
||||
server_2 = Server(
|
||||
hostname="server_2",
|
||||
ip_address="192.168.1.14",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
# # DServer 2
|
||||
|
||||
server_2_cfg = {"type": "server",
|
||||
"hostname": "server_2",
|
||||
"ip_address":"192.168.1.14",
|
||||
"subnet_mask":"255.255.255.0",
|
||||
"default_gateway":"192.168.1.1",
|
||||
"start_up_duration":0,
|
||||
}
|
||||
|
||||
server_2 = Server.from_config(config=server_2_cfg)
|
||||
|
||||
server_2.power_on()
|
||||
network.connect(endpoint_b=server_2.network_interface[1], endpoint_a=switch_1.network_interface[2])
|
||||
|
||||
@@ -264,6 +287,8 @@ def example_network() -> Network:
|
||||
|
||||
assert all(link.is_up for link in network.links.values())
|
||||
|
||||
client_1.software_manager.show()
|
||||
|
||||
return network
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.port import PORT_LOOKUP
|
||||
from tests.integration_tests.configuration_file_parsing import DMZ_NETWORK, load_config
|
||||
|
||||
@@ -36,7 +36,7 @@ def test_acl_observations(simulation):
|
||||
router.acl.add_rule(action=ACLAction.PERMIT, dst_port=PORT_LOOKUP["NTP"], src_port=PORT_LOOKUP["NTP"], position=1)
|
||||
|
||||
acl_obs = ACLObservation(
|
||||
where=["network", "nodes", router.hostname, "acl", "acl"],
|
||||
where=["network", "nodes", router.config.hostname, "acl", "acl"],
|
||||
ip_list=[],
|
||||
port_list=[123, 80, 5432],
|
||||
protocol_list=["tcp", "udp", "icmp"],
|
||||
|
||||
@@ -24,7 +24,7 @@ def test_file_observation(simulation):
|
||||
file = pc.file_system.create_file(file_name="dog.png")
|
||||
|
||||
dog_file_obs = FileObservation(
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
|
||||
where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"],
|
||||
include_num_access=False,
|
||||
file_system_requires_scan=True,
|
||||
)
|
||||
@@ -52,7 +52,7 @@ def test_folder_observation(simulation):
|
||||
file = pc.file_system.create_file(file_name="dog.png", folder_name="test_folder")
|
||||
|
||||
root_folder_obs = FolderObservation(
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"],
|
||||
where=["network", "nodes", pc.config.hostname, "file_system", "folders", "test_folder"],
|
||||
include_num_access=False,
|
||||
file_system_requires_scan=True,
|
||||
num_files=1,
|
||||
|
||||
@@ -50,7 +50,7 @@ def test_wireless_router_from_config():
|
||||
},
|
||||
}
|
||||
|
||||
rt = Router.from_config(cfg=cfg)
|
||||
rt = Router.from_config(config=cfg)
|
||||
|
||||
assert rt.num_ports == 6
|
||||
|
||||
|
||||
Reference in New Issue
Block a user