Tidy up node observation file
This commit is contained in:
@@ -1,17 +1,12 @@
|
||||
# TODO: make sure when config options are being passed down from higher-level observations to lower-level, but the lower-level also defines that option, don't overwrite.
|
||||
from __future__ import annotations
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
# from primaite.game.agent.observations.file_system_observations import FolderObservation
|
||||
# from primaite.game.agent.observations.nic_observations import NicObservation
|
||||
# from primaite.game.agent.observations.software_observation import ServiceObservation
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -420,7 +415,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
"dest_ip_id": 0,
|
||||
"dest_wildcard_id": 0,
|
||||
"dest_port_id": 0,
|
||||
"protocol": 0,
|
||||
"protocol_id": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
@@ -444,7 +439,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
"dest_ip_id": 0,
|
||||
"dest_wildcard_id": 0,
|
||||
"dest_port_id": 0,
|
||||
"protocol": 0,
|
||||
"protocol_id": 0,
|
||||
}
|
||||
else:
|
||||
src_ip = rule_state["src_ip_address"]
|
||||
@@ -470,7 +465,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
"dest_ip_id": dst_node_ip,
|
||||
"dest_wildcard_id": dst_wildcard_id,
|
||||
"dest_port_id": dst_port_id,
|
||||
"protocol": protocol_id,
|
||||
"protocol_id": protocol_id,
|
||||
}
|
||||
i += 1
|
||||
return obs
|
||||
@@ -491,7 +486,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
"dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2),
|
||||
"dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id)+2),
|
||||
"dest_port_id": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
"protocol_id": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
|
||||
Reference in New Issue
Block a user