diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index 0cb3e8f6..e5216e4a 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -205,12 +205,15 @@ class LinkObservation(AbstractObservation): bandwidth = link_state["bandwidth"] load = link_state["current_load"] - utilisation_fraction = load / bandwidth - # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100% - utilisation_category = int(utilisation_fraction * 10) + 1 + if load == 0: + utilisation_category = 0 + else: + utilisation_fraction = load / bandwidth + # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100% + utilisation_category = int(utilisation_fraction * 9) + 1 # TODO: once the links support separte load per protocol, this needs amendment to reflect that. - return {"PROTOCOLS": {"ALL": utilisation_category}} + return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}} @property def space(self) -> spaces.Space: diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index a310a3f5..f41c1ab6 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1271,8 +1271,8 @@ class Node(SimComponent): self.start_up_countdown = self.start_up_duration if self.start_up_duration <= 0: - self._start_up_actions() self.operating_state = NodeOperatingState.ON + self._start_up_actions() self.sys_log.info("Turned on") for nic in self.nics.values(): if nic._connected_link: diff --git a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py index 0dc2c031..dad6f879 100644 --- a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py +++ b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py @@ -22,7 +22,7 @@ def test_data_manipulation(uc2_network): assert db_client.query("SELECT") # Now we run the DataManipulationBot - db_manipulation_bot.run() + db_manipulation_bot.attack() # Now check that the DB client on the web_server cannot query the users table on the database assert not db_client.query("SELECT")