From 64e8b3bceaa5ef75208791757d24f01c85b27db7 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 15 Nov 2023 16:04:16 +0000 Subject: [PATCH] Add basic primaite session e2e tests --- .../assets/configs/bad_primaite_session.yaml | 725 +++++++++++++++++ .../configs/eval_only_primaite_session.yaml | 729 ++++++++++++++++++ .../assets/configs/test_primaite_session.yaml | 729 ++++++++++++++++++ .../configs/train_only_primaite_session.yaml | 729 ++++++++++++++++++ tests/conftest.py | 104 +-- .../test_primaite_session.py | 51 ++ 6 files changed, 2981 insertions(+), 86 deletions(-) create mode 100644 tests/assets/configs/bad_primaite_session.yaml create mode 100644 tests/assets/configs/eval_only_primaite_session.yaml create mode 100644 tests/assets/configs/test_primaite_session.yaml create mode 100644 tests/assets/configs/train_only_primaite_session.yaml create mode 100644 tests/e2e_integration_tests/test_primaite_session.py diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml new file mode 100644 index 00000000..752d98a5 --- /dev/null +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -0,0 +1,725 @@ +training_config: + rl_framework: SB3 + rl_algorithm: PPO + se3ed: 333 + n_learn_steps: 2560 + n_eval_episodes: 5 + + + +game_config: + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + + agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # FileSystem: # PrimAITE v2 stuff -@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3 -class TempPrimaiteSession: # PrimaiteSession): +class TempPrimaiteSession(PrimaiteSession): """ A temporary PrimaiteSession class. Uses context manager for deletion of files upon exit. """ - # def __init__( - # self, - # training_config_path: Union[str, Path], - # lay_down_config_path: Union[str, Path], - # ): - # super().__init__(training_config_path, lay_down_config_path) - # self.setup() + @classmethod + def from_config(cls, config_path: Union[str, Path]) -> "TempPrimaiteSession": + """Create a temporary PrimaiteSession object from a config file.""" + config_path = Path(config_path) + with open(config_path, "r") as f: + config = yaml.safe_load(f) - # @property - # def env(self) -> Primaite: - # """Direct access to the env for ease of testing.""" - # return self._agent_session._env # noqa + return super().from_config(cfg=config) - # def __enter__(self): - # return self + def __enter__(self): + return self - # def __exit__(self, type, value, tb): - # shutil.rmtree(self.session_path) - # _LOGGER.debug(f"Deleted temp session directory: {self.session_path}") + def __exit__(self, type, value, tb): + pass -@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3 @pytest.fixture -def temp_primaite_session(request): - """ - Provides a temporary PrimaiteSession instance. +def temp_primaite_session(request) -> TempPrimaiteSession: + """Create a temporary PrimaiteSession object.""" - It's temporary as it uses a temporary directory as the session path. - - To use this fixture you need to: - - - parametrize your test function with: - - - "temp_primaite_session" - - [[path to training config, path to lay down config]] - - Include the temp_primaite_session fixture as a param in your test - function. - - use the temp_primaite_session as a context manager assigning is the - name 'session'. - - .. code:: python - - from primaite.config.lay_down_config import dos_very_basic_config_path - from primaite.config.training_config import main_training_config_path - @pytest.mark.parametrize( - "temp_primaite_session", - [ - [main_training_config_path(), dos_very_basic_config_path()] - ], - indirect=True - ) - def test_primaite_session(temp_primaite_session): - with temp_primaite_session as session: - # Learning outputs are saved in session.learning_path - session.learn() - - # Evaluation outputs are saved in session.evaluation_path - session.evaluate() - - # To ensure that all files are written, you must call .close() - session.close() - - # If you need to inspect any session outputs, it must be done - # inside the context manager - - # Now that we've exited the context manager, the - # session.session_path directory and its contents are deleted - """ - training_config_path = request.param[0] - lay_down_config_path = request.param[1] - with patch("primaite.agents.agent_abc.get_session_path", get_temp_session_path) as mck: - mck.session_timestamp = datetime.now() - - return TempPrimaiteSession(training_config_path, lay_down_config_path) - - -@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3 -@pytest.fixture -def temp_session_path() -> Path: - """ - Get a temp directory session path the test session will output to. - - :return: The session directory path. - """ - session_timestamp = datetime.now() - date_dir = session_timestamp.strftime("%Y-%m-%d") - session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = Path(tempfile.gettempdir()) / "_primaite" / date_dir / session_path - session_path.mkdir(exist_ok=True, parents=True) - - return session_path + config_path = request.param[0] + return TempPrimaiteSession.from_config(config_path=config_path) diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py new file mode 100644 index 00000000..5e1da4ff --- /dev/null +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -0,0 +1,51 @@ +import pytest + +from tests.conftest import TempPrimaiteSession + +CFG_PATH = "tests/assets/configs/test_primaite_session.yaml" +TRAINING_ONLY_PATH = "tests/assets/configs/train_only_primaite_session.yaml" +EVAL_ONLY_PATH = "tests/assets/configs/eval_only_primaite_session.yaml" + + +class TestPrimaiteSession: + @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) + def test_creating_session(self, temp_primaite_session): + """Check that creating a session from config works.""" + with temp_primaite_session as session: + if not isinstance(session, TempPrimaiteSession): + raise AssertionError + + assert session is not None + assert session.simulation + assert len(session.agents) == 3 + assert len(session.rl_agents) == 1 + + assert session.policy + assert session.env + + assert session.simulation.network + assert len(session.simulation.network.nodes) == 10 + + @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) + def test_start_session(self, temp_primaite_session): + """Make sure you can go all the way through the session without errors.""" + with temp_primaite_session as session: + session: TempPrimaiteSession + session.start_session() + # TODO: check that env was closed, that the model was saved, etc. + + @pytest.mark.parametrize("temp_primaite_session", [[TRAINING_ONLY_PATH]], indirect=True) + def test_training_only_session(self, temp_primaite_session): + """Check that you can run a training-only session.""" + with temp_primaite_session as session: + session: TempPrimaiteSession + session.start_session() + # TODO: include checks that the model was trained, e.g. that the loss changed and checkpoints were saved? + + @pytest.mark.parametrize("temp_primaite_session", [[EVAL_ONLY_PATH]], indirect=True) + def test_eval_only_session(self, temp_primaite_session): + """Check that you can load a model and run an eval-only session.""" + with temp_primaite_session as session: + session: TempPrimaiteSession + session.start_session() + # TODO: include checks that the model was loaded and that the eval-only session ran