#2417 test firewall and router obs
This commit is contained in:
@@ -64,8 +64,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)}
|
||||
self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)}
|
||||
self.default_observation: Dict = {
|
||||
i
|
||||
+ 1: {
|
||||
i: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_ip_id": 0,
|
||||
@@ -76,7 +75,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
"dest_port_id": 0,
|
||||
"protocol_id": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
for i in range(1, self.num_rules + 1)
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
@@ -98,7 +97,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
rule_state = acl_items[i]
|
||||
if rule_state is None:
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_ip_id": 0,
|
||||
"source_wildcard_id": 0,
|
||||
@@ -124,7 +123,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = self.protocol_to_id.get(protocol, 1)
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"position": i,
|
||||
"permission": rule_state["action"],
|
||||
"source_ip_id": src_node_id,
|
||||
"source_wildcard_id": src_wildcard_id,
|
||||
|
||||
@@ -63,12 +63,12 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
self.where: WhereType = where
|
||||
|
||||
self.ports: List[PortObservation] = [
|
||||
PortObservation(where=self.where + ["port", port_num]) for port_num in (1, 2, 3)
|
||||
PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3)
|
||||
]
|
||||
# TODO: check what the port nums are for firewall.
|
||||
|
||||
self.internal_inbound_acl = ACLObservation(
|
||||
where=self.where + ["acl", "internal", "inbound"],
|
||||
where=self.where + ["internal_inbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
@@ -76,7 +76,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.internal_outbound_acl = ACLObservation(
|
||||
where=self.where + ["acl", "internal", "outbound"],
|
||||
where=self.where + ["internal_outbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
@@ -84,7 +84,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.dmz_inbound_acl = ACLObservation(
|
||||
where=self.where + ["acl", "dmz", "inbound"],
|
||||
where=self.where + ["dmz_inbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
@@ -92,7 +92,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.dmz_outbound_acl = ACLObservation(
|
||||
where=self.where + ["acl", "dmz", "outbound"],
|
||||
where=self.where + ["dmz_outbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
@@ -100,7 +100,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.external_inbound_acl = ACLObservation(
|
||||
where=self.where + ["acl", "external", "inbound"],
|
||||
where=self.where + ["external_inbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
@@ -108,7 +108,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.external_outbound_acl = ACLObservation(
|
||||
where=self.where + ["acl", "external", "outbound"],
|
||||
where=self.where + ["external_outbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
@@ -118,17 +118,19 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
|
||||
self.default_observation = {
|
||||
"PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)},
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.internal_outbound_acl.default_observation,
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.dmz_outbound_acl.default_observation,
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.external_outbound_acl.default_observation,
|
||||
"ACL": {
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.internal_outbound_acl.default_observation,
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.dmz_outbound_acl.default_observation,
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.external_outbound_acl.default_observation,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -143,17 +145,19 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
"""
|
||||
obs = {
|
||||
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.internal_outbound_acl.observe(state),
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.dmz_outbound_acl.observe(state),
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.external_outbound_acl.observe(state),
|
||||
"ACL": {
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.internal_outbound_acl.observe(state),
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.dmz_outbound_acl.observe(state),
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.external_outbound_acl.observe(state),
|
||||
},
|
||||
},
|
||||
}
|
||||
return obs
|
||||
@@ -169,22 +173,26 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
space = spaces.Dict(
|
||||
{
|
||||
"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}),
|
||||
"INTERNAL": spaces.Dict(
|
||||
"ACL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.internal_inbound_acl.space,
|
||||
"OUTBOUND": self.internal_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"DMZ": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.dmz_inbound_acl.space,
|
||||
"OUTBOUND": self.dmz_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"EXTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.external_inbound_acl.space,
|
||||
"OUTBOUND": self.external_outbound_acl.space,
|
||||
"INTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.internal_inbound_acl.space,
|
||||
"OUTBOUND": self.internal_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"DMZ": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.dmz_inbound_acl.space,
|
||||
"OUTBOUND": self.dmz_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"EXTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.external_inbound_acl.space,
|
||||
"OUTBOUND": self.external_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
from primaite.game.agent.observations.firewall_observation import FirewallObservation
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction
|
||||
from primaite.simulator.network.hardware.nodes.network.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
|
||||
|
||||
def check_default_rules(acl_obs):
|
||||
assert len(acl_obs) == 7
|
||||
assert all(acl_obs[i]["position"] == i for i in range(1, 8))
|
||||
assert all(acl_obs[i]["permission"] == 0 for i in range(1, 8))
|
||||
assert all(acl_obs[i]["source_ip_id"] == 0 for i in range(1, 8))
|
||||
assert all(acl_obs[i]["source_wildcard_id"] == 0 for i in range(1, 8))
|
||||
assert all(acl_obs[i]["source_port_id"] == 0 for i in range(1, 8))
|
||||
assert all(acl_obs[i]["dest_ip_id"] == 0 for i in range(1, 8))
|
||||
assert all(acl_obs[i]["dest_wildcard_id"] == 0 for i in range(1, 8))
|
||||
assert all(acl_obs[i]["dest_port_id"] == 0 for i in range(1, 8))
|
||||
assert all(acl_obs[i]["protocol_id"] == 0 for i in range(1, 8))
|
||||
|
||||
|
||||
def test_firewall_observation():
|
||||
"""Test adding/removing acl rules and enabling/disabling ports."""
|
||||
net = Network()
|
||||
firewall = Firewall(hostname="firewall", operating_state=NodeOperatingState.ON)
|
||||
firewall_observation = FirewallObservation(
|
||||
where=[],
|
||||
num_rules=7,
|
||||
ip_list=["10.0.0.1", "10.0.0.2"],
|
||||
wildcard_list=["0.0.0.255", "0.0.0.1"],
|
||||
port_list=["HTTP", "DNS"],
|
||||
protocol_list=["TCP"],
|
||||
)
|
||||
|
||||
observation = firewall_observation.observe(firewall.describe_state())
|
||||
assert "ACL" in observation
|
||||
assert "PORTS" in observation
|
||||
assert "INTERNAL" in observation["ACL"]
|
||||
assert "EXTERNAL" in observation["ACL"]
|
||||
assert "DMZ" in observation["ACL"]
|
||||
assert "INBOUND" in observation["ACL"]["INTERNAL"]
|
||||
assert "OUTBOUND" in observation["ACL"]["INTERNAL"]
|
||||
assert "INBOUND" in observation["ACL"]["EXTERNAL"]
|
||||
assert "OUTBOUND" in observation["ACL"]["EXTERNAL"]
|
||||
assert "INBOUND" in observation["ACL"]["DMZ"]
|
||||
assert "OUTBOUND" in observation["ACL"]["DMZ"]
|
||||
all_acls = (
|
||||
observation["ACL"]["INTERNAL"]["INBOUND"],
|
||||
observation["ACL"]["INTERNAL"]["OUTBOUND"],
|
||||
observation["ACL"]["EXTERNAL"]["INBOUND"],
|
||||
observation["ACL"]["EXTERNAL"]["OUTBOUND"],
|
||||
observation["ACL"]["DMZ"]["INBOUND"],
|
||||
observation["ACL"]["DMZ"]["OUTBOUND"],
|
||||
)
|
||||
for acl_obs in all_acls:
|
||||
check_default_rules(acl_obs)
|
||||
|
||||
# add a rule to the internal inbound and check that the observation is correct
|
||||
firewall.internal_inbound_acl.add_rule(
|
||||
action=ACLAction.DENY,
|
||||
protocol=IPProtocol.TCP,
|
||||
src_ip_address="10.0.0.1",
|
||||
src_wildcard_mask="0.0.0.1",
|
||||
dst_ip_address="10.0.0.2",
|
||||
dst_wildcard_mask="0.0.0.1",
|
||||
src_port=Port.HTTP,
|
||||
dst_port=Port.HTTP,
|
||||
position=5,
|
||||
)
|
||||
|
||||
observation = firewall_observation.observe(firewall.describe_state())
|
||||
observed_rule = observation["ACL"]["INTERNAL"]["INBOUND"][5]
|
||||
assert observed_rule["position"] == 5
|
||||
assert observed_rule["permission"] == 2
|
||||
assert observed_rule["source_ip_id"] == 2
|
||||
assert observed_rule["source_wildcard_id"] == 3
|
||||
assert observed_rule["source_port_id"] == 2
|
||||
assert observed_rule["dest_ip_id"] == 3
|
||||
assert observed_rule["dest_wildcard_id"] == 3
|
||||
assert observed_rule["dest_port_id"] == 2
|
||||
assert observed_rule["protocol_id"] == 2
|
||||
|
||||
# check that none of the other acls have changed
|
||||
all_acls = (
|
||||
observation["ACL"]["INTERNAL"]["OUTBOUND"],
|
||||
observation["ACL"]["EXTERNAL"]["INBOUND"],
|
||||
observation["ACL"]["EXTERNAL"]["OUTBOUND"],
|
||||
observation["ACL"]["DMZ"]["INBOUND"],
|
||||
observation["ACL"]["DMZ"]["OUTBOUND"],
|
||||
)
|
||||
for acl_obs in all_acls:
|
||||
check_default_rules(acl_obs)
|
||||
|
||||
# remove the rule and check that the observation is correct
|
||||
firewall.internal_inbound_acl.remove_rule(5)
|
||||
observation = firewall_observation.observe(firewall.describe_state())
|
||||
all_acls = (
|
||||
observation["ACL"]["INTERNAL"]["INBOUND"],
|
||||
observation["ACL"]["INTERNAL"]["OUTBOUND"],
|
||||
observation["ACL"]["EXTERNAL"]["INBOUND"],
|
||||
observation["ACL"]["EXTERNAL"]["OUTBOUND"],
|
||||
observation["ACL"]["DMZ"]["INBOUND"],
|
||||
observation["ACL"]["DMZ"]["OUTBOUND"],
|
||||
)
|
||||
for acl_obs in all_acls:
|
||||
check_default_rules(acl_obs)
|
||||
|
||||
# check that there are three ports in the observation
|
||||
assert len(observation["PORTS"]) == 3
|
||||
|
||||
# check that the ports are all disabled
|
||||
assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(1, 4))
|
||||
|
||||
# connect a switch to the firewall and check that only the correct port is updated
|
||||
switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON)
|
||||
link = net.connect(firewall.network_interface[1], switch.network_interface[1])
|
||||
assert firewall.network_interface[1].enabled
|
||||
observation = firewall_observation.observe(firewall.describe_state())
|
||||
assert observation["PORTS"][1]["operating_status"] == 1
|
||||
assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(2, 4))
|
||||
|
||||
# disable the port and check that the operating status is updated
|
||||
firewall.network_interface[1].disable()
|
||||
assert not firewall.network_interface[1].enabled
|
||||
observation = firewall_observation.observe(firewall.describe_state())
|
||||
assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(1, 4))
|
||||
@@ -0,0 +1,108 @@
|
||||
from pprint import pprint
|
||||
|
||||
from primaite.game.agent.observations.acl_observation import ACLObservation
|
||||
from primaite.game.agent.observations.nic_observations import PortObservation
|
||||
from primaite.game.agent.observations.router_observation import RouterObservation
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.network.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
|
||||
def test_router_observation():
|
||||
"""Test adding/removing acl rules and enabling/disabling ports."""
|
||||
net = Network()
|
||||
router = Router(hostname="router", num_ports=5, operating_state=NodeOperatingState.ON)
|
||||
|
||||
ports = [PortObservation(where=["NICs", i]) for i in range(1, 6)]
|
||||
acl = ACLObservation(
|
||||
where=["acl", "acl"],
|
||||
num_rules=7,
|
||||
ip_list=["10.0.0.1", "10.0.0.2"],
|
||||
wildcard_list=["0.0.0.255", "0.0.0.1"],
|
||||
port_list=["HTTP", "DNS"],
|
||||
protocol_list=["TCP"],
|
||||
)
|
||||
router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl)
|
||||
|
||||
# Observe the state using the RouterObservation instance
|
||||
observed_output = router_observation.observe(router.describe_state())
|
||||
|
||||
# Check that the right number of ports and acls are in the router observation
|
||||
assert len(observed_output["PORTS"]) == 8
|
||||
assert len(observed_output["ACL"]) == 7
|
||||
|
||||
# Add an ACL rule to the router
|
||||
router.acl.add_rule(
|
||||
action=ACLAction.DENY,
|
||||
protocol=IPProtocol.TCP,
|
||||
src_ip_address="10.0.0.1",
|
||||
src_wildcard_mask="0.0.0.1",
|
||||
dst_ip_address="10.0.0.2",
|
||||
dst_wildcard_mask="0.0.0.1",
|
||||
src_port=Port.HTTP,
|
||||
dst_port=Port.HTTP,
|
||||
position=5,
|
||||
)
|
||||
# Observe the state using the RouterObservation instance
|
||||
observed_output = router_observation.observe(router.describe_state())
|
||||
observed_rule = observed_output["ACL"][5]
|
||||
assert observed_rule["position"] == 5
|
||||
assert observed_rule["permission"] == 2
|
||||
assert observed_rule["source_ip_id"] == 2
|
||||
assert observed_rule["source_wildcard_id"] == 3
|
||||
assert observed_rule["source_port_id"] == 2
|
||||
assert observed_rule["dest_ip_id"] == 3
|
||||
assert observed_rule["dest_wildcard_id"] == 3
|
||||
assert observed_rule["dest_port_id"] == 2
|
||||
assert observed_rule["protocol_id"] == 2
|
||||
|
||||
# Add an ACL rule with ALL/NONE values and check that the observation is correct
|
||||
router.acl.add_rule(
|
||||
action=ACLAction.PERMIT,
|
||||
protocol=None,
|
||||
src_ip_address=None,
|
||||
src_wildcard_mask=None,
|
||||
dst_ip_address=None,
|
||||
dst_wildcard_mask=None,
|
||||
src_port=None,
|
||||
dst_port=None,
|
||||
position=2,
|
||||
)
|
||||
observed_output = router_observation.observe(router.describe_state())
|
||||
observed_rule = observed_output["ACL"][2]
|
||||
assert observed_rule["position"] == 2
|
||||
assert observed_rule["permission"] == 1
|
||||
assert observed_rule["source_ip_id"] == 1
|
||||
assert observed_rule["source_wildcard_id"] == 1
|
||||
assert observed_rule["source_port_id"] == 1
|
||||
assert observed_rule["dest_ip_id"] == 1
|
||||
assert observed_rule["dest_wildcard_id"] == 1
|
||||
assert observed_rule["dest_port_id"] == 1
|
||||
assert observed_rule["protocol_id"] == 1
|
||||
|
||||
# Check that the router ports are all disabled
|
||||
assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(1, 6))
|
||||
|
||||
# connect a switch to the router and check that only the correct port is updated
|
||||
switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON)
|
||||
link = net.connect(router.network_interface[1], switch.network_interface[1])
|
||||
assert router.network_interface[1].enabled
|
||||
observed_output = router_observation.observe(router.describe_state())
|
||||
assert observed_output["PORTS"][1]["operating_status"] == 1
|
||||
assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(2, 6))
|
||||
|
||||
# disable the port and check that the operating status is updated
|
||||
router.network_interface[1].disable()
|
||||
assert not router.network_interface[1].enabled
|
||||
observed_output = router_observation.observe(router.describe_state())
|
||||
assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(1, 6))
|
||||
|
||||
# Check that ports that are out of range are shown as unused
|
||||
observed_output = router_observation.observe(router.describe_state())
|
||||
assert observed_output["PORTS"][6]["operating_status"] == 0
|
||||
assert observed_output["PORTS"][7]["operating_status"] == 0
|
||||
assert observed_output["PORTS"][8]["operating_status"] == 0
|
||||
Reference in New Issue
Block a user