#2658: change key of observation to match what is in CAOS document
This commit is contained in:
@@ -55,17 +55,18 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
)
|
||||
|
||||
def _default_monitored_traffic_observation(self, monitored_traffic_config: Dict) -> Dict:
|
||||
default_traffic_obs = {"monitored_traffic": {}}
|
||||
default_traffic_obs = {"TRAFFIC": {}}
|
||||
|
||||
for protocol in monitored_traffic_config:
|
||||
default_traffic_obs["monitored_traffic"][str(protocol).lower()] = {}
|
||||
protocol = str(protocol).lower()
|
||||
default_traffic_obs["TRAFFIC"][protocol] = {}
|
||||
|
||||
if protocol == "icmp":
|
||||
default_traffic_obs["monitored_traffic"]["icmp"] = {"inbound": 0, "outbound": 0}
|
||||
default_traffic_obs["TRAFFIC"]["icmp"] = {"inbound": 0, "outbound": 0}
|
||||
else:
|
||||
default_traffic_obs["monitored_traffic"][protocol] = {}
|
||||
default_traffic_obs["TRAFFIC"][protocol] = {}
|
||||
for port in monitored_traffic_config[protocol]:
|
||||
default_traffic_obs["monitored_traffic"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0}
|
||||
default_traffic_obs["TRAFFIC"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0}
|
||||
|
||||
return default_traffic_obs
|
||||
|
||||
@@ -120,16 +121,17 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
|
||||
# if the observation was configured to monitor traffic from ports/protocols
|
||||
if self.monitored_traffic:
|
||||
obs["monitored_traffic"] = {}
|
||||
obs["TRAFFIC"] = {}
|
||||
|
||||
# iterate through the protocols
|
||||
for protocol in self.monitored_traffic:
|
||||
obs["monitored_traffic"][str(protocol).lower()] = {}
|
||||
protocol = str(protocol).lower()
|
||||
obs["TRAFFIC"][protocol] = {}
|
||||
# check if the nic has seen traffic with this protocol
|
||||
if nic_state["traffic"].get(protocol):
|
||||
# deal with icmp
|
||||
if str(protocol).lower() == "icmp":
|
||||
obs["monitored_traffic"][protocol] = {
|
||||
if protocol == "icmp":
|
||||
obs["TRAFFIC"][protocol] = {
|
||||
"inbound": self._categorise_traffic(
|
||||
traffic_value=nic_state["traffic"]["icmp"]["inbound"], nic_state=nic_state
|
||||
),
|
||||
@@ -140,26 +142,26 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
else:
|
||||
for port in self.monitored_traffic[protocol]:
|
||||
port_enum = Port[port]
|
||||
obs["monitored_traffic"][str(protocol).lower()][port_enum.value] = {}
|
||||
obs["TRAFFIC"][protocol][port_enum.value] = {}
|
||||
traffic = {"inbound": 0, "outbound": 0}
|
||||
|
||||
if nic_state["traffic"][protocol].get(port_enum.value) is not None:
|
||||
traffic = nic_state["traffic"][protocol][port_enum.value]
|
||||
|
||||
obs["monitored_traffic"][protocol][port_enum.value]["inbound"] = self._categorise_traffic(
|
||||
obs["TRAFFIC"][protocol][port_enum.value]["inbound"] = self._categorise_traffic(
|
||||
traffic_value=traffic["inbound"], nic_state=nic_state
|
||||
)
|
||||
obs["monitored_traffic"][protocol][port_enum.value]["outbound"] = self._categorise_traffic(
|
||||
obs["TRAFFIC"][protocol][port_enum.value]["outbound"] = self._categorise_traffic(
|
||||
traffic_value=traffic["outbound"], nic_state=nic_state
|
||||
)
|
||||
|
||||
# set all the ports under the protocol to 0
|
||||
else:
|
||||
if str(protocol).lower() == "icmp":
|
||||
obs["monitored_traffic"]["icmp"] = {"inbound": 0, "outbound": 0}
|
||||
if protocol == "icmp":
|
||||
obs["TRAFFIC"]["icmp"] = {"inbound": 0, "outbound": 0}
|
||||
else:
|
||||
for port in self.monitored_traffic[protocol]:
|
||||
obs["monitored_traffic"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0}
|
||||
obs["TRAFFIC"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0}
|
||||
|
||||
if self.include_nmne:
|
||||
obs.update({"NMNE": {}})
|
||||
@@ -188,16 +190,17 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)})
|
||||
|
||||
if self.monitored_traffic:
|
||||
space["monitored_traffic"] = spaces.Dict({})
|
||||
space["TRAFFIC"] = spaces.Dict({})
|
||||
for protocol in self.monitored_traffic:
|
||||
protocol = str(protocol).lower()
|
||||
if protocol == "icmp":
|
||||
space["monitored_traffic"]["icmp"] = spaces.Dict(
|
||||
space["TRAFFIC"]["icmp"] = spaces.Dict(
|
||||
{"inbound": spaces.Discrete(11), "outbound": spaces.Discrete(11)}
|
||||
)
|
||||
else:
|
||||
space["monitored_traffic"][protocol] = spaces.Dict({})
|
||||
space["TRAFFIC"][protocol] = spaces.Dict({})
|
||||
for port in self.monitored_traffic[protocol]:
|
||||
space["monitored_traffic"][protocol][Port[port].value] = spaces.Dict(
|
||||
space["TRAFFIC"][protocol][Port[port].value] = spaces.Dict(
|
||||
{"inbound": spaces.Discrete(11), "outbound": spaces.Discrete(11)}
|
||||
)
|
||||
|
||||
|
||||
@@ -148,21 +148,21 @@ def test_nic_monitored_traffic(simulation):
|
||||
|
||||
simulation.pre_timestep(0) # apply timestep to whole sim
|
||||
simulation.apply_timestep(0) # apply timestep to whole sim
|
||||
traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic")
|
||||
traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC")
|
||||
|
||||
assert traffic_obs["icmp"]["inbound"] == 0
|
||||
assert traffic_obs["icmp"]["outbound"] == 0
|
||||
|
||||
# send a ping
|
||||
pc.ping(target_ip_address=pc2.network_interface[1].ip_address)
|
||||
traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic")
|
||||
traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC")
|
||||
|
||||
assert traffic_obs["icmp"]["inbound"] == 1
|
||||
assert traffic_obs["icmp"]["outbound"] == 1
|
||||
|
||||
simulation.pre_timestep(1) # apply timestep to whole sim
|
||||
simulation.apply_timestep(1) # apply timestep to whole sim
|
||||
traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic")
|
||||
traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC")
|
||||
|
||||
assert traffic_obs["icmp"]["inbound"] == 0
|
||||
assert traffic_obs["icmp"]["outbound"] == 0
|
||||
@@ -174,7 +174,7 @@ def test_nic_monitored_traffic(simulation):
|
||||
browser.target_url = f"http://arcd.com/"
|
||||
browser.get_webpage()
|
||||
|
||||
traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic")
|
||||
traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC")
|
||||
assert traffic_obs["icmp"]["inbound"] == 0
|
||||
assert traffic_obs["icmp"]["outbound"] == 0
|
||||
assert traffic_obs["tcp"][53]["inbound"] == 0
|
||||
@@ -182,7 +182,7 @@ def test_nic_monitored_traffic(simulation):
|
||||
|
||||
simulation.pre_timestep(2) # apply timestep to whole sim
|
||||
simulation.apply_timestep(2) # apply timestep to whole sim
|
||||
traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic")
|
||||
traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC")
|
||||
|
||||
assert traffic_obs["icmp"]["inbound"] == 0
|
||||
assert traffic_obs["icmp"]["outbound"] == 0
|
||||
@@ -198,7 +198,7 @@ def test_nic_monitored_traffic_config():
|
||||
defender_agent: ProxyAgent = game.agents.get("defender")
|
||||
cur_obs = defender_agent.observation_manager.current_observation
|
||||
|
||||
assert cur_obs["NODES"]["HOST0"]["NICS"][1]["monitored_traffic"] == {
|
||||
assert cur_obs["NODES"]["HOST0"]["NICS"][1]["TRAFFIC"] == {
|
||||
"icmp": {"inbound": 0, "outbound": 0},
|
||||
"tcp": {53: {"inbound": 0, "outbound": 0}},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user