#2887 - Updates to Node components to use rom_config and allow for extensibility. Router and Firewall continue to have custom from_config. Some test fixes to reflect changes to functionality.

This commit is contained in:
Charlie Crane
2025-01-22 17:20:38 +00:00
parent 70d9fe2fd9
commit 3957142afd
17 changed files with 350 additions and 124 deletions

View File

@@ -0,0 +1,58 @@
.. only:: comment
© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
.. _about:
Extensible Nodes
****************
Node classes within PrimAITE have been updated to allow for easier generation of custom nodes within simulations.
Changes to Node Class structure.
================================
Node classes all inherit from the base Node Class, though new classes should inherit from either HostNode or NetworkNode, subject to the intended application of the Node.
The use of an `__init__` method is not necessary, as configurable variables for the class should be specified within the `config` of the class, and passed at run time via your YAML configuration using the `from_config` method.
An example of how additional Node classes is below, taken from `router.py` withing PrimAITE.
.. code-block:: Python
class Router(NetworkNode, identifier="router"):
""" Represents a network router within the simulation, managing routing and forwarding of IP packets across network interfaces."""
SYSTEM_SOFTWARE: ClassVar[Dict] = {
"UserSessionManager": UserSessionManager,
"UserManager": UserManager,
"Terminal": Terminal,
}
network_interfaces: Dict[str, RouterInterface] = {}
"The Router Interfaces on the node."
network_interface: Dict[int, RouterInterface] = {}
"The Router Interfaces on the node by port id."
sys_log: SysLog
config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSchema())
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Router Objects."""
num_ports: int = 5
hostname: ClassVar[str] = "Router"
ports: list = []
Changes to YAML file.
=====================
Nodes defined within configuration YAML files for use with PrimAITE 3.X should still be compatible following these changes.

View File

@@ -271,12 +271,13 @@ class PrimaiteGame:
for node_cfg in nodes_cfg:
n_type = node_cfg["type"]
node_config: dict = node_cfg["config"]
# node_config: dict = node_cfg["config"]
print(f"{n_type}:{node_cfg}")
new_node = None
if n_type in Node._registry:
# simplify down Node creation:
new_node = Node._registry["n_type"].from_config(config=node_config)
new_node = Node._registry[n_type].from_config(config=node_cfg)
else:
msg = f"invalid node type {n_type} in config"
_LOGGER.error(msg)
@@ -313,7 +314,7 @@ class PrimaiteGame:
service_class = SERVICE_TYPES_MAPPING[service_type]
if service_class is not None:
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
_LOGGER.debug(f"installing {service_type} on node {new_node.config.hostname}")
new_node.software_manager.install(service_class, **service_cfg.get("options", {}))
new_service = new_node.software_manager.software[service_class.__name__]

View File

@@ -3,7 +3,7 @@
"""Core of the PrimAITE Simulator."""
import warnings
from abc import abstractmethod
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
from uuid import uuid4
from prettytable import PrettyTable

View File

@@ -180,7 +180,7 @@ class Network(SimComponent):
table.align = "l"
table.title = "Nodes"
for node in self.nodes.values():
table.add_row((node.hostname, type(node)._identifier, node.operating_state.name))
table.add_row((node.config.hostname, type(node)._identifier, node.operating_state.name))
print(table)
if ip_addresses:
@@ -196,7 +196,7 @@ class Network(SimComponent):
if port.ip_address != IPv4Address("127.0.0.1"):
port_str = port.port_name if port.port_name else port.port_num
table.add_row(
[node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway]
[node.config.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway]
)
print(table)
@@ -215,9 +215,9 @@ class Network(SimComponent):
if node in [link.endpoint_a.parent, link.endpoint_b.parent]:
table.add_row(
[
link.endpoint_a.parent.hostname,
link.endpoint_a.parent.config.hostname,
str(link.endpoint_a),
link.endpoint_b.parent.hostname,
link.endpoint_b.parent.config.hostname,
str(link.endpoint_b),
link.is_up,
link.bandwidth,
@@ -251,7 +251,7 @@ class Network(SimComponent):
state = super().describe_state()
state.update(
{
"nodes": {node.hostname: node.describe_state() for node in self.nodes.values()},
"nodes": {node.config.hostname: node.describe_state() for node in self.nodes.values()},
"links": {},
}
)
@@ -259,8 +259,8 @@ class Network(SimComponent):
for _, link in self.links.items():
node_a = link.endpoint_a._connected_node
node_b = link.endpoint_b._connected_node
hostname_a = node_a.hostname if node_a else None
hostname_b = node_b.hostname if node_b else None
hostname_a = node_a.config.hostname if node_a else None
hostname_b = node_b.config.hostname if node_b else None
port_a = link.endpoint_a.port_num
port_b = link.endpoint_b.port_num
link_key = f"{hostname_a}:eth-{port_a}<->{hostname_b}:eth-{port_b}"
@@ -286,9 +286,9 @@ class Network(SimComponent):
self.nodes[node.uuid] = node
self._node_id_map[len(self.nodes)] = node
node.parent = self
self._nx_graph.add_node(node.hostname)
self._nx_graph.add_node(node.config.hostname)
_LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}")
self._node_request_manager.add_request(name=node.hostname, request_type=RequestType(func=node._request_manager))
self._node_request_manager.add_request(name=node.config.hostname, request_type=RequestType(func=node._request_manager))
def get_node_by_hostname(self, hostname: str) -> Optional[Node]:
"""
@@ -300,7 +300,7 @@ class Network(SimComponent):
:return: The Node if it exists in the network.
"""
for node in self.nodes.values():
if node.hostname == hostname:
if node.config.hostname == hostname:
return node
def remove_node(self, node: Node) -> None:
@@ -313,7 +313,7 @@ class Network(SimComponent):
:type node: Node
"""
if node not in self:
_LOGGER.warning(f"Can't remove node {node.hostname}. It's not in the network.")
_LOGGER.warning(f"Can't remove node {node.config.hostname}. It's not in the network.")
return
self.nodes.pop(node.uuid)
for i, _node in self._node_id_map.items():
@@ -321,8 +321,8 @@ class Network(SimComponent):
self._node_id_map.pop(i)
break
node.parent = None
self._node_request_manager.remove_request(name=node.hostname)
_LOGGER.info(f"Removed node {node.hostname} from network {self.uuid}")
self._node_request_manager.remove_request(name=node.config.hostname)
_LOGGER.info(f"Removed node {node.config.hostname} from network {self.uuid}")
def connect(
self, endpoint_a: WiredNetworkInterface, endpoint_b: WiredNetworkInterface, bandwidth: int = 100, **kwargs
@@ -352,7 +352,7 @@ class Network(SimComponent):
link = Link(endpoint_a=endpoint_a, endpoint_b=endpoint_b, bandwidth=bandwidth, **kwargs)
self.links[link.uuid] = link
self._link_id_map[len(self.links)] = link
self._nx_graph.add_edge(endpoint_a.parent.hostname, endpoint_b.parent.hostname)
self._nx_graph.add_edge(endpoint_a.parent.config.hostname, endpoint_b.parent.config.hostname)
link.parent = self
_LOGGER.debug(f"Added link {link.uuid} to connect {endpoint_a} and {endpoint_b}")
return link

View File

@@ -431,7 +431,7 @@ class WiredNetworkInterface(NetworkInterface, ABC):
self.enabled = True
self._connected_node.sys_log.info(f"Network Interface {self} enabled")
self.pcap = PacketCapture(
hostname=self._connected_node.hostname, port_num=self.port_num, port_name=self.port_name
hostname=self._connected_node.config.hostname, port_num=self.port_num, port_name=self.port_name
)
if self._connected_link:
self._connected_link.endpoint_up()
@@ -1515,14 +1515,16 @@ class Node(SimComponent, ABC):
_identifier: ClassVar[str] = "unknown"
"""Identifier for this particular class, used for printing and logging. Each subclass redefines this."""
config: Node.ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema())
config: Node.ConfigSchema
"""Configuration items within Node"""
class ConfigSchema(BaseModel, ABC):
"""Configuration Schema for Node based classes."""
model_config = ConfigDict(arbitrary_types_allowed=True)
"""Configure pydantic to allow arbitrary types and to let the instance have attributes not present in the model."""
hostname: str
hostname: str = "default"
"The node hostname on the network."
revealed_to_red: bool = False
@@ -1552,6 +1554,7 @@ class Node(SimComponent, ABC):
red_scan_countdown: int = 0
"Time steps until reveal to red scan is complete."
@classmethod
def from_config(cls, config: Dict) -> "Node":
"""Create Node object from a given configuration dictionary."""
@@ -1586,11 +1589,11 @@ class Node(SimComponent, ABC):
provided.
"""
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(kwargs["hostname"])
kwargs["sys_log"] = SysLog(kwargs["config"].hostname)
if not kwargs.get("session_manager"):
kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"))
if not kwargs.get("root"):
kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"]
kwargs["root"] = SIM_OUTPUT.path / kwargs["config"].hostname
if not kwargs.get("file_system"):
kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs")
if not kwargs.get("software_manager"):
@@ -1601,10 +1604,12 @@ class Node(SimComponent, ABC):
file_system=kwargs.get("file_system"),
dns_server=kwargs.get("dns_server"),
)
super().__init__(**kwargs)
self._install_system_software()
self.session_manager.node = self
self.session_manager.software_manager = self.software_manager
self.power_on()
@property
def user_manager(self) -> Optional[UserManager]:
@@ -1856,7 +1861,7 @@ class Node(SimComponent, ABC):
state = super().describe_state()
state.update(
{
"hostname": self.hostname,
"hostname": self.config.hostname,
"operating_state": self.operating_state.value,
"NICs": {
eth_num: network_interface.describe_state()

View File

@@ -333,10 +333,13 @@ class HostNode(Node, identifier="HostNode"):
"""Configuration Schema for HostNode class."""
hostname: str = "HostNode"
ip_address: IPV4Address = "192.168.0.1"
subnet_mask: IPV4Address = "255.255.255.0"
default_gateway: IPV4Address = "192.168.10.1"
def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask))
self.connect_nic(NIC(ip_address=kwargs["config"].ip_address, subnet_mask=kwargs["config"].subnet_mask))
@property
def nmap(self) -> Optional[NMAP]:

View File

@@ -1,4 +1,6 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from typing import ClassVar
from pydantic import Field
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
@@ -30,8 +32,23 @@ class Server(HostNode, identifier="server"):
* Web Browser
"""
config: "Server.ConfigSchema" = Field(default_factory=lambda: Server.ConfigSchema())
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Server class."""
hostname: str = "server"
class Printer(HostNode, identifier="printer"):
"""Printer? I don't even know her!."""
# TODO: Implement printer-specific behaviour
config: "Printer.ConfigSchema" = Field(default_factory=lambda: Printer.ConfigSchema())
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Printer class."""
hostname: ClassVar[str] = "printer"

View File

@@ -99,19 +99,22 @@ class Firewall(Router, identifier="firewall"):
)
"""Access Control List for managing traffic leaving towards an external network."""
_identifier: str = "firewall"
config: "Firewall.ConfigSchema" = Field(default_factory=lambda: Firewall.ConfigSchema())
class ConfigSchema(Router.ConfigSChema):
class ConfigSchema(Router.ConfigSchema):
"""Configuration Schema for Firewall 'Nodes' within PrimAITE."""
hostname: str = "Firewall"
hostname: str = "firewall"
num_ports: int = 0
operating_state: NodeOperatingState = NodeOperatingState.ON
def __init__(self, **kwargs):
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(self.config.hostname)
kwargs["sys_log"] = SysLog(kwargs["config"].hostname)
super().__init__(hostname=self.config.hostname, num_ports=self.config.num_ports, **kwargs)
super().__init__(**kwargs)
self.connect_nic(
RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", port_name="external")
@@ -124,22 +127,22 @@ class Firewall(Router, identifier="firewall"):
)
# Update ACL objects with firewall's hostname and syslog to allow accurate logging
self.internal_inbound_acl.sys_log = kwargs["sys_log"]
self.internal_inbound_acl.name = f"{hostname} - Internal Inbound"
self.internal_inbound_acl.name = f"{kwargs['config'].hostname} - Internal Inbound"
self.internal_outbound_acl.sys_log = kwargs["sys_log"]
self.internal_outbound_acl.name = f"{hostname} - Internal Outbound"
self.internal_outbound_acl.name = f"{kwargs['config'].hostname} - Internal Outbound"
self.dmz_inbound_acl.sys_log = kwargs["sys_log"]
self.dmz_inbound_acl.name = f"{hostname} - DMZ Inbound"
self.dmz_inbound_acl.name = f"{kwargs['config'].hostname} - DMZ Inbound"
self.dmz_outbound_acl.sys_log = kwargs["sys_log"]
self.dmz_outbound_acl.name = f"{hostname} - DMZ Outbound"
self.dmz_outbound_acl.name = f"{kwargs['config'].hostname} - DMZ Outbound"
self.external_inbound_acl.sys_log = kwargs["sys_log"]
self.external_inbound_acl.name = f"{hostname} - External Inbound"
self.external_inbound_acl.name = f"{kwargs['config'].hostname} - External Inbound"
self.external_outbound_acl.sys_log = kwargs["sys_log"]
self.external_outbound_acl.name = f"{hostname} - External Outbound"
self.external_outbound_acl.name = f"{kwargs['config'].hostname} - External Outbound"
def _init_request_manager(self) -> RequestManager:
"""
@@ -567,18 +570,21 @@ class Firewall(Router, identifier="firewall"):
self.dmz_port.enable()
@classmethod
def from_config(cls, cfg: dict) -> "Firewall":
def from_config(cls, config: dict) -> "Firewall":
"""Create a firewall based on a config dict."""
firewall = Firewall(
hostname=cfg["hostname"],
operating_state=NodeOperatingState.ON
if not (p := cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
if "ports" in cfg:
internal_port = cfg["ports"]["internal_port"]
external_port = cfg["ports"]["external_port"]
dmz_port = cfg["ports"].get("dmz_port")
# firewall = Firewall(
# hostname=config["hostname"],
# operating_state=NodeOperatingState.ON
# if not (p := config.get("operating_state"))
# else NodeOperatingState[p.upper()],
# )
firewall = Firewall(config = cls.ConfigSchema(**config))
if "ports" in config:
internal_port = config["ports"]["internal_port"]
external_port = config["ports"]["external_port"]
dmz_port = config["ports"].get("dmz_port")
# configure internal port
firewall.configure_internal_port(
@@ -598,10 +604,10 @@ class Firewall(Router, identifier="firewall"):
ip_address=IPV4Address(dmz_port.get("ip_address")),
subnet_mask=IPV4Address(dmz_port.get("subnet_mask", "255.255.255.0")),
)
if "acl" in cfg:
if "acl" in config:
# acl rules for internal_inbound_acl
if cfg["acl"]["internal_inbound_acl"]:
for r_num, r_cfg in cfg["acl"]["internal_inbound_acl"].items():
if config["acl"]["internal_inbound_acl"]:
for r_num, r_cfg in config["acl"]["internal_inbound_acl"].items():
firewall.internal_inbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -615,8 +621,8 @@ class Firewall(Router, identifier="firewall"):
)
# acl rules for internal_outbound_acl
if cfg["acl"]["internal_outbound_acl"]:
for r_num, r_cfg in cfg["acl"]["internal_outbound_acl"].items():
if config["acl"]["internal_outbound_acl"]:
for r_num, r_cfg in config["acl"]["internal_outbound_acl"].items():
firewall.internal_outbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -630,8 +636,8 @@ class Firewall(Router, identifier="firewall"):
)
# acl rules for dmz_inbound_acl
if cfg["acl"]["dmz_inbound_acl"]:
for r_num, r_cfg in cfg["acl"]["dmz_inbound_acl"].items():
if config["acl"]["dmz_inbound_acl"]:
for r_num, r_cfg in config["acl"]["dmz_inbound_acl"].items():
firewall.dmz_inbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -645,8 +651,8 @@ class Firewall(Router, identifier="firewall"):
)
# acl rules for dmz_outbound_acl
if cfg["acl"]["dmz_outbound_acl"]:
for r_num, r_cfg in cfg["acl"]["dmz_outbound_acl"].items():
if config["acl"]["dmz_outbound_acl"]:
for r_num, r_cfg in config["acl"]["dmz_outbound_acl"].items():
firewall.dmz_outbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -660,8 +666,8 @@ class Firewall(Router, identifier="firewall"):
)
# acl rules for external_inbound_acl
if cfg["acl"].get("external_inbound_acl"):
for r_num, r_cfg in cfg["acl"]["external_inbound_acl"].items():
if config["acl"].get("external_inbound_acl"):
for r_num, r_cfg in config["acl"]["external_inbound_acl"].items():
firewall.external_inbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -675,8 +681,8 @@ class Firewall(Router, identifier="firewall"):
)
# acl rules for external_outbound_acl
if cfg["acl"].get("external_outbound_acl"):
for r_num, r_cfg in cfg["acl"]["external_outbound_acl"].items():
if config["acl"].get("external_outbound_acl"):
for r_num, r_cfg in config["acl"]["external_outbound_acl"].items():
firewall.external_outbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -689,16 +695,16 @@ class Firewall(Router, identifier="firewall"):
position=r_num,
)
if "routes" in cfg:
for route in cfg.get("routes"):
if "routes" in config:
for route in config.get("routes"):
firewall.route_table.add_route(
address=IPv4Address(route.get("address")),
subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")),
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
metric=float(route.get("metric", 0)),
)
if "default_route" in cfg:
next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None)
if "default_route" in config:
next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None)
if next_hop_ip_address:
firewall.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)

View File

@@ -1212,21 +1212,34 @@ class Router(NetworkNode, identifier="router"):
network_interface: Dict[int, RouterInterface] = {}
"The Router Interfaces on the node by port id."
config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSchema())
sys_log: SysLog = None
acl: AccessControlList = None
route_table: RouteTable = None
config: "Router.ConfigSchema"
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Router Objects."""
num_ports: int = 5
"""Number of ports available for this Router. Default is 5"""
hostname: str = "Router"
ports: list = []
sys_log: SysLog = SysLog(hostname)
acl: AccessControlList = AccessControlList(sys_log=sys_log, implicit_action=ACLAction.DENY, name=hostname)
route_table: RouteTable = RouteTable(sys_log=sys_log)
ports: Dict[Union[int, str], Dict] = {}
def __init__(self, **kwargs):
super().__init__(hostname=self.config.hostname, num_ports=self.config.num_ports, **kwargs)
self.session_manager = RouterSessionManager(sys_log=self.config.sys_log)
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(kwargs["config"].hostname)
if not kwargs.get("acl"):
kwargs["acl"] = AccessControlList(sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=kwargs["config"].hostname)
if not kwargs.get("route_table"):
kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"])
super().__init__(**kwargs)
self.session_manager = RouterSessionManager(sys_log=self.sys_log)
self.session_manager.node = self
self.software_manager.session_manager = self.session_manager
self.session_manager.software_manager = self.software_manager
@@ -1234,9 +1247,11 @@ class Router(NetworkNode, identifier="router"):
network_interface = RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
self.connect_nic(network_interface)
self.network_interface[i] = network_interface
self.operating_state = NodeOperatingState.ON
self._set_default_acl()
def _install_system_software(self):
"""
Installs essential system software and network services on the router.
@@ -1260,10 +1275,10 @@ class Router(NetworkNode, identifier="router"):
Initializes the router's ACL (Access Control List) with default rules, permitting essential protocols like ARP
and ICMP, which are necessary for basic network operations and diagnostics.
"""
self.config.acl.add_rule(
self.acl.add_rule(
action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22
)
self.config.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23)
self.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23)
def setup_for_episode(self, episode: int):
"""
@@ -1287,7 +1302,7 @@ class Router(NetworkNode, identifier="router"):
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
rm.add_request("acl", RequestType(func=self.config.acl._request_manager))
rm.add_request("acl", RequestType(func=self.acl._request_manager))
return rm
def ip_is_router_interface(self, ip_address: IPv4Address, enabled_only: bool = False) -> bool:
@@ -1341,7 +1356,7 @@ class Router(NetworkNode, identifier="router"):
"""
state = super().describe_state()
state["num_ports"] = self.config.num_ports
state["acl"] = self.config.acl.describe_state()
state["acl"] = self.acl.describe_state()
return state
def check_send_frame_to_session_manager(self, frame: Frame) -> bool:
@@ -1562,7 +1577,7 @@ class Router(NetworkNode, identifier="router"):
print(table)
def setup_router(self, cfg: dict) -> Router:
""" TODO: This is the extra bit of Router's from_config metho. Needs sorting."""
"""TODO: This is the extra bit of Router's from_config metho. Needs sorting."""
if "ports" in cfg:
for port_num, port_cfg in cfg["ports"].items():
self.configure_port(
@@ -1594,5 +1609,100 @@ class Router(NetworkNode, identifier="router"):
if "default_route" in cfg:
next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None)
if next_hop_ip_address:
self.config.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
self.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
return self
@classmethod
def from_config(cls, config: dict, **kwargs) -> "Router":
"""Create a router based on a config dict.
Schema:
- hostname (str): unique name for this router.
- num_ports (int, optional): Number of network ports on the router. 8 by default
- ports (dict): Dict with integers from 1 - num_ports as keys. The values should be another dict specifying
ip_address and subnet_mask assigned to that ports (as strings)
- acl (dict): Dict with integers from 1 - max_acl_rules as keys. The key defines the position within the ACL
where the rule will be added (lower number is resolved first). The values should describe valid ACL
Rules as:
- action (str): either PERMIT or DENY
- src_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
- dst_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
- protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP
- src_ip_address (str, optional): IP address octet written in base 10
- dst_ip_address (str, optional): IP address octet written in base 10
- routes (list[dict]): List of route dicts with values:
- address (str): The destination address of the route.
- subnet_mask (str): The subnet mask of the route.
- next_hop_ip_address (str): The next hop IP for the route.
- metric (int): The metric of the route. Optional.
- default_route:
- next_hop_ip_address (str): The next hop IP for the route.
Example config:
```
{
'hostname': 'router_1',
'num_ports': 5,
'ports': {
1: {
'ip_address' : '192.168.1.1',
'subnet_mask' : '255.255.255.0',
},
2: {
'ip_address' : '192.168.0.1',
'subnet_mask' : '255.255.255.252',
}
},
'acl' : {
21: {'action': 'PERMIT', 'src_port': 'HTTP', dst_port: 'HTTP'},
22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'},
23: {'action': 'PERMIT', 'protocol': 'ICMP'},
},
'routes' : [
{'address': '192.168.0.0', 'subnet_mask': '255.255.255.0', 'next_hop_ip_address': '192.168.1.2'}
],
'default_route': {'next_hop_ip_address': '192.168.0.2'}
}
```
:param cfg: Router config adhering to schema described in main docstring body
:type cfg: dict
:return: Configured router.
:rtype: Router
"""
router = Router(config=Router.ConfigSchema(**config)
)
if "ports" in config:
for port_num, port_cfg in config["ports"].items():
router.configure_port(
port=port_num,
ip_address=port_cfg["ip_address"],
subnet_mask=IPv4Address(port_cfg.get("subnet_mask", "255.255.255.0")),
)
if "acl" in config:
for r_num, r_cfg in config["acl"].items():
router.acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p],
protocol=None if not (p := r_cfg.get("protocol")) else PROTOCOL_LOOKUP[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
if "routes" in config:
for route in config.get("routes"):
router.route_table.add_route(
address=IPv4Address(route.get("address")),
subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")),
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
metric=float(route.get("metric", 0)),
)
if "default_route" in config:
next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None)
if next_hop_ip_address:
router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
return router

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from __future__ import annotations
from typing import Dict, Optional
from typing import ClassVar, Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from pydantic import Field
@@ -102,7 +102,7 @@ class Switch(NetworkNode, identifier="switch"):
mac_address_table: Dict[str, SwitchPort] = {}
"A MAC address table mapping destination MAC addresses to corresponding SwitchPorts."
config: "Switch.ConfigSchema" = Field(default_factory=lambda: Switch.ConfigSchema())
config: "Switch.ConfigSchema"
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Switch nodes within PrimAITE."""
@@ -113,7 +113,7 @@ class Switch(NetworkNode, identifier="switch"):
def __init__(self, **kwargs):
super().__init__(**kwargs)
for i in range(1, self.config.num_ports + 1):
for i in range(1, kwargs["config"].num_ports + 1):
self.connect_nic(SwitchPort())
def _install_system_software(self):

View File

@@ -294,7 +294,7 @@ class IOSoftware(Software):
"""
if self.software_manager and self.software_manager.node.operating_state != NodeOperatingState.ON:
self.software_manager.node.sys_log.error(
f"{self.name} Error: {self.software_manager.node.hostname} is not powered on."
f"{self.name} Error: {self.software_manager.node.config.hostname} is not powered on."
)
return False
return True

View File

@@ -187,7 +187,7 @@ agents:
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
include_nmne: true
monitored_traffic:
icmp:
- NONE

View File

@@ -195,68 +195,91 @@ def example_network() -> Network:
network = Network()
# Router 1
router_1_cfg = {"hostname":"router_1", "type":"router"}
# router_1 = Router(hostname="router_1", start_up_duration=0)
router_1 = Router(hostname="router_1", start_up_duration=0)
router_1 = Router.from_config(config=router_1_cfg)
router_1.power_on()
router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0")
# Switch 1
# switch_1_config = Switch.ConfigSchema()
switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0)
switch_1_cfg = {"hostname": "switch_1", "type": "switch"}
switch_1 = Switch.from_config(config=switch_1_cfg)
# switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0)
switch_1.power_on()
network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[8])
router_1.enable_port(1)
# Switch 2
# switch_2_config = Switch.ConfigSchema()
switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0)
switch_2_config = {"hostname": "switch_2", "type": "switch", "num_ports": 8}
# switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0)
switch_2 = Switch.from_config(config=switch_2_config)
switch_2.power_on()
network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8])
router_1.enable_port(2)
# Client 1
client_1 = Computer(
hostname="client_1",
ip_address="192.168.10.21",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
start_up_duration=0,
)
# # Client 1
client_1_cfg = {"type": "computer",
"hostname": "client_1",
"ip_address": "192.168.10.21",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.10.1",
"start_up_duration": 0}
client_1=Computer.from_config(config=client_1_cfg)
client_1.power_on()
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
# Client 2
client_2 = Computer(
hostname="client_2",
ip_address="192.168.10.22",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
start_up_duration=0,
)
# # Client 2
client_2_cfg = {"type": "computer",
"hostname": "client_2",
"ip_address": "192.168.10.22",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.10.1",
"start_up_duration": 0,
}
client_2 = Computer.from_config(config=client_2_cfg)
client_2.power_on()
network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.network_interface[2])
# Server 1
server_1 = Server(
hostname="server_1",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
# # Server 1
server_1_cfg = {"type": "server",
"hostname": "server_1",
"ip_address":"192.168.1.10",
"subnet_mask":"255.255.255.0",
"default_gateway":"192.168.1.1",
"start_up_duration":0,
}
server_1 = Server.from_config(config=server_1_cfg)
server_1.power_on()
network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1])
# DServer 2
server_2 = Server(
hostname="server_2",
ip_address="192.168.1.14",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
# # DServer 2
server_2_cfg = {"type": "server",
"hostname": "server_2",
"ip_address":"192.168.1.14",
"subnet_mask":"255.255.255.0",
"default_gateway":"192.168.1.1",
"start_up_duration":0,
}
server_2 = Server.from_config(config=server_2_cfg)
server_2.power_on()
network.connect(endpoint_b=server_2.network_interface[1], endpoint_a=switch_1.network_interface[2])
@@ -264,6 +287,8 @@ def example_network() -> Network:
assert all(link.is_up for link in network.links.values())
client_1.software_manager.show()
return network

View File

@@ -6,6 +6,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.port import PORT_LOOKUP
from tests.integration_tests.configuration_file_parsing import DMZ_NETWORK, load_config

View File

@@ -36,7 +36,7 @@ def test_acl_observations(simulation):
router.acl.add_rule(action=ACLAction.PERMIT, dst_port=PORT_LOOKUP["NTP"], src_port=PORT_LOOKUP["NTP"], position=1)
acl_obs = ACLObservation(
where=["network", "nodes", router.hostname, "acl", "acl"],
where=["network", "nodes", router.config.hostname, "acl", "acl"],
ip_list=[],
port_list=[123, 80, 5432],
protocol_list=["tcp", "udp", "icmp"],

View File

@@ -24,7 +24,7 @@ def test_file_observation(simulation):
file = pc.file_system.create_file(file_name="dog.png")
dog_file_obs = FileObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"],
include_num_access=False,
file_system_requires_scan=True,
)
@@ -52,7 +52,7 @@ def test_folder_observation(simulation):
file = pc.file_system.create_file(file_name="dog.png", folder_name="test_folder")
root_folder_obs = FolderObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"],
where=["network", "nodes", pc.config.hostname, "file_system", "folders", "test_folder"],
include_num_access=False,
file_system_requires_scan=True,
num_files=1,

View File

@@ -50,7 +50,7 @@ def test_wireless_router_from_config():
},
}
rt = Router.from_config(cfg=cfg)
rt = Router.from_config(config=cfg)
assert rt.num_ports == 6