diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d08974c..9493dec4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - File and folder observations can now be configured to always show the true health status, or require scanning like before. - It's now possible to disable stickiness on reward components, meaning their value returns to 0 during timesteps where agent don't issue the corresponding action. Affects `GreenAdminDatabaseUnreachablePenalty`, `WebpageUnavailablePenalty`, `WebServer404Penalty` - Node observations can now be configured to show the number of active local and remote logins. +- Ports, IP Protocols, and airspace frequencies no longer use enums. They defined in dictionary lookups and are handled by custom validation to enable extendability with plugins. ### Fixed - Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config) diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index 65dceeb1..03d43130 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -48,11 +48,11 @@ _default_frequency_set: Dict[str, Dict] = { """Frequency configuration that is automatically used for any new airspace.""" -def register_default_frequency(freq_name: str, freq_hz: float, data_rate_bps: float): +def register_default_frequency(freq_name: str, freq_hz: float, data_rate_bps: float) -> None: """Add to the default frequency configuration. This is intended as a plugin hook. If your plugin makes use of bespoke frequencies for wireless communication, you should make a call to this method - whereever you define components that rely on the bespoke frequencies. That way, as soon as your components are + wherever you define components that rely on the bespoke frequencies. That way, as soon as your components are imported, this function automatically updates the default frequency set. This should also be run before instances of AirSpace are created. @@ -93,7 +93,7 @@ class AirSpace(BaseModel): return self.frequencies[freq_name]["data_rate_bps"] / (1024.0 * 1024.0) return 0.0 - def set_frequency_max_capacity_mbps(self, cfg: Dict[int, float]): + def set_frequency_max_capacity_mbps(self, cfg: Dict[int, float]) -> None: """ Sets custom maximum data transmission capacities for multiple frequencies. diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 778cffa2..050f4667 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1839,15 +1839,14 @@ class Node(SimComponent): def show_open_ports(self, markdown: bool = False): """Prints a table of the open ports on the Node.""" - table = PrettyTable(["Port", "Name"]) + table = PrettyTable(["Port"]) if markdown: table.set_style(MARKDOWN) table.align = "l" table.title = f"{self.hostname} Open Ports" for port in self.software_manager.get_open_ports(): if port > 0: - # TODO: do a reverse lookup for port name, or change this to only show port int - table.add_row([port, port]) + table.add_row([port]) print(table.get_string(sortby="Port")) @property diff --git a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py index aff12748..f77bc33a 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/abstract_c2.py @@ -366,7 +366,7 @@ class AbstractC2(Application, identifier="AbstractC2"): :return: True on successful configuration, false otherwise. :rtype: bool """ - # Validating that they are valid Enums. + # Validating that they are valid Ports and Protocols. if not is_valid_port(payload.masquerade_port) or not is_valid_protocol(payload.masquerade_protocol): self.sys_log.warning( f"{self.name}: Received invalid Masquerade Values within Keep Alive." diff --git a/tests/unit_tests/_primaite/_utils/_validation/__init__.py b/tests/unit_tests/_primaite/_utils/_validation/__init__.py new file mode 100644 index 00000000..be6c00e7 --- /dev/null +++ b/tests/unit_tests/_primaite/_utils/_validation/__init__.py @@ -0,0 +1 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK diff --git a/tests/unit_tests/_primaite/_utils/_validation/test_ip_protocol.py b/tests/unit_tests/_primaite/_utils/_validation/test_ip_protocol.py new file mode 100644 index 00000000..27829570 --- /dev/null +++ b/tests/unit_tests/_primaite/_utils/_validation/test_ip_protocol.py @@ -0,0 +1,23 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest + +from primaite.utils.validation.ip_protocol import IPProtocol, is_valid_protocol, PROTOCOL_LOOKUP, protocol_validator + + +def test_port_conversion(): + for proto_name, proto_val in PROTOCOL_LOOKUP.items(): + assert protocol_validator(proto_name) == proto_val + assert is_valid_protocol(proto_name) + + +def test_port_passthrough(): + for proto_val in PROTOCOL_LOOKUP.values(): + assert protocol_validator(proto_val) == proto_val + assert is_valid_protocol(proto_val) + + +def test_invalid_ports(): + for port in (123, "abcdefg", "NONEXISTENT_PROTO"): + with pytest.raises(ValueError): + protocol_validator(port) + assert not is_valid_protocol(port) diff --git a/tests/unit_tests/_primaite/_utils/_validation/test_port.py b/tests/unit_tests/_primaite/_utils/_validation/test_port.py new file mode 100644 index 00000000..6a8a2429 --- /dev/null +++ b/tests/unit_tests/_primaite/_utils/_validation/test_port.py @@ -0,0 +1,25 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest + +from primaite.utils.validation.port import is_valid_port, Port, PORT_LOOKUP, port_validator + + +def test_port_conversion(): + valid_port_lookup = {k: v for k, v in PORT_LOOKUP.items() if k != "UNUSED"} + for port_name, port_val in valid_port_lookup.items(): + assert port_validator(port_name) == port_val + assert is_valid_port(port_name) + + +def test_port_passthrough(): + valid_port_lookup = {k: v for k, v in PORT_LOOKUP.items() if k != "UNUSED"} + for port_val in valid_port_lookup.values(): + assert port_validator(port_val) == port_val + assert is_valid_port(port_val) + + +def test_invalid_ports(): + for port in (999999, -20, 3.214, "NONEXISTENT_PORT"): + with pytest.raises(ValueError): + port_validator(port) + assert not is_valid_port(port)