From 709486d739b31ea474432bdcc9e5dc8f4b4d0bb6 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 1 Apr 2024 16:06:12 +0100 Subject: [PATCH] #2417 test firewall and router obs --- .../agent/observations/acl_observation.py | 9 +- .../observations/firewall_observation.py | 96 +++++++------ .../observations/test_firewall_observation.py | 128 ++++++++++++++++++ .../observations/test_router_observation.py | 108 +++++++++++++++ 4 files changed, 292 insertions(+), 49 deletions(-) create mode 100644 tests/integration_tests/game_layer/observations/test_firewall_observation.py create mode 100644 tests/integration_tests/game_layer/observations/test_router_observation.py diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py index fc603a8a..8b3d8ab5 100644 --- a/src/primaite/game/agent/observations/acl_observation.py +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -64,8 +64,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)} self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)} self.default_observation: Dict = { - i - + 1: { + i: { "position": i, "permission": 0, "source_ip_id": 0, @@ -76,7 +75,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): "dest_port_id": 0, "protocol_id": 0, } - for i in range(self.num_rules) + for i in range(1, self.num_rules + 1) } def observe(self, state: Dict) -> ObsType: @@ -98,7 +97,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): rule_state = acl_items[i] if rule_state is None: obs[i] = { - "position": i - 1, + "position": i, "permission": 0, "source_ip_id": 0, "source_wildcard_id": 0, @@ -124,7 +123,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): protocol = rule_state["protocol"] protocol_id = self.protocol_to_id.get(protocol, 1) obs[i] = { - "position": i - 1, + "position": i, "permission": rule_state["action"], "source_ip_id": src_node_id, "source_wildcard_id": src_wildcard_id, diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py index 69398d96..ab48e606 100644 --- a/src/primaite/game/agent/observations/firewall_observation.py +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -63,12 +63,12 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): self.where: WhereType = where self.ports: List[PortObservation] = [ - PortObservation(where=self.where + ["port", port_num]) for port_num in (1, 2, 3) + PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3) ] # TODO: check what the port nums are for firewall. self.internal_inbound_acl = ACLObservation( - where=self.where + ["acl", "internal", "inbound"], + where=self.where + ["internal_inbound_acl", "acl"], num_rules=num_rules, ip_list=ip_list, wildcard_list=wildcard_list, @@ -76,7 +76,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): protocol_list=protocol_list, ) self.internal_outbound_acl = ACLObservation( - where=self.where + ["acl", "internal", "outbound"], + where=self.where + ["internal_outbound_acl", "acl"], num_rules=num_rules, ip_list=ip_list, wildcard_list=wildcard_list, @@ -84,7 +84,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): protocol_list=protocol_list, ) self.dmz_inbound_acl = ACLObservation( - where=self.where + ["acl", "dmz", "inbound"], + where=self.where + ["dmz_inbound_acl", "acl"], num_rules=num_rules, ip_list=ip_list, wildcard_list=wildcard_list, @@ -92,7 +92,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): protocol_list=protocol_list, ) self.dmz_outbound_acl = ACLObservation( - where=self.where + ["acl", "dmz", "outbound"], + where=self.where + ["dmz_outbound_acl", "acl"], num_rules=num_rules, ip_list=ip_list, wildcard_list=wildcard_list, @@ -100,7 +100,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): protocol_list=protocol_list, ) self.external_inbound_acl = ACLObservation( - where=self.where + ["acl", "external", "inbound"], + where=self.where + ["external_inbound_acl", "acl"], num_rules=num_rules, ip_list=ip_list, wildcard_list=wildcard_list, @@ -108,7 +108,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): protocol_list=protocol_list, ) self.external_outbound_acl = ACLObservation( - where=self.where + ["acl", "external", "outbound"], + where=self.where + ["external_outbound_acl", "acl"], num_rules=num_rules, ip_list=ip_list, wildcard_list=wildcard_list, @@ -118,17 +118,19 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): self.default_observation = { "PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)}, - "INTERNAL": { - "INBOUND": self.internal_inbound_acl.default_observation, - "OUTBOUND": self.internal_outbound_acl.default_observation, - }, - "DMZ": { - "INBOUND": self.dmz_inbound_acl.default_observation, - "OUTBOUND": self.dmz_outbound_acl.default_observation, - }, - "EXTERNAL": { - "INBOUND": self.external_inbound_acl.default_observation, - "OUTBOUND": self.external_outbound_acl.default_observation, + "ACL": { + "INTERNAL": { + "INBOUND": self.internal_inbound_acl.default_observation, + "OUTBOUND": self.internal_outbound_acl.default_observation, + }, + "DMZ": { + "INBOUND": self.dmz_inbound_acl.default_observation, + "OUTBOUND": self.dmz_outbound_acl.default_observation, + }, + "EXTERNAL": { + "INBOUND": self.external_inbound_acl.default_observation, + "OUTBOUND": self.external_outbound_acl.default_observation, + }, }, } @@ -143,17 +145,19 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): """ obs = { "PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)}, - "INTERNAL": { - "INBOUND": self.internal_inbound_acl.observe(state), - "OUTBOUND": self.internal_outbound_acl.observe(state), - }, - "DMZ": { - "INBOUND": self.dmz_inbound_acl.observe(state), - "OUTBOUND": self.dmz_outbound_acl.observe(state), - }, - "EXTERNAL": { - "INBOUND": self.external_inbound_acl.observe(state), - "OUTBOUND": self.external_outbound_acl.observe(state), + "ACL": { + "INTERNAL": { + "INBOUND": self.internal_inbound_acl.observe(state), + "OUTBOUND": self.internal_outbound_acl.observe(state), + }, + "DMZ": { + "INBOUND": self.dmz_inbound_acl.observe(state), + "OUTBOUND": self.dmz_outbound_acl.observe(state), + }, + "EXTERNAL": { + "INBOUND": self.external_inbound_acl.observe(state), + "OUTBOUND": self.external_outbound_acl.observe(state), + }, }, } return obs @@ -169,22 +173,26 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): space = spaces.Dict( { "PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), - "INTERNAL": spaces.Dict( + "ACL": spaces.Dict( { - "INBOUND": self.internal_inbound_acl.space, - "OUTBOUND": self.internal_outbound_acl.space, - } - ), - "DMZ": spaces.Dict( - { - "INBOUND": self.dmz_inbound_acl.space, - "OUTBOUND": self.dmz_outbound_acl.space, - } - ), - "EXTERNAL": spaces.Dict( - { - "INBOUND": self.external_inbound_acl.space, - "OUTBOUND": self.external_outbound_acl.space, + "INTERNAL": spaces.Dict( + { + "INBOUND": self.internal_inbound_acl.space, + "OUTBOUND": self.internal_outbound_acl.space, + } + ), + "DMZ": spaces.Dict( + { + "INBOUND": self.dmz_inbound_acl.space, + "OUTBOUND": self.dmz_outbound_acl.space, + } + ), + "EXTERNAL": spaces.Dict( + { + "INBOUND": self.external_inbound_acl.space, + "OUTBOUND": self.external_outbound_acl.space, + } + ), } ), } diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py new file mode 100644 index 00000000..12a84e9a --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -0,0 +1,128 @@ +from primaite.game.agent.observations.firewall_observation import FirewallObservation +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.network.firewall import Firewall +from primaite.simulator.network.hardware.nodes.network.router import ACLAction +from primaite.simulator.network.hardware.nodes.network.switch import Switch +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port + + +def check_default_rules(acl_obs): + assert len(acl_obs) == 7 + assert all(acl_obs[i]["position"] == i for i in range(1, 8)) + assert all(acl_obs[i]["permission"] == 0 for i in range(1, 8)) + assert all(acl_obs[i]["source_ip_id"] == 0 for i in range(1, 8)) + assert all(acl_obs[i]["source_wildcard_id"] == 0 for i in range(1, 8)) + assert all(acl_obs[i]["source_port_id"] == 0 for i in range(1, 8)) + assert all(acl_obs[i]["dest_ip_id"] == 0 for i in range(1, 8)) + assert all(acl_obs[i]["dest_wildcard_id"] == 0 for i in range(1, 8)) + assert all(acl_obs[i]["dest_port_id"] == 0 for i in range(1, 8)) + assert all(acl_obs[i]["protocol_id"] == 0 for i in range(1, 8)) + + +def test_firewall_observation(): + """Test adding/removing acl rules and enabling/disabling ports.""" + net = Network() + firewall = Firewall(hostname="firewall", operating_state=NodeOperatingState.ON) + firewall_observation = FirewallObservation( + where=[], + num_rules=7, + ip_list=["10.0.0.1", "10.0.0.2"], + wildcard_list=["0.0.0.255", "0.0.0.1"], + port_list=["HTTP", "DNS"], + protocol_list=["TCP"], + ) + + observation = firewall_observation.observe(firewall.describe_state()) + assert "ACL" in observation + assert "PORTS" in observation + assert "INTERNAL" in observation["ACL"] + assert "EXTERNAL" in observation["ACL"] + assert "DMZ" in observation["ACL"] + assert "INBOUND" in observation["ACL"]["INTERNAL"] + assert "OUTBOUND" in observation["ACL"]["INTERNAL"] + assert "INBOUND" in observation["ACL"]["EXTERNAL"] + assert "OUTBOUND" in observation["ACL"]["EXTERNAL"] + assert "INBOUND" in observation["ACL"]["DMZ"] + assert "OUTBOUND" in observation["ACL"]["DMZ"] + all_acls = ( + observation["ACL"]["INTERNAL"]["INBOUND"], + observation["ACL"]["INTERNAL"]["OUTBOUND"], + observation["ACL"]["EXTERNAL"]["INBOUND"], + observation["ACL"]["EXTERNAL"]["OUTBOUND"], + observation["ACL"]["DMZ"]["INBOUND"], + observation["ACL"]["DMZ"]["OUTBOUND"], + ) + for acl_obs in all_acls: + check_default_rules(acl_obs) + + # add a rule to the internal inbound and check that the observation is correct + firewall.internal_inbound_acl.add_rule( + action=ACLAction.DENY, + protocol=IPProtocol.TCP, + src_ip_address="10.0.0.1", + src_wildcard_mask="0.0.0.1", + dst_ip_address="10.0.0.2", + dst_wildcard_mask="0.0.0.1", + src_port=Port.HTTP, + dst_port=Port.HTTP, + position=5, + ) + + observation = firewall_observation.observe(firewall.describe_state()) + observed_rule = observation["ACL"]["INTERNAL"]["INBOUND"][5] + assert observed_rule["position"] == 5 + assert observed_rule["permission"] == 2 + assert observed_rule["source_ip_id"] == 2 + assert observed_rule["source_wildcard_id"] == 3 + assert observed_rule["source_port_id"] == 2 + assert observed_rule["dest_ip_id"] == 3 + assert observed_rule["dest_wildcard_id"] == 3 + assert observed_rule["dest_port_id"] == 2 + assert observed_rule["protocol_id"] == 2 + + # check that none of the other acls have changed + all_acls = ( + observation["ACL"]["INTERNAL"]["OUTBOUND"], + observation["ACL"]["EXTERNAL"]["INBOUND"], + observation["ACL"]["EXTERNAL"]["OUTBOUND"], + observation["ACL"]["DMZ"]["INBOUND"], + observation["ACL"]["DMZ"]["OUTBOUND"], + ) + for acl_obs in all_acls: + check_default_rules(acl_obs) + + # remove the rule and check that the observation is correct + firewall.internal_inbound_acl.remove_rule(5) + observation = firewall_observation.observe(firewall.describe_state()) + all_acls = ( + observation["ACL"]["INTERNAL"]["INBOUND"], + observation["ACL"]["INTERNAL"]["OUTBOUND"], + observation["ACL"]["EXTERNAL"]["INBOUND"], + observation["ACL"]["EXTERNAL"]["OUTBOUND"], + observation["ACL"]["DMZ"]["INBOUND"], + observation["ACL"]["DMZ"]["OUTBOUND"], + ) + for acl_obs in all_acls: + check_default_rules(acl_obs) + + # check that there are three ports in the observation + assert len(observation["PORTS"]) == 3 + + # check that the ports are all disabled + assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(1, 4)) + + # connect a switch to the firewall and check that only the correct port is updated + switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON) + link = net.connect(firewall.network_interface[1], switch.network_interface[1]) + assert firewall.network_interface[1].enabled + observation = firewall_observation.observe(firewall.describe_state()) + assert observation["PORTS"][1]["operating_status"] == 1 + assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(2, 4)) + + # disable the port and check that the operating status is updated + firewall.network_interface[1].disable() + assert not firewall.network_interface[1].enabled + observation = firewall_observation.observe(firewall.describe_state()) + assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(1, 4)) diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py new file mode 100644 index 00000000..7db6a2c2 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -0,0 +1,108 @@ +from pprint import pprint + +from primaite.game.agent.observations.acl_observation import ACLObservation +from primaite.game.agent.observations.nic_observations import PortObservation +from primaite.game.agent.observations.router_observation import RouterObservation +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router +from primaite.simulator.network.hardware.nodes.network.switch import Switch +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.sim_container import Simulation + + +def test_router_observation(): + """Test adding/removing acl rules and enabling/disabling ports.""" + net = Network() + router = Router(hostname="router", num_ports=5, operating_state=NodeOperatingState.ON) + + ports = [PortObservation(where=["NICs", i]) for i in range(1, 6)] + acl = ACLObservation( + where=["acl", "acl"], + num_rules=7, + ip_list=["10.0.0.1", "10.0.0.2"], + wildcard_list=["0.0.0.255", "0.0.0.1"], + port_list=["HTTP", "DNS"], + protocol_list=["TCP"], + ) + router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl) + + # Observe the state using the RouterObservation instance + observed_output = router_observation.observe(router.describe_state()) + + # Check that the right number of ports and acls are in the router observation + assert len(observed_output["PORTS"]) == 8 + assert len(observed_output["ACL"]) == 7 + + # Add an ACL rule to the router + router.acl.add_rule( + action=ACLAction.DENY, + protocol=IPProtocol.TCP, + src_ip_address="10.0.0.1", + src_wildcard_mask="0.0.0.1", + dst_ip_address="10.0.0.2", + dst_wildcard_mask="0.0.0.1", + src_port=Port.HTTP, + dst_port=Port.HTTP, + position=5, + ) + # Observe the state using the RouterObservation instance + observed_output = router_observation.observe(router.describe_state()) + observed_rule = observed_output["ACL"][5] + assert observed_rule["position"] == 5 + assert observed_rule["permission"] == 2 + assert observed_rule["source_ip_id"] == 2 + assert observed_rule["source_wildcard_id"] == 3 + assert observed_rule["source_port_id"] == 2 + assert observed_rule["dest_ip_id"] == 3 + assert observed_rule["dest_wildcard_id"] == 3 + assert observed_rule["dest_port_id"] == 2 + assert observed_rule["protocol_id"] == 2 + + # Add an ACL rule with ALL/NONE values and check that the observation is correct + router.acl.add_rule( + action=ACLAction.PERMIT, + protocol=None, + src_ip_address=None, + src_wildcard_mask=None, + dst_ip_address=None, + dst_wildcard_mask=None, + src_port=None, + dst_port=None, + position=2, + ) + observed_output = router_observation.observe(router.describe_state()) + observed_rule = observed_output["ACL"][2] + assert observed_rule["position"] == 2 + assert observed_rule["permission"] == 1 + assert observed_rule["source_ip_id"] == 1 + assert observed_rule["source_wildcard_id"] == 1 + assert observed_rule["source_port_id"] == 1 + assert observed_rule["dest_ip_id"] == 1 + assert observed_rule["dest_wildcard_id"] == 1 + assert observed_rule["dest_port_id"] == 1 + assert observed_rule["protocol_id"] == 1 + + # Check that the router ports are all disabled + assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(1, 6)) + + # connect a switch to the router and check that only the correct port is updated + switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON) + link = net.connect(router.network_interface[1], switch.network_interface[1]) + assert router.network_interface[1].enabled + observed_output = router_observation.observe(router.describe_state()) + assert observed_output["PORTS"][1]["operating_status"] == 1 + assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(2, 6)) + + # disable the port and check that the operating status is updated + router.network_interface[1].disable() + assert not router.network_interface[1].enabled + observed_output = router_observation.observe(router.describe_state()) + assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(1, 6)) + + # Check that ports that are out of range are shown as unused + observed_output = router_observation.observe(router.describe_state()) + assert observed_output["PORTS"][6]["operating_status"] == 0 + assert observed_output["PORTS"][7]["operating_status"] == 0 + assert observed_output["PORTS"][8]["operating_status"] == 0