#2417 update observation tests and make old tests pass

This commit is contained in:
Marek Wolan
2024-04-01 00:54:55 +01:00
parent 0e0df1012f
commit 0ba767d2a0
22 changed files with 767 additions and 626 deletions

View File

@@ -40,8 +40,7 @@ agents:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -90,8 +89,7 @@ agents:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -140,10 +138,7 @@ agents:
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes: {}
observation_space: null
action_space:
action_list:
@@ -179,61 +174,73 @@ agents:
type: ProxyAgent
observation_space:
type: UC2BlueObservation
type: CUSTOM
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_hostname: domain_controller
services:
- service_name: DNSServer
- node_hostname: web_server
services:
- service_name: WebServer
- node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: domain_controller
nic_num: 1
- node_hostname: web_server
nic_num: 1
- node_hostname: database_server
nic_num: 1
- node_hostname: backup_server
nic_num: 1
- node_hostname: security_suite
nic_num: 1
- node_hostname: client_1
nic_num: 1
- node_hostname: client_2
nic_num: 1
- node_hostname: security_suite
nic_num: 2
ics: null
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1___switch_1
- router_1___switch_2
- switch_1___domain_controller
- switch_1___web_server
- switch_1___database_server
- switch_1___backup_server
- switch_1___security_suite
- switch_2___client_1
- switch_2___client_2
- switch_2___security_suite
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:
@@ -730,61 +737,73 @@ agents:
type: ProxyAgent
observation_space:
type: UC2BlueObservation
type: CUSTOM
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_hostname: domain_controller
services:
- service_name: DNSServer
- node_hostname: web_server
services:
- service_name: WebServer
- node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: domain_controller
nic_num: 1
- node_hostname: web_server
nic_num: 1
- node_hostname: database_server
nic_num: 1
- node_hostname: backup_server
nic_num: 1
- node_hostname: security_suite
nic_num: 1
- node_hostname: client_1
nic_num: 1
- node_hostname: client_2
nic_num: 1
- node_hostname: security_suite
nic_num: 2
ics: null
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1___switch_1
- router_1___switch_2
- switch_1___domain_controller
- switch_1___web_server
- switch_1___database_server
- switch_1___backup_server
- switch_1___security_suite
- switch_2___client_1
- switch_2___client_2
- switch_2___security_suite
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:

View File

@@ -59,10 +59,10 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
"""
self.where = where
self.num_rules: int = num_rules
self.ip_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(ip_list)}
self.wildcard_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(wildcard_list)}
self.port_to_id: Dict[int, int] = {i + 2: p for i, p in enumerate(port_list)}
self.protocol_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(protocol_list)}
self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)}
self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)}
self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)}
self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)}
self.default_observation: Dict = {
i
+ 1: {
@@ -110,16 +110,16 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
}
else:
src_ip = rule_state["src_ip_address"]
src_node_id = self.ip_to_id.get(src_ip, 1)
src_node_id = 1 if src_ip is None else self.ip_to_id[src_ip]
dst_ip = rule_state["dst_ip_address"]
dst_node_ip = self.ip_to_id.get(dst_ip, 1)
src_wildcard = rule_state["source_wildcard_id"]
dst_node_id = 1 if dst_ip is None else self.ip_to_id[dst_ip]
src_wildcard = rule_state["src_wildcard_mask"]
src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1)
dst_wildcard = rule_state["dest_wildcard_id"]
dst_wildcard = rule_state["dst_wildcard_mask"]
dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1)
src_port = rule_state["source_port_id"]
src_port = rule_state["src_port"]
src_port_id = self.port_to_id.get(src_port, 1)
dst_port = rule_state["dest_port_id"]
dst_port = rule_state["dst_port"]
dst_port_id = self.port_to_id.get(dst_port, 1)
protocol = rule_state["protocol"]
protocol_id = self.protocol_to_id.get(protocol, 1)
@@ -129,7 +129,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
"source_ip_id": src_node_id,
"source_wildcard_id": src_wildcard_id,
"source_port_id": src_port_id,
"dest_ip_id": dst_node_ip,
"dest_ip_id": dst_node_id,
"dest_wildcard_id": dst_wildcard_id,
"dest_port_id": dst_port_id,
"protocol_id": protocol_id,

View File

@@ -133,8 +133,9 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
self.default_observation = {
"health_status": 0,
"FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)},
}
if self.files:
self.default_observation["FILES"] = {i + 1: f.default_observation for i, f in enumerate(self.files)}
def observe(self, state: Dict) -> ObsType:
"""
@@ -154,7 +155,8 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
obs = {}
obs["health_status"] = health_status
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
if self.files:
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
return obs
@@ -166,12 +168,10 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
:return: Gymnasium space representing the observation space for folder status.
:rtype: spaces.Space
"""
return spaces.Dict(
{
"health_status": spaces.Discrete(6),
"FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}),
}
)
shape = {"health_status": spaces.Discrete(6)}
if self.files:
shape["FILES"] = spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)})
return spaces.Dict(shape)
@classmethod
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FolderObservation:

View File

@@ -123,21 +123,27 @@ class HostObservation(AbstractObservation, identifier="HOST"):
msg = f"Too many folders in Node observation space for node. Truncating folder {truncated_folder.where}"
_LOGGER.warning(msg)
self.network_interfaces: List[NICObservation] = network_interfaces
while len(self.network_interfaces) < num_nics:
self.network_interfaces.append(NICObservation(where=None, include_nmne=include_nmne))
while len(self.network_interfaces) > num_nics:
truncated_nic = self.network_interfaces.pop()
self.nics: List[NICObservation] = network_interfaces
while len(self.nics) < num_nics:
self.nics.append(NICObservation(where=None, include_nmne=include_nmne))
while len(self.nics) > num_nics:
truncated_nic = self.nics.pop()
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}"
_LOGGER.warning(msg)
self.default_observation: ObsType = {
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
"APPLICATIONS": {i + 1: a.default_observation for i, a in enumerate(self.applications)},
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
"operating_status": 0,
}
if self.services:
self.default_observation["SERVICES"] = {i + 1: s.default_observation for i, s in enumerate(self.services)}
if self.applications:
self.default_observation["APPLICATIONS"] = {
i + 1: a.default_observation for i, a in enumerate(self.applications)
}
if self.folders:
self.default_observation["FOLDERS"] = {i + 1: f.default_observation for i, f in enumerate(self.folders)}
if self.nics:
self.default_observation["NICS"] = {i + 1: n.default_observation for i, n in enumerate(self.nics)}
if self.include_num_access:
self.default_observation["num_file_creations"] = 0
self.default_observation["num_file_deletions"] = 0
@@ -156,13 +162,15 @@ class HostObservation(AbstractObservation, identifier="HOST"):
return self.default_observation
obs = {}
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
obs["operating_status"] = node_state["operating_state"]
obs["NICS"] = {
i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces)
}
if self.services:
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
if self.applications:
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
if self.folders:
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
if self.nics:
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
if self.include_num_access:
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
@@ -177,14 +185,16 @@ class HostObservation(AbstractObservation, identifier="HOST"):
:rtype: spaces.Space
"""
shape = {
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
"APPLICATIONS": spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)}),
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
"operating_status": spaces.Discrete(5),
"NICS": spaces.Dict(
{i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)}
),
}
if self.services:
shape["SERVICES"] = spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)})
if self.applications:
shape["APPLICATIONS"] = spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)})
if self.folders:
shape["FOLDERS"] = spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)})
if self.nics:
shape["NICS"] = spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)})
if self.include_num_access:
shape["num_file_creations"] = spaces.Discrete(4)
shape["num_file_deletions"] = spaces.Discrete(4)

View File

@@ -23,7 +23,11 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
include_nmne: Optional[bool] = None
"""Whether to include number of malicious network events (NMNE) in the observation."""
def __init__(self, where: WhereType, include_nmne: bool) -> None:
def __init__(
self,
where: WhereType,
include_nmne: bool,
) -> None:
"""
Initialise a network interface observation instance.
@@ -40,6 +44,36 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
self.default_observation: ObsType = {"nic_status": 0}
if self.include_nmne:
self.default_observation.update({"NMNE": {"inbound": 0, "outbound": 0}})
self.nmne_inbound_last_step: int = 0
self.nmne_outbound_last_step: int = 0
# TODO: allow these to be configured in yaml
self.high_nmne_threshold = 10
self.med_nmne_threshold = 5
self.low_nmne_threshold = 0
def _categorise_mne_count(self, nmne_count: int) -> int:
"""
Categorise the number of Malicious Network Events (NMNEs) into discrete bins.
This helps in classifying the severity or volume of MNEs into manageable levels for the agent.
Bins are defined as follows:
- 0: No MNEs detected (0 events).
- 1: Low number of MNEs (default 1-5 events).
- 2: Moderate number of MNEs (default 6-10 events).
- 3: High number of MNEs (default more than 10 events).
:param nmne_count: Number of MNEs detected.
:return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count.
"""
if nmne_count > self.high_nmne_threshold:
return 3
elif nmne_count > self.med_nmne_threshold:
return 2
elif nmne_count > self.low_nmne_threshold:
return 1
return 0
def observe(self, state: Dict) -> ObsType:
"""

View File

@@ -74,9 +74,10 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
_LOGGER.warning(msg)
self.default_observation = {
"PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)},
"ACL": self.acl.default_observation,
}
if self.ports:
self.default_observation["PORTS"] = {i + 1: p.default_observation for i, p in enumerate(self.ports)}
def observe(self, state: Dict) -> ObsType:
"""
@@ -92,8 +93,9 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
return self.default_observation
obs = {}
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
obs["ACL"] = self.acl.observe(state)
if self.ports:
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
return obs
@property
@@ -104,9 +106,10 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
:return: Gymnasium space representing the observation space for router status.
:rtype: spaces.Space
"""
return spaces.Dict(
{"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), "ACL": self.acl.space}
)
shape = {"ACL": self.acl.space}
if self.ports:
shape["PORTS"] = spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)})
return spaces.Dict(shape)
@classmethod
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> RouterObservation: