Merge remote-tracking branch 'origin/feature/2887-Align_Node_Types' into feature/3062-discriminators

This commit is contained in:
Marek Wolan
2025-02-04 15:20:48 +00:00
14 changed files with 35 additions and 46 deletions

View File

@@ -18,8 +18,7 @@ Node classes all inherit from the base Node Class, though new classes should inh
The use of an `__init__` method is not necessary, as configurable variables for the class should be specified within the `config` of the class, and passed at run time via your YAML configuration using the `from_config` method.
An example of how additional Node classes is below, taken from `router.py` withing PrimAITE.
An example of how additional Node classes is below, taken from `router.py` within PrimAITE.
.. code-block:: Python
@@ -53,4 +52,4 @@ class Router(NetworkNode, identifier="router"):
Changes to YAML file.
=====================
Nodes defined within configuration YAML files for use with PrimAITE 3.X should still be compatible following these changes.
While effort has been made to ensure that nodes defined within configuration YAML files for use with PrimAITE 3.X remain compatible with PrimAITE v4+, it is encouraged to review for minor changes needed.

View File

@@ -3,7 +3,7 @@
"""Core of the PrimAITE Simulator."""
import warnings
from abc import abstractmethod
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
from uuid import uuid4
from prettytable import PrettyTable

View File

@@ -167,7 +167,6 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"):
# Optionally include a router in the LAN
if config.include_router:
default_gateway = IPv4Address(f"192.168.{config.subnet_base}.1")
# router = Router(hostname=f"router_{config.lan_name}", start_up_duration=0)
router = Router.from_config(
config={"hostname": f"router_{config.lan_name}", "type": "router", "start_up_duration": 0}
)
@@ -230,7 +229,7 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"):
"type": "computer",
"hostname": f"pc_{i}_{config.lan_name}",
"ip_address": f"192.168.{config.subnet_base}.{i+config.pcs_ip_block_start-1}",
"default_gateway": "192.168.10.1",
"default_gateway": default_gateway,
"start_up_duration": 0,
}
pc = Computer.from_config(config=pc_cfg)

View File

@@ -1526,16 +1526,13 @@ class Node(SimComponent, ABC):
_discriminator: ClassVar[str]
"""discriminator for this particular class, used for printing and logging. Each subclass redefines this."""
config: Node.ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema())
"""Configuration items within Node"""
class ConfigSchema(BaseModel, ABC):
"""Configuration Schema for Node based classes."""
model_config = ConfigDict(arbitrary_types_allowed=True)
"""Configure pydantic to allow arbitrary types, let the instance have attributes not present in the model."""
hostname: str = "default"
hostname: str
"The node hostname on the network."
revealed_to_red: bool = False
@@ -1573,6 +1570,9 @@ class Node(SimComponent, ABC):
operating_state: Any = None
config: ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema())
"""Configuration items within Node"""
@property
def dns_server(self) -> Optional[IPv4Address]:
"""Convenience method to access the dns_server IP."""
@@ -2243,10 +2243,6 @@ class Node(SimComponent, ABC):
for app_id in self.applications:
self.applications[app_id].close()
# Turn off all processes in the node
# for process_id in self.processes:
# self.processes[process_id]
def _start_up_actions(self):
"""Actions to perform when the node is starting up."""
# Turn on all the services in the node
@@ -2255,14 +2251,8 @@ class Node(SimComponent, ABC):
# Turn on all the applications in the node
for app_id in self.applications:
print(app_id)
print(f"Starting application:{self.applications[app_id].config.type}")
self.applications[app_id].run()
# Turn off all processes in the node
# for process_id in self.processes:
# self.processes[process_id]
def _install_system_software(self) -> None:
"""Preinstall required software."""
for _, software_class in self.SYSTEM_SOFTWARE.items():

View File

@@ -37,11 +37,11 @@ class Computer(HostNode, discriminator="computer"):
SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "ftp-client": FTPClient}
config: "Computer.ConfigSchema" = Field(default_factory=lambda: Computer.ConfigSchema())
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Computer class."""
hostname: str = "Computer"
config: ConfigSchema = Field(default_factory=lambda: Computer.ConfigSchema())
pass

View File

@@ -330,8 +330,6 @@ class HostNode(Node, discriminator="host-node"):
network_interface: Dict[int, NIC] = {}
"The NICs on the node by port id."
config: HostNode.ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema())
class ConfigSchema(Node.ConfigSchema):
"""Configuration Schema for HostNode class."""
@@ -339,6 +337,8 @@ class HostNode(Node, discriminator="host-node"):
subnet_mask: IPV4Address = "255.255.255.0"
ip_address: IPV4Address
config: ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema())
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.connect_nic(NIC(ip_address=kwargs["config"].ip_address, subnet_mask=kwargs["config"].subnet_mask))

View File

@@ -33,22 +33,22 @@ class Server(HostNode, discriminator="server"):
* Web Browser
"""
config: "Server.ConfigSchema" = Field(default_factory=lambda: Server.ConfigSchema())
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Server class."""
hostname: str = "server"
config: ConfigSchema = Field(default_factory=lambda: Server.ConfigSchema())
class Printer(HostNode, discriminator="printer"):
"""Printer? I don't even know her!."""
# TODO: Implement printer-specific behaviour
config: "Printer.ConfigSchema" = Field(default_factory=lambda: Printer.ConfigSchema())
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Printer class."""
hostname: str = "printer"
config: ConfigSchema = Field(default_factory=lambda: Printer.ConfigSchema())

View File

@@ -100,14 +100,14 @@ class Firewall(Router, discriminator="firewall"):
_identifier: str = "firewall"
config: "Firewall.ConfigSchema" = Field(default_factory=lambda: Firewall.ConfigSchema())
class ConfigSchema(Router.ConfigSchema):
"""Configuration Schema for Firewall 'Nodes' within PrimAITE."""
hostname: str = "firewall"
num_ports: int = 0
config: ConfigSchema = Field(default_factory=lambda: Firewall.ConfigSchema())
def __init__(self, **kwargs):
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(kwargs["config"].hostname)

View File

@@ -7,7 +7,7 @@ from ipaddress import IPv4Address, IPv4Network
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import validate_call
from pydantic import Field, validate_call
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType, SimComponent
@@ -1201,6 +1201,14 @@ class Router(NetworkNode, discriminator="router"):
RouteTable, RouterARP, and RouterICMP services.
"""
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Routers."""
hostname: str = "router"
num_ports: int = 5
config: ConfigSchema = Field(default_factory=lambda: Router.ConfigSchema())
SYSTEM_SOFTWARE: ClassVar[Dict] = {
"UserSessionManager": UserSessionManager,
"UserManager": UserManager,
@@ -1214,14 +1222,6 @@ class Router(NetworkNode, discriminator="router"):
acl: AccessControlList
route_table: RouteTable
config: "Router.ConfigSchema"
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Routers."""
hostname: str = "router"
num_ports: int = 5
def __init__(self, **kwargs):
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(kwargs["config"].hostname)

View File

@@ -98,8 +98,6 @@ class Switch(NetworkNode, discriminator="switch"):
mac_address_table: Dict[str, SwitchPort] = {}
"A MAC address table mapping destination MAC addresses to corresponding SwitchPorts."
config: "Switch.ConfigSchema" = Field(default_factory=lambda: Switch.ConfigSchema())
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Switch nodes within PrimAITE."""
@@ -107,9 +105,11 @@ class Switch(NetworkNode, discriminator="switch"):
num_ports: int = 24
"The number of ports on the switch. Default is 24."
config: ConfigSchema = Field(default_factory=lambda: Switch.ConfigSchema())
def __init__(self, **kwargs):
super().__init__(**kwargs)
for i in range(1, kwargs["config"].num_ports + 1):
for i in range(1, self.config.num_ports + 1):
self.connect_nic(SwitchPort())
def _install_system_software(self):

View File

@@ -123,8 +123,6 @@ class WirelessRouter(Router, discriminator="wireless-router"):
network_interfaces: Dict[str, Union[RouterInterface, WirelessAccessPoint]] = {}
network_interface: Dict[int, Union[RouterInterface, WirelessAccessPoint]] = {}
config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.ConfigSchema())
class ConfigSchema(Router.ConfigSchema):
"""Configuration Schema for WirelessRouter nodes within PrimAITE."""
@@ -132,6 +130,8 @@ class WirelessRouter(Router, discriminator="wireless-router"):
airspace: AirSpace
num_ports: int = 0
config: ConfigSchema = Field(default_factory=lambda: WirelessRouter.ConfigSchema())
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -26,6 +26,7 @@ class DNSClient(Service, discriminator="dns-client"):
type: str = "dns-client"
dns_server: Optional[IPV4Address] = None
"The DNS Server the client sends requests to."
config: ConfigSchema = Field(default_factory=lambda: DNSClient.ConfigSchema())
dns_cache: Dict[str, IPv4Address] = {}

View File

@@ -20,8 +20,6 @@ class FTPServer(FTPServiceABC, discriminator="ftp-server"):
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
"""
server_password: Optional[str] = None
class ConfigSchema(FTPServiceABC.ConfigSchema):
"""ConfigSchema for FTPServer."""
@@ -29,6 +27,7 @@ class FTPServer(FTPServiceABC, discriminator="ftp-server"):
server_password: Optional[str] = None
config: ConfigSchema = Field(default_factory=lambda: FTPServer.ConfigSchema())
server_password: Optional[str] = None
def __init__(self, **kwargs):
kwargs["name"] = "ftp-server"

View File

@@ -23,6 +23,7 @@ class NTPClient(Service, discriminator="ntp-client"):
type: str = "ntp-client"
ntp_server_ip: Optional[IPV4Address] = None
"The NTP server the client sends requests to."
config: ConfigSchema = Field(default_factory=lambda: NTPClient.ConfigSchema())