Change ACL observation to 0-index and remove weird off-by-one offset
This commit is contained in:
@@ -65,8 +65,7 @@ class ACLObservation(AbstractObservation, discriminator="acl"):
|
|||||||
self.port_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(port_list)}
|
self.port_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(port_list)}
|
||||||
self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)}
|
self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)}
|
||||||
self.default_observation: Dict = {
|
self.default_observation: Dict = {
|
||||||
i
|
i: {
|
||||||
+ 1: {
|
|
||||||
"position": i,
|
"position": i,
|
||||||
"permission": 0,
|
"permission": 0,
|
||||||
"source_ip_id": 0,
|
"source_ip_id": 0,
|
||||||
@@ -94,12 +93,11 @@ class ACLObservation(AbstractObservation, discriminator="acl"):
|
|||||||
return self.default_observation
|
return self.default_observation
|
||||||
obs = {}
|
obs = {}
|
||||||
acl_items = dict(acl_state.items())
|
acl_items = dict(acl_state.items())
|
||||||
i = 1 # don't show rule 0 for compatibility reasons.
|
for i in range(self.num_rules):
|
||||||
while i < self.num_rules + 1:
|
|
||||||
rule_state = acl_items[i]
|
rule_state = acl_items[i]
|
||||||
if rule_state is None:
|
if rule_state is None:
|
||||||
obs[i] = {
|
obs[i] = {
|
||||||
"position": i - 1,
|
"position": i,
|
||||||
"permission": 0,
|
"permission": 0,
|
||||||
"source_ip_id": 0,
|
"source_ip_id": 0,
|
||||||
"source_wildcard_id": 0,
|
"source_wildcard_id": 0,
|
||||||
@@ -125,7 +123,7 @@ class ACLObservation(AbstractObservation, discriminator="acl"):
|
|||||||
protocol = rule_state["protocol"]
|
protocol = rule_state["protocol"]
|
||||||
protocol_id = self.protocol_to_id.get(protocol, 1)
|
protocol_id = self.protocol_to_id.get(protocol, 1)
|
||||||
obs[i] = {
|
obs[i] = {
|
||||||
"position": i - 1,
|
"position": i,
|
||||||
"permission": rule_state["action"],
|
"permission": rule_state["action"],
|
||||||
"source_ip_id": src_node_id,
|
"source_ip_id": src_node_id,
|
||||||
"source_wildcard_id": src_wildcard_id,
|
"source_wildcard_id": src_wildcard_id,
|
||||||
@@ -135,7 +133,6 @@ class ACLObservation(AbstractObservation, discriminator="acl"):
|
|||||||
"dest_port_id": dst_port_id,
|
"dest_port_id": dst_port_id,
|
||||||
"protocol_id": protocol_id,
|
"protocol_id": protocol_id,
|
||||||
}
|
}
|
||||||
i += 1
|
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -148,8 +145,7 @@ class ACLObservation(AbstractObservation, discriminator="acl"):
|
|||||||
"""
|
"""
|
||||||
return spaces.Dict(
|
return spaces.Dict(
|
||||||
{
|
{
|
||||||
i
|
i: spaces.Dict(
|
||||||
+ 1: spaces.Dict(
|
|
||||||
{
|
{
|
||||||
"position": spaces.Discrete(self.num_rules),
|
"position": spaces.Discrete(self.num_rules),
|
||||||
"permission": spaces.Discrete(3),
|
"permission": spaces.Discrete(3),
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def test_acl_observations(simulation):
|
|||||||
observation_space = acl_obs.observe(simulation.describe_state())
|
observation_space = acl_obs.observe(simulation.describe_state())
|
||||||
assert observation_space.get(1) is not None
|
assert observation_space.get(1) is not None
|
||||||
rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP
|
rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP
|
||||||
assert rule_obs.get("position") == 0 # rule was put at position 1 (0 because counting from 1 instead of 1)
|
assert rule_obs.get("position") == 1 # rule was put at position 1
|
||||||
assert rule_obs.get("permission") == 1 # permit = 1 deny = 2
|
assert rule_obs.get("permission") == 1 # permit = 1 deny = 2
|
||||||
assert rule_obs.get("source_ip_id") == 1 # applies to all source nodes
|
assert rule_obs.get("source_ip_id") == 1 # applies to all source nodes
|
||||||
assert rule_obs.get("dest_ip_id") == 1 # applies to all destination nodes
|
assert rule_obs.get("dest_ip_id") == 1 # applies to all destination nodes
|
||||||
@@ -60,7 +60,7 @@ def test_acl_observations(simulation):
|
|||||||
observation_space = acl_obs.observe(simulation.describe_state())
|
observation_space = acl_obs.observe(simulation.describe_state())
|
||||||
assert observation_space.get(1) is not None
|
assert observation_space.get(1) is not None
|
||||||
rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP
|
rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP
|
||||||
assert rule_obs.get("position") == 0
|
assert rule_obs.get("position") == 1
|
||||||
assert rule_obs.get("permission") == 0
|
assert rule_obs.get("permission") == 0
|
||||||
assert rule_obs.get("source_ip_id") == 0
|
assert rule_obs.get("source_ip_id") == 0
|
||||||
assert rule_obs.get("dest_ip_id") == 0
|
assert rule_obs.get("dest_ip_id") == 0
|
||||||
|
|||||||
@@ -11,15 +11,15 @@ from primaite.utils.validation.port import PORT_LOOKUP
|
|||||||
|
|
||||||
def check_default_rules(acl_obs):
|
def check_default_rules(acl_obs):
|
||||||
assert len(acl_obs) == 7
|
assert len(acl_obs) == 7
|
||||||
assert all(acl_obs[i]["position"] == i - 1 for i in range(1, 8))
|
assert all(acl_obs[i]["position"] == i for i in range(7))
|
||||||
assert all(acl_obs[i]["permission"] == 0 for i in range(1, 8))
|
assert all(acl_obs[i]["permission"] == 0 for i in range(7))
|
||||||
assert all(acl_obs[i]["source_ip_id"] == 0 for i in range(1, 8))
|
assert all(acl_obs[i]["source_ip_id"] == 0 for i in range(7))
|
||||||
assert all(acl_obs[i]["source_wildcard_id"] == 0 for i in range(1, 8))
|
assert all(acl_obs[i]["source_wildcard_id"] == 0 for i in range(7))
|
||||||
assert all(acl_obs[i]["source_port_id"] == 0 for i in range(1, 8))
|
assert all(acl_obs[i]["source_port_id"] == 0 for i in range(7))
|
||||||
assert all(acl_obs[i]["dest_ip_id"] == 0 for i in range(1, 8))
|
assert all(acl_obs[i]["dest_ip_id"] == 0 for i in range(7))
|
||||||
assert all(acl_obs[i]["dest_wildcard_id"] == 0 for i in range(1, 8))
|
assert all(acl_obs[i]["dest_wildcard_id"] == 0 for i in range(7))
|
||||||
assert all(acl_obs[i]["dest_port_id"] == 0 for i in range(1, 8))
|
assert all(acl_obs[i]["dest_port_id"] == 0 for i in range(7))
|
||||||
assert all(acl_obs[i]["protocol_id"] == 0 for i in range(1, 8))
|
assert all(acl_obs[i]["protocol_id"] == 0 for i in range(7))
|
||||||
|
|
||||||
|
|
||||||
def test_firewall_observation():
|
def test_firewall_observation():
|
||||||
@@ -75,7 +75,7 @@ def test_firewall_observation():
|
|||||||
|
|
||||||
observation = firewall_observation.observe(firewall.describe_state())
|
observation = firewall_observation.observe(firewall.describe_state())
|
||||||
observed_rule = observation["ACL"]["INTERNAL"]["INBOUND"][5]
|
observed_rule = observation["ACL"]["INTERNAL"]["INBOUND"][5]
|
||||||
assert observed_rule["position"] == 4
|
assert observed_rule["position"] == 5
|
||||||
assert observed_rule["permission"] == 2
|
assert observed_rule["permission"] == 2
|
||||||
assert observed_rule["source_ip_id"] == 2
|
assert observed_rule["source_ip_id"] == 2
|
||||||
assert observed_rule["source_wildcard_id"] == 3
|
assert observed_rule["source_wildcard_id"] == 3
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ def test_router_observation():
|
|||||||
# Observe the state using the RouterObservation instance
|
# Observe the state using the RouterObservation instance
|
||||||
observed_output = router_observation.observe(router.describe_state())
|
observed_output = router_observation.observe(router.describe_state())
|
||||||
observed_rule = observed_output["ACL"][5]
|
observed_rule = observed_output["ACL"][5]
|
||||||
assert observed_rule["position"] == 4
|
assert observed_rule["position"] == 5
|
||||||
assert observed_rule["permission"] == 2
|
assert observed_rule["permission"] == 2
|
||||||
assert observed_rule["source_ip_id"] == 2
|
assert observed_rule["source_ip_id"] == 2
|
||||||
assert observed_rule["source_wildcard_id"] == 3
|
assert observed_rule["source_wildcard_id"] == 3
|
||||||
@@ -77,7 +77,7 @@ def test_router_observation():
|
|||||||
)
|
)
|
||||||
observed_output = router_observation.observe(router.describe_state())
|
observed_output = router_observation.observe(router.describe_state())
|
||||||
observed_rule = observed_output["ACL"][2]
|
observed_rule = observed_output["ACL"][2]
|
||||||
assert observed_rule["position"] == 1
|
assert observed_rule["position"] == 2
|
||||||
assert observed_rule["permission"] == 1
|
assert observed_rule["permission"] == 1
|
||||||
assert observed_rule["source_ip_id"] == 1
|
assert observed_rule["source_ip_id"] == 1
|
||||||
assert observed_rule["source_wildcard_id"] == 1
|
assert observed_rule["source_wildcard_id"] == 1
|
||||||
|
|||||||
Reference in New Issue
Block a user