#2417 Add categorisation and updated new configs from merge

This commit is contained in:
Marek Wolan
2024-04-01 22:03:28 +01:00
parent e4300faa1c
commit d2c7ae481c
9 changed files with 188 additions and 112 deletions

View File

@@ -43,6 +43,26 @@ class FileObservation(AbstractObservation, identifier="FILE"):
if self.include_num_access:
self.default_observation["num_access"] = 0
# TODO: allow these to be configured in yaml
self.high_threshold = 10
self.med_threshold = 5
self.low_threshold = 0
def _categorise_num_access(self, num_access: int) -> int:
"""
Represent number of file accesses as a categorical variable.
:param num_access: Number of file accesses.
:return: Bin number corresponding to the number of accesses.
"""
if num_access > self.high_threshold:
return 3
elif num_access > self.med_threshold:
return 2
elif num_access > self.low_threshold:
return 1
return 0
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
@@ -57,8 +77,7 @@ class FileObservation(AbstractObservation, identifier="FILE"):
return self.default_observation
obs = {"health_status": file_state["visible_status"]}
if self.include_num_access:
obs["num_access"] = file_state["num_access"]
# raise NotImplementedError("TODO: need to fix num_access to use thresholds instead of raw value.")
obs["num_access"] = self._categorise_num_access(file_state["num_access"])
return obs
@property

View File

@@ -214,9 +214,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
:return: Constructed firewall observation instance.
:rtype: FirewallObservation
"""
where = parent_where + ["nodes", config.hostname]
return cls(
where=where,
where=parent_where + ["nodes", config.hostname],
ip_list=config.ip_list,
wildcard_list=config.wildcard_list,
port_list=config.port_list,

View File

@@ -185,9 +185,11 @@ class ObservationManager:
Create observation space from a config.
:param config: Dictionary containing the configuration for this observation space.
It should contain the key 'type' which selects which observation class to use (from a choice of:
UC2BlueObservation, UC2RedObservation, UC2GreenObservation)
The other key is 'options' which are passed to the constructor of the selected observation class.
If None, a blank observation space is created.
Otherwise, this must be a Dict with a type field and options field.
type: string that corresponds to one of the observation identifiers that are provided when subclassing
AbstractObservation
options: this must adhere to the chosen observation type's ConfigSchema nested class.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame

View File

@@ -98,6 +98,26 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
self.where = where
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
# TODO: allow these to be configured in yaml
self.high_threshold = 10
self.med_threshold = 5
self.low_threshold = 0
def _categorise_num_executions(self, num_executions: int) -> int:
"""
Represent number of file accesses as a categorical variable.
:param num_access: Number of file accesses.
:return: Bin number corresponding to the number of accesses.
"""
if num_executions > self.high_threshold:
return 3
elif num_executions > self.med_threshold:
return 2
elif num_executions > self.low_threshold:
return 1
return 0
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
@@ -113,7 +133,7 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
return {
"operating_status": application_state["operating_state"],
"health_status": application_state["health_state_visible"],
"num_executions": application_state["num_executions"],
"num_executions": self._categorise_num_executions(application_state["num_executions"]),
}
@property

View File

@@ -64,25 +64,67 @@ agents:
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: UC2BlueObservation
type: CUSTOM
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_hostname: client_1
links:
- link_ref: client_1___switch_1
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: client_1
nic_num: 1
ics: null
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: client_1
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.0.10
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- client_1___switch_1
- type: "NONE"
label: ICS
options: {}
# observation_space:
# type: UC2BlueObservation
# options:
# num_services_per_node: 1
# num_folders_per_node: 1
# num_files_per_folder: 1
# num_nics_per_node: 2
# nodes:
# - node_hostname: client_1
# links:
# - link_ref: client_1___switch_1
# acl:
# options:
# max_acl_rules: 10
# router_hostname: router_1
# ip_address_order:
# - node_hostname: client_1
# nic_num: 1
# ics: null
action_space:
action_list:
- type: DONOTHING

View File

@@ -41,8 +41,7 @@ agents:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -91,8 +90,7 @@ agents:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -141,10 +139,7 @@ agents:
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes: {}
observation_space: null
action_space:
action_list:
@@ -177,61 +172,73 @@ agents:
type: ProxyAgent
observation_space:
type: UC2BlueObservation
type: CUSTOM
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_hostname: domain_controller
services:
- service_name: DNSServer
- node_hostname: web_server
services:
- service_name: WebServer
- node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: domain_controller
nic_num: 1
- node_hostname: web_server
nic_num: 1
- node_hostname: database_server
nic_num: 1
- node_hostname: backup_server
nic_num: 1
- node_hostname: security_suite
nic_num: 1
- node_hostname: client_1
nic_num: 1
- node_hostname: client_2
nic_num: 1
- node_hostname: security_suite
nic_num: 2
ics: null
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1___switch_1
- router_1___switch_2
- switch_1___domain_controller
- switch_1___web_server
- switch_1___database_server
- switch_1___backup_server
- switch_1___security_suite
- switch_2___client_1
- switch_2___client_2
- switch_2___security_suite
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:

View File

@@ -72,3 +72,6 @@ def test_folder_observation(simulation):
observation_state = root_folder_obs.observe(simulation.describe_state())
assert observation_state.get("health_status") == 3 # file is corrupt therefore folder is corrupted too
# TODO: Add tests to check num access is correct.

View File

@@ -51,31 +51,8 @@ def simulation() -> Simulation:
return sim
def test_link_observation(simulation):
"""Test the link observation."""
# get a link
link: Link = next(iter(simulation.network.links.values()))
computer: Computer = simulation.network.get_node_by_hostname("computer")
server: Server = simulation.network.get_node_by_hostname("server")
simulation.apply_timestep(0) # some pings when network was made - reset with apply timestep
link_obs = LinkObservation(where=["network", "links", link.uuid])
assert link_obs.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) # test that the spaces are 0-10 including 0 and 10
observation_state = link_obs.observe(simulation.describe_state())
assert observation_state.get("PROTOCOLS") is not None
assert observation_state["PROTOCOLS"]["ALL"] == 0
computer.ping(server.network_interface.get(1).ip_address)
observation_state = link_obs.observe(simulation.describe_state())
assert observation_state["PROTOCOLS"]["ALL"] == 1
def test_link_observation_again():
def test_link_observation():
"""Check the shape and contents of the link observation."""
net = Network()
sim = Simulation(network=net)
switch = Switch(hostname="switch", num_ports=5, operating_state=NodeOperatingState.ON)
@@ -102,6 +79,8 @@ def test_link_observation_again():
assert "PROTOCOLS" in link_2_obs
assert "ALL" in link_1_obs["PROTOCOLS"]
assert "ALL" in link_2_obs["PROTOCOLS"]
assert link_1_observation.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11)
assert link_2_observation.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11)
assert link_1_obs["PROTOCOLS"]["ALL"] == 0
assert link_2_obs["PROTOCOLS"]["ALL"] == 0

View File

@@ -19,3 +19,8 @@ def test_file_observation():
)
assert dog_file_obs.observe(state) == {"health_status": 1}
assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)})
# TODO:
# def test_file_num_access():
# ...