diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index b51ea1f2..ed930265 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -1,17 +1,12 @@ -# TODO: make sure when config options are being passed down from higher-level observations to lower-level, but the lower-level also defines that option, don't overwrite. from __future__ import annotations from ipaddress import IPv4Address -from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, Iterable, List, Optional from gymnasium import spaces from gymnasium.core import ObsType -from pydantic import BaseModel, ConfigDict from primaite import getLogger from primaite.game.agent.observations.observations import AbstractObservation -# from primaite.game.agent.observations.file_system_observations import FolderObservation -# from primaite.game.agent.observations.nic_observations import NicObservation -# from primaite.game.agent.observations.software_observation import ServiceObservation from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE _LOGGER = getLogger(__name__) @@ -420,7 +415,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): "dest_ip_id": 0, "dest_wildcard_id": 0, "dest_port_id": 0, - "protocol": 0, + "protocol_id": 0, } for i in range(self.num_rules) } @@ -444,7 +439,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): "dest_ip_id": 0, "dest_wildcard_id": 0, "dest_port_id": 0, - "protocol": 0, + "protocol_id": 0, } else: src_ip = rule_state["src_ip_address"] @@ -470,7 +465,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): "dest_ip_id": dst_node_ip, "dest_wildcard_id": dst_wildcard_id, "dest_port_id": dst_port_id, - "protocol": protocol_id, + "protocol_id": protocol_id, } i += 1 return obs @@ -491,7 +486,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): "dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2), "dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id)+2), "dest_port_id": spaces.Discrete(len(self.port_to_id) + 2), - "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), + "protocol_id": spaces.Discrete(len(self.protocol_to_id) + 2), } ) for i in range(self.num_rules)