bugfix - Make node schemas stricter

This commit is contained in:
Marek Wolan
2025-02-05 15:04:41 +00:00
parent 4a472c5c75
commit c1abbfe58c
17 changed files with 73 additions and 34 deletions

View File

@@ -102,8 +102,7 @@ simulation:
subnet_mask: 255.255.255.252
default_gateway: 8.8.8.1
services:
- ref: dns_server
type: dns-server
- type: dns-server
options:
domain_mapping:
sometech.ai: 94.10.180.6
@@ -196,8 +195,7 @@ simulation:
default_gateway: 94.10.180.5
dns_server: 8.8.8.2
services:
- ref: web_server
type: web-server
- type: web-server
applications:
- type: database-client
options:

View File

@@ -14,6 +14,7 @@ from primaite.simulator.network.creation import NetworkNodeAdder
from primaite.simulator.network.hardware.base import NetworkInterface, Node, NodeOperatingState, UserManager
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.application import Application
@@ -268,9 +269,11 @@ class PrimaiteGame:
new_node = None
if n_type in Node._registry:
if n_type == "wireless-router":
node_cfg["airspace"] = net.airspace
new_node = Node._registry[n_type].from_config(config=node_cfg)
n_class = Node._registry[n_type]
if issubclass(n_class, WirelessRouter):
new_node = n_class(config=n_class.ConfigSchema(**node_cfg), airspace=net.airspace)
else:
new_node = Node._registry[n_type].from_config(config=node_cfg)
else:
msg = f"invalid node type {n_type} in config"
_LOGGER.error(msg)

View File

@@ -1529,9 +1529,11 @@ class Node(SimComponent, ABC):
class ConfigSchema(BaseModel, ABC):
"""Configuration Schema for Node based classes."""
model_config = ConfigDict(arbitrary_types_allowed=True)
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
"""Configure pydantic to allow arbitrary types, let the instance have attributes not present in the model."""
type: str
hostname: str
"The node hostname on the network."
@@ -1570,6 +1572,8 @@ class Node(SimComponent, ABC):
operating_state: Any = None
users: Any = None # Temporary to appease "extra=forbid"
config: ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema())
"""Configuration items within Node"""

View File

@@ -1,5 +1,5 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from typing import ClassVar, Dict
from typing import ClassVar, Dict, Literal
from pydantic import Field
@@ -40,6 +40,7 @@ class Computer(HostNode, discriminator="computer"):
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Computer class."""
type: Literal["computer"] = "computer"
hostname: str = "Computer"
config: ConfigSchema = Field(default_factory=lambda: Computer.ConfigSchema())

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Any, ClassVar, Dict, Optional
from typing import Any, ClassVar, Dict, Literal, Optional
from pydantic import Field
@@ -333,9 +333,14 @@ class HostNode(Node, discriminator="host-node"):
class ConfigSchema(Node.ConfigSchema):
"""Configuration Schema for HostNode class."""
type: Literal["host-node"]
hostname: str = "HostNode"
subnet_mask: IPV4Address = "255.255.255.0"
ip_address: IPV4Address
services: Any = None # temporarily unset to appease extra="forbid"
applications: Any = None # temporarily unset to appease extra="forbid"
folders: Any = None # temporarily unset to appease extra="forbid"
network_interfaces: Any = None # temporarily unset to appease extra="forbid"
config: ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema())

View File

@@ -1,5 +1,7 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from typing import Literal
from pydantic import Field
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
@@ -36,6 +38,7 @@ class Server(HostNode, discriminator="server"):
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Server class."""
type: Literal["server"] = "server"
hostname: str = "server"
config: ConfigSchema = Field(default_factory=lambda: Server.ConfigSchema())
@@ -49,6 +52,7 @@ class Printer(HostNode, discriminator="printer"):
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Printer class."""
type: Literal["printer"] = "printer"
hostname: str = "printer"
config: ConfigSchema = Field(default_factory=lambda: Printer.ConfigSchema())

View File

@@ -1,6 +1,6 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from ipaddress import IPv4Address
from typing import Dict, Final, Union
from typing import Dict, Final, Literal, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import Field, validate_call
@@ -103,6 +103,7 @@ class Firewall(Router, discriminator="firewall"):
class ConfigSchema(Router.ConfigSchema):
"""Configuration Schema for Firewall 'Nodes' within PrimAITE."""
type: Literal["firewall"] = "firewall"
hostname: str = "firewall"
num_ports: int = 0

View File

@@ -1,6 +1,6 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import abstractmethod
from typing import Optional
from typing import Any, Optional
from primaite.simulator.network.hardware.base import NetworkInterface, Node
from primaite.simulator.network.transmission.data_link_layer import Frame
@@ -16,6 +16,11 @@ class NetworkNode(Node, discriminator="network-node"):
provide functionality for receiving and processing frames received on their network interfaces.
"""
class ConfigSchema(Node.ConfigSchema):
"""Config schema for Node baseclass."""
num_ports: Any = None # temporarily unset to appease extra="forbid"
@abstractmethod
def receive_frame(self, frame: Frame, from_network_interface: NetworkInterface):
"""

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import secrets
from enum import Enum
from ipaddress import IPv4Address, IPv4Network
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import Field, validate_call
@@ -1204,8 +1204,13 @@ class Router(NetworkNode, discriminator="router"):
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Routers."""
type: Literal["router"] = "router"
hostname: str = "router"
num_ports: int = 5
acl: Any = None # temporarily unset to appease extra="forbid"
routes: Any = None # temporarily unset to appease extra="forbid"
ports: Any = None # temporarily unset to appease extra="forbid"
default_route: Any = None # temporarily unset to appease extra="forbid"
config: ConfigSchema = Field(default_factory=lambda: Router.ConfigSchema())
@@ -1625,16 +1630,20 @@ class Router(NetworkNode, discriminator="router"):
:return: Configured router.
:rtype: Router
"""
ports = config.pop("ports", None)
acl = config.pop("acl", None)
routes = config.pop("routes", None)
default_route = config.pop("default_route", None)
router = Router(config=Router.ConfigSchema(**config))
if "ports" in config:
for port_num, port_cfg in config["ports"].items():
if ports:
for port_num, port_cfg in ports.items():
router.configure_port(
port=port_num,
ip_address=port_cfg["ip_address"],
subnet_mask=IPv4Address(port_cfg.get("subnet_mask", "255.255.255.0")),
)
if "acl" in config:
for r_num, r_cfg in config["acl"].items():
if acl:
for r_num, r_cfg in acl.items():
router.acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -1646,16 +1655,16 @@ class Router(NetworkNode, discriminator="router"):
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
if "routes" in config:
for route in config.get("routes"):
if routes:
for route in routes:
router.route_table.add_route(
address=IPv4Address(route.get("address")),
subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")),
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
metric=float(route.get("metric", 0)),
)
if "default_route" in config:
next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None)
if default_route:
next_hop_ip_address = default_route.get("next_hop_ip_address", None)
if next_hop_ip_address:
router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
router.operating_state = (

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from __future__ import annotations
from typing import Dict, Optional
from typing import Dict, Literal, Optional
from prettytable import MARKDOWN, PrettyTable
from pydantic import Field
@@ -101,6 +101,7 @@ class Switch(NetworkNode, discriminator="switch"):
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Switch nodes within PrimAITE."""
type: Literal["switch"] = "switch"
hostname: str = "Switch"
num_ports: int = 24
"The number of ports on the switch. Default is 24."

View File

@@ -1,6 +1,6 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from ipaddress import IPv4Address
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Literal, Optional, Union
from pydantic import Field, validate_call
@@ -126,10 +126,13 @@ class WirelessRouter(Router, discriminator="wireless-router"):
class ConfigSchema(Router.ConfigSchema):
"""Configuration Schema for WirelessRouter nodes within PrimAITE."""
type: Literal["wireless-router"] = "wireless-router"
hostname: str = "WirelessRouter"
airspace: AirSpace
num_ports: int = 0
router_interface: Any = None # temporarily unset to appease extra="forbid"
wireless_access_point: Any = None # temporarily unset to appease extra="forbid"
airspace: AirSpace
config: ConfigSchema = Field(default_factory=lambda: WirelessRouter.ConfigSchema())
def __init__(self, **kwargs):
@@ -137,7 +140,7 @@ class WirelessRouter(Router, discriminator="wireless-router"):
self.connect_nic(
WirelessAccessPoint(
ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", airspace=kwargs["config"].airspace
ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", airspace=self.airspace
)
)
@@ -236,7 +239,7 @@ class WirelessRouter(Router, discriminator="wireless-router"):
)
@classmethod
def from_config(cls, config: Dict, **kwargs) -> "WirelessRouter":
def from_config(cls, config: Dict, airspace: AirSpace) -> "WirelessRouter":
"""Generate the wireless router from config.
Schema:
@@ -263,7 +266,7 @@ class WirelessRouter(Router, discriminator="wireless-router"):
:return: WirelessRouter instance.
:rtype: WirelessRouter
"""
router = cls(config=cls.ConfigSchema(**config))
router = cls(config=cls.ConfigSchema(**config), airspace=airspace)
router.operating_state = (
NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()]
)

View File

@@ -1,5 +1,5 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from typing import Dict
from typing import Dict, Literal
from prettytable import MARKDOWN, PrettyTable
@@ -18,6 +18,9 @@ class GigaSwitch(NetworkNode, discriminator="gigaswitch"):
:ivar num_ports: The number of ports on the switch. Default is 24.
"""
class ConfigSchema(NetworkNode.ConfigSchema):
type: Literal["gigaswitch"] = "gigaswitch"
num_ports: int = 24
"The number of ports on the switch."
network_interfaces: Dict[str, SwitchPort] = {}

View File

@@ -1,5 +1,5 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from typing import ClassVar, Dict
from typing import ClassVar, Dict, Literal
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
@@ -34,6 +34,9 @@ class SuperComputer(HostNode, discriminator="supercomputer"):
* Web Browser
"""
class ConfigSchema(HostNode.ConfigSchema):
type: Literal["supercomputer"] = "supercomputer"
SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "ftp-client": FTPClient}
def __init__(self, **kwargs):

View File

@@ -16,7 +16,7 @@ def test_wireless_link_loading(wireless_wan_network):
# Configure Router 2 ACLs
router_2.acl.add_rule(action=ACLAction.PERMIT, position=1)
airspace = router_1.config.airspace
airspace = router_1.airspace
client.software_manager.install(FTPClient)
ftp_client: FTPClient = client.software_manager.software.get("ftp-client")

View File

@@ -32,7 +32,7 @@ def wireless_wan_network():
# Configure Router 1
router_1 = WirelessRouter.from_config(
config={"type": "wireless_router", "hostname": "router_1", "start_up_duration": 0, "airspace": network.airspace}
config={"type": "wireless-router", "hostname": "router_1", "start_up_duration": 0}, airspace=network.airspace
)
router_1.power_on()
network.add_node(router_1)
@@ -63,7 +63,7 @@ def wireless_wan_network():
# Configure Router 2
router_2: WirelessRouter = WirelessRouter.from_config(
config={"type": "wireless_router", "hostname": "router_2", "start_up_duration": 0, "airspace": network.airspace}
config={"type": "wireless-router", "hostname": "router_2", "start_up_duration": 0}, airspace=network.airspace
)
router_2.power_on()
network.add_node(router_2)

View File

@@ -8,7 +8,6 @@ from primaite.utils.validation.port import PORT_LOOKUP
def test_wireless_router_from_config():
cfg = {
"ref": "router_1",
"type": "router",
"hostname": "router_1",
"num_ports": 6,

View File

@@ -95,7 +95,7 @@ def wireless_wan_network():
# Configure Router 1
router_1 = WirelessRouter.from_config(
config={"type": "wireless_router", "hostname": "router_1", "start_up_duration": 0, "airspace": network.airspace}
config={"type": "wireless-router", "hostname": "router_1", "start_up_duration": 0}, airspace=network.airspace
)
router_1.power_on()
network.add_node(router_1)