Merge branch '4.0.0a1-dev' into feature/3075_Migrate_notebooks_to_MilPac_(Core_changes)

This commit is contained in:
Nick Todd
2025-02-05 08:36:59 +00:00
78 changed files with 1434 additions and 846 deletions

View File

@@ -1,6 +1,5 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
"""PrimAITE game - Encapsulates the simulation and agents."""
from ipaddress import IPv4Address
from typing import Dict, List, Optional, Union
import numpy as np
@@ -13,14 +12,8 @@ from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.creation import NetworkNodeAdder
from primaite.simulator.network.hardware.base import NetworkInterface, Node, NodeOperatingState, UserManager
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode
from primaite.simulator.network.hardware.nodes.network.router import Router
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.application import Application
@@ -274,66 +267,10 @@ class PrimaiteGame:
n_type = node_cfg["type"]
new_node = None
# Default PrimAITE nodes
if n_type == "computer":
new_node = Computer(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
default_gateway=node_cfg.get("default_gateway"),
dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
elif n_type == "server":
new_node = Server(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
default_gateway=node_cfg.get("default_gateway"),
dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
elif n_type == "switch":
new_node = Switch(
hostname=node_cfg["hostname"],
num_ports=int(node_cfg.get("num_ports", "8")),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
elif n_type == "router":
new_node = Router.from_config(node_cfg)
elif n_type == "firewall":
new_node = Firewall.from_config(node_cfg)
elif n_type == "wireless_router":
new_node = WirelessRouter.from_config(node_cfg, airspace=net.airspace)
elif n_type == "printer":
new_node = Printer(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=node_cfg["subnet_mask"],
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
# Handle extended nodes
elif n_type.lower() in Node._registry:
new_node = HostNode._registry[n_type](
hostname=node_cfg["hostname"],
ip_address=node_cfg.get("ip_address"),
subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
default_gateway=node_cfg.get("default_gateway"),
dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
elif n_type in NetworkNode._registry:
new_node = NetworkNode._registry[n_type](**node_cfg)
if n_type in Node._registry:
if n_type == "wireless_router":
node_cfg["airspace"] = net.airspace
new_node = Node._registry[n_type].from_config(config=node_cfg)
else:
msg = f"invalid node type {n_type} in config"
_LOGGER.error(msg)
@@ -341,11 +278,11 @@ class PrimaiteGame:
# TODO: handle simulation defaults more cleanly
if "node_start_up_duration" in defaults_config:
new_node.start_up_duration = defaults_config["node_startup_duration"]
new_node.config.start_up_duration = defaults_config["node_startup_duration"]
if "node_shut_down_duration" in defaults_config:
new_node.shut_down_duration = defaults_config["node_shut_down_duration"]
new_node.config.shut_down_duration = defaults_config["node_shut_down_duration"]
if "node_scan_duration" in defaults_config:
new_node.node_scan_duration = defaults_config["node_scan_duration"]
new_node.config.node_scan_duration = defaults_config["node_scan_duration"]
if "folder_scan_duration" in defaults_config:
new_node.file_system._default_folder_scan_duration = defaults_config["folder_scan_duration"]
if "folder_restore_duration" in defaults_config:
@@ -382,7 +319,7 @@ class PrimaiteGame:
service_class = SERVICE_TYPES_MAPPING[service_type]
if service_class is not None:
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
_LOGGER.debug(f"installing {service_type} on node {new_node.config.hostname}")
new_node.software_manager.install(service_class, software_config=service_cfg.get("options", {}))
new_service = new_node.software_manager.software[service_class.__name__]
@@ -400,7 +337,7 @@ class PrimaiteGame:
# TODO: handle simulation defaults more cleanly
if "service_fix_duration" in defaults_config:
new_service.fixing_duration = defaults_config["service_fix_duration"]
new_service.config.fixing_duration = defaults_config["service_fix_duration"]
if "service_restart_duration" in defaults_config:
new_service.restart_duration = defaults_config["service_restart_duration"]
if "service_install_duration" in defaults_config:
@@ -431,8 +368,8 @@ class PrimaiteGame:
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))
# temporarily set to 0 so all nodes are initially on
new_node.start_up_duration = 0
new_node.shut_down_duration = 0
new_node.config.start_up_duration = 0
new_node.config.shut_down_duration = 0
net.add_node(new_node)
# run through the power on step if the node is to be turned on at the start
@@ -440,8 +377,8 @@ class PrimaiteGame:
new_node.power_on()
# set start up and shut down duration
new_node.start_up_duration = int(node_cfg.get("start_up_duration", 3))
new_node.shut_down_duration = int(node_cfg.get("shut_down_duration", 3))
new_node.config.start_up_duration = int(node_cfg.get("start_up_duration", 3))
new_node.config.shut_down_duration = int(node_cfg.get("shut_down_duration", 3))
# 1.1 Create Node Sets
for node_set_cfg in node_sets_cfg:

View File

@@ -178,7 +178,7 @@ class AirSpace(BaseModel):
status = "Enabled" if interface.enabled else "Disabled"
table.add_row(
[
interface._connected_node.hostname, # noqa
interface._connected_node.config.hostname, # noqa
interface.mac_address,
interface.ip_address if hasattr(interface, "ip_address") else None,
interface.subnet_mask if hasattr(interface, "subnet_mask") else None,
@@ -320,7 +320,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC):
self.enabled = True
self._connected_node.sys_log.info(f"Network Interface {self} enabled")
self.pcap = PacketCapture(
hostname=self._connected_node.hostname, port_num=self.port_num, port_name=self.port_name
hostname=self._connected_node.config.hostname, port_num=self.port_num, port_name=self.port_name
)
self.airspace.add_wireless_interface(self)

View File

@@ -180,7 +180,7 @@ class Network(SimComponent):
table.align = "l"
table.title = "Nodes"
for node in self.nodes.values():
table.add_row((node.hostname, type(node)._identifier, node.operating_state.name))
table.add_row((node.config.hostname, type(node)._identifier, node.operating_state.name))
print(table)
if ip_addresses:
@@ -196,7 +196,13 @@ class Network(SimComponent):
if port.ip_address != IPv4Address("127.0.0.1"):
port_str = port.port_name if port.port_name else port.port_num
table.add_row(
[node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway]
[
node.config.hostname,
port_str,
port.ip_address,
port.subnet_mask,
node.config.default_gateway,
]
)
print(table)
@@ -215,9 +221,9 @@ class Network(SimComponent):
if node in [link.endpoint_a.parent, link.endpoint_b.parent]:
table.add_row(
[
link.endpoint_a.parent.hostname,
link.endpoint_a.parent.config.hostname,
str(link.endpoint_a),
link.endpoint_b.parent.hostname,
link.endpoint_b.parent.config.hostname,
str(link.endpoint_b),
link.is_up,
link.bandwidth,
@@ -251,7 +257,7 @@ class Network(SimComponent):
state = super().describe_state()
state.update(
{
"nodes": {node.hostname: node.describe_state() for node in self.nodes.values()},
"nodes": {node.config.hostname: node.describe_state() for node in self.nodes.values()},
"links": {},
}
)
@@ -259,8 +265,8 @@ class Network(SimComponent):
for _, link in self.links.items():
node_a = link.endpoint_a._connected_node
node_b = link.endpoint_b._connected_node
hostname_a = node_a.hostname if node_a else None
hostname_b = node_b.hostname if node_b else None
hostname_a = node_a.config.hostname if node_a else None
hostname_b = node_b.config.hostname if node_b else None
port_a = link.endpoint_a.port_num
port_b = link.endpoint_b.port_num
link_key = f"{hostname_a}:eth-{port_a}<->{hostname_b}:eth-{port_b}"
@@ -286,9 +292,11 @@ class Network(SimComponent):
self.nodes[node.uuid] = node
self._node_id_map[len(self.nodes)] = node
node.parent = self
self._nx_graph.add_node(node.hostname)
self._nx_graph.add_node(node.config.hostname)
_LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}")
self._node_request_manager.add_request(name=node.hostname, request_type=RequestType(func=node._request_manager))
self._node_request_manager.add_request(
name=node.config.hostname, request_type=RequestType(func=node._request_manager)
)
def get_node_by_hostname(self, hostname: str) -> Optional[Node]:
"""
@@ -300,7 +308,7 @@ class Network(SimComponent):
:return: The Node if it exists in the network.
"""
for node in self.nodes.values():
if node.hostname == hostname:
if node.config.hostname == hostname:
return node
def remove_node(self, node: Node) -> None:
@@ -313,7 +321,7 @@ class Network(SimComponent):
:type node: Node
"""
if node not in self:
_LOGGER.warning(f"Can't remove node {node.hostname}. It's not in the network.")
_LOGGER.warning(f"Can't remove node {node.config.hostname}. It's not in the network.")
return
self.nodes.pop(node.uuid)
for i, _node in self._node_id_map.items():
@@ -321,8 +329,8 @@ class Network(SimComponent):
self._node_id_map.pop(i)
break
node.parent = None
self._node_request_manager.remove_request(name=node.hostname)
_LOGGER.info(f"Removed node {node.hostname} from network {self.uuid}")
self._node_request_manager.remove_request(name=node.config.hostname)
_LOGGER.info(f"Removed node {node.config.hostname} from network {self.uuid}")
def connect(
self, endpoint_a: WiredNetworkInterface, endpoint_b: WiredNetworkInterface, bandwidth: int = 100, **kwargs
@@ -352,7 +360,7 @@ class Network(SimComponent):
link = Link(endpoint_a=endpoint_a, endpoint_b=endpoint_b, bandwidth=bandwidth, **kwargs)
self.links[link.uuid] = link
self._link_id_map[len(self.links)] = link
self._nx_graph.add_edge(endpoint_a.parent.hostname, endpoint_b.parent.hostname)
self._nx_graph.add_edge(endpoint_a.parent.config.hostname, endpoint_b.parent.config.hostname)
link.parent = self
_LOGGER.debug(f"Added link {link.uuid} to connect {endpoint_a} and {endpoint_b}")
return link

View File

@@ -154,7 +154,9 @@ class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"):
# Create a core switch if more than one edge switch is needed
if num_of_switches > 1:
core_switch = Switch(hostname=f"switch_core_{config.lan_name}", start_up_duration=0)
core_switch = Switch.from_config(
config={"type": "switch", "hostname": f"switch_core_{config.lan_name}", "start_up_duration": 0}
)
core_switch.power_on()
network.add_node(core_switch)
core_switch_port = 1
@@ -165,7 +167,9 @@ class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"):
# Optionally include a router in the LAN
if config.include_router:
default_gateway = IPv4Address(f"192.168.{config.subnet_base}.1")
router = Router(hostname=f"router_{config.lan_name}", start_up_duration=0)
router = Router.from_config(
config={"hostname": f"router_{config.lan_name}", "type": "router", "start_up_duration": 0}
)
router.power_on()
router.acl.add_rule(
action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22
@@ -178,7 +182,9 @@ class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"):
# Initialise the first edge switch and connect to the router or core switch
switch_port = 0
switch_n = 1
switch = Switch(hostname=f"switch_edge_{switch_n}_{config.lan_name}", start_up_duration=0)
switch = Switch.from_config(
config={"type": "switch", "hostname": f"switch_edge_{switch_n}_{config.lan_name}", "start_up_duration": 0}
)
switch.power_on()
network.add_node(switch)
if num_of_switches > 1:
@@ -196,7 +202,13 @@ class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"):
if switch_port == effective_network_interface:
switch_n += 1
switch_port = 0
switch = Switch(hostname=f"switch_edge_{switch_n}_{config.lan_name}", start_up_duration=0)
switch = Switch.from_config(
config={
"type": "switch",
"hostname": f"switch_edge_{switch_n}_{config.lan_name}",
"start_up_duration": 0,
}
)
switch.power_on()
network.add_node(switch)
# Connect the new switch to the router or core switch
@@ -213,13 +225,14 @@ class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"):
)
# Create and add a PC to the network
pc = Computer(
hostname=f"pc_{i}_{config.lan_name}",
ip_address=f"192.168.{config.subnet_base}.{i+config.pcs_ip_block_start-1}",
subnet_mask="255.255.255.0",
default_gateway=default_gateway,
start_up_duration=0,
)
pc_cfg = {
"type": "computer",
"hostname": f"pc_{i}_{config.lan_name}",
"ip_address": f"192.168.{config.subnet_base}.{i+config.pcs_ip_block_start-1}",
"default_gateway": default_gateway,
"start_up_duration": 0,
}
pc = Computer.from_config(config=pc_cfg)
pc.power_on()
network.add_node(pc)

View File

@@ -9,7 +9,7 @@ from pathlib import Path
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel, Field, validate_call
from pydantic import BaseModel, ConfigDict, Field, validate_call
from primaite import getLogger
from primaite.exceptions import NetworkError
@@ -431,7 +431,7 @@ class WiredNetworkInterface(NetworkInterface, ABC):
self.enabled = True
self._connected_node.sys_log.info(f"Network Interface {self} enabled")
self.pcap = PacketCapture(
hostname=self._connected_node.hostname, port_num=self.port_num, port_name=self.port_name
hostname=self._connected_node.config.hostname, port_num=self.port_num, port_name=self.port_name
)
if self._connected_link:
self._connected_link.endpoint_up()
@@ -1483,7 +1483,7 @@ class UserSessionManager(Service, identifier="UserSessionManager"):
return self.local_session is not None
class Node(SimComponent):
class Node(SimComponent, ABC):
"""
A basic Node class that represents a node on the network.
@@ -1494,19 +1494,12 @@ class Node(SimComponent):
:param operating_state: The node operating state, either ON or OFF.
"""
hostname: str
"The node hostname on the network."
default_gateway: Optional[IPV4Address] = None
"The default gateway IP address for forwarding network traffic to other networks."
operating_state: NodeOperatingState = NodeOperatingState.OFF
"The hardware state of the node."
network_interfaces: Dict[str, NetworkInterface] = {}
"The Network Interfaces on the node."
network_interface: Dict[int, NetworkInterface] = {}
"The Network Interfaces on the node by port id."
dns_server: Optional[IPv4Address] = None
"List of IP addresses of DNS servers used for name resolution."
accounts: Dict[str, Account] = {}
"All accounts on the node."
applications: Dict[str, Application] = {}
@@ -1523,33 +1516,6 @@ class Node(SimComponent):
session_manager: SessionManager
software_manager: SoftwareManager
revealed_to_red: bool = False
"Informs whether the node has been revealed to a red agent."
start_up_duration: int = 3
"Time steps needed for the node to start up."
start_up_countdown: int = 0
"Time steps needed until node is booted up."
shut_down_duration: int = 3
"Time steps needed for the node to shut down."
shut_down_countdown: int = 0
"Time steps needed until node is shut down."
is_resetting: bool = False
"If true, the node will try turning itself off then back on again."
node_scan_duration: int = 10
"How many timesteps until the whole node is scanned. Default 10 time steps."
node_scan_countdown: int = 0
"Time steps until scan is complete"
red_scan_countdown: int = 0
"Time steps until reveal to red scan is complete."
SYSTEM_SOFTWARE: ClassVar[Dict[str, Type[Software]]] = {}
"Base system software that must be preinstalled."
@@ -1559,6 +1525,67 @@ class Node(SimComponent):
_identifier: ClassVar[str] = "unknown"
"""Identifier for this particular class, used for printing and logging. Each subclass redefines this."""
class ConfigSchema(BaseModel, ABC):
"""Configuration Schema for Node based classes."""
model_config = ConfigDict(arbitrary_types_allowed=True)
"""Configure pydantic to allow arbitrary types, let the instance have attributes not present in the model."""
hostname: str
"The node hostname on the network."
revealed_to_red: bool = False
"Informs whether the node has been revealed to a red agent."
start_up_duration: int = 3
"Time steps needed for the node to start up."
start_up_countdown: int = 0
"Time steps needed until node is booted up."
shut_down_duration: int = 3
"Time steps needed for the node to shut down."
shut_down_countdown: int = 0
"Time steps needed until node is shut down."
is_resetting: bool = False
"If true, the node will try turning itself off then back on again."
node_scan_duration: int = 10
"How many timesteps until the whole node is scanned. Default 10 time steps."
node_scan_countdown: int = 0
"Time steps until scan is complete"
red_scan_countdown: int = 0
"Time steps until reveal to red scan is complete."
dns_server: Optional[IPv4Address] = None
"List of IP addresses of DNS servers used for name resolution."
default_gateway: Optional[IPV4Address] = None
"The default gateway IP address for forwarding network traffic to other networks."
operating_state: Any = None
config: ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema())
"""Configuration items within Node"""
@property
def dns_server(self) -> Optional[IPv4Address]:
"""Convenience method to access the dns_server IP."""
return self.config.dns_server
@classmethod
def from_config(cls, config: Dict) -> "Node":
"""Create Node object from a given configuration dictionary."""
if config["type"] not in cls._registry:
msg = f"Configuration contains an invalid Node type: {config['type']}"
return ValueError(msg)
obj = cls(config=cls.ConfigSchema(**config))
return obj
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
"""
Register a node type.
@@ -1584,11 +1611,11 @@ class Node(SimComponent):
provided.
"""
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(kwargs["hostname"])
kwargs["sys_log"] = SysLog(kwargs["config"].hostname)
if not kwargs.get("session_manager"):
kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"))
if not kwargs.get("root"):
kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"]
kwargs["root"] = SIM_OUTPUT.path / kwargs["config"].hostname
if not kwargs.get("file_system"):
kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs")
if not kwargs.get("software_manager"):
@@ -1597,9 +1624,12 @@ class Node(SimComponent):
sys_log=kwargs.get("sys_log"),
session_manager=kwargs.get("session_manager"),
file_system=kwargs.get("file_system"),
dns_server=kwargs.get("dns_server"),
dns_server=kwargs["config"].dns_server,
)
super().__init__(**kwargs)
self.operating_state = (
NodeOperatingState.ON if not (p := kwargs["config"].operating_state) else NodeOperatingState[p.upper()]
)
self._install_system_software()
self.session_manager.node = self
self.session_manager.software_manager = self.software_manager
@@ -1693,7 +1723,7 @@ class Node(SimComponent):
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return f"Cannot perform request on node '{self.node.hostname}' because it is not powered on."
return f"Cannot perform request on node '{self.node.config.hostname}' because it is not powered on."
class _NodeIsOffValidator(RequestPermissionValidator):
"""
@@ -1712,7 +1742,7 @@ class Node(SimComponent):
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return f"Cannot perform request on node '{self.node.hostname}' because it is not turned off."
return f"Cannot perform request on node '{self.node.config.hostname}' because it is not turned off."
def _init_request_manager(self) -> RequestManager:
"""
@@ -1740,7 +1770,7 @@ class Node(SimComponent):
self.software_manager.install(application_class)
application_instance = self.software_manager.software.get(application_name)
self.applications[application_instance.uuid] = application_instance
_LOGGER.debug(f"Added application {application_instance.name} to node {self.hostname}")
_LOGGER.debug(f"Added application {application_instance.name} to node {self.config.hostname}")
self._application_request_manager.add_request(
application_name, RequestType(func=application_instance._request_manager)
)
@@ -1854,7 +1884,7 @@ class Node(SimComponent):
state = super().describe_state()
state.update(
{
"hostname": self.hostname,
"hostname": self.config.hostname,
"operating_state": self.operating_state.value,
"NICs": {
eth_num: network_interface.describe_state()
@@ -1864,7 +1894,7 @@ class Node(SimComponent):
"applications": {app.name: app.describe_state() for app in self.applications.values()},
"services": {svc.name: svc.describe_state() for svc in self.services.values()},
"process": {proc.name: proc.describe_state() for proc in self.processes.values()},
"revealed_to_red": self.revealed_to_red,
"revealed_to_red": self.config.revealed_to_red,
}
)
return state
@@ -1880,7 +1910,7 @@ class Node(SimComponent):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Open Ports"
table.title = f"{self.config.hostname} Open Ports"
for port in self.software_manager.get_open_ports():
if port > 0:
table.add_row([port])
@@ -1907,7 +1937,7 @@ class Node(SimComponent):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Network Interface Cards"
table.title = f"{self.config.hostname} Network Interface Cards"
for port, network_interface in self.network_interface.items():
ip_address = ""
if hasattr(network_interface, "ip_address"):
@@ -1942,38 +1972,38 @@ class Node(SimComponent):
network_interface.apply_timestep(timestep=timestep)
# count down to boot up
if self.start_up_countdown > 0:
self.start_up_countdown -= 1
if self.config.start_up_countdown > 0:
self.config.start_up_countdown -= 1
else:
if self.operating_state == NodeOperatingState.BOOTING:
self.operating_state = NodeOperatingState.ON
self.sys_log.info(f"{self.hostname}: Turned on")
self.sys_log.info(f"{self.config.hostname}: Turned on")
for network_interface in self.network_interfaces.values():
network_interface.enable()
self._start_up_actions()
# count down to shut down
if self.shut_down_countdown > 0:
self.shut_down_countdown -= 1
if self.config.shut_down_countdown > 0:
self.config.shut_down_countdown -= 1
else:
if self.operating_state == NodeOperatingState.SHUTTING_DOWN:
self.operating_state = NodeOperatingState.OFF
self.sys_log.info(f"{self.hostname}: Turned off")
self.sys_log.info(f"{self.config.hostname}: Turned off")
self._shut_down_actions()
# if resetting turn back on
if self.is_resetting:
self.is_resetting = False
if self.config.is_resetting:
self.config.is_resetting = False
self.power_on()
# time steps which require the node to be on
if self.operating_state == NodeOperatingState.ON:
# node scanning
if self.node_scan_countdown > 0:
self.node_scan_countdown -= 1
if self.config.node_scan_countdown > 0:
self.config.node_scan_countdown -= 1
if self.node_scan_countdown == 0:
if self.config.node_scan_countdown == 0:
# scan everything!
for process_id in self.processes:
self.processes[process_id].scan()
@@ -1989,10 +2019,10 @@ class Node(SimComponent):
# scan file system
self.file_system.scan(instant_scan=True)
if self.red_scan_countdown > 0:
self.red_scan_countdown -= 1
if self.config.red_scan_countdown > 0:
self.config.red_scan_countdown -= 1
if self.red_scan_countdown == 0:
if self.config.red_scan_countdown == 0:
# scan processes
for process_id in self.processes:
self.processes[process_id].reveal_to_red()
@@ -2049,7 +2079,7 @@ class Node(SimComponent):
to the red agent.
"""
self.node_scan_countdown = self.node_scan_duration
self.config.node_scan_countdown = self.config.node_scan_duration
return True
def reveal_to_red(self) -> bool:
@@ -2065,12 +2095,12 @@ class Node(SimComponent):
`revealed_to_red` to `True`.
"""
self.red_scan_countdown = self.node_scan_duration
self.config.red_scan_countdown = self.config.node_scan_duration
return True
def power_on(self) -> bool:
"""Power on the Node, enabling its NICs if it is in the OFF state."""
if self.start_up_duration <= 0:
if self.config.start_up_duration <= 0:
self.operating_state = NodeOperatingState.ON
self._start_up_actions()
self.sys_log.info("Power on")
@@ -2079,14 +2109,14 @@ class Node(SimComponent):
return True
if self.operating_state == NodeOperatingState.OFF:
self.operating_state = NodeOperatingState.BOOTING
self.start_up_countdown = self.start_up_duration
self.config.start_up_countdown = self.config.start_up_duration
return True
return False
def power_off(self) -> bool:
"""Power off the Node, disabling its NICs if it is in the ON state."""
if self.shut_down_duration <= 0:
if self.config.shut_down_duration <= 0:
self._shut_down_actions()
self.operating_state = NodeOperatingState.OFF
self.sys_log.info("Power off")
@@ -2095,7 +2125,7 @@ class Node(SimComponent):
for network_interface in self.network_interfaces.values():
network_interface.disable()
self.operating_state = NodeOperatingState.SHUTTING_DOWN
self.shut_down_countdown = self.shut_down_duration
self.config.shut_down_countdown = self.config.shut_down_duration
return True
return False
@@ -2107,7 +2137,7 @@ class Node(SimComponent):
Applying more timesteps will eventually turn the node back on.
"""
if self.operating_state.ON:
self.is_resetting = True
self.config.is_resetting = True
self.sys_log.info("Resetting")
self.power_off()
return True
@@ -2212,10 +2242,6 @@ class Node(SimComponent):
for app_id in self.applications:
self.applications[app_id].close()
# Turn off all processes in the node
# for process_id in self.processes:
# self.processes[process_id]
def _start_up_actions(self):
"""Actions to perform when the node is starting up."""
# Turn on all the services in the node
@@ -2226,10 +2252,6 @@ class Node(SimComponent):
for app_id in self.applications:
self.applications[app_id].run()
# Turn off all processes in the node
# for process_id in self.processes:
# self.processes[process_id]
def _install_system_software(self) -> None:
"""Preinstall required software."""
for _, software_class in self.SYSTEM_SOFTWARE.items():

View File

@@ -1,6 +1,8 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from typing import ClassVar, Dict
from pydantic import Field
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
@@ -35,4 +37,11 @@ class Computer(HostNode, identifier="computer"):
SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "FTPClient": FTPClient}
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Computer class."""
hostname: str = "Computer"
config: ConfigSchema = Field(default_factory=lambda: Computer.ConfigSchema())
pass

View File

@@ -4,6 +4,8 @@ from __future__ import annotations
from ipaddress import IPv4Address
from typing import Any, ClassVar, Dict, Optional
from pydantic import Field
from primaite import getLogger
from primaite.simulator.network.hardware.base import (
IPWiredNetworkInterface,
@@ -44,8 +46,8 @@ class HostARP(ARP):
:return: The MAC address of the default gateway if present in the ARP cache; otherwise, None.
"""
if self.software_manager.node.default_gateway:
return self.get_arp_cache_mac_address(self.software_manager.node.default_gateway)
if self.software_manager.node.config.default_gateway:
return self.get_arp_cache_mac_address(self.software_manager.node.config.default_gateway)
def get_default_gateway_network_interface(self) -> Optional[NIC]:
"""
@@ -53,8 +55,11 @@ class HostARP(ARP):
:return: The NIC associated with the default gateway if it exists in the ARP cache; otherwise, None.
"""
if self.software_manager.node.default_gateway and self.software_manager.node.has_enabled_network_interface:
return self.get_arp_cache_network_interface(self.software_manager.node.default_gateway)
if (
self.software_manager.node.config.default_gateway
and self.software_manager.node.has_enabled_network_interface
):
return self.get_arp_cache_network_interface(self.software_manager.node.config.default_gateway)
def _get_arp_cache_mac_address(
self, ip_address: IPV4Address, is_reattempt: bool = False, is_default_gateway_attempt: bool = False
@@ -73,7 +78,7 @@ class HostARP(ARP):
if arp_entry:
return arp_entry.mac_address
if ip_address == self.software_manager.node.default_gateway:
if ip_address == self.software_manager.node.config.default_gateway:
is_reattempt = True
if not is_reattempt:
self.send_arp_request(ip_address)
@@ -81,11 +86,11 @@ class HostARP(ARP):
ip_address=ip_address, is_reattempt=True, is_default_gateway_attempt=is_default_gateway_attempt
)
else:
if self.software_manager.node.default_gateway:
if self.software_manager.node.config.default_gateway:
if not is_default_gateway_attempt:
self.send_arp_request(self.software_manager.node.default_gateway)
self.send_arp_request(self.software_manager.node.config.default_gateway)
return self._get_arp_cache_mac_address(
ip_address=self.software_manager.node.default_gateway,
ip_address=self.software_manager.node.config.default_gateway,
is_reattempt=True,
is_default_gateway_attempt=True,
)
@@ -116,7 +121,7 @@ class HostARP(ARP):
if arp_entry:
return self.software_manager.node.network_interfaces[arp_entry.network_interface_uuid]
else:
if ip_address == self.software_manager.node.default_gateway:
if ip_address == self.software_manager.node.config.default_gateway:
is_reattempt = True
if not is_reattempt:
self.send_arp_request(ip_address)
@@ -124,11 +129,11 @@ class HostARP(ARP):
ip_address=ip_address, is_reattempt=True, is_default_gateway_attempt=is_default_gateway_attempt
)
else:
if self.software_manager.node.default_gateway:
if self.software_manager.node.config.default_gateway:
if not is_default_gateway_attempt:
self.send_arp_request(self.software_manager.node.default_gateway)
self.send_arp_request(self.software_manager.node.config.default_gateway)
return self._get_arp_cache_network_interface(
ip_address=self.software_manager.node.default_gateway,
ip_address=self.software_manager.node.config.default_gateway,
is_reattempt=True,
is_default_gateway_attempt=True,
)
@@ -325,9 +330,18 @@ class HostNode(Node, identifier="HostNode"):
network_interface: Dict[int, NIC] = {}
"The NICs on the node by port id."
def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs):
class ConfigSchema(Node.ConfigSchema):
"""Configuration Schema for HostNode class."""
hostname: str = "HostNode"
subnet_mask: IPV4Address = "255.255.255.0"
ip_address: IPV4Address
config: ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema())
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask))
self.connect_nic(NIC(ip_address=kwargs["config"].ip_address, subnet_mask=kwargs["config"].subnet_mask))
@property
def nmap(self) -> Optional[NMAP]:

View File

@@ -1,4 +1,7 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from pydantic import Field
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
@@ -30,8 +33,22 @@ class Server(HostNode, identifier="server"):
* Web Browser
"""
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Server class."""
hostname: str = "server"
config: ConfigSchema = Field(default_factory=lambda: Server.ConfigSchema())
class Printer(HostNode, identifier="printer"):
"""Printer? I don't even know her!."""
# TODO: Implement printer-specific behaviour
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Printer class."""
hostname: str = "printer"
config: ConfigSchema = Field(default_factory=lambda: Printer.ConfigSchema())

View File

@@ -6,7 +6,6 @@ from prettytable import MARKDOWN, PrettyTable
from pydantic import Field, validate_call
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.network.router import (
AccessControlList,
ACLAction,
@@ -99,11 +98,21 @@ class Firewall(Router, identifier="firewall"):
)
"""Access Control List for managing traffic leaving towards an external network."""
def __init__(self, hostname: str, **kwargs):
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(hostname)
_identifier: str = "firewall"
super().__init__(hostname=hostname, num_ports=0, **kwargs)
class ConfigSchema(Router.ConfigSchema):
"""Configuration Schema for Firewall 'Nodes' within PrimAITE."""
hostname: str = "firewall"
num_ports: int = 0
config: ConfigSchema = Field(default_factory=lambda: Firewall.ConfigSchema())
def __init__(self, **kwargs):
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(kwargs["config"].hostname)
super().__init__(**kwargs)
self.connect_nic(
RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", port_name="external")
@@ -116,22 +125,23 @@ class Firewall(Router, identifier="firewall"):
)
# Update ACL objects with firewall's hostname and syslog to allow accurate logging
self.internal_inbound_acl.sys_log = kwargs["sys_log"]
self.internal_inbound_acl.name = f"{hostname} - Internal Inbound"
self.internal_inbound_acl.name = f"{kwargs['config'].hostname} - Internal Inbound"
self.internal_outbound_acl.sys_log = kwargs["sys_log"]
self.internal_outbound_acl.name = f"{hostname} - Internal Outbound"
self.internal_outbound_acl.name = f"{kwargs['config'].hostname} - Internal Outbound"
self.dmz_inbound_acl.sys_log = kwargs["sys_log"]
self.dmz_inbound_acl.name = f"{hostname} - DMZ Inbound"
self.dmz_inbound_acl.name = f"{kwargs['config'].hostname} - DMZ Inbound"
self.dmz_outbound_acl.sys_log = kwargs["sys_log"]
self.dmz_outbound_acl.name = f"{hostname} - DMZ Outbound"
self.dmz_outbound_acl.name = f"{kwargs['config'].hostname} - DMZ Outbound"
self.external_inbound_acl.sys_log = kwargs["sys_log"]
self.external_inbound_acl.name = f"{hostname} - External Inbound"
self.external_inbound_acl.name = f"{kwargs['config'].hostname} - External Inbound"
self.external_outbound_acl.sys_log = kwargs["sys_log"]
self.external_outbound_acl.name = f"{hostname} - External Outbound"
self.external_outbound_acl.name = f"{kwargs['config'].hostname} - External Outbound"
self.power_on()
def _init_request_manager(self) -> RequestManager:
"""
@@ -231,7 +241,7 @@ class Firewall(Router, identifier="firewall"):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Network Interfaces"
table.title = f"{self.config.hostname} Network Interfaces"
ports = {"External": self.external_port, "Internal": self.internal_port, "DMZ": self.dmz_port}
for port, network_interface in ports.items():
table.add_row(
@@ -551,18 +561,14 @@ class Firewall(Router, identifier="firewall"):
self.dmz_port.enable()
@classmethod
def from_config(cls, cfg: dict) -> "Firewall":
def from_config(cls, config: dict) -> "Firewall":
"""Create a firewall based on a config dict."""
firewall = Firewall(
hostname=cfg["hostname"],
operating_state=NodeOperatingState.ON
if not (p := cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
if "ports" in cfg:
internal_port = cfg["ports"]["internal_port"]
external_port = cfg["ports"]["external_port"]
dmz_port = cfg["ports"].get("dmz_port")
firewall = Firewall(config=cls.ConfigSchema(**config))
if "ports" in config:
internal_port = config["ports"]["internal_port"]
external_port = config["ports"]["external_port"]
dmz_port = config["ports"].get("dmz_port")
# configure internal port
firewall.configure_internal_port(
@@ -582,10 +588,10 @@ class Firewall(Router, identifier="firewall"):
ip_address=IPV4Address(dmz_port.get("ip_address")),
subnet_mask=IPV4Address(dmz_port.get("subnet_mask", "255.255.255.0")),
)
if "acl" in cfg:
if "acl" in config:
# acl rules for internal_inbound_acl
if cfg["acl"]["internal_inbound_acl"]:
for r_num, r_cfg in cfg["acl"]["internal_inbound_acl"].items():
if config["acl"]["internal_inbound_acl"]:
for r_num, r_cfg in config["acl"]["internal_inbound_acl"].items():
firewall.internal_inbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -599,8 +605,8 @@ class Firewall(Router, identifier="firewall"):
)
# acl rules for internal_outbound_acl
if cfg["acl"]["internal_outbound_acl"]:
for r_num, r_cfg in cfg["acl"]["internal_outbound_acl"].items():
if config["acl"]["internal_outbound_acl"]:
for r_num, r_cfg in config["acl"]["internal_outbound_acl"].items():
firewall.internal_outbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -614,8 +620,8 @@ class Firewall(Router, identifier="firewall"):
)
# acl rules for dmz_inbound_acl
if cfg["acl"]["dmz_inbound_acl"]:
for r_num, r_cfg in cfg["acl"]["dmz_inbound_acl"].items():
if config["acl"]["dmz_inbound_acl"]:
for r_num, r_cfg in config["acl"]["dmz_inbound_acl"].items():
firewall.dmz_inbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -629,8 +635,8 @@ class Firewall(Router, identifier="firewall"):
)
# acl rules for dmz_outbound_acl
if cfg["acl"]["dmz_outbound_acl"]:
for r_num, r_cfg in cfg["acl"]["dmz_outbound_acl"].items():
if config["acl"]["dmz_outbound_acl"]:
for r_num, r_cfg in config["acl"]["dmz_outbound_acl"].items():
firewall.dmz_outbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -644,8 +650,8 @@ class Firewall(Router, identifier="firewall"):
)
# acl rules for external_inbound_acl
if cfg["acl"].get("external_inbound_acl"):
for r_num, r_cfg in cfg["acl"]["external_inbound_acl"].items():
if config["acl"].get("external_inbound_acl"):
for r_num, r_cfg in config["acl"]["external_inbound_acl"].items():
firewall.external_inbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -659,8 +665,8 @@ class Firewall(Router, identifier="firewall"):
)
# acl rules for external_outbound_acl
if cfg["acl"].get("external_outbound_acl"):
for r_num, r_cfg in cfg["acl"]["external_outbound_acl"].items():
if config["acl"].get("external_outbound_acl"):
for r_num, r_cfg in config["acl"]["external_outbound_acl"].items():
firewall.external_outbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -673,16 +679,16 @@ class Firewall(Router, identifier="firewall"):
position=r_num,
)
if "routes" in cfg:
for route in cfg.get("routes"):
if "routes" in config:
for route in config.get("routes"):
firewall.route_table.add_route(
address=IPv4Address(route.get("address")),
subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")),
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
metric=float(route.get("metric", 0)),
)
if "default_route" in cfg:
next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None)
if "default_route" in config:
next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None)
if next_hop_ip_address:
firewall.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)

View File

@@ -7,7 +7,7 @@ from ipaddress import IPv4Address, IPv4Network
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import validate_call
from pydantic import Field, validate_call
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType, SimComponent
@@ -1201,13 +1201,20 @@ class Router(NetworkNode, identifier="router"):
RouteTable, RouterARP, and RouterICMP services.
"""
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Routers."""
hostname: str = "router"
num_ports: int = 5
config: ConfigSchema = Field(default_factory=lambda: Router.ConfigSchema())
SYSTEM_SOFTWARE: ClassVar[Dict] = {
"UserSessionManager": UserSessionManager,
"UserManager": UserManager,
"Terminal": Terminal,
}
num_ports: int
network_interfaces: Dict[str, RouterInterface] = {}
"The Router Interfaces on the node."
network_interface: Dict[int, RouterInterface] = {}
@@ -1215,19 +1222,21 @@ class Router(NetworkNode, identifier="router"):
acl: AccessControlList
route_table: RouteTable
def __init__(self, hostname: str, num_ports: int = 5, **kwargs):
def __init__(self, **kwargs):
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(hostname)
kwargs["sys_log"] = SysLog(kwargs["config"].hostname)
if not kwargs.get("acl"):
kwargs["acl"] = AccessControlList(sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=hostname)
kwargs["acl"] = AccessControlList(
sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=kwargs["config"].hostname
)
if not kwargs.get("route_table"):
kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"])
super().__init__(hostname=hostname, num_ports=num_ports, **kwargs)
super().__init__(**kwargs)
self.session_manager = RouterSessionManager(sys_log=self.sys_log)
self.session_manager.node = self
self.software_manager.session_manager = self.session_manager
self.session_manager.software_manager = self.software_manager
for i in range(1, self.num_ports + 1):
for i in range(1, self.config.num_ports + 1):
network_interface = RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
self.connect_nic(network_interface)
self.network_interface[i] = network_interface
@@ -1337,7 +1346,7 @@ class Router(NetworkNode, identifier="router"):
:return: A dictionary representing the current state.
"""
state = super().describe_state()
state["num_ports"] = self.num_ports
state["num_ports"] = self.config.num_ports
state["acl"] = self.acl.describe_state()
return state
@@ -1545,7 +1554,7 @@ class Router(NetworkNode, identifier="router"):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Network Interfaces"
table.title = f"{self.config.hostname} Network Interfaces"
for port, network_interface in self.network_interface.items():
table.add_row(
[
@@ -1559,7 +1568,7 @@ class Router(NetworkNode, identifier="router"):
print(table)
@classmethod
def from_config(cls, cfg: dict, **kwargs) -> "Router":
def from_config(cls, config: dict, **kwargs) -> "Router":
"""Create a router based on a config dict.
Schema:
@@ -1616,22 +1625,16 @@ class Router(NetworkNode, identifier="router"):
:return: Configured router.
:rtype: Router
"""
router = Router(
hostname=cfg["hostname"],
num_ports=int(cfg.get("num_ports", "5")),
operating_state=NodeOperatingState.ON
if not (p := cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
if "ports" in cfg:
for port_num, port_cfg in cfg["ports"].items():
router = Router(config=Router.ConfigSchema(**config))
if "ports" in config:
for port_num, port_cfg in config["ports"].items():
router.configure_port(
port=port_num,
ip_address=port_cfg["ip_address"],
subnet_mask=IPv4Address(port_cfg.get("subnet_mask", "255.255.255.0")),
)
if "acl" in cfg:
for r_num, r_cfg in cfg["acl"].items():
if "acl" in config:
for r_num, r_cfg in config["acl"].items():
router.acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -1643,16 +1646,19 @@ class Router(NetworkNode, identifier="router"):
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
if "routes" in cfg:
for route in cfg.get("routes"):
if "routes" in config:
for route in config.get("routes"):
router.route_table.add_route(
address=IPv4Address(route.get("address")),
subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")),
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
metric=float(route.get("metric", 0)),
)
if "default_route" in cfg:
next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None)
if "default_route" in config:
next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None)
if next_hop_ip_address:
router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
router.operating_state = (
NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()]
)
return router

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from pydantic import Field
from primaite import getLogger
from primaite.exceptions import NetworkError
@@ -88,14 +89,8 @@ class SwitchPort(WiredNetworkInterface):
class Switch(NetworkNode, identifier="switch"):
"""
A class representing a Layer 2 network switch.
"""A class representing a Layer 2 network switch."""
:ivar num_ports: The number of ports on the switch. Default is 24.
"""
num_ports: int = 24
"The number of ports on the switch."
network_interfaces: Dict[str, SwitchPort] = {}
"The SwitchPorts on the Switch."
network_interface: Dict[int, SwitchPort] = {}
@@ -103,9 +98,18 @@ class Switch(NetworkNode, identifier="switch"):
mac_address_table: Dict[str, SwitchPort] = {}
"A MAC address table mapping destination MAC addresses to corresponding SwitchPorts."
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Switch nodes within PrimAITE."""
hostname: str = "Switch"
num_ports: int = 24
"The number of ports on the switch. Default is 24."
config: ConfigSchema = Field(default_factory=lambda: Switch.ConfigSchema())
def __init__(self, **kwargs):
super().__init__(**kwargs)
for i in range(1, self.num_ports + 1):
for i in range(1, self.config.num_ports + 1):
self.connect_nic(SwitchPort())
def _install_system_software(self):
@@ -121,7 +125,7 @@ class Switch(NetworkNode, identifier="switch"):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Switch Ports"
table.title = f"{self.config.hostname} Switch Ports"
for port_num, port in self.network_interface.items():
table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"])
print(table)
@@ -134,7 +138,7 @@ class Switch(NetworkNode, identifier="switch"):
"""
state = super().describe_state()
state["ports"] = {port_num: port.describe_state() for port_num, port in self.network_interface.items()}
state["num_ports"] = self.num_ports # redundant?
state["num_ports"] = self.config.num_ports # redundant?
state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()}
return state

View File

@@ -2,7 +2,7 @@
from ipaddress import IPv4Address
from typing import Any, Dict, Optional, Union
from pydantic import validate_call
from pydantic import Field, validate_call
from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, FREQ_WIFI_2_4, IPWirelessNetworkInterface
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
@@ -122,13 +122,23 @@ class WirelessRouter(Router, identifier="wireless_router"):
network_interfaces: Dict[str, Union[RouterInterface, WirelessAccessPoint]] = {}
network_interface: Dict[int, Union[RouterInterface, WirelessAccessPoint]] = {}
airspace: AirSpace
def __init__(self, hostname: str, airspace: AirSpace, **kwargs):
super().__init__(hostname=hostname, num_ports=0, airspace=airspace, **kwargs)
class ConfigSchema(Router.ConfigSchema):
"""Configuration Schema for WirelessRouter nodes within PrimAITE."""
hostname: str = "WirelessRouter"
airspace: AirSpace
num_ports: int = 0
config: ConfigSchema = Field(default_factory=lambda: WirelessRouter.ConfigSchema())
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.connect_nic(
WirelessAccessPoint(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", airspace=airspace)
WirelessAccessPoint(
ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", airspace=kwargs["config"].airspace
)
)
self.connect_nic(RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0"))
@@ -226,7 +236,7 @@ class WirelessRouter(Router, identifier="wireless_router"):
)
@classmethod
def from_config(cls, cfg: Dict, **kwargs) -> "WirelessRouter":
def from_config(cls, config: Dict, **kwargs) -> "WirelessRouter":
"""Generate the wireless router from config.
Schema:
@@ -253,22 +263,22 @@ class WirelessRouter(Router, identifier="wireless_router"):
:return: WirelessRouter instance.
:rtype: WirelessRouter
"""
operating_state = (
NodeOperatingState.ON if not (p := cfg.get("operating_state")) else NodeOperatingState[p.upper()]
router = cls(config=cls.ConfigSchema(**config))
router.operating_state = (
NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()]
)
router = cls(hostname=cfg["hostname"], operating_state=operating_state, airspace=kwargs["airspace"])
if "router_interface" in cfg:
ip_address = cfg["router_interface"]["ip_address"]
subnet_mask = cfg["router_interface"]["subnet_mask"]
if "router_interface" in config:
ip_address = config["router_interface"]["ip_address"]
subnet_mask = config["router_interface"]["subnet_mask"]
router.configure_router_interface(ip_address=ip_address, subnet_mask=subnet_mask)
if "wireless_access_point" in cfg:
ip_address = cfg["wireless_access_point"]["ip_address"]
subnet_mask = cfg["wireless_access_point"]["subnet_mask"]
frequency = AirSpaceFrequency._registry[cfg["wireless_access_point"]["frequency"]]
if "wireless_access_point" in config:
ip_address = config["wireless_access_point"]["ip_address"]
subnet_mask = config["wireless_access_point"]["subnet_mask"]
frequency = AirSpaceFrequency._registry[config["wireless_access_point"]["frequency"]]
router.configure_wireless_access_point(ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency)
if "acl" in cfg:
for r_num, r_cfg in cfg["acl"].items():
if "acl" in config:
for r_num, r_cfg in config["acl"].items():
router.acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -280,8 +290,8 @@ class WirelessRouter(Router, identifier="wireless_router"):
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
if "routes" in cfg:
for route in cfg.get("routes"):
if "routes" in config:
for route in config.get("routes"):
router.route_table.add_route(
address=IPv4Address(route.get("address")),
subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")),

View File

@@ -40,41 +40,45 @@ def client_server_routed() -> Network:
network = Network()
# Router 1
router_1 = Router(hostname="router_1", num_ports=3)
router_1 = Router(config=dict(hostname="router_1", num_ports=3))
router_1.power_on()
router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
router_1.configure_port(port=2, ip_address="192.168.2.1", subnet_mask="255.255.255.0")
# Switch 1
switch_1 = Switch(hostname="switch_1", num_ports=6)
switch_1 = Switch(config=dict(hostname="switch_1", num_ports=6))
switch_1.power_on()
network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[6])
router_1.enable_port(1)
# Switch 2
switch_2 = Switch(hostname="switch_2", num_ports=6)
switch_2 = Switch(config=dict(hostname="switch_2", num_ports=6))
switch_2.power_on()
network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[6])
router_1.enable_port(2)
# Client 1
client_1 = Computer(
hostname="client_1",
ip_address="192.168.2.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.2.1",
start_up_duration=0,
config=dict(
hostname="client_1",
ip_address="192.168.2.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.2.1",
start_up_duration=0,
)
)
client_1.power_on()
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
# Server 1
server_1 = Server(
hostname="server_1",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
config=dict(
hostname="server_1",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
)
server_1.power_on()
network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1])
@@ -128,32 +132,41 @@ def arcd_uc2_network() -> Network:
network = Network()
# Router 1
router_1 = Router(hostname="router_1", num_ports=5, start_up_duration=0)
router_1 = Router.from_config(
config={"type": "router", "hostname": "router_1", "num_ports": 5, "start_up_duration": 0}
)
router_1.power_on()
router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0")
# Switch 1
switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0)
switch_1 = Switch.from_config(
config={"type": "switch", "hostname": "switch_1", "num_ports": 8, "start_up_duration": 0}
)
switch_1.power_on()
network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[8])
router_1.enable_port(1)
# Switch 2
switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0)
switch_2 = Switch.from_config(
config={"type": "switch", "hostname": "switch_2", "num_ports": 8, "start_up_duration": 0}
)
switch_2.power_on()
network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8])
router_1.enable_port(2)
# Client 1
client_1 = Computer(
hostname="client_1",
ip_address="192.168.10.21",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
dns_server=IPv4Address("192.168.1.10"),
start_up_duration=0,
)
client_1_cfg = {
"type": "computer",
"hostname": "client_1",
"ip_address": "192.168.10.21",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.10.1",
"dns_server": IPv4Address("192.168.1.10"),
"start_up_duration": 0,
}
client_1: Computer = Computer.from_config(config=client_1_cfg)
client_1.power_on()
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
client_1.software_manager.install(DatabaseClient)
@@ -172,14 +185,18 @@ def arcd_uc2_network() -> Network:
)
# Client 2
client_2 = Computer(
hostname="client_2",
ip_address="192.168.10.22",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
dns_server=IPv4Address("192.168.1.10"),
start_up_duration=0,
)
client_2_cfg = {
"type": "computer",
"hostname": "client_2",
"ip_address": "192.168.10.22",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.10.1",
"dns_server": IPv4Address("192.168.1.10"),
"start_up_duration": 0,
}
client_2: Computer = Computer.from_config(config=client_2_cfg)
client_2.power_on()
client_2.software_manager.install(DatabaseClient)
db_client_2 = client_2.software_manager.software.get("DatabaseClient")
@@ -193,27 +210,36 @@ def arcd_uc2_network() -> Network:
)
# Domain Controller
domain_controller = Server(
hostname="domain_controller",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
domain_controller_cfg = {
"type": "server",
"hostname": "domain_controller",
"ip_address": "192.168.1.10",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
domain_controller = Server.from_config(config=domain_controller_cfg)
domain_controller.power_on()
domain_controller.software_manager.install(DNSServer)
network.connect(endpoint_b=domain_controller.network_interface[1], endpoint_a=switch_1.network_interface[1])
# Database Server
database_server = Server(
hostname="database_server",
ip_address="192.168.1.14",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
start_up_duration=0,
)
database_server_cfg = {
"type": "server",
"hostname": "database_server",
"ip_address": "192.168.1.14",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"dns_server": IPv4Address("192.168.1.10"),
"start_up_duration": 0,
}
database_server = Server.from_config(config=database_server_cfg)
database_server.power_on()
network.connect(endpoint_b=database_server.network_interface[1], endpoint_a=switch_1.network_interface[3])
@@ -223,14 +249,18 @@ def arcd_uc2_network() -> Network:
database_service.configure_backup(backup_server=IPv4Address("192.168.1.16"))
# Web Server
web_server = Server(
hostname="web_server",
ip_address="192.168.1.12",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
start_up_duration=0,
)
web_server_cfg = {
"type": "server",
"hostname": "web_server",
"ip_address": "192.168.1.11",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"dns_server": IPv4Address("192.168.1.10"),
"start_up_duration": 0,
}
web_server = Server.from_config(config=web_server_cfg)
web_server.power_on()
web_server.software_manager.install(DatabaseClient)
@@ -247,27 +277,32 @@ def arcd_uc2_network() -> Network:
dns_server_service.dns_register("arcd.com", web_server.network_interface[1].ip_address)
# Backup Server
backup_server = Server(
hostname="backup_server",
ip_address="192.168.1.16",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
start_up_duration=0,
)
backup_server_cfg = {
"type": "server",
"hostname": "backup_server",
"ip_address": "192.168.1.16",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"dns_server": IPv4Address("192.168.1.10"),
"start_up_duration": 0,
}
backup_server: Server = Server.from_config(config=backup_server_cfg)
backup_server.power_on()
backup_server.software_manager.install(FTPServer)
network.connect(endpoint_b=backup_server.network_interface[1], endpoint_a=switch_1.network_interface[4])
# Security Suite
security_suite = Server(
hostname="security_suite",
ip_address="192.168.1.110",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
start_up_duration=0,
)
security_suite_cfg = {
"type": "server",
"hostname": "backup_server",
"ip_address": "192.168.1.110",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"dns_server": IPv4Address("192.168.1.10"),
"start_up_duration": 0,
}
security_suite: Server = Server.from_config(config=security_suite_cfg)
security_suite.power_on()
network.connect(endpoint_b=security_suite.network_interface[1], endpoint_a=switch_1.network_interface[7])
security_suite.connect_nic(NIC(ip_address="192.168.10.110", subnet_mask="255.255.255.0"))

View File

@@ -208,7 +208,7 @@ class NMAP(Application, identifier="NMAP"):
if show:
table = PrettyTable(["IP Address", "Can Ping"])
table.align = "l"
table.title = f"{self.software_manager.node.hostname} NMAP Ping Scan"
table.title = f"{self.software_manager.node.config.hostname} NMAP Ping Scan"
ip_addresses = self._explode_ip_address_network_array(target_ip_address)
@@ -367,7 +367,7 @@ class NMAP(Application, identifier="NMAP"):
if show:
table = PrettyTable(["IP Address", "Port", "Protocol"])
table.align = "l"
table.title = f"{self.software_manager.node.hostname} NMAP Port Scan ({scan_type})"
table.title = f"{self.software_manager.node.config.hostname} NMAP Port Scan ({scan_type})"
self.sys_log.info(f"{self.name}: Starting port scan")
for ip_address in ip_addresses:
# Prevent port scan on this node

View File

@@ -140,6 +140,7 @@ class SoftwareManager:
elif isinstance(software, Service):
self.node.services[software.uuid] = software
self.node._service_request_manager.add_request(software.name, RequestType(func=software._request_manager))
software.start()
software.install()
software.software_manager = self
self.software[software.name] = software

View File

@@ -32,6 +32,7 @@ class DatabaseService(Service, identifier="DatabaseService"):
type: str = "DatabaseService"
backup_server_ip: Optional[IPv4Address] = None
db_password: Optional[str] = None
"""Password that needs to be provided by clients if they want to connect to the DatabaseService."""
config: "DatabaseService.ConfigSchema" = Field(default_factory=lambda: DatabaseService.ConfigSchema())
@@ -224,7 +225,7 @@ class DatabaseService(Service, identifier="DatabaseService"):
SoftwareHealthState.FIXING,
SoftwareHealthState.COMPROMISED,
]:
if self.password == password:
if self.config.db_password == password:
status_code = 200 # ok
connection_id = self._generate_connection_id()
# try to create connection

View File

@@ -25,9 +25,12 @@ class DNSClient(Service, identifier="DNSClient"):
"""ConfigSchema for DNSClient."""
type: str = "DNSClient"
dns_server: Optional[IPV4Address] = None
config: "DNSClient.ConfigSchema" = Field(default_factory=lambda: DNSClient.ConfigSchema())
dns_server: Optional[IPv4Address] = None
"The DNS Server the client sends requests to."
config: ConfigSchema = Field(default_factory=lambda: DNSClient.ConfigSchema())
dns_cache: Dict[str, IPv4Address] = {}
"A dict of known mappings between domain/URLs names and IPv4 addresses."

View File

@@ -20,14 +20,15 @@ class FTPServer(FTPServiceABC, identifier="FTPServer"):
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
"""
config: "FTPServer.ConfigSchema" = Field(default_factory=lambda: FTPServer.ConfigSchema())
server_password: Optional[str] = None
class ConfigSchema(FTPServiceABC.ConfigSchema):
"""ConfigSchema for FTPServer."""
type: str = "FTPServer"
server_password: Optional[str] = None
"""Password needed to connect to FTP server. Default is None."""
config: ConfigSchema = Field(default_factory=lambda: FTPServer.ConfigSchema())
def __init__(self, **kwargs):
kwargs["name"] = "FTPServer"
@@ -35,7 +36,11 @@ class FTPServer(FTPServiceABC, identifier="FTPServer"):
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
self.start()
self.server_password = self.config.server_password
@property
def server_password(self) -> Optional[str]:
"""Convenience method for accessing FTP server password."""
return self.config.server_password
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""

View File

@@ -9,7 +9,6 @@ from primaite import getLogger
from primaite.simulator.network.protocols.ntp import NTPPacket
from primaite.simulator.system.services.service import Service, ServiceOperatingState
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import Port, PORT_LOOKUP
_LOGGER = getLogger(__name__)
@@ -22,12 +21,12 @@ class NTPClient(Service, identifier="NTPClient"):
"""ConfigSchema for NTPClient."""
type: str = "NTPClient"
ntp_server_ip: Optional[IPV4Address] = None
config: "NTPClient.ConfigSchema" = Field(default_factory=lambda: NTPClient.ConfigSchema())
ntp_server_ip: Optional[IPv4Address] = None
"The NTP server the client sends requests to."
config: ConfigSchema = Field(default_factory=lambda: NTPClient.ConfigSchema())
ntp_server: Optional[IPv4Address] = None
"The NTP server the client sends requests to."
time: Optional[datetime] = None
def __init__(self, **kwargs):
@@ -45,8 +44,8 @@ class NTPClient(Service, identifier="NTPClient"):
:param ntp_server_ip_address: IPv4 address of NTP server.
:param ntp_client_ip_Address: IPv4 address of NTP client.
"""
self.ntp_server = ntp_server_ip_address
self.sys_log.info(f"{self.name}: ntp_server: {self.ntp_server}")
self.config.ntp_server_ip = ntp_server_ip_address
self.sys_log.info(f"{self.name}: ntp_server: {self.config.ntp_server_ip}")
def describe_state(self) -> Dict:
"""
@@ -108,10 +107,10 @@ class NTPClient(Service, identifier="NTPClient"):
def request_time(self) -> None:
"""Send request to ntp_server."""
if self.ntp_server:
if self.config.ntp_server_ip:
self.software_manager.session_manager.receive_payload_from_software_manager(
payload=NTPPacket(),
dst_ip_address=self.ntp_server,
dst_ip_address=self.config.ntp_server_ip,
src_port=self.port,
dst_port=self.port,
ip_protocol=self.protocol,

View File

@@ -315,7 +315,7 @@ class IOSoftware(Software, ABC):
"""
if self.software_manager and self.software_manager.node.operating_state != NodeOperatingState.ON:
self.software_manager.node.sys_log.error(
f"{self.name} Error: {self.software_manager.node.hostname} is not powered on."
f"{self.name} Error: {self.software_manager.node.config.hostname} is not powered on."
)
return False
return True