#2350: add tests to check spaces + acl obs test + nmne space changes
This commit is contained in:
@@ -22,8 +22,6 @@ io_settings:
|
||||
game:
|
||||
max_episode_length: 256
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
|
||||
@@ -20,6 +20,8 @@ class NicObservation(AbstractObservation):
|
||||
high_nmne_threshold: int = 10
|
||||
"""The minimum number of malicious network events to be considered high."""
|
||||
|
||||
global CAPTURE_NMNE
|
||||
|
||||
@property
|
||||
def default_observation(self) -> Dict:
|
||||
"""The default NIC observation dict."""
|
||||
@@ -47,6 +49,15 @@ class NicObservation(AbstractObservation):
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
global CAPTURE_NMNE
|
||||
if CAPTURE_NMNE:
|
||||
self.nmne_inbound_last_step: int = 0
|
||||
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
|
||||
us find the difference."""
|
||||
self.nmne_outbound_last_step: int = 0
|
||||
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
|
||||
us find the difference."""
|
||||
|
||||
if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold:
|
||||
self._validate_nmne_categories(
|
||||
low_nmne_threshold=low_nmne_threshold,
|
||||
@@ -128,19 +139,21 @@ class NicObservation(AbstractObservation):
|
||||
inbound_count = inbound_keywords.get("*", 0)
|
||||
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
|
||||
outbound_count = outbound_keywords.get("*", 0)
|
||||
obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count)
|
||||
obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count)
|
||||
obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
|
||||
obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
|
||||
self.nmne_inbound_last_step = inbound_count
|
||||
self.nmne_outbound_last_step = outbound_count
|
||||
return obs_dict
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"nic_status": spaces.Discrete(3),
|
||||
"nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}),
|
||||
}
|
||||
)
|
||||
space = spaces.Dict({"nic_status": spaces.Discrete(3)})
|
||||
|
||||
if CAPTURE_NMNE:
|
||||
space["nmne"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)})
|
||||
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
|
||||
|
||||
@@ -51,7 +51,7 @@ class ServiceObservation(AbstractObservation):
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)})
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)})
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
|
||||
Reference in New Issue
Block a user