Fix airspace and remaining port problems from refactor

This commit is contained in:
Marek Wolan
2024-09-19 15:06:29 +01:00
parent dd931d900b
commit 08f1cf1fbd
34 changed files with 227 additions and 177 deletions

View File

@@ -49,7 +49,7 @@ additional steps to configure wireless settings:
wireless_router.configure_wireless_access_point(
port=1, ip_address="192.168.2.1",
subnet_mask="255.255.255.0",
frequency=AirSpaceFrequency["WIFI_2_4"],
frequency="WIFI_2_4",
)
@@ -130,13 +130,13 @@ ICMP traffic, ensuring basic network connectivity and ping functionality.
port=1,
ip_address="192.168.1.1",
subnet_mask="255.255.255.0",
frequency=AirSpaceFrequency["WIFI_2_4"],
frequency="WIFI_2_4",
)
router_2.configure_wireless_access_point(
port=1,
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
frequency=AirSpaceFrequency["WIFI_2_4"],
frequency="WIFI_2_4",
)
# Configure routes for inter-router communication

View File

@@ -60,9 +60,10 @@ def data_manipulation_marl_config_path() -> Path:
raise FileNotFoundError(msg)
return path
def get_extended_config_path() -> Path:
"""
Get the path to an 'extended' example config that contains nodes using the extension framework
Get the path to an 'extended' example config that contains nodes using the extension framework.
:return: Path to the extended example config
:rtype: Path
@@ -72,4 +73,4 @@ def get_extended_config_path() -> Path:
msg = f"Example config does not exist: {path}. Have you run `primaite setup`?"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
return path
return path

View File

@@ -10,8 +10,6 @@ from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
_LOGGER = getLogger(__name__)
@@ -63,7 +61,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)}
self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)}
self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)}
self.protocol_to_id: Dict[str, int] = {IPProtocol[p]: i + 2 for i, p in enumerate(protocol_list)}
self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)}
self.default_observation: Dict = {
i
+ 1: {

View File

@@ -58,8 +58,14 @@ class HostObservation(AbstractObservation, identifier="HOST"):
include_users: Optional[bool] = True
"""If True, report user session information."""
@field_validator('monitored_traffic', mode='before')
def traffic_lookup(cls, val:Optional[Dict]) -> Optional[Dict]:
@field_validator("monitored_traffic", mode="before")
def traffic_lookup(cls, val: Optional[Dict]) -> Optional[Dict]:
"""
Convert monitored_traffic by lookup against Port and Protocol dicts.
This is necessary for retaining compatiblility with configs written for PrimAITE <=3.3.
This method will be removed in PrimAITE >= 4.0
"""
if val is None:
return val
new_val = {}

View File

@@ -26,8 +26,14 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
monitored_traffic: Optional[Dict] = None
"""A dict containing which traffic types are to be included in the observation."""
@field_validator('monitored_traffic', mode='before')
def traffic_lookup(cls, val:Optional[Dict]) -> Optional[Dict]:
@field_validator("monitored_traffic", mode="before")
def traffic_lookup(cls, val: Optional[Dict]) -> Optional[Dict]:
"""
Convert monitored_traffic by lookup against Port and Protocol dicts.
This is necessary for retaining compatiblility with configs written for PrimAITE <=3.3.
This method will be removed in PrimAITE >= 4.0
"""
if val is None:
return val
new_val = {}
@@ -41,7 +47,6 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
new_val[proto].append(port)
return new_val
def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None:
"""
Initialise a network interface observation instance.
@@ -76,7 +81,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
def _default_monitored_traffic_observation(self, monitored_traffic_config: Dict) -> Dict:
default_traffic_obs = {"TRAFFIC": {}}
for protocol in monitored_traffic_config:
for protocol in self.monitored_traffic:
protocol = str(protocol).lower()
default_traffic_obs["TRAFFIC"][protocol] = {}
@@ -84,8 +89,8 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
default_traffic_obs["TRAFFIC"]["icmp"] = {"inbound": 0, "outbound": 0}
else:
default_traffic_obs["TRAFFIC"][protocol] = {}
for port in monitored_traffic_config[protocol]:
default_traffic_obs["TRAFFIC"][protocol] = {"inbound": 0, "outbound": 0}
for port in self.monitored_traffic[protocol]:
default_traffic_obs["TRAFFIC"][protocol][port] = {"inbound": 0, "outbound": 0}
return default_traffic_obs

View File

@@ -63,8 +63,14 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
@field_validator('monitored_traffic', mode='before')
def traffic_lookup(cls, val:Optional[Dict]) -> Optional[Dict]:
@field_validator("monitored_traffic", mode="before")
def traffic_lookup(cls, val: Optional[Dict]) -> Optional[Dict]:
"""
Convert monitored_traffic by lookup against Port and Protocol dicts.
This is necessary for retaining compatiblility with configs written for PrimAITE <=3.3.
This method will be removed in PrimAITE >= 4.0
"""
if val is None:
return val
new_val = {}

View File

@@ -17,10 +17,9 @@ from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
from primaite.game.agent.scripted_agents.tap001 import TAP001
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.airspace import AirSpaceFrequency
from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState, UserManager
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC, HostNode
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode
@@ -89,8 +88,8 @@ class PrimaiteGameOptions(BaseModel):
thresholds: Optional[Dict] = {}
"""A dict containing the thresholds used for determining what is acceptable during observations."""
@field_validator('ports', mode='before')
def ports_str2int(cls, vals:Union[List[str],List[int]]) -> List[int]:
@field_validator("ports", mode="before")
def ports_str2int(cls, vals: Union[List[str], List[int]]) -> List[int]:
"""
Convert named port strings to port integer values. Integer ports remain unaffected.
@@ -102,8 +101,8 @@ class PrimaiteGameOptions(BaseModel):
vals[i] = Port[port_val]
return vals
@field_validator('protocols', mode='before')
def protocols_str2int(cls, vals:List[str]) -> List[str]:
@field_validator("protocols", mode="before")
def protocols_str2int(cls, vals: List[str]) -> List[str]:
"""
Convert old-style named protocols to their proper values.
@@ -116,7 +115,6 @@ class PrimaiteGameOptions(BaseModel):
return vals
class PrimaiteGame:
"""
Primaite game encapsulates the simulation and agents which interact with it.
@@ -294,10 +292,7 @@ class PrimaiteGame:
network_config = simulation_config.get("network", {})
airspace_cfg = network_config.get("airspace", {})
frequency_max_capacity_mbps_cfg = airspace_cfg.get("frequency_max_capacity_mbps", {})
frequency_max_capacity_mbps_cfg = {AirSpaceFrequency[k]: v for k, v in frequency_max_capacity_mbps_cfg.items()}
net.airspace.frequency_max_capacity_mbps_ = frequency_max_capacity_mbps_cfg
net.airspace.set_frequency_max_capacity_mbps(frequency_max_capacity_mbps_cfg)
nodes_cfg = network_config.get("nodes", [])
links_cfg = network_config.get("links", [])
@@ -318,11 +313,10 @@ class PrimaiteGame:
dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()])
elif n_type in NetworkNode._registry:
new_node = NetworkNode._registry[n_type](
**node_cfg
else NodeOperatingState[p.upper()],
)
elif n_type in NetworkNode._registry:
new_node = NetworkNode._registry[n_type](**node_cfg)
# Default PrimAITE nodes
elif n_type == "computer":
new_node = Computer(
@@ -502,7 +496,7 @@ class PrimaiteGame:
opt = application_cfg["options"]
new_application.configure(
target_ip_address=IPv4Address(opt.get("target_ip_address")),
target_port = Port[opt.get("target_port", "POSTGRES_SERVER")],
target_port=Port[opt.get("target_port", "POSTGRES_SERVER")],
payload=opt.get("payload"),
repeat=bool(opt.get("repeat")),
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),

View File

@@ -1,14 +1,12 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from __future__ import annotations
import copy
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, ClassVar, Dict, List, Type
from pydantic._internal._generics import PydanticGenericMetadata
from typing_extensions import Unpack
from typing import Any, Dict, List
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, Field, validate_call
from primaite import getLogger
from primaite.simulator.network.hardware.base import Layer3Interface, NetworkInterface, WiredNetworkInterface
@@ -42,29 +40,31 @@ def format_hertz(hertz: float, format_terahertz: bool = False, decimals: int = 3
else: # Hertz
return format_str.format(hertz) + " Hz"
AirSpaceFrequencyRegistry: Dict[str,Dict] = {
"WIFI_2_4" : {'frequency': 2.4e9, 'data_rate_bps':100_000_000.0},
"WIFI_5" : {'frequency': 5e9, 'data_rate_bps':500_000_000.0},
_default_frequency_set: Dict[str, Dict] = {
"WIFI_2_4": {"frequency": 2.4e9, "data_rate_bps": 100_000_000.0},
"WIFI_5": {"frequency": 5e9, "data_rate_bps": 500_000_000.0},
}
"""Frequency configuration that is automatically used for any new airspace."""
def register_frequency(freq_name: str, freq_hz: int, data_rate_bps: int) -> None:
if freq_name in AirSpaceFrequencyRegistry:
raise RuntimeError(f"Cannot register new frequency {freq_name} because it's already registered.")
AirSpaceFrequencyRegistry.update({freq_name:{'frequency': freq_hz, 'data_rate_bps':data_rate_bps}})
def maximum_data_rate_mbps(frequency_name:str) -> float:
def register_default_frequency(freq_name: str, freq_hz: float, data_rate_bps: float):
"""Add to the default frequency configuration. This is intended as a plugin hook.
If your plugin makes use of bespoke frequencies for wireless communication, you should make a call to this method
whereever you define components that rely on the bespoke frequencies. That way, as soon as your components are
imported, this function automatically updates the default frequency set.
This should also be run before instances of AirSpace are created.
:param freq_name: The frequency name. If this clashes with an existing frequency name, it will be overwritten.
:type freq_name: str
:param freq_hz: The frequency itself, measured in Hertz.
:type freq_hz: float
:param data_rate_bps: The transmission capacity over this frequency, in bits per second.
:type data_rate_bps: float
"""
Retrieves the maximum data transmission rate in megabits per second (Mbps).
This is derived by converting the maximum data rate from bits per second, as defined
in `maximum_data_rate_bps`, to megabits per second.
:return: The maximum data rate in megabits per second.
"""
return AirSpaceFrequencyRegistry[frequency_name]['data_rate_bps']
return data_rate / 1_000_000.0
_default_frequency_set.update({freq_name: {"frequency": freq_hz, "data_rate_bps": data_rate_bps}})
class AirSpace(BaseModel):
@@ -77,27 +77,21 @@ class AirSpace(BaseModel):
"""
wireless_interfaces: Dict[str, WirelessNetworkInterface] = Field(default_factory=lambda: {})
wireless_interfaces_by_frequency: Dict[int, List[WirelessNetworkInterface]] = Field(
default_factory=lambda: {}
)
wireless_interfaces_by_frequency: Dict[int, List[WirelessNetworkInterface]] = Field(default_factory=lambda: {})
bandwidth_load: Dict[int, float] = Field(default_factory=lambda: {})
frequency_max_capacity_mbps_: Dict[int, float] = Field(default_factory=lambda: {})
frequencies: Dict[str, Dict] = Field(default_factory=lambda: copy.deepcopy(_default_frequency_set))
def get_frequency_max_capacity_mbps(self, frequency: str) -> float:
@validate_call
def get_frequency_max_capacity_mbps(self, freq_name: str) -> float:
"""
Retrieves the maximum data transmission capacity for a specified frequency.
This method checks a dictionary holding custom maximum capacities. If the frequency is found, it returns the
custom set maximum capacity. If the frequency is not found in the dictionary, it defaults to the standard
maximum data rate associated with that frequency.
:param frequency: The frequency for which the maximum capacity is queried.
:param freq_name: The frequency for which the maximum capacity is queried.
:return: The maximum capacity in Mbps for the specified frequency.
"""
if frequency in self.frequency_max_capacity_mbps_:
return self.frequency_max_capacity_mbps_[frequency]
return maximum_data_rate_mbps(frequency)
if freq_name in self.frequencies:
return self.frequencies[freq_name]["data_rate_bps"] / (1024.0 * 1024.0)
return 0.0
def set_frequency_max_capacity_mbps(self, cfg: Dict[int, float]):
"""
@@ -105,10 +99,29 @@ class AirSpace(BaseModel):
:param cfg: A dictionary mapping frequencies to their new maximum capacities in Mbps.
"""
self.frequency_max_capacity_mbps_ = cfg
for freq, mbps in cfg.items():
self.frequencies[freq]["data_rate_bps"] = mbps * 1024 * 1024
print(f"Overriding {freq} max capacity as {mbps:.3f} mbps")
def register_frequency(self, freq_name: str, freq_hz: float, data_rate_bps: float) -> None:
"""
Define a new frequency for this airspace.
:param freq_name: The frequency name. If this clashes with an existing frequency name, it will be overwritten.
:type freq_name: str
:param freq_hz: The frequency itself, measured in Hertz.
:type freq_hz: float
:param data_rate_bps: The transmission capacity over this frequency, in bits per second.
:type data_rate_bps: float
"""
if freq_name in self.frequencies:
_LOGGER.info(
f"Overwriting Air space frequency {freq_name}. "
f"Previous data rate: {self.frequencies[freq_name]['data_rate_bps']}. "
f"Current data rate: {data_rate_bps}."
)
self.frequencies.update({freq_name: {"frequency": freq_hz, "data_rate_bps": data_rate_bps}})
def show_bandwidth_load(self, markdown: bool = False):
"""
Prints a table of the current bandwidth load for each frequency on the airspace.
@@ -130,7 +143,13 @@ class AirSpace(BaseModel):
load_percent = load / maximum_capacity if maximum_capacity > 0 else 0.0
if load_percent > 1.0:
load_percent = 1.0
table.add_row([format_hertz(frequency), f"{load_percent:.0%}", f"{maximum_capacity:.3f}"])
table.add_row(
[
format_hertz(self.frequencies[frequency]["frequency"]),
f"{load_percent:.0%}",
f"{maximum_capacity:.3f}",
]
)
print(table)
def show_wireless_interfaces(self, markdown: bool = False):
@@ -162,7 +181,7 @@ class AirSpace(BaseModel):
interface.mac_address,
interface.ip_address if hasattr(interface, "ip_address") else None,
interface.subnet_mask if hasattr(interface, "subnet_mask") else None,
format_hertz(interface.frequency),
format_hertz(self.frequencies[interface.frequency]["frequency"]),
f"{interface.speed:.3f}",
status,
]

View File

@@ -130,15 +130,15 @@ class Network(SimComponent):
def firewall_nodes(self) -> List[Node]:
"""The Firewalls in the Network."""
return [node for node in self.nodes.values() if node.__class__.__name__ == "Firewall"]
@property
def extended_hostnodes(self) -> List[Node]:
"""Extended nodes that inherited HostNode in the network"""
"""Extended nodes that inherited HostNode in the network."""
return [node for node in self.nodes.values() if node.__class__.__name__.lower() in HostNode._registry]
@property
def extended_networknodes(self) -> List[Node]:
"""Extended nodes that inherited NetworkNode in the network"""
"""Extended nodes that inherited NetworkNode in the network."""
return [node for node in self.nodes.values() if node.__class__.__name__.lower() in NetworkNode._registry]
@property

View File

@@ -332,7 +332,7 @@ class HostNode(Node):
super().__init__(**kwargs)
self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask))
def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None:
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
"""
Register a hostnode type.
@@ -340,7 +340,7 @@ class HostNode(Node):
:type identifier: str
:raises ValueError: When attempting to register an hostnode with a name that is already allocated.
"""
if identifier == 'default':
if identifier == "default":
return
# Enforce lowercase registry entries because it makes comparisons everywhere else much easier.
identifier = identifier.lower()

View File

@@ -19,7 +19,7 @@ class NetworkNode(Node):
_registry: ClassVar[Dict[str, Type["NetworkNode"]]] = {}
"""Registry of application types. Automatically populated when subclasses are defined."""
def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None:
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
"""
Register a networknode type.
@@ -27,7 +27,7 @@ class NetworkNode(Node):
:type identifier: str
:raises ValueError: When attempting to register an networknode with a name that is already allocated.
"""
if identifier == 'default':
if identifier == "default":
return
identifier = identifier.lower()
super().__init_subclass__(**kwargs)

View File

@@ -130,19 +130,20 @@ class ACLRule(SimComponent):
dst_port: Optional[int] = None
match_count: int = 0
@field_validator('protocol', mode='before')
def protocol_valid(cls, val:Optional[str]) -> Optional[str]:
@field_validator("protocol", mode="before")
def protocol_valid(cls, val: Optional[str]) -> Optional[str]:
"""Assert that the protocol for the rule is predefined in the IPProtocol lookup."""
if val is not None:
assert val in IPProtocol.values(), f"Cannot create ACL rule with invalid protocol {val}"
return val
@field_validator('src_port', 'dst_port', mode='before')
def ports_valid(cls, val:Optional[int]) -> Optional[int]:
@field_validator("src_port", "dst_port", mode="before")
def ports_valid(cls, val: Optional[int]) -> Optional[int]:
"""Assert that the port for the rule is predefined in the Port lookup."""
if val is not None:
assert val in Port.values(), f"Cannot create ACL rule with invalid port {val}"
return val
def __str__(self) -> str:
rule_strings = []
for key, value in self.model_dump(exclude={"uuid", "request_manager"}).items():

View File

@@ -4,7 +4,7 @@ from typing import Any, Dict, Optional, Union
from pydantic import validate_call
from primaite.simulator.network.airspace import AirSpace, AirSpaceFrequency, IPWirelessNetworkInterface
from primaite.simulator.network.airspace import AirSpace, IPWirelessNetworkInterface
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterInterface
from primaite.simulator.network.transmission.data_link_layer import Frame
@@ -116,7 +116,7 @@ class WirelessRouter(Router):
>>> wireless_router.configure_wireless_access_point(
... ip_address="10.10.10.1",
... subnet_mask="255.255.255.0"
... frequency=AirSpaceFrequency["WIFI_2_4"]
... frequency="WIFI_2_4"
... )
"""
@@ -153,7 +153,7 @@ class WirelessRouter(Router):
self,
ip_address: IPV4Address,
subnet_mask: IPV4Address,
frequency: Optional[int] = AirSpaceFrequency["WIFI_2_4"],
frequency: Optional[str] = "WIFI_2_4",
):
"""
Configures a wireless access point (WAP).
@@ -166,12 +166,12 @@ class WirelessRouter(Router):
:param ip_address: The IP address to be assigned to the wireless access point.
:param subnet_mask: The subnet mask associated with the IP address
:param frequency: The operating frequency of the wireless access point, defined by the AirSpaceFrequency
:param frequency: The operating frequency of the wireless access point, defined by the air space frequency
enum. This determines the frequency band (e.g., 2.4 GHz or 5 GHz) the access point will use for wireless
communication. Default is AirSpaceFrequency["WIFI_2_4"].
communication. Default is "WIFI_2_4".
"""
if not frequency:
frequency = AirSpaceFrequency["WIFI_2_4"]
frequency = "WIFI_2_4"
self.sys_log.info("Configuring wireless access point")
self.wireless_access_point.disable() # Temporarily disable the WAP for reconfiguration
@@ -264,7 +264,7 @@ class WirelessRouter(Router):
if "wireless_access_point" in cfg:
ip_address = cfg["wireless_access_point"]["ip_address"]
subnet_mask = cfg["wireless_access_point"]["subnet_mask"]
frequency = AirSpaceFrequency[cfg["wireless_access_point"]["frequency"]]
frequency = cfg["wireless_access_point"]["frequency"]
router.configure_wireless_access_point(ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency)
if "acl" in cfg:

View File

@@ -9,11 +9,11 @@ from primaite.utils.validators import IPV4Address
_LOGGER = getLogger(__name__)
IPProtocol : dict[str, str] = dict(
NONE = "none",
TCP = "tcp",
UDP = "udp",
ICMP = "icmp",
IPProtocol: dict[str, str] = dict(
NONE="none",
TCP="tcp",
UDP="udp",
ICMP="icmp",
)
# class IPProtocol(Enum):

View File

@@ -1,40 +1,39 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from enum import Enum
from typing import List, Union
from typing import List
from pydantic import BaseModel
Port: dict[str, int] = dict(
UNUSED = -1,
NONE = 0,
WOL = 9,
FTP_DATA = 20,
FTP = 21,
SSH = 22,
SMTP = 25,
DNS = 53,
HTTP = 80,
POP3 = 110,
SFTP = 115,
NTP = 123,
IMAP = 143,
SNMP = 161,
SNMP_TRAP = 162,
ARP = 219,
LDAP = 389,
HTTPS = 443,
SMB = 445,
IPP = 631,
SQL_SERVER = 1433,
MYSQL = 3306,
RDP = 3389,
RTP = 5004,
RTP_ALT = 5005,
DNS_ALT = 5353,
HTTP_ALT = 8080,
HTTPS_ALT = 8443,
POSTGRES_SERVER = 5432,
UNUSED=-1,
NONE=0,
WOL=9,
FTP_DATA=20,
FTP=21,
SSH=22,
SMTP=25,
DNS=53,
HTTP=80,
POP3=110,
SFTP=115,
NTP=123,
IMAP=143,
SNMP=161,
SNMP_TRAP=162,
ARP=219,
LDAP=389,
HTTPS=443,
SMB=445,
IPP=631,
SQL_SERVER=1433,
MYSQL=3306,
RDP=3389,
RTP=5004,
RTP_ALT=5005,
DNS_ALT=5353,
HTTP_ALT=8080,
HTTPS_ALT=8443,
POSTGRES_SERVER=5432,
)
# class Port():

View File

@@ -44,7 +44,7 @@ class Application(IOSoftware):
_registry: ClassVar[Dict[str, Type["Application"]]] = {}
"""Registry of application types. Automatically populated when subclasses are defined."""
def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None:
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
"""
Register an application type.
@@ -52,7 +52,7 @@ class Application(IOSoftware):
:type identifier: str
:raises ValueError: When attempting to register an application with a name that is already allocated.
"""
if identifier == 'default':
if identifier == "default":
return
super().__init_subclass__(**kwargs)
if identifier in cls._registry:

View File

@@ -1,5 +1,4 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from enum import Enum
from ipaddress import IPv4Address
from typing import Dict, Optional

View File

@@ -76,9 +76,7 @@ class SessionManager:
"""
def __init__(self, sys_log: SysLog):
self.sessions_by_key: Dict[
Tuple[str, IPv4Address, IPv4Address, Optional[int], Optional[int]], Session
] = {}
self.sessions_by_key: Dict[Tuple[str, IPv4Address, IPv4Address, Optional[int], Optional[int]], Session] = {}
self.sessions_by_uuid: Dict[str, Session] = {}
self.sys_log: SysLog = sys_log
self.software_manager: SoftwareManager = None # Noqa

View File

@@ -5,7 +5,6 @@ from typing import Dict, Optional
from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.service import Service

View File

@@ -52,7 +52,7 @@ class Service(IOSoftware):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __init_subclass__(cls, identifier: str = 'default', **kwargs: Any) -> None:
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
"""
Register a hostnode type.
@@ -60,7 +60,7 @@ class Service(IOSoftware):
:type identifier: str
:raises ValueError: When attempting to register an hostnode with a name that is already allocated.
"""
if identifier == 'default':
if identifier == "default":
return
# Enforce lowercase registry entries because it makes comparisons everywhere else much easier.
identifier = identifier.lower()

View File

@@ -14,7 +14,6 @@ from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.file_system.file_system import FileSystem, Folder
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.session_manager import Session
from primaite.simulator.system.core.sys_log import SysLog

View File

@@ -1,3 +1,4 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Dict
from prettytable import MARKDOWN, PrettyTable
@@ -27,7 +28,7 @@ class GigaSwitch(NetworkNode, identifier="gigaswitch"):
"A MAC address table mapping destination MAC addresses to corresponding SwitchPorts."
def __init__(self, **kwargs):
print('--- Extended Component: GigaSwitch ---')
print("--- Extended Component: GigaSwitch ---")
super().__init__(**kwargs)
for i in range(1, self.num_ports + 1):
self.connect_nic(SwitchPort())

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import ClassVar, Dict
from primaite.simulator.network.hardware.nodes.host.host_node import NIC, HostNode
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode, NIC
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.utils.validators import IPV4Address
@@ -37,7 +37,7 @@ class SuperComputer(HostNode, identifier="supercomputer"):
SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "FTPClient": FTPClient}
def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs):
print('--- Extended Component: SuperComputer ---')
print("--- Extended Component: SuperComputer ---")
super().__init__(ip_address=ip_address, subnet_mask=subnet_mask, **kwargs)
pass

View File

@@ -17,7 +17,7 @@ from primaite.simulator.system.software import SoftwareHealthState
_LOGGER = getLogger(__name__)
class ExtendedService(Service, identifier='extendedservice'):
class ExtendedService(Service, identifier="extendedservice"):
"""
A copy of DatabaseService that uses the extension framework instead of being part of PrimAITE.
@@ -42,7 +42,7 @@ class ExtendedService(Service, identifier='extendedservice'):
kwargs["protocol"] = IPProtocol["TCP"]
super().__init__(**kwargs)
self._create_db_file()
if kwargs.get('options'):
if kwargs.get("options"):
opt = kwargs["options"]
self.password = opt.get("db_password", None)
if "backup_server_ip" in opt:
@@ -139,7 +139,9 @@ class ExtendedService(Service, identifier='extendedservice'):
old_visible_state = SoftwareHealthState.GOOD
# get db file regardless of whether or not it was deleted
db_file = self.file_system.get_file(folder_name="database", file_name="extended_service_database.db", include_deleted=True)
db_file = self.file_system.get_file(
folder_name="database", file_name="extended_service_database.db", include_deleted=True
)
if db_file is None:
self.sys_log.warning("Database file not initialised.")
@@ -153,7 +155,9 @@ class ExtendedService(Service, identifier='extendedservice'):
self.file_system.delete_file(folder_name="database", file_name="extended_service_database.db")
# replace db file
self.file_system.copy_file(src_folder_name="downloads", src_file_name="extended_service_database.db", dst_folder_name="database")
self.file_system.copy_file(
src_folder_name="downloads", src_file_name="extended_service_database.db", dst_folder_name="database"
)
if self.db_file is None:
self.sys_log.error("Copying database backup failed.")

View File

@@ -1,22 +1,22 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import os
from primaite.config.load import get_extended_config_path
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ_NETWORK, load_config
import os
from tests.integration_tests.extensions.applications.extended_application import ExtendedApplication
from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch
# Import the extended components so that PrimAITE registers them
from tests.integration_tests.extensions.nodes.super_computer import SuperComputer
from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch
from tests.integration_tests.extensions.services.extended_service import ExtendedService
from tests.integration_tests.extensions.applications.extended_application import ExtendedApplication
def test_extended_example_config():
"""Test that the example config can be parsed properly."""
config_path = os.path.join( "tests", "assets", "configs", "extended_config.yaml")
config_path = os.path.join("tests", "assets", "configs", "extended_config.yaml")
game = load_config(config_path)
network: Network = game.simulation.network
@@ -25,8 +25,8 @@ def test_extended_example_config():
assert len(network.router_nodes) == 1 # 1 router in network
assert len(network.switch_nodes) == 1 # 1 switches in network
assert len(network.server_nodes) == 5 # 5 servers in network
assert len(network.extended_hostnodes) == 1 # One extended node based on HostNode
assert len(network.extended_networknodes) == 1 # One extended node based on NetworkNode
assert len(network.extended_hostnodes) == 1 # One extended node based on HostNode
assert len(network.extended_networknodes) == 1 # One extended node based on NetworkNode
assert 'ExtendedApplication' in network.extended_hostnodes[0].software_manager.software
assert 'ExtendedService' in network.extended_hostnodes[0].software_manager.software
assert "ExtendedApplication" in network.extended_hostnodes[0].software_manager.software
assert "ExtendedService" in network.extended_hostnodes[0].software_manager.software

View File

@@ -38,8 +38,8 @@ def test_acl_observations(simulation):
acl_obs = ACLObservation(
where=["network", "nodes", router.hostname, "acl", "acl"],
ip_list=[],
port_list=["NTP", "HTTP", "POSTGRES_SERVER"],
protocol_list=["TCP", "UDP", "ICMP"],
port_list=[123, 80, 5432],
protocol_list=["tcp", "udp", "icmp"],
num_rules=10,
wildcard_list=[],
)

View File

@@ -31,8 +31,8 @@ def test_firewall_observation():
num_rules=7,
ip_list=["10.0.0.1", "10.0.0.2"],
wildcard_list=["0.0.0.255", "0.0.0.1"],
port_list=["HTTP", "DNS"],
protocol_list=["TCP"],
port_list=[80, 53],
protocol_list=["tcp"],
include_users=False,
)

View File

@@ -152,7 +152,12 @@ def test_config_nic_categories(simulation):
def test_nic_monitored_traffic(simulation):
monitored_traffic = {"icmp": ["NONE"], "tcp": [53,]}
monitored_traffic = {
"icmp": ["NONE"],
"tcp": [
53,
],
}
pc: Computer = simulation.network.get_node_by_hostname("client_1")
pc2: Computer = simulation.network.get_node_by_hostname("client_2")

View File

@@ -24,8 +24,8 @@ def test_router_observation():
num_rules=7,
ip_list=["10.0.0.1", "10.0.0.2"],
wildcard_list=["0.0.0.255", "0.0.0.1"],
port_list=["HTTP", "DNS"],
protocol_list=["TCP"],
port_list=[80, 53],
protocol_list=["tcp"],
)
router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl, include_users=False)

View File

@@ -65,7 +65,9 @@ def test_uc2_rewards(game_and_agent):
db_client.run()
router: Router = game.simulation.network.get_node_by_hostname("router")
router.acl.add_rule(ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=2)
router.acl.add_rule(
ACLAction.PERMIT, src_port=Port["POSTGRES_SERVER"], dst_port=Port["POSTGRES_SERVER"], position=2
)
comp = GreenAdminDatabaseUnreachablePenalty("client_1")

View File

@@ -2,7 +2,6 @@
import yaml
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.airspace import AirSpaceFrequency
from tests import TEST_ASSETS_ROOT
@@ -13,8 +12,8 @@ def test_override_freq_max_capacity_mbps():
config_dict = yaml.safe_load(f)
network = PrimaiteGame.from_config(cfg=config_dict).simulation.network
assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_2_4"]) == 123.45
assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_5"]) == 0.0
assert network.airspace.get_frequency_max_capacity_mbps("WIFI_2_4") == 123.45
assert network.airspace.get_frequency_max_capacity_mbps("WIFI_5") == 0.0
pc_a = network.get_node_by_hostname("pc_a")
pc_b = network.get_node_by_hostname("pc_b")
@@ -32,8 +31,8 @@ def test_override_freq_max_capacity_mbps_blocked():
config_dict = yaml.safe_load(f)
network = PrimaiteGame.from_config(cfg=config_dict).simulation.network
assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_2_4"]) == 0.0
assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency["WIFI_5"]) == 0.0
assert network.airspace.get_frequency_max_capacity_mbps("WIFI_2_4") == 0.0
assert network.airspace.get_frequency_max_capacity_mbps("WIFI_5") == 0.0
pc_a = network.get_node_by_hostname("pc_a")
pc_b = network.get_node_by_hostname("pc_b")

View File

@@ -73,8 +73,12 @@ def dmz_external_internal_network() -> Network:
firewall_node.external_outbound_acl.add_rule(
action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22
)
firewall_node.dmz_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22)
firewall_node.dmz_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22)
firewall_node.dmz_inbound_acl.add_rule(
action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22
)
firewall_node.dmz_outbound_acl.add_rule(
action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22
)
# external node
external_node = Computer(
@@ -262,8 +266,12 @@ def test_service_allowed_with_rule(dmz_external_internal_network):
assert not internal_ntp_client.time
firewall.internal_outbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1)
firewall.internal_inbound_acl.add_rule(action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1)
firewall.internal_outbound_acl.add_rule(
action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1
)
firewall.internal_inbound_acl.add_rule(
action=ACLAction.PERMIT, src_port=Port["NTP"], dst_port=Port["NTP"], position=1
)
internal_ntp_client.request_time()

View File

@@ -73,7 +73,9 @@ def test_port_scan_one_node_one_port(example_network):
client_2 = network.get_node_by_hostname("client_2")
actual_result = client_1_nmap.port_scan(
target_ip_address=client_2.network_interface[1].ip_address, target_port=Port["DNS"], target_protocol=IPProtocol["TCP"]
target_ip_address=client_2.network_interface[1].ip_address,
target_port=Port["DNS"],
target_protocol=IPProtocol["TCP"],
)
expected_result = {IPv4Address("192.168.10.22"): {IPProtocol["TCP"]: [Port["DNS"]]}}

View File

@@ -66,7 +66,9 @@ def test_nested_dicts():
The expected output should have string values of enums as keys at all levels.
"""
original_dict = {
IPProtocol["UDP"]: {Port["ARP"]: {"inbound": 0, "outbound": 1016.0, "details": {IPProtocol["TCP"]: {"latency": "low"}}}}
IPProtocol["UDP"]: {
Port["ARP"]: {"inbound": 0, "outbound": 1016.0, "details": {IPProtocol["TCP"]: {"latency": "low"}}}
}
}
expected_dict = {"udp": {219: {"inbound": 0, "outbound": 1016.0, "details": {"tcp": {"latency": "low"}}}}}
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict
@@ -79,6 +81,9 @@ def test_non_dict_values():
The original dictionary contains lists and tuples as values.
The expected output should preserve these non-dictionary values while converting enum keys to string values.
"""
original_dict = {IPProtocol["UDP"]: [Port["ARP"], Port["HTTP"]], "protocols": (IPProtocol["TCP"], IPProtocol["UDP"])}
original_dict = {
IPProtocol["UDP"]: [Port["ARP"], Port["HTTP"]],
"protocols": (IPProtocol["TCP"], IPProtocol["UDP"]),
}
expected_dict = {"udp": [Port["ARP"], Port["HTTP"]], "protocols": (IPProtocol["TCP"], IPProtocol["UDP"])}
assert convert_dict_enum_keys_to_enum_values(original_dict) == expected_dict