#2350: apply PR suggestions

This commit is contained in:
Czar Echavez
2024-03-12 12:20:02 +00:00
parent ec4818e4d3
commit f2c6f10c21
7 changed files with 19 additions and 19 deletions

View File

@@ -27,7 +27,7 @@ class NicObservation(AbstractObservation):
"""The default NIC observation dict."""
data = {"nic_status": 0}
if CAPTURE_NMNE:
data.update({"nmne": {"inbound": 0, "outbound": 0}})
data.update({"NMNE": {"inbound": 0, "outbound": 0}})
return data
@@ -133,14 +133,14 @@ class NicObservation(AbstractObservation):
else:
obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2}
if CAPTURE_NMNE:
obs_dict.update({"nmne": {}})
obs_dict.update({"NMNE": {}})
direction_dict = nic_state["nmne"].get("direction", {})
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
inbound_count = inbound_keywords.get("*", 0)
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
outbound_count = outbound_keywords.get("*", 0)
obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
obs_dict["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
obs_dict["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
self.nmne_inbound_last_step = inbound_count
self.nmne_outbound_last_step = outbound_count
return obs_dict
@@ -151,7 +151,7 @@ class NicObservation(AbstractObservation):
space = spaces.Dict({"nic_status": spaces.Discrete(3)})
if CAPTURE_NMNE:
space["nmne"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)})
space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)})
return space

View File

@@ -86,7 +86,7 @@ class NodeObservation(AbstractObservation):
self.default_observation: Dict = {
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
"NETWORK_INTERFACES": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
"operating_status": 0,
}
if self.logon_status:
@@ -111,7 +111,7 @@ class NodeObservation(AbstractObservation):
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
obs["operating_status"] = node_state["operating_state"]
obs["NETWORK_INTERFACES"] = {
obs["NICS"] = {
i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces)
}
@@ -127,7 +127,7 @@ class NodeObservation(AbstractObservation):
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
"operating_status": spaces.Discrete(5),
"NETWORK_INTERFACES": spaces.Dict(
"NICS": spaces.Dict(
{i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)}
),
}

View File

@@ -48,6 +48,7 @@ class DatabaseClient(Application):
def execute(self) -> bool:
"""Execution definition for db client: perform a select query."""
self.num_executions += 1 # trying to connect counts as an execution
if self.connections:
can_connect = self.check_connection(connection_id=list(self.connections.keys())[-1])
else:
@@ -82,8 +83,6 @@ class DatabaseClient(Application):
if not self._can_perform_action():
return False
self.num_executions += 1 # trying to connect counts as an execution
if not connection_id:
connection_id = str(uuid4())

View File

@@ -193,6 +193,8 @@ class DataManipulationBot(Application):
if not self._can_perform_action():
_LOGGER.debug("Data manipulation application attempted to execute but it cannot perform actions right now.")
self.run()
self.num_executions += 1
return self._application_loop()
def _application_loop(self) -> bool:
@@ -202,7 +204,6 @@ class DataManipulationBot(Application):
This is the core loop where the bot sequentially goes through the stages of the attack.
"""
if not self._can_perform_action():
self.num_executions += 1
return False
if self.server_ip_address and self.payload:
self.sys_log.info(f"{self.name}: Running")

View File

@@ -43,14 +43,14 @@ def test_nic(simulation):
nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
assert nic_obs.space["nic_status"] == spaces.Discrete(3)
assert nic_obs.space["nmne"]["inbound"] == spaces.Discrete(4)
assert nic_obs.space["nmne"]["outbound"] == spaces.Discrete(4)
assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4)
assert nic_obs.space["NMNE"]["outbound"] == spaces.Discrete(4)
observation_state = nic_obs.observe(simulation.describe_state())
assert observation_state.get("nic_status") == 1 # enabled
assert observation_state.get("nmne") is not None
assert observation_state["nmne"].get("inbound") == 0
assert observation_state["nmne"].get("outbound") == 0
assert observation_state.get("NMNE") is not None
assert observation_state["NMNE"].get("inbound") == 0
assert observation_state["NMNE"].get("outbound") == 0
nic.disable()
observation_state = nic_obs.observe(simulation.describe_state())

View File

@@ -32,7 +32,7 @@ def test_node_observation(simulation):
assert observation_state.get("SERVICES") is not None
assert observation_state.get("FOLDERS") is not None
assert observation_state.get("NETWORK_INTERFACES") is not None
assert observation_state.get("NICS") is not None
# turn off computer
pc.power_off()

View File

@@ -179,8 +179,8 @@ def test_capture_nmne_observations(uc2_network):
# Observe the current state of NMNEs from the NICs of both the database and web servers
state = sim.describe_state()
db_nic_obs = db_server_nic_obs.observe(state)["nmne"]
web_nic_obs = web_server_nic_obs.observe(state)["nmne"]
db_nic_obs = db_server_nic_obs.observe(state)["NMNE"]
web_nic_obs = web_server_nic_obs.observe(state)["NMNE"]
# Define expected NMNE values based on the iteration count
if i > 10: