#2350: add tests to check spaces + acl obs test + nmne space changes

This commit is contained in:
Czar Echavez
2024-03-11 17:47:33 +00:00
parent a228a09917
commit cd6d6325db
9 changed files with 193 additions and 12 deletions

View File

@@ -22,8 +22,6 @@ io_settings:
game:
max_episode_length: 256
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:

View File

@@ -20,6 +20,8 @@ class NicObservation(AbstractObservation):
high_nmne_threshold: int = 10
"""The minimum number of malicious network events to be considered high."""
global CAPTURE_NMNE
@property
def default_observation(self) -> Dict:
"""The default NIC observation dict."""
@@ -47,6 +49,15 @@ class NicObservation(AbstractObservation):
super().__init__()
self.where: Optional[Tuple[str]] = where
global CAPTURE_NMNE
if CAPTURE_NMNE:
self.nmne_inbound_last_step: int = 0
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
us find the difference."""
self.nmne_outbound_last_step: int = 0
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
us find the difference."""
if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold:
self._validate_nmne_categories(
low_nmne_threshold=low_nmne_threshold,
@@ -128,19 +139,21 @@ class NicObservation(AbstractObservation):
inbound_count = inbound_keywords.get("*", 0)
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
outbound_count = outbound_keywords.get("*", 0)
obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count)
obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count)
obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
self.nmne_inbound_last_step = inbound_count
self.nmne_outbound_last_step = outbound_count
return obs_dict
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict(
{
"nic_status": spaces.Discrete(3),
"nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}),
}
)
space = spaces.Dict({"nic_status": spaces.Discrete(3)})
if CAPTURE_NMNE:
space["nmne"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)})
return space
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":

View File

@@ -51,7 +51,7 @@ class ServiceObservation(AbstractObservation):
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)})
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)})
@classmethod
def from_config(