#2887 - Initial commit of Node refactor for extensibility in version 4.0.0. Addition of ConfigSchema and changes to how Nodes are generated within Game.py

This commit is contained in:
Charlie Crane
2025-01-15 11:21:18 +00:00
parent eb91721518
commit 582e7cfec7
8 changed files with 185 additions and 125 deletions

View File

@@ -19,14 +19,8 @@ from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.creation import NetworkNodeAdder from primaite.simulator.network.creation import NetworkNodeAdder
from primaite.simulator.network.hardware.base import NetworkInterface, Node, NodeOperatingState, UserManager from primaite.simulator.network.hardware.base import NetworkInterface, Node, NodeOperatingState, UserManager
from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode
from primaite.simulator.network.hardware.nodes.network.router import Router
from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.sim_container import Simulation from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.application import Application
@@ -277,68 +271,73 @@ class PrimaiteGame:
for node_cfg in nodes_cfg: for node_cfg in nodes_cfg:
n_type = node_cfg["type"] n_type = node_cfg["type"]
node_config: dict = node_cfg["config"]
new_node = None new_node = None
if n_type in Node._registry:
# simplify down Node creation:
new_node = Node._registry["n_type"].from_config(config=node_config)
# Default PrimAITE nodes # Default PrimAITE nodes
if n_type == "computer": # if n_type == "computer":
new_node = Computer( # new_node = Computer(
hostname=node_cfg["hostname"], # hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"], # ip_address=node_cfg["ip_address"],
subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), # subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
default_gateway=node_cfg.get("default_gateway"), # default_gateway=node_cfg.get("default_gateway"),
dns_server=node_cfg.get("dns_server", None), # dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON # operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state")) # if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()], # else NodeOperatingState[p.upper()],
) # )
elif n_type == "server": # elif n_type == "server":
new_node = Server( # new_node = Server(
hostname=node_cfg["hostname"], # hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"], # ip_address=node_cfg["ip_address"],
subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), # subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
default_gateway=node_cfg.get("default_gateway"), # default_gateway=node_cfg.get("default_gateway"),
dns_server=node_cfg.get("dns_server", None), # dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON # operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state")) # if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()], # else NodeOperatingState[p.upper()],
) # )
elif n_type == "switch": # elif n_type == "switch":
new_node = Switch( # new_node = Switch(
hostname=node_cfg["hostname"], # hostname=node_cfg["hostname"],
num_ports=int(node_cfg.get("num_ports", "8")), # num_ports=int(node_cfg.get("num_ports", "8")),
operating_state=NodeOperatingState.ON # operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state")) # if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()], # else NodeOperatingState[p.upper()],
) # )
elif n_type == "router": # elif n_type == "router":
new_node = Router.from_config(node_cfg) # new_node = Router.from_config(node_cfg)
elif n_type == "firewall": # elif n_type == "firewall":
new_node = Firewall.from_config(node_cfg) # new_node = Firewall.from_config(node_cfg)
elif n_type == "wireless_router": # elif n_type == "wireless_router":
new_node = WirelessRouter.from_config(node_cfg, airspace=net.airspace) # new_node = WirelessRouter.from_config(node_cfg, airspace=net.airspace)
elif n_type == "printer": # elif n_type == "printer":
new_node = Printer( # new_node = Printer(
hostname=node_cfg["hostname"], # hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"], # ip_address=node_cfg["ip_address"],
subnet_mask=node_cfg["subnet_mask"], # subnet_mask=node_cfg["subnet_mask"],
operating_state=NodeOperatingState.ON # operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state")) # if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()], # else NodeOperatingState[p.upper()],
) # )
# Handle extended nodes # # Handle extended nodes
elif n_type.lower() in Node._registry: # elif n_type.lower() in Node._registry:
new_node = HostNode._registry[n_type]( # new_node = HostNode._registry[n_type](
hostname=node_cfg["hostname"], # hostname=node_cfg["hostname"],
ip_address=node_cfg.get("ip_address"), # ip_address=node_cfg.get("ip_address"),
subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), # subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
default_gateway=node_cfg.get("default_gateway"), # default_gateway=node_cfg.get("default_gateway"),
dns_server=node_cfg.get("dns_server", None), # dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON # operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state")) # if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()], # else NodeOperatingState[p.upper()],
) # )
elif n_type in NetworkNode._registry: # elif n_type in NetworkNode._registry:
new_node = NetworkNode._registry[n_type](**node_cfg) # new_node = NetworkNode._registry[n_type](**node_cfg)
else: else:
msg = f"invalid node type {n_type} in config" msg = f"invalid node type {n_type} in config"
_LOGGER.error(msg) _LOGGER.error(msg)

View File

@@ -1469,7 +1469,7 @@ class UserSessionManager(Service):
return self.local_session is not None return self.local_session is not None
class Node(SimComponent): class Node(SimComponent, ABC):
""" """
A basic Node class that represents a node on the network. A basic Node class that represents a node on the network.
@@ -1492,7 +1492,6 @@ class Node(SimComponent):
"The Network Interfaces on the node by port id." "The Network Interfaces on the node by port id."
dns_server: Optional[IPv4Address] = None dns_server: Optional[IPv4Address] = None
"List of IP addresses of DNS servers used for name resolution." "List of IP addresses of DNS servers used for name resolution."
accounts: Dict[str, Account] = {} accounts: Dict[str, Account] = {}
"All accounts on the node." "All accounts on the node."
applications: Dict[str, Application] = {} applications: Dict[str, Application] = {}
@@ -1509,33 +1508,6 @@ class Node(SimComponent):
session_manager: SessionManager session_manager: SessionManager
software_manager: SoftwareManager software_manager: SoftwareManager
revealed_to_red: bool = False
"Informs whether the node has been revealed to a red agent."
start_up_duration: int = 3
"Time steps needed for the node to start up."
start_up_countdown: int = 0
"Time steps needed until node is booted up."
shut_down_duration: int = 3
"Time steps needed for the node to shut down."
shut_down_countdown: int = 0
"Time steps needed until node is shut down."
is_resetting: bool = False
"If true, the node will try turning itself off then back on again."
node_scan_duration: int = 10
"How many timesteps until the whole node is scanned. Default 10 time steps."
node_scan_countdown: int = 0
"Time steps until scan is complete"
red_scan_countdown: int = 0
"Time steps until reveal to red scan is complete."
SYSTEM_SOFTWARE: ClassVar[Dict[str, Type[Software]]] = {} SYSTEM_SOFTWARE: ClassVar[Dict[str, Type[Software]]] = {}
"Base system software that must be preinstalled." "Base system software that must be preinstalled."
@@ -1545,6 +1517,46 @@ class Node(SimComponent):
_identifier: ClassVar[str] = "unknown" _identifier: ClassVar[str] = "unknown"
"""Identifier for this particular class, used for printing and logging. Each subclass redefines this.""" """Identifier for this particular class, used for printing and logging. Each subclass redefines this."""
config: Node.ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema())
class ConfigSchema:
"""Configuration Schema for Node based classes."""
revealed_to_red: bool = False
"Informs whether the node has been revealed to a red agent."
start_up_duration: int = 3
"Time steps needed for the node to start up."
start_up_countdown: int = 0
"Time steps needed until node is booted up."
shut_down_duration: int = 3
"Time steps needed for the node to shut down."
shut_down_countdown: int = 0
"Time steps needed until node is shut down."
is_resetting: bool = False
"If true, the node will try turning itself off then back on again."
node_scan_duration: int = 10
"How many timesteps until the whole node is scanned. Default 10 time steps."
node_scan_countdown: int = 0
"Time steps until scan is complete"
red_scan_countdown: int = 0
"Time steps until reveal to red scan is complete."
def from_config(cls, config: Dict) -> Node:
"""Create Node object from a given configuration."""
if config["type"] not in cls._registry:
msg = f"Configuration contains an invalid Node type: {config['type']}"
return ValueError(msg)
obj = cls(config=cls.ConfigSchema(**config))
return obj
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None: def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
""" """
Register a node type. Register a node type.
@@ -1850,7 +1862,7 @@ class Node(SimComponent):
"applications": {app.name: app.describe_state() for app in self.applications.values()}, "applications": {app.name: app.describe_state() for app in self.applications.values()},
"services": {svc.name: svc.describe_state() for svc in self.services.values()}, "services": {svc.name: svc.describe_state() for svc in self.services.values()},
"process": {proc.name: proc.describe_state() for proc in self.processes.values()}, "process": {proc.name: proc.describe_state() for proc in self.processes.values()},
"revealed_to_red": self.revealed_to_red, "revealed_to_red": self.config.revealed_to_red,
} }
) )
return state return state
@@ -1928,8 +1940,8 @@ class Node(SimComponent):
network_interface.apply_timestep(timestep=timestep) network_interface.apply_timestep(timestep=timestep)
# count down to boot up # count down to boot up
if self.start_up_countdown > 0: if self.config.start_up_countdown > 0:
self.start_up_countdown -= 1 self.config.start_up_countdown -= 1
else: else:
if self.operating_state == NodeOperatingState.BOOTING: if self.operating_state == NodeOperatingState.BOOTING:
self.operating_state = NodeOperatingState.ON self.operating_state = NodeOperatingState.ON
@@ -1940,8 +1952,8 @@ class Node(SimComponent):
self._start_up_actions() self._start_up_actions()
# count down to shut down # count down to shut down
if self.shut_down_countdown > 0: if self.config.shut_down_countdown > 0:
self.shut_down_countdown -= 1 self.config.shut_down_countdown -= 1
else: else:
if self.operating_state == NodeOperatingState.SHUTTING_DOWN: if self.operating_state == NodeOperatingState.SHUTTING_DOWN:
self.operating_state = NodeOperatingState.OFF self.operating_state = NodeOperatingState.OFF
@@ -1949,17 +1961,17 @@ class Node(SimComponent):
self._shut_down_actions() self._shut_down_actions()
# if resetting turn back on # if resetting turn back on
if self.is_resetting: if self.config.is_resetting:
self.is_resetting = False self.config.is_resetting = False
self.power_on() self.power_on()
# time steps which require the node to be on # time steps which require the node to be on
if self.operating_state == NodeOperatingState.ON: if self.operating_state == NodeOperatingState.ON:
# node scanning # node scanning
if self.node_scan_countdown > 0: if self.config.node_scan_countdown > 0:
self.node_scan_countdown -= 1 self.config.node_scan_countdown -= 1
if self.node_scan_countdown == 0: if self.config.node_scan_countdown == 0:
# scan everything! # scan everything!
for process_id in self.processes: for process_id in self.processes:
self.processes[process_id].scan() self.processes[process_id].scan()
@@ -1975,10 +1987,10 @@ class Node(SimComponent):
# scan file system # scan file system
self.file_system.scan(instant_scan=True) self.file_system.scan(instant_scan=True)
if self.red_scan_countdown > 0: if self.config.red_scan_countdown > 0:
self.red_scan_countdown -= 1 self.config.red_scan_countdown -= 1
if self.red_scan_countdown == 0: if self.config.red_scan_countdown == 0:
# scan processes # scan processes
for process_id in self.processes: for process_id in self.processes:
self.processes[process_id].reveal_to_red() self.processes[process_id].reveal_to_red()
@@ -2035,7 +2047,7 @@ class Node(SimComponent):
to the red agent. to the red agent.
""" """
self.node_scan_countdown = self.node_scan_duration self.config.node_scan_countdown = self.config.node_scan_duration
return True return True
def reveal_to_red(self) -> bool: def reveal_to_red(self) -> bool:
@@ -2051,12 +2063,12 @@ class Node(SimComponent):
`revealed_to_red` to `True`. `revealed_to_red` to `True`.
""" """
self.red_scan_countdown = self.node_scan_duration self.config.red_scan_countdown = self.config.node_scan_duration
return True return True
def power_on(self) -> bool: def power_on(self) -> bool:
"""Power on the Node, enabling its NICs if it is in the OFF state.""" """Power on the Node, enabling its NICs if it is in the OFF state."""
if self.start_up_duration <= 0: if self.config.start_up_duration <= 0:
self.operating_state = NodeOperatingState.ON self.operating_state = NodeOperatingState.ON
self._start_up_actions() self._start_up_actions()
self.sys_log.info("Power on") self.sys_log.info("Power on")
@@ -2065,14 +2077,14 @@ class Node(SimComponent):
return True return True
if self.operating_state == NodeOperatingState.OFF: if self.operating_state == NodeOperatingState.OFF:
self.operating_state = NodeOperatingState.BOOTING self.operating_state = NodeOperatingState.BOOTING
self.start_up_countdown = self.start_up_duration self.config.start_up_countdown = self.config.start_up_duration
return True return True
return False return False
def power_off(self) -> bool: def power_off(self) -> bool:
"""Power off the Node, disabling its NICs if it is in the ON state.""" """Power off the Node, disabling its NICs if it is in the ON state."""
if self.shut_down_duration <= 0: if self.config.shut_down_duration <= 0:
self._shut_down_actions() self._shut_down_actions()
self.operating_state = NodeOperatingState.OFF self.operating_state = NodeOperatingState.OFF
self.sys_log.info("Power off") self.sys_log.info("Power off")
@@ -2081,7 +2093,7 @@ class Node(SimComponent):
for network_interface in self.network_interfaces.values(): for network_interface in self.network_interfaces.values():
network_interface.disable() network_interface.disable()
self.operating_state = NodeOperatingState.SHUTTING_DOWN self.operating_state = NodeOperatingState.SHUTTING_DOWN
self.shut_down_countdown = self.shut_down_duration self.config.shut_down_countdown = self.config.shut_down_duration
return True return True
return False return False
@@ -2093,7 +2105,7 @@ class Node(SimComponent):
Applying more timesteps will eventually turn the node back on. Applying more timesteps will eventually turn the node back on.
""" """
if self.operating_state.ON: if self.operating_state.ON:
self.is_resetting = True self.config.is_resetting = True
self.sys_log.info("Resetting") self.sys_log.info("Resetting")
self.power_off() self.power_off()
return True return True

View File

@@ -1,6 +1,8 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from typing import ClassVar, Dict from typing import ClassVar, Dict
from pydantic import Field
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_client import FTPClient
@@ -35,4 +37,11 @@ class Computer(HostNode, identifier="computer"):
SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "FTPClient": FTPClient} SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "FTPClient": FTPClient}
config: "Computer.ConfigSchema" = Field(default_factory=lambda: Computer.ConfigSchema())
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Computer class."""
pass
pass pass

View File

@@ -4,6 +4,8 @@ from __future__ import annotations
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import Any, ClassVar, Dict, Optional from typing import Any, ClassVar, Dict, Optional
from pydantic import Field
from primaite import getLogger from primaite import getLogger
from primaite.simulator.network.hardware.base import ( from primaite.simulator.network.hardware.base import (
IPWiredNetworkInterface, IPWiredNetworkInterface,
@@ -325,6 +327,13 @@ class HostNode(Node, identifier="HostNode"):
network_interface: Dict[int, NIC] = {} network_interface: Dict[int, NIC] = {}
"The NICs on the node by port id." "The NICs on the node by port id."
config: HostNode.ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema())
class ConfigSchema(Node.ConfigSchema):
"""Configuration Schema for HostNode class."""
pass
def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs): def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask)) self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask))

View File

@@ -99,6 +99,13 @@ class Firewall(Router, identifier="firewall"):
) )
"""Access Control List for managing traffic leaving towards an external network.""" """Access Control List for managing traffic leaving towards an external network."""
config: "Firewall.ConfigSchema" = Field(default_factory=lambda: Firewall.ConfigSchema())
class ConfigSchema(Router.ConfigSChema):
"""Configuration Schema for Firewall 'Nodes' within PrimAITE."""
pass
def __init__(self, hostname: str, **kwargs): def __init__(self, hostname: str, **kwargs):
if not kwargs.get("sys_log"): if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(hostname) kwargs["sys_log"] = SysLog(hostname)

View File

@@ -7,7 +7,7 @@ from ipaddress import IPv4Address, IPv4Network
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable from prettytable import MARKDOWN, PrettyTable
from pydantic import validate_call from pydantic import Field, validate_call
from primaite.interface.request import RequestResponse from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.core import RequestManager, RequestType, SimComponent
@@ -1207,7 +1207,6 @@ class Router(NetworkNode, identifier="router"):
"Terminal": Terminal, "Terminal": Terminal,
} }
num_ports: int
network_interfaces: Dict[str, RouterInterface] = {} network_interfaces: Dict[str, RouterInterface] = {}
"The Router Interfaces on the node." "The Router Interfaces on the node."
network_interface: Dict[int, RouterInterface] = {} network_interface: Dict[int, RouterInterface] = {}
@@ -1215,6 +1214,15 @@ class Router(NetworkNode, identifier="router"):
acl: AccessControlList acl: AccessControlList
route_table: RouteTable route_table: RouteTable
config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSChema())
class ConfigSChema(NetworkNode.ConfigSchema):
"""Configuration Schema for Router Objects."""
num_ports: int = 10
hostname: str = "Router"
ports: list = []
def __init__(self, hostname: str, num_ports: int = 5, **kwargs): def __init__(self, hostname: str, num_ports: int = 5, **kwargs):
if not kwargs.get("sys_log"): if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(hostname) kwargs["sys_log"] = SysLog(hostname)
@@ -1227,11 +1235,11 @@ class Router(NetworkNode, identifier="router"):
self.session_manager.node = self self.session_manager.node = self
self.software_manager.session_manager = self.session_manager self.software_manager.session_manager = self.session_manager
self.session_manager.software_manager = self.software_manager self.session_manager.software_manager = self.software_manager
for i in range(1, self.num_ports + 1): for i in range(1, self.config.num_ports + 1):
network_interface = RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0") network_interface = RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
self.connect_nic(network_interface) self.connect_nic(network_interface)
self.network_interface[i] = network_interface self.network_interface[i] = network_interface
self.operating_state = NodeOperatingState.ON
self._set_default_acl() self._set_default_acl()
def _install_system_software(self): def _install_system_software(self):
@@ -1337,7 +1345,7 @@ class Router(NetworkNode, identifier="router"):
:return: A dictionary representing the current state. :return: A dictionary representing the current state.
""" """
state = super().describe_state() state = super().describe_state()
state["num_ports"] = self.num_ports state["num_ports"] = self.config.num_ports
state["acl"] = self.acl.describe_state() state["acl"] = self.acl.describe_state()
return state return state
@@ -1558,6 +1566,8 @@ class Router(NetworkNode, identifier="router"):
) )
print(table) print(table)
# TODO: Remove - Cover normal config items with ConfigSchema. Move additional setup components to __init__ ?
@classmethod @classmethod
def from_config(cls, cfg: dict, **kwargs) -> "Router": def from_config(cls, cfg: dict, **kwargs) -> "Router":
"""Create a router based on a config dict. """Create a router based on a config dict.

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
from typing import Dict, Optional from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable from prettytable import MARKDOWN, PrettyTable
from pydantic import Field
from primaite import getLogger from primaite import getLogger
from primaite.exceptions import NetworkError from primaite.exceptions import NetworkError
@@ -94,8 +95,6 @@ class Switch(NetworkNode, identifier="switch"):
:ivar num_ports: The number of ports on the switch. Default is 24. :ivar num_ports: The number of ports on the switch. Default is 24.
""" """
num_ports: int = 24
"The number of ports on the switch."
network_interfaces: Dict[str, SwitchPort] = {} network_interfaces: Dict[str, SwitchPort] = {}
"The SwitchPorts on the Switch." "The SwitchPorts on the Switch."
network_interface: Dict[int, SwitchPort] = {} network_interface: Dict[int, SwitchPort] = {}
@@ -103,9 +102,17 @@ class Switch(NetworkNode, identifier="switch"):
mac_address_table: Dict[str, SwitchPort] = {} mac_address_table: Dict[str, SwitchPort] = {}
"A MAC address table mapping destination MAC addresses to corresponding SwitchPorts." "A MAC address table mapping destination MAC addresses to corresponding SwitchPorts."
config: "Switch.ConfigSchema" = Field(default_factory=lambda: Switch.ConfigSchema())
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Switch nodes within PrimAITE."""
num_ports: int = 24
"The number of ports on the switch."
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
for i in range(1, self.num_ports + 1): for i in range(1, self.config.num_ports + 1):
self.connect_nic(SwitchPort()) self.connect_nic(SwitchPort())
def _install_system_software(self): def _install_system_software(self):
@@ -134,7 +141,7 @@ class Switch(NetworkNode, identifier="switch"):
""" """
state = super().describe_state() state = super().describe_state()
state["ports"] = {port_num: port.describe_state() for port_num, port in self.network_interface.items()} state["ports"] = {port_num: port.describe_state() for port_num, port in self.network_interface.items()}
state["num_ports"] = self.num_ports # redundant? state["num_ports"] = self.config.num_ports # redundant?
state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()} state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()}
return state return state

View File

@@ -2,7 +2,7 @@
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
from pydantic import validate_call from pydantic import Field, validate_call
from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, FREQ_WIFI_2_4, IPWirelessNetworkInterface from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, FREQ_WIFI_2_4, IPWirelessNetworkInterface
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
@@ -124,6 +124,13 @@ class WirelessRouter(Router, identifier="wireless_router"):
network_interface: Dict[int, Union[RouterInterface, WirelessAccessPoint]] = {} network_interface: Dict[int, Union[RouterInterface, WirelessAccessPoint]] = {}
airspace: AirSpace airspace: AirSpace
config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.Configschema())
class ConfigSchema(Router.ConfigSChema):
"""Configuration Schema for WirelessRouter nodes within PrimAITE."""
pass
def __init__(self, hostname: str, airspace: AirSpace, **kwargs): def __init__(self, hostname: str, airspace: AirSpace, **kwargs):
super().__init__(hostname=hostname, num_ports=0, airspace=airspace, **kwargs) super().__init__(hostname=hostname, num_ports=0, airspace=airspace, **kwargs)