diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py index 8b3d8ab5..fc603a8a 100644 --- a/src/primaite/game/agent/observations/acl_observation.py +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -64,7 +64,8 @@ class ACLObservation(AbstractObservation, identifier="ACL"): self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)} self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)} self.default_observation: Dict = { - i: { + i + + 1: { "position": i, "permission": 0, "source_ip_id": 0, @@ -75,7 +76,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): "dest_port_id": 0, "protocol_id": 0, } - for i in range(1, self.num_rules + 1) + for i in range(self.num_rules) } def observe(self, state: Dict) -> ObsType: @@ -97,7 +98,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): rule_state = acl_items[i] if rule_state is None: obs[i] = { - "position": i, + "position": i - 1, "permission": 0, "source_ip_id": 0, "source_wildcard_id": 0, @@ -123,7 +124,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): protocol = rule_state["protocol"] protocol_id = self.protocol_to_id.get(protocol, 1) obs[i] = { - "position": i, + "position": i - 1, "permission": rule_state["action"], "source_ip_id": src_node_id, "source_wildcard_id": src_wildcard_id, diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 12a84e9a..959e30f6 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -10,7 +10,7 @@ from primaite.simulator.network.transmission.transport_layer import Port def check_default_rules(acl_obs): assert len(acl_obs) == 7 - assert all(acl_obs[i]["position"] == i for i in range(1, 8)) + assert all(acl_obs[i]["position"] == i - 1 for i in range(1, 8)) assert all(acl_obs[i]["permission"] == 0 for i in range(1, 8)) assert all(acl_obs[i]["source_ip_id"] == 0 for i in range(1, 8)) assert all(acl_obs[i]["source_wildcard_id"] == 0 for i in range(1, 8)) @@ -72,7 +72,7 @@ def test_firewall_observation(): observation = firewall_observation.observe(firewall.describe_state()) observed_rule = observation["ACL"]["INTERNAL"]["INBOUND"][5] - assert observed_rule["position"] == 5 + assert observed_rule["position"] == 4 assert observed_rule["permission"] == 2 assert observed_rule["source_ip_id"] == 2 assert observed_rule["source_wildcard_id"] == 3 diff --git a/tests/integration_tests/game_layer/observations/test_link_observations.py b/tests/integration_tests/game_layer/observations/test_link_observations.py index b13314f1..1a41cad4 100644 --- a/tests/integration_tests/game_layer/observations/test_link_observations.py +++ b/tests/integration_tests/game_layer/observations/test_link_observations.py @@ -4,8 +4,10 @@ from gymnasium import spaces from primaite.game.agent.observations.link_observation import LinkObservation from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.base import Link, Node +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.sim_container import Simulation @@ -71,3 +73,43 @@ def test_link_observation(simulation): observation_state = link_obs.observe(simulation.describe_state()) assert observation_state["PROTOCOLS"]["ALL"] == 1 + + +def test_link_observation_again(): + net = Network() + sim = Simulation(network=net) + switch = Switch(hostname="switch", num_ports=5, operating_state=NodeOperatingState.ON) + computer_1 = Computer( + hostname="computer_1", ip_address="10.0.0.1", subnet_mask="255.255.255.0", start_up_duration=0 + ) + computer_2 = Computer( + hostname="computer_2", ip_address="10.0.0.2", subnet_mask="255.255.255.0", start_up_duration=0 + ) + computer_1.power_on() + computer_2.power_on() + link_1 = net.connect(switch.network_interface[1], computer_1.network_interface[1]) + link_2 = net.connect(switch.network_interface[2], computer_2.network_interface[1]) + assert link_1 is not None + assert link_2 is not None + + link_1_observation = LinkObservation(where=["network", "links", link_1.uuid]) + link_2_observation = LinkObservation(where=["network", "links", link_2.uuid]) + + state = sim.describe_state() + link_1_obs = link_1_observation.observe(state) + link_2_obs = link_2_observation.observe(state) + assert "PROTOCOLS" in link_1_obs + assert "PROTOCOLS" in link_2_obs + assert "ALL" in link_1_obs["PROTOCOLS"] + assert "ALL" in link_2_obs["PROTOCOLS"] + assert link_1_obs["PROTOCOLS"]["ALL"] == 0 + assert link_2_obs["PROTOCOLS"]["ALL"] == 0 + + # Test that the link observation is updated when a packet is sent + computer_1.ping("10.0.0.2") + computer_2.ping("10.0.0.1") + state = sim.describe_state() + link_1_obs = link_1_observation.observe(state) + link_2_obs = link_2_observation.observe(state) + assert link_1_obs["PROTOCOLS"]["ALL"] > 0 + assert link_2_obs["PROTOCOLS"]["ALL"] > 0 diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index 7db6a2c2..55471676 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -50,7 +50,7 @@ def test_router_observation(): # Observe the state using the RouterObservation instance observed_output = router_observation.observe(router.describe_state()) observed_rule = observed_output["ACL"][5] - assert observed_rule["position"] == 5 + assert observed_rule["position"] == 4 assert observed_rule["permission"] == 2 assert observed_rule["source_ip_id"] == 2 assert observed_rule["source_wildcard_id"] == 3 @@ -74,7 +74,7 @@ def test_router_observation(): ) observed_output = router_observation.observe(router.describe_state()) observed_rule = observed_output["ACL"][2] - assert observed_rule["position"] == 2 + assert observed_rule["position"] == 1 assert observed_rule["permission"] == 1 assert observed_rule["source_ip_id"] == 1 assert observed_rule["source_wildcard_id"] == 1