From 70d9fe2fd97317f2516b1629ca222da4e8008147 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 15 Jan 2025 16:33:11 +0000 Subject: [PATCH] #2887 - End of day commit. Updates to ConfigSchema inheritance, and some initials changes to Router to remove the custom from_config method --- src/primaite/game/game.py | 61 ---------- .../simulator/network/hardware/base.py | 18 +-- .../network/hardware/nodes/host/computer.py | 2 +- .../network/hardware/nodes/host/host_node.py | 2 +- .../hardware/nodes/network/firewall.py | 9 +- .../network/hardware/nodes/network/router.py | 112 ++++-------------- .../network/hardware/nodes/network/switch.py | 3 +- .../hardware/nodes/network/wireless_router.py | 4 +- tests/conftest.py | 3 + 9 files changed, 46 insertions(+), 168 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 48d9df87..6599430a 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -277,67 +277,6 @@ class PrimaiteGame: if n_type in Node._registry: # simplify down Node creation: new_node = Node._registry["n_type"].from_config(config=node_config) - - # Default PrimAITE nodes - # if n_type == "computer": - # new_node = Computer( - # hostname=node_cfg["hostname"], - # ip_address=node_cfg["ip_address"], - # subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), - # default_gateway=node_cfg.get("default_gateway"), - # dns_server=node_cfg.get("dns_server", None), - # operating_state=NodeOperatingState.ON - # if not (p := node_cfg.get("operating_state")) - # else NodeOperatingState[p.upper()], - # ) - # elif n_type == "server": - # new_node = Server( - # hostname=node_cfg["hostname"], - # ip_address=node_cfg["ip_address"], - # subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), - # default_gateway=node_cfg.get("default_gateway"), - # dns_server=node_cfg.get("dns_server", None), - # operating_state=NodeOperatingState.ON - # if not (p := node_cfg.get("operating_state")) - # else NodeOperatingState[p.upper()], - # ) - # elif n_type == "switch": - # new_node = Switch( - # hostname=node_cfg["hostname"], - # num_ports=int(node_cfg.get("num_ports", "8")), - # operating_state=NodeOperatingState.ON - # if not (p := node_cfg.get("operating_state")) - # else NodeOperatingState[p.upper()], - # ) - # elif n_type == "router": - # new_node = Router.from_config(node_cfg) - # elif n_type == "firewall": - # new_node = Firewall.from_config(node_cfg) - # elif n_type == "wireless_router": - # new_node = WirelessRouter.from_config(node_cfg, airspace=net.airspace) - # elif n_type == "printer": - # new_node = Printer( - # hostname=node_cfg["hostname"], - # ip_address=node_cfg["ip_address"], - # subnet_mask=node_cfg["subnet_mask"], - # operating_state=NodeOperatingState.ON - # if not (p := node_cfg.get("operating_state")) - # else NodeOperatingState[p.upper()], - # ) - # # Handle extended nodes - # elif n_type.lower() in Node._registry: - # new_node = HostNode._registry[n_type]( - # hostname=node_cfg["hostname"], - # ip_address=node_cfg.get("ip_address"), - # subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), - # default_gateway=node_cfg.get("default_gateway"), - # dns_server=node_cfg.get("dns_server", None), - # operating_state=NodeOperatingState.ON - # if not (p := node_cfg.get("operating_state")) - # else NodeOperatingState[p.upper()], - # ) - # elif n_type in NetworkNode._registry: - # new_node = NetworkNode._registry[n_type](**node_cfg) else: msg = f"invalid node type {n_type} in config" _LOGGER.error(msg) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index b003009b..822714cb 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -9,7 +9,7 @@ from pathlib import Path from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union from prettytable import MARKDOWN, PrettyTable -from pydantic import BaseModel, Field, validate_call +from pydantic import BaseModel, ConfigDict, Field, validate_call from primaite import getLogger from primaite.exceptions import NetworkError @@ -1480,8 +1480,6 @@ class Node(SimComponent, ABC): :param operating_state: The node operating state, either ON or OFF. """ - hostname: str - "The node hostname on the network." default_gateway: Optional[IPV4Address] = None "The default gateway IP address for forwarding network traffic to other networks." operating_state: NodeOperatingState = NodeOperatingState.OFF @@ -1519,13 +1517,18 @@ class Node(SimComponent, ABC): config: Node.ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema()) - class ConfigSchema: + class ConfigSchema(BaseModel, ABC): """Configuration Schema for Node based classes.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + """Configure pydantic to allow arbitrary types and to let the instance have attributes not present in the model.""" + hostname: str + "The node hostname on the network." + revealed_to_red: bool = False "Informs whether the node has been revealed to a red agent." - start_up_duration: int = 3 + start_up_duration: int = 0 "Time steps needed for the node to start up." start_up_countdown: int = 0 @@ -1549,8 +1552,9 @@ class Node(SimComponent, ABC): red_scan_countdown: int = 0 "Time steps until reveal to red scan is complete." - def from_config(cls, config: Dict) -> Node: - """Create Node object from a given configuration.""" + @classmethod + def from_config(cls, config: Dict) -> "Node": + """Create Node object from a given configuration dictionary.""" if config["type"] not in cls._registry: msg = f"Configuration contains an invalid Node type: {config['type']}" return ValueError(msg) diff --git a/src/primaite/simulator/network/hardware/nodes/host/computer.py b/src/primaite/simulator/network/hardware/nodes/host/computer.py index 1fb63b2e..85857a44 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/computer.py +++ b/src/primaite/simulator/network/hardware/nodes/host/computer.py @@ -42,6 +42,6 @@ class Computer(HostNode, identifier="computer"): class ConfigSchema(HostNode.ConfigSchema): """Configuration Schema for Computer class.""" - pass + hostname: str = "Computer" pass diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index fa73bf10..00f21342 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -332,7 +332,7 @@ class HostNode(Node, identifier="HostNode"): class ConfigSchema(Node.ConfigSchema): """Configuration Schema for HostNode class.""" - pass + hostname: str = "HostNode" def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs): super().__init__(**kwargs) diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 0a397c49..c7e22d49 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -104,13 +104,14 @@ class Firewall(Router, identifier="firewall"): class ConfigSchema(Router.ConfigSChema): """Configuration Schema for Firewall 'Nodes' within PrimAITE.""" - pass + hostname: str = "Firewall" + num_ports: int = 0 - def __init__(self, hostname: str, **kwargs): + def __init__(self, **kwargs): if not kwargs.get("sys_log"): - kwargs["sys_log"] = SysLog(hostname) + kwargs["sys_log"] = SysLog(self.config.hostname) - super().__init__(hostname=hostname, num_ports=0, **kwargs) + super().__init__(hostname=self.config.hostname, num_ports=self.config.num_ports, **kwargs) self.connect_nic( RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", port_name="external") diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 132b1462..83fa066d 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1211,27 +1211,22 @@ class Router(NetworkNode, identifier="router"): "The Router Interfaces on the node." network_interface: Dict[int, RouterInterface] = {} "The Router Interfaces on the node by port id." - acl: AccessControlList - route_table: RouteTable - config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSChema()) + config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSchema()) - class ConfigSChema(NetworkNode.ConfigSchema): + class ConfigSchema(NetworkNode.ConfigSchema): """Configuration Schema for Router Objects.""" - num_ports: int = 10 + num_ports: int = 5 hostname: str = "Router" ports: list = [] + sys_log: SysLog = SysLog(hostname) + acl: AccessControlList = AccessControlList(sys_log=sys_log, implicit_action=ACLAction.DENY, name=hostname) + route_table: RouteTable = RouteTable(sys_log=sys_log) - def __init__(self, hostname: str, num_ports: int = 5, **kwargs): - if not kwargs.get("sys_log"): - kwargs["sys_log"] = SysLog(hostname) - if not kwargs.get("acl"): - kwargs["acl"] = AccessControlList(sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=hostname) - if not kwargs.get("route_table"): - kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"]) - super().__init__(hostname=hostname, num_ports=num_ports, **kwargs) - self.session_manager = RouterSessionManager(sys_log=self.sys_log) + def __init__(self, **kwargs): + super().__init__(hostname=self.config.hostname, num_ports=self.config.num_ports, **kwargs) + self.session_manager = RouterSessionManager(sys_log=self.config.sys_log) self.session_manager.node = self self.software_manager.session_manager = self.session_manager self.session_manager.software_manager = self.software_manager @@ -1265,10 +1260,10 @@ class Router(NetworkNode, identifier="router"): Initializes the router's ACL (Access Control List) with default rules, permitting essential protocols like ARP and ICMP, which are necessary for basic network operations and diagnostics. """ - self.acl.add_rule( + self.config.acl.add_rule( action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22 ) - self.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) + self.config.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) def setup_for_episode(self, episode: int): """ @@ -1292,7 +1287,7 @@ class Router(NetworkNode, identifier="router"): More information in user guide and docstring for SimComponent._init_request_manager. """ rm = super()._init_request_manager() - rm.add_request("acl", RequestType(func=self.acl._request_manager)) + rm.add_request("acl", RequestType(func=self.config.acl._request_manager)) return rm def ip_is_router_interface(self, ip_address: IPv4Address, enabled_only: bool = False) -> bool: @@ -1346,7 +1341,7 @@ class Router(NetworkNode, identifier="router"): """ state = super().describe_state() state["num_ports"] = self.config.num_ports - state["acl"] = self.acl.describe_state() + state["acl"] = self.config.acl.describe_state() return state def check_send_frame_to_session_manager(self, frame: Frame) -> bool: @@ -1393,7 +1388,7 @@ class Router(NetworkNode, identifier="router"): return # Check if it's permitted - permitted, rule = self.acl.is_permitted(frame) + permitted, rule = self.config.acl.is_permitted(frame) if not permitted: at_port = self._get_port_of_nic(from_network_interface) @@ -1566,83 +1561,18 @@ class Router(NetworkNode, identifier="router"): ) print(table) - # TODO: Remove - Cover normal config items with ConfigSchema. Move additional setup components to __init__ ? - - @classmethod - def from_config(cls, cfg: dict, **kwargs) -> "Router": - """Create a router based on a config dict. - - Schema: - - hostname (str): unique name for this router. - - num_ports (int, optional): Number of network ports on the router. 8 by default - - ports (dict): Dict with integers from 1 - num_ports as keys. The values should be another dict specifying - ip_address and subnet_mask assigned to that ports (as strings) - - acl (dict): Dict with integers from 1 - max_acl_rules as keys. The key defines the position within the ACL - where the rule will be added (lower number is resolved first). The values should describe valid ACL - Rules as: - - action (str): either PERMIT or DENY - - src_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER - - dst_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER - - protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP - - src_ip_address (str, optional): IP address octet written in base 10 - - dst_ip_address (str, optional): IP address octet written in base 10 - - routes (list[dict]): List of route dicts with values: - - address (str): The destination address of the route. - - subnet_mask (str): The subnet mask of the route. - - next_hop_ip_address (str): The next hop IP for the route. - - metric (int): The metric of the route. Optional. - - default_route: - - next_hop_ip_address (str): The next hop IP for the route. - - Example config: - ``` - { - 'hostname': 'router_1', - 'num_ports': 5, - 'ports': { - 1: { - 'ip_address' : '192.168.1.1', - 'subnet_mask' : '255.255.255.0', - }, - 2: { - 'ip_address' : '192.168.0.1', - 'subnet_mask' : '255.255.255.252', - } - }, - 'acl' : { - 21: {'action': 'PERMIT', 'src_port': 'HTTP', dst_port: 'HTTP'}, - 22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'}, - 23: {'action': 'PERMIT', 'protocol': 'ICMP'}, - }, - 'routes' : [ - {'address': '192.168.0.0', 'subnet_mask': '255.255.255.0', 'next_hop_ip_address': '192.168.1.2'} - ], - 'default_route': {'next_hop_ip_address': '192.168.0.2'} - } - ``` - - :param cfg: Router config adhering to schema described in main docstring body - :type cfg: dict - :return: Configured router. - :rtype: Router - """ - router = Router( - hostname=cfg["hostname"], - num_ports=int(cfg.get("num_ports", "5")), - operating_state=NodeOperatingState.ON - if not (p := cfg.get("operating_state")) - else NodeOperatingState[p.upper()], - ) + def setup_router(self, cfg: dict) -> Router: + """ TODO: This is the extra bit of Router's from_config metho. Needs sorting.""" if "ports" in cfg: for port_num, port_cfg in cfg["ports"].items(): - router.configure_port( + self.configure_port( port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=IPv4Address(port_cfg.get("subnet_mask", "255.255.255.0")), ) if "acl" in cfg: for r_num, r_cfg in cfg["acl"].items(): - router.acl.add_rule( + self.config.acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p], dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p], @@ -1655,7 +1585,7 @@ class Router(NetworkNode, identifier="router"): ) if "routes" in cfg: for route in cfg.get("routes"): - router.route_table.add_route( + self.config.route_table.add_route( address=IPv4Address(route.get("address")), subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")), next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")), @@ -1664,5 +1594,5 @@ class Router(NetworkNode, identifier="router"): if "default_route" in cfg: next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None) if next_hop_ip_address: - router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) - return router + self.config.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address) + return self diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index b73af3cb..a2d0050b 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -107,8 +107,9 @@ class Switch(NetworkNode, identifier="switch"): class ConfigSchema(NetworkNode.ConfigSchema): """Configuration Schema for Switch nodes within PrimAITE.""" + hostname: str = "Switch" num_ports: int = 24 - "The number of ports on the switch." + "The number of ports on the switch. Default is 24." def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 0a527ab8..2c4b5976 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -124,12 +124,12 @@ class WirelessRouter(Router, identifier="wireless_router"): network_interface: Dict[int, Union[RouterInterface, WirelessAccessPoint]] = {} airspace: AirSpace - config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.Configschema()) + config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.ConfigSchema()) class ConfigSchema(Router.ConfigSChema): """Configuration Schema for WirelessRouter nodes within PrimAITE.""" - pass + hostname: str = "WirelessRouter" def __init__(self, hostname: str, airspace: AirSpace, **kwargs): super().__init__(hostname=hostname, num_ports=0, airspace=airspace, **kwargs) diff --git a/tests/conftest.py b/tests/conftest.py index 0d2cc363..6cbcfa84 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -195,12 +195,14 @@ def example_network() -> Network: network = Network() # Router 1 + # router_1 = Router(hostname="router_1", start_up_duration=0) router_1 = Router(hostname="router_1", start_up_duration=0) router_1.power_on() router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0") router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0") # Switch 1 + # switch_1_config = Switch.ConfigSchema() switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0) switch_1.power_on() @@ -208,6 +210,7 @@ def example_network() -> Network: router_1.enable_port(1) # Switch 2 + # switch_2_config = Switch.ConfigSchema() switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0) switch_2.power_on() network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8])