From f2c6f10c21f445cf5d85db808ad4092ffa923993 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 12 Mar 2024 12:20:02 +0000 Subject: [PATCH] #2350: apply PR suggestions --- .../game/agent/observations/nic_observations.py | 10 +++++----- .../game/agent/observations/node_observations.py | 6 +++--- .../simulator/system/applications/database_client.py | 3 +-- .../red_applications/data_manipulation_bot.py | 3 ++- .../game_layer/observations/test_nic_observations.py | 10 +++++----- .../game_layer/observations/test_node_observations.py | 2 +- tests/integration_tests/network/test_capture_nmne.py | 4 ++-- 7 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 735b41d4..de83e03a 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -27,7 +27,7 @@ class NicObservation(AbstractObservation): """The default NIC observation dict.""" data = {"nic_status": 0} if CAPTURE_NMNE: - data.update({"nmne": {"inbound": 0, "outbound": 0}}) + data.update({"NMNE": {"inbound": 0, "outbound": 0}}) return data @@ -133,14 +133,14 @@ class NicObservation(AbstractObservation): else: obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2} if CAPTURE_NMNE: - obs_dict.update({"nmne": {}}) + obs_dict.update({"NMNE": {}}) direction_dict = nic_state["nmne"].get("direction", {}) inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) inbound_count = inbound_keywords.get("*", 0) outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) outbound_count = outbound_keywords.get("*", 0) - obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step) - obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step) + obs_dict["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step) + obs_dict["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step) self.nmne_inbound_last_step = inbound_count self.nmne_outbound_last_step = outbound_count return obs_dict @@ -151,7 +151,7 @@ class NicObservation(AbstractObservation): space = spaces.Dict({"nic_status": spaces.Discrete(3)}) if CAPTURE_NMNE: - space["nmne"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)}) + space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)}) return space diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index f211a6b5..94f0974b 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -86,7 +86,7 @@ class NodeObservation(AbstractObservation): self.default_observation: Dict = { "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, - "NETWORK_INTERFACES": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, + "NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, "operating_status": 0, } if self.logon_status: @@ -111,7 +111,7 @@ class NodeObservation(AbstractObservation): obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} obs["operating_status"] = node_state["operating_state"] - obs["NETWORK_INTERFACES"] = { + obs["NICS"] = { i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) } @@ -127,7 +127,7 @@ class NodeObservation(AbstractObservation): "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), "operating_status": spaces.Discrete(5), - "NETWORK_INTERFACES": spaces.Dict( + "NICS": spaces.Dict( {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} ), } diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index bc51b3a2..d3afef59 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -48,6 +48,7 @@ class DatabaseClient(Application): def execute(self) -> bool: """Execution definition for db client: perform a select query.""" + self.num_executions += 1 # trying to connect counts as an execution if self.connections: can_connect = self.check_connection(connection_id=list(self.connections.keys())[-1]) else: @@ -82,8 +83,6 @@ class DatabaseClient(Application): if not self._can_perform_action(): return False - self.num_executions += 1 # trying to connect counts as an execution - if not connection_id: connection_id = str(uuid4()) diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index 2a6c2b11..ee276971 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -193,6 +193,8 @@ class DataManipulationBot(Application): if not self._can_perform_action(): _LOGGER.debug("Data manipulation application attempted to execute but it cannot perform actions right now.") self.run() + + self.num_executions += 1 return self._application_loop() def _application_loop(self) -> bool: @@ -202,7 +204,6 @@ class DataManipulationBot(Application): This is the core loop where the bot sequentially goes through the stages of the attack. """ if not self._can_perform_action(): - self.num_executions += 1 return False if self.server_ip_address and self.payload: self.sys_log.info(f"{self.name}: Running") diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index c210b751..332bc1f7 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -43,14 +43,14 @@ def test_nic(simulation): nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1]) assert nic_obs.space["nic_status"] == spaces.Discrete(3) - assert nic_obs.space["nmne"]["inbound"] == spaces.Discrete(4) - assert nic_obs.space["nmne"]["outbound"] == spaces.Discrete(4) + assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4) + assert nic_obs.space["NMNE"]["outbound"] == spaces.Discrete(4) observation_state = nic_obs.observe(simulation.describe_state()) assert observation_state.get("nic_status") == 1 # enabled - assert observation_state.get("nmne") is not None - assert observation_state["nmne"].get("inbound") == 0 - assert observation_state["nmne"].get("outbound") == 0 + assert observation_state.get("NMNE") is not None + assert observation_state["NMNE"].get("inbound") == 0 + assert observation_state["NMNE"].get("outbound") == 0 nic.disable() observation_state = nic_obs.observe(simulation.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index b1563fbd..dce05b6a 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -32,7 +32,7 @@ def test_node_observation(simulation): assert observation_state.get("SERVICES") is not None assert observation_state.get("FOLDERS") is not None - assert observation_state.get("NETWORK_INTERFACES") is not None + assert observation_state.get("NICS") is not None # turn off computer pc.power_off() diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index 85fcf102..9efc70f7 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -179,8 +179,8 @@ def test_capture_nmne_observations(uc2_network): # Observe the current state of NMNEs from the NICs of both the database and web servers state = sim.describe_state() - db_nic_obs = db_server_nic_obs.observe(state)["nmne"] - web_nic_obs = web_server_nic_obs.observe(state)["nmne"] + db_nic_obs = db_server_nic_obs.observe(state)["NMNE"] + web_nic_obs = web_server_nic_obs.observe(state)["NMNE"] # Define expected NMNE values based on the iteration count if i > 10: