Merge '2887-Align_Node_Types' into 3062-discriminators

This commit is contained in:
Marek Wolan
2025-02-04 14:04:40 +00:00
78 changed files with 1429 additions and 832 deletions

View File

@@ -22,6 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Relabeled action parameters to match the new action config schemas, and updated the values to no longer rely on indices
- Removed action space options which were previously used for assigning meaning to action space IDs
- Updated tests that don't use YAMLs to still use the new action and agent schemas
- Nodes now use a config schema and are extensible, allowing for plugin support.
- Node tests have been updated to use the new node config schemas when not using YAML files.
### Fixed
- DNS client no longer fails to check its cache if a DNS server address is missing.

View File

@@ -0,0 +1,56 @@
.. only:: comment
© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
.. _about:
Extensible Nodes
****************
Node classes within PrimAITE have been updated to allow for easier generation of custom nodes within simulations.
Changes to Node Class structure.
================================
Node classes all inherit from the base Node Class, though new classes should inherit from either HostNode or NetworkNode, subject to the intended application of the Node.
The use of an `__init__` method is not necessary, as configurable variables for the class should be specified within the `config` of the class, and passed at run time via your YAML configuration using the `from_config` method.
An example of how additional Node classes is below, taken from `router.py` withing PrimAITE.
.. code-block:: Python
class Router(NetworkNode, identifier="router"):
""" Represents a network router within the simulation, managing routing and forwarding of IP packets across network interfaces."""
SYSTEM_SOFTWARE: ClassVar[Dict] = {
"UserSessionManager": UserSessionManager,
"UserManager": UserManager,
"Terminal": Terminal,
}
network_interfaces: Dict[str, RouterInterface] = {}
"The Router Interfaces on the node."
network_interface: Dict[int, RouterInterface] = {}
"The Router Interfaces on the node by port id."
sys_log: SysLog
config: "Router.ConfigSchema" = Field(default_factory=lambda: Router.ConfigSchema())
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Router Objects."""
num_ports: int = 5
hostname: str = "Router"
Changes to YAML file.
=====================
Nodes defined within configuration YAML files for use with PrimAITE 3.X should still be compatible following these changes.

View File

@@ -1,6 +1,5 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
"""PrimAITE game - Encapsulates the simulation and agents."""
from ipaddress import IPv4Address
from typing import Dict, List, Optional, Union
import numpy as np
@@ -13,14 +12,8 @@ from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.creation import NetworkNodeAdder
from primaite.simulator.network.hardware.base import NetworkInterface, Node, NodeOperatingState, UserManager
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode
from primaite.simulator.network.hardware.nodes.network.router import Router
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.application import Application
@@ -274,66 +267,10 @@ class PrimaiteGame:
n_type = node_cfg["type"]
new_node = None
# Default PrimAITE nodes
if n_type == "computer":
new_node = Computer(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
default_gateway=node_cfg.get("default_gateway"),
dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
elif n_type == "server":
new_node = Server(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
default_gateway=node_cfg.get("default_gateway"),
dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
elif n_type == "switch":
new_node = Switch(
hostname=node_cfg["hostname"],
num_ports=int(node_cfg.get("num_ports", "8")),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
elif n_type == "router":
new_node = Router.from_config(node_cfg)
elif n_type == "firewall":
new_node = Firewall.from_config(node_cfg)
elif n_type == "wireless_router":
new_node = WirelessRouter.from_config(node_cfg, airspace=net.airspace)
elif n_type == "printer":
new_node = Printer(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=node_cfg["subnet_mask"],
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
# Handle extended nodes
elif n_type.lower() in Node._registry:
new_node = HostNode._registry[n_type](
hostname=node_cfg["hostname"],
ip_address=node_cfg.get("ip_address"),
subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
default_gateway=node_cfg.get("default_gateway"),
dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
elif n_type in NetworkNode._registry:
new_node = NetworkNode._registry[n_type](**node_cfg)
if n_type in Node._registry:
if n_type == "wireless-router":
node_cfg["airspace"] = net.airspace
new_node = Node._registry[n_type].from_config(config=node_cfg)
else:
msg = f"invalid node type {n_type} in config"
_LOGGER.error(msg)
@@ -341,11 +278,11 @@ class PrimaiteGame:
# TODO: handle simulation defaults more cleanly
if "node_start_up_duration" in defaults_config:
new_node.start_up_duration = defaults_config["node_startup_duration"]
new_node.config.start_up_duration = defaults_config["node_startup_duration"]
if "node_shut_down_duration" in defaults_config:
new_node.shut_down_duration = defaults_config["node_shut_down_duration"]
new_node.config.shut_down_duration = defaults_config["node_shut_down_duration"]
if "node_scan_duration" in defaults_config:
new_node.node_scan_duration = defaults_config["node_scan_duration"]
new_node.config.node_scan_duration = defaults_config["node_scan_duration"]
if "folder_scan_duration" in defaults_config:
new_node.file_system._default_folder_scan_duration = defaults_config["folder_scan_duration"]
if "folder_restore_duration" in defaults_config:
@@ -382,7 +319,7 @@ class PrimaiteGame:
service_class = SERVICE_TYPES_MAPPING[service_type]
if service_class is not None:
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
_LOGGER.debug(f"installing {service_type} on node {new_node.config.hostname}")
new_node.software_manager.install(service_class, software_config=service_cfg.get("options", {}))
new_service = new_node.software_manager.software[service_type]
@@ -400,7 +337,7 @@ class PrimaiteGame:
# TODO: handle simulation defaults more cleanly
if "service_fix_duration" in defaults_config:
new_service.fixing_duration = defaults_config["service_fix_duration"]
new_service.config.fixing_duration = defaults_config["service_fix_duration"]
if "service_restart_duration" in defaults_config:
new_service.restart_duration = defaults_config["service_restart_duration"]
if "service_install_duration" in defaults_config:
@@ -431,8 +368,8 @@ class PrimaiteGame:
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))
# temporarily set to 0 so all nodes are initially on
new_node.start_up_duration = 0
new_node.shut_down_duration = 0
new_node.config.start_up_duration = 0
new_node.config.shut_down_duration = 0
net.add_node(new_node)
# run through the power on step if the node is to be turned on at the start
@@ -440,8 +377,8 @@ class PrimaiteGame:
new_node.power_on()
# set start up and shut down duration
new_node.start_up_duration = int(node_cfg.get("start_up_duration", 3))
new_node.shut_down_duration = int(node_cfg.get("shut_down_duration", 3))
new_node.config.start_up_duration = int(node_cfg.get("start_up_duration", 3))
new_node.config.shut_down_duration = int(node_cfg.get("shut_down_duration", 3))
# 1.1 Create Node Sets
for node_set_cfg in node_sets_cfg:

View File

@@ -3,7 +3,7 @@
"""Core of the PrimAITE Simulator."""
import warnings
from abc import abstractmethod
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
from uuid import uuid4
from prettytable import PrettyTable

View File

@@ -178,7 +178,7 @@ class AirSpace(BaseModel):
status = "Enabled" if interface.enabled else "Disabled"
table.add_row(
[
interface._connected_node.hostname, # noqa
interface._connected_node.config.hostname, # noqa
interface.mac_address,
interface.ip_address if hasattr(interface, "ip_address") else None,
interface.subnet_mask if hasattr(interface, "subnet_mask") else None,
@@ -320,7 +320,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC):
self.enabled = True
self._connected_node.sys_log.info(f"Network Interface {self} enabled")
self.pcap = PacketCapture(
hostname=self._connected_node.hostname, port_num=self.port_num, port_name=self.port_name
hostname=self._connected_node.config.hostname, port_num=self.port_num, port_name=self.port_name
)
self.airspace.add_wireless_interface(self)

View File

@@ -180,7 +180,7 @@ class Network(SimComponent):
table.align = "l"
table.title = "Nodes"
for node in self.nodes.values():
table.add_row((node.hostname, type(node)._discriminator, node.operating_state.name))
table.add_row((node.config.hostname, type(node)._discriminator, node.operating_state.name))
print(table)
if ip_addresses:
@@ -196,7 +196,13 @@ class Network(SimComponent):
if port.ip_address != IPv4Address("127.0.0.1"):
port_str = port.port_name if port.port_name else port.port_num
table.add_row(
[node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway]
[
node.config.hostname,
port_str,
port.ip_address,
port.subnet_mask,
node.config.default_gateway,
]
)
print(table)
@@ -215,9 +221,9 @@ class Network(SimComponent):
if node in [link.endpoint_a.parent, link.endpoint_b.parent]:
table.add_row(
[
link.endpoint_a.parent.hostname,
link.endpoint_a.parent.config.hostname,
str(link.endpoint_a),
link.endpoint_b.parent.hostname,
link.endpoint_b.parent.config.hostname,
str(link.endpoint_b),
link.is_up,
link.bandwidth,
@@ -251,7 +257,7 @@ class Network(SimComponent):
state = super().describe_state()
state.update(
{
"nodes": {node.hostname: node.describe_state() for node in self.nodes.values()},
"nodes": {node.config.hostname: node.describe_state() for node in self.nodes.values()},
"links": {},
}
)
@@ -259,8 +265,8 @@ class Network(SimComponent):
for _, link in self.links.items():
node_a = link.endpoint_a._connected_node
node_b = link.endpoint_b._connected_node
hostname_a = node_a.hostname if node_a else None
hostname_b = node_b.hostname if node_b else None
hostname_a = node_a.config.hostname if node_a else None
hostname_b = node_b.config.hostname if node_b else None
port_a = link.endpoint_a.port_num
port_b = link.endpoint_b.port_num
link_key = f"{hostname_a}:eth-{port_a}<->{hostname_b}:eth-{port_b}"
@@ -286,9 +292,11 @@ class Network(SimComponent):
self.nodes[node.uuid] = node
self._node_id_map[len(self.nodes)] = node
node.parent = self
self._nx_graph.add_node(node.hostname)
self._nx_graph.add_node(node.config.hostname)
_LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}")
self._node_request_manager.add_request(name=node.hostname, request_type=RequestType(func=node._request_manager))
self._node_request_manager.add_request(
name=node.config.hostname, request_type=RequestType(func=node._request_manager)
)
def get_node_by_hostname(self, hostname: str) -> Optional[Node]:
"""
@@ -300,7 +308,7 @@ class Network(SimComponent):
:return: The Node if it exists in the network.
"""
for node in self.nodes.values():
if node.hostname == hostname:
if node.config.hostname == hostname:
return node
def remove_node(self, node: Node) -> None:
@@ -313,7 +321,7 @@ class Network(SimComponent):
:type node: Node
"""
if node not in self:
_LOGGER.warning(f"Can't remove node {node.hostname}. It's not in the network.")
_LOGGER.warning(f"Can't remove node {node.config.hostname}. It's not in the network.")
return
self.nodes.pop(node.uuid)
for i, _node in self._node_id_map.items():
@@ -321,8 +329,8 @@ class Network(SimComponent):
self._node_id_map.pop(i)
break
node.parent = None
self._node_request_manager.remove_request(name=node.hostname)
_LOGGER.info(f"Removed node {node.hostname} from network {self.uuid}")
self._node_request_manager.remove_request(name=node.config.hostname)
_LOGGER.info(f"Removed node {node.config.hostname} from network {self.uuid}")
def connect(
self, endpoint_a: WiredNetworkInterface, endpoint_b: WiredNetworkInterface, bandwidth: int = 100, **kwargs
@@ -352,7 +360,7 @@ class Network(SimComponent):
link = Link(endpoint_a=endpoint_a, endpoint_b=endpoint_b, bandwidth=bandwidth, **kwargs)
self.links[link.uuid] = link
self._link_id_map[len(self.links)] = link
self._nx_graph.add_edge(endpoint_a.parent.hostname, endpoint_b.parent.hostname)
self._nx_graph.add_edge(endpoint_a.parent.config.hostname, endpoint_b.parent.config.hostname)
link.parent = self
_LOGGER.debug(f"Added link {link.uuid} to connect {endpoint_a} and {endpoint_b}")
return link

View File

@@ -154,7 +154,9 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"):
# Create a core switch if more than one edge switch is needed
if num_of_switches > 1:
core_switch = Switch(hostname=f"switch_core_{config.lan_name}", start_up_duration=0)
core_switch = Switch.from_config(
config={"type": "switch", "hostname": f"switch_core_{config.lan_name}", "start_up_duration": 0}
)
core_switch.power_on()
network.add_node(core_switch)
core_switch_port = 1
@@ -165,7 +167,10 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"):
# Optionally include a router in the LAN
if config.include_router:
default_gateway = IPv4Address(f"192.168.{config.subnet_base}.1")
router = Router(hostname=f"router_{config.lan_name}", start_up_duration=0)
# router = Router(hostname=f"router_{config.lan_name}", start_up_duration=0)
router = Router.from_config(
config={"hostname": f"router_{config.lan_name}", "type": "router", "start_up_duration": 0}
)
router.power_on()
router.acl.add_rule(
action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22
@@ -178,7 +183,9 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"):
# Initialise the first edge switch and connect to the router or core switch
switch_port = 0
switch_n = 1
switch = Switch(hostname=f"switch_edge_{switch_n}_{config.lan_name}", start_up_duration=0)
switch = Switch.from_config(
config={"type": "switch", "hostname": f"switch_edge_{switch_n}_{config.lan_name}", "start_up_duration": 0}
)
switch.power_on()
network.add_node(switch)
if num_of_switches > 1:
@@ -196,7 +203,13 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"):
if switch_port == effective_network_interface:
switch_n += 1
switch_port = 0
switch = Switch(hostname=f"switch_edge_{switch_n}_{config.lan_name}", start_up_duration=0)
switch = Switch.from_config(
config={
"type": "switch",
"hostname": f"switch_edge_{switch_n}_{config.lan_name}",
"start_up_duration": 0,
}
)
switch.power_on()
network.add_node(switch)
# Connect the new switch to the router or core switch
@@ -213,13 +226,14 @@ class OfficeLANAdder(NetworkNodeAdder, discriminator="office-lan"):
)
# Create and add a PC to the network
pc = Computer(
hostname=f"pc_{i}_{config.lan_name}",
ip_address=f"192.168.{config.subnet_base}.{i+config.pcs_ip_block_start-1}",
subnet_mask="255.255.255.0",
default_gateway=default_gateway,
start_up_duration=0,
)
pc_cfg = {
"type": "computer",
"hostname": f"pc_{i}_{config.lan_name}",
"ip_address": f"192.168.{config.subnet_base}.{i+config.pcs_ip_block_start-1}",
"default_gateway": "192.168.10.1",
"start_up_duration": 0,
}
pc = Computer.from_config(config=pc_cfg)
pc.power_on()
network.add_node(pc)

View File

@@ -9,7 +9,7 @@ from pathlib import Path
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel, Field, validate_call
from pydantic import BaseModel, ConfigDict, Field, validate_call
from primaite import getLogger
from primaite.exceptions import NetworkError
@@ -431,7 +431,7 @@ class WiredNetworkInterface(NetworkInterface, ABC):
self.enabled = True
self._connected_node.sys_log.info(f"Network Interface {self} enabled")
self.pcap = PacketCapture(
hostname=self._connected_node.hostname, port_num=self.port_num, port_name=self.port_name
hostname=self._connected_node.config.hostname, port_num=self.port_num, port_name=self.port_name
)
if self._connected_link:
self._connected_link.endpoint_up()
@@ -1494,19 +1494,12 @@ class Node(SimComponent, ABC):
:param operating_state: The node operating state, either ON or OFF.
"""
hostname: str
"The node hostname on the network."
default_gateway: Optional[IPV4Address] = None
"The default gateway IP address for forwarding network traffic to other networks."
operating_state: NodeOperatingState = NodeOperatingState.OFF
"The hardware state of the node."
network_interfaces: Dict[str, NetworkInterface] = {}
"The Network Interfaces on the node."
network_interface: Dict[int, NetworkInterface] = {}
"The Network Interfaces on the node by port id."
dns_server: Optional[IPv4Address] = None
"List of IP addresses of DNS servers used for name resolution."
accounts: Dict[str, Account] = {}
"All accounts on the node."
applications: Dict[str, Application] = {}
@@ -1523,6 +1516,28 @@ class Node(SimComponent, ABC):
session_manager: SessionManager
software_manager: SoftwareManager
SYSTEM_SOFTWARE: ClassVar[Dict[str, Type[Software]]] = {}
"Base system software that must be preinstalled."
_registry: ClassVar[Dict[str, Type["Node"]]] = {}
"""Registry of application types. Automatically populated when subclasses are defined."""
# TODO: this should not be set for abstract classes.
_discriminator: ClassVar[str]
"""discriminator for this particular class, used for printing and logging. Each subclass redefines this."""
config: Node.ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema())
"""Configuration items within Node"""
class ConfigSchema(BaseModel, ABC):
"""Configuration Schema for Node based classes."""
model_config = ConfigDict(arbitrary_types_allowed=True)
"""Configure pydantic to allow arbitrary types, let the instance have attributes not present in the model."""
hostname: str = "default"
"The node hostname on the network."
revealed_to_red: bool = False
"Informs whether the node has been revealed to a red agent."
@@ -1550,15 +1565,27 @@ class Node(SimComponent, ABC):
red_scan_countdown: int = 0
"Time steps until reveal to red scan is complete."
SYSTEM_SOFTWARE: ClassVar[Dict[str, Type[Software]]] = {}
"Base system software that must be preinstalled."
dns_server: Optional[IPv4Address] = None
"List of IP addresses of DNS servers used for name resolution."
_registry: ClassVar[Dict[str, Type["Node"]]] = {}
"""Registry of application types. Automatically populated when subclasses are defined."""
default_gateway: Optional[IPV4Address] = None
"The default gateway IP address for forwarding network traffic to other networks."
# TODO: this should not be set for abstract classes.
_discriminator: ClassVar[str]
"""discriminator for this particular class, used for printing and logging. Each subclass redefines this."""
operating_state: Any = None
@property
def dns_server(self) -> Optional[IPv4Address]:
"""Convenience method to access the dns_server IP."""
return self.config.dns_server
@classmethod
def from_config(cls, config: Dict) -> "Node":
"""Create Node object from a given configuration dictionary."""
if config["type"] not in cls._registry:
msg = f"Configuration contains an invalid Node type: {config['type']}"
return ValueError(msg)
obj = cls(config=cls.ConfigSchema(**config))
return obj
def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None:
"""
@@ -1585,11 +1612,11 @@ class Node(SimComponent, ABC):
provided.
"""
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(kwargs["hostname"])
kwargs["sys_log"] = SysLog(kwargs["config"].hostname)
if not kwargs.get("session_manager"):
kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"))
if not kwargs.get("root"):
kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"]
kwargs["root"] = SIM_OUTPUT.path / kwargs["config"].hostname
if not kwargs.get("file_system"):
kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs")
if not kwargs.get("software_manager"):
@@ -1598,9 +1625,12 @@ class Node(SimComponent, ABC):
sys_log=kwargs.get("sys_log"),
session_manager=kwargs.get("session_manager"),
file_system=kwargs.get("file_system"),
dns_server=kwargs.get("dns_server"),
dns_server=kwargs["config"].dns_server,
)
super().__init__(**kwargs)
self.operating_state = (
NodeOperatingState.ON if not (p := kwargs["config"].operating_state) else NodeOperatingState[p.upper()]
)
self._install_system_software()
self.session_manager.node = self
self.session_manager.software_manager = self.software_manager
@@ -1694,7 +1724,7 @@ class Node(SimComponent, ABC):
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return f"Cannot perform request on node '{self.node.hostname}' because it is not powered on."
return f"Cannot perform request on node '{self.node.config.hostname}' because it is not powered on."
class _NodeIsOffValidator(RequestPermissionValidator):
"""
@@ -1713,7 +1743,7 @@ class Node(SimComponent, ABC):
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return f"Cannot perform request on node '{self.node.hostname}' because it is not turned off."
return f"Cannot perform request on node '{self.node.config.hostname}' because it is not turned off."
def _init_request_manager(self) -> RequestManager:
"""
@@ -1741,7 +1771,7 @@ class Node(SimComponent, ABC):
self.software_manager.install(application_class)
application_instance = self.software_manager.software.get(application_name)
self.applications[application_instance.uuid] = application_instance
_LOGGER.debug(f"Added application {application_instance.name} to node {self.hostname}")
_LOGGER.debug(f"Added application {application_instance.name} to node {self.config.hostname}")
self._application_request_manager.add_request(
application_name, RequestType(func=application_instance._request_manager)
)
@@ -1855,7 +1885,7 @@ class Node(SimComponent, ABC):
state = super().describe_state()
state.update(
{
"hostname": self.hostname,
"hostname": self.config.hostname,
"operating_state": self.operating_state.value,
"NICs": {
eth_num: network_interface.describe_state()
@@ -1865,7 +1895,7 @@ class Node(SimComponent, ABC):
"applications": {app.name: app.describe_state() for app in self.applications.values()},
"services": {svc.name: svc.describe_state() for svc in self.services.values()},
"process": {proc.name: proc.describe_state() for proc in self.processes.values()},
"revealed_to_red": self.revealed_to_red,
"revealed_to_red": self.config.revealed_to_red,
}
)
return state
@@ -1881,7 +1911,7 @@ class Node(SimComponent, ABC):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Open Ports"
table.title = f"{self.config.hostname} Open Ports"
for port in self.software_manager.get_open_ports():
if port > 0:
table.add_row([port])
@@ -1908,7 +1938,7 @@ class Node(SimComponent, ABC):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Network Interface Cards"
table.title = f"{self.config.hostname} Network Interface Cards"
for port, network_interface in self.network_interface.items():
ip_address = ""
if hasattr(network_interface, "ip_address"):
@@ -1943,38 +1973,38 @@ class Node(SimComponent, ABC):
network_interface.apply_timestep(timestep=timestep)
# count down to boot up
if self.start_up_countdown > 0:
self.start_up_countdown -= 1
if self.config.start_up_countdown > 0:
self.config.start_up_countdown -= 1
else:
if self.operating_state == NodeOperatingState.BOOTING:
self.operating_state = NodeOperatingState.ON
self.sys_log.info(f"{self.hostname}: Turned on")
self.sys_log.info(f"{self.config.hostname}: Turned on")
for network_interface in self.network_interfaces.values():
network_interface.enable()
self._start_up_actions()
# count down to shut down
if self.shut_down_countdown > 0:
self.shut_down_countdown -= 1
if self.config.shut_down_countdown > 0:
self.config.shut_down_countdown -= 1
else:
if self.operating_state == NodeOperatingState.SHUTTING_DOWN:
self.operating_state = NodeOperatingState.OFF
self.sys_log.info(f"{self.hostname}: Turned off")
self.sys_log.info(f"{self.config.hostname}: Turned off")
self._shut_down_actions()
# if resetting turn back on
if self.is_resetting:
self.is_resetting = False
if self.config.is_resetting:
self.config.is_resetting = False
self.power_on()
# time steps which require the node to be on
if self.operating_state == NodeOperatingState.ON:
# node scanning
if self.node_scan_countdown > 0:
self.node_scan_countdown -= 1
if self.config.node_scan_countdown > 0:
self.config.node_scan_countdown -= 1
if self.node_scan_countdown == 0:
if self.config.node_scan_countdown == 0:
# scan everything!
for process_id in self.processes:
self.processes[process_id].scan()
@@ -1990,10 +2020,10 @@ class Node(SimComponent, ABC):
# scan file system
self.file_system.scan(instant_scan=True)
if self.red_scan_countdown > 0:
self.red_scan_countdown -= 1
if self.config.red_scan_countdown > 0:
self.config.red_scan_countdown -= 1
if self.red_scan_countdown == 0:
if self.config.red_scan_countdown == 0:
# scan processes
for process_id in self.processes:
self.processes[process_id].reveal_to_red()
@@ -2050,7 +2080,7 @@ class Node(SimComponent, ABC):
to the red agent.
"""
self.node_scan_countdown = self.node_scan_duration
self.config.node_scan_countdown = self.config.node_scan_duration
return True
def reveal_to_red(self) -> bool:
@@ -2066,12 +2096,12 @@ class Node(SimComponent, ABC):
`revealed_to_red` to `True`.
"""
self.red_scan_countdown = self.node_scan_duration
self.config.red_scan_countdown = self.config.node_scan_duration
return True
def power_on(self) -> bool:
"""Power on the Node, enabling its NICs if it is in the OFF state."""
if self.start_up_duration <= 0:
if self.config.start_up_duration <= 0:
self.operating_state = NodeOperatingState.ON
self._start_up_actions()
self.sys_log.info("Power on")
@@ -2080,14 +2110,14 @@ class Node(SimComponent, ABC):
return True
if self.operating_state == NodeOperatingState.OFF:
self.operating_state = NodeOperatingState.BOOTING
self.start_up_countdown = self.start_up_duration
self.config.start_up_countdown = self.config.start_up_duration
return True
return False
def power_off(self) -> bool:
"""Power off the Node, disabling its NICs if it is in the ON state."""
if self.shut_down_duration <= 0:
if self.config.shut_down_duration <= 0:
self._shut_down_actions()
self.operating_state = NodeOperatingState.OFF
self.sys_log.info("Power off")
@@ -2096,7 +2126,7 @@ class Node(SimComponent, ABC):
for network_interface in self.network_interfaces.values():
network_interface.disable()
self.operating_state = NodeOperatingState.SHUTTING_DOWN
self.shut_down_countdown = self.shut_down_duration
self.config.shut_down_countdown = self.config.shut_down_duration
return True
return False
@@ -2108,7 +2138,7 @@ class Node(SimComponent, ABC):
Applying more timesteps will eventually turn the node back on.
"""
if self.operating_state.ON:
self.is_resetting = True
self.config.is_resetting = True
self.sys_log.info("Resetting")
self.power_off()
return True
@@ -2225,6 +2255,8 @@ class Node(SimComponent, ABC):
# Turn on all the applications in the node
for app_id in self.applications:
print(app_id)
print(f"Starting application:{self.applications[app_id].config.type}")
self.applications[app_id].run()
# Turn off all processes in the node

View File

@@ -1,6 +1,8 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from typing import ClassVar, Dict
from pydantic import Field
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
@@ -35,4 +37,11 @@ class Computer(HostNode, discriminator="computer"):
SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "ftp-client": FTPClient}
config: "Computer.ConfigSchema" = Field(default_factory=lambda: Computer.ConfigSchema())
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Computer class."""
hostname: str = "Computer"
pass

View File

@@ -4,6 +4,8 @@ from __future__ import annotations
from ipaddress import IPv4Address
from typing import Any, ClassVar, Dict, Optional
from pydantic import Field
from primaite import getLogger
from primaite.simulator.network.hardware.base import (
IPWiredNetworkInterface,
@@ -44,8 +46,8 @@ class HostARP(ARP):
:return: The MAC address of the default gateway if present in the ARP cache; otherwise, None.
"""
if self.software_manager.node.default_gateway:
return self.get_arp_cache_mac_address(self.software_manager.node.default_gateway)
if self.software_manager.node.config.default_gateway:
return self.get_arp_cache_mac_address(self.software_manager.node.config.default_gateway)
def get_default_gateway_network_interface(self) -> Optional[NIC]:
"""
@@ -53,8 +55,11 @@ class HostARP(ARP):
:return: The NIC associated with the default gateway if it exists in the ARP cache; otherwise, None.
"""
if self.software_manager.node.default_gateway and self.software_manager.node.has_enabled_network_interface:
return self.get_arp_cache_network_interface(self.software_manager.node.default_gateway)
if (
self.software_manager.node.config.default_gateway
and self.software_manager.node.has_enabled_network_interface
):
return self.get_arp_cache_network_interface(self.software_manager.node.config.default_gateway)
def _get_arp_cache_mac_address(
self, ip_address: IPV4Address, is_reattempt: bool = False, is_default_gateway_attempt: bool = False
@@ -73,7 +78,7 @@ class HostARP(ARP):
if arp_entry:
return arp_entry.mac_address
if ip_address == self.software_manager.node.default_gateway:
if ip_address == self.software_manager.node.config.default_gateway:
is_reattempt = True
if not is_reattempt:
self.send_arp_request(ip_address)
@@ -81,11 +86,11 @@ class HostARP(ARP):
ip_address=ip_address, is_reattempt=True, is_default_gateway_attempt=is_default_gateway_attempt
)
else:
if self.software_manager.node.default_gateway:
if self.software_manager.node.config.default_gateway:
if not is_default_gateway_attempt:
self.send_arp_request(self.software_manager.node.default_gateway)
self.send_arp_request(self.software_manager.node.config.default_gateway)
return self._get_arp_cache_mac_address(
ip_address=self.software_manager.node.default_gateway,
ip_address=self.software_manager.node.config.default_gateway,
is_reattempt=True,
is_default_gateway_attempt=True,
)
@@ -116,7 +121,7 @@ class HostARP(ARP):
if arp_entry:
return self.software_manager.node.network_interfaces[arp_entry.network_interface_uuid]
else:
if ip_address == self.software_manager.node.default_gateway:
if ip_address == self.software_manager.node.config.default_gateway:
is_reattempt = True
if not is_reattempt:
self.send_arp_request(ip_address)
@@ -124,11 +129,11 @@ class HostARP(ARP):
ip_address=ip_address, is_reattempt=True, is_default_gateway_attempt=is_default_gateway_attempt
)
else:
if self.software_manager.node.default_gateway:
if self.software_manager.node.config.default_gateway:
if not is_default_gateway_attempt:
self.send_arp_request(self.software_manager.node.default_gateway)
self.send_arp_request(self.software_manager.node.config.default_gateway)
return self._get_arp_cache_network_interface(
ip_address=self.software_manager.node.default_gateway,
ip_address=self.software_manager.node.config.default_gateway,
is_reattempt=True,
is_default_gateway_attempt=True,
)
@@ -325,9 +330,18 @@ class HostNode(Node, discriminator="host-node"):
network_interface: Dict[int, NIC] = {}
"The NICs on the node by port id."
def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs):
config: HostNode.ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema())
class ConfigSchema(Node.ConfigSchema):
"""Configuration Schema for HostNode class."""
hostname: str = "HostNode"
subnet_mask: IPV4Address = "255.255.255.0"
ip_address: IPV4Address
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask))
self.connect_nic(NIC(ip_address=kwargs["config"].ip_address, subnet_mask=kwargs["config"].subnet_mask))
@property
def nmap(self) -> Optional[NMAP]:

View File

@@ -1,4 +1,7 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from pydantic import Field
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
@@ -30,8 +33,22 @@ class Server(HostNode, discriminator="server"):
* Web Browser
"""
config: "Server.ConfigSchema" = Field(default_factory=lambda: Server.ConfigSchema())
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Server class."""
hostname: str = "server"
class Printer(HostNode, discriminator="printer"):
"""Printer? I don't even know her!."""
# TODO: Implement printer-specific behaviour
config: "Printer.ConfigSchema" = Field(default_factory=lambda: Printer.ConfigSchema())
class ConfigSchema(HostNode.ConfigSchema):
"""Configuration Schema for Printer class."""
hostname: str = "printer"

View File

@@ -6,7 +6,6 @@ from prettytable import MARKDOWN, PrettyTable
from pydantic import Field, validate_call
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.network.router import (
AccessControlList,
ACLAction,
@@ -99,11 +98,21 @@ class Firewall(Router, discriminator="firewall"):
)
"""Access Control List for managing traffic leaving towards an external network."""
def __init__(self, hostname: str, **kwargs):
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(hostname)
_identifier: str = "firewall"
super().__init__(hostname=hostname, num_ports=0, **kwargs)
config: "Firewall.ConfigSchema" = Field(default_factory=lambda: Firewall.ConfigSchema())
class ConfigSchema(Router.ConfigSchema):
"""Configuration Schema for Firewall 'Nodes' within PrimAITE."""
hostname: str = "firewall"
num_ports: int = 0
def __init__(self, **kwargs):
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(kwargs["config"].hostname)
super().__init__(**kwargs)
self.connect_nic(
RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", port_name="external")
@@ -116,22 +125,23 @@ class Firewall(Router, discriminator="firewall"):
)
# Update ACL objects with firewall's hostname and syslog to allow accurate logging
self.internal_inbound_acl.sys_log = kwargs["sys_log"]
self.internal_inbound_acl.name = f"{hostname} - Internal Inbound"
self.internal_inbound_acl.name = f"{kwargs['config'].hostname} - Internal Inbound"
self.internal_outbound_acl.sys_log = kwargs["sys_log"]
self.internal_outbound_acl.name = f"{hostname} - Internal Outbound"
self.internal_outbound_acl.name = f"{kwargs['config'].hostname} - Internal Outbound"
self.dmz_inbound_acl.sys_log = kwargs["sys_log"]
self.dmz_inbound_acl.name = f"{hostname} - DMZ Inbound"
self.dmz_inbound_acl.name = f"{kwargs['config'].hostname} - DMZ Inbound"
self.dmz_outbound_acl.sys_log = kwargs["sys_log"]
self.dmz_outbound_acl.name = f"{hostname} - DMZ Outbound"
self.dmz_outbound_acl.name = f"{kwargs['config'].hostname} - DMZ Outbound"
self.external_inbound_acl.sys_log = kwargs["sys_log"]
self.external_inbound_acl.name = f"{hostname} - External Inbound"
self.external_inbound_acl.name = f"{kwargs['config'].hostname} - External Inbound"
self.external_outbound_acl.sys_log = kwargs["sys_log"]
self.external_outbound_acl.name = f"{hostname} - External Outbound"
self.external_outbound_acl.name = f"{kwargs['config'].hostname} - External Outbound"
self.power_on()
def _init_request_manager(self) -> RequestManager:
"""
@@ -231,7 +241,7 @@ class Firewall(Router, discriminator="firewall"):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Network Interfaces"
table.title = f"{self.config.hostname} Network Interfaces"
ports = {"External": self.external_port, "Internal": self.internal_port, "DMZ": self.dmz_port}
for port, network_interface in ports.items():
table.add_row(
@@ -551,18 +561,14 @@ class Firewall(Router, discriminator="firewall"):
self.dmz_port.enable()
@classmethod
def from_config(cls, cfg: dict) -> "Firewall":
def from_config(cls, config: dict) -> "Firewall":
"""Create a firewall based on a config dict."""
firewall = Firewall(
hostname=cfg["hostname"],
operating_state=NodeOperatingState.ON
if not (p := cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
if "ports" in cfg:
internal_port = cfg["ports"]["internal_port"]
external_port = cfg["ports"]["external_port"]
dmz_port = cfg["ports"].get("dmz_port")
firewall = Firewall(config=cls.ConfigSchema(**config))
if "ports" in config:
internal_port = config["ports"]["internal_port"]
external_port = config["ports"]["external_port"]
dmz_port = config["ports"].get("dmz_port")
# configure internal port
firewall.configure_internal_port(
@@ -582,10 +588,10 @@ class Firewall(Router, discriminator="firewall"):
ip_address=IPV4Address(dmz_port.get("ip_address")),
subnet_mask=IPV4Address(dmz_port.get("subnet_mask", "255.255.255.0")),
)
if "acl" in cfg:
if "acl" in config:
# acl rules for internal_inbound_acl
if cfg["acl"]["internal_inbound_acl"]:
for r_num, r_cfg in cfg["acl"]["internal_inbound_acl"].items():
if config["acl"]["internal_inbound_acl"]:
for r_num, r_cfg in config["acl"]["internal_inbound_acl"].items():
firewall.internal_inbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -599,8 +605,8 @@ class Firewall(Router, discriminator="firewall"):
)
# acl rules for internal_outbound_acl
if cfg["acl"]["internal_outbound_acl"]:
for r_num, r_cfg in cfg["acl"]["internal_outbound_acl"].items():
if config["acl"]["internal_outbound_acl"]:
for r_num, r_cfg in config["acl"]["internal_outbound_acl"].items():
firewall.internal_outbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -614,8 +620,8 @@ class Firewall(Router, discriminator="firewall"):
)
# acl rules for dmz_inbound_acl
if cfg["acl"]["dmz_inbound_acl"]:
for r_num, r_cfg in cfg["acl"]["dmz_inbound_acl"].items():
if config["acl"]["dmz_inbound_acl"]:
for r_num, r_cfg in config["acl"]["dmz_inbound_acl"].items():
firewall.dmz_inbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -629,8 +635,8 @@ class Firewall(Router, discriminator="firewall"):
)
# acl rules for dmz_outbound_acl
if cfg["acl"]["dmz_outbound_acl"]:
for r_num, r_cfg in cfg["acl"]["dmz_outbound_acl"].items():
if config["acl"]["dmz_outbound_acl"]:
for r_num, r_cfg in config["acl"]["dmz_outbound_acl"].items():
firewall.dmz_outbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -644,8 +650,8 @@ class Firewall(Router, discriminator="firewall"):
)
# acl rules for external_inbound_acl
if cfg["acl"].get("external_inbound_acl"):
for r_num, r_cfg in cfg["acl"]["external_inbound_acl"].items():
if config["acl"].get("external_inbound_acl"):
for r_num, r_cfg in config["acl"]["external_inbound_acl"].items():
firewall.external_inbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -659,8 +665,8 @@ class Firewall(Router, discriminator="firewall"):
)
# acl rules for external_outbound_acl
if cfg["acl"].get("external_outbound_acl"):
for r_num, r_cfg in cfg["acl"]["external_outbound_acl"].items():
if config["acl"].get("external_outbound_acl"):
for r_num, r_cfg in config["acl"]["external_outbound_acl"].items():
firewall.external_outbound_acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -673,16 +679,16 @@ class Firewall(Router, discriminator="firewall"):
position=r_num,
)
if "routes" in cfg:
for route in cfg.get("routes"):
if "routes" in config:
for route in config.get("routes"):
firewall.route_table.add_route(
address=IPv4Address(route.get("address")),
subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")),
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
metric=float(route.get("metric", 0)),
)
if "default_route" in cfg:
next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None)
if "default_route" in config:
next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None)
if next_hop_ip_address:
firewall.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)

View File

@@ -1207,7 +1207,6 @@ class Router(NetworkNode, discriminator="router"):
"Terminal": Terminal,
}
num_ports: int
network_interfaces: Dict[str, RouterInterface] = {}
"The Router Interfaces on the node."
network_interface: Dict[int, RouterInterface] = {}
@@ -1215,19 +1214,29 @@ class Router(NetworkNode, discriminator="router"):
acl: AccessControlList
route_table: RouteTable
def __init__(self, hostname: str, num_ports: int = 5, **kwargs):
config: "Router.ConfigSchema"
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Routers."""
hostname: str = "router"
num_ports: int = 5
def __init__(self, **kwargs):
if not kwargs.get("sys_log"):
kwargs["sys_log"] = SysLog(hostname)
kwargs["sys_log"] = SysLog(kwargs["config"].hostname)
if not kwargs.get("acl"):
kwargs["acl"] = AccessControlList(sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=hostname)
kwargs["acl"] = AccessControlList(
sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=kwargs["config"].hostname
)
if not kwargs.get("route_table"):
kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"])
super().__init__(hostname=hostname, num_ports=num_ports, **kwargs)
super().__init__(**kwargs)
self.session_manager = RouterSessionManager(sys_log=self.sys_log)
self.session_manager.node = self
self.software_manager.session_manager = self.session_manager
self.session_manager.software_manager = self.software_manager
for i in range(1, self.num_ports + 1):
for i in range(1, self.config.num_ports + 1):
network_interface = RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
self.connect_nic(network_interface)
self.network_interface[i] = network_interface
@@ -1337,7 +1346,7 @@ class Router(NetworkNode, discriminator="router"):
:return: A dictionary representing the current state.
"""
state = super().describe_state()
state["num_ports"] = self.num_ports
state["num_ports"] = self.config.num_ports
state["acl"] = self.acl.describe_state()
return state
@@ -1545,7 +1554,7 @@ class Router(NetworkNode, discriminator="router"):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Network Interfaces"
table.title = f"{self.config.hostname} Network Interfaces"
for port, network_interface in self.network_interface.items():
table.add_row(
[
@@ -1559,7 +1568,7 @@ class Router(NetworkNode, discriminator="router"):
print(table)
@classmethod
def from_config(cls, cfg: dict, **kwargs) -> "Router":
def from_config(cls, config: dict, **kwargs) -> "Router":
"""Create a router based on a config dict.
Schema:
@@ -1616,22 +1625,16 @@ class Router(NetworkNode, discriminator="router"):
:return: Configured router.
:rtype: Router
"""
router = Router(
hostname=cfg["hostname"],
num_ports=int(cfg.get("num_ports", "5")),
operating_state=NodeOperatingState.ON
if not (p := cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
if "ports" in cfg:
for port_num, port_cfg in cfg["ports"].items():
router = Router(config=Router.ConfigSchema(**config))
if "ports" in config:
for port_num, port_cfg in config["ports"].items():
router.configure_port(
port=port_num,
ip_address=port_cfg["ip_address"],
subnet_mask=IPv4Address(port_cfg.get("subnet_mask", "255.255.255.0")),
)
if "acl" in cfg:
for r_num, r_cfg in cfg["acl"].items():
if "acl" in config:
for r_num, r_cfg in config["acl"].items():
router.acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -1643,16 +1646,19 @@ class Router(NetworkNode, discriminator="router"):
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
if "routes" in cfg:
for route in cfg.get("routes"):
if "routes" in config:
for route in config.get("routes"):
router.route_table.add_route(
address=IPv4Address(route.get("address")),
subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")),
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
metric=float(route.get("metric", 0)),
)
if "default_route" in cfg:
next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None)
if "default_route" in config:
next_hop_ip_address = config["default_route"].get("next_hop_ip_address", None)
if next_hop_ip_address:
router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
router.operating_state = (
NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()]
)
return router

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from pydantic import Field
from primaite import getLogger
from primaite.exceptions import NetworkError
@@ -88,14 +89,8 @@ class SwitchPort(WiredNetworkInterface):
class Switch(NetworkNode, discriminator="switch"):
"""
A class representing a Layer 2 network switch.
"""A class representing a Layer 2 network switch."""
:ivar num_ports: The number of ports on the switch. Default is 24.
"""
num_ports: int = 24
"The number of ports on the switch."
network_interfaces: Dict[str, SwitchPort] = {}
"The SwitchPorts on the Switch."
network_interface: Dict[int, SwitchPort] = {}
@@ -103,9 +98,18 @@ class Switch(NetworkNode, discriminator="switch"):
mac_address_table: Dict[str, SwitchPort] = {}
"A MAC address table mapping destination MAC addresses to corresponding SwitchPorts."
config: "Switch.ConfigSchema" = Field(default_factory=lambda: Switch.ConfigSchema())
class ConfigSchema(NetworkNode.ConfigSchema):
"""Configuration Schema for Switch nodes within PrimAITE."""
hostname: str = "Switch"
num_ports: int = 24
"The number of ports on the switch. Default is 24."
def __init__(self, **kwargs):
super().__init__(**kwargs)
for i in range(1, self.num_ports + 1):
for i in range(1, kwargs["config"].num_ports + 1):
self.connect_nic(SwitchPort())
def _install_system_software(self):
@@ -121,7 +125,7 @@ class Switch(NetworkNode, discriminator="switch"):
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Switch Ports"
table.title = f"{self.config.hostname} Switch Ports"
for port_num, port in self.network_interface.items():
table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"])
print(table)
@@ -134,7 +138,7 @@ class Switch(NetworkNode, discriminator="switch"):
"""
state = super().describe_state()
state["ports"] = {port_num: port.describe_state() for port_num, port in self.network_interface.items()}
state["num_ports"] = self.num_ports # redundant?
state["num_ports"] = self.config.num_ports # redundant?
state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()}
return state

View File

@@ -2,7 +2,7 @@
from ipaddress import IPv4Address
from typing import Any, Dict, Optional, Union
from pydantic import validate_call
from pydantic import Field, validate_call
from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, FREQ_WIFI_2_4, IPWirelessNetworkInterface
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
@@ -122,13 +122,23 @@ class WirelessRouter(Router, discriminator="wireless-router"):
network_interfaces: Dict[str, Union[RouterInterface, WirelessAccessPoint]] = {}
network_interface: Dict[int, Union[RouterInterface, WirelessAccessPoint]] = {}
airspace: AirSpace
def __init__(self, hostname: str, airspace: AirSpace, **kwargs):
super().__init__(hostname=hostname, num_ports=0, airspace=airspace, **kwargs)
config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.ConfigSchema())
class ConfigSchema(Router.ConfigSchema):
"""Configuration Schema for WirelessRouter nodes within PrimAITE."""
hostname: str = "WirelessRouter"
airspace: AirSpace
num_ports: int = 0
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.connect_nic(
WirelessAccessPoint(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", airspace=airspace)
WirelessAccessPoint(
ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", airspace=kwargs["config"].airspace
)
)
self.connect_nic(RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0"))
@@ -226,7 +236,7 @@ class WirelessRouter(Router, discriminator="wireless-router"):
)
@classmethod
def from_config(cls, cfg: Dict, **kwargs) -> "WirelessRouter":
def from_config(cls, config: Dict, **kwargs) -> "WirelessRouter":
"""Generate the wireless router from config.
Schema:
@@ -253,22 +263,22 @@ class WirelessRouter(Router, discriminator="wireless-router"):
:return: WirelessRouter instance.
:rtype: WirelessRouter
"""
operating_state = (
NodeOperatingState.ON if not (p := cfg.get("operating_state")) else NodeOperatingState[p.upper()]
router = cls(config=cls.ConfigSchema(**config))
router.operating_state = (
NodeOperatingState.ON if not (p := config.get("operating_state")) else NodeOperatingState[p.upper()]
)
router = cls(hostname=cfg["hostname"], operating_state=operating_state, airspace=kwargs["airspace"])
if "router_interface" in cfg:
ip_address = cfg["router_interface"]["ip_address"]
subnet_mask = cfg["router_interface"]["subnet_mask"]
if "router_interface" in config:
ip_address = config["router_interface"]["ip_address"]
subnet_mask = config["router_interface"]["subnet_mask"]
router.configure_router_interface(ip_address=ip_address, subnet_mask=subnet_mask)
if "wireless_access_point" in cfg:
ip_address = cfg["wireless_access_point"]["ip_address"]
subnet_mask = cfg["wireless_access_point"]["subnet_mask"]
frequency = AirSpaceFrequency._registry[cfg["wireless_access_point"]["frequency"]]
if "wireless_access_point" in config:
ip_address = config["wireless_access_point"]["ip_address"]
subnet_mask = config["wireless_access_point"]["subnet_mask"]
frequency = AirSpaceFrequency._registry[config["wireless_access_point"]["frequency"]]
router.configure_wireless_access_point(ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency)
if "acl" in cfg:
for r_num, r_cfg in cfg["acl"].items():
if "acl" in config:
for r_num, r_cfg in config["acl"].items():
router.acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else PORT_LOOKUP[p],
@@ -280,8 +290,8 @@ class WirelessRouter(Router, discriminator="wireless-router"):
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
if "routes" in cfg:
for route in cfg.get("routes"):
if "routes" in config:
for route in config.get("routes"):
router.route_table.add_route(
address=IPv4Address(route.get("address")),
subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")),

View File

@@ -40,42 +40,46 @@ def client_server_routed() -> Network:
network = Network()
# Router 1
router_1 = Router(hostname="router_1", num_ports=3)
router_1 = Router(config=dict(hostname="router_1", num_ports=3))
router_1.power_on()
router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
router_1.configure_port(port=2, ip_address="192.168.2.1", subnet_mask="255.255.255.0")
# Switch 1
switch_1 = Switch(hostname="switch_1", num_ports=6)
switch_1 = Switch(config=dict(hostname="switch_1", num_ports=6))
switch_1.power_on()
network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[6])
router_1.enable_port(1)
# Switch 2
switch_2 = Switch(hostname="switch_2", num_ports=6)
switch_2 = Switch(config=dict(hostname="switch_2", num_ports=6))
switch_2.power_on()
network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[6])
router_1.enable_port(2)
# Client 1
client_1 = Computer(
config=dict(
hostname="client_1",
ip_address="192.168.2.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.2.1",
start_up_duration=0,
)
)
client_1.power_on()
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
# Server 1
server_1 = Server(
config=dict(
hostname="server_1",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
)
server_1.power_on()
network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1])
@@ -128,32 +132,41 @@ def arcd_uc2_network() -> Network:
network = Network()
# Router 1
router_1 = Router(hostname="router_1", num_ports=5, start_up_duration=0)
router_1 = Router.from_config(
config={"type": "router", "hostname": "router_1", "num_ports": 5, "start_up_duration": 0}
)
router_1.power_on()
router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0")
# Switch 1
switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0)
switch_1 = Switch.from_config(
config={"type": "switch", "hostname": "switch_1", "num_ports": 8, "start_up_duration": 0}
)
switch_1.power_on()
network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[8])
router_1.enable_port(1)
# Switch 2
switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0)
switch_2 = Switch.from_config(
config={"type": "switch", "hostname": "switch_2", "num_ports": 8, "start_up_duration": 0}
)
switch_2.power_on()
network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8])
router_1.enable_port(2)
# Client 1
client_1 = Computer(
hostname="client_1",
ip_address="192.168.10.21",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
dns_server=IPv4Address("192.168.1.10"),
start_up_duration=0,
)
client_1_cfg = {
"type": "computer",
"hostname": "client_1",
"ip_address": "192.168.10.21",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.10.1",
"dns_server": IPv4Address("192.168.1.10"),
"start_up_duration": 0,
}
client_1: Computer = Computer.from_config(config=client_1_cfg)
client_1.power_on()
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
client_1.software_manager.install(DatabaseClient)
@@ -172,14 +185,18 @@ def arcd_uc2_network() -> Network:
)
# Client 2
client_2 = Computer(
hostname="client_2",
ip_address="192.168.10.22",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
dns_server=IPv4Address("192.168.1.10"),
start_up_duration=0,
)
client_2_cfg = {
"type": "computer",
"hostname": "client_2",
"ip_address": "192.168.10.22",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.10.1",
"dns_server": IPv4Address("192.168.1.10"),
"start_up_duration": 0,
}
client_2: Computer = Computer.from_config(config=client_2_cfg)
client_2.power_on()
client_2.software_manager.install(DatabaseClient)
db_client_2 = client_2.software_manager.software.get("database-client")
@@ -193,27 +210,36 @@ def arcd_uc2_network() -> Network:
)
# Domain Controller
domain_controller = Server(
hostname="domain_controller",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
domain_controller_cfg = {
"type": "server",
"hostname": "domain_controller",
"ip_address": "192.168.1.10",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
domain_controller = Server.from_config(config=domain_controller_cfg)
domain_controller.power_on()
domain_controller.software_manager.install(DNSServer)
network.connect(endpoint_b=domain_controller.network_interface[1], endpoint_a=switch_1.network_interface[1])
# Database Server
database_server = Server(
hostname="database_server",
ip_address="192.168.1.14",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
start_up_duration=0,
)
database_server_cfg = {
"type": "server",
"hostname": "database_server",
"ip_address": "192.168.1.14",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"dns_server": IPv4Address("192.168.1.10"),
"start_up_duration": 0,
}
database_server = Server.from_config(config=database_server_cfg)
database_server.power_on()
network.connect(endpoint_b=database_server.network_interface[1], endpoint_a=switch_1.network_interface[3])
@@ -223,14 +249,18 @@ def arcd_uc2_network() -> Network:
database_service.configure_backup(backup_server=IPv4Address("192.168.1.16"))
# Web Server
web_server = Server(
hostname="web_server",
ip_address="192.168.1.12",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
start_up_duration=0,
)
web_server_cfg = {
"type": "server",
"hostname": "web_server",
"ip_address": "192.168.1.11",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"dns_server": IPv4Address("192.168.1.10"),
"start_up_duration": 0,
}
web_server = Server.from_config(config=web_server_cfg)
web_server.power_on()
web_server.software_manager.install(DatabaseClient)
@@ -247,27 +277,32 @@ def arcd_uc2_network() -> Network:
dns_server_service.dns_register("arcd.com", web_server.network_interface[1].ip_address)
# Backup Server
backup_server = Server(
hostname="backup_server",
ip_address="192.168.1.16",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
start_up_duration=0,
)
backup_server_cfg = {
"type": "server",
"hostname": "backup_server",
"ip_address": "192.168.1.16",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"dns_server": IPv4Address("192.168.1.10"),
"start_up_duration": 0,
}
backup_server: Server = Server.from_config(config=backup_server_cfg)
backup_server.power_on()
backup_server.software_manager.install(FTPServer)
network.connect(endpoint_b=backup_server.network_interface[1], endpoint_a=switch_1.network_interface[4])
# Security Suite
security_suite = Server(
hostname="security_suite",
ip_address="192.168.1.110",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
start_up_duration=0,
)
security_suite_cfg = {
"type": "server",
"hostname": "backup_server",
"ip_address": "192.168.1.110",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"dns_server": IPv4Address("192.168.1.10"),
"start_up_duration": 0,
}
security_suite: Server = Server.from_config(config=security_suite_cfg)
security_suite.power_on()
network.connect(endpoint_b=security_suite.network_interface[1], endpoint_a=switch_1.network_interface[7])
security_suite.connect_nic(NIC(ip_address="192.168.10.110", subnet_mask="255.255.255.0"))

View File

@@ -208,7 +208,7 @@ class NMAP(Application, discriminator="nmap"):
if show:
table = PrettyTable(["IP Address", "Can Ping"])
table.align = "l"
table.title = f"{self.software_manager.node.hostname} NMAP Ping Scan"
table.title = f"{self.software_manager.node.config.hostname} NMAP Ping Scan"
ip_addresses = self._explode_ip_address_network_array(target_ip_address)
@@ -367,7 +367,7 @@ class NMAP(Application, discriminator="nmap"):
if show:
table = PrettyTable(["IP Address", "Port", "Protocol"])
table.align = "l"
table.title = f"{self.software_manager.node.hostname} NMAP Port Scan ({scan_type})"
table.title = f"{self.software_manager.node.config.hostname} NMAP Port Scan ({scan_type})"
self.sys_log.info(f"{self.name}: Starting port scan")
for ip_address in ip_addresses:
# Prevent port scan on this node

View File

@@ -140,6 +140,7 @@ class SoftwareManager:
elif isinstance(software, Service):
self.node.services[software.uuid] = software
self.node._service_request_manager.add_request(software.name, RequestType(func=software._request_manager))
software.start()
software.install()
software.software_manager = self
self.software[software.name] = software

View File

@@ -32,6 +32,7 @@ class DatabaseService(Service, discriminator="database-service"):
type: str = "database-service"
backup_server_ip: Optional[IPv4Address] = None
db_password: Optional[str] = None
"""Password that needs to be provided by clients if they want to connect to the DatabaseService."""
config: ConfigSchema = Field(default_factory=lambda: DatabaseService.ConfigSchema())
@@ -224,7 +225,7 @@ class DatabaseService(Service, discriminator="database-service"):
SoftwareHealthState.FIXING,
SoftwareHealthState.COMPROMISED,
]:
if self.password == password:
if self.config.db_password == password:
status_code = 200 # ok
connection_id = self._generate_connection_id()
# try to create connection

View File

@@ -36,7 +36,11 @@ class FTPServer(FTPServiceABC, discriminator="ftp-server"):
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
self.start()
self.server_password = self.config.server_password
@property
def server_password(self) -> Optional[str]:
"""Convenience method for accessing FTP server password."""
return self.config.server_password
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""

View File

@@ -26,8 +26,6 @@ class NTPClient(Service, discriminator="ntp-client"):
config: ConfigSchema = Field(default_factory=lambda: NTPClient.ConfigSchema())
ntp_server: Optional[IPv4Address] = None
"The NTP server the client sends requests to."
time: Optional[datetime] = None
def __init__(self, **kwargs):
@@ -45,8 +43,8 @@ class NTPClient(Service, discriminator="ntp-client"):
:param ntp_server_ip_address: IPv4 address of NTP server.
:param ntp_client_ip_Address: IPv4 address of NTP client.
"""
self.ntp_server = ntp_server_ip_address
self.sys_log.info(f"{self.name}: ntp_server: {self.ntp_server}")
self.config.ntp_server_ip = ntp_server_ip_address
self.sys_log.info(f"{self.name}: ntp_server: {self.config.ntp_server_ip}")
def describe_state(self) -> Dict:
"""
@@ -108,10 +106,10 @@ class NTPClient(Service, discriminator="ntp-client"):
def request_time(self) -> None:
"""Send request to ntp_server."""
if self.ntp_server:
if self.config.ntp_server_ip:
self.software_manager.session_manager.receive_payload_from_software_manager(
payload=NTPPacket(),
dst_ip_address=self.ntp_server,
dst_ip_address=self.config.ntp_server_ip,
src_port=self.port,
dst_port=self.port,
ip_protocol=self.protocol,

View File

@@ -315,7 +315,7 @@ class IOSoftware(Software, ABC):
"""
if self.software_manager and self.software_manager.node.operating_state != NodeOperatingState.ON:
self.software_manager.node.sys_log.error(
f"{self.name} Error: {self.software_manager.node.hostname} is not powered on."
f"{self.name} Error: {self.software_manager.node.config.hostname} is not powered on."
)
return False
return True

View File

@@ -119,7 +119,14 @@ def application_class():
@pytest.fixture(scope="function")
def file_system() -> FileSystem:
computer = Computer(hostname="fs_node", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0)
computer_cfg = {
"type": "computer",
"hostname": "fs_node",
"ip_address": "192.168.1.2",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
computer = Computer.from_config(config=computer_cfg)
computer.power_on()
return computer.file_system
@@ -129,23 +136,29 @@ def client_server() -> Tuple[Computer, Server]:
network = Network()
# Create Computer
computer = Computer(
hostname="computer",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
computer_cfg = {
"type": "computer",
"hostname": "computer",
"ip_address": "192.168.1.2",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
computer: Computer = Computer.from_config(config=computer_cfg)
computer.power_on()
# Create Server
server = Server(
hostname="server",
ip_address="192.168.1.3",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
server_cfg = {
"type": "server",
"hostname": "server",
"ip_address": "192.168.1.3",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
server: Server = Server.from_config(config=server_cfg)
server.power_on()
# Connect Computer and Server
@@ -162,26 +175,33 @@ def client_switch_server() -> Tuple[Computer, Switch, Server]:
network = Network()
# Create Computer
computer = Computer(
hostname="computer",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
computer_cfg = {
"type": "computer",
"hostname": "computer",
"ip_address": "192.168.1.2",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
computer: Computer = Computer.from_config(config=computer_cfg)
computer.power_on()
# Create Server
server = Server(
hostname="server",
ip_address="192.168.1.3",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
server_cfg = {
"type": "server",
"hostname": "server",
"ip_address": "192.168.1.3",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
server: Server = Server.from_config(config=server_cfg)
server.power_on()
switch = Switch(hostname="switch", start_up_duration=0)
# Create Switch
switch: Switch = Switch.from_config(config={"type": "switch", "hostname": "switch", "start_up_duration": 0})
switch.power_on()
network.connect(endpoint_a=computer.network_interface[1], endpoint_b=switch.network_interface[1])
@@ -211,65 +231,96 @@ def example_network() -> Network:
network = Network()
# Router 1
router_1 = Router(hostname="router_1", start_up_duration=0)
router_1_cfg = {"hostname": "router_1", "type": "router", "start_up_duration": 0}
# router_1 = Router(hostname="router_1", start_up_duration=0)
router_1 = Router.from_config(config=router_1_cfg)
router_1.power_on()
router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0")
# Switch 1
switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0)
switch_1_cfg = {"hostname": "switch_1", "type": "switch", "start_up_duration": 0}
switch_1 = Switch.from_config(config=switch_1_cfg)
# switch_1 = Switch(hostname="switch_1", num_ports=8, start_up_duration=0)
switch_1.power_on()
network.connect(endpoint_a=router_1.network_interface[1], endpoint_b=switch_1.network_interface[8])
router_1.enable_port(1)
# Switch 2
switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0)
switch_2_config = {"hostname": "switch_2", "type": "switch", "num_ports": 8, "start_up_duration": 0}
# switch_2 = Switch(hostname="switch_2", num_ports=8, start_up_duration=0)
switch_2 = Switch.from_config(config=switch_2_config)
switch_2.power_on()
network.connect(endpoint_a=router_1.network_interface[2], endpoint_b=switch_2.network_interface[8])
router_1.enable_port(2)
# Client 1
client_1 = Computer(
hostname="client_1",
ip_address="192.168.10.21",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
start_up_duration=0,
)
# # Client 1
client_1_cfg = {
"type": "computer",
"hostname": "client_1",
"ip_address": "192.168.10.21",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.10.1",
"start_up_duration": 0,
}
client_1 = Computer.from_config(config=client_1_cfg)
client_1.power_on()
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
# Client 2
client_2 = Computer(
hostname="client_2",
ip_address="192.168.10.22",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
start_up_duration=0,
)
# # Client 2
client_2_cfg = {
"type": "computer",
"hostname": "client_2",
"ip_address": "192.168.10.22",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.10.1",
"start_up_duration": 0,
}
client_2 = Computer.from_config(config=client_2_cfg)
client_2.power_on()
network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.network_interface[2])
# Server 1
server_1 = Server(
hostname="server_1",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
# # Server 1
server_1_cfg = {
"type": "server",
"hostname": "server_1",
"ip_address": "192.168.1.10",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
server_1 = Server.from_config(config=server_1_cfg)
server_1.power_on()
network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1])
# DServer 2
server_2 = Server(
hostname="server_2",
ip_address="192.168.1.14",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
# # DServer 2
server_2_cfg = {
"type": "server",
"hostname": "server_2",
"ip_address": "192.168.1.14",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
server_2 = Server.from_config(config=server_2_cfg)
server_2.power_on()
network.connect(endpoint_b=server_2.network_interface[1], endpoint_a=switch_1.network_interface[2])
@@ -277,6 +328,8 @@ def example_network() -> Network:
assert all(link.is_up for link in network.links.values())
client_1.software_manager.show()
return network
@@ -309,29 +362,35 @@ def install_stuff_to_sim(sim: Simulation):
# 1: Set up network hardware
# 1.1: Configure the router
router = Router(hostname="router", num_ports=3, start_up_duration=0)
router = Router.from_config(config={"type": "router", "hostname": "router", "num_ports": 3, "start_up_duration": 0})
router.power_on()
router.configure_port(port=1, ip_address="10.0.1.1", subnet_mask="255.255.255.0")
router.configure_port(port=2, ip_address="10.0.2.1", subnet_mask="255.255.255.0")
# 1.2: Create and connect switches
switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0)
switch_1 = Switch.from_config(
config={"type": "switch", "hostname": "switch_1", "num_ports": 6, "start_up_duration": 0}
)
switch_1.power_on()
network.connect(endpoint_a=router.network_interface[1], endpoint_b=switch_1.network_interface[6])
router.enable_port(1)
switch_2 = Switch(hostname="switch_2", num_ports=6, start_up_duration=0)
switch_2 = Switch.from_config(
config={"type": "switch", "hostname": "switch_2", "num_ports": 6, "start_up_duration": 0}
)
switch_2.power_on()
network.connect(endpoint_a=router.network_interface[2], endpoint_b=switch_2.network_interface[6])
router.enable_port(2)
# 1.3: Create and connect computer
client_1 = Computer(
hostname="client_1",
ip_address="10.0.1.2",
subnet_mask="255.255.255.0",
default_gateway="10.0.1.1",
start_up_duration=0,
)
client_1_cfg = {
"type": "computer",
"hostname": "client_1",
"ip_address": "10.0.1.2",
"subnet_mask": "255.255.255.0",
"default_gateway": "10.0.1.1",
"start_up_duration": 0,
}
client_1: Computer = Computer.from_config(config=client_1_cfg)
client_1.power_on()
network.connect(
endpoint_a=client_1.network_interface[1],
@@ -339,23 +398,28 @@ def install_stuff_to_sim(sim: Simulation):
)
# 1.4: Create and connect servers
server_1 = Server(
hostname="server_1",
ip_address="10.0.2.2",
subnet_mask="255.255.255.0",
default_gateway="10.0.2.1",
start_up_duration=0,
)
server_1_cfg = {
"type": "server",
"hostname": "server_1",
"ip_address": "10.0.2.2",
"subnet_mask": "255.255.255.0",
"default_gateway": "10.0.2.1",
"start_up_duration": 0,
}
server_1: Server = Server.from_config(config=server_1_cfg)
server_1.power_on()
network.connect(endpoint_a=server_1.network_interface[1], endpoint_b=switch_2.network_interface[1])
server_2_cfg = {
"type": "server",
"hostname": "server_2",
"ip_address": "10.0.2.3",
"subnet_mask": "255.255.255.0",
"default_gateway": "10.0.2.1",
"start_up_duration": 0,
}
server_2 = Server(
hostname="server_2",
ip_address="10.0.2.3",
subnet_mask="255.255.255.0",
default_gateway="10.0.2.1",
start_up_duration=0,
)
server_2: Server = Server.from_config(config=server_2_cfg)
server_2.power_on()
network.connect(endpoint_a=server_2.network_interface[1], endpoint_b=switch_2.network_interface[2])
@@ -403,18 +467,18 @@ def install_stuff_to_sim(sim: Simulation):
assert acl_rule is None
# 5.2: Assert the client is correctly configured
c: Computer = [node for node in sim.network.nodes.values() if node.hostname == "client_1"][0]
c: Computer = [node for node in sim.network.nodes.values() if node.config.hostname == "client_1"][0]
assert c.software_manager.software.get("web-browser") is not None
assert c.software_manager.software.get("dns-client") is not None
assert str(c.network_interface[1].ip_address) == "10.0.1.2"
# 5.3: Assert that server_1 is correctly configured
s1: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_1"][0]
s1: Server = [node for node in sim.network.nodes.values() if node.config.hostname == "server_1"][0]
assert str(s1.network_interface[1].ip_address) == "10.0.2.2"
assert s1.software_manager.software.get("dns-server") is not None
# 5.4: Assert that server_2 is correctly configured
s2: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_2"][0]
s2: Server = [node for node in sim.network.nodes.values() if node.config.hostname == "server_2"][0]
assert str(s2.network_interface[1].ip_address) == "10.0.2.3"
assert s2.software_manager.software.get("web-server") is not None

View File

@@ -12,6 +12,7 @@ from sb3_contrib import MaskablePPO
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from tests import TEST_ASSETS_ROOT
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"

View File

@@ -12,12 +12,18 @@ def test_passing_actions_down(monkeypatch) -> None:
sim = Simulation()
pc1 = Computer(hostname="PC-1", ip_address="10.10.1.1", subnet_mask="255.255.255.0")
pc1 = Computer.from_config(
config={"type": "computer", "hostname": "PC-1", "ip_address": "10.10.1.1", "subnet_mask": "255.255.255.0"}
)
pc1.start_up_duration = 0
pc1.power_on()
pc2 = Computer(hostname="PC-2", ip_address="10.10.1.2", subnet_mask="255.255.255.0")
srv = Server(hostname="WEBSERVER", ip_address="10.10.1.100", subnet_mask="255.255.255.0")
s1 = Switch(hostname="switch1")
pc2 = Computer.from_config(
config={"type": "computer", "hostname": "PC-2", "ip_address": "10.10.1.2", "subnet_mask": "255.255.255.0"}
)
srv = Server.from_config(
config={"type": "server", "hostname": "WEBSERVER", "ip_address": "10.10.1.100", "subnet_mask": "255.255.255.0"}
)
s1 = Switch.from_config(config={"type": "switch", "hostname": "switch1"})
for n in [pc1, pc2, srv, s1]:
sim.network.add_node(n)
@@ -48,6 +54,6 @@ def test_passing_actions_down(monkeypatch) -> None:
assert not action_invoked
# call the patched method
sim.apply_request(["network", "node", pc1.hostname, "file_system", "folder", "downloads", "repair"])
sim.apply_request(["network", "node", pc1.config.hostname, "file_system", "folder", "downloads", "repair"])
assert action_invoked

View File

@@ -5,6 +5,7 @@ from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.port import PORT_LOOKUP

View File

@@ -3,6 +3,8 @@ from primaite.config.load import data_manipulation_config_path
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ_NETWORK, load_config

View File

@@ -4,6 +4,7 @@ import yaml
from primaite.session.environment import PrimaiteGymEnv
from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from tests.conftest import TEST_ASSETS_ROOT
folder_path = TEST_ASSETS_ROOT / "configs" / "scenario_with_placeholders"

View File

@@ -36,8 +36,8 @@ class SuperComputer(HostNode, discriminator="supercomputer"):
SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "ftp-client": FTPClient}
def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs):
def __init__(self, **kwargs):
print("--- Extended Component: SuperComputer ---")
super().__init__(ip_address=ip_address, subnet_mask=subnet_mask, **kwargs)
super().__init__(**kwargs)
pass

View File

@@ -31,14 +31,14 @@ class ExtendedService(Service, discriminator="extended-service"):
type: str = "extended-service"
backup_server_ip: IPv4Address = None
"""IP address of the backup server."""
config: "ExtendedService.ConfigSchema" = Field(default_factory=lambda: ExtendedService.ConfigSchema())
password: Optional[str] = None
"""Password that needs to be provided by clients if they want to connect to the DatabaseService."""
backup_server_ip: IPv4Address = None
"""IP address of the backup server."""
latest_backup_directory: str = None
"""Directory of latest backup."""

View File

@@ -25,6 +25,7 @@ def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.config.shut_down_duration = 3
assert client_1.operating_state == NodeOperatingState.ON
@@ -35,13 +36,15 @@ def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy
assert client_1.operating_state == NodeOperatingState.SHUTTING_DOWN
for i in range(client_1.shut_down_duration + 1):
for i in range(client_1.config.shut_down_duration + 1):
action = ("do-nothing", {})
agent.store_action(action)
game.step()
assert client_1.operating_state == NodeOperatingState.OFF
client_1.config.start_up_duration = 3
# turn it on
action = ("node-startup", {"node_name": "client_1"})
agent.store_action(action)
@@ -49,7 +52,7 @@ def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy
assert client_1.operating_state == NodeOperatingState.BOOTING
for i in range(client_1.start_up_duration + 1):
for i in range(client_1.config.start_up_duration + 1):
action = ("do-nothing", {})
agent.store_action(action)
game.step()
@@ -79,7 +82,7 @@ def test_node_cannot_be_shut_down_if_node_is_already_off(game_and_agent_fixture:
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.power_off()
for i in range(client_1.shut_down_duration + 1):
for i in range(client_1.config.shut_down_duration + 1):
action = ("do-nothing", {})
agent.store_action(action)
game.step()

View File

@@ -36,7 +36,7 @@ def test_acl_observations(simulation):
router.acl.add_rule(action=ACLAction.PERMIT, dst_port=PORT_LOOKUP["NTP"], src_port=PORT_LOOKUP["NTP"], position=1)
acl_obs = ACLObservation(
where=["network", "nodes", router.hostname, "acl", "acl"],
where=["network", "nodes", router.config.hostname, "acl", "acl"],
ip_list=[],
port_list=[123, 80, 5432],
protocol_list=["tcp", "udp", "icmp"],

View File

@@ -24,7 +24,7 @@ def test_file_observation(simulation):
file = pc.file_system.create_file(file_name="dog.png")
dog_file_obs = FileObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"],
include_num_access=False,
file_system_requires_scan=True,
)
@@ -52,7 +52,7 @@ def test_folder_observation(simulation):
file = pc.file_system.create_file(file_name="dog.png", folder_name="test_folder")
root_folder_obs = FolderObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"],
where=["network", "nodes", pc.config.hostname, "file_system", "folders", "test_folder"],
include_num_access=False,
file_system_requires_scan=True,
num_files=1,

View File

@@ -25,7 +25,8 @@ def check_default_rules(acl_obs):
def test_firewall_observation():
"""Test adding/removing acl rules and enabling/disabling ports."""
net = Network()
firewall = Firewall(hostname="firewall", operating_state=NodeOperatingState.ON)
firewall_cfg = {"type": "firewall", "hostname": "firewall"}
firewall = Firewall.from_config(config=firewall_cfg)
firewall_observation = FirewallObservation(
where=[],
num_rules=7,
@@ -116,7 +117,9 @@ def test_firewall_observation():
assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(1, 4))
# connect a switch to the firewall and check that only the correct port is updated
switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON)
switch: Switch = Switch.from_config(
config={"type": "switch", "hostname": "switch", "num_ports": 1, "operating_state": "ON"}
)
link = net.connect(firewall.network_interface[1], switch.network_interface[1])
assert firewall.network_interface[1].enabled
observation = firewall_observation.observe(firewall.describe_state())

View File

@@ -56,12 +56,26 @@ def test_link_observation():
"""Check the shape and contents of the link observation."""
net = Network()
sim = Simulation(network=net)
switch = Switch(hostname="switch", num_ports=5, operating_state=NodeOperatingState.ON)
computer_1 = Computer(
hostname="computer_1", ip_address="10.0.0.1", subnet_mask="255.255.255.0", start_up_duration=0
switch: Switch = Switch.from_config(
config={"type": "switch", "hostname": "switch", "num_ports": 5, "operating_state": "ON"}
)
computer_2 = Computer(
hostname="computer_2", ip_address="10.0.0.2", subnet_mask="255.255.255.0", start_up_duration=0
computer_1: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "computer_1",
"ip_address": "10.0.0.1",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
computer_2: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "computer_2",
"ip_address": "10.0.0.2",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
computer_1.power_on()
computer_2.power_on()

View File

@@ -75,7 +75,7 @@ def test_nic(simulation):
nic: NIC = pc.network_interface[1]
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True)
nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True)
# Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs
nmne_config = {
@@ -108,7 +108,7 @@ def test_nic_categories(simulation):
"""Test the NIC observation nmne count categories."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True)
nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True)
assert nic_obs.high_nmne_threshold == 10 # default
assert nic_obs.med_nmne_threshold == 5 # default
@@ -163,7 +163,9 @@ def test_nic_monitored_traffic(simulation):
pc2: Computer = simulation.network.get_node_by_hostname("client_2")
nic_obs = NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=False, monitored_traffic=monitored_traffic
where=["network", "nodes", pc.config.hostname, "NICs", 1],
include_nmne=False,
monitored_traffic=monitored_traffic,
)
simulation.pre_timestep(0) # apply timestep to whole sim

View File

@@ -25,7 +25,7 @@ def test_host_observation(simulation):
pc: Computer = simulation.network.get_node_by_hostname("client_1")
host_obs = HostObservation(
where=["network", "nodes", pc.hostname],
where=["network", "nodes", pc.config.hostname],
num_applications=0,
num_files=1,
num_folders=1,
@@ -56,7 +56,7 @@ def test_host_observation(simulation):
observation_state = host_obs.observe(simulation.describe_state())
assert observation_state.get("operating_status") == 4 # shutting down
for i in range(pc.shut_down_duration + 1):
for i in range(pc.config.shut_down_duration + 1):
pc.apply_timestep(i)
observation_state = host_obs.observe(simulation.describe_state())

View File

@@ -16,7 +16,9 @@ from primaite.utils.validation.port import PORT_LOOKUP
def test_router_observation():
"""Test adding/removing acl rules and enabling/disabling ports."""
net = Network()
router = Router(hostname="router", num_ports=5, operating_state=NodeOperatingState.ON)
router = Router.from_config(
config={"type": "router", "hostname": "router", "num_ports": 5, "operating_state": "ON"}
)
ports = [PortObservation(where=["NICs", i]) for i in range(1, 6)]
acl = ACLObservation(
@@ -89,7 +91,9 @@ def test_router_observation():
assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(1, 6))
# connect a switch to the router and check that only the correct port is updated
switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON)
switch: Switch = Switch.from_config(
config={"type": "switch", "hostname": "switch", "num_ports": 1, "operating_state": "ON"}
)
link = net.connect(router.network_interface[1], switch.network_interface[1])
assert router.network_interface[1].enabled
observed_output = router_observation.observe(router.describe_state())

View File

@@ -29,7 +29,7 @@ def test_service_observation(simulation):
ntp_server = pc.software_manager.software.get("ntp-server")
assert ntp_server
service_obs = ServiceObservation(where=["network", "nodes", pc.hostname, "services", "ntp-server"])
service_obs = ServiceObservation(where=["network", "nodes", pc.config.hostname, "services", "ntp-server"])
assert service_obs.space["operating_status"] == spaces.Discrete(7)
assert service_obs.space["health_status"] == spaces.Discrete(5)
@@ -54,7 +54,7 @@ def test_application_observation(simulation):
web_browser: WebBrowser = pc.software_manager.software.get("web-browser")
assert web_browser
app_obs = ApplicationObservation(where=["network", "nodes", pc.hostname, "applications", "web-browser"])
app_obs = ApplicationObservation(where=["network", "nodes", pc.config.hostname, "applications", "web-browser"])
web_browser.close()
observation_state = app_obs.observe(simulation.describe_state())

View File

@@ -2,6 +2,7 @@
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from primaite.simulator.system.services.service import ServiceOperatingState
from tests.conftest import TEST_ASSETS_ROOT

View File

@@ -21,6 +21,7 @@ from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.software import SoftwareHealthState

View File

@@ -8,14 +8,16 @@ from primaite.simulator.sim_container import Simulation
def test_file_observation():
sim = Simulation()
pc = Computer(hostname="beep", ip_address="123.123.123.123", subnet_mask="255.255.255.0")
pc: Computer = Computer.from_config(
config={"type": "computer", "hostname": "beep", "ip_address": "123.123.123.123", "subnet_mask": "255.255.255.0"}
)
sim.network.add_node(pc)
f = pc.file_system.create_file(file_name="dog.png")
state = sim.describe_state()
dog_file_obs = FileObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"],
include_num_access=False,
file_system_requires_scan=False,
)

View File

@@ -2,6 +2,7 @@
import yaml
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from tests import TEST_ASSETS_ROOT

View File

@@ -16,7 +16,7 @@ def test_wireless_link_loading(wireless_wan_network):
# Configure Router 2 ACLs
router_2.acl.add_rule(action=ACLAction.PERMIT, position=1)
airspace = router_1.airspace
airspace = router_1.config.airspace
client.software_manager.install(FTPClient)
ftp_client: FTPClient = client.software_manager.software.get("ftp-client")

View File

@@ -84,44 +84,55 @@ class BroadcastTestClient(Application, discriminator="broadcast-test-client"):
def broadcast_network() -> Network:
network = Network()
client_1 = Computer(
hostname="client_1",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
client_1_cfg = {
"type": "computer",
"hostname": "client_1",
"ip_address": "192.168.1.2",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
client_1: Computer = Computer.from_config(config=client_1_cfg)
client_1.power_on()
client_1.software_manager.install(BroadcastTestClient)
application_1 = client_1.software_manager.software["broadcast-test-client"]
application_1.run()
client_2_cfg = {
"type": "computer",
"hostname": "client_2",
"ip_address": "192.168.1.3",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
client_2 = Computer(
hostname="client_2",
ip_address="192.168.1.3",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
client_2: Computer = Computer.from_config(config=client_2_cfg)
client_2.power_on()
client_2.software_manager.install(BroadcastTestClient)
application_2 = client_2.software_manager.software["broadcast-test-client"]
application_2.run()
server_1 = Server(
hostname="server_1",
ip_address="192.168.1.1",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
server_1_cfg = {
"type": "server",
"hostname": "server_1",
"ip_address": "192.168.1.1",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
server_1: Server = Server.from_config(config=server_1_cfg)
server_1.power_on()
server_1.software_manager.install(BroadcastTestService)
service: BroadcastTestService = server_1.software_manager.software["BroadcastService"]
service.start()
switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0)
switch_1: Switch = Switch.from_config(
config={"type": "switch", "hostname": "switch_1", "num_ports": 6, "start_up_duration": 0}
)
switch_1.power_on()
network.connect(endpoint_a=client_1.network_interface[1], endpoint_b=switch_1.network_interface[1])

View File

@@ -41,7 +41,9 @@ def dmz_external_internal_network() -> Network:
"""
network = Network()
firewall_node: Firewall = Firewall(hostname="firewall_1", start_up_duration=0)
firewall_node: Firewall = Firewall.from_config(
config={"type": "firewall", "hostname": "firewall_1", "start_up_duration": 0}
)
firewall_node.power_on()
# configure firewall ports
firewall_node.configure_external_port(
@@ -81,12 +83,15 @@ def dmz_external_internal_network() -> Network:
)
# external node
external_node = Computer(
hostname="external_node",
ip_address="192.168.10.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
start_up_duration=0,
external_node: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "external_node",
"ip_address": "192.168.10.2",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.10.1",
"start_up_duration": 0,
}
)
external_node.power_on()
external_node.software_manager.install(NTPServer)
@@ -96,12 +101,15 @@ def dmz_external_internal_network() -> Network:
network.connect(endpoint_b=external_node.network_interface[1], endpoint_a=firewall_node.external_port)
# internal node
internal_node = Computer(
hostname="internal_node",
ip_address="192.168.0.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.0.1",
start_up_duration=0,
internal_node: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "internal_node",
"ip_address": "192.168.0.2",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.0.1",
"start_up_duration": 0,
}
)
internal_node.power_on()
internal_node.software_manager.install(NTPClient)
@@ -112,12 +120,15 @@ def dmz_external_internal_network() -> Network:
network.connect(endpoint_b=internal_node.network_interface[1], endpoint_a=firewall_node.internal_port)
# dmz node
dmz_node = Computer(
hostname="dmz_node",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
dmz_node: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "dmz_node",
"ip_address": "192.168.1.2",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
)
dmz_node.power_on()
dmz_ntp_client: NTPClient = dmz_node.software_manager.software["ntp-client"]
@@ -155,9 +166,9 @@ def test_nodes_can_ping_default_gateway(dmz_external_internal_network):
internal_node = dmz_external_internal_network.get_node_by_hostname("internal_node")
dmz_node = dmz_external_internal_network.get_node_by_hostname("dmz_node")
assert internal_node.ping(internal_node.default_gateway) # default gateway internal
assert dmz_node.ping(dmz_node.default_gateway) # default gateway dmz
assert external_node.ping(external_node.default_gateway) # default gateway external
assert internal_node.ping(internal_node.config.default_gateway) # default gateway internal
assert dmz_node.ping(dmz_node.config.default_gateway) # default gateway dmz
assert external_node.ping(external_node.config.default_gateway) # default gateway external
def test_nodes_can_ping_default_gateway_on_another_subnet(dmz_external_internal_network):
@@ -171,14 +182,14 @@ def test_nodes_can_ping_default_gateway_on_another_subnet(dmz_external_internal_
internal_node = dmz_external_internal_network.get_node_by_hostname("internal_node")
dmz_node = dmz_external_internal_network.get_node_by_hostname("dmz_node")
assert internal_node.ping(external_node.default_gateway) # internal node to external default gateway
assert internal_node.ping(dmz_node.default_gateway) # internal node to dmz default gateway
assert internal_node.ping(external_node.config.default_gateway) # internal node to external default gateway
assert internal_node.ping(dmz_node.config.default_gateway) # internal node to dmz default gateway
assert dmz_node.ping(internal_node.default_gateway) # dmz node to internal default gateway
assert dmz_node.ping(external_node.default_gateway) # dmz node to external default gateway
assert dmz_node.ping(internal_node.config.default_gateway) # dmz node to internal default gateway
assert dmz_node.ping(external_node.config.default_gateway) # dmz node to external default gateway
assert external_node.ping(external_node.default_gateway) # external node to internal default gateway
assert external_node.ping(dmz_node.default_gateway) # external node to dmz default gateway
assert external_node.ping(external_node.config.default_gateway) # external node to internal default gateway
assert external_node.ping(dmz_node.config.default_gateway) # external node to dmz default gateway
def test_nodes_can_ping_each_other(dmz_external_internal_network):

View File

@@ -10,25 +10,31 @@ def test_node_to_node_ping():
"""Tests two Computers are able to ping each other."""
network = Network()
client_1 = Computer(
hostname="client_1",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
client_1: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "client_1",
"ip_address": "192.168.1.10",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
)
client_1.power_on()
server_1 = Server(
hostname="server_1",
ip_address="192.168.1.11",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
server_1: Server = Server.from_config(
config={
"type": "server",
"hostname": "server_1",
"ip_address": "192.168.1.11",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
)
server_1.power_on()
switch_1 = Switch(hostname="switch_1", start_up_duration=0)
switch_1: Switch = Switch.from_config(config={"type": "switch", "hostname": "switch_1", "start_up_duration": 0})
switch_1.power_on()
network.connect(endpoint_a=client_1.network_interface[1], endpoint_b=switch_1.network_interface[1])
@@ -41,14 +47,38 @@ def test_multi_nic():
"""Tests that Computers with multiple NICs can ping each other and the data go across the correct links."""
network = Network()
node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
node_a: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "node_a",
"ip_address": "192.168.0.10",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
node_a.power_on()
node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0)
node_b: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "node_b",
"ip_address": "192.168.0.11",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
node_b.power_on()
node_b.connect_nic(NIC(ip_address="10.0.0.12", subnet_mask="255.0.0.0"))
node_c = Computer(hostname="node_c", ip_address="10.0.0.13", subnet_mask="255.0.0.0", start_up_duration=0)
node_c: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "node_c",
"ip_address": "10.0.0.13",
"subnet_mask": "255.0.0.0",
"start_up_duration": 0,
}
)
node_c.power_on()
network.connect(node_a.network_interface[1], node_b.network_interface[1])

View File

@@ -1,6 +1,7 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.network.networks import multi_lan_internet_network_example
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.web_browser import WebBrowser

View File

@@ -27,7 +27,15 @@ def test_network(example_network):
def test_adding_removing_nodes():
"""Check that we can create and add a node to a network."""
net = Network()
n1 = Computer(hostname="computer", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0)
n1 = Computer.from_config(
config={
"type": "computer",
"hostname": "computer",
"ip_address": "192.168.1.2",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
net.add_node(n1)
assert n1.parent is net
assert n1 in net
@@ -37,10 +45,18 @@ def test_adding_removing_nodes():
assert n1 not in net
def test_readding_node():
"""Check that warning is raised when readding a node."""
def test_reading_node():
"""Check that warning is raised when reading a node."""
net = Network()
n1 = Computer(hostname="computer", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0)
n1 = Computer.from_config(
config={
"type": "computer",
"hostname": "computer",
"ip_address": "192.168.1.2",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
net.add_node(n1)
net.add_node(n1)
assert n1.parent is net
@@ -50,7 +66,15 @@ def test_readding_node():
def test_removing_nonexistent_node():
"""Check that warning is raised when trying to remove a node that is not in the network."""
net = Network()
n1 = Computer(hostname="computer1", ip_address="192.168.1.1", subnet_mask="255.255.255.0", start_up_duration=0)
n1 = Computer.from_config(
config={
"type": "computer",
"hostname": "computer1",
"ip_address": "192.168.1.1",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
net.remove_node(n1)
assert n1.parent is None
assert n1 not in net
@@ -59,8 +83,24 @@ def test_removing_nonexistent_node():
def test_connecting_nodes():
"""Check that two nodes on the network can be connected."""
net = Network()
n1 = Computer(hostname="computer1", ip_address="192.168.1.1", subnet_mask="255.255.255.0", start_up_duration=0)
n2 = Computer(hostname="computer2", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0)
n1: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "computer1",
"ip_address": "192.168.1.1",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
n2: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "computer2",
"ip_address": "192.168.1.2",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
net.add_node(n1)
net.add_node(n2)
@@ -75,7 +115,15 @@ def test_connecting_nodes():
def test_connecting_node_to_itself_fails():
net = Network()
node = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0)
node = Computer.from_config(
config={
"type": "computer",
"hostname": "node_b",
"ip_address": "192.168.0.11",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
node.power_on()
node.connect_nic(NIC(ip_address="10.0.0.12", subnet_mask="255.0.0.0"))
@@ -92,8 +140,24 @@ def test_connecting_node_to_itself_fails():
def test_disconnecting_nodes():
net = Network()
n1 = Computer(hostname="computer1", ip_address="192.168.1.1", subnet_mask="255.255.255.0", start_up_duration=0)
n2 = Computer(hostname="computer2", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0)
n1 = Computer.from_config(
config={
"type": "computer",
"hostname": "computer1",
"ip_address": "192.168.1.1",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
n2 = Computer.from_config(
config={
"type": "computer",
"hostname": "computer2",
"ip_address": "192.168.1.2",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
net.connect(n1.network_interface[1], n2.network_interface[1])
assert len(net.links) == 1

View File

@@ -15,25 +15,31 @@ from primaite.utils.validation.port import PORT_LOOKUP
@pytest.fixture(scope="function")
def pc_a_pc_b_router_1() -> Tuple[Computer, Computer, Router]:
network = Network()
pc_a = Computer(
hostname="pc_a",
ip_address="192.168.0.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.0.1",
start_up_duration=0,
pc_a = Computer.from_config(
config={
"type": "computer",
"hostname": "pc_a",
"ip_address": "192.168.0.10",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.0.1",
"start_up_duration": 0,
}
)
pc_a.power_on()
pc_b = Computer(
hostname="pc_b",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
pc_b = Computer.from_config(
config={
"type": "computer",
"hostname": "pc_b",
"ip_address": "192.168.1.10",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
)
pc_b.power_on()
router_1 = Router(hostname="router_1", start_up_duration=0)
router_1 = Router.from_config(config={"type": "router", "hostname": "router_1", "start_up_duration": 0})
router_1.power_on()
router_1.configure_port(1, "192.168.0.1", "255.255.255.0")
@@ -52,18 +58,21 @@ def multi_hop_network() -> Network:
network = Network()
# Configure PC A
pc_a = Computer(
hostname="pc_a",
ip_address="192.168.0.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.0.1",
start_up_duration=0,
pc_a: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "pc_a",
"ip_address": "192.168.0.2",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.0.1",
"start_up_duration": 0,
}
)
pc_a.power_on()
network.add_node(pc_a)
# Configure Router 1
router_1 = Router(hostname="router_1", start_up_duration=0)
router_1: Router = Router.from_config(config={"type": "router", "hostname": "router_1", "start_up_duration": 0})
router_1.power_on()
network.add_node(router_1)
@@ -79,18 +88,21 @@ def multi_hop_network() -> Network:
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23)
# Configure PC B
pc_b = Computer(
hostname="pc_b",
ip_address="192.168.2.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.2.1",
start_up_duration=0,
pc_b: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "pc_b",
"ip_address": "192.168.2.2",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.2.1",
"start_up_duration": 0,
}
)
pc_b.power_on()
network.add_node(pc_b)
# Configure Router 2
router_2 = Router(hostname="router_2", start_up_duration=0)
router_2: Router = Router.from_config(config={"type": "router", "hostname": "router_2", "start_up_duration": 0})
router_2.power_on()
network.add_node(router_2)
@@ -113,13 +125,13 @@ def multi_hop_network() -> Network:
def test_ping_default_gateway(pc_a_pc_b_router_1):
pc_a, pc_b, router_1 = pc_a_pc_b_router_1
assert pc_a.ping(pc_a.default_gateway)
assert pc_a.ping(pc_a.config.default_gateway)
def test_ping_other_router_port(pc_a_pc_b_router_1):
pc_a, pc_b, router_1 = pc_a_pc_b_router_1
assert pc_a.ping(pc_b.default_gateway)
assert pc_a.ping(pc_b.config.default_gateway)
def test_host_on_other_subnet(pc_a_pc_b_router_1):

View File

@@ -17,18 +17,23 @@ def wireless_wan_network():
network = Network()
# Configure PC A
pc_a = Computer(
hostname="pc_a",
ip_address="192.168.0.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.0.1",
start_up_duration=0,
pc_a = Computer.from_config(
config={
"type": "computer",
"hostname": "pc_a",
"ip_address": "192.168.0.2",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.0.1",
"start_up_duration": 0,
}
)
pc_a.power_on()
network.add_node(pc_a)
# Configure Router 1
router_1 = WirelessRouter(hostname="router_1", start_up_duration=0, airspace=network.airspace)
router_1 = WirelessRouter.from_config(
config={"type": "wireless_router", "hostname": "router_1", "start_up_duration": 0, "airspace": network.airspace}
)
router_1.power_on()
network.add_node(router_1)
@@ -43,18 +48,23 @@ def wireless_wan_network():
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23)
# Configure PC B
pc_b = Computer(
hostname="pc_b",
ip_address="192.168.2.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.2.1",
start_up_duration=0,
pc_b: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "pc_b",
"ip_address": "192.168.2.2",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.2.1",
"start_up_duration": 0,
}
)
pc_b.power_on()
network.add_node(pc_b)
# Configure Router 2
router_2 = WirelessRouter(hostname="router_2", start_up_duration=0, airspace=network.airspace)
router_2: WirelessRouter = WirelessRouter.from_config(
config={"type": "wireless_router", "hostname": "router_2", "start_up_duration": 0, "airspace": network.airspace}
)
router_2.power_on()
network.add_node(router_2)
@@ -98,8 +108,8 @@ def wireless_wan_network_from_config_yaml():
def test_cross_wireless_wan_connectivity(wireless_wan_network):
pc_a, pc_b, router_1, router_2 = wireless_wan_network
# Ensure that PCs can ping across routers before any frequency change
assert pc_a.ping(pc_a.default_gateway), "PC A should ping its default gateway successfully."
assert pc_b.ping(pc_b.default_gateway), "PC B should ping its default gateway successfully."
assert pc_a.ping(pc_a.config.default_gateway), "PC A should ping its default gateway successfully."
assert pc_b.ping(pc_b.config.default_gateway), "PC B should ping its default gateway successfully."
assert pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should ping PC B across routers successfully."
assert pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should ping PC A across routers successfully."
@@ -109,8 +119,8 @@ def test_cross_wireless_wan_connectivity_from_yaml(wireless_wan_network_from_con
pc_a = wireless_wan_network_from_config_yaml.get_node_by_hostname("pc_a")
pc_b = wireless_wan_network_from_config_yaml.get_node_by_hostname("pc_b")
assert pc_a.ping(pc_a.default_gateway), "PC A should ping its default gateway successfully."
assert pc_b.ping(pc_b.default_gateway), "PC B should ping its default gateway successfully."
assert pc_a.ping(pc_a.config.default_gateway), "PC A should ping its default gateway successfully."
assert pc_b.ping(pc_b.config.default_gateway), "PC B should ping its default gateway successfully."
assert pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should ping PC B across routers successfully."
assert pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should ping PC A across routers successfully."

View File

@@ -34,52 +34,64 @@ def basic_network() -> Network:
# Creating two generic nodes for the C2 Server and the C2 Beacon.
node_a = Computer(
hostname="node_a",
ip_address="192.168.0.2",
subnet_mask="255.255.255.252",
default_gateway="192.168.0.1",
start_up_duration=0,
)
node_a_cfg = {
"type": "computer",
"hostname": "node_a",
"ip_address": "192.168.0.2",
"subnet_mask": "255.255.255.252",
"default_gateway": "192.168.0.1",
"start_up_duration": 0,
}
node_a: Computer = Computer.from_config(config=node_a_cfg)
node_a.power_on()
node_a.software_manager.get_open_ports()
node_a.software_manager.install(software_class=C2Server)
node_b = Computer(
hostname="node_b",
ip_address="192.168.255.2",
subnet_mask="255.255.255.248",
default_gateway="192.168.255.1",
start_up_duration=0,
)
node_b_cfg = {
"type": "computer",
"hostname": "node_b",
"ip_address": "192.168.255.2",
"subnet_mask": "255.255.255.248",
"default_gateway": "192.168.255.1",
"start_up_duration": 0,
}
node_b: Computer = Computer.from_config(config=node_b_cfg)
node_b.power_on()
node_b.software_manager.install(software_class=C2Beacon)
# Creating a generic computer for testing remote terminal connections.
node_c = Computer(
hostname="node_c",
ip_address="192.168.255.3",
subnet_mask="255.255.255.248",
default_gateway="192.168.255.1",
start_up_duration=0,
)
node_c_cfg = {
"type": "computer",
"hostname": "node_c",
"ip_address": "192.168.255.3",
"subnet_mask": "255.255.255.248",
"default_gateway": "192.168.255.1",
"start_up_duration": 0,
}
node_c: Computer = Computer.from_config(config=node_c_cfg)
node_c.power_on()
# Creating a router to sit between node 1 and node 2.
router = Router(hostname="router", num_ports=3, start_up_duration=0)
router = Router.from_config(config={"type": "router", "hostname": "router", "num_ports": 3, "start_up_duration": 0})
# Default allow all.
router.acl.add_rule(action=ACLAction.PERMIT)
router.power_on()
# Creating switches for each client.
switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0)
switch_1 = Switch.from_config(
config={"type": "switch", "hostname": "switch_1", "num_ports": 6, "start_up_duration": 0}
)
switch_1.power_on()
# Connecting the switches to the router.
router.configure_port(port=1, ip_address="192.168.0.1", subnet_mask="255.255.255.252")
network.connect(endpoint_a=router.network_interface[1], endpoint_b=switch_1.network_interface[6])
switch_2 = Switch(hostname="switch_2", num_ports=6, start_up_duration=0)
switch_2 = Switch.from_config(
config={"type": "switch", "hostname": "switch_2", "num_ports": 6, "start_up_duration": 0}
)
switch_2.power_on()
network.connect(endpoint_a=router.network_interface[2], endpoint_b=switch_2.network_interface[6])

View File

@@ -10,13 +10,16 @@ from primaite.simulator.system.applications.application import Application, Appl
@pytest.fixture(scope="function")
def populated_node(application_class) -> Tuple[Application, Computer]:
computer: Computer = Computer(
hostname="test_computer",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
shut_down_duration=0,
computer: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "test_computer",
"ip_address": "192.168.1.2",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
"shut_down_duration": 0,
}
)
computer.power_on()
computer.software_manager.install(application_class)
@@ -29,13 +32,16 @@ def populated_node(application_class) -> Tuple[Application, Computer]:
def test_application_on_offline_node(application_class):
"""Test to check that the application cannot be interacted with when node it is on is off."""
computer: Computer = Computer(
hostname="test_computer",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
shut_down_duration=0,
computer: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "test_computer",
"ip_address": "192.168.1.2",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
"shut_down_duration": 0,
}
)
computer.software_manager.install(application_class)

View File

@@ -20,11 +20,27 @@ from primaite.simulator.system.software import SoftwareHealthState
@pytest.fixture(scope="function")
def peer_to_peer() -> Tuple[Computer, Computer]:
network = Network()
node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
node_a: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "node_a",
"ip_address": "192.168.0.10",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
node_a.power_on()
node_a.software_manager.get_open_ports()
node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0)
node_b: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "node_b",
"ip_address": "192.168.0.11",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
node_b.power_on()
network.connect(node_a.network_interface[1], node_b.network_interface[1])
@@ -46,7 +62,7 @@ def peer_to_peer_secure_db(peer_to_peer) -> Tuple[Computer, Computer]:
database_service: DatabaseService = node_b.software_manager.software["database-service"] # noqa
database_service.stop()
database_service.password = "12345"
database_service.config.db_password = "12345"
database_service.start()
return node_a, node_b
@@ -338,7 +354,7 @@ def test_database_client_cannot_query_offline_database_server(uc2_network):
assert db_connection.query("INSERT") is True
db_server.power_off()
for i in range(db_server.shut_down_duration + 1):
for i in range(db_server.config.shut_down_duration + 1):
uc2_network.apply_timestep(timestep=i)
assert db_server.operating_state is NodeOperatingState.OFF
@@ -412,8 +428,14 @@ def test_database_service_can_terminate_connection(peer_to_peer):
def test_client_connection_terminate_does_not_terminate_another_clients_connection():
network = Network()
db_server = Server(
hostname="db_client", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0
db_server: Server = Server.from_config(
config={
"type": "server",
"hostname": "db_client",
"ip_address": "192.168.0.11",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
db_server.power_on()
@@ -421,8 +443,14 @@ def test_client_connection_terminate_does_not_terminate_another_clients_connecti
db_service: DatabaseService = db_server.software_manager.software["database-service"] # noqa
db_service.start()
client_a = Computer(
hostname="client_a", ip_address="192.168.0.12", subnet_mask="255.255.255.0", start_up_duration=0
client_a = Computer.from_config(
config={
"type": "computer",
"hostname": "client_a",
"ip_address": "192.168.0.12",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
client_a.power_on()
@@ -430,8 +458,14 @@ def test_client_connection_terminate_does_not_terminate_another_clients_connecti
client_a.software_manager.software["database-client"].configure(server_ip_address=IPv4Address("192.168.0.11"))
client_a.software_manager.software["database-client"].run()
client_b = Computer(
hostname="client_b", ip_address="192.168.0.13", subnet_mask="255.255.255.0", start_up_duration=0
client_b = Computer.from_config(
config={
"type": "computer",
"hostname": "client_b",
"ip_address": "192.168.0.13",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
client_b.power_on()
@@ -439,7 +473,7 @@ def test_client_connection_terminate_does_not_terminate_another_clients_connecti
client_b.software_manager.software["database-client"].configure(server_ip_address=IPv4Address("192.168.0.11"))
client_b.software_manager.software["database-client"].run()
switch = Switch(hostname="switch", start_up_duration=0, num_ports=3)
switch = Switch.from_config(config={"type": "switch", "hostname": "switch", "start_up_duration": 0, "num_ports": 3})
switch.power_on()
network.connect(endpoint_a=switch.network_interface[1], endpoint_b=db_server.network_interface[1])
@@ -465,6 +499,14 @@ def test_client_connection_terminate_does_not_terminate_another_clients_connecti
def test_database_server_install_ftp_client():
server = Server(hostname="db_server", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0)
server: Server = Server.from_config(
config={
"type": "server",
"hostname": "db_server",
"ip_address": "192.168.1.2",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
server.software_manager.install(DatabaseService)
assert server.software_manager.software.get("ftp-client")

View File

@@ -72,7 +72,7 @@ def test_dns_client_requests_offline_dns_server(dns_client_and_dns_server):
server.power_off()
for i in range(server.shut_down_duration + 1):
for i in range(server.config.shut_down_duration + 1):
server.apply_timestep(timestep=i)
assert server.operating_state == NodeOperatingState.OFF

View File

@@ -87,7 +87,7 @@ def test_ftp_client_tries_to_connect_to_offline_server(ftp_client_and_ftp_server
server.power_off()
for i in range(server.shut_down_duration + 1):
for i in range(server.config.shut_down_duration + 1):
server.apply_timestep(timestep=i)
assert ftp_client.operating_state == ServiceOperatingState.RUNNING

View File

@@ -13,13 +13,15 @@ from primaite.simulator.system.services.service import Service, ServiceOperating
def populated_node(
service_class,
) -> Tuple[Server, Service]:
server = Server(
hostname="server",
ip_address="192.168.0.1",
subnet_mask="255.255.255.0",
start_up_duration=0,
shut_down_duration=0,
)
server_cfg = {
"type": "server",
"hostname": "server",
"ip_address": "192.168.0.1",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
"shut_down_duration": 0,
}
server: Server = Server.from_config(config=server_cfg)
server.power_on()
server.software_manager.install(service_class)
@@ -31,14 +33,16 @@ def populated_node(
def test_service_on_offline_node(service_class):
"""Test to check that the service cannot be interacted with when node it is on is off."""
computer: Computer = Computer(
hostname="test_computer",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
shut_down_duration=0,
)
computer_cfg = {
"type": "computer",
"hostname": "test_computer",
"ip_address": "192.168.1.2",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
"shut_down_duration": 0,
}
computer: Computer = Computer.from_config(config=computer_cfg)
computer.power_on()
computer.software_manager.install(service_class)

View File

@@ -14,21 +14,27 @@ from primaite.simulator.network.hardware.nodes.host.server import Server
def client_server_network() -> Tuple[Computer, Server, Network]:
network = Network()
client = Computer(
hostname="client",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
client = Computer.from_config(
config={
"type": "computer",
"hostname": "client",
"ip_address": "192.168.1.2",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
)
client.power_on()
server = Server(
hostname="server",
ip_address="192.168.1.3",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
server = Server.from_config(
config={
"type": "server",
"hostname": "server",
"ip_address": "192.168.1.3",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
)
server.power_on()

View File

@@ -94,7 +94,7 @@ def test_web_page_request_from_shut_down_server(web_client_and_web_server):
server.power_off()
for i in range(server.shut_down_duration + 1):
for i in range(server.config.shut_down_duration + 1):
server.apply_timestep(timestep=i)
# node should be off

View File

@@ -111,7 +111,7 @@ def test_request_fails_if_node_off(example_network, node_request):
"""Test that requests succeed when the node is on, and fail if the node is off."""
net = example_network
client_1: HostNode = net.get_node_by_hostname("client_1")
client_1.shut_down_duration = 0
client_1.config.shut_down_duration = 0
assert client_1.operating_state == NodeOperatingState.ON
resp_1 = net.apply_request(node_request)
@@ -144,9 +144,9 @@ class TestDataManipulationGreenRequests:
client_1 = net.get_node_by_hostname("client_1")
client_2 = net.get_node_by_hostname("client_2")
client_1.shut_down_duration = 0
client_1.config.shut_down_duration = 0
client_1.power_off()
client_2.shut_down_duration = 0
client_2.config.shut_down_duration = 0
client_2.power_off()
client_1_browser_execute_off = net.apply_request(["node", "client_1", "application", "web-browser", "execute"])

View File

@@ -25,7 +25,8 @@ def router_with_acl_rules():
:return: A configured Router object with ACL rules.
"""
router = Router("Router")
router_cfg = {"hostname": "router_1", "type": "router"}
router = Router.from_config(config=router_cfg)
acl = router.acl
# Add rules here as needed
acl.add_rule(
@@ -62,7 +63,8 @@ def router_with_wildcard_acl():
:return: A Router object with configured ACL rules, including rules with wildcard masking.
"""
router = Router("Router")
router_cfg = {"hostname": "router_1", "type": "router"}
router = Router.from_config(config=router_cfg)
acl = router.acl
# Rule to permit traffic from a specific source IP and port to a specific destination IP and port
acl.add_rule(
@@ -243,7 +245,8 @@ def test_ip_traffic_from_specific_subnet():
- Traffic from outside the 192.168.1.0/24 subnet is denied.
"""
router = Router("Router")
router_cfg = {"hostname": "router_1", "type": "router"}
router = Router.from_config(config=router_cfg)
acl = router.acl
# Add rules here as needed
acl.add_rule(

View File

@@ -50,9 +50,9 @@ def test_wireless_router_from_config():
},
}
rt = Router.from_config(cfg=cfg)
rt = Router.from_config(config=cfg)
assert rt.num_ports == 6
assert rt.config.num_ports == 6
assert rt.network_interface[1].ip_address == IPv4Address("192.168.1.1")
assert rt.network_interface[1].subnet_mask == IPv4Address("255.255.255.0")

View File

@@ -7,7 +7,8 @@ from primaite.simulator.network.hardware.nodes.network.switch import Switch
@pytest.fixture(scope="function")
def switch() -> Switch:
switch: Switch = Switch(hostname="switch_1", num_ports=8, start_up_duration=0)
switch_cfg = {"type": "switch", "hostname": "switch_1", "num_ports": 8, "start_up_duration": 0}
switch: Switch = Switch.from_config(config=switch_cfg)
switch.power_on()
switch.show()
return switch

View File

@@ -7,7 +7,10 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer
@pytest.fixture
def node() -> Node:
return Computer(hostname="test", ip_address="192.168.1.2", subnet_mask="255.255.255.0")
computer_cfg = {"type": "computer", "hostname": "test", "ip_address": "192.168.1.2", "subnet_mask": "255.255.255.0"}
computer = Computer.from_config(config=computer_cfg)
return computer
def test_nic_enabled_validator(node):

View File

@@ -12,7 +12,16 @@ from tests.conftest import DummyApplication, DummyService
@pytest.fixture
def node() -> Node:
return Computer(hostname="test", ip_address="192.168.1.2", subnet_mask="255.255.255.0")
computer_cfg = {
"type": "computer",
"hostname": "test",
"ip_address": "192.168.1.2",
"subnet_mask": "255.255.255.0",
"operating_state": "OFF",
}
computer = Computer.from_config(config=computer_cfg)
return computer
def test_node_startup(node):
@@ -166,7 +175,7 @@ def test_node_is_on_validator(node):
"""Test that the node is on validator."""
node.power_on()
for i in range(node.start_up_duration + 1):
for i in range(node.config.start_up_duration + 1):
node.apply_timestep(i)
validator = Node._NodeIsOnValidator(node=node)
@@ -174,7 +183,7 @@ def test_node_is_on_validator(node):
assert validator(request=[], context={})
node.power_off()
for i in range(node.shut_down_duration + 1):
for i in range(node.config.shut_down_duration + 1):
node.apply_timestep(i)
assert validator(request=[], context={}) is False
@@ -184,7 +193,7 @@ def test_node_is_off_validator(node):
"""Test that the node is on validator."""
node.power_on()
for i in range(node.start_up_duration + 1):
for i in range(node.config.start_up_duration + 1):
node.apply_timestep(i)
validator = Node._NodeIsOffValidator(node=node)
@@ -192,7 +201,7 @@ def test_node_is_off_validator(node):
assert validator(request=[], context={}) is False
node.power_off()
for i in range(node.shut_down_duration + 1):
for i in range(node.config.shut_down_duration + 1):
node.apply_timestep(i)
assert validator(request=[], context={})

View File

@@ -61,12 +61,12 @@ def test_apply_timestep_to_nodes(network):
client_1.power_off()
assert client_1.operating_state is NodeOperatingState.SHUTTING_DOWN
for i in range(client_1.shut_down_duration + 1):
for i in range(client_1.config.shut_down_duration + 1):
network.apply_timestep(timestep=i)
assert client_1.operating_state is NodeOperatingState.OFF
network.apply_timestep(client_1.shut_down_duration + 2)
network.apply_timestep(client_1.config.shut_down_duration + 2)
assert client_1.operating_state is NodeOperatingState.OFF
@@ -74,7 +74,16 @@ def test_removing_node_that_does_not_exist(network):
"""Node that does not exist on network should not affect existing nodes."""
assert len(network.nodes) is 7
network.remove_node(Computer(hostname="new_node", ip_address="192.168.1.2", subnet_mask="255.255.255.0"))
network.remove_node(
Computer.from_config(
config={
"type": "computer",
"hostname": "new_node",
"ip_address": "192.168.1.2",
"subnet_mask": "255.255.255.0",
}
)
)
assert len(network.nodes) is 7

View File

@@ -19,7 +19,7 @@ def _assert_valid_creation(net: Network, lan_name, subnet_base, pcs_ip_block_sta
num_routers = 1 if include_router else 0
total_nodes = num_pcs + num_switches + num_routers
assert all((n.hostname.endswith(lan_name) for n in net.nodes.values()))
assert all((n.config.hostname.endswith(lan_name) for n in net.nodes.values()))
assert len(net.computer_nodes) == num_pcs
assert len(net.switch_nodes) == num_switches
assert len(net.router_nodes) == num_routers

View File

@@ -16,19 +16,27 @@ def basic_c2_network() -> Network:
network = Network()
# Creating two generic nodes for the C2 Server and the C2 Beacon.
computer_a_cfg = {
"type": "computer",
"hostname": "computer_a",
"ip_address": "192.168.0.1",
"subnet_mask": "255.255.255.252",
"start_up_duration": 0,
}
computer_a = Computer.from_config(config=computer_a_cfg)
computer_a = Computer(
hostname="computer_a",
ip_address="192.168.0.1",
subnet_mask="255.255.255.252",
start_up_duration=0,
)
computer_a.power_on()
computer_a.software_manager.install(software_class=C2Server)
computer_b = Computer(
hostname="computer_b", ip_address="192.168.0.2", subnet_mask="255.255.255.252", start_up_duration=0
)
computer_b_cfg = {
"type": "computer",
"hostname": "computer_b",
"ip_address": "192.168.0.2",
"subnet_mask": "255.255.255.252",
"start_up_duration": 0,
}
computer_b = Computer.from_config(config=computer_b_cfg)
computer_b.power_on()
computer_b.software_manager.install(software_class=C2Beacon)

View File

@@ -12,9 +12,15 @@ from primaite.utils.validation.port import PORT_LOOKUP
@pytest.fixture(scope="function")
def dos_bot() -> DoSBot:
computer = Computer(
hostname="compromised_pc", ip_address="192.168.0.1", subnet_mask="255.255.255.0", start_up_duration=0
)
computer_cfg = {
"type": "computer",
"hostname": "compromised_pc",
"ip_address": "192.168.0.1",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
computer: Computer = Computer.from_config(config=computer_cfg)
computer.power_on()
computer.software_manager.install(DoSBot)
@@ -34,7 +40,7 @@ def test_dos_bot_cannot_run_when_node_offline(dos_bot):
dos_bot_node.power_off()
for i in range(dos_bot_node.shut_down_duration + 1):
for i in range(dos_bot_node.config.shut_down_duration + 1):
dos_bot_node.apply_timestep(timestep=i)
assert dos_bot_node.operating_state is NodeOperatingState.OFF

View File

@@ -17,13 +17,27 @@ from primaite.simulator.system.services.database.database_service import Databas
def database_client_on_computer() -> Tuple[DatabaseClient, Computer]:
network = Network()
db_server = Server(hostname="db_server", ip_address="192.168.0.1", subnet_mask="255.255.255.0", start_up_duration=0)
db_server: Server = Server.from_config(
config={
"type": "server",
"hostname": "db_server",
"ip_address": "192.168.0.1",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
db_server.power_on()
db_server.software_manager.install(DatabaseService)
db_server.software_manager.software["database-service"].start()
db_client = Computer(
hostname="db_client", ip_address="192.168.0.2", subnet_mask="255.255.255.0", start_up_duration=0
db_client: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "db_client",
"ip_address": "192.168.0.2",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
db_client.power_on()
db_client.software_manager.install(DatabaseClient)

View File

@@ -12,13 +12,17 @@ from primaite.utils.validation.port import PORT_LOOKUP
@pytest.fixture(scope="function")
def web_browser() -> WebBrowser:
computer = Computer(
hostname="web_client",
ip_address="192.168.1.11",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
computer_cfg = {
"type": "computer",
"hostname": "web_client",
"ip_address": "192.168.1.11",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
computer: Computer = Computer.from_config(config=computer_cfg)
computer.power_on()
# Web Browser should be pre-installed in computer
web_browser: WebBrowser = computer.software_manager.software.get("web-browser")
@@ -28,13 +32,17 @@ def web_browser() -> WebBrowser:
def test_create_web_client():
computer = Computer(
hostname="web_client",
ip_address="192.168.1.11",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
computer_cfg = {
"type": "computer",
"hostname": "web_client",
"ip_address": "192.168.1.11",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
computer: Computer = Computer.from_config(config=computer_cfg)
computer.power_on()
# Web Browser should be pre-installed in computer
web_browser: WebBrowser = computer.software_manager.software.get("web-browser")

View File

@@ -8,7 +8,15 @@ from primaite.simulator.system.services.database.database_service import Databas
@pytest.fixture(scope="function")
def database_server() -> Node:
node = Computer(hostname="db_node", ip_address="192.168.1.2", subnet_mask="255.255.255.0", start_up_duration=0)
node_cfg = {
"type": "computer",
"hostname": "db_node",
"ip_address": "192.168.1.2",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
node = Computer.from_config(config=node_cfg)
node.power_on()
node.software_manager.install(DatabaseService)
node.software_manager.software.get("database-service").start()

View File

@@ -14,13 +14,15 @@ from primaite.utils.validation.port import PORT_LOOKUP
@pytest.fixture(scope="function")
def dns_client() -> Computer:
node = Computer(
hostname="dns_client",
ip_address="192.168.1.11",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
)
node_cfg = {
"type": "computer",
"hostname": "dns_client",
"ip_address": "192.168.1.11",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"dns_server": IPv4Address("192.168.1.10"),
}
node = Computer.from_config(config=node_cfg)
return node
@@ -34,6 +36,16 @@ def test_create_dns_client(dns_client):
def test_dns_client_add_domain_to_cache_when_not_running(dns_client):
dns_client_service: DNSClient = dns_client.software_manager.software.get("dns-client")
# shutdown the dns_client
dns_client.power_off()
# wait for dns_client to turn off
idx = 0
while dns_client.operating_state == NodeOperatingState.SHUTTING_DOWN:
dns_client.apply_timestep(idx)
idx += 1
assert dns_client.operating_state is NodeOperatingState.OFF
assert dns_client_service.operating_state is ServiceOperatingState.STOPPED
@@ -61,7 +73,7 @@ def test_dns_client_check_domain_exists_when_not_running(dns_client):
dns_client.power_off()
for i in range(dns_client.shut_down_duration + 1):
for i in range(dns_client.config.shut_down_duration + 1):
dns_client.apply_timestep(timestep=i)
assert dns_client.operating_state is NodeOperatingState.OFF

View File

@@ -16,13 +16,15 @@ from primaite.utils.validation.port import PORT_LOOKUP
@pytest.fixture(scope="function")
def dns_server() -> Node:
node = Server(
hostname="dns_server",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
node_cfg = {
"type": "server",
"hostname": "dns_server",
"ip_address": "192.168.1.10",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
node = Server.from_config(config=node_cfg)
node.power_on()
node.software_manager.install(software_class=DNSServer)
return node
@@ -55,9 +57,16 @@ def test_dns_server_receive(dns_server):
# register the web server in the domain controller
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
client = Computer(hostname="client", ip_address="192.168.1.11", subnet_mask="255.255.255.0", start_up_duration=0)
client_cfg = {
"type": "computer",
"hostname": "client",
"ip_address": "192.168.1.11",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
client = Computer.from_config(config=client_cfg)
client.power_on()
client.dns_server = IPv4Address("192.168.1.10")
client.config.dns_server = IPv4Address("192.168.1.10")
network = Network()
network.connect(dns_server.network_interface[1], client.network_interface[1])
dns_client: DNSClient = client.software_manager.software["dns-client"] # noqa

View File

@@ -16,13 +16,15 @@ from primaite.utils.validation.port import PORT_LOOKUP
@pytest.fixture(scope="function")
def ftp_client() -> Node:
node = Computer(
hostname="ftp_client",
ip_address="192.168.1.11",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
node_cfg = {
"type": "computer",
"hostname": "ftp_client",
"ip_address": "192.168.1.11",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
node = Computer.from_config(config=node_cfg)
node.power_on()
return node
@@ -94,7 +96,7 @@ def test_offline_ftp_client_receives_request(ftp_client):
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("ftp-client")
ftp_client.power_off()
for i in range(ftp_client.shut_down_duration + 1):
for i in range(ftp_client.config.shut_down_duration + 1):
ftp_client.apply_timestep(timestep=i)
assert ftp_client.operating_state is NodeOperatingState.OFF

View File

@@ -14,13 +14,15 @@ from primaite.utils.validation.port import PORT_LOOKUP
@pytest.fixture(scope="function")
def ftp_server() -> Node:
node = Server(
hostname="ftp_server",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
node_cfg = {
"type": "server",
"hostname": "ftp_server",
"ip_address": "192.168.1.10",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
node = Server.from_config(config=node_cfg)
node.power_on()
node.software_manager.install(software_class=FTPServer)
return node

View File

@@ -12,6 +12,7 @@ from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from primaite.simulator.network.networks import arcd_uc2_network
from primaite.simulator.network.protocols.ssh import (
SSHConnectionMessage,
SSHPacket,
@@ -29,8 +30,14 @@ from primaite.utils.validation.port import PORT_LOOKUP
@pytest.fixture(scope="function")
def terminal_on_computer() -> Tuple[Terminal, Computer]:
computer: Computer = Computer(
hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0
computer: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "node_a",
"ip_address": "192.168.0.10",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
computer.power_on()
terminal: Terminal = computer.software_manager.software.get("terminal")
@@ -41,11 +48,27 @@ def terminal_on_computer() -> Tuple[Terminal, Computer]:
@pytest.fixture(scope="function")
def basic_network() -> Network:
network = Network()
node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
node_a = Computer.from_config(
config={
"type": "computer",
"hostname": "node_a",
"ip_address": "192.168.0.10",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
node_a.power_on()
node_a.software_manager.get_open_ports()
node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0)
node_b = Computer.from_config(
config={
"type": "computer",
"hostname": "node_b",
"ip_address": "192.168.0.11",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
node_b.power_on()
network.connect(node_a.network_interface[1], node_b.network_interface[1])
@@ -57,18 +80,23 @@ def wireless_wan_network():
network = Network()
# Configure PC A
pc_a = Computer(
hostname="pc_a",
ip_address="192.168.0.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.0.1",
start_up_duration=0,
)
pc_a_cfg = {
"type": "computer",
"hostname": "pc_a",
"ip_address": "192.168.0.2",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.0.1",
"start_up_duration": 0,
}
pc_a = Computer.from_config(config=pc_a_cfg)
pc_a.power_on()
network.add_node(pc_a)
# Configure Router 1
router_1 = WirelessRouter(hostname="router_1", start_up_duration=0, airspace=network.airspace)
router_1 = WirelessRouter.from_config(
config={"type": "wireless_router", "hostname": "router_1", "start_up_duration": 0, "airspace": network.airspace}
)
router_1.power_on()
network.add_node(router_1)
@@ -88,41 +116,30 @@ def wireless_wan_network():
)
# Configure PC B
pc_b = Computer(
hostname="pc_b",
ip_address="192.168.2.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.2.1",
start_up_duration=0,
)
pc_b_cfg = {
"type": "computer",
"hostname": "pc_b",
"ip_address": "192.168.2.2",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.2.1",
"start_up_duration": 0,
}
pc_b = Computer.from_config(config=pc_b_cfg)
pc_b.power_on()
network.add_node(pc_b)
# Configure Router 2
router_2 = WirelessRouter(hostname="router_2", start_up_duration=0, airspace=network.airspace)
router_2.power_on()
network.add_node(router_2)
# Configure the connection between PC B and Router 2 port 2
router_2.configure_router_interface("192.168.2.1", "255.255.255.0")
network.connect(pc_b.network_interface[1], router_2.network_interface[2])
# Configure Router 2 ACLs
# Configure the wireless connection between Router 1 port 1 and Router 2 port 1
router_1.configure_wireless_access_point("192.168.1.1", "255.255.255.0")
router_2.configure_wireless_access_point("192.168.1.2", "255.255.255.0")
router_1.route_table.add_route(
address="192.168.2.0", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.2"
)
# Configure Route from Router 2 to PC A subnet
router_2.route_table.add_route(
address="192.168.0.2", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.1"
)
return pc_a, pc_b, router_1, router_2
return network
@pytest.fixture
@@ -131,7 +148,7 @@ def game_and_agent_fixture(game_and_agent):
game, agent = game_and_agent
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
client_1.start_up_duration = 3
client_1.config.start_up_duration = 3
return game, agent
@@ -142,8 +159,16 @@ def test_terminal_creation(terminal_on_computer):
def test_terminal_install_default():
"""terminal should be auto installed onto Nodes"""
computer = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
"""Terminal should be auto installed onto Nodes"""
computer: Computer = Computer.from_config(
config={
"type": "computer",
"hostname": "node_a",
"ip_address": "192.168.0.10",
"subnet_mask": "255.255.255.0",
"start_up_duration": 0,
}
)
computer.power_on()
assert computer.software_manager.software.get("terminal")
@@ -151,7 +176,7 @@ def test_terminal_install_default():
def test_terminal_not_on_switch():
"""Ensure terminal does not auto-install to switch"""
test_switch = Switch(hostname="Test")
test_switch = Switch.from_config(config={"type": "switch", "hostname": "Test"})
assert not test_switch.software_manager.software.get("terminal")
@@ -274,7 +299,10 @@ def test_terminal_ignores_when_off(basic_network):
def test_computer_remote_login_to_router(wireless_wan_network):
"""Test to confirm that a computer can SSH into a router."""
pc_a, _, router_1, _ = wireless_wan_network
pc_a = wireless_wan_network.get_node_by_hostname("pc_a")
router_1 = wireless_wan_network.get_node_by_hostname("router_1")
pc_a_terminal: Terminal = pc_a.software_manager.software.get("terminal")
@@ -293,7 +321,9 @@ def test_computer_remote_login_to_router(wireless_wan_network):
def test_router_remote_login_to_computer(wireless_wan_network):
"""Test to confirm that a router can ssh into a computer."""
pc_a, _, router_1, _ = wireless_wan_network
pc_a = wireless_wan_network.get_node_by_hostname("pc_a")
router_1 = wireless_wan_network.get_node_by_hostname("router_1")
router_1_terminal: Terminal = router_1.software_manager.software.get("terminal")
@@ -311,8 +341,10 @@ def test_router_remote_login_to_computer(wireless_wan_network):
def test_router_blocks_SSH_traffic(wireless_wan_network):
"""Test to check that router will block SSH traffic if no acl rule."""
pc_a, _, router_1, _ = wireless_wan_network
"""Test to check that router will block SSH traffic if no ACL rule."""
pc_a = wireless_wan_network.get_node_by_hostname("pc_a")
router_1 = wireless_wan_network.get_node_by_hostname("router_1")
# Remove rule that allows SSH traffic.
router_1.acl.remove_rule(position=21)
@@ -326,20 +358,22 @@ def test_router_blocks_SSH_traffic(wireless_wan_network):
assert len(pc_a_terminal._connections) == 0
def test_SSH_across_network(wireless_wan_network):
def test_SSH_across_network():
"""Test to show ability to SSH across a network."""
pc_a, pc_b, router_1, router_2 = wireless_wan_network
network: Network = arcd_uc2_network()
pc_a = network.get_node_by_hostname("client_1")
router_1 = network.get_node_by_hostname("router_1")
terminal_a: Terminal = pc_a.software_manager.software.get("terminal")
terminal_b: Terminal = pc_b.software_manager.software.get("terminal")
terminal_a: Terminal = pc_a.software_manager.software.get("Terminal")
router_2.acl.add_rule(
router_1.acl.add_rule(
action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=21
)
assert len(terminal_a._connections) == 0
terminal_b_on_terminal_a = terminal_b.login(username="admin", password="admin", ip_address="192.168.0.2")
# Login to the Domain Controller
terminal_a.login(username="admin", password="admin", ip_address="192.168.1.10")
assert len(terminal_a._connections) == 1
@@ -357,8 +391,6 @@ def test_multiple_remote_terminals_same_node(basic_network):
for attempt in range(3):
remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11")
terminal_a.show()
assert len(terminal_a._connections) == 3

View File

@@ -16,13 +16,15 @@ from primaite.utils.validation.port import PORT_LOOKUP
@pytest.fixture(scope="function")
def web_server() -> Server:
node = Server(
hostname="web_server",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
node_cfg = {
"type": "server",
"hostname": "web_server",
"ip_address": "192.168.1.10",
"subnet_mask": "255.255.255.0",
"default_gateway": "192.168.1.1",
"start_up_duration": 0,
}
node = Server.from_config(config=node_cfg)
node.power_on()
node.software_manager.install(WebServer)
node.software_manager.software.get("web-server").start()