diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index aafa27eb..c743e41a 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -428,15 +428,15 @@ class AccessControlList(AbstractObservationComponent): acl_rule2 position, ... ] - """ - # Terms (for ACL observation space): - # [0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW) - # [0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to IP addresses) - # [0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to IP addresses) - # [0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol) - # [0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port) - # [0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list) + Terms (for ACL Observation Space): + [0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW) + [0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to IP addresses) + [0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to IP addresses) + [0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol) + [0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port) + [0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list) + """ _DATA_TYPE: type = np.int64 @@ -521,9 +521,6 @@ class AccessControlList(AbstractObservationComponent): _LOGGER.info(f"Port {port} could not be found.") port_int = None - # Either do the multiply on the obs space - # Change the obs to - print("current obs", port_int) obs.extend( [ permission_int, diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index 432dd15d..d32dfa03 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -253,73 +253,70 @@ class TestAccessControlList: """Test the AccessControlList observation component (in isolation).""" def test_obs_shape(self, temp_primaite_session): - """Try creating env with MultiDiscrete observation space.""" + """Try creating env with MultiDiscrete observation space. + + The laydown has 3 ACL Rules - that is the maximum_acl_rules it can have. + Each ACL Rule in the observation space has 6 different elements: + + 6 * 3 = 18 + """ with temp_primaite_session as session: env = session.env env.update_environent_obs() - # we have two ACLs assert env.env_obs.shape == (18,) def test_values(self, temp_primaite_session): """Test that traffic values are encoded correctly. The laydown has: - * two services - * three nodes - * two links - * an IER trying to send 999 bits of data over both links the whole time (via the first service) - * link bandwidth of 1000, therefore the utilisation is 99.9% + * one ACL IMPLICIT DENY rule + + Therefore, the ACL is full of NAs aka zeros and just 6 non-zero elements representing DENY ANY ANY ANY at + Position 2. """ with temp_primaite_session as session: env = session.env obs, reward, done, info = env.step(0) obs, reward, done, info = env.step(0) - # the observation space has combine_service_traffic set to False, so the space has this format: - # [link1_service1, link1_service2, link2_service1, link2_service2] - # we send 999 bits of data via link1 and link2 on service 1. - # therefore the first and third elements should be 6 and all others 0 - # (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%) - print(obs) assert np.array_equal(obs, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2]) def test_observation_space_with_implicit_rule(self, temp_primaite_session): """ Test observation space is what is expected when an agent adds ACLs during an episode. - Observation space at the end of the episode. - At the start of the episode, there is a single implicit Deny rule = 1,1,1,1,1,0 - (0 represents its initial position at top of ACL list) - (1, 1, 1, 2, 1, 2, 0) - ACTION - On Step 5, there is a rule added at POSITION 2: 2,2,3,2,3,0 - (1, 3, 1, 2, 2, 1) - SECOND ACTION - On Step 7, there is a second rule added at POSITION 1: 2,4,2,3,3,1 - THINK THE RULES SHOULD BE THE OTHER WAY AROUND IN THE CURRENT OBSERVATION + At the start of the episode, there is a single implicit DENY rule + In the observation space IMPLICIT DENY: 1,1,1,1,1,0 + 0 shows the rule is the start (when episode began no other rules were created) so this is correct. + + On Step 2, there is an ACL rule added at Position 0: 2,2,3,2,3,0 + + On Step 4, there is a second ACL rule added at POSITION 1: 2,4,2,3,3,1 + + The final observation space should be this: + [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2] + + The ACL Rule from Step 2 is added first and has a HIGHER position than the ACL rule from Step 4 + but both come before the IMPLICIT DENY which will ALWAYS be at the end of the ACL List. """ # TODO: Refactor this at some point to build a custom ACL Hardcoded # Agent and then patch the AgentIdentifier Enum class so that it # has ACL_AGENT. This then allows us to set the agent identified in # the main config and is a bit cleaner. - # Used to use env from test fixture but AtrributeError function object has no 'training_config' + with temp_primaite_session as session: env = session.env - training_config = env.training_config for episode in range(0, training_config.num_train_episodes): for step in range(0, training_config.num_train_steps): - # Send the observation space to the agent to get an action - # TEMP - random action for now - # action = env.blue_agent_action(obs) + # Do nothing action action = 0 - print("Episode:", episode, "\nStep:", step) if step == 2: - # [1, 1, 2, 1, 1, 1, 1(position)] - # NEED [1, 1, 1, 2, 1, 1, 1] - # Creates an ACL rule - # Allows traffic from server_1 to node_1 on port FTP + # Action to add the first ACL rule action = 43 elif step == 4: + # Action to add the second ACL rule action = 96 # Run the simulation step on the live environment @@ -329,11 +326,51 @@ class TestAccessControlList: if done: break obs = env.env_obs - print("what i am testing", obs) - # acl rule 1 - # source is 1 should be 4 - # dest is 3 should be 2 - # [2 2 3 2 3 0 2 1?4 3?2 3 3 1 1 1 1 1 1 2] - # np.array_equal(obs, [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2]) + assert np.array_equal(obs, [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2]) - # assert obs == [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2] + + def test_observation_space_with_different_positions(self, temp_primaite_session): + """ + Test observation space is what is expected when an agent adds ACLs during an episode. + + At the start of the episode, there is a single implicit DENY rule + In the observation space IMPLICIT DENY: 1,1,1,1,1,0 + 0 shows the rule is the start (when episode began no other rules were created) so this is correct. + + On Step 2, there is an ACL rule added at Position 1: 2,2,3,2,3,1 + + On Step 4 there is a second ACL rule added at Position 0: 2,4,2,3,3,0 + + The final observation space should be this: + [2 , 4, 2, 3, 3, 0, 2, 2, 3, 2, 3, 1, 1, 1, 1, 1, 1, 2] + + The ACL Rule from Step 2 is added before and has a LOWER position than the ACL rule from Step 4 + but both come before the IMPLICIT DENY which will ALWAYS be at the end of the ACL List. + """ + # TODO: Refactor this at some point to build a custom ACL Hardcoded + # Agent and then patch the AgentIdentifier Enum class so that it + # has ACL_AGENT. This then allows us to set the agent identified in + # the main config and is a bit cleaner. + + with temp_primaite_session as session: + env = session.env + training_config = env.training_config + for episode in range(0, training_config.num_train_episodes): + for step in range(0, training_config.num_train_steps): + # Do nothing action + action = 0 + if step == 2: + # Action to add the first ACL rule + action = 44 + elif step == 4: + # Action to add the second ACL rule + action = 95 + # Run the simulation step on the live environment + obs, reward, done, info = env.step(action) + + # Break if done is True + if done: + break + obs = env.env_obs + + assert np.array_equal(obs, [2, 4, 2, 3, 3, 0, 2, 2, 3, 2, 3, 1, 1, 1, 1, 1, 1, 2])