#2417 update observation tests and make old tests pass

This commit is contained in:
Marek Wolan
2024-04-01 00:54:55 +01:00
parent 0e0df1012f
commit 0ba767d2a0
22 changed files with 767 additions and 626 deletions

View File

@@ -59,10 +59,10 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
"""
self.where = where
self.num_rules: int = num_rules
self.ip_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(ip_list)}
self.wildcard_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(wildcard_list)}
self.port_to_id: Dict[int, int] = {i + 2: p for i, p in enumerate(port_list)}
self.protocol_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(protocol_list)}
self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)}
self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)}
self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)}
self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)}
self.default_observation: Dict = {
i
+ 1: {
@@ -110,16 +110,16 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
}
else:
src_ip = rule_state["src_ip_address"]
src_node_id = self.ip_to_id.get(src_ip, 1)
src_node_id = 1 if src_ip is None else self.ip_to_id[src_ip]
dst_ip = rule_state["dst_ip_address"]
dst_node_ip = self.ip_to_id.get(dst_ip, 1)
src_wildcard = rule_state["source_wildcard_id"]
dst_node_id = 1 if dst_ip is None else self.ip_to_id[dst_ip]
src_wildcard = rule_state["src_wildcard_mask"]
src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1)
dst_wildcard = rule_state["dest_wildcard_id"]
dst_wildcard = rule_state["dst_wildcard_mask"]
dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1)
src_port = rule_state["source_port_id"]
src_port = rule_state["src_port"]
src_port_id = self.port_to_id.get(src_port, 1)
dst_port = rule_state["dest_port_id"]
dst_port = rule_state["dst_port"]
dst_port_id = self.port_to_id.get(dst_port, 1)
protocol = rule_state["protocol"]
protocol_id = self.protocol_to_id.get(protocol, 1)
@@ -129,7 +129,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
"source_ip_id": src_node_id,
"source_wildcard_id": src_wildcard_id,
"source_port_id": src_port_id,
"dest_ip_id": dst_node_ip,
"dest_ip_id": dst_node_id,
"dest_wildcard_id": dst_wildcard_id,
"dest_port_id": dst_port_id,
"protocol_id": protocol_id,

View File

@@ -133,8 +133,9 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
self.default_observation = {
"health_status": 0,
"FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)},
}
if self.files:
self.default_observation["FILES"] = {i + 1: f.default_observation for i, f in enumerate(self.files)}
def observe(self, state: Dict) -> ObsType:
"""
@@ -154,7 +155,8 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
obs = {}
obs["health_status"] = health_status
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
if self.files:
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
return obs
@@ -166,12 +168,10 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
:return: Gymnasium space representing the observation space for folder status.
:rtype: spaces.Space
"""
return spaces.Dict(
{
"health_status": spaces.Discrete(6),
"FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}),
}
)
shape = {"health_status": spaces.Discrete(6)}
if self.files:
shape["FILES"] = spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)})
return spaces.Dict(shape)
@classmethod
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FolderObservation:

View File

@@ -123,21 +123,27 @@ class HostObservation(AbstractObservation, identifier="HOST"):
msg = f"Too many folders in Node observation space for node. Truncating folder {truncated_folder.where}"
_LOGGER.warning(msg)
self.network_interfaces: List[NICObservation] = network_interfaces
while len(self.network_interfaces) < num_nics:
self.network_interfaces.append(NICObservation(where=None, include_nmne=include_nmne))
while len(self.network_interfaces) > num_nics:
truncated_nic = self.network_interfaces.pop()
self.nics: List[NICObservation] = network_interfaces
while len(self.nics) < num_nics:
self.nics.append(NICObservation(where=None, include_nmne=include_nmne))
while len(self.nics) > num_nics:
truncated_nic = self.nics.pop()
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}"
_LOGGER.warning(msg)
self.default_observation: ObsType = {
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
"APPLICATIONS": {i + 1: a.default_observation for i, a in enumerate(self.applications)},
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
"operating_status": 0,
}
if self.services:
self.default_observation["SERVICES"] = {i + 1: s.default_observation for i, s in enumerate(self.services)}
if self.applications:
self.default_observation["APPLICATIONS"] = {
i + 1: a.default_observation for i, a in enumerate(self.applications)
}
if self.folders:
self.default_observation["FOLDERS"] = {i + 1: f.default_observation for i, f in enumerate(self.folders)}
if self.nics:
self.default_observation["NICS"] = {i + 1: n.default_observation for i, n in enumerate(self.nics)}
if self.include_num_access:
self.default_observation["num_file_creations"] = 0
self.default_observation["num_file_deletions"] = 0
@@ -156,13 +162,15 @@ class HostObservation(AbstractObservation, identifier="HOST"):
return self.default_observation
obs = {}
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
obs["operating_status"] = node_state["operating_state"]
obs["NICS"] = {
i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces)
}
if self.services:
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
if self.applications:
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
if self.folders:
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
if self.nics:
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
if self.include_num_access:
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
@@ -177,14 +185,16 @@ class HostObservation(AbstractObservation, identifier="HOST"):
:rtype: spaces.Space
"""
shape = {
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
"APPLICATIONS": spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)}),
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
"operating_status": spaces.Discrete(5),
"NICS": spaces.Dict(
{i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)}
),
}
if self.services:
shape["SERVICES"] = spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)})
if self.applications:
shape["APPLICATIONS"] = spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)})
if self.folders:
shape["FOLDERS"] = spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)})
if self.nics:
shape["NICS"] = spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)})
if self.include_num_access:
shape["num_file_creations"] = spaces.Discrete(4)
shape["num_file_deletions"] = spaces.Discrete(4)

View File

@@ -23,7 +23,11 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
include_nmne: Optional[bool] = None
"""Whether to include number of malicious network events (NMNE) in the observation."""
def __init__(self, where: WhereType, include_nmne: bool) -> None:
def __init__(
self,
where: WhereType,
include_nmne: bool,
) -> None:
"""
Initialise a network interface observation instance.
@@ -40,6 +44,36 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
self.default_observation: ObsType = {"nic_status": 0}
if self.include_nmne:
self.default_observation.update({"NMNE": {"inbound": 0, "outbound": 0}})
self.nmne_inbound_last_step: int = 0
self.nmne_outbound_last_step: int = 0
# TODO: allow these to be configured in yaml
self.high_nmne_threshold = 10
self.med_nmne_threshold = 5
self.low_nmne_threshold = 0
def _categorise_mne_count(self, nmne_count: int) -> int:
"""
Categorise the number of Malicious Network Events (NMNEs) into discrete bins.
This helps in classifying the severity or volume of MNEs into manageable levels for the agent.
Bins are defined as follows:
- 0: No MNEs detected (0 events).
- 1: Low number of MNEs (default 1-5 events).
- 2: Moderate number of MNEs (default 6-10 events).
- 3: High number of MNEs (default more than 10 events).
:param nmne_count: Number of MNEs detected.
:return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count.
"""
if nmne_count > self.high_nmne_threshold:
return 3
elif nmne_count > self.med_nmne_threshold:
return 2
elif nmne_count > self.low_nmne_threshold:
return 1
return 0
def observe(self, state: Dict) -> ObsType:
"""

View File

@@ -74,9 +74,10 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
_LOGGER.warning(msg)
self.default_observation = {
"PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)},
"ACL": self.acl.default_observation,
}
if self.ports:
self.default_observation["PORTS"] = {i + 1: p.default_observation for i, p in enumerate(self.ports)}
def observe(self, state: Dict) -> ObsType:
"""
@@ -92,8 +93,9 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
return self.default_observation
obs = {}
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
obs["ACL"] = self.acl.observe(state)
if self.ports:
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
return obs
@property
@@ -104,9 +106,10 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
:return: Gymnasium space representing the observation space for router status.
:rtype: spaces.Space
"""
return spaces.Dict(
{"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), "ACL": self.acl.space}
)
shape = {"ACL": self.acl.space}
if self.ports:
shape["PORTS"] = spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)})
return spaces.Dict(shape)
@classmethod
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> RouterObservation: