From d55b6a5b48bf0faa6aeed6bd5ee94c65ab90912b Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 28 Feb 2024 12:03:58 +0000 Subject: [PATCH] #2238 - Fixed the observations issue causing tests to fail --- src/primaite/game/agent/observations.py | 7 +++++-- src/primaite/simulator/network/hardware/base.py | 2 ++ .../simulator/network/hardware/nodes/host/host_node.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index 7ccc3f11..82e11fe0 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -351,6 +351,8 @@ class NicObservation(AbstractObservation): def default_observation(self) -> Dict: """The default NIC observation dict.""" data = {"nic_status": 0} + if CAPTURE_NMNE: + data.update({"nmne": {"inbound": 0, "outbound": 0}}) return data @@ -404,8 +406,9 @@ class NicObservation(AbstractObservation): if nic_state is NOT_PRESENT_IN_STATE: return self.default_observation else: - obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2, "nmne": {}} - if CAPTURE_NMNE and nic_state.get("nmne"): + obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2} + if CAPTURE_NMNE: + obs_dict.update({"nmne": {}}) direction_dict = nic_state["nmne"].get("direction", {}) inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) inbound_count = inbound_keywords.get("*", 0) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index b22bea25..35c90d05 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -123,6 +123,8 @@ class NetworkInterface(SimComponent, ABC): "enabled": self.enabled, } ) + if CAPTURE_NMNE: + state.update({"nmne": self.nmne}) return state def reset_component_for_episode(self, episode: int): diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 8e104924..b48950b7 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -205,7 +205,7 @@ class NIC(IPWiredNetworkInterface): state = super().describe_state() # Update the state with NIC-specific information - state.update({"wake_on_lan": self.wake_on_lan, "nmne": self.nmne}) + state.update({"wake_on_lan": self.wake_on_lan}) return state