#2417 fix last observation tests
This commit is contained in:
@@ -64,7 +64,8 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)}
|
||||
self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)}
|
||||
self.default_observation: Dict = {
|
||||
i: {
|
||||
i
|
||||
+ 1: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_ip_id": 0,
|
||||
@@ -75,7 +76,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
"dest_port_id": 0,
|
||||
"protocol_id": 0,
|
||||
}
|
||||
for i in range(1, self.num_rules + 1)
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
@@ -97,7 +98,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
rule_state = acl_items[i]
|
||||
if rule_state is None:
|
||||
obs[i] = {
|
||||
"position": i,
|
||||
"position": i - 1,
|
||||
"permission": 0,
|
||||
"source_ip_id": 0,
|
||||
"source_wildcard_id": 0,
|
||||
@@ -123,7 +124,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = self.protocol_to_id.get(protocol, 1)
|
||||
obs[i] = {
|
||||
"position": i,
|
||||
"position": i - 1,
|
||||
"permission": rule_state["action"],
|
||||
"source_ip_id": src_node_id,
|
||||
"source_wildcard_id": src_wildcard_id,
|
||||
|
||||
@@ -10,7 +10,7 @@ from primaite.simulator.network.transmission.transport_layer import Port
|
||||
|
||||
def check_default_rules(acl_obs):
|
||||
assert len(acl_obs) == 7
|
||||
assert all(acl_obs[i]["position"] == i for i in range(1, 8))
|
||||
assert all(acl_obs[i]["position"] == i - 1 for i in range(1, 8))
|
||||
assert all(acl_obs[i]["permission"] == 0 for i in range(1, 8))
|
||||
assert all(acl_obs[i]["source_ip_id"] == 0 for i in range(1, 8))
|
||||
assert all(acl_obs[i]["source_wildcard_id"] == 0 for i in range(1, 8))
|
||||
@@ -72,7 +72,7 @@ def test_firewall_observation():
|
||||
|
||||
observation = firewall_observation.observe(firewall.describe_state())
|
||||
observed_rule = observation["ACL"]["INTERNAL"]["INBOUND"][5]
|
||||
assert observed_rule["position"] == 5
|
||||
assert observed_rule["position"] == 4
|
||||
assert observed_rule["permission"] == 2
|
||||
assert observed_rule["source_ip_id"] == 2
|
||||
assert observed_rule["source_wildcard_id"] == 3
|
||||
|
||||
@@ -4,8 +4,10 @@ from gymnasium import spaces
|
||||
from primaite.game.agent.observations.link_observation import LinkObservation
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.base import Link, Node
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.network.switch import Switch
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
|
||||
@@ -71,3 +73,43 @@ def test_link_observation(simulation):
|
||||
|
||||
observation_state = link_obs.observe(simulation.describe_state())
|
||||
assert observation_state["PROTOCOLS"]["ALL"] == 1
|
||||
|
||||
|
||||
def test_link_observation_again():
|
||||
net = Network()
|
||||
sim = Simulation(network=net)
|
||||
switch = Switch(hostname="switch", num_ports=5, operating_state=NodeOperatingState.ON)
|
||||
computer_1 = Computer(
|
||||
hostname="computer_1", ip_address="10.0.0.1", subnet_mask="255.255.255.0", start_up_duration=0
|
||||
)
|
||||
computer_2 = Computer(
|
||||
hostname="computer_2", ip_address="10.0.0.2", subnet_mask="255.255.255.0", start_up_duration=0
|
||||
)
|
||||
computer_1.power_on()
|
||||
computer_2.power_on()
|
||||
link_1 = net.connect(switch.network_interface[1], computer_1.network_interface[1])
|
||||
link_2 = net.connect(switch.network_interface[2], computer_2.network_interface[1])
|
||||
assert link_1 is not None
|
||||
assert link_2 is not None
|
||||
|
||||
link_1_observation = LinkObservation(where=["network", "links", link_1.uuid])
|
||||
link_2_observation = LinkObservation(where=["network", "links", link_2.uuid])
|
||||
|
||||
state = sim.describe_state()
|
||||
link_1_obs = link_1_observation.observe(state)
|
||||
link_2_obs = link_2_observation.observe(state)
|
||||
assert "PROTOCOLS" in link_1_obs
|
||||
assert "PROTOCOLS" in link_2_obs
|
||||
assert "ALL" in link_1_obs["PROTOCOLS"]
|
||||
assert "ALL" in link_2_obs["PROTOCOLS"]
|
||||
assert link_1_obs["PROTOCOLS"]["ALL"] == 0
|
||||
assert link_2_obs["PROTOCOLS"]["ALL"] == 0
|
||||
|
||||
# Test that the link observation is updated when a packet is sent
|
||||
computer_1.ping("10.0.0.2")
|
||||
computer_2.ping("10.0.0.1")
|
||||
state = sim.describe_state()
|
||||
link_1_obs = link_1_observation.observe(state)
|
||||
link_2_obs = link_2_observation.observe(state)
|
||||
assert link_1_obs["PROTOCOLS"]["ALL"] > 0
|
||||
assert link_2_obs["PROTOCOLS"]["ALL"] > 0
|
||||
|
||||
@@ -50,7 +50,7 @@ def test_router_observation():
|
||||
# Observe the state using the RouterObservation instance
|
||||
observed_output = router_observation.observe(router.describe_state())
|
||||
observed_rule = observed_output["ACL"][5]
|
||||
assert observed_rule["position"] == 5
|
||||
assert observed_rule["position"] == 4
|
||||
assert observed_rule["permission"] == 2
|
||||
assert observed_rule["source_ip_id"] == 2
|
||||
assert observed_rule["source_wildcard_id"] == 3
|
||||
@@ -74,7 +74,7 @@ def test_router_observation():
|
||||
)
|
||||
observed_output = router_observation.observe(router.describe_state())
|
||||
observed_rule = observed_output["ACL"][2]
|
||||
assert observed_rule["position"] == 2
|
||||
assert observed_rule["position"] == 1
|
||||
assert observed_rule["permission"] == 1
|
||||
assert observed_rule["source_ip_id"] == 1
|
||||
assert observed_rule["source_wildcard_id"] == 1
|
||||
|
||||
Reference in New Issue
Block a user