diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 894180c1..84ba2c6f 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -94,13 +94,13 @@ class TrainingConfig: "Stable Baselines3 learn/eval output verbosity level" # Access Control List/Rules - apply_implicit_rule: str = True + apply_implicit_rule: str = False "User choice to have Implicit ALLOW or DENY." implicit_acl_rule: RulePermissionType = RulePermissionType.DENY "ALLOW or DENY implicit firewall rule to go at the end of list of ACL list." - max_number_acl_rules: int = 0 + max_number_acl_rules: int = 10 "Sets a limit for number of acl rules allowed in the list and environment." # Reward values diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 023f55b0..aeccd933 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -519,30 +519,26 @@ class AccessControlList(AbstractObservationComponent): port_int = self.env.ports_list.index(port) + 2 else: _LOGGER.info(f"Port {port} could not be found.") + port_int = None # Either do the multiply on the obs space # Change the obs to - items_to_add = [ - permission_int, - source_ip_int, - dest_ip_int, - protocol_int, - port_int, - position, - ] - position = position * 6 - for item in items_to_add: - # print("position", position, "\nitem", int(item)) - obs.insert(position, int(item)) - position += 1 + obs.extend( + [ + permission_int, + source_ip_int, + dest_ip_int, + protocol_int, + port_int, + position, + ] + ) + else: - starting_position = index * 6 - for placeholder in range(6): - obs.insert(starting_position, 0) - starting_position += 1 + obs.extend([0, 0, 0, 0, 0, 0]) # print("current obs", obs, "\n" ,len(obs)) - self.current_observation = obs + self.current_observation[:] = obs def generate_structure(self): """Return a list of labels for the components of the flattened observation space.""" diff --git a/tests/config/obs_tests/main_config_ACCESS_CONTROL_LIST.yaml b/tests/config/obs_tests/main_config_ACCESS_CONTROL_LIST.yaml index 7aa30205..ff11d2c8 100644 --- a/tests/config/obs_tests/main_config_ACCESS_CONTROL_LIST.yaml +++ b/tests/config/obs_tests/main_config_ACCESS_CONTROL_LIST.yaml @@ -5,7 +5,8 @@ # "STABLE_BASELINES3_PPO" # "STABLE_BASELINES3_A2C" # "GENERIC" -agent_identifier: STABLE_BASELINES3_A2C +agent_framework: SB3 +agent_identifier: PPO # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -21,7 +22,7 @@ apply_implicit_rule: True # Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY) implicit_acl_rule: DENY # Total number of ACL rules allowed in the environment -max_number_acl_rules: 10 +max_number_acl_rules: 3 observation_space: components: @@ -31,7 +32,7 @@ observation_space: time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: TRAIN # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/single_action_space_main_config.yaml b/tests/config/single_action_space_main_config.yaml index 501a4999..f72b43df 100644 --- a/tests/config/single_action_space_main_config.yaml +++ b/tests/config/single_action_space_main_config.yaml @@ -39,6 +39,8 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip] # The high value for the observation space observation_space_high_value: 1000000000 +# Choice whether to have an ALLOW or DENY implicit rule or not (TRUE or FALSE) +apply_implicit_rule: True implicit_acl_rule: DENY max_number_acl_rules: 10 # Reward values diff --git a/tests/conftest.py b/tests/conftest.py index 388bc034..c3799f15 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,6 +58,7 @@ class TempPrimaiteSession(PrimaiteSession): def __exit__(self, type, value, tb): shutil.rmtree(self.session_path) + # shutil.rmtree(self.session_path.parent) _LOGGER.debug(f"Deleted temp session directory: {self.session_path}") diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index 9c0a340b..6d805992 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -257,7 +257,7 @@ class TestLinkTrafficLevels: "temp_primaite_session", [ [ - TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml", + TEST_CONFIG_ROOT / "obs_tests/main_config_ACCESS_CONTROL_LIST.yaml", TEST_CONFIG_ROOT / "obs_tests/laydown_ACL.yaml", ] ], @@ -273,7 +273,7 @@ class TestAccessControlList: env.update_environent_obs() # we have two ACLs - assert env.env_obs.shape == (6 * 3) + assert env.env_obs.shape == (18,) def test_values(self, temp_primaite_session): """Test that traffic values are encoded correctly. @@ -296,7 +296,7 @@ class TestAccessControlList: # therefore the first and third elements should be 6 and all others 0 # (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%) print(obs) - assert np.array_equal(obs, []) + assert np.array_equal(obs, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2]) def test_observation_space_with_implicit_rule(self, temp_primaite_session): """Test observation space is what is expected when an agent adds ACLs during an episode.""" diff --git a/tests/test_seeding_and_deterministic_session.py b/tests/test_seeding_and_deterministic_session.py index 34cb43fb..789e7d13 100644 --- a/tests/test_seeding_and_deterministic_session.py +++ b/tests/test_seeding_and_deterministic_session.py @@ -11,6 +11,7 @@ from tests import TEST_CONFIG_ROOT ) def test_seeded_learning(temp_primaite_session): """Test running seeded learning produces the same output when ran twice.""" + """ expected_mean_reward_per_episode = { 1: -90.703125, 2: -91.15234375, @@ -23,14 +24,22 @@ def test_seeded_learning(temp_primaite_session): 9: -112.79296875, 10: -100.01953125, } + """ with temp_primaite_session as session: assert session._training_config.seed == 67890, ( "Expected output is based upon a agent that was trained with " "seed 67890" ) session.learn() - actual_mean_reward_per_episode = session.learn_av_reward_per_episode() + actual_mean_reward_per_episode_run_1 = session.learn_av_reward_per_episode() - assert actual_mean_reward_per_episode == expected_mean_reward_per_episode + with temp_primaite_session as session: + assert session._training_config.seed == 67890, ( + "Expected output is based upon a agent that was trained with " "seed 67890" + ) + session.learn() + actual_mean_reward_per_episode_run_2 = session.learn_av_reward_per_episode() + + assert actual_mean_reward_per_episode_run_1 == actual_mean_reward_per_episode_run_2 @pytest.mark.skip(reason="Inconsistent results. Needs someone with RL " "knowledge to investigate further.")