From 101fa3ebdbab7d58a2bc5d09fc32733ef80b3593 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Thu, 13 Jun 2024 16:29:02 +0100 Subject: [PATCH] #2658: change key of observation to match what is in CAOS document --- .../agent/observations/nic_observations.py | 41 ++++++++++--------- .../observations/test_nic_observations.py | 12 +++--- 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index d6744cba..373d2c94 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -55,17 +55,18 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): ) def _default_monitored_traffic_observation(self, monitored_traffic_config: Dict) -> Dict: - default_traffic_obs = {"monitored_traffic": {}} + default_traffic_obs = {"TRAFFIC": {}} for protocol in monitored_traffic_config: - default_traffic_obs["monitored_traffic"][str(protocol).lower()] = {} + protocol = str(protocol).lower() + default_traffic_obs["TRAFFIC"][protocol] = {} if protocol == "icmp": - default_traffic_obs["monitored_traffic"]["icmp"] = {"inbound": 0, "outbound": 0} + default_traffic_obs["TRAFFIC"]["icmp"] = {"inbound": 0, "outbound": 0} else: - default_traffic_obs["monitored_traffic"][protocol] = {} + default_traffic_obs["TRAFFIC"][protocol] = {} for port in monitored_traffic_config[protocol]: - default_traffic_obs["monitored_traffic"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0} + default_traffic_obs["TRAFFIC"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0} return default_traffic_obs @@ -120,16 +121,17 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): # if the observation was configured to monitor traffic from ports/protocols if self.monitored_traffic: - obs["monitored_traffic"] = {} + obs["TRAFFIC"] = {} # iterate through the protocols for protocol in self.monitored_traffic: - obs["monitored_traffic"][str(protocol).lower()] = {} + protocol = str(protocol).lower() + obs["TRAFFIC"][protocol] = {} # check if the nic has seen traffic with this protocol if nic_state["traffic"].get(protocol): # deal with icmp - if str(protocol).lower() == "icmp": - obs["monitored_traffic"][protocol] = { + if protocol == "icmp": + obs["TRAFFIC"][protocol] = { "inbound": self._categorise_traffic( traffic_value=nic_state["traffic"]["icmp"]["inbound"], nic_state=nic_state ), @@ -140,26 +142,26 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): else: for port in self.monitored_traffic[protocol]: port_enum = Port[port] - obs["monitored_traffic"][str(protocol).lower()][port_enum.value] = {} + obs["TRAFFIC"][protocol][port_enum.value] = {} traffic = {"inbound": 0, "outbound": 0} if nic_state["traffic"][protocol].get(port_enum.value) is not None: traffic = nic_state["traffic"][protocol][port_enum.value] - obs["monitored_traffic"][protocol][port_enum.value]["inbound"] = self._categorise_traffic( + obs["TRAFFIC"][protocol][port_enum.value]["inbound"] = self._categorise_traffic( traffic_value=traffic["inbound"], nic_state=nic_state ) - obs["monitored_traffic"][protocol][port_enum.value]["outbound"] = self._categorise_traffic( + obs["TRAFFIC"][protocol][port_enum.value]["outbound"] = self._categorise_traffic( traffic_value=traffic["outbound"], nic_state=nic_state ) # set all the ports under the protocol to 0 else: - if str(protocol).lower() == "icmp": - obs["monitored_traffic"]["icmp"] = {"inbound": 0, "outbound": 0} + if protocol == "icmp": + obs["TRAFFIC"]["icmp"] = {"inbound": 0, "outbound": 0} else: for port in self.monitored_traffic[protocol]: - obs["monitored_traffic"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0} + obs["TRAFFIC"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0} if self.include_nmne: obs.update({"NMNE": {}}) @@ -188,16 +190,17 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)}) if self.monitored_traffic: - space["monitored_traffic"] = spaces.Dict({}) + space["TRAFFIC"] = spaces.Dict({}) for protocol in self.monitored_traffic: + protocol = str(protocol).lower() if protocol == "icmp": - space["monitored_traffic"]["icmp"] = spaces.Dict( + space["TRAFFIC"]["icmp"] = spaces.Dict( {"inbound": spaces.Discrete(11), "outbound": spaces.Discrete(11)} ) else: - space["monitored_traffic"][protocol] = spaces.Dict({}) + space["TRAFFIC"][protocol] = spaces.Dict({}) for port in self.monitored_traffic[protocol]: - space["monitored_traffic"][protocol][Port[port].value] = spaces.Dict( + space["TRAFFIC"][protocol][Port[port].value] = spaces.Dict( {"inbound": spaces.Discrete(11), "outbound": spaces.Discrete(11)} ) diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index d7b1e347..7ac6dc1a 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -148,21 +148,21 @@ def test_nic_monitored_traffic(simulation): simulation.pre_timestep(0) # apply timestep to whole sim simulation.apply_timestep(0) # apply timestep to whole sim - traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic") + traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC") assert traffic_obs["icmp"]["inbound"] == 0 assert traffic_obs["icmp"]["outbound"] == 0 # send a ping pc.ping(target_ip_address=pc2.network_interface[1].ip_address) - traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic") + traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC") assert traffic_obs["icmp"]["inbound"] == 1 assert traffic_obs["icmp"]["outbound"] == 1 simulation.pre_timestep(1) # apply timestep to whole sim simulation.apply_timestep(1) # apply timestep to whole sim - traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic") + traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC") assert traffic_obs["icmp"]["inbound"] == 0 assert traffic_obs["icmp"]["outbound"] == 0 @@ -174,7 +174,7 @@ def test_nic_monitored_traffic(simulation): browser.target_url = f"http://arcd.com/" browser.get_webpage() - traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic") + traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC") assert traffic_obs["icmp"]["inbound"] == 0 assert traffic_obs["icmp"]["outbound"] == 0 assert traffic_obs["tcp"][53]["inbound"] == 0 @@ -182,7 +182,7 @@ def test_nic_monitored_traffic(simulation): simulation.pre_timestep(2) # apply timestep to whole sim simulation.apply_timestep(2) # apply timestep to whole sim - traffic_obs = nic_obs.observe(simulation.describe_state()).get("monitored_traffic") + traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC") assert traffic_obs["icmp"]["inbound"] == 0 assert traffic_obs["icmp"]["outbound"] == 0 @@ -198,7 +198,7 @@ def test_nic_monitored_traffic_config(): defender_agent: ProxyAgent = game.agents.get("defender") cur_obs = defender_agent.observation_manager.current_observation - assert cur_obs["NODES"]["HOST0"]["NICS"][1]["monitored_traffic"] == { + assert cur_obs["NODES"]["HOST0"]["NICS"][1]["TRAFFIC"] == { "icmp": {"inbound": 0, "outbound": 0}, "tcp": {53: {"inbound": 0, "outbound": 0}}, }