#2887 - Actioning review comments

This commit is contained in:
Charlie Crane
2025-02-04 10:21:56 +00:00
parent f3bbfffe7f
commit c1a5a26ffc
12 changed files with 33 additions and 36 deletions

View File

@@ -3,7 +3,7 @@
"""Core of the PrimAITE Simulator."""
import warnings
from abc import abstractmethod
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
from uuid import uuid4
from prettytable import PrettyTable

View File

@@ -1525,16 +1525,13 @@ class Node(SimComponent, ABC):
_identifier: ClassVar[str] = "unknown"
"""Identifier for this particular class, used for printing and logging. Each subclass redefines this."""
config: Node.ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema())
"""Configuration items within Node"""
class ConfigSchema(BaseModel, ABC):
"""Configuration Schema for Node based classes."""
model_config = ConfigDict(arbitrary_types_allowed=True)
"""Configure pydantic to allow arbitrary types, let the instance have attributes not present in the model."""
hostname: str = "default"
hostname: str
"The node hostname on the network."
revealed_to_red: bool = False
@@ -1572,6 +1569,9 @@ class Node(SimComponent, ABC):
operating_state: Any = None
config: ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema())
"""Configuration items within Node"""
@property
def dns_server(self) -> Optional[IPv4Address]:
"""Convenience method to access the dns_server IP."""

View File

@@ -37,11 +37,11 @@ class Computer(HostNode, identifier="computer"):
SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "FTPClient": FTPClient}
config: "Computer.ConfigSchema" = Field(default_factory=lambda: Computer.ConfigSchema())
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Computer class."""
hostname: str = "Computer"
config: ConfigSchema = Field(default_factory=lambda: Computer.ConfigSchema())
pass

View File

@@ -330,8 +330,6 @@ class HostNode(Node, identifier="HostNode"):
network_interface: Dict[int, NIC] = {}
"The NICs on the node by port id."
config: HostNode.ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema())
class ConfigSchema(Node.ConfigSchema):
"""Configuration Schema for HostNode class."""
@@ -339,6 +337,8 @@ class HostNode(Node, identifier="HostNode"):
subnet_mask: IPV4Address = "255.255.255.0"
ip_address: IPV4Address
config: ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema())
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.connect_nic(NIC(ip_address=kwargs["config"].ip_address, subnet_mask=kwargs["config"].subnet_mask))

View File

@@ -33,22 +33,22 @@ class Server(HostNode, identifier="server"):
* Web Browser
"""
config: "Server.ConfigSchema" = Field(default_factory=lambda: Server.ConfigSchema())
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Server class."""
hostname: str = "server"
config: ConfigSchema = Field(default_factory=lambda: Server.ConfigSchema())
class Printer(HostNode, identifier="printer"):
"""Printer? I don't even know her!."""
# TODO: Implement printer-specific behaviour
config: "Printer.ConfigSchema" = Field(default_factory=lambda: Printer.ConfigSchema())
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Printer class."""
hostname: str = "printer"
config: ConfigSchema = Field(default_factory=lambda: Printer.ConfigSchema())

View File

@@ -100,14 +100,14 @@ class Firewall(Router, identifier="firewall"):
_identifier: str = "firewall"
config: "Firewall.ConfigSchema" = Field(default_factory=lambda: Firewall.ConfigSchema())
class ConfigSchema(Router.ConfigSchema):
"""Configuration Schema for Firewall 'Nodes' within PrimAITE."""
hostname: str = "firewall"
num_ports: int = 0
config: ConfigSchema = Field(default_factory=lambda: Firewall.ConfigSchema())
def __init__(self, **kwargs):
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(kwargs["config"].hostname)

View File

@@ -7,7 +7,7 @@ from ipaddress import IPv4Address, IPv4Network
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import validate_call
from pydantic import Field, validate_call
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType, SimComponent
@@ -1201,6 +1201,14 @@ class Router(NetworkNode, identifier="router"):
RouteTable, RouterARP, and RouterICMP services.
"""
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Routers."""
hostname: str = "router"
num_ports: int = 5
config: ConfigSchema = Field(default_factory=lambda: Router.ConfigSchema())
SYSTEM_SOFTWARE: ClassVar[Dict] = {
"UserSessionManager": UserSessionManager,
"UserManager": UserManager,
@@ -1214,14 +1222,6 @@ class Router(NetworkNode, identifier="router"):
acl: AccessControlList
route_table: RouteTable
config: "Router.ConfigSchema"
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Routers."""
hostname: str = "router"
num_ports: int = 5
def __init__(self, **kwargs):
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(kwargs["config"].hostname)

View File

@@ -98,8 +98,6 @@ class Switch(NetworkNode, identifier="switch"):
mac_address_table: Dict[str, SwitchPort] = {}
"A MAC address table mapping destination MAC addresses to corresponding SwitchPorts."
config: "Switch.ConfigSchema" = Field(default_factory=lambda: Switch.ConfigSchema())
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Switch nodes within PrimAITE."""
@@ -107,6 +105,8 @@ class Switch(NetworkNode, identifier="switch"):
num_ports: int = 24
"The number of ports on the switch. Default is 24."
config: ConfigSchema = Field(default_factory=lambda: Switch.ConfigSchema())
def __init__(self, **kwargs):
super().__init__(**kwargs)
for i in range(1, kwargs["config"].num_ports + 1):

View File

@@ -123,8 +123,6 @@ class WirelessRouter(Router, identifier="wireless_router"):
network_interfaces: Dict[str, Union[RouterInterface, WirelessAccessPoint]] = {}
network_interface: Dict[int, Union[RouterInterface, WirelessAccessPoint]] = {}
config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.ConfigSchema())
class ConfigSchema(Router.ConfigSchema):
"""Configuration Schema for WirelessRouter nodes within PrimAITE."""
@@ -132,6 +130,8 @@ class WirelessRouter(Router, identifier="wireless_router"):
airspace: AirSpace
num_ports: int = 0
config: ConfigSchema = Field(default_factory=lambda: WirelessRouter.ConfigSchema())
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -25,12 +25,12 @@ class DNSClient(Service, identifier="DNSClient"):
"""ConfigSchema for DNSClient."""
type: str = "DNSClient"
dns_server: Optional[IPV4Address] = None
dns_server: Optional[IPv4Address] = None
"The DNS Server the client sends requests to."
config: "DNSClient.ConfigSchema" = Field(default_factory=lambda: DNSClient.ConfigSchema())
config: ConfigSchema = Field(default_factory=lambda: DNSClient.ConfigSchema())
dns_cache: Dict[str, IPv4Address] = {}
"A dict of known mappings between domain/URLs names and IPv4 addresses."

View File

@@ -20,17 +20,16 @@ class FTPServer(FTPServiceABC, identifier="FTPServer"):
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
"""
config: "FTPServer.ConfigSchema" = Field(default_factory=lambda: FTPServer.ConfigSchema())
class ConfigSchema(FTPServiceABC.ConfigSchema):
"""ConfigSchema for FTPServer."""
type: str = "FTPServer"
server_password: Optional[str] = None
server_password: Optional[str] = None
"""Password needed to connect to FTP server. Default is None."""
config: ConfigSchema = Field(default_factory=lambda: FTPServer.ConfigSchema())
def __init__(self, **kwargs):
kwargs["name"] = "FTPServer"
kwargs["port"] = PORT_LOOKUP["FTP"]

View File

@@ -9,7 +9,6 @@ from primaite import getLogger
from primaite.simulator.network.protocols.ntp import NTPPacket
from primaite.simulator.system.services.service import Service, ServiceOperatingState
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import Port, PORT_LOOKUP
_LOGGER = getLogger(__name__)
@@ -22,12 +21,11 @@ class NTPClient(Service, identifier="NTPClient"):
"""ConfigSchema for NTPClient."""
type: str = "NTPClient"
ntp_server_ip: Optional[IPV4Address] = None
ntp_server_ip: Optional[IPv4Address] = None
"The NTP server the client sends requests to."
config: "NTPClient.ConfigSchema" = Field(default_factory=lambda: NTPClient.ConfigSchema())
config: ConfigSchema = Field(default_factory=lambda: NTPClient.ConfigSchema())
time: Optional[datetime] = None