#2887 - End of day commit. Updates to ConfigSchema inheritance, and some initials changes to Router to remove the custom from_config method

This commit is contained in:
Charlie Crane
2025-01-15 16:33:11 +00:00
parent 582e7cfec7
commit 70d9fe2fd9
9 changed files with 46 additions and 168 deletions

View File

@@ -277,67 +277,6 @@ class PrimaiteGame:
if n_type in Node._registry:
# simplify down Node creation:
new_node = Node._registry["n_type"].from_config(config=node_config)
# Default PrimAITE nodes
# if n_type == "computer":
# new_node = Computer(
# hostname=node_cfg["hostname"],
# ip_address=node_cfg["ip_address"],
# subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
# default_gateway=node_cfg.get("default_gateway"),
# dns_server=node_cfg.get("dns_server", None),
# operating_state=NodeOperatingState.ON
# if not (p := node_cfg.get("operating_state"))
# else NodeOperatingState[p.upper()],
# )
# elif n_type == "server":
# new_node = Server(
# hostname=node_cfg["hostname"],
# ip_address=node_cfg["ip_address"],
# subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
# default_gateway=node_cfg.get("default_gateway"),
# dns_server=node_cfg.get("dns_server", None),
# operating_state=NodeOperatingState.ON
# if not (p := node_cfg.get("operating_state"))
# else NodeOperatingState[p.upper()],
# )
# elif n_type == "switch":
# new_node = Switch(
# hostname=node_cfg["hostname"],
# num_ports=int(node_cfg.get("num_ports", "8")),
# operating_state=NodeOperatingState.ON
# if not (p := node_cfg.get("operating_state"))
# else NodeOperatingState[p.upper()],
# )
# elif n_type == "router":
# new_node = Router.from_config(node_cfg)
# elif n_type == "firewall":
# new_node = Firewall.from_config(node_cfg)
# elif n_type == "wireless_router":
# new_node = WirelessRouter.from_config(node_cfg, airspace=net.airspace)
# elif n_type == "printer":
# new_node = Printer(
# hostname=node_cfg["hostname"],
# ip_address=node_cfg["ip_address"],
# subnet_mask=node_cfg["subnet_mask"],
# operating_state=NodeOperatingState.ON
# if not (p := node_cfg.get("operating_state"))
# else NodeOperatingState[p.upper()],
# )
# # Handle extended nodes
# elif n_type.lower() in Node._registry:
# new_node = HostNode._registry[n_type](
# hostname=node_cfg["hostname"],
# ip_address=node_cfg.get("ip_address"),
# subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
# default_gateway=node_cfg.get("default_gateway"),
# dns_server=node_cfg.get("dns_server", None),
# operating_state=NodeOperatingState.ON
# if not (p := node_cfg.get("operating_state"))
# else NodeOperatingState[p.upper()],
# )
# elif n_type in NetworkNode._registry:
# new_node = NetworkNode._registry[n_type](**node_cfg)
else:
msg = f"invalid node type {n_type} in config"
_LOGGER.error(msg)

View File

@@ -9,7 +9,7 @@ from pathlib import Path
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel, Field, validate_call
from pydantic import BaseModel, ConfigDict, Field, validate_call
from primaite import getLogger
from primaite.exceptions import NetworkError
@@ -1480,8 +1480,6 @@ class Node(SimComponent, ABC):
:param operating_state: The node operating state, either ON or OFF.
"""
hostname: str
"The node hostname on the network."
default_gateway: Optional[IPV4Address] = None
"The default gateway IP address for forwarding network traffic to other networks."
operating_state: NodeOperatingState = NodeOperatingState.OFF
@@ -1519,13 +1517,18 @@ class Node(SimComponent, ABC):
config: Node.ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema())
class ConfigSchema:
class ConfigSchema(BaseModel, ABC):
"""Configuration Schema for Node based classes."""
model_config = ConfigDict(arbitrary_types_allowed=True)
"""Configure pydantic to allow arbitrary types and to let the instance have attributes not present in the model."""
hostname: str
"The node hostname on the network."
revealed_to_red: bool = False
"Informs whether the node has been revealed to a red agent."
start_up_duration: int = 3
start_up_duration: int = 0
"Time steps needed for the node to start up."
start_up_countdown: int = 0
@@ -1549,8 +1552,9 @@ class Node(SimComponent, ABC):
red_scan_countdown: int = 0
"Time steps until reveal to red scan is complete."
def from_config(cls, config: Dict) -> Node:
"""Create Node object from a given configuration."""
@classmethod
def from_config(cls, config: Dict) -> "Node":
"""Create Node object from a given configuration dictionary."""
if config["type"] not in cls._registry:
msg = f"Configuration contains an invalid Node type: {config['type']}"
return ValueError(msg)

View File

@@ -42,6 +42,6 @@ class Computer(HostNode, identifier="computer"):
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Computer class."""
pass
hostname: str = "Computer"
pass

View File

@@ -332,7 +332,7 @@ class HostNode(Node, identifier="HostNode"):
class ConfigSchema(Node.ConfigSchema):
"""Configuration Schema for HostNode class."""
pass
hostname: str = "HostNode"
def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs):
super().__init__(**kwargs)

View File

@@ -104,13 +104,14 @@ class Firewall(Router, identifier="firewall"):
class ConfigSchema(Router.ConfigSChema):
"""Configuration Schema for Firewall 'Nodes' within PrimAITE."""
pass
hostname: str = "Firewall"
num_ports: int = 0
def __init__(self, hostname: str, **kwargs):
def __init__(self, **kwargs):
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(hostname)
kwargs["sys_log"] = SysLog(self.config.hostname)
super().__init__(hostname=hostname, num_ports=0, **kwargs)
super().__init__(hostname=self.config.hostname, num_ports=self.config.num_ports, **kwargs)
self.connect_nic(
RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", port_name="external")

View File

@@ -1211,27 +1211,22 @@ class Router(NetworkNode, identifier="router"):
"The Router Interfaces on the node."
network_interface: Dict[int, RouterInterface] = {}
"The Router Interfaces on the node by port id."
acl: AccessControlList
route_table: RouteTable
config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSChema())
config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSchema())
class ConfigSChema(NetworkNode.ConfigSchema):
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Router Objects."""
num_ports: int = 10
num_ports: int = 5
hostname: str = "Router"
ports: list = []
sys_log: SysLog = SysLog(hostname)
acl: AccessControlList = AccessControlList(sys_log=sys_log, implicit_action=ACLAction.DENY, name=hostname)
route_table: RouteTable = RouteTable(sys_log=sys_log)
def __init__(self, hostname: str, num_ports: int = 5, **kwargs):
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(hostname)
if not kwargs.get("acl"):
kwargs["acl"] = AccessControlList(sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=hostname)
if not kwargs.get("route_table"):
kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"])
super().__init__(hostname=hostname, num_ports=num_ports, **kwargs)
self.session_manager = RouterSessionManager(sys_log=self.sys_log)
def __init__(self, **kwargs):
super().__init__(hostname=self.config.hostname, num_ports=self.config.num_ports, **kwargs)
self.session_manager = RouterSessionManager(sys_log=self.config.sys_log)
self.session_manager.node = self
self.software_manager.session_manager = self.session_manager
self.session_manager.software_manager = self.software_manager
@@ -1265,10 +1260,10 @@ class Router(NetworkNode, identifier="router"):
Initializes the router's ACL (Access Control List) with default rules, permitting essential protocols like ARP
and ICMP, which are necessary for basic network operations and diagnostics.
"""
self.acl.add_rule(
self.config.acl.add_rule(
action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22
)
self.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23)
self.config.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23)
def setup_for_episode(self, episode: int):
"""
@@ -1292,7 +1287,7 @@ class Router(NetworkNode, identifier="router"):
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
rm.add_request("acl", RequestType(func=self.acl._request_manager))
rm.add_request("acl", RequestType(func=self.config.acl._request_manager))
return rm
def ip_is_router_interface(self, ip_address: IPv4Address, enabled_only: bool = False) -> bool:
@@ -1346,7 +1341,7 @@ class Router(NetworkNode, identifier="router"):
"""
state = super().describe_state()
state["num_ports"] = self.config.num_ports
state["acl"] = self.acl.describe_state()
state["acl"] = self.config.acl.describe_state()
return state
def check_send_frame_to_session_manager(self, frame: Frame) -> bool:
@@ -1393,7 +1388,7 @@ class Router(NetworkNode, identifier="router"):
return
# Check if it's permitted
permitted, rule = self.acl.is_permitted(frame)
permitted, rule = self.config.acl.is_permitted(frame)
if not permitted:
at_port = self._get_port_of_nic(from_network_interface)
@@ -1566,83 +1561,18 @@ class Router(NetworkNode, identifier="router"):
)
print(table)
# TODO: Remove - Cover normal config items with ConfigSchema. Move additional setup components to __init__ ?
@classmethod
def from_config(cls, cfg: dict, **kwargs) -> "Router":
"""Create a router based on a config dict.
Schema:
- hostname (str): unique name for this router.
- num_ports (int, optional): Number of network ports on the router. 8 by default
- ports (dict): Dict with integers from 1 - num_ports as keys. The values should be another dict specifying
ip_address and subnet_mask assigned to that ports (as strings)
- acl (dict): Dict with integers from 1 - max_acl_rules as keys. The key defines the position within the ACL
where the rule will be added (lower number is resolved first). The values should describe valid ACL
Rules as:
- action (str): either PERMIT or DENY
- src_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
- dst_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
- protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP
- src_ip_address (str, optional): IP address octet written in base 10
- dst_ip_address (str, optional): IP address octet written in base 10
- routes (list[dict]): List of route dicts with values:
- address (str): The destination address of the route.
- subnet_mask (str): The subnet mask of the route.
- next_hop_ip_address (str): The next hop IP for the route.
- metric (int): The metric of the route. Optional.
- default_route:
- next_hop_ip_address (str): The next hop IP for the route.
Example config:
```
{
'hostname': 'router_1',
'num_ports': 5,
'ports': {
1: {
'ip_address' : '192.168.1.1',
'subnet_mask' : '255.255.255.0',
},
2: {
'ip_address' : '192.168.0.1',
'subnet_mask' : '255.255.255.252',
}
},
'acl' : {
21: {'action': 'PERMIT', 'src_port': 'HTTP', dst_port: 'HTTP'},
22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'},
23: {'action': 'PERMIT', 'protocol': 'ICMP'},
},
'routes' : [
{'address': '192.168.0.0', 'subnet_mask': '255.255.255.0', 'next_hop_ip_address': '192.168.1.2'}
],
'default_route': {'next_hop_ip_address': '192.168.0.2'}
}
```
:param cfg: Router config adhering to schema described in main docstring body
:type cfg: dict
:return: Configured router.
:rtype: Router
"""
router = Router(
hostname=cfg["hostname"],
num_ports=int(cfg.get("num_ports", "5")),
operating_state=NodeOperatingState.ON
if not (p := cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
def setup_router(self, cfg: dict) -> Router:
""" TODO: This is the extra bit of Router's from_config metho. Needs sorting."""
if "ports" in cfg:
for port_num, port_cfg in cfg["ports"].items():
router.configure_port(
self.configure_port(
port=port_num,
ip_address=port_cfg["ip_address"],
subnet_mask=IPv4Address(port_cfg.get("subnet_mask", "255.255.255.0")),
)
if "acl" in cfg:
for r_num, r_cfg in cfg["acl"].items():
router.acl.add_rule(
self.config.acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p],
@@ -1655,7 +1585,7 @@ class Router(NetworkNode, identifier="router"):
)
if "routes" in cfg:
for route in cfg.get("routes"):
router.route_table.add_route(
self.config.route_table.add_route(
address=IPv4Address(route.get("address")),
subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")),
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
@@ -1664,5 +1594,5 @@ class Router(NetworkNode, identifier="router"):
if "default_route" in cfg:
next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None)
if next_hop_ip_address:
router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
return router
self.config.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
return self

View File

@@ -107,8 +107,9 @@ class Switch(NetworkNode, identifier="switch"):
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Switch nodes within PrimAITE."""
hostname: str = "Switch"
num_ports: int = 24
"The number of ports on the switch."
"The number of ports on the switch. Default is 24."
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -124,12 +124,12 @@ class WirelessRouter(Router, identifier="wireless_router"):
network_interface: Dict[int, Union[RouterInterface, WirelessAccessPoint]] = {}
airspace: AirSpace
config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.Configschema())
config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.ConfigSchema())
class ConfigSchema(Router.ConfigSChema):
"""Configuration Schema for WirelessRouter nodes within PrimAITE."""
pass
hostname: str = "WirelessRouter"
def __init__(self, hostname: str, airspace: AirSpace, **kwargs):
super().__init__(hostname=hostname, num_ports=0, airspace=airspace, **kwargs)

View File

@@ -195,12 +195,14 @@ def example_network() -> Network:
network = Network()
# Router 1
# router_1 = Router(hostname="router_1", start_up_duration=0)
router_1 = Router(hostname="router_1", start_up_duration=0)
router_1.power_on()
router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0")
# Switch 1
# switch_1_config = Switch.ConfigSchema()
switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0)
switch_1.power_on()
@@ -208,6 +210,7 @@ def example_network() -> Network:
router_1.enable_port(1)
# Switch 2
# switch_2_config = Switch.ConfigSchema()
switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0)
switch_2.power_on()
network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8])