#2887 - End of day commit. Updates to ConfigSchema inheritance, and some initials changes to Router to remove the custom from_config method
This commit is contained in:
@@ -277,67 +277,6 @@ class PrimaiteGame:
|
||||
if n_type in Node._registry:
|
||||
# simplify down Node creation:
|
||||
new_node = Node._registry["n_type"].from_config(config=node_config)
|
||||
|
||||
# Default PrimAITE nodes
|
||||
# if n_type == "computer":
|
||||
# new_node = Computer(
|
||||
# hostname=node_cfg["hostname"],
|
||||
# ip_address=node_cfg["ip_address"],
|
||||
# subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
|
||||
# default_gateway=node_cfg.get("default_gateway"),
|
||||
# dns_server=node_cfg.get("dns_server", None),
|
||||
# operating_state=NodeOperatingState.ON
|
||||
# if not (p := node_cfg.get("operating_state"))
|
||||
# else NodeOperatingState[p.upper()],
|
||||
# )
|
||||
# elif n_type == "server":
|
||||
# new_node = Server(
|
||||
# hostname=node_cfg["hostname"],
|
||||
# ip_address=node_cfg["ip_address"],
|
||||
# subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
|
||||
# default_gateway=node_cfg.get("default_gateway"),
|
||||
# dns_server=node_cfg.get("dns_server", None),
|
||||
# operating_state=NodeOperatingState.ON
|
||||
# if not (p := node_cfg.get("operating_state"))
|
||||
# else NodeOperatingState[p.upper()],
|
||||
# )
|
||||
# elif n_type == "switch":
|
||||
# new_node = Switch(
|
||||
# hostname=node_cfg["hostname"],
|
||||
# num_ports=int(node_cfg.get("num_ports", "8")),
|
||||
# operating_state=NodeOperatingState.ON
|
||||
# if not (p := node_cfg.get("operating_state"))
|
||||
# else NodeOperatingState[p.upper()],
|
||||
# )
|
||||
# elif n_type == "router":
|
||||
# new_node = Router.from_config(node_cfg)
|
||||
# elif n_type == "firewall":
|
||||
# new_node = Firewall.from_config(node_cfg)
|
||||
# elif n_type == "wireless_router":
|
||||
# new_node = WirelessRouter.from_config(node_cfg, airspace=net.airspace)
|
||||
# elif n_type == "printer":
|
||||
# new_node = Printer(
|
||||
# hostname=node_cfg["hostname"],
|
||||
# ip_address=node_cfg["ip_address"],
|
||||
# subnet_mask=node_cfg["subnet_mask"],
|
||||
# operating_state=NodeOperatingState.ON
|
||||
# if not (p := node_cfg.get("operating_state"))
|
||||
# else NodeOperatingState[p.upper()],
|
||||
# )
|
||||
# # Handle extended nodes
|
||||
# elif n_type.lower() in Node._registry:
|
||||
# new_node = HostNode._registry[n_type](
|
||||
# hostname=node_cfg["hostname"],
|
||||
# ip_address=node_cfg.get("ip_address"),
|
||||
# subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
|
||||
# default_gateway=node_cfg.get("default_gateway"),
|
||||
# dns_server=node_cfg.get("dns_server", None),
|
||||
# operating_state=NodeOperatingState.ON
|
||||
# if not (p := node_cfg.get("operating_state"))
|
||||
# else NodeOperatingState[p.upper()],
|
||||
# )
|
||||
# elif n_type in NetworkNode._registry:
|
||||
# new_node = NetworkNode._registry[n_type](**node_cfg)
|
||||
else:
|
||||
msg = f"invalid node type {n_type} in config"
|
||||
_LOGGER.error(msg)
|
||||
|
||||
@@ -9,7 +9,7 @@ from pathlib import Path
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import BaseModel, Field, validate_call
|
||||
from pydantic import BaseModel, ConfigDict, Field, validate_call
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.exceptions import NetworkError
|
||||
@@ -1480,8 +1480,6 @@ class Node(SimComponent, ABC):
|
||||
:param operating_state: The node operating state, either ON or OFF.
|
||||
"""
|
||||
|
||||
hostname: str
|
||||
"The node hostname on the network."
|
||||
default_gateway: Optional[IPV4Address] = None
|
||||
"The default gateway IP address for forwarding network traffic to other networks."
|
||||
operating_state: NodeOperatingState = NodeOperatingState.OFF
|
||||
@@ -1519,13 +1517,18 @@ class Node(SimComponent, ABC):
|
||||
|
||||
config: Node.ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema())
|
||||
|
||||
class ConfigSchema:
|
||||
class ConfigSchema(BaseModel, ABC):
|
||||
"""Configuration Schema for Node based classes."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
"""Configure pydantic to allow arbitrary types and to let the instance have attributes not present in the model."""
|
||||
hostname: str
|
||||
"The node hostname on the network."
|
||||
|
||||
revealed_to_red: bool = False
|
||||
"Informs whether the node has been revealed to a red agent."
|
||||
|
||||
start_up_duration: int = 3
|
||||
start_up_duration: int = 0
|
||||
"Time steps needed for the node to start up."
|
||||
|
||||
start_up_countdown: int = 0
|
||||
@@ -1549,8 +1552,9 @@ class Node(SimComponent, ABC):
|
||||
red_scan_countdown: int = 0
|
||||
"Time steps until reveal to red scan is complete."
|
||||
|
||||
def from_config(cls, config: Dict) -> Node:
|
||||
"""Create Node object from a given configuration."""
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "Node":
|
||||
"""Create Node object from a given configuration dictionary."""
|
||||
if config["type"] not in cls._registry:
|
||||
msg = f"Configuration contains an invalid Node type: {config['type']}"
|
||||
return ValueError(msg)
|
||||
|
||||
@@ -42,6 +42,6 @@ class Computer(HostNode, identifier="computer"):
|
||||
class ConfigSchema(HostNode.ConfigSchema):
|
||||
"""Configuration Schema for Computer class."""
|
||||
|
||||
pass
|
||||
hostname: str = "Computer"
|
||||
|
||||
pass
|
||||
|
||||
@@ -332,7 +332,7 @@ class HostNode(Node, identifier="HostNode"):
|
||||
class ConfigSchema(Node.ConfigSchema):
|
||||
"""Configuration Schema for HostNode class."""
|
||||
|
||||
pass
|
||||
hostname: str = "HostNode"
|
||||
|
||||
def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -104,13 +104,14 @@ class Firewall(Router, identifier="firewall"):
|
||||
class ConfigSchema(Router.ConfigSChema):
|
||||
"""Configuration Schema for Firewall 'Nodes' within PrimAITE."""
|
||||
|
||||
pass
|
||||
hostname: str = "Firewall"
|
||||
num_ports: int = 0
|
||||
|
||||
def __init__(self, hostname: str, **kwargs):
|
||||
def __init__(self, **kwargs):
|
||||
if not kwargs.get("sys_log"):
|
||||
kwargs["sys_log"] = SysLog(hostname)
|
||||
kwargs["sys_log"] = SysLog(self.config.hostname)
|
||||
|
||||
super().__init__(hostname=hostname, num_ports=0, **kwargs)
|
||||
super().__init__(hostname=self.config.hostname, num_ports=self.config.num_ports, **kwargs)
|
||||
|
||||
self.connect_nic(
|
||||
RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", port_name="external")
|
||||
|
||||
@@ -1211,27 +1211,22 @@ class Router(NetworkNode, identifier="router"):
|
||||
"The Router Interfaces on the node."
|
||||
network_interface: Dict[int, RouterInterface] = {}
|
||||
"The Router Interfaces on the node by port id."
|
||||
acl: AccessControlList
|
||||
route_table: RouteTable
|
||||
|
||||
config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSChema())
|
||||
config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSchema())
|
||||
|
||||
class ConfigSChema(NetworkNode.ConfigSchema):
|
||||
class ConfigSchema(NetworkNode.ConfigSchema):
|
||||
"""Configuration Schema for Router Objects."""
|
||||
|
||||
num_ports: int = 10
|
||||
num_ports: int = 5
|
||||
hostname: str = "Router"
|
||||
ports: list = []
|
||||
sys_log: SysLog = SysLog(hostname)
|
||||
acl: AccessControlList = AccessControlList(sys_log=sys_log, implicit_action=ACLAction.DENY, name=hostname)
|
||||
route_table: RouteTable = RouteTable(sys_log=sys_log)
|
||||
|
||||
def __init__(self, hostname: str, num_ports: int = 5, **kwargs):
|
||||
if not kwargs.get("sys_log"):
|
||||
kwargs["sys_log"] = SysLog(hostname)
|
||||
if not kwargs.get("acl"):
|
||||
kwargs["acl"] = AccessControlList(sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=hostname)
|
||||
if not kwargs.get("route_table"):
|
||||
kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"])
|
||||
super().__init__(hostname=hostname, num_ports=num_ports, **kwargs)
|
||||
self.session_manager = RouterSessionManager(sys_log=self.sys_log)
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(hostname=self.config.hostname, num_ports=self.config.num_ports, **kwargs)
|
||||
self.session_manager = RouterSessionManager(sys_log=self.config.sys_log)
|
||||
self.session_manager.node = self
|
||||
self.software_manager.session_manager = self.session_manager
|
||||
self.session_manager.software_manager = self.software_manager
|
||||
@@ -1265,10 +1260,10 @@ class Router(NetworkNode, identifier="router"):
|
||||
Initializes the router's ACL (Access Control List) with default rules, permitting essential protocols like ARP
|
||||
and ICMP, which are necessary for basic network operations and diagnostics.
|
||||
"""
|
||||
self.acl.add_rule(
|
||||
self.config.acl.add_rule(
|
||||
action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22
|
||||
)
|
||||
self.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23)
|
||||
self.config.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23)
|
||||
|
||||
def setup_for_episode(self, episode: int):
|
||||
"""
|
||||
@@ -1292,7 +1287,7 @@ class Router(NetworkNode, identifier="router"):
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request("acl", RequestType(func=self.acl._request_manager))
|
||||
rm.add_request("acl", RequestType(func=self.config.acl._request_manager))
|
||||
return rm
|
||||
|
||||
def ip_is_router_interface(self, ip_address: IPv4Address, enabled_only: bool = False) -> bool:
|
||||
@@ -1346,7 +1341,7 @@ class Router(NetworkNode, identifier="router"):
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["num_ports"] = self.config.num_ports
|
||||
state["acl"] = self.acl.describe_state()
|
||||
state["acl"] = self.config.acl.describe_state()
|
||||
return state
|
||||
|
||||
def check_send_frame_to_session_manager(self, frame: Frame) -> bool:
|
||||
@@ -1393,7 +1388,7 @@ class Router(NetworkNode, identifier="router"):
|
||||
return
|
||||
|
||||
# Check if it's permitted
|
||||
permitted, rule = self.acl.is_permitted(frame)
|
||||
permitted, rule = self.config.acl.is_permitted(frame)
|
||||
|
||||
if not permitted:
|
||||
at_port = self._get_port_of_nic(from_network_interface)
|
||||
@@ -1566,83 +1561,18 @@ class Router(NetworkNode, identifier="router"):
|
||||
)
|
||||
print(table)
|
||||
|
||||
# TODO: Remove - Cover normal config items with ConfigSchema. Move additional setup components to __init__ ?
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: dict, **kwargs) -> "Router":
|
||||
"""Create a router based on a config dict.
|
||||
|
||||
Schema:
|
||||
- hostname (str): unique name for this router.
|
||||
- num_ports (int, optional): Number of network ports on the router. 8 by default
|
||||
- ports (dict): Dict with integers from 1 - num_ports as keys. The values should be another dict specifying
|
||||
ip_address and subnet_mask assigned to that ports (as strings)
|
||||
- acl (dict): Dict with integers from 1 - max_acl_rules as keys. The key defines the position within the ACL
|
||||
where the rule will be added (lower number is resolved first). The values should describe valid ACL
|
||||
Rules as:
|
||||
- action (str): either PERMIT or DENY
|
||||
- src_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
|
||||
- dst_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
|
||||
- protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP
|
||||
- src_ip_address (str, optional): IP address octet written in base 10
|
||||
- dst_ip_address (str, optional): IP address octet written in base 10
|
||||
- routes (list[dict]): List of route dicts with values:
|
||||
- address (str): The destination address of the route.
|
||||
- subnet_mask (str): The subnet mask of the route.
|
||||
- next_hop_ip_address (str): The next hop IP for the route.
|
||||
- metric (int): The metric of the route. Optional.
|
||||
- default_route:
|
||||
- next_hop_ip_address (str): The next hop IP for the route.
|
||||
|
||||
Example config:
|
||||
```
|
||||
{
|
||||
'hostname': 'router_1',
|
||||
'num_ports': 5,
|
||||
'ports': {
|
||||
1: {
|
||||
'ip_address' : '192.168.1.1',
|
||||
'subnet_mask' : '255.255.255.0',
|
||||
},
|
||||
2: {
|
||||
'ip_address' : '192.168.0.1',
|
||||
'subnet_mask' : '255.255.255.252',
|
||||
}
|
||||
},
|
||||
'acl' : {
|
||||
21: {'action': 'PERMIT', 'src_port': 'HTTP', dst_port: 'HTTP'},
|
||||
22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'},
|
||||
23: {'action': 'PERMIT', 'protocol': 'ICMP'},
|
||||
},
|
||||
'routes' : [
|
||||
{'address': '192.168.0.0', 'subnet_mask': '255.255.255.0', 'next_hop_ip_address': '192.168.1.2'}
|
||||
],
|
||||
'default_route': {'next_hop_ip_address': '192.168.0.2'}
|
||||
}
|
||||
```
|
||||
|
||||
:param cfg: Router config adhering to schema described in main docstring body
|
||||
:type cfg: dict
|
||||
:return: Configured router.
|
||||
:rtype: Router
|
||||
"""
|
||||
router = Router(
|
||||
hostname=cfg["hostname"],
|
||||
num_ports=int(cfg.get("num_ports", "5")),
|
||||
operating_state=NodeOperatingState.ON
|
||||
if not (p := cfg.get("operating_state"))
|
||||
else NodeOperatingState[p.upper()],
|
||||
)
|
||||
def setup_router(self, cfg: dict) -> Router:
|
||||
""" TODO: This is the extra bit of Router's from_config metho. Needs sorting."""
|
||||
if "ports" in cfg:
|
||||
for port_num, port_cfg in cfg["ports"].items():
|
||||
router.configure_port(
|
||||
self.configure_port(
|
||||
port=port_num,
|
||||
ip_address=port_cfg["ip_address"],
|
||||
subnet_mask=IPv4Address(port_cfg.get("subnet_mask", "255.255.255.0")),
|
||||
)
|
||||
if "acl" in cfg:
|
||||
for r_num, r_cfg in cfg["acl"].items():
|
||||
router.acl.add_rule(
|
||||
self.config.acl.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else PORT_LOOKUP[p],
|
||||
@@ -1655,7 +1585,7 @@ class Router(NetworkNode, identifier="router"):
|
||||
)
|
||||
if "routes" in cfg:
|
||||
for route in cfg.get("routes"):
|
||||
router.route_table.add_route(
|
||||
self.config.route_table.add_route(
|
||||
address=IPv4Address(route.get("address")),
|
||||
subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")),
|
||||
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
|
||||
@@ -1664,5 +1594,5 @@ class Router(NetworkNode, identifier="router"):
|
||||
if "default_route" in cfg:
|
||||
next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None)
|
||||
if next_hop_ip_address:
|
||||
router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
|
||||
return router
|
||||
self.config.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
|
||||
return self
|
||||
|
||||
@@ -107,8 +107,9 @@ class Switch(NetworkNode, identifier="switch"):
|
||||
class ConfigSchema(NetworkNode.ConfigSchema):
|
||||
"""Configuration Schema for Switch nodes within PrimAITE."""
|
||||
|
||||
hostname: str = "Switch"
|
||||
num_ports: int = 24
|
||||
"The number of ports on the switch."
|
||||
"The number of ports on the switch. Default is 24."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -124,12 +124,12 @@ class WirelessRouter(Router, identifier="wireless_router"):
|
||||
network_interface: Dict[int, Union[RouterInterface, WirelessAccessPoint]] = {}
|
||||
airspace: AirSpace
|
||||
|
||||
config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.Configschema())
|
||||
config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.ConfigSchema())
|
||||
|
||||
class ConfigSchema(Router.ConfigSChema):
|
||||
"""Configuration Schema for WirelessRouter nodes within PrimAITE."""
|
||||
|
||||
pass
|
||||
hostname: str = "WirelessRouter"
|
||||
|
||||
def __init__(self, hostname: str, airspace: AirSpace, **kwargs):
|
||||
super().__init__(hostname=hostname, num_ports=0, airspace=airspace, **kwargs)
|
||||
|
||||
@@ -195,12 +195,14 @@ def example_network() -> Network:
|
||||
network = Network()
|
||||
|
||||
# Router 1
|
||||
# router_1 = Router(hostname="router_1", start_up_duration=0)
|
||||
router_1 = Router(hostname="router_1", start_up_duration=0)
|
||||
router_1.power_on()
|
||||
router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
|
||||
router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0")
|
||||
|
||||
# Switch 1
|
||||
# switch_1_config = Switch.ConfigSchema()
|
||||
switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0)
|
||||
switch_1.power_on()
|
||||
|
||||
@@ -208,6 +210,7 @@ def example_network() -> Network:
|
||||
router_1.enable_port(1)
|
||||
|
||||
# Switch 2
|
||||
# switch_2_config = Switch.ConfigSchema()
|
||||
switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0)
|
||||
switch_2.power_on()
|
||||
network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8])
|
||||
|
||||
Reference in New Issue
Block a user