diff --git a/src/primaite/notebooks/training_example_ray_single_agent.ipynb b/src/primaite/notebooks/training_example_ray_single_agent.ipynb index 9b935346..8ee16d41 100644 --- a/src/primaite/notebooks/training_example_ray_single_agent.ipynb +++ b/src/primaite/notebooks/training_example_ray_single_agent.ipynb @@ -96,17 +96,6 @@ "source": [ "algo.save(\"temp/deleteme\")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from primaite.config.load import example_config_path\n", - "from primaite.main import run\n", - "run(example_config_path())" - ] } ], "metadata": { diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 80567aea..b5e43ab3 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -7,7 +7,7 @@ training_config: -game_config: +game: ports: - ARP - DNS @@ -18,523 +18,523 @@ game_config: - TCP - UDP - agents: - - ref: client_1_green_user - team: GREEN - type: GreenWebBrowsingAgent - observation_space: - type: UC2GreenObservation - action_space: - action_list: - - type: DONOTHING - # - # - type: NODE_LOGON - # - type: NODE_LOGOFF - # - type: NODE_APPLICATION_EXECUTE - # options: - # execution_definition: - # target_address: arcd.com +agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com - options: - nodes: - - node_ref: client_2 - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 - reward_function: - reward_components: - - type: DUMMY + reward_function: + reward_components: + - type: DUMMY - agent_settings: - start_step: 5 - frequency: 4 - variance: 3 + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 - - ref: client_1_data_manipulation_red_bot - team: RED - type: RedDatabaseCorruptingAgent + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: - - node_ref: client_1 + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot observations: - - logon_status - - operating_status - services: - - service_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} + operating_status + health_status + folders: {} - action_space: - action_list: - - type: DONOTHING - # - # - type: NODE_LOGON - # - type: NODE_LOGOFF - # - type: NODE_APPLICATION_EXECUTE - # options: - # execution_definition: - # target_address: arcd.com +agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com - options: - nodes: - - node_ref: client_2 - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 - reward_function: - reward_components: - - type: DUMMY + reward_function: + reward_components: + - type: DUMMY - agent_settings: - start_step: 5 - frequency: 4 - variance: 3 + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 - - ref: client_1_data_manipulation_red_bot - team: RED - type: RedDatabaseCorruptingAgent + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: - - node_ref: client_1 + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot observations: - - logon_status - - operating_status - services: - - service_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} + operating_status + health_status + folders: {} - action_space: - action_list: - - type: DONOTHING - # - # - type: NODE_LOGON - # - type: NODE_LOGOFF - # - type: NODE_APPLICATION_EXECUTE - # options: - # execution_definition: - # target_address: arcd.com +agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com - options: - nodes: - - node_ref: client_2 - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 - reward_function: - reward_components: - - type: DUMMY + reward_function: + reward_components: + - type: DUMMY - agent_settings: - start_step: 5 - frequency: 4 - variance: 3 + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 - - ref: client_1_data_manipulation_red_bot - team: RED - type: RedDatabaseCorruptingAgent + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: - - node_ref: client_1 + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot observations: - - logon_status - - operating_status - services: - - service_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} + operating_status + health_status + folders: {} - action_space: - action_list: - - type: DONOTHING - # - # - type: NODE_LOGON - # - type: NODE_LOGOFF - # - type: NODE_APPLICATION_EXECUTE - # options: - # execution_definition: - # target_address: arcd.com +agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com - options: - nodes: - - node_ref: client_2 - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 - reward_function: - reward_components: - - type: DUMMY + reward_function: + reward_components: + - type: DUMMY - agent_settings: - start_step: 5 - frequency: 4 - variance: 3 + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 - - ref: client_1_data_manipulation_red_bot - team: RED - type: RedDatabaseCorruptingAgent + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: - - node_ref: client_1 + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot observations: - - logon_status - - operating_status - services: - - service_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} + operating_status + health_status + folders: {} - action_space: - action_list: - - type: DONOTHING - # FileSystem: # PrimAITE v2 stuff -class TempPrimaiteSession(PrimaiteGame): +class TempPrimaiteSession(PrimaiteSession): """ A temporary PrimaiteSession class. diff --git a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py new file mode 100644 index 00000000..0cf245b4 --- /dev/null +++ b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py @@ -0,0 +1,43 @@ +import ray +import yaml +from ray import air, tune +from ray.rllib.algorithms.ppo import PPOConfig + +from primaite.config.load import example_config_path +from primaite.game.game import PrimaiteGame +from primaite.session.environment import PrimaiteRayMARLEnv + + +def test_rllib_multi_agent_compatibility(): + """Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system.""" + + with open(example_config_path(), "r") as f: + cfg = yaml.safe_load(f) + + game = PrimaiteGame.from_config(cfg) + + ray.shutdown() + ray.init() + + env_config = {"game": game} + config = ( + PPOConfig() + .environment(env=PrimaiteRayMARLEnv, env_config={"game": game}) + .rollouts(num_rollout_workers=0) + .multi_agent( + policies={agent.agent_name for agent in game.rl_agents}, + policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id, + ) + .training(train_batch_size=128) + ) + + tune.Tuner( + "PPO", + run_config=air.RunConfig( + stop={"training_iteration": 128}, + checkpoint_config=air.CheckpointConfig( + checkpoint_frequency=10, + ), + ), + param_space=config, + ).fit() diff --git a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py new file mode 100644 index 00000000..ce23501a --- /dev/null +++ b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py @@ -0,0 +1,38 @@ +import tempfile +from pathlib import Path + +import ray +import yaml +from ray.rllib.algorithms import ppo + +from primaite.config.load import example_config_path +from primaite.game.game import PrimaiteGame +from primaite.session.environment import PrimaiteRayEnv + + +def test_rllib_single_agent_compatibility(): + """Test that the PrimaiteRayEnv class can be used with a single agent RLLIB system.""" + with open(example_config_path(), "r") as f: + cfg = yaml.safe_load(f) + + game = PrimaiteGame.from_config(cfg) + + ray.shutdown() + ray.init() + + env_config = {"game": game} + config = { + "env": PrimaiteRayEnv, + "env_config": env_config, + "disable_env_checking": True, + "num_rollout_workers": 0, + } + + algo = ppo.PPO(config=config) + + for i in range(5): + result = algo.train() + + save_file = Path(tempfile.gettempdir()) / "ray/" + algo.save(save_file) + assert save_file.exists() diff --git a/tests/e2e_integration_tests/environments/test_sb3_environment.py b/tests/e2e_integration_tests/environments/test_sb3_environment.py new file mode 100644 index 00000000..3907ff50 --- /dev/null +++ b/tests/e2e_integration_tests/environments/test_sb3_environment.py @@ -0,0 +1,27 @@ +"""Test that we can create a primaite environment and train sb3 agent with no crash.""" +import tempfile +from pathlib import Path + +import yaml +from stable_baselines3 import PPO + +from primaite.config.load import example_config_path +from primaite.game.game import PrimaiteGame +from primaite.session.environment import PrimaiteGymEnv + + +def test_sb3_compatibility(): + """Test that the Gymnasium environment can be used with an SB3 agent.""" + with open(example_config_path(), "r") as f: + cfg = yaml.safe_load(f) + + game = PrimaiteGame.from_config(cfg) + gym = PrimaiteGymEnv(game=game) + model = PPO("MlpPolicy", gym) + + model.learn(total_timesteps=1000) + + save_path = Path(tempfile.gettempdir()) / "model.zip" + model.save(save_path) + + assert (save_path).exists() diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py index b6122bad..68672b51 100644 --- a/tests/e2e_integration_tests/test_primaite_session.py +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -18,15 +18,15 @@ class TestPrimaiteSession: raise AssertionError assert session is not None - assert session.simulation - assert len(session.agents) == 3 - assert len(session.rl_agents) == 1 + assert session.game.simulation + assert len(session.game.agents) == 3 + assert len(session.game.rl_agents) == 1 assert session.policy assert session.env - assert session.simulation.network - assert len(session.simulation.network.nodes) == 10 + assert session.game.simulation.network + assert len(session.game.simulation.network.nodes) == 10 @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) def test_start_session(self, temp_primaite_session):