#2417 test firewall and router obs
This commit is contained in:
@@ -64,8 +64,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)}
|
||||
self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)}
|
||||
self.default_observation: Dict = {
|
||||
i
|
||||
+ 1: {
|
||||
i: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_ip_id": 0,
|
||||
@@ -76,7 +75,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
"dest_port_id": 0,
|
||||
"protocol_id": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
for i in range(1, self.num_rules + 1)
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
@@ -98,7 +97,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
rule_state = acl_items[i]
|
||||
if rule_state is None:
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_ip_id": 0,
|
||||
"source_wildcard_id": 0,
|
||||
@@ -124,7 +123,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = self.protocol_to_id.get(protocol, 1)
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"position": i,
|
||||
"permission": rule_state["action"],
|
||||
"source_ip_id": src_node_id,
|
||||
"source_wildcard_id": src_wildcard_id,
|
||||
|
||||
@@ -63,12 +63,12 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
self.where: WhereType = where
|
||||
|
||||
self.ports: List[PortObservation] = [
|
||||
PortObservation(where=self.where + ["port", port_num]) for port_num in (1, 2, 3)
|
||||
PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3)
|
||||
]
|
||||
# TODO: check what the port nums are for firewall.
|
||||
|
||||
self.internal_inbound_acl = ACLObservation(
|
||||
where=self.where + ["acl", "internal", "inbound"],
|
||||
where=self.where + ["internal_inbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
@@ -76,7 +76,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.internal_outbound_acl = ACLObservation(
|
||||
where=self.where + ["acl", "internal", "outbound"],
|
||||
where=self.where + ["internal_outbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
@@ -84,7 +84,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.dmz_inbound_acl = ACLObservation(
|
||||
where=self.where + ["acl", "dmz", "inbound"],
|
||||
where=self.where + ["dmz_inbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
@@ -92,7 +92,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.dmz_outbound_acl = ACLObservation(
|
||||
where=self.where + ["acl", "dmz", "outbound"],
|
||||
where=self.where + ["dmz_outbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
@@ -100,7 +100,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.external_inbound_acl = ACLObservation(
|
||||
where=self.where + ["acl", "external", "inbound"],
|
||||
where=self.where + ["external_inbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
@@ -108,7 +108,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.external_outbound_acl = ACLObservation(
|
||||
where=self.where + ["acl", "external", "outbound"],
|
||||
where=self.where + ["external_outbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
@@ -118,17 +118,19 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
|
||||
self.default_observation = {
|
||||
"PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)},
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.internal_outbound_acl.default_observation,
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.dmz_outbound_acl.default_observation,
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.external_outbound_acl.default_observation,
|
||||
"ACL": {
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.internal_outbound_acl.default_observation,
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.dmz_outbound_acl.default_observation,
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.external_outbound_acl.default_observation,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -143,17 +145,19 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
"""
|
||||
obs = {
|
||||
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.internal_outbound_acl.observe(state),
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.dmz_outbound_acl.observe(state),
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.external_outbound_acl.observe(state),
|
||||
"ACL": {
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.internal_outbound_acl.observe(state),
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.dmz_outbound_acl.observe(state),
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.external_outbound_acl.observe(state),
|
||||
},
|
||||
},
|
||||
}
|
||||
return obs
|
||||
@@ -169,22 +173,26 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
space = spaces.Dict(
|
||||
{
|
||||
"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}),
|
||||
"INTERNAL": spaces.Dict(
|
||||
"ACL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.internal_inbound_acl.space,
|
||||
"OUTBOUND": self.internal_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"DMZ": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.dmz_inbound_acl.space,
|
||||
"OUTBOUND": self.dmz_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"EXTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.external_inbound_acl.space,
|
||||
"OUTBOUND": self.external_outbound_acl.space,
|
||||
"INTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.internal_inbound_acl.space,
|
||||
"OUTBOUND": self.internal_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"DMZ": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.dmz_inbound_acl.space,
|
||||
"OUTBOUND": self.dmz_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"EXTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.external_inbound_acl.space,
|
||||
"OUTBOUND": self.external_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user