Remove broken tests
This commit is contained in:
@@ -54,37 +54,37 @@ def file_system() -> FileSystem:
|
||||
return Node(hostname="fs_node").file_system
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
# PrimAITE v2 stuff
|
||||
class TempPrimaiteSession(PrimaiteSession):
|
||||
@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
class TempPrimaiteSession: # PrimaiteSession):
|
||||
"""
|
||||
A temporary PrimaiteSession class.
|
||||
|
||||
Uses context manager for deletion of files upon exit.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
self.setup()
|
||||
# def __init__(
|
||||
# self,
|
||||
# training_config_path: Union[str, Path],
|
||||
# lay_down_config_path: Union[str, Path],
|
||||
# ):
|
||||
# super().__init__(training_config_path, lay_down_config_path)
|
||||
# self.setup()
|
||||
|
||||
@property
|
||||
def env(self) -> Primaite:
|
||||
"""Direct access to the env for ease of testing."""
|
||||
return self._agent_session._env # noqa
|
||||
# @property
|
||||
# def env(self) -> Primaite:
|
||||
# """Direct access to the env for ease of testing."""
|
||||
# return self._agent_session._env # noqa
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
# def __enter__(self):
|
||||
# return self
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
shutil.rmtree(self.session_path)
|
||||
_LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
|
||||
# def __exit__(self, type, value, tb):
|
||||
# shutil.rmtree(self.session_path)
|
||||
# _LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.fixture
|
||||
def temp_primaite_session(request):
|
||||
"""
|
||||
@@ -139,7 +139,7 @@ def temp_primaite_session(request):
|
||||
return TempPrimaiteSession(training_config_path, lay_down_config_path)
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.fixture
|
||||
def temp_session_path() -> Path:
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations import FileObservation
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
|
||||
@@ -1,174 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Used to tes the ACL functions."""
|
||||
|
||||
# from primaite.acl.access_control_list import AccessControlList
|
||||
# from primaite.acl.acl_rule import ACLRule
|
||||
# from primaite.common.enums import RulePermissionType
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
def test_acl_address_match_1():
|
||||
"""Test that matching IP addresses produce True."""
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
def test_acl_address_match_2():
|
||||
"""Test that mismatching IP addresses produce False."""
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.3") == False
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
def test_acl_address_match_3():
|
||||
"""Test the ANY condition for source IP addresses produce True."""
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "ANY", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
def test_acl_address_match_4():
|
||||
"""Test the ANY condition for dest IP addresses produce True."""
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule(RulePermissionType.ALLOW, "192.168.1.1", "ANY", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
def test_check_acl_block_affirmative():
|
||||
"""Test the block function (affirmative)."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
# Create a rule
|
||||
acl_rule_permission = RulePermissionType.ALLOW
|
||||
acl_rule_source = "192.168.1.1"
|
||||
acl_rule_destination = "192.168.1.2"
|
||||
acl_rule_protocol = "TCP"
|
||||
acl_rule_port = "80"
|
||||
acl_position_in_list = "0"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == False
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
def test_check_acl_block_negative():
|
||||
"""Test the block function (negative)."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
# Create a rule
|
||||
acl_rule_permission = RulePermissionType.DENY
|
||||
acl_rule_source = "192.168.1.1"
|
||||
acl_rule_destination = "192.168.1.2"
|
||||
acl_rule_protocol = "TCP"
|
||||
acl_rule_port = "80"
|
||||
acl_position_in_list = "0"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
|
||||
assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == True
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
def test_rule_hash():
|
||||
"""Test the rule hash."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList(RulePermissionType.DENY, 10)
|
||||
|
||||
rule = ACLRule(RulePermissionType.DENY, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
hash_value_local = hash(rule)
|
||||
|
||||
hash_value_remote = acl.get_dictionary_hash(RulePermissionType.DENY, "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert hash_value_local == hash_value_remote
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
def test_delete_rule():
|
||||
"""Adds 3 rules and deletes 1 rule and checks its deletion."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList(RulePermissionType.ALLOW, 10)
|
||||
|
||||
# Create a first rule
|
||||
acl_rule_permission = RulePermissionType.DENY
|
||||
acl_rule_source = "192.168.1.1"
|
||||
acl_rule_destination = "192.168.1.2"
|
||||
acl_rule_protocol = "TCP"
|
||||
acl_rule_port = "80"
|
||||
acl_position_in_list = "0"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
|
||||
# Create a second rule
|
||||
acl_rule_permission = RulePermissionType.DENY
|
||||
acl_rule_source = "20"
|
||||
acl_rule_destination = "30"
|
||||
acl_rule_protocol = "FTP"
|
||||
acl_rule_port = "21"
|
||||
acl_position_in_list = "2"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
|
||||
# Create a third rule
|
||||
acl_rule_permission = RulePermissionType.ALLOW
|
||||
acl_rule_source = "192.168.1.3"
|
||||
acl_rule_destination = "192.168.1.1"
|
||||
acl_rule_protocol = "UDP"
|
||||
acl_rule_port = "60"
|
||||
acl_position_in_list = "4"
|
||||
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
acl_position_in_list,
|
||||
)
|
||||
# Remove the second ACL rule added from the list
|
||||
acl.remove_rule(RulePermissionType.DENY, "20", "30", "FTP", "21")
|
||||
|
||||
assert len(acl.acl) == 10
|
||||
assert acl.acl[2] is None
|
||||
@@ -1,126 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Used to test Active Node functions."""
|
||||
import pytest
|
||||
|
||||
# from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
||||
# from primaite.nodes.active_node import ActiveNode
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"operating_state, expected_state",
|
||||
[
|
||||
(HardwareState.OFF, SoftwareState.GOOD),
|
||||
(HardwareState.ON, SoftwareState.OVERWHELMED),
|
||||
],
|
||||
)
|
||||
def test_os_state_change(operating_state, expected_state):
|
||||
"""
|
||||
Test that a node cannot change its Software State.
|
||||
|
||||
When its hardware state is OFF.
|
||||
"""
|
||||
active_node = ActiveNode(
|
||||
0,
|
||||
"node",
|
||||
"COMPUTER",
|
||||
"1",
|
||||
operating_state,
|
||||
"192.168.0.1",
|
||||
SoftwareState.GOOD,
|
||||
"GOOD",
|
||||
1,
|
||||
)
|
||||
|
||||
active_node.software_state = SoftwareState.OVERWHELMED
|
||||
|
||||
assert active_node.software_state == expected_state
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"operating_state, expected_state",
|
||||
[
|
||||
(HardwareState.OFF, SoftwareState.GOOD),
|
||||
(HardwareState.ON, SoftwareState.OVERWHELMED),
|
||||
],
|
||||
)
|
||||
def test_os_state_change_if_not_compromised(operating_state, expected_state):
|
||||
"""
|
||||
Test that a node cannot change its Software State.
|
||||
|
||||
If not compromised) when its hardware state is OFF.
|
||||
"""
|
||||
active_node = ActiveNode(
|
||||
0,
|
||||
"node",
|
||||
"COMPUTER",
|
||||
"1",
|
||||
operating_state,
|
||||
"192.168.0.1",
|
||||
SoftwareState.GOOD,
|
||||
"GOOD",
|
||||
1,
|
||||
)
|
||||
|
||||
active_node.set_software_state_if_not_compromised(SoftwareState.OVERWHELMED)
|
||||
|
||||
assert active_node.software_state == expected_state
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"operating_state, expected_state",
|
||||
[
|
||||
(HardwareState.OFF, FileSystemState.GOOD),
|
||||
(HardwareState.ON, FileSystemState.CORRUPT),
|
||||
],
|
||||
)
|
||||
def test_file_system_change(operating_state, expected_state):
|
||||
"""Test that a node cannot change its file system state when its hardware state is ON."""
|
||||
active_node = ActiveNode(
|
||||
0,
|
||||
"node",
|
||||
"COMPUTER",
|
||||
"1",
|
||||
operating_state,
|
||||
"192.168.0.1",
|
||||
"COMPROMISED",
|
||||
FileSystemState.GOOD,
|
||||
1,
|
||||
)
|
||||
|
||||
active_node.set_file_system_state(FileSystemState.CORRUPT)
|
||||
|
||||
assert active_node.file_system_state_actual == expected_state
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"operating_state, expected_state",
|
||||
[
|
||||
(HardwareState.OFF, FileSystemState.GOOD),
|
||||
(HardwareState.ON, FileSystemState.CORRUPT),
|
||||
],
|
||||
)
|
||||
def test_file_system_change_if_not_compromised(operating_state, expected_state):
|
||||
"""
|
||||
Test that a node cannot change its file system state.
|
||||
|
||||
If not compromised) when its hardware state is OFF.
|
||||
"""
|
||||
active_node = ActiveNode(
|
||||
0,
|
||||
"node",
|
||||
"COMPUTER",
|
||||
"1",
|
||||
operating_state,
|
||||
"192.168.0.1",
|
||||
"GOOD",
|
||||
FileSystemState.GOOD,
|
||||
1,
|
||||
)
|
||||
|
||||
active_node.set_file_system_state_if_not_compromised(FileSystemState.CORRUPT)
|
||||
|
||||
assert active_node.file_system_state_actual == expected_state
|
||||
@@ -1,30 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.main import run
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"legacy_file",
|
||||
[
|
||||
("legacy_config_1_DDOS_BASIC.yaml"),
|
||||
("legacy_config_2_DDOS_BASIC.yaml"),
|
||||
("legacy_config_3_DOS_VERY_BASIC.yaml"),
|
||||
("legacy_config_5_DATA_MANIPULATION.yaml"),
|
||||
],
|
||||
)
|
||||
def test_legacy_training_config_run_session(legacy_file):
|
||||
"""Tests using legacy training and lay down config files in PrimAITE session end-to-end."""
|
||||
legacy_training_config_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml"
|
||||
legacy_lay_down_config_path = TEST_CONFIG_ROOT / "legacy_conversion" / legacy_file
|
||||
|
||||
# Run a PrimAITE session using legacy training and lay down config file paths
|
||||
run(
|
||||
legacy_training_config_path,
|
||||
legacy_lay_down_config_path,
|
||||
legacy_training_config=True,
|
||||
legacy_lay_down_config=True,
|
||||
)
|
||||
@@ -1,45 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
# from primaite.config.lay_down_config import (
|
||||
# convert_legacy_lay_down_config,
|
||||
# data_manipulation_config_path,
|
||||
# ddos_basic_one_config_path,
|
||||
# ddos_basic_two_config_path,
|
||||
# dos_very_basic_config_path,
|
||||
# )
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"legacy_file, new_path",
|
||||
[
|
||||
("legacy_config_1_DDOS_BASIC.yaml", ddos_basic_one_config_path()),
|
||||
("legacy_config_2_DDOS_BASIC.yaml", ddos_basic_two_config_path()),
|
||||
("legacy_config_3_DOS_VERY_BASIC.yaml", dos_very_basic_config_path()),
|
||||
("legacy_config_5_DATA_MANIPULATION.yaml", data_manipulation_config_path()),
|
||||
],
|
||||
)
|
||||
def test_legacy_lay_down_config_load(legacy_file, new_path):
|
||||
"""Tests converting legacy lay down files into the new format."""
|
||||
with open(TEST_CONFIG_ROOT / "legacy_conversion" / legacy_file, "r") as file:
|
||||
legacy_lay_down_config = yaml.safe_load(file)
|
||||
|
||||
with open(new_path, "r") as file:
|
||||
new_lay_down_config = yaml.safe_load(file)
|
||||
|
||||
converted_lay_down_config = convert_legacy_lay_down_config(legacy_lay_down_config)
|
||||
|
||||
assert len(converted_lay_down_config) == len(new_lay_down_config)
|
||||
|
||||
for i, new_item in enumerate(new_lay_down_config):
|
||||
converted_item = converted_lay_down_config[i]
|
||||
|
||||
for key, val in new_item.items():
|
||||
if key == "position":
|
||||
continue
|
||||
assert key in converted_item
|
||||
|
||||
assert val == converted_item[key]
|
||||
@@ -1,383 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Test env creation and behaviour with different observation spaces."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
# from primaite.environment.observations import NodeLinkTable, NodeStatuses, ObservationsHandler
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_default_obs_space(temp_primaite_session):
|
||||
"""Create environment with no obs space defined in config and check that the default obs space was created."""
|
||||
with temp_primaite_session as session:
|
||||
session.env.update_environent_obs()
|
||||
|
||||
components = session.env.obs_handler.registered_obs_components
|
||||
|
||||
assert len(components) == 1
|
||||
assert isinstance(components[0], NodeLinkTable)
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_registering_components(temp_primaite_session):
|
||||
"""Test regitering and deregistering a component."""
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
handler = ObservationsHandler()
|
||||
component = NodeStatuses(env)
|
||||
handler.register(component)
|
||||
assert component in handler.registered_obs_components
|
||||
handler.deregister(component)
|
||||
assert component not in handler.registered_obs_components
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_LINK_TABLE.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
class TestNodeLinkTable:
|
||||
"""Test the NodeLinkTable observation component (in isolation)."""
|
||||
|
||||
def test_obs_shape(self, temp_primaite_session):
|
||||
"""Try creating env with box observation space."""
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
env.update_environent_obs()
|
||||
|
||||
# we have three nodes and two links, with two service
|
||||
# therefore the box observation space will have:
|
||||
# * 5 rows (3 nodes + 2 links)
|
||||
# * 6 columns (four fixed and two for the services)
|
||||
assert env.env_obs.shape == (5, 6)
|
||||
|
||||
def test_value(self, temp_primaite_session):
|
||||
"""
|
||||
Test that the observation is generated correctly.
|
||||
|
||||
The laydown has:
|
||||
* 3 nodes (2 service nodes and 1 active node)
|
||||
* 2 services
|
||||
* 2 links
|
||||
|
||||
Both nodes have both services, and all states are GOOD, therefore the expected observation value is:
|
||||
|
||||
* Node 1:
|
||||
* 1 (id)
|
||||
* 1 (good hardware state)
|
||||
* 3 (compromised OS state)
|
||||
* 1 (good file system state)
|
||||
* 1 (good TCP state)
|
||||
* 1 (good UDP state)
|
||||
* Node 2:
|
||||
* 2 (id)
|
||||
* 1 (good hardware state)
|
||||
* 1 (good OS state)
|
||||
* 1 (good file system state)
|
||||
* 1 (good TCP state)
|
||||
* 4 (overwhelmed UDP state)
|
||||
* Node 3 (active node):
|
||||
* 3 (id)
|
||||
* 1 (good hardware state)
|
||||
* 1 (good OS state)
|
||||
* 1 (good file system state)
|
||||
* 0 (doesn't have service1)
|
||||
* 0 (doesn't have service2)
|
||||
* Link 1:
|
||||
* 4 (id)
|
||||
* 0 (n/a hardware state)
|
||||
* 0 (n/a OS state)
|
||||
* 0 (n/a file system state)
|
||||
* 999 (999 traffic for service1)
|
||||
* 0 (no traffic for service2)
|
||||
* Link 2:
|
||||
* 5 (id)
|
||||
* 0 (good hardware state)
|
||||
* 0 (good OS state)
|
||||
* 0 (good file system state)
|
||||
* 999 (999 traffic service1)
|
||||
* 0 (no traffic for service2)
|
||||
"""
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
# act = np.asarray([0,])
|
||||
obs, reward, done, info = env.step(0) # apply the 'do nothing' action
|
||||
|
||||
assert np.array_equal(
|
||||
obs,
|
||||
[
|
||||
[1, 1, 3, 1, 1, 1],
|
||||
[2, 1, 1, 1, 1, 4],
|
||||
[3, 1, 1, 1, 0, 0],
|
||||
[4, 0, 0, 0, 999, 0],
|
||||
[5, 0, 0, 0, 999, 0],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
class TestNodeStatuses:
|
||||
"""Test the NodeStatuses observation component (in isolation)."""
|
||||
|
||||
def test_obs_shape(self, temp_primaite_session):
|
||||
"""Try creating env with NodeStatuses as the only component."""
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
assert env.env_obs.shape == (15,)
|
||||
|
||||
def test_values(self, temp_primaite_session):
|
||||
"""
|
||||
Test that the hardware and software states are encoded correctly.
|
||||
|
||||
The laydown has:
|
||||
* one node with a compromised operating system state
|
||||
* one node with two services, and the second service is overwhelmed.
|
||||
* all other states are good or null
|
||||
Therefore, the expected state is:
|
||||
* node 1:
|
||||
* hardware = good (1)
|
||||
* OS = compromised (3)
|
||||
* file system = good (1)
|
||||
* service 1 = good (1)
|
||||
* service 2 = good (1)
|
||||
* node 2:
|
||||
* hardware = good (1)
|
||||
* OS = good (1)
|
||||
* file system = good (1)
|
||||
* service 1 = good (1)
|
||||
* service 2 = overwhelmed (4)
|
||||
* node 3 (switch):
|
||||
* hardware = good (1)
|
||||
* OS = good (1)
|
||||
* file system = good (1)
|
||||
* service 1 = n/a (0)
|
||||
* service 2 = n/a (0)
|
||||
"""
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
obs, _, _, _ = env.step(0) # apply the 'do nothing' action
|
||||
print(obs)
|
||||
assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0])
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
class TestLinkTrafficLevels:
|
||||
"""Test the LinkTrafficLevels observation component (in isolation)."""
|
||||
|
||||
def test_obs_shape(self, temp_primaite_session):
|
||||
"""Try creating env with MultiDiscrete observation space."""
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
env.update_environent_obs()
|
||||
|
||||
# we have two links and two services, so the shape should be 2 * 2
|
||||
assert env.env_obs.shape == (2 * 2,)
|
||||
|
||||
def test_values(self, temp_primaite_session):
|
||||
"""
|
||||
Test that traffic values are encoded correctly.
|
||||
|
||||
The laydown has:
|
||||
* two services
|
||||
* three nodes
|
||||
* two links
|
||||
* an IER trying to send 999 bits of data over both links the whole time (via the first service)
|
||||
* link bandwidth of 1000, therefore the utilisation is 99.9%
|
||||
"""
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
obs, reward, done, info = env.step(0)
|
||||
obs, reward, done, info = env.step(0)
|
||||
|
||||
# the observation space has combine_service_traffic set to False, so the space has this format:
|
||||
# [link1_service1, link1_service2, link2_service1, link2_service2]
|
||||
# we send 999 bits of data via link1 and link2 on service 1.
|
||||
# therefore the first and third elements should be 6 and all others 0
|
||||
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
|
||||
assert np.array_equal(obs, [6, 0, 6, 0])
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_ACCESS_CONTROL_LIST.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown_ACL.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
class TestAccessControlList:
|
||||
"""Test the AccessControlList observation component (in isolation)."""
|
||||
|
||||
def test_obs_shape(self, temp_primaite_session):
|
||||
"""Try creating env with MultiDiscrete observation space.
|
||||
|
||||
The laydown has 3 ACL Rules - that is the maximum_acl_rules it can have.
|
||||
Each ACL Rule in the observation space has 6 different elements:
|
||||
|
||||
6 * 3 = 18
|
||||
"""
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
env.update_environent_obs()
|
||||
|
||||
assert env.env_obs.shape == (18,)
|
||||
|
||||
def test_values(self, temp_primaite_session):
|
||||
"""Test that traffic values are encoded correctly.
|
||||
|
||||
The laydown has:
|
||||
* one ACL IMPLICIT DENY rule
|
||||
|
||||
Therefore, the ACL is full of NAs aka zeros and just 6 non-zero elements representing DENY ANY ANY ANY at
|
||||
Position 2.
|
||||
"""
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
obs, reward, done, info = env.step(0)
|
||||
obs, reward, done, info = env.step(0)
|
||||
|
||||
assert np.array_equal(obs, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2])
|
||||
|
||||
def test_observation_space_with_implicit_rule(self, temp_primaite_session):
|
||||
"""
|
||||
Test observation space is what is expected when an agent adds ACLs during an episode.
|
||||
|
||||
At the start of the episode, there is a single implicit DENY rule
|
||||
In the observation space IMPLICIT DENY: 1,1,1,1,1,0
|
||||
0 shows the rule is the start (when episode began no other rules were created) so this is correct.
|
||||
|
||||
On Step 2, there is an ACL rule added at Position 0: 2,2,3,2,3,0
|
||||
|
||||
On Step 4, there is a second ACL rule added at POSITION 1: 2,4,2,3,3,1
|
||||
|
||||
The final observation space should be this:
|
||||
[2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2]
|
||||
|
||||
The ACL Rule from Step 2 is added first and has a HIGHER position than the ACL rule from Step 4
|
||||
but both come before the IMPLICIT DENY which will ALWAYS be at the end of the ACL List.
|
||||
"""
|
||||
# TODO: Refactor this at some point to build a custom ACL Hardcoded
|
||||
# Agent and then patch the AgentIdentifier Enum class so that it
|
||||
# has ACL_AGENT. This then allows us to set the agent identified in
|
||||
# the main config and is a bit cleaner.
|
||||
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
training_config = env.training_config
|
||||
for episode in range(0, training_config.num_train_episodes):
|
||||
for step in range(0, training_config.num_train_steps):
|
||||
# Do nothing action
|
||||
action = 0
|
||||
if step == 2:
|
||||
# Action to add the first ACL rule
|
||||
action = 43
|
||||
elif step == 4:
|
||||
# Action to add the second ACL rule
|
||||
action = 96
|
||||
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
# Break if done is True
|
||||
if done:
|
||||
break
|
||||
obs = env.env_obs
|
||||
|
||||
assert np.array_equal(obs, [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2])
|
||||
|
||||
def test_observation_space_with_different_positions(self, temp_primaite_session):
|
||||
"""
|
||||
Test observation space is what is expected when an agent adds ACLs during an episode.
|
||||
|
||||
At the start of the episode, there is a single implicit DENY rule
|
||||
In the observation space IMPLICIT DENY: 1,1,1,1,1,0
|
||||
0 shows the rule is the start (when episode began no other rules were created) so this is correct.
|
||||
|
||||
On Step 2, there is an ACL rule added at Position 1: 2,2,3,2,3,1
|
||||
|
||||
On Step 4 there is a second ACL rule added at Position 0: 2,4,2,3,3,0
|
||||
|
||||
The final observation space should be this:
|
||||
[2 , 4, 2, 3, 3, 0, 2, 2, 3, 2, 3, 1, 1, 1, 1, 1, 1, 2]
|
||||
|
||||
The ACL Rule from Step 2 is added before and has a LOWER position than the ACL rule from Step 4
|
||||
but both come before the IMPLICIT DENY which will ALWAYS be at the end of the ACL List.
|
||||
"""
|
||||
# TODO: Refactor this at some point to build a custom ACL Hardcoded
|
||||
# Agent and then patch the AgentIdentifier Enum class so that it
|
||||
# has ACL_AGENT. This then allows us to set the agent identified in
|
||||
# the main config and is a bit cleaner.
|
||||
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
training_config = env.training_config
|
||||
for episode in range(0, training_config.num_train_episodes):
|
||||
for step in range(0, training_config.num_train_steps):
|
||||
# Do nothing action
|
||||
action = 0
|
||||
if step == 2:
|
||||
# Action to add the first ACL rule
|
||||
action = 44
|
||||
elif step == 4:
|
||||
# Action to add the second ACL rule
|
||||
action = 95
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
# Break if done is True
|
||||
if done:
|
||||
break
|
||||
obs = env.env_obs
|
||||
|
||||
assert np.array_equal(obs, [2, 4, 2, 3, 3, 0, 2, 2, 3, 2, 3, 1, 1, 1, 1, 1, 1, 2])
|
||||
@@ -1,79 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
# from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
# [TEST_CONFIG_ROOT / "session_test/training_config_main_rllib.yaml", dos_very_basic_config_path()],
|
||||
[TEST_CONFIG_ROOT / "session_test/training_config_main_sb3.yaml", dos_very_basic_config_path()],
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_primaite_session(temp_primaite_session):
|
||||
"""
|
||||
Tests the PrimaiteSession class and all of its outputs.
|
||||
|
||||
This test runs for both a Stable Baselines3 agent, and a Ray RLlib agent.
|
||||
"""
|
||||
with temp_primaite_session as session:
|
||||
session_path = session.session_path
|
||||
assert session_path.exists()
|
||||
session.learn()
|
||||
# Learning outputs are saved in session.learning_path
|
||||
session.evaluate()
|
||||
# Evaluation outputs are saved in session.evaluation_path
|
||||
|
||||
# If you need to inspect any session outputs, it must be done inside
|
||||
# the context manager
|
||||
|
||||
# Check that the metadata json file exists
|
||||
assert (session_path / "session_metadata.json").exists()
|
||||
|
||||
# Check that the network png file exists
|
||||
assert (session_path / f"network_{session.timestamp_str}.png").exists()
|
||||
|
||||
# Check that the saved agent exists
|
||||
assert session._agent_session._saved_agent_path.exists()
|
||||
|
||||
# Check that both the transactions and av reward csv files exist
|
||||
for file in session.learning_path.iterdir():
|
||||
if file.suffix == ".csv":
|
||||
assert "all_transactions" in file.name or "average_reward_per_episode" in file.name
|
||||
|
||||
# Check that both the transactions and av reward csv files exist
|
||||
for file in session.evaluation_path.iterdir():
|
||||
if file.suffix == ".csv":
|
||||
assert "all_transactions" in file.name or "average_reward_per_episode" in file.name
|
||||
|
||||
# Check that the average reward per episode plots exist
|
||||
assert (session.learning_path / f"average_reward_per_episode_{session.timestamp_str}.png").exists()
|
||||
assert (session.evaluation_path / f"average_reward_per_episode_{session.timestamp_str}.png").exists()
|
||||
|
||||
# Check that the metadata has captured the correct number of learning and eval episodes and steps
|
||||
assert len(session.learn_av_reward_per_episode_dict().keys()) == 10
|
||||
assert len(session.learn_all_transactions_dict().keys()) == 10 * 256
|
||||
|
||||
assert len(session.eval_av_reward_per_episode_dict().keys()) == 3
|
||||
assert len(session.eval_all_transactions_dict().keys()) == 3 * 256
|
||||
|
||||
_LOGGER.debug("Inspecting files in temp session path...")
|
||||
for dir_path, dir_names, file_names in os.walk(session_path):
|
||||
for file in file_names:
|
||||
path = os.path.join(dir_path, file)
|
||||
file_str = path.split(str(session_path))[-1]
|
||||
_LOGGER.debug(f" {file_str}")
|
||||
|
||||
# Now that we've exited the context manager, the session.session_path
|
||||
# directory and its contents are deleted
|
||||
assert not session_path.exists()
|
||||
@@ -1,40 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import pytest
|
||||
|
||||
# from primaite.config.lay_down_config import data_manipulation_config_path
|
||||
# from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "test_random_red_main_config.yaml",
|
||||
data_manipulation_config_path(),
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_random_red_agent_behaviour(temp_primaite_session):
|
||||
"""Test that red agent POL is randomised each episode."""
|
||||
list_of_node_instructions = []
|
||||
|
||||
with temp_primaite_session as session:
|
||||
session.evaluate()
|
||||
list_of_node_instructions.append(session.env.red_node_pol)
|
||||
|
||||
session.evaluate()
|
||||
list_of_node_instructions.append(session.env.red_node_pol)
|
||||
|
||||
# compare instructions to make sure that red instructions are truly random
|
||||
for index, instruction in enumerate(list_of_node_instructions):
|
||||
for key in list_of_node_instructions[index].keys():
|
||||
instruction: NodeStateInstructionRed = list_of_node_instructions[index][key]
|
||||
print(f"run {index}")
|
||||
print(f"{key} start step: {instruction.get_start_step()}")
|
||||
print(f"{key} end step: {instruction.get_end_step()}")
|
||||
print(f"{key} target node id: {instruction.get_target_node_id()}")
|
||||
print("")
|
||||
assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1])
|
||||
@@ -1,89 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Used to test Active Node functions."""
|
||||
import pytest
|
||||
|
||||
# from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
|
||||
# from primaite.common.service import Service
|
||||
# from primaite.config.training_config import TrainingConfig
|
||||
# from primaite.nodes.active_node import ActiveNode
|
||||
# from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"starting_operating_state, expected_operating_state",
|
||||
[(HardwareState.RESETTING, HardwareState.ON)],
|
||||
)
|
||||
def test_node_resets_correctly(starting_operating_state, expected_operating_state):
|
||||
"""Tests that a node resets correctly."""
|
||||
active_node = ActiveNode(
|
||||
node_id="0",
|
||||
name="node",
|
||||
node_type=NodeType.COMPUTER,
|
||||
priority=Priority.P1,
|
||||
hardware_state=starting_operating_state,
|
||||
ip_address="192.168.0.1",
|
||||
software_state=SoftwareState.COMPROMISED,
|
||||
file_system_state=FileSystemState.CORRUPT,
|
||||
config_values=TrainingConfig(),
|
||||
)
|
||||
|
||||
for x in range(5):
|
||||
active_node.update_resetting_status()
|
||||
|
||||
assert active_node.software_state == SoftwareState.GOOD
|
||||
assert active_node.file_system_state_actual == FileSystemState.GOOD
|
||||
assert active_node.hardware_state == expected_operating_state
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"operating_state, expected_operating_state",
|
||||
[(HardwareState.BOOTING, HardwareState.ON)],
|
||||
)
|
||||
def test_node_boots_correctly(operating_state, expected_operating_state):
|
||||
"""Tests that a node boots correctly."""
|
||||
service_node = ServiceNode(
|
||||
node_id=0,
|
||||
name="node",
|
||||
node_type="COMPUTER",
|
||||
priority="1",
|
||||
hardware_state=operating_state,
|
||||
ip_address="192.168.0.1",
|
||||
software_state=SoftwareState.GOOD,
|
||||
file_system_state="GOOD",
|
||||
config_values=1,
|
||||
)
|
||||
service_attributes = Service(name="node", port="80", software_state=SoftwareState.COMPROMISED)
|
||||
service_node.add_service(service_attributes)
|
||||
|
||||
for x in range(5):
|
||||
service_node.update_booting_status()
|
||||
|
||||
assert service_attributes.software_state == SoftwareState.GOOD
|
||||
assert service_node.hardware_state == expected_operating_state
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"operating_state, expected_operating_state",
|
||||
[(HardwareState.SHUTTING_DOWN, HardwareState.OFF)],
|
||||
)
|
||||
def test_node_shutdown_correctly(operating_state, expected_operating_state):
|
||||
"""Tests that a node shutdown correctly."""
|
||||
active_node = ActiveNode(
|
||||
node_id=0,
|
||||
name="node",
|
||||
node_type="COMPUTER",
|
||||
priority="1",
|
||||
hardware_state=operating_state,
|
||||
ip_address="192.168.0.1",
|
||||
software_state=SoftwareState.GOOD,
|
||||
file_system_state="GOOD",
|
||||
config_values=1,
|
||||
)
|
||||
|
||||
for x in range(5):
|
||||
active_node.update_shutdown_status()
|
||||
|
||||
assert active_node.hardware_state == expected_operating_state
|
||||
@@ -1,54 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import pytest
|
||||
|
||||
from primaite import getLogger
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
|
||||
TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_rewards_are_being_penalised_at_each_step_function(
|
||||
temp_primaite_session,
|
||||
):
|
||||
"""
|
||||
Test that hardware state is penalised at each step.
|
||||
|
||||
When the initial state is OFF compared to reference state which is ON.
|
||||
|
||||
The config 'one_node_states_on_off_lay_down_config.yaml' has 15 steps:
|
||||
On different steps, the laydown config has Pattern of Life (PoLs) which change a state of the node's attribute.
|
||||
For example, turning the nodes' file system state to CORRUPT from its original state GOOD.
|
||||
As a result these are the following rewards are activated:
|
||||
File System State: corrupt_should_be_good = -10 * 2 (on Steps 1 & 2)
|
||||
Hardware State: off_should_be_on = -10 * 2 (on Steps 4 & 5)
|
||||
Service State: compromised_should_be_good = -20 * 2 (on Steps 7 & 8)
|
||||
Software State: compromised_should_be_good = -20 * 2 (on Steps 10 & 11)
|
||||
|
||||
The Pattern of Life (PoLs) last for 2 steps, so the agent is penalised twice.
|
||||
|
||||
Note: This test run inherits from conftest.py where the PrimAITE environment is ran and the blue agent is hard-coded
|
||||
to do NOTHING on every step.
|
||||
We use Pattern of Lifes (PoLs) to change the nodes states and display that the agent is being penalised on all steps
|
||||
where the live network node differs from the network reference node.
|
||||
|
||||
Total Reward: -10 + -10 + -10 + -10 + -20 + -20 + -20 + -20 = -120
|
||||
Step Count: 15
|
||||
|
||||
For the 4 steps where this occurs the average reward is:
|
||||
Average Reward: -8 (-120 / 15)
|
||||
"""
|
||||
with temp_primaite_session as session:
|
||||
session.evaluate()
|
||||
ev_rewards = session.eval_av_reward_per_episode_dict()
|
||||
assert ev_rewards[1] == -8.0
|
||||
@@ -1,66 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import pytest as pytest
|
||||
|
||||
# from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]],
|
||||
indirect=True,
|
||||
)
|
||||
def test_seeded_learning(temp_primaite_session):
|
||||
"""
|
||||
Test running seeded learning produces the same output when ran twice.
|
||||
|
||||
.. note::
|
||||
|
||||
If this is failing, the hard-coded expected_mean_reward_per_episode
|
||||
from a pre-trained agent will probably need to be updated. If the
|
||||
env changes and those changed how this agent is trained, chances are
|
||||
the mean rewards are going to be different.
|
||||
|
||||
Run the test, but print out the session.learn_av_reward_per_episode()
|
||||
before comparing it. Then copy the printed dict and replace the
|
||||
expected_mean_reward_per_episode with those values. The test should
|
||||
now work. If not, then you've got a bug :).
|
||||
"""
|
||||
expected_mean_reward_per_episode = {
|
||||
1: -20.7421875,
|
||||
2: -19.82421875,
|
||||
3: -17.01171875,
|
||||
4: -19.08203125,
|
||||
5: -21.93359375,
|
||||
6: -20.21484375,
|
||||
7: -15.546875,
|
||||
8: -12.08984375,
|
||||
9: -17.59765625,
|
||||
10: -14.6875,
|
||||
}
|
||||
|
||||
with temp_primaite_session as session:
|
||||
assert (
|
||||
session._training_config.seed == 67890
|
||||
), "Expected output is based upon a agent that was trained with seed 67890"
|
||||
session.learn()
|
||||
actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict()
|
||||
|
||||
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]],
|
||||
indirect=True,
|
||||
)
|
||||
def test_deterministic_evaluation(temp_primaite_session):
|
||||
"""Test running deterministic evaluation gives same av eward per episode."""
|
||||
with temp_primaite_session as session:
|
||||
# do stuff
|
||||
session.learn()
|
||||
session.evaluate()
|
||||
eval_mean_reward = session.eval_av_reward_per_episode_dict()
|
||||
assert len(set(eval_mean_reward.values())) == 1
|
||||
@@ -1,73 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Used to test Service Node functions."""
|
||||
import pytest
|
||||
|
||||
# from primaite.common.enums import HardwareState, SoftwareState
|
||||
# from primaite.common.service import Service
|
||||
# from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"operating_state, expected_state",
|
||||
[
|
||||
(HardwareState.OFF, SoftwareState.GOOD),
|
||||
(HardwareState.ON, SoftwareState.OVERWHELMED),
|
||||
],
|
||||
)
|
||||
def test_service_state_change(operating_state, expected_state):
|
||||
"""
|
||||
Test that a node cannot change the state of a running service.
|
||||
|
||||
When its hardware state is OFF.
|
||||
"""
|
||||
service_node = ServiceNode(
|
||||
0,
|
||||
"node",
|
||||
"COMPUTER",
|
||||
"1",
|
||||
operating_state,
|
||||
"192.168.0.1",
|
||||
"COMPROMISED",
|
||||
"RESTORING",
|
||||
1,
|
||||
)
|
||||
service = Service("TCP", 80, SoftwareState.GOOD)
|
||||
service_node.add_service(service)
|
||||
|
||||
service_node.set_service_state("TCP", SoftwareState.OVERWHELMED)
|
||||
|
||||
assert service_node.get_service_state("TCP") == expected_state
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.mark.parametrize(
|
||||
"operating_state, expected_state",
|
||||
[
|
||||
(HardwareState.OFF, SoftwareState.GOOD),
|
||||
(HardwareState.ON, SoftwareState.OVERWHELMED),
|
||||
],
|
||||
)
|
||||
def test_service_state_change_if_not_comprised(operating_state, expected_state):
|
||||
"""
|
||||
Test that a node cannot change the state of a running service.
|
||||
|
||||
If not compromised when its hardware state is ON.
|
||||
"""
|
||||
service_node = ServiceNode(
|
||||
0,
|
||||
"node",
|
||||
"COMPUTER",
|
||||
"1",
|
||||
operating_state,
|
||||
"192.168.0.1",
|
||||
"GOOD",
|
||||
"RESTORING",
|
||||
1,
|
||||
)
|
||||
service = Service("TCP", 80, SoftwareState.GOOD)
|
||||
service_node.add_service(service)
|
||||
|
||||
service_node.set_service_state_if_not_compromised("TCP", SoftwareState.OVERWHELMED)
|
||||
|
||||
assert service_node.get_service_state("TCP") == expected_state
|
||||
@@ -1,194 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import os.path
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
# from primaite.agents.sb3 import SB3Agent
|
||||
from primaite.cli import app
|
||||
|
||||
# from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.main import run
|
||||
|
||||
# from primaite.primaite_session import PrimaiteSession
|
||||
from primaite.utils.session_output_reader import av_rewards_dict
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
sb3_expected_avg_reward_per_episode = {
|
||||
10: 0.0,
|
||||
11: -0.0011074218750000008,
|
||||
12: -0.0010000000000000007,
|
||||
13: -0.0016601562500000013,
|
||||
14: -0.001400390625000001,
|
||||
15: -0.0009863281250000007,
|
||||
16: -0.0011855468750000008,
|
||||
17: -0.0009511718750000007,
|
||||
18: -0.0008789062500000007,
|
||||
19: -0.0012226562500000009,
|
||||
20: -0.0010292968750000007,
|
||||
}
|
||||
|
||||
sb3_expected_eval_rewards = -0.0018515625000000014
|
||||
|
||||
|
||||
def copy_session_asset(asset_path: Union[str, Path]) -> str:
|
||||
"""Copies the asset into a temporary test folder."""
|
||||
if asset_path is None:
|
||||
raise Exception("No path provided")
|
||||
|
||||
if isinstance(asset_path, Path):
|
||||
asset_path = str(os.path.normpath(asset_path))
|
||||
|
||||
copy_path = str(Path(tempfile.gettempdir()) / "primaite" / str(uuid4()))
|
||||
|
||||
# copy the asset into a temp path
|
||||
try:
|
||||
shutil.copytree(asset_path, copy_path)
|
||||
except Exception as e:
|
||||
msg = f"Unable to copy directory: {asset_path}"
|
||||
_LOGGER.error(msg, e)
|
||||
print(msg, e)
|
||||
|
||||
_LOGGER.debug(f"Copied test asset to: {copy_path}")
|
||||
|
||||
# return the copied assets path
|
||||
return copy_path
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
def test_load_sb3_session():
|
||||
"""Test that loading an SB3 agent works."""
|
||||
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
|
||||
|
||||
loaded_agent = SB3Agent(session_path=test_path)
|
||||
|
||||
# loaded agent should have the same UUID as the previous agent
|
||||
assert loaded_agent.uuid == "301874d3-2e14-43c2-ba7f-e2b03ad05dde"
|
||||
assert loaded_agent._training_config.agent_framework == AgentFramework.SB3.name
|
||||
assert loaded_agent._training_config.agent_identifier == AgentIdentifier.PPO.name
|
||||
assert loaded_agent._training_config.deterministic
|
||||
assert loaded_agent._training_config.seed == 12345
|
||||
assert str(loaded_agent.session_path) == str(test_path)
|
||||
|
||||
# run another learn session
|
||||
loaded_agent.learn()
|
||||
|
||||
learn_mean_rewards = av_rewards_dict(
|
||||
loaded_agent.learning_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv"
|
||||
)
|
||||
|
||||
# run is seeded so should have the expected learn value
|
||||
assert learn_mean_rewards == sb3_expected_avg_reward_per_episode
|
||||
|
||||
# run an evaluation
|
||||
loaded_agent.evaluate()
|
||||
|
||||
# load the evaluation average reward csv file
|
||||
eval_mean_reward = av_rewards_dict(
|
||||
loaded_agent.evaluation_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv"
|
||||
)
|
||||
|
||||
# the agent config ran the evaluation in deterministic mode, so should have the same reward value
|
||||
assert len(set(eval_mean_reward.values())) == 1
|
||||
|
||||
# the evaluation should be the same as a previous run
|
||||
assert next(iter(set(eval_mean_reward.values()))) == sb3_expected_eval_rewards
|
||||
|
||||
# delete the test directory
|
||||
shutil.rmtree(test_path)
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
def test_load_primaite_session():
|
||||
"""Test that loading a Primaite session works."""
|
||||
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
|
||||
|
||||
# create loaded session
|
||||
session = PrimaiteSession(session_path=test_path)
|
||||
|
||||
# run setup on session
|
||||
session.setup()
|
||||
|
||||
# make sure that the session was loaded correctly
|
||||
assert session._agent_session.uuid == "301874d3-2e14-43c2-ba7f-e2b03ad05dde"
|
||||
assert session._agent_session._training_config.agent_framework == AgentFramework.SB3.name
|
||||
assert session._agent_session._training_config.agent_identifier == AgentIdentifier.PPO.name
|
||||
assert session._agent_session._training_config.deterministic
|
||||
assert session._agent_session._training_config.seed == 12345
|
||||
assert str(session._agent_session.session_path) == str(test_path)
|
||||
|
||||
# run another learn session
|
||||
session.learn()
|
||||
|
||||
learn_mean_rewards = av_rewards_dict(
|
||||
session.learning_path / f"average_reward_per_episode_{session.timestamp_str}.csv"
|
||||
)
|
||||
|
||||
# run is seeded so should have the expected learn value
|
||||
assert learn_mean_rewards == sb3_expected_avg_reward_per_episode
|
||||
|
||||
# run an evaluation
|
||||
session.evaluate()
|
||||
|
||||
# load the evaluation average reward csv file
|
||||
eval_mean_reward = av_rewards_dict(
|
||||
session.evaluation_path / f"average_reward_per_episode_{session.timestamp_str}.csv"
|
||||
)
|
||||
|
||||
# the agent config ran the evaluation in deterministic mode, so should have the same reward value
|
||||
assert len(set(eval_mean_reward.values())) == 1
|
||||
|
||||
# the evaluation should be the same as a previous run
|
||||
assert next(iter(set(eval_mean_reward.values()))) == sb3_expected_eval_rewards
|
||||
|
||||
# delete the test directory
|
||||
shutil.rmtree(test_path)
|
||||
|
||||
|
||||
@pytest.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
def test_run_loading():
|
||||
"""Test loading session via main.run."""
|
||||
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
|
||||
|
||||
# create loaded session
|
||||
run(session_path=test_path)
|
||||
|
||||
learn_mean_rewards = av_rewards_dict(
|
||||
next(Path(test_path).rglob("**/learning/average_reward_per_episode_*.csv"), None)
|
||||
)
|
||||
|
||||
# run is seeded so should have the expected learn value
|
||||
assert learn_mean_rewards == sb3_expected_avg_reward_per_episode
|
||||
|
||||
# delete the test directory
|
||||
shutil.rmtree(test_path)
|
||||
|
||||
|
||||
def test_cli():
|
||||
"""Test loading session via CLI."""
|
||||
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
|
||||
result = runner.invoke(app, ["session", "--load", test_path])
|
||||
|
||||
# cli should work
|
||||
assert result.exit_code == 0
|
||||
|
||||
learn_mean_rewards = av_rewards_dict(
|
||||
next(Path(test_path).rglob("**/learning/average_reward_per_episode_*.csv"), None)
|
||||
)
|
||||
|
||||
# run is seeded so should have the expected learn value
|
||||
assert learn_mean_rewards == sb3_expected_avg_reward_per_episode
|
||||
|
||||
# delete the test directory
|
||||
shutil.rmtree(test_path)
|
||||
@@ -1,132 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
# from primaite.acl.acl_rule import ACLRule
|
||||
# from primaite.common.enums import HardwareState
|
||||
# from primaite.environment.primaite_env import Primaite
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
|
||||
|
||||
@pytest.skip("Deprecated")
|
||||
def run_generic_set_actions(env: Primaite):
|
||||
"""Run against a generic agent with specified blue agent actions."""
|
||||
# Reset the environment at the start of the episode
|
||||
# env.reset()
|
||||
training_config = env.training_config
|
||||
for episode in range(0, training_config.num_train_episodes):
|
||||
for step in range(0, training_config.num_train_steps):
|
||||
# Send the observation space to the agent to get an action
|
||||
# TEMP - random action for now
|
||||
# action = env.blue_agent_action(obs)
|
||||
action = 0
|
||||
# print("Episode:", episode, "\nStep:", step)
|
||||
if step == 5:
|
||||
# [1, 1, 2, 1, 1, 1, 1(position)]
|
||||
# Creates an ACL rule
|
||||
# Allows traffic from server_1 to node_1 on port FTP
|
||||
action = 56
|
||||
elif step == 7:
|
||||
# [1, 1, 2, 0] Node Action
|
||||
# Sets Node 1 Hardware State to OFF
|
||||
# Does not resolve any service
|
||||
action = 128
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
# Break if done is True
|
||||
if done:
|
||||
break
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(training_config.time_delay / 1000)
|
||||
|
||||
# Reset the environment at the end of the episode
|
||||
# env.reset()
|
||||
|
||||
# env.close()
|
||||
|
||||
|
||||
@pytest.skip("Deprecated")
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "single_action_space_main_config.yaml",
|
||||
TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_single_action_space_is_valid(temp_primaite_session):
|
||||
"""Test single action space is valid."""
|
||||
# TODO: Refactor this at some point to build a custom ACL Hardcoded
|
||||
# Agent and then patch the AgentIdentifier Enum class so that it
|
||||
# has ACL_AGENT. This then allows us to set the agent identified in
|
||||
# the main config and is a bit cleaner.
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
|
||||
run_generic_set_actions(env)
|
||||
# Retrieve the action space dictionary values from environment
|
||||
env_action_space_dict = env.action_dict.values()
|
||||
# Flags to check the conditions of the action space
|
||||
contains_acl_actions = False
|
||||
contains_node_actions = False
|
||||
both_action_spaces = False
|
||||
# Loop through each element of the list (which is every value from the dictionary)
|
||||
for dict_item in env_action_space_dict:
|
||||
# Node action detected
|
||||
if len(dict_item) == 4:
|
||||
contains_node_actions = True
|
||||
# Link action detected
|
||||
elif len(dict_item) == 7:
|
||||
contains_acl_actions = True
|
||||
# If both are there then the ANY action type is working
|
||||
if contains_node_actions and contains_acl_actions:
|
||||
both_action_spaces = True
|
||||
# Check condition should be True
|
||||
assert both_action_spaces
|
||||
|
||||
|
||||
@pytest.skip("Deprecated")
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml",
|
||||
TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_agent_is_executing_actions_from_both_spaces(temp_primaite_session):
|
||||
"""Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL)."""
|
||||
# TODO: Refactor this at some point to build a custom ACL Hardcoded
|
||||
# Agent and then patch the AgentIdentifier Enum class so that it
|
||||
# has ACL_AGENT. This then allows us to set the agent identified in
|
||||
# the main config and is a bit cleaner.
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
# Run environment with specified fixed blue agent actions only
|
||||
run_generic_set_actions(env)
|
||||
# Retrieve hardware state of computer_1 node in laydown config
|
||||
# Agent turned this off in Step 5
|
||||
computer_node_hardware_state = env.nodes["1"].hardware_state
|
||||
# Retrieve the Access Control List object stored by the environment at the end of the episode
|
||||
access_control_list = env.acl
|
||||
# Use the Access Control List object acl object attribute to get dictionary
|
||||
# Use dictionary.values() to get total list of all items in the dictionary
|
||||
acl_rules_list = access_control_list.acl
|
||||
# Length of this list tells you how many items are in the dictionary
|
||||
# This number is the frequency of Access Control Rules in the environment
|
||||
# In the scenario, we specified that the agent should create only 1 acl rule
|
||||
# This 1 rule added to the implicit deny means there should be 2 rules in total.
|
||||
rules_count = 0
|
||||
for rule in acl_rules_list:
|
||||
if isinstance(rule, ACLRule):
|
||||
rules_count += 1
|
||||
# Therefore these statements below MUST be true
|
||||
assert computer_node_hardware_state == HardwareState.OFF
|
||||
assert rules_count == 2
|
||||
@@ -1,45 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import pytest
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
# from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.skip("Deprecated")
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[[TEST_CONFIG_ROOT / "train_episode_step.yaml", dos_very_basic_config_path()]],
|
||||
indirect=True,
|
||||
)
|
||||
def test_eval_steps_differ_from_training(temp_primaite_session):
|
||||
"""Uses PrimaiteSession class to compare number of episodes used for training and evaluation.
|
||||
|
||||
Train_episode_step.yaml main config:
|
||||
num_train_steps = 25
|
||||
num_train_episodes = 3
|
||||
num_eval_steps = 17
|
||||
num_eval_episodes = 1
|
||||
"""
|
||||
expected_learning_metadata = {"total_episodes": 3, "total_time_steps": 75}
|
||||
expected_evaluation_metadata = {"total_episodes": 1, "total_time_steps": 17}
|
||||
|
||||
with temp_primaite_session as session:
|
||||
# Run learning and check episode and step counts
|
||||
session.learn()
|
||||
assert session.env.actual_episode_count == expected_learning_metadata["total_episodes"]
|
||||
assert session.env.total_step_count == expected_learning_metadata["total_time_steps"]
|
||||
|
||||
# Run evaluation and check episode and step counts
|
||||
session.evaluate()
|
||||
assert session.env.actual_episode_count == expected_evaluation_metadata["total_episodes"]
|
||||
assert session.env.total_step_count == expected_evaluation_metadata["total_time_steps"]
|
||||
|
||||
# Load the session_metadata.json file and check that the both the
|
||||
# learning and evaluation match what is expected above
|
||||
metadata = session.metadata_file_as_dict()
|
||||
assert metadata["learning"] == expected_learning_metadata
|
||||
assert metadata["evaluation"] == expected_evaluation_metadata
|
||||
@@ -1,40 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
# from primaite.config import training_config
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
|
||||
|
||||
@pytest.skip("Deprecated")
|
||||
def test_legacy_lay_down_config_yaml_conversion():
|
||||
"""Tests the conversion of legacy lay down config files."""
|
||||
legacy_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml"
|
||||
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml"
|
||||
|
||||
with open(legacy_path, "r") as file:
|
||||
legacy_dict = yaml.safe_load(file)
|
||||
|
||||
with open(new_path, "r") as file:
|
||||
new_dict = yaml.safe_load(file)
|
||||
|
||||
converted_dict = training_config.convert_legacy_training_config_dict(legacy_dict)
|
||||
|
||||
for key, value in new_dict.items():
|
||||
assert converted_dict[key] == value
|
||||
|
||||
|
||||
@pytest.skip("Deprecated")
|
||||
def test_create_config_values_main_from_file():
|
||||
"""Tests creating an instance of TrainingConfig from file."""
|
||||
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml"
|
||||
|
||||
training_config.load(new_path)
|
||||
|
||||
|
||||
@pytest.skip("Deprecated")
|
||||
def test_create_config_values_main_from_legacy_file():
|
||||
"""Tests creating an instance of TrainingConfig from legacy file."""
|
||||
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml"
|
||||
|
||||
training_config.load(new_path, legacy_file=True)
|
||||
Reference in New Issue
Block a user