From 1798674d39ad15f1b0c1d64a4dc1f36fcb08c45b Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Thu, 13 Jun 2024 12:37:20 +0100 Subject: [PATCH] #2658: fix space --- .../game/agent/observations/nic_observations.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 56494748..125ec951 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -63,6 +63,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): if protocol == "icmp": default_traffic_obs["monitored_traffic"]["icmp"] = {"inbound": 0, "outbound": 0} else: + default_traffic_obs["monitored_traffic"][protocol] = {} for port in monitored_traffic_config[protocol]: default_traffic_obs["monitored_traffic"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0} @@ -186,7 +187,21 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): if self.include_nmne: space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)}) - return space + if self.monitored_traffic: + space["monitored_traffic"] = spaces.Dict({}) + for protocol in self.monitored_traffic: + if protocol == "icmp": + space["monitored_traffic"]["icmp"] = spaces.Dict( + {"inbound": spaces.Discrete(11), "outbound": spaces.Discrete(11)} + ) + else: + space["monitored_traffic"][protocol] = spaces.Dict({}) + for port in self.monitored_traffic[protocol]: + space["monitored_traffic"][protocol][Port[port].value] = spaces.Dict( + {"inbound": spaces.Discrete(11), "outbound": spaces.Discrete(11)} + ) + + return spaces.Dict(space) @classmethod def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NICObservation: