Merge '2887-Align_Node_Types' into 3062-discriminators

This commit is contained in:
Marek Wolan
2025-02-04 14:04:40 +00:00
78 changed files with 1429 additions and 832 deletions

View File

@@ -1,6 +1,5 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
"""PrimAITE game - Encapsulates the simulation and agents."""
from ipaddress import IPv4Address
from typing import Dict, List, Optional, Union
import numpy as np
@@ -13,14 +12,8 @@ from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.creation import NetworkNodeAdder
from primaite.simulator.network.hardware.base import NetworkInterface, Node, NodeOperatingState, UserManager
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode
from primaite.simulator.network.hardware.nodes.network.router import Router
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.application import Application
@@ -274,66 +267,10 @@ class PrimaiteGame:
n_type = node_cfg["type"]
new_node = None
# Default PrimAITE nodes
if n_type == "computer":
new_node = Computer(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
default_gateway=node_cfg.get("default_gateway"),
dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
elif n_type == "server":
new_node = Server(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
default_gateway=node_cfg.get("default_gateway"),
dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
elif n_type == "switch":
new_node = Switch(
hostname=node_cfg["hostname"],
num_ports=int(node_cfg.get("num_ports", "8")),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
elif n_type == "router":
new_node = Router.from_config(node_cfg)
elif n_type == "firewall":
new_node = Firewall.from_config(node_cfg)
elif n_type == "wireless_router":
new_node = WirelessRouter.from_config(node_cfg, airspace=net.airspace)
elif n_type == "printer":
new_node = Printer(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=node_cfg["subnet_mask"],
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
# Handle extended nodes
elif n_type.lower() in Node._registry:
new_node = HostNode._registry[n_type](
hostname=node_cfg["hostname"],
ip_address=node_cfg.get("ip_address"),
subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
default_gateway=node_cfg.get("default_gateway"),
dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
elif n_type in NetworkNode._registry:
new_node = NetworkNode._registry[n_type](**node_cfg)
if n_type in Node._registry:
if n_type == "wireless-router":
node_cfg["airspace"] = net.airspace
new_node = Node._registry[n_type].from_config(config=node_cfg)
else:
msg = f"invalid node type {n_type} in config"
_LOGGER.error(msg)
@@ -341,11 +278,11 @@ class PrimaiteGame:
# TODO: handle simulation defaults more cleanly
if "node_start_up_duration" in defaults_config:
new_node.start_up_duration = defaults_config["node_startup_duration"]
new_node.config.start_up_duration = defaults_config["node_startup_duration"]
if "node_shut_down_duration" in defaults_config:
new_node.shut_down_duration = defaults_config["node_shut_down_duration"]
new_node.config.shut_down_duration = defaults_config["node_shut_down_duration"]
if "node_scan_duration" in defaults_config:
new_node.node_scan_duration = defaults_config["node_scan_duration"]
new_node.config.node_scan_duration = defaults_config["node_scan_duration"]
if "folder_scan_duration" in defaults_config:
new_node.file_system._default_folder_scan_duration = defaults_config["folder_scan_duration"]
if "folder_restore_duration" in defaults_config:
@@ -382,7 +319,7 @@ class PrimaiteGame:
service_class = SERVICE_TYPES_MAPPING[service_type]
if service_class is not None:
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
_LOGGER.debug(f"installing {service_type} on node {new_node.config.hostname}")
new_node.software_manager.install(service_class, software_config=service_cfg.get("options", {}))
new_service = new_node.software_manager.software[service_type]
@@ -400,7 +337,7 @@ class PrimaiteGame:
# TODO: handle simulation defaults more cleanly
if "service_fix_duration" in defaults_config:
new_service.fixing_duration = defaults_config["service_fix_duration"]
new_service.config.fixing_duration = defaults_config["service_fix_duration"]
if "service_restart_duration" in defaults_config:
new_service.restart_duration = defaults_config["service_restart_duration"]
if "service_install_duration" in defaults_config:
@@ -431,8 +368,8 @@ class PrimaiteGame:
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))
# temporarily set to 0 so all nodes are initially on
new_node.start_up_duration = 0
new_node.shut_down_duration = 0
new_node.config.start_up_duration = 0
new_node.config.shut_down_duration = 0
net.add_node(new_node)
# run through the power on step if the node is to be turned on at the start
@@ -440,8 +377,8 @@ class PrimaiteGame:
new_node.power_on()
# set start up and shut down duration
new_node.start_up_duration = int(node_cfg.get("start_up_duration", 3))
new_node.shut_down_duration = int(node_cfg.get("shut_down_duration", 3))
new_node.config.start_up_duration = int(node_cfg.get("start_up_duration", 3))
new_node.config.shut_down_duration = int(node_cfg.get("shut_down_duration", 3))
# 1.1 Create Node Sets
for node_set_cfg in node_sets_cfg:

View File

@@ -3,7 +3,7 @@
"""Core of the PrimAITE Simulator."""
import warnings
from abc import abstractmethod
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
from uuid import uuid4
from prettytable import PrettyTable

View File

@@ -178,7 +178,7 @@ class AirSpace(BaseModel):
status = "Enabled" if interface.enabled else "Disabled"
table.add_row(
[
interface._connected_node.hostname, # noqa
interface._connected_node.config.hostname, # noqa
interface.mac_address,
interface.ip_address if hasattr(interface, "ip_address") else None,
interface.subnet_mask if hasattr(interface, "subnet_mask") else None,
@@ -320,7 +320,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC):
self.enabled = True
self._connected_node.sys_log.info(f"Network Interface {self} enabled")
self.pcap = PacketCapture(
hostname=self._connected_node.hostname, port_num=self.port_num, port_name=self.port_name
hostname=self._connected_node.config.hostname, port_num=self.port_num, port_name=self.port_name
)
self.airspace.add_wireless_interface(self)

View File

@@ -180,7 +180,7 @@ class Network(SimComponent):
table.align = "l"
table.title = "Nodes"
for node in self.nodes.values():
table.add_row((node.hostname, type(node)._discriminator, node.operating_state.name))
table.add_row((node.config.hostname, type(node)._discriminator, node.operating_state.name))
print(table)
if ip_addresses:
@@ -196,7 +196,13 @@ class Network(SimComponent):
if port.ip_address != IPv4Address("127.0.0.1"):
port_str = port.port_name if port.port_name else port.port_num
table.add_row(
[node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway]
[
node.config.hostname,
port_str,
port.ip_address,
port.subnet_mask,
node.config.default_gateway,
]
)
print(table)
@@ -215,9 +221,9 @@ class Network(SimComponent):
if node in [link.endpoint_a.parent, link.endpoint_b.parent]:
table.add_row(
[
link.endpoint_a.parent.hostname,
link.endpoint_a.parent.config.hostname,
str(link.endpoint_a),
link.endpoint_b.parent.hostname,
link.endpoint_b.parent.config.hostname,
str(link.endpoint_b),
link.is_up,
link.bandwidth,
@@ -251,7 +257,7 @@ class Network(SimComponent):
state = super().describe_state()
state.update(
{
"nodes": {node.hostname: node.describe_state() for node in self.nodes.values()},
"nodes": {node.config.hostname: node.describe_state() for node in self.nodes.values()},
"links": {},
}
)
@@ -259,8 +265,8 @@ class Network(SimComponent):
for _, link in self.links.items():
node_a = link.endpoint_a._connected_node
node_b = link.endpoint_b._connected_node
hostname_a = node_a.hostname if node_a else None
hostname_b = node_b.hostname if node_b else None
hostname_a = node_a.config.hostname if node_a else None
hostname_b = node_b.config.hostname if node_b else None
port_a = link.endpoint_a.port_num
port_b = link.endpoint_b.port_num
link_key = f"{hostname_a}:eth-{port_a}<->{hostname_b}:eth-{port_b}"
@@ -286,9 +292,11 @@ class Network(SimComponent):
self.nodes[node.uuid] = node
self._node_id_map[len(self.nodes)] = node
node.parent = self
self._nx_graph.add_node(node.hostname)
self._nx_graph.add_node(node.config.hostname)
_LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}")
self._node_request_manager.add_request(name=node.hostname, request_type=RequestType(func=node._request_manager))
self._node_request_manager.add_request(
name=node.config.hostname, request_type=RequestType(func=node._request_manager)
)
def get_node_by_hostname(self, hostname: str) -> Optional[Node]:
"""
@@ -300,7 +308,7 @@ class Network(SimComponent):
:return: The Node if it exists in the network.
"""
for node in self.nodes.values():
if node.hostname == hostname:
if node.config.hostname == hostname:
return node
def remove_node(self, node: Node) -> None:
@@ -313,7 +321,7 @@ class Network(SimComponent):
:type node: Node
"""
if node not in self:
_LOGGER.warning(f"Can't remove node {node.hostname}. It's not in the network.")
_LOGGER.warning(f"Can't remove node {node.config.hostname}. It's not in the network.")
return
self.nodes.pop(node.uuid)
for i, _node in self._node_id_map.items():
@@ -321,8 +329,8 @@ class Network(SimComponent):
self._node_id_map.pop(i)
break
node.parent = None
self._node_request_manager.remove_request(name=node.hostname)
_LOGGER.info(f"Removed node {node.hostname} from network {self.uuid}")
self._node_request_manager.remove_request(name=node.config.hostname)
_LOGGER.info(f"Removed node {node.config.hostname} from network {self.uuid}")
def connect(
self, endpoint_a: WiredNetworkInterface, endpoint_b: WiredNetworkInterface, bandwidth: int = 100, **kwargs
@@ -352,7 +360,7 @@ class Network(SimComponent):
link = Link(endpoint_a=endpoint_a, endpoint_b=endpoint_b, bandwidth=bandwidth, **kwargs)
self.links[link.uuid] = link
self._link_id_map[len(self.links)] = link
self._nx_graph.add_edge(endpoint_a.parent.hostname, endpoint_b.parent.hostname)
self._nx_graph.add_edge(endpoint_a.parent.config.hostname, endpoint_b.parent.config.hostname)
link.parent = self
_LOGGER.debug(f"Added link {link.uuid} to connect {endpoint_a} and {endpoint_b}")
return link

View File

@@ -154,7 +154,9 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"):
# Create a core switch if more than one edge switch is needed
if num_of_switches > 1:
core_switch = Switch(hostname=f"switch_core_{config.lan_name}", start_up_duration=0)
core_switch = Switch.from_config(
config={"type": "switch", "hostname": f"switch_core_{config.lan_name}", "start_up_duration": 0}
)
core_switch.power_on()
network.add_node(core_switch)
core_switch_port = 1
@@ -165,7 +167,10 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"):
# Optionally include a router in the LAN
if config.include_router:
default_gateway = IPv4Address(f"192.168.{config.subnet_base}.1")
router = Router(hostname=f"router_{config.lan_name}", start_up_duration=0)
# router = Router(hostname=f"router_{config.lan_name}", start_up_duration=0)
router = Router.from_config(
config={"hostname": f"router_{config.lan_name}", "type": "router", "start_up_duration": 0}
)
router.power_on()
router.acl.add_rule(
action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22
@@ -178,7 +183,9 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"):
# Initialise the first edge switch and connect to the router or core switch
switch_port = 0
switch_n = 1
switch = Switch(hostname=f"switch_edge_{switch_n}_{config.lan_name}", start_up_duration=0)
switch = Switch.from_config(
config={"type": "switch", "hostname": f"switch_edge_{switch_n}_{config.lan_name}", "start_up_duration": 0}
)
switch.power_on()
network.add_node(switch)
if num_of_switches > 1:
@@ -196,7 +203,13 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"):
if switch_port == effective_network_interface:
switch_n += 1
switch_port = 0
switch = Switch(hostname=f"switch_edge_{switch_n}_{config.lan_name}", start_up_duration=0)
switch = Switch.from_config(
config={
"type": "switch",
"hostname": f"switch_edge_{switch_n}_{config.lan_name}",
"start_up_duration": 0,
}
)
switch.power_on()
network.add_node(switch)
# Connect the new switch to the router or core switch
@@ -213,13 +226,14 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"):
)
# Create and add a PC to the network
pc = Computer(
hostname=f"pc_{i}_{config.lan_name}",
ip_address=f"192.168.{config.subnet_base}.{i+config.pcs_ip_block_start-1}",
subnet_mask="255.255.255.0",
default_gateway=default_gateway,
start_up_duration=0,
)
pc_cfg = {
"type": "computer",
"hostname": f"pc_{i}_{config.lan_name}",
"ip_address": f"192.168.{config.subnet_base}.{i+config.pcs_ip_block_start-1}",
"default_gateway": "192.168.10.1",
"start_up_duration": 0,
}
pc = Computer.from_config(config=pc_cfg)
pc.power_on()
network.add_node(pc)

View File

@@ -9,7 +9,7 @@ from pathlib import Path
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel, Field, validate_call
from pydantic import BaseModel, ConfigDict, Field, validate_call
from primaite import getLogger
from primaite.exceptions import NetworkError
@@ -431,7 +431,7 @@ class WiredNetworkInterface(NetworkInterface, ABC):
self.enabled = True
self._connected_node.sys_log.info(f"Network Interface {self} enabled")
self.pcap = PacketCapture(
hostname=self._connected_node.hostname, port_num=self.port_num, port_name=self.port_name
hostname=self._connected_node.config.hostname, port_num=self.port_num, port_name=self.port_name
)
if self._connected_link:
self._connected_link.endpoint_up()
@@ -1494,19 +1494,12 @@ class Node(SimComponent, ABC):
:param operating_state: The node operating state, either ON or OFF.
"""
hostname: str
"The node hostname on the network."
default_gateway: Optional[IPV4Address] = None
"The default gateway IP address for forwarding network traffic to other networks."
operating_state: NodeOperatingState = NodeOperatingState.OFF
"The hardware state of the node."
network_interfaces: Dict[str, NetworkInterface] = {}
"The Network Interfaces on the node."
network_interface: Dict[int, NetworkInterface] = {}
"The Network Interfaces on the node by port id."
dns_server: Optional[IPv4Address] = None
"List of IP addresses of DNS servers used for name resolution."
accounts: Dict[str, Account] = {}
"All accounts on the node."
applications: Dict[str, Application] = {}
@@ -1523,33 +1516,6 @@ class Node(SimComponent, ABC):
session_manager: SessionManager
software_manager: SoftwareManager
revealed_to_red: bool = False
"Informs whether the node has been revealed to a red agent."
start_up_duration: int = 3
"Time steps needed for the node to start up."
start_up_countdown: int = 0
"Time steps needed until node is booted up."
shut_down_duration: int = 3
"Time steps needed for the node to shut down."
shut_down_countdown: int = 0
"Time steps needed until node is shut down."
is_resetting: bool = False
"If true, the node will try turning itself off then back on again."
node_scan_duration: int = 10
"How many timesteps until the whole node is scanned. Default 10 time steps."
node_scan_countdown: int = 0
"Time steps until scan is complete"
red_scan_countdown: int = 0
"Time steps until reveal to red scan is complete."
SYSTEM_SOFTWARE: ClassVar[Dict[str, Type[Software]]] = {}
"Base system software that must be preinstalled."
@@ -1560,6 +1526,67 @@ class Node(SimComponent, ABC):
_discriminator: ClassVar[str]
"""discriminator for this particular class, used for printing and logging. Each subclass redefines this."""
config: Node.ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema())
"""Configuration items within Node"""
class ConfigSchema(BaseModel, ABC):
"""Configuration Schema for Node based classes."""
model_config = ConfigDict(arbitrary_types_allowed=True)
"""Configure pydantic to allow arbitrary types, let the instance have attributes not present in the model."""
hostname: str = "default"
"The node hostname on the network."
revealed_to_red: bool = False
"Informs whether the node has been revealed to a red agent."
start_up_duration: int = 3
"Time steps needed for the node to start up."
start_up_countdown: int = 0
"Time steps needed until node is booted up."
shut_down_duration: int = 3
"Time steps needed for the node to shut down."
shut_down_countdown: int = 0
"Time steps needed until node is shut down."
is_resetting: bool = False
"If true, the node will try turning itself off then back on again."
node_scan_duration: int = 10
"How many timesteps until the whole node is scanned. Default 10 time steps."
node_scan_countdown: int = 0
"Time steps until scan is complete"
red_scan_countdown: int = 0
"Time steps until reveal to red scan is complete."
dns_server: Optional[IPv4Address] = None
"List of IP addresses of DNS servers used for name resolution."
default_gateway: Optional[IPV4Address] = None
"The default gateway IP address for forwarding network traffic to other networks."
operating_state: Any = None
@property
def dns_server(self) -> Optional[IPv4Address]:
"""Convenience method to access the dns_server IP."""
return self.config.dns_server
@classmethod
def from_config(cls, config: Dict) -> "Node":
"""Create Node object from a given configuration dictionary."""
if config["type"] not in cls._registry:
msg = f"Configuration contains an invalid Node type: {config['type']}"
return ValueError(msg)
obj = cls(config=cls.ConfigSchema(**config))
return obj
def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None:
"""
Register a node type.
@@ -1585,11 +1612,11 @@ class Node(SimComponent, ABC):
provided.
"""
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(kwargs["hostname"])
kwargs["sys_log"] = SysLog(kwargs["config"].hostname)
if not kwargs.get("session_manager"):
kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"))
if not kwargs.get("root"):
kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"]
kwargs["root"] = SIM_OUTPUT.path / kwargs["config"].hostname
if not kwargs.get("file_system"):
kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs")
if not kwargs.get("software_manager"):
@@ -1598,9 +1625,12 @@ class Node(SimComponent, ABC):
sys_log=kwargs.get("sys_log"),
session_manager=kwargs.get("session_manager"),
file_system=kwargs.get("file_system"),
dns_server=kwargs.get("dns_server"),
dns_server=kwargs["config"].dns_server,
)
super().__init__(**kwargs)
self.operating_state = (
NodeOperatingState.ON if not (p := kwargs["config"].operating_state) else NodeOperatingState[p.upper()]
)
self._install_system_software()
self.session_manager.node = self
self.session_manager.software_manager = self.software_manager
@@ -1694,7 +1724,7 @@ class Node(SimComponent, ABC):
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return f"Cannot perform request on node '{self.node.hostname}' because it is not powered on."
return f"Cannot perform request on node '{self.node.config.hostname}' because it is not powered on."
class _NodeIsOffValidator(RequestPermissionValidator):
"""
@@ -1713,7 +1743,7 @@ class Node(SimComponent, ABC):
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return f"Cannot perform request on node '{self.node.hostname}' because it is not turned off."
return f"Cannot perform request on node '{self.node.config.hostname}' because it is not turned off."
def _init_request_manager(self) -> RequestManager:
"""
@@ -1741,7 +1771,7 @@ class Node(SimComponent, ABC):
self.software_manager.install(application_class)
application_instance = self.software_manager.software.get(application_name)
self.applications[application_instance.uuid] = application_instance
_LOGGER.debug(f"Added application {application_instance.name} to node {self.hostname}")
_LOGGER.debug(f"Added application {application_instance.name} to node {self.config.hostname}")
self._application_request_manager.add_request(
application_name, RequestType(func=application_instance._request_manager)
)
@@ -1855,7 +1885,7 @@ class Node(SimComponent, ABC):
state = super().describe_state()
state.update(
{
"hostname": self.hostname,
"hostname": self.config.hostname,
"operating_state": self.operating_state.value,
"NICs": {
eth_num: network_interface.describe_state()
@@ -1865,7 +1895,7 @@ class Node(SimComponent, ABC):
"applications": {app.name: app.describe_state() for app in self.applications.values()},
"services": {svc.name: svc.describe_state() for svc in self.services.values()},
"process": {proc.name: proc.describe_state() for proc in self.processes.values()},
"revealed_to_red": self.revealed_to_red,
"revealed_to_red": self.config.revealed_to_red,
}
)
return state
@@ -1881,7 +1911,7 @@ class Node(SimComponent, ABC):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Open Ports"
table.title = f"{self.config.hostname} Open Ports"
for port in self.software_manager.get_open_ports():
if port > 0:
table.add_row([port])
@@ -1908,7 +1938,7 @@ class Node(SimComponent, ABC):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Network Interface Cards"
table.title = f"{self.config.hostname} Network Interface Cards"
for port, network_interface in self.network_interface.items():
ip_address = ""
if hasattr(network_interface, "ip_address"):
@@ -1943,38 +1973,38 @@ class Node(SimComponent, ABC):
network_interface.apply_timestep(timestep=timestep)
# count down to boot up
if self.start_up_countdown > 0:
self.start_up_countdown -= 1
if self.config.start_up_countdown > 0:
self.config.start_up_countdown -= 1
else:
if self.operating_state == NodeOperatingState.BOOTING:
self.operating_state = NodeOperatingState.ON
self.sys_log.info(f"{self.hostname}: Turned on")
self.sys_log.info(f"{self.config.hostname}: Turned on")
for network_interface in self.network_interfaces.values():
network_interface.enable()
self._start_up_actions()
# count down to shut down
if self.shut_down_countdown > 0:
self.shut_down_countdown -= 1
if self.config.shut_down_countdown > 0:
self.config.shut_down_countdown -= 1
else:
if self.operating_state == NodeOperatingState.SHUTTING_DOWN:
self.operating_state = NodeOperatingState.OFF
self.sys_log.info(f"{self.hostname}: Turned off")
self.sys_log.info(f"{self.config.hostname}: Turned off")
self._shut_down_actions()
# if resetting turn back on
if self.is_resetting:
self.is_resetting = False
if self.config.is_resetting:
self.config.is_resetting = False
self.power_on()
# time steps which require the node to be on
if self.operating_state == NodeOperatingState.ON:
# node scanning
if self.node_scan_countdown > 0:
self.node_scan_countdown -= 1
if self.config.node_scan_countdown > 0:
self.config.node_scan_countdown -= 1
if self.node_scan_countdown == 0:
if self.config.node_scan_countdown == 0:
# scan everything!
for process_id in self.processes:
self.processes[process_id].scan()
@@ -1990,10 +2020,10 @@ class Node(SimComponent, ABC):
# scan file system
self.file_system.scan(instant_scan=True)
if self.red_scan_countdown > 0:
self.red_scan_countdown -= 1
if self.config.red_scan_countdown > 0:
self.config.red_scan_countdown -= 1
if self.red_scan_countdown == 0:
if self.config.red_scan_countdown == 0:
# scan processes
for process_id in self.processes:
self.processes[process_id].reveal_to_red()
@@ -2050,7 +2080,7 @@ class Node(SimComponent, ABC):
to the red agent.
"""
self.node_scan_countdown = self.node_scan_duration
self.config.node_scan_countdown = self.config.node_scan_duration
return True
def reveal_to_red(self) -> bool:
@@ -2066,12 +2096,12 @@ class Node(SimComponent, ABC):
`revealed_to_red` to `True`.
"""
self.red_scan_countdown = self.node_scan_duration
self.config.red_scan_countdown = self.config.node_scan_duration
return True
def power_on(self) -> bool:
"""Power on the Node, enabling its NICs if it is in the OFF state."""
if self.start_up_duration <= 0:
if self.config.start_up_duration <= 0:
self.operating_state = NodeOperatingState.ON
self._start_up_actions()
self.sys_log.info("Power on")
@@ -2080,14 +2110,14 @@ class Node(SimComponent, ABC):
return True
if self.operating_state == NodeOperatingState.OFF:
self.operating_state = NodeOperatingState.BOOTING
self.start_up_countdown = self.start_up_duration
self.config.start_up_countdown = self.config.start_up_duration
return True
return False
def power_off(self) -> bool:
"""Power off the Node, disabling its NICs if it is in the ON state."""
if self.shut_down_duration <= 0:
if self.config.shut_down_duration <= 0:
self._shut_down_actions()
self.operating_state = NodeOperatingState.OFF
self.sys_log.info("Power off")
@@ -2096,7 +2126,7 @@ class Node(SimComponent, ABC):
for network_interface in self.network_interfaces.values():
network_interface.disable()
self.operating_state = NodeOperatingState.SHUTTING_DOWN
self.shut_down_countdown = self.shut_down_duration
self.config.shut_down_countdown = self.config.shut_down_duration
return True
return False
@@ -2108,7 +2138,7 @@ class Node(SimComponent, ABC):
Applying more timesteps will eventually turn the node back on.
"""
if self.operating_state.ON:
self.is_resetting = True
self.config.is_resetting = True
self.sys_log.info("Resetting")
self.power_off()
return True
@@ -2225,6 +2255,8 @@ class Node(SimComponent, ABC):
# Turn on all the applications in the node
for app_id in self.applications:
print(app_id)
print(f"Starting application:{self.applications[app_id].config.type}")
self.applications[app_id].run()
# Turn off all processes in the node

View File

@@ -1,6 +1,8 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from typing import ClassVar, Dict
from pydantic import Field
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
@@ -35,4 +37,11 @@ class Computer(HostNode, discriminator="computer"):
SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "ftp-client": FTPClient}
config: "Computer.ConfigSchema" = Field(default_factory=lambda: Computer.ConfigSchema())
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Computer class."""
hostname: str = "Computer"
pass

View File

@@ -4,6 +4,8 @@ from __future__ import annotations
from ipaddress import IPv4Address
from typing import Any, ClassVar, Dict, Optional
from pydantic import Field
from primaite import getLogger
from primaite.simulator.network.hardware.base import (
IPWiredNetworkInterface,
@@ -44,8 +46,8 @@ class HostARP(ARP):
:return: The MAC address of the default gateway if present in the ARP cache; otherwise, None.
"""
if self.software_manager.node.default_gateway:
return self.get_arp_cache_mac_address(self.software_manager.node.default_gateway)
if self.software_manager.node.config.default_gateway:
return self.get_arp_cache_mac_address(self.software_manager.node.config.default_gateway)
def get_default_gateway_network_interface(self) -> Optional[NIC]:
"""
@@ -53,8 +55,11 @@ class HostARP(ARP):
:return: The NIC associated with the default gateway if it exists in the ARP cache; otherwise, None.
"""
if self.software_manager.node.default_gateway and self.software_manager.node.has_enabled_network_interface:
return self.get_arp_cache_network_interface(self.software_manager.node.default_gateway)
if (
self.software_manager.node.config.default_gateway
and self.software_manager.node.has_enabled_network_interface
):
return self.get_arp_cache_network_interface(self.software_manager.node.config.default_gateway)
def _get_arp_cache_mac_address(
self, ip_address: IPV4Address, is_reattempt: bool = False, is_default_gateway_attempt: bool = False
@@ -73,7 +78,7 @@ class HostARP(ARP):
if arp_entry:
return arp_entry.mac_address
if ip_address == self.software_manager.node.default_gateway:
if ip_address == self.software_manager.node.config.default_gateway:
is_reattempt = True
if not is_reattempt:
self.send_arp_request(ip_address)
@@ -81,11 +86,11 @@ class HostARP(ARP):
ip_address=ip_address, is_reattempt=True, is_default_gateway_attempt=is_default_gateway_attempt
)
else:
if self.software_manager.node.default_gateway:
if self.software_manager.node.config.default_gateway:
if not is_default_gateway_attempt:
self.send_arp_request(self.software_manager.node.default_gateway)
self.send_arp_request(self.software_manager.node.config.default_gateway)
return self._get_arp_cache_mac_address(
ip_address=self.software_manager.node.default_gateway,
ip_address=self.software_manager.node.config.default_gateway,
is_reattempt=True,
is_default_gateway_attempt=True,
)
@@ -116,7 +121,7 @@ class HostARP(ARP):
if arp_entry:
return self.software_manager.node.network_interfaces[arp_entry.network_interface_uuid]
else:
if ip_address == self.software_manager.node.default_gateway:
if ip_address == self.software_manager.node.config.default_gateway:
is_reattempt = True
if not is_reattempt:
self.send_arp_request(ip_address)
@@ -124,11 +129,11 @@ class HostARP(ARP):
ip_address=ip_address, is_reattempt=True, is_default_gateway_attempt=is_default_gateway_attempt
)
else:
if self.software_manager.node.default_gateway:
if self.software_manager.node.config.default_gateway:
if not is_default_gateway_attempt:
self.send_arp_request(self.software_manager.node.default_gateway)
self.send_arp_request(self.software_manager.node.config.default_gateway)
return self._get_arp_cache_network_interface(
ip_address=self.software_manager.node.default_gateway,
ip_address=self.software_manager.node.config.default_gateway,
is_reattempt=True,
is_default_gateway_attempt=True,
)
@@ -325,9 +330,18 @@ class HostNode(Node, discriminator="host-node"):
network_interface: Dict[int, NIC] = {}
"The NICs on the node by port id."
def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs):
config: HostNode.ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema())
class ConfigSchema(Node.ConfigSchema):
"""Configuration Schema for HostNode class."""
hostname: str = "HostNode"
subnet_mask: IPV4Address = "255.255.255.0"
ip_address: IPV4Address
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask))
self.connect_nic(NIC(ip_address=kwargs["config"].ip_address, subnet_mask=kwargs["config"].subnet_mask))
@property
def nmap(self) -> Optional[NMAP]:

View File

@@ -1,4 +1,7 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from pydantic import Field
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
@@ -30,8 +33,22 @@ class Server(HostNode, discriminator="server"):
* Web Browser
"""
config: "Server.ConfigSchema" = Field(default_factory=lambda: Server.ConfigSchema())
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Server class."""
hostname: str = "server"
class Printer(HostNode, discriminator="printer"):
"""Printer? I don't even know her!."""
# TODO: Implement printer-specific behaviour
config: "Printer.ConfigSchema" = Field(default_factory=lambda: Printer.ConfigSchema())
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Printer class."""
hostname: str = "printer"

View File

@@ -6,7 +6,6 @@ from prettytable import MARKDOWN, PrettyTable
from pydantic import Field, validate_call
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.network.router import (
AccessControlList,
ACLAction,
@@ -99,11 +98,21 @@ class Firewall(Router, discriminator="firewall"):
)
"""Access Control List for managing traffic leaving towards an external network."""
def __init__(self, hostname: str, **kwargs):
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(hostname)
_identifier: str = "firewall"
super().__init__(hostname=hostname, num_ports=0, **kwargs)
config: "Firewall.ConfigSchema" = Field(default_factory=lambda: Firewall.ConfigSchema())
class ConfigSchema(Router.ConfigSchema):
"""Configuration Schema for Firewall 'Nodes' within PrimAITE."""
hostname: str = "firewall"
num_ports: int = 0
def __init__(self, **kwargs):
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(kwargs["config"].hostname)
super().__init__(**kwargs)
self.connect_nic(
RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", port_name="external")
@@ -116,22 +125,23 @@ class Firewall(Router, discriminator="firewall"):
)
# Update ACL objects with firewall's hostname and syslog to allow accurate logging
self.internal_inbound_acl.sys_log = kwargs["sys_log"]
self.internal_inbound_acl.name = f"{hostname} - Internal Inbound"
self.internal_inbound_acl.name = f"{kwargs['config'].hostname} - Internal Inbound"
self.internal_outbound_acl.sys_log = kwargs["sys_log"]
self.internal_outbound_acl.name = f"{hostname} - Internal Outbound"
self.internal_outbound_acl.name = f"{kwargs['config'].hostname} - Internal Outbound"
self.dmz_inbound_acl.sys_log = kwargs["sys_log"]
self.dmz_inbound_acl.name = f"{hostname} - DMZ Inbound"
self.dmz_inbound_acl.name = f"{kwargs['config'].hostname} - DMZ Inbound"
self.dmz_outbound_acl.sys_log = kwargs["sys_log"]
self.dmz_outbound_acl.name = f"{hostname} - DMZ Outbound"
self.dmz_outbound_acl.name = f"{kwargs['config'].hostname} - DMZ Outbound"
self.external_inbound_acl.sys_log = kwargs["sys_log"]
self.external_inbound_acl.name = f"{hostname} - External Inbound"
self.external_inbound_acl.name = f"{kwargs['config'].hostname} - External Inbound"
self.external_outbound_acl.sys_log = kwargs["sys_log"]
self.external_outbound_acl.name = f"{hostname} - External Outbound"
self.external_outbound_acl.name = f"{kwargs['config'].hostname} - External Outbound"
self.power_on()
def _init_request_manager(self) -> RequestManager:
"""
@@ -231,7 +241,7 @@ class Firewall(Router, discriminator="firewall"):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Network Interfaces"
table.title = f"{self.config.hostname} Network Interfaces"
ports = {"External": self.external_port, "Internal": self.internal_port, "DMZ": self.dmz_port}
for port, network_interface in ports.items():
table.add_row(
@@ -551,18 +561,14 @@ class Firewall(Router, discriminator="firewall"):
self.dmz_port.enable()
@classmethod
def from_config(cls, cfg: dict) -> "Firewall":
def from_config(cls, config: dict) -> "Firewall":
"""Create a firewall based on a config dict."""
firewall = Firewall(
hostname=cfg["hostname"],
operating_state=NodeOperatingState.ON
if not (p := cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
if "ports" in cfg:
internal_port = cfg["ports"]["internal_port"]
external_port = cfg["ports"]["external_port"]
dmz_port = cfg["ports"].get("dmz_port")
firewall = Firewall(config=cls.ConfigSchema(**config))
if "ports" in config:
internal_port = config["ports"]["internal_port"]
external_port = config["ports"]["external_port"]
dmz_port = config["ports"].get("dmz_port")
# configure internal port
firewall.configure_internal_port(
@@ -582,10 +588,10 @@ class Firewall(Router, discriminator="firewall"):
ip_address=IPV4Address(dmz_port.get("ip_address")),
subnet_mask=IPV4Address(dmz_port.get("subnet_mask", "255.255.255.0")),
)
if "acl" in cfg:
if "acl" in config:
# acl rules for internal_inbound_acl
if cfg["acl"]["internal_inbound_acl"]:
for r_num, r_cfg in cfg["acl"]["internal_inbound_acl"].items():
if config["acl"]["internal_inbound_acl"]:
for r_num, r_cfg in config["acl"]["internal_inbound_acl"].items():
firewall.internal_inbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -599,8 +605,8 @@ class Firewall(Router, discriminator="firewall"):
)
# acl rules for internal_outbound_acl
if cfg["acl"]["internal_outbound_acl"]:
for r_num, r_cfg in cfg["acl"]["internal_outbound_acl"].items():
if config["acl"]["internal_outbound_acl"]:
for r_num, r_cfg in config["acl"]["internal_outbound_acl"].items():
firewall.internal_outbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -614,8 +620,8 @@ class Firewall(Router, discriminator="firewall"):
)
# acl rules for dmz_inbound_acl
if cfg["acl"]["dmz_inbound_acl"]:
for r_num, r_cfg in cfg["acl"]["dmz_inbound_acl"].items():
if config["acl"]["dmz_inbound_acl"]:
for r_num, r_cfg in config["acl"]["dmz_inbound_acl"].items():
firewall.dmz_inbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -629,8 +635,8 @@ class Firewall(Router, discriminator="firewall"):
)
# acl rules for dmz_outbound_acl
if cfg["acl"]["dmz_outbound_acl"]:
for r_num, r_cfg in cfg["acl"]["dmz_outbound_acl"].items():
if config["acl"]["dmz_outbound_acl"]:
for r_num, r_cfg in config["acl"]["dmz_outbound_acl"].items():
firewall.dmz_outbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -644,8 +650,8 @@ class Firewall(Router, discriminator="firewall"):
)
# acl rules for external_inbound_acl
if cfg["acl"].get("external_inbound_acl"):
for r_num, r_cfg in cfg["acl"]["external_inbound_acl"].items():
if config["acl"].get("external_inbound_acl"):
for r_num, r_cfg in config["acl"]["external_inbound_acl"].items():
firewall.external_inbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -659,8 +665,8 @@ class Firewall(Router, discriminator="firewall"):
)
# acl rules for external_outbound_acl
if cfg["acl"].get("external_outbound_acl"):
for r_num, r_cfg in cfg["acl"]["external_outbound_acl"].items():
if config["acl"].get("external_outbound_acl"):
for r_num, r_cfg in config["acl"]["external_outbound_acl"].items():
firewall.external_outbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -673,16 +679,16 @@ class Firewall(Router, discriminator="firewall"):
position=r_num,
)
if "routes" in cfg:
for route in cfg.get("routes"):
if "routes" in config:
for route in config.get("routes"):
firewall.route_table.add_route(
address=IPv4Address(route.get("address")),
subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")),
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
metric=float(route.get("metric", 0)),
)
if "default_route" in cfg:
next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None)
if "default_route" in config:
next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None)
if next_hop_ip_address:
firewall.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)

View File

@@ -1207,7 +1207,6 @@ class Router(NetworkNode, discriminator="router"):
"Terminal": Terminal,
}
num_ports: int
network_interfaces: Dict[str, RouterInterface] = {}
"The Router Interfaces on the node."
network_interface: Dict[int, RouterInterface] = {}
@@ -1215,19 +1214,29 @@ class Router(NetworkNode, discriminator="router"):
acl: AccessControlList
route_table: RouteTable
def __init__(self, hostname: str, num_ports: int = 5, **kwargs):
config: "Router.ConfigSchema"
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Routers."""
hostname: str = "router"
num_ports: int = 5
def __init__(self, **kwargs):
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(hostname)
kwargs["sys_log"] = SysLog(kwargs["config"].hostname)
if not kwargs.get("acl"):
kwargs["acl"] = AccessControlList(sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=hostname)
kwargs["acl"] = AccessControlList(
sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=kwargs["config"].hostname
)
if not kwargs.get("route_table"):
kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"])
super().__init__(hostname=hostname, num_ports=num_ports, **kwargs)
super().__init__(**kwargs)
self.session_manager = RouterSessionManager(sys_log=self.sys_log)
self.session_manager.node = self
self.software_manager.session_manager = self.session_manager
self.session_manager.software_manager = self.software_manager
for i in range(1, self.num_ports + 1):
for i in range(1, self.config.num_ports + 1):
network_interface = RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
self.connect_nic(network_interface)
self.network_interface[i] = network_interface
@@ -1337,7 +1346,7 @@ class Router(NetworkNode, discriminator="router"):
:return: A dictionary representing the current state.
"""
state = super().describe_state()
state["num_ports"] = self.num_ports
state["num_ports"] = self.config.num_ports
state["acl"] = self.acl.describe_state()
return state
@@ -1545,7 +1554,7 @@ class Router(NetworkNode, discriminator="router"):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Network Interfaces"
table.title = f"{self.config.hostname} Network Interfaces"
for port, network_interface in self.network_interface.items():
table.add_row(
[
@@ -1559,7 +1568,7 @@ class Router(NetworkNode, discriminator="router"):
print(table)
@classmethod
def from_config(cls, cfg: dict, **kwargs) -> "Router":
def from_config(cls, config: dict, **kwargs) -> "Router":
"""Create a router based on a config dict.
Schema:
@@ -1616,22 +1625,16 @@ class Router(NetworkNode, discriminator="router"):
:return: Configured router.
:rtype: Router
"""
router = Router(
hostname=cfg["hostname"],
num_ports=int(cfg.get("num_ports", "5")),
operating_state=NodeOperatingState.ON
if not (p := cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
if "ports" in cfg:
for port_num, port_cfg in cfg["ports"].items():
router = Router(config=Router.ConfigSchema(**config))
if "ports" in config:
for port_num, port_cfg in config["ports"].items():
router.configure_port(
port=port_num,
ip_address=port_cfg["ip_address"],
subnet_mask=IPv4Address(port_cfg.get("subnet_mask", "255.255.255.0")),
)
if "acl" in cfg:
for r_num, r_cfg in cfg["acl"].items():
if "acl" in config:
for r_num, r_cfg in config["acl"].items():
router.acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -1643,16 +1646,19 @@ class Router(NetworkNode, discriminator="router"):
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
if "routes" in cfg:
for route in cfg.get("routes"):
if "routes" in config:
for route in config.get("routes"):
router.route_table.add_route(
address=IPv4Address(route.get("address")),
subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")),
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
metric=float(route.get("metric", 0)),
)
if "default_route" in cfg:
next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None)
if "default_route" in config:
next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None)
if next_hop_ip_address:
router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
router.operating_state = (
NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()]
)
return router

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from pydantic import Field
from primaite import getLogger
from primaite.exceptions import NetworkError
@@ -88,14 +89,8 @@ class SwitchPort(WiredNetworkInterface):
class Switch(NetworkNode, discriminator="switch"):
"""
A class representing a Layer 2 network switch.
"""A class representing a Layer 2 network switch."""
:ivar num_ports: The number of ports on the switch. Default is 24.
"""
num_ports: int = 24
"The number of ports on the switch."
network_interfaces: Dict[str, SwitchPort] = {}
"The SwitchPorts on the Switch."
network_interface: Dict[int, SwitchPort] = {}
@@ -103,9 +98,18 @@ class Switch(NetworkNode, discriminator="switch"):
mac_address_table: Dict[str, SwitchPort] = {}
"A MAC address table mapping destination MAC addresses to corresponding SwitchPorts."
config: "Switch.ConfigSchema" = Field(default_factory=lambda: Switch.ConfigSchema())
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Switch nodes within PrimAITE."""
hostname: str = "Switch"
num_ports: int = 24
"The number of ports on the switch. Default is 24."
def __init__(self, **kwargs):
super().__init__(**kwargs)
for i in range(1, self.num_ports + 1):
for i in range(1, kwargs["config"].num_ports + 1):
self.connect_nic(SwitchPort())
def _install_system_software(self):
@@ -121,7 +125,7 @@ class Switch(NetworkNode, discriminator="switch"):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Switch Ports"
table.title = f"{self.config.hostname} Switch Ports"
for port_num, port in self.network_interface.items():
table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"])
print(table)
@@ -134,7 +138,7 @@ class Switch(NetworkNode, discriminator="switch"):
"""
state = super().describe_state()
state["ports"] = {port_num: port.describe_state() for port_num, port in self.network_interface.items()}
state["num_ports"] = self.num_ports # redundant?
state["num_ports"] = self.config.num_ports # redundant?
state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()}
return state

View File

@@ -2,7 +2,7 @@
from ipaddress import IPv4Address
from typing import Any, Dict, Optional, Union
from pydantic import validate_call
from pydantic import Field, validate_call
from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, FREQ_WIFI_2_4, IPWirelessNetworkInterface
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
@@ -122,13 +122,23 @@ class WirelessRouter(Router, discriminator="wireless-router"):
network_interfaces: Dict[str, Union[RouterInterface, WirelessAccessPoint]] = {}
network_interface: Dict[int, Union[RouterInterface, WirelessAccessPoint]] = {}
airspace: AirSpace
def __init__(self, hostname: str, airspace: AirSpace, **kwargs):
super().__init__(hostname=hostname, num_ports=0, airspace=airspace, **kwargs)
config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.ConfigSchema())
class ConfigSchema(Router.ConfigSchema):
"""Configuration Schema for WirelessRouter nodes within PrimAITE."""
hostname: str = "WirelessRouter"
airspace: AirSpace
num_ports: int = 0
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.connect_nic(
WirelessAccessPoint(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", airspace=airspace)
WirelessAccessPoint(
ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", airspace=kwargs["config"].airspace
)
)
self.connect_nic(RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0"))
@@ -226,7 +236,7 @@ class WirelessRouter(Router, discriminator="wireless-router"):
)
@classmethod
def from_config(cls, cfg: Dict, **kwargs) -> "WirelessRouter":
def from_config(cls, config: Dict, **kwargs) -> "WirelessRouter":
"""Generate the wireless router from config.
Schema:
@@ -253,22 +263,22 @@ class WirelessRouter(Router, discriminator="wireless-router"):
:return: WirelessRouter instance.
:rtype: WirelessRouter
"""
operating_state = (
NodeOperatingState.ON if not (p := cfg.get("operating_state")) else NodeOperatingState[p.upper()]
router = cls(config=cls.ConfigSchema(**config))
router.operating_state = (
NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()]
)
router = cls(hostname=cfg["hostname"], operating_state=operating_state, airspace=kwargs["airspace"])
if "router_interface" in cfg:
ip_address = cfg["router_interface"]["ip_address"]
subnet_mask = cfg["router_interface"]["subnet_mask"]
if "router_interface" in config:
ip_address = config["router_interface"]["ip_address"]
subnet_mask = config["router_interface"]["subnet_mask"]
router.configure_router_interface(ip_address=ip_address, subnet_mask=subnet_mask)
if "wireless_access_point" in cfg:
ip_address = cfg["wireless_access_point"]["ip_address"]
subnet_mask = cfg["wireless_access_point"]["subnet_mask"]
frequency = AirSpaceFrequency._registry[cfg["wireless_access_point"]["frequency"]]
if "wireless_access_point" in config:
ip_address = config["wireless_access_point"]["ip_address"]
subnet_mask = config["wireless_access_point"]["subnet_mask"]
frequency = AirSpaceFrequency._registry[config["wireless_access_point"]["frequency"]]
router.configure_wireless_access_point(ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency)
if "acl" in cfg:
for r_num, r_cfg in cfg["acl"].items():
if "acl" in config:
for r_num, r_cfg in config["acl"].items():
router.acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -280,8 +290,8 @@ class WirelessRouter(Router, discriminator="wireless-router"):
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
if "routes" in cfg:
for route in cfg.get("routes"):
if "routes" in config:
for route in config.get("routes"):
router.route_table.add_route(
address=IPv4Address(route.get("address")),
subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")),

View File

@@ -40,41 +40,45 @@ def client_server_routed() -> Network:
network = Network()
# Router 1
router_1 = Router(hostname="router_1", num_ports=3)
router_1 = Router(config=dict(hostname="router_1", num_ports=3))
router_1.power_on()
router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
router_1.configure_port(port=2, ip_address="192.168.2.1", subnet_mask="255.255.255.0")
# Switch 1
switch_1 = Switch(hostname="switch_1", num_ports=6)
switch_1 = Switch(config=dict(hostname="switch_1", num_ports=6))
switch_1.power_on()
network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[6])
router_1.enable_port(1)
# Switch 2
switch_2 = Switch(hostname="switch_2", num_ports=6)
switch_2 = Switch(config=dict(hostname="switch_2", num_ports=6))
switch_2.power_on()
network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[6])
router_1.enable_port(2)
# Client 1
client_1 = Computer(
hostname="client_1",
ip_address="192.168.2.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.2.1",
start_up_duration=0,
config=dict(
hostname="client_1",
ip_address="192.168.2.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.2.1",
start_up_duration=0,
)
)
client_1.power_on()
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
# Server 1
server_1 = Server(
hostname="server_1",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
config=dict(
hostname="server_1",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
)
server_1.power_on()
network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1])
@@ -128,32 +132,41 @@ def arcd_uc2_network() -> Network:
network = Network()
# Router 1
router_1 = Router(hostname="router_1", num_ports=5, start_up_duration=0)
router_1 = Router.from_config(
config={"type": "router", "hostname": "router_1", "num_ports": 5, "start_up_duration": 0}
)
router_1.power_on()
router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0")
# Switch 1
switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0)
switch_1 = Switch.from_config(
config={"type": "switch", "hostname": "switch_1", "num_ports": 8, "start_up_duration": 0}
)
switch_1.power_on()
network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[8])
router_1.enable_port(1)
# Switch 2
switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0)
switch_2 = Switch.from_config(
config={"type": "switch", "hostname": "switch_2", "num_ports": 8, "start_up_duration": 0}
)
switch_2.power_on()
network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8])
router_1.enable_port(2)
# Client 1
client_1 = Computer(
hostname="client_1",
ip_address="192.168.10.21",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
dns_server=IPv4Address("192.168.1.10"),
start_up_duration=0,
)
client_1_cfg = {
"type": "computer",
"hostname": "client_1",
"ip_address": "192.168.10.21",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.10.1",
"dns_server": IPv4Address("192.168.1.10"),
"start_up_duration": 0,
}
client_1: Computer = Computer.from_config(config=client_1_cfg)
client_1.power_on()
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
client_1.software_manager.install(DatabaseClient)
@@ -172,14 +185,18 @@ def arcd_uc2_network() -> Network:
)
# Client 2
client_2 = Computer(
hostname="client_2",
ip_address="192.168.10.22",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
dns_server=IPv4Address("192.168.1.10"),
start_up_duration=0,
)
client_2_cfg = {
"type": "computer",
"hostname": "client_2",
"ip_address": "192.168.10.22",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.10.1",
"dns_server": IPv4Address("192.168.1.10"),
"start_up_duration": 0,
}
client_2: Computer = Computer.from_config(config=client_2_cfg)
client_2.power_on()
client_2.software_manager.install(DatabaseClient)
db_client_2 = client_2.software_manager.software.get("database-client")
@@ -193,27 +210,36 @@ def arcd_uc2_network() -> Network:
)
# Domain Controller
domain_controller = Server(
hostname="domain_controller",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
domain_controller_cfg = {
"type": "server",
"hostname": "domain_controller",
"ip_address": "192.168.1.10",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
domain_controller = Server.from_config(config=domain_controller_cfg)
domain_controller.power_on()
domain_controller.software_manager.install(DNSServer)
network.connect(endpoint_b=domain_controller.network_interface[1], endpoint_a=switch_1.network_interface[1])
# Database Server
database_server = Server(
hostname="database_server",
ip_address="192.168.1.14",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
start_up_duration=0,
)
database_server_cfg = {
"type": "server",
"hostname": "database_server",
"ip_address": "192.168.1.14",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"dns_server": IPv4Address("192.168.1.10"),
"start_up_duration": 0,
}
database_server = Server.from_config(config=database_server_cfg)
database_server.power_on()
network.connect(endpoint_b=database_server.network_interface[1], endpoint_a=switch_1.network_interface[3])
@@ -223,14 +249,18 @@ def arcd_uc2_network() -> Network:
database_service.configure_backup(backup_server=IPv4Address("192.168.1.16"))
# Web Server
web_server = Server(
hostname="web_server",
ip_address="192.168.1.12",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
start_up_duration=0,
)
web_server_cfg = {
"type": "server",
"hostname": "web_server",
"ip_address": "192.168.1.11",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"dns_server": IPv4Address("192.168.1.10"),
"start_up_duration": 0,
}
web_server = Server.from_config(config=web_server_cfg)
web_server.power_on()
web_server.software_manager.install(DatabaseClient)
@@ -247,27 +277,32 @@ def arcd_uc2_network() -> Network:
dns_server_service.dns_register("arcd.com", web_server.network_interface[1].ip_address)
# Backup Server
backup_server = Server(
hostname="backup_server",
ip_address="192.168.1.16",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
start_up_duration=0,
)
backup_server_cfg = {
"type": "server",
"hostname": "backup_server",
"ip_address": "192.168.1.16",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"dns_server": IPv4Address("192.168.1.10"),
"start_up_duration": 0,
}
backup_server: Server = Server.from_config(config=backup_server_cfg)
backup_server.power_on()
backup_server.software_manager.install(FTPServer)
network.connect(endpoint_b=backup_server.network_interface[1], endpoint_a=switch_1.network_interface[4])
# Security Suite
security_suite = Server(
hostname="security_suite",
ip_address="192.168.1.110",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
start_up_duration=0,
)
security_suite_cfg = {
"type": "server",
"hostname": "backup_server",
"ip_address": "192.168.1.110",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"dns_server": IPv4Address("192.168.1.10"),
"start_up_duration": 0,
}
security_suite: Server = Server.from_config(config=security_suite_cfg)
security_suite.power_on()
network.connect(endpoint_b=security_suite.network_interface[1], endpoint_a=switch_1.network_interface[7])
security_suite.connect_nic(NIC(ip_address="192.168.10.110", subnet_mask="255.255.255.0"))

View File

@@ -208,7 +208,7 @@ class NMAP(Application, discriminator="nmap"):
if show:
table = PrettyTable(["IP Address", "Can Ping"])
table.align = "l"
table.title = f"{self.software_manager.node.hostname} NMAP Ping Scan"
table.title = f"{self.software_manager.node.config.hostname} NMAP Ping Scan"
ip_addresses = self._explode_ip_address_network_array(target_ip_address)
@@ -367,7 +367,7 @@ class NMAP(Application, discriminator="nmap"):
if show:
table = PrettyTable(["IP Address", "Port", "Protocol"])
table.align = "l"
table.title = f"{self.software_manager.node.hostname} NMAP Port Scan ({scan_type})"
table.title = f"{self.software_manager.node.config.hostname} NMAP Port Scan ({scan_type})"
self.sys_log.info(f"{self.name}: Starting port scan")
for ip_address in ip_addresses:
# Prevent port scan on this node

View File

@@ -140,6 +140,7 @@ class SoftwareManager:
elif isinstance(software, Service):
self.node.services[software.uuid] = software
self.node._service_request_manager.add_request(software.name, RequestType(func=software._request_manager))
software.start()
software.install()
software.software_manager = self
self.software[software.name] = software

View File

@@ -32,6 +32,7 @@ class DatabaseService(Service, discriminator="database-service"):
type: str = "database-service"
backup_server_ip: Optional[IPv4Address] = None
db_password: Optional[str] = None
"""Password that needs to be provided by clients if they want to connect to the DatabaseService."""
config: ConfigSchema = Field(default_factory=lambda: DatabaseService.ConfigSchema())
@@ -224,7 +225,7 @@ class DatabaseService(Service, discriminator="database-service"):
SoftwareHealthState.FIXING,
SoftwareHealthState.COMPROMISED,
]:
if self.password == password:
if self.config.db_password == password:
status_code = 200 # ok
connection_id = self._generate_connection_id()
# try to create connection

View File

@@ -36,7 +36,11 @@ class FTPServer(FTPServiceABC, discriminator="ftp-server"):
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
self.start()
self.server_password = self.config.server_password
@property
def server_password(self) -> Optional[str]:
"""Convenience method for accessing FTP server password."""
return self.config.server_password
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""

View File

@@ -26,8 +26,6 @@ class NTPClient(Service, discriminator="ntp-client"):
config: ConfigSchema = Field(default_factory=lambda: NTPClient.ConfigSchema())
ntp_server: Optional[IPv4Address] = None
"The NTP server the client sends requests to."
time: Optional[datetime] = None
def __init__(self, **kwargs):
@@ -45,8 +43,8 @@ class NTPClient(Service, discriminator="ntp-client"):
:param ntp_server_ip_address: IPv4 address of NTP server.
:param ntp_client_ip_Address: IPv4 address of NTP client.
"""
self.ntp_server = ntp_server_ip_address
self.sys_log.info(f"{self.name}: ntp_server: {self.ntp_server}")
self.config.ntp_server_ip = ntp_server_ip_address
self.sys_log.info(f"{self.name}: ntp_server: {self.config.ntp_server_ip}")
def describe_state(self) -> Dict:
"""
@@ -108,10 +106,10 @@ class NTPClient(Service, discriminator="ntp-client"):
def request_time(self) -> None:
"""Send request to ntp_server."""
if self.ntp_server:
if self.config.ntp_server_ip:
self.software_manager.session_manager.receive_payload_from_software_manager(
payload=NTPPacket(),
dst_ip_address=self.ntp_server,
dst_ip_address=self.config.ntp_server_ip,
src_port=self.port,
dst_port=self.port,
ip_protocol=self.protocol,

View File

@@ -315,7 +315,7 @@ class IOSoftware(Software, ABC):
"""
if self.software_manager and self.software_manager.node.operating_state != NodeOperatingState.ON:
self.software_manager.node.sys_log.error(
f"{self.name} Error: {self.software_manager.node.hostname} is not powered on."
f"{self.name} Error: {self.software_manager.node.config.hostname} is not powered on."
)
return False
return True