#2417 Add categorisation and updated new configs from merge
This commit is contained in:
@@ -43,6 +43,26 @@ class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
if self.include_num_access:
|
||||
self.default_observation["num_access"] = 0
|
||||
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_threshold = 10
|
||||
self.med_threshold = 5
|
||||
self.low_threshold = 0
|
||||
|
||||
def _categorise_num_access(self, num_access: int) -> int:
|
||||
"""
|
||||
Represent number of file accesses as a categorical variable.
|
||||
|
||||
:param num_access: Number of file accesses.
|
||||
:return: Bin number corresponding to the number of accesses.
|
||||
"""
|
||||
if num_access > self.high_threshold:
|
||||
return 3
|
||||
elif num_access > self.med_threshold:
|
||||
return 2
|
||||
elif num_access > self.low_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
@@ -57,8 +77,7 @@ class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
return self.default_observation
|
||||
obs = {"health_status": file_state["visible_status"]}
|
||||
if self.include_num_access:
|
||||
obs["num_access"] = file_state["num_access"]
|
||||
# raise NotImplementedError("TODO: need to fix num_access to use thresholds instead of raw value.")
|
||||
obs["num_access"] = self._categorise_num_access(file_state["num_access"])
|
||||
return obs
|
||||
|
||||
@property
|
||||
|
||||
@@ -214,9 +214,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
:return: Constructed firewall observation instance.
|
||||
:rtype: FirewallObservation
|
||||
"""
|
||||
where = parent_where + ["nodes", config.hostname]
|
||||
return cls(
|
||||
where=where,
|
||||
where=parent_where + ["nodes", config.hostname],
|
||||
ip_list=config.ip_list,
|
||||
wildcard_list=config.wildcard_list,
|
||||
port_list=config.port_list,
|
||||
|
||||
@@ -185,9 +185,11 @@ class ObservationManager:
|
||||
Create observation space from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this observation space.
|
||||
It should contain the key 'type' which selects which observation class to use (from a choice of:
|
||||
UC2BlueObservation, UC2RedObservation, UC2GreenObservation)
|
||||
The other key is 'options' which are passed to the constructor of the selected observation class.
|
||||
If None, a blank observation space is created.
|
||||
Otherwise, this must be a Dict with a type field and options field.
|
||||
type: string that corresponds to one of the observation identifiers that are provided when subclassing
|
||||
AbstractObservation
|
||||
options: this must adhere to the chosen observation type's ConfigSchema nested class.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
|
||||
@@ -98,6 +98,26 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
|
||||
self.where = where
|
||||
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
|
||||
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_threshold = 10
|
||||
self.med_threshold = 5
|
||||
self.low_threshold = 0
|
||||
|
||||
def _categorise_num_executions(self, num_executions: int) -> int:
|
||||
"""
|
||||
Represent number of file accesses as a categorical variable.
|
||||
|
||||
:param num_access: Number of file accesses.
|
||||
:return: Bin number corresponding to the number of accesses.
|
||||
"""
|
||||
if num_executions > self.high_threshold:
|
||||
return 3
|
||||
elif num_executions > self.med_threshold:
|
||||
return 2
|
||||
elif num_executions > self.low_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
@@ -113,7 +133,7 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
|
||||
return {
|
||||
"operating_status": application_state["operating_state"],
|
||||
"health_status": application_state["health_state_visible"],
|
||||
"num_executions": application_state["num_executions"],
|
||||
"num_executions": self._categorise_num_executions(application_state["num_executions"]),
|
||||
}
|
||||
|
||||
@property
|
||||
|
||||
@@ -64,25 +64,67 @@ agents:
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
type: CUSTOM
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_hostname: client_1
|
||||
links:
|
||||
- link_ref: client_1___switch_1
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_hostname: client_1
|
||||
nic_num: 1
|
||||
ics: null
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: client_1
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.0.10
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- client_1___switch_1
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
# observation_space:
|
||||
# type: UC2BlueObservation
|
||||
# options:
|
||||
# num_services_per_node: 1
|
||||
# num_folders_per_node: 1
|
||||
# num_files_per_folder: 1
|
||||
# num_nics_per_node: 2
|
||||
# nodes:
|
||||
# - node_hostname: client_1
|
||||
# links:
|
||||
# - link_ref: client_1___switch_1
|
||||
# acl:
|
||||
# options:
|
||||
# max_acl_rules: 10
|
||||
# router_hostname: router_1
|
||||
# ip_address_order:
|
||||
# - node_hostname: client_1
|
||||
# nic_num: 1
|
||||
# ics: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
|
||||
@@ -41,8 +41,7 @@ agents:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -91,8 +90,7 @@ agents:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -141,10 +139,7 @@ agents:
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes: {}
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
@@ -177,61 +172,73 @@ agents:
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
type: CUSTOM
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_name: DNSServer
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- node_hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_hostname: backup_server
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_hostname: domain_controller
|
||||
nic_num: 1
|
||||
- node_hostname: web_server
|
||||
nic_num: 1
|
||||
- node_hostname: database_server
|
||||
nic_num: 1
|
||||
- node_hostname: backup_server
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 1
|
||||
- node_hostname: client_1
|
||||
nic_num: 1
|
||||
- node_hostname: client_2
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1___switch_1
|
||||
- router_1___switch_2
|
||||
- switch_1___domain_controller
|
||||
- switch_1___web_server
|
||||
- switch_1___database_server
|
||||
- switch_1___backup_server
|
||||
- switch_1___security_suite
|
||||
- switch_2___client_1
|
||||
- switch_2___client_2
|
||||
- switch_2___security_suite
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
|
||||
@@ -72,3 +72,6 @@ def test_folder_observation(simulation):
|
||||
|
||||
observation_state = root_folder_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("health_status") == 3 # file is corrupt therefore folder is corrupted too
|
||||
|
||||
|
||||
# TODO: Add tests to check num access is correct.
|
||||
|
||||
@@ -51,31 +51,8 @@ def simulation() -> Simulation:
|
||||
return sim
|
||||
|
||||
|
||||
def test_link_observation(simulation):
|
||||
"""Test the link observation."""
|
||||
# get a link
|
||||
link: Link = next(iter(simulation.network.links.values()))
|
||||
|
||||
computer: Computer = simulation.network.get_node_by_hostname("computer")
|
||||
server: Server = simulation.network.get_node_by_hostname("server")
|
||||
|
||||
simulation.apply_timestep(0) # some pings when network was made - reset with apply timestep
|
||||
|
||||
link_obs = LinkObservation(where=["network", "links", link.uuid])
|
||||
|
||||
assert link_obs.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) # test that the spaces are 0-10 including 0 and 10
|
||||
|
||||
observation_state = link_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("PROTOCOLS") is not None
|
||||
assert observation_state["PROTOCOLS"]["ALL"] == 0
|
||||
|
||||
computer.ping(server.network_interface.get(1).ip_address)
|
||||
|
||||
observation_state = link_obs.observe(simulation.describe_state())
|
||||
assert observation_state["PROTOCOLS"]["ALL"] == 1
|
||||
|
||||
|
||||
def test_link_observation_again():
|
||||
def test_link_observation():
|
||||
"""Check the shape and contents of the link observation."""
|
||||
net = Network()
|
||||
sim = Simulation(network=net)
|
||||
switch = Switch(hostname="switch", num_ports=5, operating_state=NodeOperatingState.ON)
|
||||
@@ -102,6 +79,8 @@ def test_link_observation_again():
|
||||
assert "PROTOCOLS" in link_2_obs
|
||||
assert "ALL" in link_1_obs["PROTOCOLS"]
|
||||
assert "ALL" in link_2_obs["PROTOCOLS"]
|
||||
assert link_1_observation.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11)
|
||||
assert link_2_observation.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11)
|
||||
assert link_1_obs["PROTOCOLS"]["ALL"] == 0
|
||||
assert link_2_obs["PROTOCOLS"]["ALL"] == 0
|
||||
|
||||
|
||||
@@ -19,3 +19,8 @@ def test_file_observation():
|
||||
)
|
||||
assert dog_file_obs.observe(state) == {"health_status": 1}
|
||||
assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)})
|
||||
|
||||
|
||||
# TODO:
|
||||
# def test_file_num_access():
|
||||
# ...
|
||||
|
||||
Reference in New Issue
Block a user