diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 56494748..125ec951 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -63,6 +63,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): if protocol == "icmp": default_traffic_obs["monitored_traffic"]["icmp"] = {"inbound": 0, "outbound": 0} else: + default_traffic_obs["monitored_traffic"][protocol] = {} for port in monitored_traffic_config[protocol]: default_traffic_obs["monitored_traffic"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0} @@ -186,7 +187,21 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): if self.include_nmne: space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)}) - return space + if self.monitored_traffic: + space["monitored_traffic"] = spaces.Dict({}) + for protocol in self.monitored_traffic: + if protocol == "icmp": + space["monitored_traffic"]["icmp"] = spaces.Dict( + {"inbound": spaces.Discrete(11), "outbound": spaces.Discrete(11)} + ) + else: + space["monitored_traffic"][protocol] = spaces.Dict({}) + for port in self.monitored_traffic[protocol]: + space["monitored_traffic"][protocol][Port[port].value] = spaces.Dict( + {"inbound": spaces.Discrete(11), "outbound": spaces.Discrete(11)} + ) + + return spaces.Dict(space) @classmethod def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NICObservation: