diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py index b2f5e786..8a137629 100644 --- a/src/primaite/game/agent/observations/acl_observation.py +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -65,8 +65,7 @@ class ACLObservation(AbstractObservation, discriminator="acl"): self.port_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(port_list)} self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)} self.default_observation: Dict = { - i - + 1: { + i: { "position": i, "permission": 0, "source_ip_id": 0, @@ -94,12 +93,11 @@ class ACLObservation(AbstractObservation, discriminator="acl"): return self.default_observation obs = {} acl_items = dict(acl_state.items()) - i = 1 # don't show rule 0 for compatibility reasons. - while i < self.num_rules + 1: + for i in range(self.num_rules): rule_state = acl_items[i] if rule_state is None: obs[i] = { - "position": i - 1, + "position": i, "permission": 0, "source_ip_id": 0, "source_wildcard_id": 0, @@ -125,7 +123,7 @@ class ACLObservation(AbstractObservation, discriminator="acl"): protocol = rule_state["protocol"] protocol_id = self.protocol_to_id.get(protocol, 1) obs[i] = { - "position": i - 1, + "position": i, "permission": rule_state["action"], "source_ip_id": src_node_id, "source_wildcard_id": src_wildcard_id, @@ -135,7 +133,6 @@ class ACLObservation(AbstractObservation, discriminator="acl"): "dest_port_id": dst_port_id, "protocol_id": protocol_id, } - i += 1 return obs @property @@ -148,8 +145,7 @@ class ACLObservation(AbstractObservation, discriminator="acl"): """ return spaces.Dict( { - i - + 1: spaces.Dict( + i: spaces.Dict( { "position": spaces.Discrete(self.num_rules), "permission": spaces.Discrete(3), diff --git a/tests/integration_tests/game_layer/observations/test_acl_observations.py b/tests/integration_tests/game_layer/observations/test_acl_observations.py index 0a633b2d..c70c454d 100644 --- a/tests/integration_tests/game_layer/observations/test_acl_observations.py +++ b/tests/integration_tests/game_layer/observations/test_acl_observations.py @@ -47,7 +47,7 @@ def test_acl_observations(simulation): observation_space = acl_obs.observe(simulation.describe_state()) assert observation_space.get(1) is not None rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP - assert rule_obs.get("position") == 0 # rule was put at position 1 (0 because counting from 1 instead of 1) + assert rule_obs.get("position") == 1 # rule was put at position 1 assert rule_obs.get("permission") == 1 # permit = 1 deny = 2 assert rule_obs.get("source_ip_id") == 1 # applies to all source nodes assert rule_obs.get("dest_ip_id") == 1 # applies to all destination nodes @@ -60,7 +60,7 @@ def test_acl_observations(simulation): observation_space = acl_obs.observe(simulation.describe_state()) assert observation_space.get(1) is not None rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP - assert rule_obs.get("position") == 0 + assert rule_obs.get("position") == 1 assert rule_obs.get("permission") == 0 assert rule_obs.get("source_ip_id") == 0 assert rule_obs.get("dest_ip_id") == 0 diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 874fa49e..431fff91 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -11,15 +11,15 @@ from primaite.utils.validation.port import PORT_LOOKUP def check_default_rules(acl_obs): assert len(acl_obs) == 7 - assert all(acl_obs[i]["position"] == i - 1 for i in range(1, 8)) - assert all(acl_obs[i]["permission"] == 0 for i in range(1, 8)) - assert all(acl_obs[i]["source_ip_id"] == 0 for i in range(1, 8)) - assert all(acl_obs[i]["source_wildcard_id"] == 0 for i in range(1, 8)) - assert all(acl_obs[i]["source_port_id"] == 0 for i in range(1, 8)) - assert all(acl_obs[i]["dest_ip_id"] == 0 for i in range(1, 8)) - assert all(acl_obs[i]["dest_wildcard_id"] == 0 for i in range(1, 8)) - assert all(acl_obs[i]["dest_port_id"] == 0 for i in range(1, 8)) - assert all(acl_obs[i]["protocol_id"] == 0 for i in range(1, 8)) + assert all(acl_obs[i]["position"] == i for i in range(7)) + assert all(acl_obs[i]["permission"] == 0 for i in range(7)) + assert all(acl_obs[i]["source_ip_id"] == 0 for i in range(7)) + assert all(acl_obs[i]["source_wildcard_id"] == 0 for i in range(7)) + assert all(acl_obs[i]["source_port_id"] == 0 for i in range(7)) + assert all(acl_obs[i]["dest_ip_id"] == 0 for i in range(7)) + assert all(acl_obs[i]["dest_wildcard_id"] == 0 for i in range(7)) + assert all(acl_obs[i]["dest_port_id"] == 0 for i in range(7)) + assert all(acl_obs[i]["protocol_id"] == 0 for i in range(7)) def test_firewall_observation(): @@ -75,7 +75,7 @@ def test_firewall_observation(): observation = firewall_observation.observe(firewall.describe_state()) observed_rule = observation["ACL"]["INTERNAL"]["INBOUND"][5] - assert observed_rule["position"] == 4 + assert observed_rule["position"] == 5 assert observed_rule["permission"] == 2 assert observed_rule["source_ip_id"] == 2 assert observed_rule["source_wildcard_id"] == 3 diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index 495e102d..ddbbb655 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -53,7 +53,7 @@ def test_router_observation(): # Observe the state using the RouterObservation instance observed_output = router_observation.observe(router.describe_state()) observed_rule = observed_output["ACL"][5] - assert observed_rule["position"] == 4 + assert observed_rule["position"] == 5 assert observed_rule["permission"] == 2 assert observed_rule["source_ip_id"] == 2 assert observed_rule["source_wildcard_id"] == 3 @@ -77,7 +77,7 @@ def test_router_observation(): ) observed_output = router_observation.observe(router.describe_state()) observed_rule = observed_output["ACL"][2] - assert observed_rule["position"] == 1 + assert observed_rule["position"] == 2 assert observed_rule["permission"] == 1 assert observed_rule["source_ip_id"] == 1 assert observed_rule["source_wildcard_id"] == 1