#2417 update observation tests and make old tests pass
This commit is contained in:
@@ -59,10 +59,10 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
"""
|
||||
self.where = where
|
||||
self.num_rules: int = num_rules
|
||||
self.ip_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(ip_list)}
|
||||
self.wildcard_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(wildcard_list)}
|
||||
self.port_to_id: Dict[int, int] = {i + 2: p for i, p in enumerate(port_list)}
|
||||
self.protocol_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(protocol_list)}
|
||||
self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)}
|
||||
self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)}
|
||||
self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)}
|
||||
self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)}
|
||||
self.default_observation: Dict = {
|
||||
i
|
||||
+ 1: {
|
||||
@@ -110,16 +110,16 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
}
|
||||
else:
|
||||
src_ip = rule_state["src_ip_address"]
|
||||
src_node_id = self.ip_to_id.get(src_ip, 1)
|
||||
src_node_id = 1 if src_ip is None else self.ip_to_id[src_ip]
|
||||
dst_ip = rule_state["dst_ip_address"]
|
||||
dst_node_ip = self.ip_to_id.get(dst_ip, 1)
|
||||
src_wildcard = rule_state["source_wildcard_id"]
|
||||
dst_node_id = 1 if dst_ip is None else self.ip_to_id[dst_ip]
|
||||
src_wildcard = rule_state["src_wildcard_mask"]
|
||||
src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1)
|
||||
dst_wildcard = rule_state["dest_wildcard_id"]
|
||||
dst_wildcard = rule_state["dst_wildcard_mask"]
|
||||
dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1)
|
||||
src_port = rule_state["source_port_id"]
|
||||
src_port = rule_state["src_port"]
|
||||
src_port_id = self.port_to_id.get(src_port, 1)
|
||||
dst_port = rule_state["dest_port_id"]
|
||||
dst_port = rule_state["dst_port"]
|
||||
dst_port_id = self.port_to_id.get(dst_port, 1)
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = self.protocol_to_id.get(protocol, 1)
|
||||
@@ -129,7 +129,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
"source_ip_id": src_node_id,
|
||||
"source_wildcard_id": src_wildcard_id,
|
||||
"source_port_id": src_port_id,
|
||||
"dest_ip_id": dst_node_ip,
|
||||
"dest_ip_id": dst_node_id,
|
||||
"dest_wildcard_id": dst_wildcard_id,
|
||||
"dest_port_id": dst_port_id,
|
||||
"protocol_id": protocol_id,
|
||||
|
||||
@@ -133,8 +133,9 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
|
||||
self.default_observation = {
|
||||
"health_status": 0,
|
||||
"FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)},
|
||||
}
|
||||
if self.files:
|
||||
self.default_observation["FILES"] = {i + 1: f.default_observation for i, f in enumerate(self.files)}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -154,7 +155,8 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
obs = {}
|
||||
|
||||
obs["health_status"] = health_status
|
||||
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
|
||||
if self.files:
|
||||
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
|
||||
|
||||
return obs
|
||||
|
||||
@@ -166,12 +168,10 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
:return: Gymnasium space representing the observation space for folder status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"health_status": spaces.Discrete(6),
|
||||
"FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}),
|
||||
}
|
||||
)
|
||||
shape = {"health_status": spaces.Discrete(6)}
|
||||
if self.files:
|
||||
shape["FILES"] = spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)})
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FolderObservation:
|
||||
|
||||
@@ -123,21 +123,27 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
msg = f"Too many folders in Node observation space for node. Truncating folder {truncated_folder.where}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.network_interfaces: List[NICObservation] = network_interfaces
|
||||
while len(self.network_interfaces) < num_nics:
|
||||
self.network_interfaces.append(NICObservation(where=None, include_nmne=include_nmne))
|
||||
while len(self.network_interfaces) > num_nics:
|
||||
truncated_nic = self.network_interfaces.pop()
|
||||
self.nics: List[NICObservation] = network_interfaces
|
||||
while len(self.nics) < num_nics:
|
||||
self.nics.append(NICObservation(where=None, include_nmne=include_nmne))
|
||||
while len(self.nics) > num_nics:
|
||||
truncated_nic = self.nics.pop()
|
||||
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.default_observation: ObsType = {
|
||||
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
|
||||
"APPLICATIONS": {i + 1: a.default_observation for i, a in enumerate(self.applications)},
|
||||
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
|
||||
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
|
||||
"operating_status": 0,
|
||||
}
|
||||
if self.services:
|
||||
self.default_observation["SERVICES"] = {i + 1: s.default_observation for i, s in enumerate(self.services)}
|
||||
if self.applications:
|
||||
self.default_observation["APPLICATIONS"] = {
|
||||
i + 1: a.default_observation for i, a in enumerate(self.applications)
|
||||
}
|
||||
if self.folders:
|
||||
self.default_observation["FOLDERS"] = {i + 1: f.default_observation for i, f in enumerate(self.folders)}
|
||||
if self.nics:
|
||||
self.default_observation["NICS"] = {i + 1: n.default_observation for i, n in enumerate(self.nics)}
|
||||
if self.include_num_access:
|
||||
self.default_observation["num_file_creations"] = 0
|
||||
self.default_observation["num_file_deletions"] = 0
|
||||
@@ -156,13 +162,15 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
|
||||
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
|
||||
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
|
||||
obs["operating_status"] = node_state["operating_state"]
|
||||
obs["NICS"] = {
|
||||
i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces)
|
||||
}
|
||||
if self.services:
|
||||
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
|
||||
if self.applications:
|
||||
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
|
||||
if self.folders:
|
||||
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
|
||||
if self.nics:
|
||||
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
|
||||
if self.include_num_access:
|
||||
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
|
||||
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
|
||||
@@ -177,14 +185,16 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
shape = {
|
||||
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
|
||||
"APPLICATIONS": spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)}),
|
||||
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
|
||||
"operating_status": spaces.Discrete(5),
|
||||
"NICS": spaces.Dict(
|
||||
{i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)}
|
||||
),
|
||||
}
|
||||
if self.services:
|
||||
shape["SERVICES"] = spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)})
|
||||
if self.applications:
|
||||
shape["APPLICATIONS"] = spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)})
|
||||
if self.folders:
|
||||
shape["FOLDERS"] = spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)})
|
||||
if self.nics:
|
||||
shape["NICS"] = spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)})
|
||||
if self.include_num_access:
|
||||
shape["num_file_creations"] = spaces.Discrete(4)
|
||||
shape["num_file_deletions"] = spaces.Discrete(4)
|
||||
|
||||
@@ -23,7 +23,11 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
include_nmne: Optional[bool] = None
|
||||
"""Whether to include number of malicious network events (NMNE) in the observation."""
|
||||
|
||||
def __init__(self, where: WhereType, include_nmne: bool) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
include_nmne: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a network interface observation instance.
|
||||
|
||||
@@ -40,6 +44,36 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
self.default_observation: ObsType = {"nic_status": 0}
|
||||
if self.include_nmne:
|
||||
self.default_observation.update({"NMNE": {"inbound": 0, "outbound": 0}})
|
||||
self.nmne_inbound_last_step: int = 0
|
||||
self.nmne_outbound_last_step: int = 0
|
||||
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_nmne_threshold = 10
|
||||
self.med_nmne_threshold = 5
|
||||
self.low_nmne_threshold = 0
|
||||
|
||||
def _categorise_mne_count(self, nmne_count: int) -> int:
|
||||
"""
|
||||
Categorise the number of Malicious Network Events (NMNEs) into discrete bins.
|
||||
|
||||
This helps in classifying the severity or volume of MNEs into manageable levels for the agent.
|
||||
|
||||
Bins are defined as follows:
|
||||
- 0: No MNEs detected (0 events).
|
||||
- 1: Low number of MNEs (default 1-5 events).
|
||||
- 2: Moderate number of MNEs (default 6-10 events).
|
||||
- 3: High number of MNEs (default more than 10 events).
|
||||
|
||||
:param nmne_count: Number of MNEs detected.
|
||||
:return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count.
|
||||
"""
|
||||
if nmne_count > self.high_nmne_threshold:
|
||||
return 3
|
||||
elif nmne_count > self.med_nmne_threshold:
|
||||
return 2
|
||||
elif nmne_count > self.low_nmne_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
|
||||
@@ -74,9 +74,10 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.default_observation = {
|
||||
"PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)},
|
||||
"ACL": self.acl.default_observation,
|
||||
}
|
||||
if self.ports:
|
||||
self.default_observation["PORTS"] = {i + 1: p.default_observation for i, p in enumerate(self.ports)}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -92,8 +93,9 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
|
||||
obs["ACL"] = self.acl.observe(state)
|
||||
if self.ports:
|
||||
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
|
||||
return obs
|
||||
|
||||
@property
|
||||
@@ -104,9 +106,10 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
:return: Gymnasium space representing the observation space for router status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), "ACL": self.acl.space}
|
||||
)
|
||||
shape = {"ACL": self.acl.space}
|
||||
if self.ports:
|
||||
shape["PORTS"] = spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)})
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> RouterObservation:
|
||||
|
||||
Reference in New Issue
Block a user